├── README.md ├── builder ├── __init__.py ├── data_builder.py ├── loss_builder.py └── model_builder.py ├── config ├── __init__.py ├── config.py ├── label_mapping │ ├── semantic-kitti-all.yaml │ ├── semantic-kitti-multiscan.yaml │ └── semantic-kitti.yaml ├── semantickitti-multiscan.yaml └── semantickitti.yaml ├── dataloader with label rectification ├── dataloader │ ├── __init__.py │ ├── dataset_semantickitti.py │ └── pc_dataset.py └── visualize_voxel_label.py ├── dataloader ├── __init__.py ├── dataset_semantickitti.py └── pc_dataset.py ├── imgs └── pipline.png ├── network ├── __init__.py ├── conv_base.py ├── cylinder_fea_generator.py ├── cylinder_spconv_3d.py └── segmentator_3d_asymm_spconv.py ├── test_scpnet_comp.py ├── train_scpnet_comp.py └── utils ├── __init__.py ├── load_save_util.py ├── log_util.py ├── lovasz_losses.py ├── metric_util.py └── np_ioueval.py /README.md: -------------------------------------------------------------------------------- 1 | ## SCPNet: Semantic Scene Completion on Point Cloud (CVPR 2023, Highlight) 2 | 3 | 4 | 5 | ## News 6 | - **2023-05** Preliminary codes are released. 7 | - **2023-02** Our SCPNet is accepted by CVPR 2023 (**Highlight**)! 8 | - **2022-11** Our method ranks **1st** in **SemanticKITTI Semantic Scene Completion Challenge**, with mIoU=36.7. 9 | - SCPNet is comprised of a novel completion sub-network without an encoder-decoder structure and a segmentation sub-network obtained by replacing the cylindrical partition of Cylinder3D with conventional cubic partition. 10 | 11 | 12 | ## Installation 13 | 14 | ### Requirements 15 | - PyTorch >= 1.10 16 | - pyyaml 17 | - Cython 18 | - tqdm 19 | - numba 20 | - Numpy-indexed 21 | - [torch-scatter](https://github.com/rusty1s/pytorch_scatter) 22 | - [spconv](https://github.com/tyjiang1997/spconv1.0) (tested with spconv==1.0 and cuda==11.3) 23 | 24 | ## Data Preparation 25 | 26 | ### SemanticKITTI 27 | ``` 28 | ./ 29 | ├── 30 | ├── ... 31 | └── path_to_data_shown_in_config/ 32 | ├──sequences 33 | ├── 00/ 34 | │ ├── velodyne/ 35 | | | ├── 000000.bin 36 | | | ├── 000001.bin 37 | | | └── ... 38 | │ └── labels/ 39 | | ├── 000000.label 40 | | ├── 000001.label 41 | | └── ... 42 | │ └── voxels/ 43 | | ├── 000000.bin 44 | | ├── 000000.label 45 | | ├── 000000.invalid 46 | | ├── 000000.occluded 47 | | ├── 000001.bin 48 | | ├── 000001.label 49 | | ├── 000001.invalid 50 | | ├── 000001.occluded 51 | | └── ... 52 | ├── 08/ # for validation 53 | ├── 11/ # 11-21 for testing 54 | └── 21/ 55 | └── ... 56 | ``` 57 | ## Test 58 | We take evaluation on the SemanticKITTI test set (single-scan) as example. 59 | 1. Download the pre-trained models and put them in ```./model_load_dir```. 60 | 2. Set val_data_loader>imageset: “test” in the configuration file ```config/semantickitti-multiscan.yaml```. 61 | 3. Generate predictions on the SemanticKITTI test set. 62 | ``` 63 | CUDA_VISIBLE_DEVICES=0 python -u test_scpnet_comp.py 64 | ``` 65 | The model predictions will be saved in ```./out_scpnet/test``` by default. 66 | 67 | ## Train 68 | 1. Set val_data_loader>imageset: “test” in the configuration file ```config/semantickitti-multiscan.yaml```. 69 | 2. train the network by running the train script 70 | ``` 71 | CUDA_VISIBLE_DEVICES=0 python -u train_scpnet_comp.py 72 | ``` 73 | 74 | ## Citation 75 | 76 | If you use the codes, please cite the following publication: 77 | ``` 78 | @inproceedings{scpnet, 79 | title = {SCPNet: Semantic Scene Completion on Point Cloud}, 80 | author = {Xia, Zhaoyang and Liu, Youquan and Li, Xin and Zhu, Xinge and Ma, Yuexin and Li, Yikang and Hou, Yuenan and Qiao, Yu}, 81 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, 82 | year = {2023} 83 | } 84 | ``` 85 | 86 | ## Acknowledgements 87 | We thanks for these codebases, including [Cylinder3D](https://github.com/xinge008/Cylinder3D), [PVKD](https://github.com/cardwing/Codes-for-PVKD) and [spconv](https://github.com/traveller59/spconv). 88 | -------------------------------------------------------------------------------- /builder/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | -------------------------------------------------------------------------------- /builder/data_builder.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: data_builder.py 4 | 5 | import torch 6 | from dataloader.dataset_semantickitti import get_model_class, collate_fn_BEV, collate_fn_BEV_tta, collate_fn_BEV_ms, collate_fn_BEV_ms_tta 7 | from dataloader.pc_dataset import get_pc_model_class 8 | 9 | 10 | def build(dataset_config, 11 | train_dataloader_config, 12 | val_dataloader_config, 13 | grid_size=[480, 360, 32], 14 | use_tta=False, 15 | use_multiscan=False, 16 | use_waymo=False): 17 | data_path = train_dataloader_config["data_path"] 18 | train_imageset = train_dataloader_config["imageset"] 19 | val_imageset = val_dataloader_config["imageset"] 20 | train_ref = train_dataloader_config["return_ref"] 21 | val_ref = val_dataloader_config["return_ref"] 22 | 23 | label_mapping = dataset_config["label_mapping"] 24 | 25 | SemKITTI = get_pc_model_class(dataset_config['pc_dataset_type']) 26 | 27 | nusc=None 28 | if "nusc" in dataset_config['pc_dataset_type']: 29 | from nuscenes import NuScenes 30 | nusc = NuScenes(version='v1.0-trainval', dataroot=data_path, verbose=True) 31 | 32 | train_pt_dataset = SemKITTI(data_path, imageset=train_imageset, 33 | return_ref=train_ref, label_mapping=label_mapping, nusc=nusc) 34 | val_pt_dataset = SemKITTI(data_path, imageset=val_imageset, 35 | return_ref=val_ref, label_mapping=label_mapping, nusc=nusc) 36 | 37 | # train_dataset = get_model_class(dataset_config['dataset_type'])( 38 | # train_pt_dataset, 39 | # grid_size=grid_size, 40 | # flip_aug=True, 41 | # fixed_volume_space=dataset_config['fixed_volume_space'], 42 | # max_volume_space=dataset_config['max_volume_space'], 43 | # min_volume_space=dataset_config['min_volume_space'], 44 | # ignore_label=dataset_config["ignore_label"], 45 | # rotate_aug=True, 46 | # scale_aug=True, 47 | # transform_aug=True 48 | # ) 49 | # if use_tta: 50 | # val_dataset = get_model_class(dataset_config['dataset_type'])( 51 | # val_pt_dataset, 52 | # grid_size=grid_size, 53 | # flip_aug=True, 54 | # fixed_volume_space=dataset_config['fixed_volume_space'], 55 | # max_volume_space=dataset_config['max_volume_space'], 56 | # min_volume_space=dataset_config['min_volume_space'], 57 | # ignore_label=dataset_config["ignore_label"], 58 | # rotate_aug=True, 59 | # scale_aug=True, 60 | # return_test=True, 61 | # use_tta=True, 62 | # ) 63 | dataAug = 0 64 | if dataAug: 65 | train_dataset = get_model_class(dataset_config['dataset_type'])( 66 | train_pt_dataset, 67 | grid_size=grid_size, 68 | rotate_aug=True, 69 | flip_aug=True, 70 | ignore_label=dataset_config["ignore_label"], 71 | fixed_volume_space=dataset_config['fixed_volume_space'], 72 | max_volume_space=dataset_config['max_volume_space'], 73 | min_volume_space=dataset_config['min_volume_space'], 74 | return_test=True, 75 | ) 76 | else: 77 | train_dataset = get_model_class(dataset_config['dataset_type'])( 78 | train_pt_dataset, 79 | grid_size=grid_size, 80 | rotate_aug=False, 81 | flip_aug=False, 82 | ignore_label=dataset_config["ignore_label"], 83 | fixed_volume_space=dataset_config['fixed_volume_space'], 84 | max_volume_space=dataset_config['max_volume_space'], 85 | min_volume_space=dataset_config['min_volume_space'], 86 | return_test=True, 87 | ) 88 | if use_tta: 89 | if dataAug: 90 | val_dataset = get_model_class(dataset_config['dataset_type'])( 91 | val_pt_dataset, 92 | grid_size=grid_size, 93 | rotate_aug=False, # True 94 | flip_aug=False, # True 95 | ignore_label=dataset_config["ignore_label"], 96 | fixed_volume_space=dataset_config['fixed_volume_space'], 97 | max_volume_space=dataset_config['max_volume_space'], 98 | min_volume_space=dataset_config['min_volume_space'], 99 | return_test=True, 100 | ) 101 | else: 102 | val_dataset = get_model_class(dataset_config['dataset_type'])( 103 | val_pt_dataset, 104 | grid_size=grid_size, 105 | rotate_aug=False, 106 | flip_aug=False, 107 | ignore_label=dataset_config["ignore_label"], 108 | fixed_volume_space=dataset_config['fixed_volume_space'], 109 | max_volume_space=dataset_config['max_volume_space'], 110 | min_volume_space=dataset_config['min_volume_space'], 111 | return_test=True, 112 | ) 113 | if use_multiscan: 114 | collate_fn_BEV_tmp = collate_fn_BEV_ms_tta 115 | else: 116 | collate_fn_BEV_tmp = collate_fn_BEV_tta 117 | else: 118 | val_dataset = get_model_class(dataset_config['dataset_type'])( 119 | val_pt_dataset, 120 | grid_size=grid_size, 121 | fixed_volume_space=dataset_config['fixed_volume_space'], 122 | max_volume_space=dataset_config['max_volume_space'], 123 | min_volume_space=dataset_config['min_volume_space'], 124 | ignore_label=dataset_config["ignore_label"], 125 | return_test=True, 126 | ) 127 | if use_multiscan or use_waymo: 128 | collate_fn_BEV_tmp = collate_fn_BEV_ms 129 | else: 130 | collate_fn_BEV_tmp = collate_fn_BEV 131 | 132 | train_dataset_loader = torch.utils.data.DataLoader(dataset=train_dataset, 133 | batch_size=train_dataloader_config["batch_size"], 134 | collate_fn=collate_fn_BEV_tmp, 135 | shuffle=train_dataloader_config["shuffle"], 136 | num_workers=train_dataloader_config["num_workers"]) 137 | val_dataset_loader = torch.utils.data.DataLoader(dataset=val_dataset, 138 | batch_size=val_dataloader_config["batch_size"], 139 | collate_fn=collate_fn_BEV_tmp, 140 | shuffle=val_dataloader_config["shuffle"], 141 | num_workers=val_dataloader_config["num_workers"]) 142 | 143 | if use_tta: 144 | return train_dataset_loader, val_dataset_loader, val_pt_dataset 145 | else: 146 | return train_dataset_loader, val_dataset_loader, val_pt_dataset 147 | -------------------------------------------------------------------------------- /builder/loss_builder.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: loss_builder.py 4 | 5 | import torch 6 | from utils.lovasz_losses import lovasz_softmax 7 | 8 | 9 | def build(wce=True, lovasz=True, num_class=20, ignore_label=0): 10 | 11 | loss_funs = torch.nn.CrossEntropyLoss(ignore_index=ignore_label) 12 | 13 | if wce and lovasz: 14 | return loss_funs, lovasz_softmax 15 | elif wce and not lovasz: 16 | return wce 17 | elif not wce and lovasz: 18 | return lovasz_softmax 19 | else: 20 | raise NotImplementedError 21 | -------------------------------------------------------------------------------- /builder/model_builder.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: model_builder.py 4 | 5 | from network.cylinder_spconv_3d import get_model_class 6 | from network.segmentator_3d_asymm_spconv import Asymm_3d_spconv 7 | from network.cylinder_fea_generator import cylinder_fea 8 | 9 | 10 | def build(model_config): 11 | output_shape = model_config['output_shape'] 12 | num_class = model_config['num_class'] 13 | num_input_features = model_config['num_input_features'] 14 | use_norm = model_config['use_norm'] 15 | init_size = model_config['init_size'] 16 | fea_dim = model_config['fea_dim'] 17 | out_fea_dim = model_config['out_fea_dim'] 18 | 19 | cylinder_3d_spconv_seg = Asymm_3d_spconv( 20 | output_shape=output_shape, 21 | use_norm=use_norm, 22 | num_input_features=num_input_features, 23 | init_size=init_size, 24 | nclasses=num_class) 25 | 26 | cy_fea_net = cylinder_fea(grid_size=output_shape, 27 | fea_dim=fea_dim, 28 | out_pt_fea_dim=out_fea_dim, 29 | fea_compre=num_input_features) 30 | 31 | model = get_model_class(model_config["model_architecture"])( 32 | cylin_model=cy_fea_net, 33 | segmentator_spconv=cylinder_3d_spconv_seg, 34 | sparse_shape=output_shape 35 | ) 36 | 37 | return model 38 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | from pathlib import Path 5 | 6 | from strictyaml import Bool, Float, Int, Map, Seq, Str, as_document, load 7 | 8 | model_params = Map( 9 | { 10 | "model_architecture": Str(), 11 | "output_shape": Seq(Int()), 12 | "fea_dim": Int(), 13 | "out_fea_dim": Int(), 14 | "num_class": Int(), 15 | "num_input_features": Int(), 16 | "use_norm": Bool(), 17 | "init_size": Int(), 18 | } 19 | ) 20 | 21 | dataset_params = Map( 22 | { 23 | "dataset_type": Str(), 24 | "pc_dataset_type": Str(), 25 | "ignore_label": Int(), 26 | "return_test": Bool(), 27 | "fixed_volume_space": Bool(), 28 | "label_mapping": Str(), 29 | "max_volume_space": Seq(Float()), 30 | "min_volume_space": Seq(Float()), 31 | } 32 | ) 33 | 34 | 35 | train_data_loader = Map( 36 | { 37 | "data_path": Str(), 38 | "imageset": Str(), 39 | "return_ref": Bool(), 40 | "batch_size": Int(), 41 | "shuffle": Bool(), 42 | "num_workers": Int(), 43 | } 44 | ) 45 | 46 | val_data_loader = Map( 47 | { 48 | "data_path": Str(), 49 | "imageset": Str(), 50 | "return_ref": Bool(), 51 | "batch_size": Int(), 52 | "shuffle": Bool(), 53 | "num_workers": Int(), 54 | } 55 | ) 56 | 57 | 58 | train_params = Map( 59 | { 60 | "model_load_path": Str(), 61 | "model_save_path": Str(), 62 | "checkpoint_every_n_steps": Int(), 63 | "max_num_epochs": Int(), 64 | "eval_every_n_steps": Int(), 65 | "learning_rate": Float() 66 | } 67 | ) 68 | 69 | schema_v4 = Map( 70 | { 71 | "format_version": Int(), 72 | "model_params": model_params, 73 | "dataset_params": dataset_params, 74 | "train_data_loader": train_data_loader, 75 | "val_data_loader": val_data_loader, 76 | "train_params": train_params, 77 | } 78 | ) 79 | 80 | 81 | SCHEMA_FORMAT_VERSION_TO_SCHEMA = {4: schema_v4} 82 | 83 | 84 | def load_config_data(path: str) -> dict: 85 | yaml_string = Path(path).read_text() 86 | cfg_without_schema = load(yaml_string, schema=None) 87 | schema_version = int(cfg_without_schema["format_version"]) 88 | if schema_version not in SCHEMA_FORMAT_VERSION_TO_SCHEMA: 89 | raise Exception(f"Unsupported schema format version: {schema_version}.") 90 | 91 | strict_cfg = load(yaml_string, schema=SCHEMA_FORMAT_VERSION_TO_SCHEMA[schema_version]) 92 | cfg: dict = strict_cfg.data 93 | return cfg 94 | 95 | 96 | def config_data_to_config(data): # type: ignore 97 | return as_document(data, schema_v4) 98 | 99 | 100 | def save_config_data(data: dict, path: str) -> None: 101 | cfg_document = config_data_to_config(data) 102 | with open(Path(path), "w") as f: 103 | f.write(cfg_document.as_yaml()) 104 | -------------------------------------------------------------------------------- /config/label_mapping/semantic-kitti-all.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 20 # "moving-car" 137 | 253: 21 # "moving-bicyclist" 138 | 254: 22 # "moving-person" 139 | 255: 23 # "moving-motorcyclist" 140 | 256: 24 # "moving-on-rails" mapped to "moving-other-vehicle" ------mapped 141 | 257: 24 # "moving-bus" mapped to "moving-other-vehicle" -----------mapped 142 | 258: 25 # "moving-truck" 143 | 259: 24 # "moving-other-vehicle" 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | 20: 252 # "moving-car" 166 | 21: 253 # "moving-bicyclist" 167 | 22: 254 # "moving-person" 168 | 23: 255 # "moving-motorcyclist" 169 | 24: 259 # "moving-other-vehicle" 170 | 25: 258 # "moving-truck" 171 | learning_ignore: # Ignore classes 172 | 0: True # "unlabeled", and others ignored 173 | 1: False # "car" 174 | 2: False # "bicycle" 175 | 3: False # "motorcycle" 176 | 4: False # "truck" 177 | 5: False # "other-vehicle" 178 | 6: False # "person" 179 | 7: False # "bicyclist" 180 | 8: False # "motorcyclist" 181 | 9: False # "road" 182 | 10: False # "parking" 183 | 11: False # "sidewalk" 184 | 12: False # "other-ground" 185 | 13: False # "building" 186 | 14: False # "fence" 187 | 15: False # "vegetation" 188 | 16: False # "trunk" 189 | 17: False # "terrain" 190 | 18: False # "pole" 191 | 19: False # "traffic-sign" 192 | 20: False # "moving-car" 193 | 21: False # "moving-bicyclist" 194 | 22: False # "moving-person" 195 | 23: False # "moving-motorcyclist" 196 | 24: False # "moving-other-vehicle" 197 | 25: False # "moving-truck" 198 | split: # sequence numbers 199 | train: 200 | - 0 201 | - 1 202 | - 2 203 | - 3 204 | - 4 205 | - 5 206 | - 6 207 | - 7 208 | - 9 209 | - 10 210 | valid: 211 | - 8 212 | test: 213 | - 11 214 | - 12 215 | - 13 216 | - 14 217 | - 15 218 | - 16 219 | - 17 220 | - 18 221 | - 19 222 | - 20 223 | - 21 224 | -------------------------------------------------------------------------------- /config/label_mapping/semantic-kitti-multiscan.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 138 | 254: 6 # "moving-person" to "person" ------------------------------mapped 139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | learning_ignore: # Ignore classes 166 | 0: True # "unlabeled", and others ignored 167 | 1: False # "car" 168 | 2: False # "bicycle" 169 | 3: False # "motorcycle" 170 | 4: False # "truck" 171 | 5: False # "other-vehicle" 172 | 6: False # "person" 173 | 7: False # "bicyclist" 174 | 8: False # "motorcyclist" 175 | 9: False # "road" 176 | 10: False # "parking" 177 | 11: False # "sidewalk" 178 | 12: False # "other-ground" 179 | 13: False # "building" 180 | 14: False # "fence" 181 | 15: False # "vegetation" 182 | 16: False # "trunk" 183 | 17: False # "terrain" 184 | 18: False # "pole" 185 | 19: False # "traffic-sign" 186 | split: # sequence numbers 187 | train: 188 | - 0 189 | - 1 190 | - 2 191 | - 3 192 | - 4 193 | - 5 194 | - 6 195 | - 7 196 | - 9 197 | - 10 198 | valid: 199 | - 8 200 | test: 201 | - 11 202 | - 12 203 | - 13 204 | - 14 205 | - 15 206 | - 16 207 | - 17 208 | - 18 209 | - 19 210 | - 20 211 | - 21 212 | -------------------------------------------------------------------------------- /config/label_mapping/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 138 | 254: 6 # "moving-person" to "person" ------------------------------mapped 139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | learning_ignore: # Ignore classes 166 | 0: True # "unlabeled", and others ignored 167 | 1: False # "car" 168 | 2: False # "bicycle" 169 | 3: False # "motorcycle" 170 | 4: False # "truck" 171 | 5: False # "other-vehicle" 172 | 6: False # "person" 173 | 7: False # "bicyclist" 174 | 8: False # "motorcyclist" 175 | 9: False # "road" 176 | 10: False # "parking" 177 | 11: False # "sidewalk" 178 | 12: False # "other-ground" 179 | 13: False # "building" 180 | 14: False # "fence" 181 | 15: False # "vegetation" 182 | 16: False # "trunk" 183 | 17: False # "terrain" 184 | 18: False # "pole" 185 | 19: False # "traffic-sign" 186 | split: # sequence numbers 187 | train: 188 | - 0 189 | - 1 190 | - 2 191 | - 3 192 | - 4 193 | - 5 194 | - 6 195 | - 7 196 | - 9 197 | - 10 198 | valid: 199 | - 8 200 | test: 201 | - 11 202 | - 12 203 | - 13 204 | - 14 205 | - 15 206 | - 16 207 | - 17 208 | - 18 209 | - 19 210 | - 20 211 | - 21 212 | -------------------------------------------------------------------------------- /config/semantickitti-multiscan.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 256 11 | - 256 12 | - 32 13 | 14 | fea_dim: 7 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 32 18 | use_norm: True 19 | init_size: 32 20 | 21 | 22 | ################### 23 | ## Dataset options 24 | dataset_params: 25 | dataset_type: "voxel_dataset" 26 | pc_dataset_type: "SemKITTI_sk_multiscan" 27 | ignore_label: 255 28 | return_test: True 29 | fixed_volume_space: True 30 | label_mapping: "./config/label_mapping/semantic-kitti-multiscan.yaml" 31 | max_volume_space: 32 | - 51.2 33 | - 25.6 34 | - 4.4 35 | min_volume_space: 36 | - 0 37 | - -25.6 38 | - -2 39 | 40 | 41 | ################### 42 | ## Data_loader options 43 | train_data_loader: 44 | data_path: "/mnt/lustre/share_data/liuyouquan/semantickitti/sequences/" 45 | imageset: "train" 46 | return_ref: True 47 | batch_size: 2 48 | shuffle: True 49 | num_workers: 12 50 | 51 | val_data_loader: 52 | data_path: "/mnt/lustre/share_data/liuyouquan/semantickitti/sequences/" 53 | imageset: "val" 54 | # imageset: "test" 55 | return_ref: True 56 | batch_size: 1 #2 57 | shuffle: False 58 | num_workers: 12 59 | 60 | 61 | 62 | 63 | ################### 64 | ## Train params 65 | train_params: 66 | model_load_path: "./model_load_dir/" 67 | model_save_path: "./model_load_dir/" 68 | checkpoint_every_n_steps: 4599 69 | max_num_epochs: 40 70 | eval_every_n_steps: 1917 71 | learning_rate: 0.0015 72 | -------------------------------------------------------------------------------- /config/semantickitti.yaml: -------------------------------------------------------------------------------- 1 | # Config format schema number 2 | format_version: 4 3 | 4 | ################### 5 | ## Model options 6 | model_params: 7 | model_architecture: "cylinder_asym" 8 | 9 | output_shape: 10 | - 480 11 | - 360 12 | - 32 13 | 14 | fea_dim: 9 15 | out_fea_dim: 256 16 | num_class: 20 17 | num_input_features: 32 #16 18 | use_norm: True 19 | init_size: 32 #16 20 | 21 | 22 | ################### 23 | ## Dataset options 24 | dataset_params: 25 | dataset_type: "cylinder_dataset" 26 | pc_dataset_type: "SemKITTI_sk" 27 | ignore_label: 0 28 | return_test: False 29 | fixed_volume_space: True 30 | label_mapping: "./config/label_mapping/semantic-kitti.yaml" 31 | max_volume_space: 32 | - 50 33 | - 3.1415926 34 | - 2 35 | min_volume_space: 36 | - 0 37 | - -3.1415926 38 | - -4 39 | 40 | 41 | ################### 42 | ## Data_loader options 43 | train_data_loader: 44 | data_path: "/nvme/yuenan/semantickitti_dataset/sequences/" 45 | imageset: "train" 46 | return_ref: True 47 | batch_size: 2 48 | shuffle: True 49 | num_workers: 12 #4 50 | 51 | val_data_loader: 52 | data_path: "/nvme/yuenan/semantickitti_dataset/sequences/" 53 | imageset: "test" #"val" 54 | return_ref: True 55 | batch_size: 1 56 | shuffle: False 57 | num_workers: 12 #4 58 | 59 | 60 | ################### 61 | ## Train params 62 | train_params: 63 | model_load_path: "./model_load_dir/model_full_ft.pt" 64 | model_save_path: "./model_save_dir/model_tmp.pt" 65 | checkpoint_every_n_steps: 4599 66 | max_num_epochs: 20 #40 67 | eval_every_n_steps: 5000 #4599 68 | learning_rate: 0.002 #1 69 | -------------------------------------------------------------------------------- /dataloader with label rectification/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | 5 | # from . import dataset_nuscenes -------------------------------------------------------------------------------- /dataloader with label rectification/dataloader/dataset_semantickitti.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | """ 5 | SemKITTI dataloader 6 | """ 7 | import numpy as np 8 | import torch 9 | import numba as nb 10 | from torch.utils import data 11 | 12 | REGISTERED_DATASET_CLASSES = {} 13 | 14 | 15 | def register_dataset(cls, name=None): 16 | global REGISTERED_DATASET_CLASSES 17 | if name is None: 18 | name = cls.__name__ 19 | assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}" 20 | REGISTERED_DATASET_CLASSES[name] = cls 21 | return cls 22 | 23 | 24 | def get_model_class(name): 25 | global REGISTERED_DATASET_CLASSES 26 | assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}" 27 | return REGISTERED_DATASET_CLASSES[name] 28 | 29 | 30 | @register_dataset 31 | class voxel_dataset(data.Dataset): 32 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False, 33 | fixed_volume_space=False, max_volume_space=[50, 50, 1.5], min_volume_space=[-50, -50, -3]): 34 | 'Initialization' 35 | self.point_cloud_dataset = in_dataset 36 | self.grid_size = np.asarray(grid_size) 37 | self.rotate_aug = rotate_aug 38 | self.ignore_label = ignore_label 39 | self.return_test = return_test 40 | self.flip_aug = flip_aug 41 | self.fixed_volume_space = fixed_volume_space 42 | self.max_volume_space = max_volume_space 43 | self.min_volume_space = min_volume_space 44 | self.use_every_i_data = 5 45 | 46 | def __len__(self): 47 | 'Denotes the total number of samples' 48 | return len(self.point_cloud_dataset)//self.use_every_i_data 49 | 50 | def __getitem__(self, index): 51 | 'Generates one sample of data' 52 | 53 | index *= self.use_every_i_data 54 | 55 | data = self.point_cloud_dataset[index] 56 | if len(data) == 3: 57 | xyz, voxel_label, instance_label = data 58 | elif len(data) == 4: 59 | xyz, voxel_label, instance_label, sig = data 60 | if len(sig.shape) == 2: sig = np.squeeze(sig) 61 | elif len(data) == 5: 62 | xyz, voxel_label, instance_label, sig, origin_len = data 63 | if len(sig.shape) == 2: sig = np.squeeze(sig) 64 | else: 65 | raise Exception('Return invalid data tuple') 66 | 67 | # # random data augmentation by rotation 68 | # if self.rotate_aug: 69 | # rotate_rad = np.deg2rad(np.random.random() * 360) 70 | # c, s = np.cos(rotate_rad), np.sin(rotate_rad) 71 | # j = np.matrix([[c, s], [-s, c]]) 72 | # xyz[:, :2] = np.dot(xyz[:, :2], j) 73 | 74 | # # random data augmentation by flip x , y or x+y 75 | # if self.flip_aug: 76 | # flip_type = np.random.choice(4, 1) 77 | # if flip_type == 1: 78 | # xyz[:, 0] = -xyz[:, 0] 79 | # elif flip_type == 2: 80 | # xyz[:, 1] = -xyz[:, 1] 81 | # elif flip_type == 3: 82 | # xyz[:, :2] = -xyz[:, :2] 83 | 84 | 85 | if self.fixed_volume_space: 86 | max_bound = np.asarray(self.max_volume_space) 87 | min_bound = np.asarray(self.min_volume_space) 88 | else: 89 | max_bound = np.percentile(xyz, 100, axis=0) 90 | min_bound = np.percentile(xyz, 0, axis=0) 91 | 92 | ### Cut point cloud and segmentation label for valid range 93 | cut_point = 1 94 | if cut_point == 1: 95 | xyz0 = xyz 96 | for ci in range(3): 97 | xyz0[xyz[:, ci] < min_bound[ci], :] = 1000 98 | xyz0[xyz[:, ci] > max_bound[ci], :] = 1000 99 | valid_inds = xyz0[:, 0] != 1000 100 | xyz = xyz[valid_inds, :] 101 | sig = sig[valid_inds] 102 | instance_label = instance_label[valid_inds] 103 | 104 | # transpose centre coord for x axis 105 | trans_x = 1 106 | if trans_x: 107 | x_bias = (self.max_volume_space[0] - self.min_volume_space[0])/2 108 | min_bound[0] -= x_bias 109 | max_bound[0] -= x_bias 110 | xyz[:, 0] -= x_bias 111 | 112 | if len(data) == 5: 113 | origin_len = len(xyz) 114 | 115 | # get grid index 116 | crop_range = max_bound - min_bound 117 | cur_grid_size = self.grid_size 118 | 119 | intervals = crop_range / (cur_grid_size - 1) 120 | 121 | if (intervals == 0).any(): print("Zero interval!") 122 | 123 | grid_ind = (np.floor((np.clip(xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 124 | 125 | # process voxel position 126 | dim_array = np.ones(len(self.grid_size) + 1, int) 127 | dim_array[0] = -1 128 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array) 129 | 130 | ## process labels 131 | # processed_label = np.ones(self.grid_size, dtype=np.uint8) * self.ignore_label 132 | # label_voxel_pair = np.concatenate([grid_ind, labels], axis=1) 133 | # label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2])), :] 134 | # processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair) 135 | 136 | # processed_label = voxel_label # voxel labels 137 | processed_label = label_rectification(grid_ind, voxel_label.copy(), instance_label) 138 | 139 | ## uncomment to save label for visualization 140 | # torch.save( 141 | # {'voxel_label_org': voxel_label, 142 | # 'voxel_label_rect': processed_label, 143 | # }, 144 | # f"label_{index}.pt" 145 | # ) 146 | 147 | data_tuple = (voxel_position, processed_label) 148 | 149 | # center data on each voxel for PTnet 150 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 151 | return_xyz = xyz - voxel_centers 152 | return_xyz = np.concatenate((return_xyz, xyz), axis=1) 153 | 154 | if len(data) == 3: 155 | return_fea = return_xyz 156 | elif len(data) == 4 or len(data) == 5: 157 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity 158 | 159 | if self.return_test: 160 | data_tuple += (grid_ind, voxel_label, return_fea, index) 161 | else: 162 | data_tuple += (grid_ind, voxel_label, return_fea) 163 | 164 | if len(data) == 5: 165 | data_tuple += (origin_len,) 166 | 167 | return data_tuple 168 | 169 | 170 | def label_rectification(grid_ind, voxel_label, instance_label, 171 | dynamic_classes=[1,4,5,6,7,8], 172 | voxel_shape=(256,256,32), 173 | ignore_class_label=255): 174 | 175 | segmentation_label = voxel_label[grid_ind[:,0], grid_ind[:,1], grid_ind[:,2]] 176 | 177 | for c in dynamic_classes: 178 | voxel_pos_class_c = (voxel_label==c).astype(int) 179 | instance_label_class_c = instance_label[segmentation_label==c].squeeze(1) 180 | 181 | if len(instance_label_class_c) == 0: 182 | pos_to_remove = voxel_pos_class_c 183 | 184 | elif len(instance_label_class_c) > 0 and np.sum(voxel_pos_class_c) > 0: 185 | mask_class_c = np.zeros(voxel_shape, dtype=int) 186 | point_pos_class_c = grid_ind[segmentation_label==c] 187 | uniq_instance_label_class_c = np.unique(instance_label_class_c) 188 | 189 | for i in uniq_instance_label_class_c: 190 | point_pos_instance_i = point_pos_class_c[instance_label_class_c==i] 191 | x_max, y_max, z_max = np.amax(point_pos_instance_i, axis=0) 192 | x_min, y_min, z_min = np.amin(point_pos_instance_i, axis=0) 193 | 194 | mask_class_c[x_min:x_max,y_min:y_max,z_min:z_max] = 1 195 | 196 | pos_to_remove = (voxel_pos_class_c - mask_class_c) > 0 197 | 198 | voxel_label[pos_to_remove] = ignore_class_label 199 | 200 | return voxel_label 201 | 202 | # transformation between Cartesian coordinates and polar coordinates 203 | def cart2polar(input_xyz): 204 | rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2) 205 | phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0]) 206 | return np.stack((rho, phi, input_xyz[:, 2]), axis=1) 207 | 208 | 209 | def polar2cat(input_xyz_polar): 210 | # print(input_xyz_polar.shape) 211 | x = input_xyz_polar[0] * np.cos(input_xyz_polar[1]) 212 | y = input_xyz_polar[0] * np.sin(input_xyz_polar[1]) 213 | return np.stack((x, y, input_xyz_polar[2]), axis=0) 214 | 215 | 216 | @register_dataset 217 | class cylinder_dataset(data.Dataset): 218 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False, 219 | fixed_volume_space=False, max_volume_space=[50, np.pi, 2], min_volume_space=[0, -np.pi, -4], 220 | scale_aug=False, 221 | transform_aug=False, trans_std=[0.1, 0.1, 0.1], 222 | min_rad=-np.pi / 4, max_rad=np.pi / 4, use_tta=False): 223 | self.point_cloud_dataset = in_dataset 224 | self.grid_size = np.asarray(grid_size) 225 | self.rotate_aug = rotate_aug 226 | self.flip_aug = flip_aug 227 | self.scale_aug = scale_aug 228 | self.ignore_label = ignore_label 229 | self.return_test = return_test 230 | self.fixed_volume_space = fixed_volume_space 231 | self.max_volume_space = max_volume_space 232 | self.min_volume_space = min_volume_space 233 | self.transform = transform_aug 234 | self.trans_std = trans_std 235 | 236 | self.noise_rotation = np.random.uniform(min_rad, max_rad) 237 | self.use_tta = use_tta 238 | 239 | def __len__(self): 240 | 'Denotes the total number of samples' 241 | return len(self.point_cloud_dataset) 242 | 243 | def rotation_points_single_angle(self, points, angle, axis=0): 244 | # points: [N, 3] 245 | rot_sin = np.sin(angle) 246 | rot_cos = np.cos(angle) 247 | if axis == 1: 248 | rot_mat_T = np.array( 249 | [[rot_cos, 0, -rot_sin], [0, 1, 0], [rot_sin, 0, rot_cos]], 250 | dtype=points.dtype) 251 | elif axis == 2 or axis == -1: 252 | rot_mat_T = np.array( 253 | [[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]], 254 | dtype=points.dtype) 255 | elif axis == 0: 256 | rot_mat_T = np.array( 257 | [[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]], 258 | dtype=points.dtype) 259 | else: 260 | raise ValueError("axis should in range") 261 | 262 | return points @ rot_mat_T 263 | 264 | def __getitem__(self, index): 265 | 'Generates one sample of data' 266 | data = self.point_cloud_dataset[index] 267 | if self.use_tta: 268 | data_total = [] 269 | voting = 4 270 | for idx in range(voting): 271 | data_single_ori = self.get_single_sample(data, index, idx) 272 | data_total.append(data_single_ori) 273 | data_total = tuple(data_total) 274 | return data_total 275 | else: 276 | data_single = self.get_single_sample(data, index) 277 | return data_single 278 | 279 | def get_single_sample(self, data, index, vote_idx=0): 280 | if len(data) == 2: 281 | xyz, labels = data 282 | elif len(data) == 3: 283 | xyz, labels, sig = data 284 | if len(sig.shape) == 2: sig = np.squeeze(sig) 285 | elif len(data) == 4: 286 | xyz, labels, sig, origin_len = data 287 | if len(sig.shape) == 2: sig = np.squeeze(sig) 288 | else: 289 | raise Exception('Return invalid data tuple') 290 | 291 | # random data augmentation by rotation 292 | if self.rotate_aug: 293 | rotate_rad = np.deg2rad(np.random.random() * 90) - np.pi / 4 294 | c, s = np.cos(rotate_rad), np.sin(rotate_rad) 295 | j = np.matrix([[c, s], [-s, c]]) 296 | xyz[:, :2] = np.dot(xyz[:, :2], j) 297 | 298 | # random data augmentation by flip x , y or x+y 299 | if self.flip_aug: 300 | if self.use_tta: 301 | flip_type = vote_idx 302 | else: 303 | flip_type = np.random.choice(4, 1) 304 | if flip_type == 1: 305 | xyz[:, 0] = -xyz[:, 0] 306 | elif flip_type == 2: 307 | xyz[:, 1] = -xyz[:, 1] 308 | elif flip_type == 3: 309 | xyz[:, :2] = -xyz[:, :2] 310 | if self.scale_aug: 311 | noise_scale = np.random.uniform(0.95, 1.05) 312 | xyz[:, 0] = noise_scale * xyz[:, 0] 313 | xyz[:, 1] = noise_scale * xyz[:, 1] 314 | # convert coordinate into polar coordinates 315 | 316 | if self.transform: 317 | noise_translate = np.array([np.random.normal(0, self.trans_std[0], 1), 318 | np.random.normal(0, self.trans_std[1], 1), 319 | np.random.normal(0, self.trans_std[2], 1)]).T 320 | 321 | xyz[:, 0:3] += noise_translate 322 | 323 | xyz_pol = cart2polar(xyz) 324 | 325 | max_bound_r = np.percentile(xyz_pol[:, 0], 100, axis=0) 326 | min_bound_r = np.percentile(xyz_pol[:, 0], 0, axis=0) 327 | max_bound = np.max(xyz_pol[:, 1:], axis=0) 328 | min_bound = np.min(xyz_pol[:, 1:], axis=0) 329 | max_bound = np.concatenate(([max_bound_r], max_bound)) 330 | min_bound = np.concatenate(([min_bound_r], min_bound)) 331 | if self.fixed_volume_space: 332 | max_bound = np.asarray(self.max_volume_space) 333 | min_bound = np.asarray(self.min_volume_space) 334 | # get grid index 335 | crop_range = max_bound - min_bound 336 | cur_grid_size = self.grid_size 337 | intervals = crop_range / (cur_grid_size - 1) 338 | 339 | if (intervals == 0).any(): print("Zero interval!") 340 | grid_ind = (np.floor((np.clip(xyz_pol, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 341 | 342 | voxel_position = np.zeros(self.grid_size, dtype=np.float32) 343 | dim_array = np.ones(len(self.grid_size) + 1, int) 344 | dim_array[0] = -1 345 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array) 346 | voxel_position = polar2cat(voxel_position) 347 | 348 | processed_label = np.ones(self.grid_size, dtype=np.uint8) * self.ignore_label 349 | label_voxel_pair = np.concatenate([grid_ind, labels], axis=1) 350 | label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2])), :] 351 | processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair) 352 | data_tuple = (voxel_position, processed_label) 353 | 354 | # center data on each voxel for PTnet 355 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 356 | return_xyz = xyz_pol - voxel_centers 357 | return_xyz = np.concatenate((return_xyz, xyz_pol, xyz[:, :2]), axis=1) 358 | 359 | if len(data) == 2: 360 | return_fea = return_xyz 361 | elif len(data) == 3 or len(data) == 4: 362 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) 363 | 364 | if self.return_test: 365 | data_tuple += (grid_ind, labels, return_fea, index) 366 | else: 367 | data_tuple += (grid_ind, labels, return_fea) 368 | 369 | if len(data) == 4: 370 | data_tuple += (origin_len,) 371 | return data_tuple 372 | 373 | 374 | @register_dataset 375 | class polar_dataset(data.Dataset): 376 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False, 377 | fixed_volume_space=False, max_volume_space=[50, np.pi, 2], min_volume_space=[0, -np.pi, -4], 378 | scale_aug=False): 379 | self.point_cloud_dataset = in_dataset 380 | self.grid_size = np.asarray(grid_size) 381 | self.rotate_aug = rotate_aug 382 | self.flip_aug = flip_aug 383 | self.scale_aug = scale_aug 384 | self.ignore_label = ignore_label 385 | self.return_test = return_test 386 | self.fixed_volume_space = fixed_volume_space 387 | self.max_volume_space = max_volume_space 388 | self.min_volume_space = min_volume_space 389 | 390 | def __len__(self): 391 | 'Denotes the total number of samples' 392 | return len(self.point_cloud_dataset) 393 | 394 | def __getitem__(self, index): 395 | 'Generates one sample of data' 396 | data = self.point_cloud_dataset[index] 397 | if len(data) == 2: 398 | xyz, labels = data 399 | elif len(data) == 3: 400 | xyz, labels, sig = data 401 | if len(sig.shape) == 2: 402 | sig = np.squeeze(sig) 403 | else: 404 | raise Exception('Return invalid data tuple') 405 | 406 | # random data augmentation by rotation 407 | if self.rotate_aug: 408 | rotate_rad = np.deg2rad(np.random.random() * 45) - np.pi / 8 409 | c, s = np.cos(rotate_rad), np.sin(rotate_rad) 410 | j = np.matrix([[c, s], [-s, c]]) 411 | xyz[:, :2] = np.dot(xyz[:, :2], j) 412 | 413 | # random data augmentation by flip x , y or x+y 414 | if self.flip_aug: 415 | flip_type = np.random.choice(4, 1) 416 | if flip_type == 1: 417 | xyz[:, 0] = -xyz[:, 0] 418 | elif flip_type == 2: 419 | xyz[:, 1] = -xyz[:, 1] 420 | elif flip_type == 3: 421 | xyz[:, :2] = -xyz[:, :2] 422 | if self.scale_aug: 423 | noise_scale = np.random.uniform(0.95, 1.05) 424 | xyz[:, 0] = noise_scale * xyz[:, 0] 425 | xyz[:, 1] = noise_scale * xyz[:, 1] 426 | xyz_pol = cart2polar(xyz) 427 | 428 | max_bound_r = np.percentile(xyz_pol[:, 0], 100, axis=0) 429 | min_bound_r = np.percentile(xyz_pol[:, 0], 0, axis=0) 430 | max_bound = np.max(xyz_pol[:, 1:], axis=0) 431 | min_bound = np.min(xyz_pol[:, 1:], axis=0) 432 | max_bound = np.concatenate(([max_bound_r], max_bound)) 433 | min_bound = np.concatenate(([min_bound_r], min_bound)) 434 | if self.fixed_volume_space: 435 | max_bound = np.asarray(self.max_volume_space) 436 | min_bound = np.asarray(self.min_volume_space) 437 | # get grid index 438 | crop_range = max_bound - min_bound 439 | cur_grid_size = self.grid_size 440 | intervals = crop_range / (cur_grid_size - 1) 441 | 442 | if (intervals == 0).any(): print("Zero interval!") 443 | grid_ind = (np.floor((np.clip(xyz_pol, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 444 | 445 | voxel_position = np.zeros(self.grid_size, dtype=np.float32) 446 | dim_array = np.ones(len(self.grid_size) + 1, int) 447 | dim_array[0] = -1 448 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array) 449 | voxel_position = polar2cat(voxel_position) 450 | 451 | processed_label = np.ones(self.grid_size, dtype=np.uint8) * self.ignore_label 452 | label_voxel_pair = np.concatenate([grid_ind, labels], axis=1) 453 | label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2])), :] 454 | processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair) 455 | data_tuple = (voxel_position, processed_label) 456 | 457 | # center data on each voxel for PTnet 458 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 459 | return_xyz = xyz_pol - voxel_centers 460 | return_xyz = np.concatenate((return_xyz, xyz_pol, xyz[:, :2]), axis=1) 461 | 462 | if len(data) == 2: 463 | return_fea = return_xyz 464 | elif len(data) == 3: 465 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) 466 | 467 | if self.return_test: 468 | data_tuple += (grid_ind, labels, return_fea, index) 469 | else: 470 | data_tuple += (grid_ind, labels, return_fea) 471 | 472 | return data_tuple 473 | 474 | 475 | @nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])', nopython=True, cache=True, parallel=False) 476 | def nb_process_label(processed_label, sorted_label_voxel_pair): 477 | label_size = 256 478 | counter = np.zeros((label_size,), dtype=np.uint16) 479 | counter[sorted_label_voxel_pair[0, 3]] = 1 480 | cur_sear_ind = sorted_label_voxel_pair[0, :3] 481 | for i in range(1, sorted_label_voxel_pair.shape[0]): 482 | cur_ind = sorted_label_voxel_pair[i, :3] 483 | if not np.all(np.equal(cur_ind, cur_sear_ind)): 484 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter) 485 | counter = np.zeros((label_size,), dtype=np.uint16) 486 | cur_sear_ind = cur_ind 487 | counter[sorted_label_voxel_pair[i, 3]] += 1 488 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter) 489 | return processed_label 490 | 491 | 492 | def collate_fn_BEV(data): 493 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 494 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 495 | grid_ind_stack = [d[2] for d in data] 496 | point_label = [d[3] for d in data] 497 | xyz = [d[4] for d in data] 498 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz 499 | 500 | '''def collate_fn_BEV(data): 501 | 502 | voxel_label = [] 503 | 504 | for da1 in data: 505 | for da2 in da1: 506 | voxel_label.append(da2[1]) 507 | 508 | voxel_label = np.stack(voxel_label).astype(np.int) 509 | 510 | grid_ind_stack = [] 511 | for da1 in data: 512 | for da2 in da1: 513 | grid_ind_stack.append(da2[2]) 514 | 515 | 516 | point_label = [] 517 | 518 | for da1 in data: 519 | for da2 in da1: 520 | point_label.append(da2[3]) 521 | 522 | xyz = [] 523 | 524 | for da1 in data: 525 | for da2 in da1: 526 | xyz.append(da2[4]) 527 | 528 | return xyz, torch.from_numpy(voxel_label), grid_ind_stack, point_label, xyz 529 | ''' 530 | 531 | '''def collate_fn_BEV_test_old(data): 532 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 533 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 534 | grid_ind_stack = [d[2] for d in data] 535 | point_label = [d[3] for d in data] 536 | xyz = [d[4] for d in data] 537 | index = [d[5] for d in data] 538 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index 539 | ''' 540 | 541 | def collate_fn_BEV_tta(data): 542 | 543 | voxel_label = [] 544 | 545 | for da1 in data: 546 | for da2 in da1: 547 | voxel_label.append(da2[1]) 548 | 549 | #voxel_label.astype(np.int) 550 | 551 | grid_ind_stack = [] 552 | for da1 in data: 553 | for da2 in da1: 554 | grid_ind_stack.append(da2[2]) 555 | 556 | 557 | 558 | point_label = [] 559 | 560 | for da1 in data: 561 | for da2 in da1: 562 | point_label.append(da2[3]) 563 | 564 | xyz = [] 565 | 566 | for da1 in data: 567 | for da2 in da1: 568 | xyz.append(da2[4]) 569 | index = [] 570 | for da1 in data: 571 | for da2 in da1: 572 | index.append(da2[5]) 573 | 574 | return xyz, xyz, grid_ind_stack, point_label, xyz, index 575 | 576 | 577 | def collate_fn_BEV_ms(data): 578 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 579 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 580 | grid_ind_stack = [d[2] for d in data] 581 | point_label = [d[3] for d in data] 582 | xyz = [d[4] for d in data] 583 | # origin_len = [d[5] for d in data] 584 | index = [d[5] for d in data] 585 | origin_len = [d[6] for d in data] 586 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index, origin_len 587 | 588 | 589 | def collate_fn_BEV_ms_tta(data): 590 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 591 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 592 | grid_ind_stack = [d[2] for d in data] 593 | point_label = [d[3] for d in data] 594 | xyz = [d[4] for d in data] 595 | index = [d[5] for d in data] 596 | origin_len = [d[6] for d in data] 597 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index, origin_len 598 | 599 | 600 | # def collate_fn_BEV_ms_tta(data): 601 | # 602 | # voxel_label = [] 603 | # 604 | # for da1 in data: 605 | # for da2 in da1: 606 | # voxel_label.append(da2[1]) 607 | # 608 | # #voxel_label.astype(np.int) 609 | # 610 | # grid_ind_stack = [] 611 | # for da1 in data: 612 | # for da2 in da1: 613 | # grid_ind_stack.append(da2[2]) 614 | # 615 | # point_label = [] 616 | # 617 | # for da1 in data: 618 | # for da2 in da1: 619 | # point_label.append(da2[3]) 620 | # 621 | # xyz = [] 622 | # 623 | # for da1 in data: 624 | # for da2 in da1: 625 | # xyz.append(da2[4]) 626 | # 627 | # index = [] 628 | # 629 | # for da1 in data: 630 | # for da2 in da1: 631 | # index.append(da2[5]) 632 | # 633 | # origin_len = [] 634 | # 635 | # for da1 in data: 636 | # for da2 in da1: 637 | # origin_len.append(da2[6]) 638 | # 639 | # return xyz, xyz, grid_ind_stack, point_label, xyz, index, origin_len 640 | -------------------------------------------------------------------------------- /dataloader with label rectification/dataloader/pc_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: pc_dataset.py 4 | 5 | import os 6 | import numpy as np 7 | from torch.utils import data 8 | import yaml 9 | import pickle 10 | import pathlib 11 | 12 | REGISTERED_PC_DATASET_CLASSES = {} 13 | 14 | 15 | def register_dataset(cls, name=None): 16 | global REGISTERED_PC_DATASET_CLASSES 17 | if name is None: 18 | name = cls.__name__ 19 | assert name not in REGISTERED_PC_DATASET_CLASSES, f"exist class: {REGISTERED_PC_DATASET_CLASSES}" 20 | REGISTERED_PC_DATASET_CLASSES[name] = cls 21 | return cls 22 | 23 | 24 | def get_pc_model_class(name): 25 | global REGISTERED_PC_DATASET_CLASSES 26 | assert name in REGISTERED_PC_DATASET_CLASSES, f"available class: {REGISTERED_PC_DATASET_CLASSES}" 27 | return REGISTERED_PC_DATASET_CLASSES[name] 28 | 29 | 30 | @register_dataset 31 | class SemKITTI_sk(data.Dataset): 32 | def __init__(self, data_path, imageset='train', 33 | return_ref=False, label_mapping="semantic-kitti.yaml", nusc=None): 34 | self.return_ref = return_ref 35 | with open(label_mapping, 'r') as stream: 36 | semkittiyaml = yaml.safe_load(stream) 37 | self.learning_map = semkittiyaml['learning_map'] 38 | self.imageset = imageset 39 | if imageset == 'train': 40 | split = semkittiyaml['split']['train'] 41 | elif imageset == 'val': 42 | split = semkittiyaml['split']['valid'] 43 | elif imageset == 'test': 44 | split = semkittiyaml['split']['test'] 45 | else: 46 | raise Exception('Split must be train/val/test') 47 | 48 | self.im_idx = [] 49 | for i_folder in split: 50 | self.im_idx += absoluteFilePaths('/'.join([data_path, str(i_folder).zfill(2), 'velodyne'])) 51 | 52 | def __len__(self): 53 | 'Denotes the total number of samples' 54 | return len(self.im_idx) 55 | 56 | def __getitem__(self, index): 57 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) 58 | if self.imageset == 'test': 59 | annotated_data = np.expand_dims(np.zeros_like(raw_data[:, 0], dtype=int), axis=1) 60 | else: 61 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'labels')[:-3] + 'label', 62 | dtype=np.uint32).reshape((-1, 1)) 63 | annotated_data = annotated_data & 0xFFFF # delete high 16 digits binary 64 | annotated_data = np.vectorize(self.learning_map.__getitem__)(annotated_data) 65 | 66 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8)) 67 | if self.return_ref: 68 | data_tuple += (raw_data[:, 3],) 69 | return data_tuple 70 | 71 | 72 | def absoluteFilePaths(directory): 73 | for dirpath, _, filenames in os.walk(directory): 74 | filenames.sort() 75 | for f in filenames: 76 | yield os.path.abspath(os.path.join(dirpath, f)) 77 | 78 | 79 | def SemKITTI2train(label): 80 | if isinstance(label, list): 81 | return [SemKITTI2train_single(a) for a in label] 82 | else: 83 | return SemKITTI2train_single(label) 84 | 85 | 86 | def SemKITTI2train_single(label): 87 | remove_ind = label == 0 88 | label -= 1 89 | label[remove_ind] = 255 90 | return label 91 | 92 | 93 | def unpack(compressed): # from samantickitti api 94 | ''' given a bit encoded voxel grid, make a normal voxel grid out of it. ''' 95 | uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8) 96 | uncompressed[::8] = compressed[:] >> 7 & 1 97 | uncompressed[1::8] = compressed[:] >> 6 & 1 98 | uncompressed[2::8] = compressed[:] >> 5 & 1 99 | uncompressed[3::8] = compressed[:] >> 4 & 1 100 | uncompressed[4::8] = compressed[:] >> 3 & 1 101 | uncompressed[5::8] = compressed[:] >> 2 & 1 102 | uncompressed[6::8] = compressed[:] >> 1 & 1 103 | uncompressed[7::8] = compressed[:] & 1 104 | 105 | return uncompressed 106 | 107 | def get_eval_mask(labels, invalid_voxels): # from samantickitti api 108 | """ 109 | Ignore labels set to 255 and invalid voxels (the ones never hit by a laser ray, probed using ray tracing) 110 | :param labels: input ground truth voxels 111 | :param invalid_voxels: voxels ignored during evaluation since the lie beyond the scene that was captured by the laser 112 | :return: boolean mask to subsample the voxels to evaluate 113 | """ 114 | masks = np.ones_like(labels, dtype=np.bool) 115 | masks[labels == 255] = False 116 | masks[invalid_voxels == 1] = False 117 | 118 | return masks 119 | 120 | 121 | from os.path import join 122 | @register_dataset 123 | class SemKITTI_sk_multiscan(data.Dataset): 124 | def __init__(self, data_path, imageset='train',return_ref=False, label_mapping="semantic-kitti-multiscan.yaml", nusc=None): 125 | self.return_ref = return_ref 126 | with open(label_mapping, 'r') as stream: 127 | semkittiyaml = yaml.safe_load(stream) 128 | ### remap completion label 129 | remapdict = semkittiyaml['learning_map'] 130 | # make lookup table for mapping 131 | maxkey = max(remapdict.keys()) 132 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32) 133 | remap_lut[list(remapdict.keys())] = list(remapdict.values()) 134 | # in completion we have to distinguish empty and invalid voxels. 135 | # Important: For voxels 0 corresponds to "empty" and not "unlabeled". 136 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' 137 | remap_lut[0] = 0 # only 'empty' stays 'empty'. 138 | self.comletion_remap_lut = remap_lut 139 | 140 | self.learning_map = semkittiyaml['learning_map'] 141 | self.imageset = imageset 142 | self.data_path = data_path 143 | if imageset == 'train': 144 | split = semkittiyaml['split']['train'] 145 | elif imageset == 'val': 146 | split = semkittiyaml['split']['valid'] 147 | elif imageset == 'test': 148 | split = semkittiyaml['split']['test'] 149 | else: 150 | raise Exception('Split must be train/val/test') 151 | 152 | multiscan = 4 # additional frames are fused with target-frame. Hence, multiscan+1 point clouds in total 153 | print('multiscan: %d' %multiscan) 154 | self.multiscan = multiscan 155 | self.im_idx = [] 156 | 157 | self.calibrations = [] 158 | self.times = [] 159 | self.poses = [] 160 | 161 | self.load_calib_poses() 162 | for i_folder in split: 163 | # velodyne path corresponding to voxel path 164 | complete_path = os.path.join(data_path, str(i_folder).zfill(2), "voxels") 165 | files = list(pathlib.Path(complete_path).glob('*.bin')) 166 | for filename in files: 167 | self.im_idx.append(str(filename).replace('voxels', 'velodyne')) 168 | 169 | 170 | 171 | def __len__(self): 172 | 'Denotes the total number of samples' 173 | return len(self.im_idx) 174 | 175 | def load_calib_poses(self): 176 | """ 177 | load calib poses and times. 178 | """ 179 | 180 | ########### 181 | # Load data 182 | ########### 183 | 184 | self.calibrations = [] 185 | self.times = [] 186 | self.poses = [] 187 | 188 | for seq in range(0, 22): 189 | seq_folder = join(self.data_path, str(seq).zfill(2)) 190 | 191 | # Read Calib 192 | self.calibrations.append(self.parse_calibration(join(seq_folder, "calib.txt"))) 193 | 194 | # Read times 195 | self.times.append(np.loadtxt(join(seq_folder, 'times.txt'), dtype=np.float32)) 196 | 197 | # Read poses 198 | poses_f64 = self.parse_poses(join(seq_folder, 'poses.txt'), self.calibrations[-1]) 199 | self.poses.append([pose.astype(np.float32) for pose in poses_f64]) 200 | 201 | def parse_calibration(self, filename): 202 | """ read calibration file with given filename 203 | 204 | Returns 205 | ------- 206 | dict 207 | Calibration matrices as 4x4 numpy arrays. 208 | """ 209 | calib = {} 210 | 211 | calib_file = open(filename) 212 | for line in calib_file: 213 | key, content = line.strip().split(":") 214 | values = [float(v) for v in content.strip().split()] 215 | 216 | pose = np.zeros((4, 4)) 217 | pose[0, 0:4] = values[0:4] 218 | pose[1, 0:4] = values[4:8] 219 | pose[2, 0:4] = values[8:12] 220 | pose[3, 3] = 1.0 221 | 222 | calib[key] = pose 223 | 224 | calib_file.close() 225 | 226 | return calib 227 | 228 | def parse_poses(self, filename, calibration): 229 | """ read poses file with per-scan poses from given filename 230 | 231 | Returns 232 | ------- 233 | list 234 | list of poses as 4x4 numpy arrays. 235 | """ 236 | file = open(filename) 237 | 238 | poses = [] 239 | 240 | Tr = calibration["Tr"] 241 | Tr_inv = np.linalg.inv(Tr) 242 | 243 | for line in file: 244 | values = [float(v) for v in line.strip().split()] 245 | 246 | pose = np.zeros((4, 4)) 247 | pose[0, 0:4] = values[0:4] 248 | pose[1, 0:4] = values[4:8] 249 | pose[2, 0:4] = values[8:12] 250 | pose[3, 3] = 1.0 251 | 252 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) 253 | 254 | return poses 255 | 256 | def fuse_multi_scan(self, points, pose0, pose): 257 | 258 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 259 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1) 260 | new_points = new_points[:, :3] 261 | new_coords = new_points - pose0[:3, 3] 262 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1) 263 | new_coords = np.hstack((new_coords, points[:, 3:])) 264 | 265 | return new_coords 266 | 267 | def __getitem__(self, index): 268 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) # point cloud 269 | origin_len = len(raw_data) 270 | 271 | number_idx = int(self.im_idx[index][-10:-4]) 272 | dir_idx = int(self.im_idx[index][-22:-20]) 273 | 274 | pose0 = self.poses[dir_idx][number_idx] 275 | 276 | if self.imageset == 'test': 277 | instance_label = np.expand_dims(np.zeros_like(raw_data[:, 0], dtype=int), axis=1) 278 | else: 279 | instance_label = np.fromfile(self.im_idx[index].replace('velodyne', 'labels')[:-3] + 'label', 280 | dtype=np.int32).reshape((-1, 1)) 281 | instance_label = instance_label & 0xFFFF # delete high 16 digits binary 282 | 283 | for fuse_idx in range(self.multiscan): 284 | plus_idx = fuse_idx + 1 285 | data_idx = number_idx + plus_idx 286 | path = self.im_idx[index][:-10] 287 | newpath2 = path + str(data_idx).zfill(6) + self.im_idx[index][-4:] 288 | voxel_path = path.replace('velodyne', 'labels') 289 | files = list(pathlib.Path(voxel_path).glob('*.label')) 290 | 291 | if data_idx < len(files): 292 | pose = self.poses[dir_idx][data_idx] 293 | 294 | raw_data2 = np.fromfile(newpath2, dtype=np.float32).reshape((-1, 4)) 295 | 296 | if self.imageset == 'test': 297 | instance_label2 = np.expand_dims(np.zeros_like(raw_data2[:, 0], dtype=int), axis=1) 298 | else: 299 | instance_label2 = np.fromfile(newpath2.replace('velodyne', 'labels')[:-3] + 'label', 300 | dtype=np.int32).reshape((-1, 1)) 301 | instance_label2 = instance_label2 & 0xFFFF # delete high 16 digits binary 302 | 303 | raw_data2 = self.fuse_multi_scan(raw_data2, pose0, pose) 304 | 305 | if len(raw_data2) != 0: 306 | raw_data = np.concatenate((raw_data, raw_data2), 0) 307 | instance_label = np.concatenate((instance_label, instance_label2), 0) 308 | 309 | if self.imageset == 'test': 310 | voxel_label = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1)) 311 | else: 312 | voxel_label = np.fromfile(self.im_idx[index].replace('velodyne', 'voxels')[:-3] + 'label', 313 | dtype=np.uint16).reshape((-1, 1)) # voxel labels 314 | 315 | voxel_label = self.comletion_remap_lut[voxel_label] 316 | voxel_label = voxel_label.reshape((256, 256, 32)) 317 | 318 | data_tuple = (raw_data[:, :3], voxel_label.astype(np.uint8), instance_label) # xyz, voxel labels 319 | 320 | if self.return_ref: 321 | data_tuple += (raw_data[:, 3], origin_len) # origin_len is used to indicate the length of target-scan 322 | 323 | return data_tuple 324 | 325 | 326 | # load Semantic KITTI class info 327 | def get_SemKITTI_label_name(label_mapping): 328 | with open(label_mapping, 'r') as stream: 329 | semkittiyaml = yaml.safe_load(stream) 330 | SemKITTI_label_name = dict() 331 | for i in sorted(list(semkittiyaml['learning_map'].keys()))[::-1]: 332 | SemKITTI_label_name[semkittiyaml['learning_map'][i]] = semkittiyaml['labels'][i] 333 | 334 | return SemKITTI_label_name 335 | 336 | 337 | def get_nuScenes_label_name(label_mapping): 338 | with open(label_mapping, 'r') as stream: 339 | nuScenesyaml = yaml.safe_load(stream) 340 | nuScenes_label_name = dict() 341 | for i in sorted(list(nuScenesyaml['learning_map'].keys()))[::-1]: 342 | val_ = nuScenesyaml['learning_map'][i] 343 | nuScenes_label_name[val_] = nuScenesyaml['labels_16'][val_] 344 | 345 | return nuScenes_label_name 346 | -------------------------------------------------------------------------------- /dataloader with label rectification/visualize_voxel_label.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | import yaml 5 | 6 | import torch 7 | 8 | import open3d as o3d 9 | 10 | 11 | def get_cmap_semanticKITTI20(): 12 | colors = np.array([ 13 | [0 , 0 , 0, 255], 14 | [100, 150, 245, 255], 15 | [100, 230, 245, 255], 16 | [30, 60, 150, 255], 17 | [80, 30, 180, 255], 18 | [100, 80, 250, 255], 19 | [255, 30, 30, 255], 20 | [255, 40, 200, 255], 21 | [150, 30, 90, 255], 22 | [255, 0, 255, 255], 23 | [255, 150, 255, 255], 24 | [75, 0, 75, 255], 25 | [175, 0, 75, 255], 26 | [255, 200, 0, 255], 27 | [255, 120, 50, 255], 28 | [0, 175, 0, 255], 29 | [135, 60, 0, 255], 30 | [150, 240, 80, 255], 31 | [255, 240, 150, 255], 32 | [255, 0, 0, 255]]).astype(np.uint8) 33 | 34 | return colors 35 | 36 | 37 | if __name__ == "__main__": 38 | 39 | dataset_root = "./labels" 40 | voxels_dir = glob.glob(os.path.join(dataset_root, "*.pt")) 41 | voxels_dir.sort() 42 | 43 | max_volume_space = [51.2, 25.6, 4.4] 44 | min_volume_space = [0, -25.6, -2.0] 45 | 46 | 47 | with open("./config/semantic-kitti.yaml", 'r') as stream: 48 | semkittiyaml = yaml.safe_load(stream) 49 | 50 | learning_map = semkittiyaml['learning_map'].copy() 51 | for key, value in learning_map.items(): 52 | if key != 0 and value == 0: 53 | learning_map[key] = 255 54 | 55 | color_map = get_cmap_semanticKITTI20()[:,:3] 56 | 57 | for voxel_path in voxels_dir: 58 | 59 | data = torch.load(voxel_path) 60 | 61 | x = np.arange(256) 62 | y = np.arange(256) 63 | z = np.arange(32) 64 | 65 | Y, X, Z = np.meshgrid(x,y,z) 66 | 67 | full_voxel_coord = np.concatenate((X[...,None], Y[...,None], Z[...,None]), axis=-1) 68 | 69 | voxel_label_org = data['voxel_label_org'] 70 | voxel_label_rect = data['voxel_label_rect'] 71 | 72 | mask = (voxel_label_org!=0) & (voxel_label_org!=255) 73 | full_voxel_coord_org = full_voxel_coord[mask] 74 | voxel_label_org = voxel_label_org[mask] 75 | 76 | mask = (voxel_label_rect!=0) & (voxel_label_rect!=255) 77 | full_voxel_coord_rect = full_voxel_coord[mask] 78 | voxel_label_rect = voxel_label_rect[mask] 79 | 80 | 81 | voxel_label_org_o3d = o3d.geometry.PointCloud() 82 | voxel_label_org_o3d.points = o3d.utility.Vector3dVector(full_voxel_coord_org) 83 | voxel_label_org_o3d.colors = o3d.utility.Vector3dVector(color_map[voxel_label_org.astype(int)].astype(float)/255) 84 | 85 | voxel_label_rect_o3d = o3d.geometry.PointCloud() 86 | voxel_label_rect_o3d.points = o3d.utility.Vector3dVector(full_voxel_coord_rect) 87 | voxel_label_rect_o3d.colors = o3d.utility.Vector3dVector(color_map[voxel_label_rect.astype(int)].astype(float)/255) 88 | 89 | coord_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=5, origin=[0, 0, 0]) 90 | 91 | o3d.visualization.draw_geometries([coord_frame, 92 | voxel_label_org_o3d, 93 | ], 94 | zoom=0.1, 95 | front=[-np.sqrt(3)/2, 0, 1/2], 96 | lookat=[-50, 128, 178/np.sqrt(3)], 97 | up=[0, 0, 1], 98 | window_name ='voxel label original') 99 | 100 | 101 | print() 102 | 103 | o3d.visualization.draw_geometries([coord_frame, 104 | voxel_label_rect_o3d, 105 | ], 106 | zoom=0.1, 107 | front=[-np.sqrt(3)/2, 0, 1/2], 108 | lookat=[-50, 128, 178/np.sqrt(3)], 109 | up=[0, 0, 1], 110 | window_name ='voxel label rectified') 111 | 112 | print() -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | 5 | # from . import dataset_nuscenes -------------------------------------------------------------------------------- /dataloader/dataset_semantickitti.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | """ 5 | SemKITTI dataloader 6 | """ 7 | import numpy as np 8 | import torch 9 | import numba as nb 10 | from torch.utils import data 11 | 12 | REGISTERED_DATASET_CLASSES = {} 13 | 14 | 15 | def register_dataset(cls, name=None): 16 | global REGISTERED_DATASET_CLASSES 17 | if name is None: 18 | name = cls.__name__ 19 | assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}" 20 | REGISTERED_DATASET_CLASSES[name] = cls 21 | return cls 22 | 23 | 24 | def get_model_class(name): 25 | global REGISTERED_DATASET_CLASSES 26 | assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}" 27 | return REGISTERED_DATASET_CLASSES[name] 28 | 29 | 30 | @register_dataset 31 | class voxel_dataset(data.Dataset): 32 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False, 33 | fixed_volume_space=False, max_volume_space=[50, 50, 1.5], min_volume_space=[-50, -50, -3]): 34 | 'Initialization' 35 | self.point_cloud_dataset = in_dataset 36 | self.grid_size = np.asarray(grid_size) 37 | self.rotate_aug = rotate_aug 38 | self.ignore_label = ignore_label 39 | self.return_test = return_test 40 | self.flip_aug = flip_aug 41 | self.fixed_volume_space = fixed_volume_space 42 | self.max_volume_space = max_volume_space 43 | self.min_volume_space = min_volume_space 44 | 45 | def __len__(self): 46 | 'Denotes the total number of samples' 47 | return len(self.point_cloud_dataset) 48 | 49 | def __getitem__(self, index): 50 | 'Generates one sample of data' 51 | data = self.point_cloud_dataset[index] 52 | if len(data) == 2: 53 | xyz, labels = data 54 | elif len(data) == 3: 55 | xyz, labels, sig = data 56 | if len(sig.shape) == 2: sig = np.squeeze(sig) 57 | elif len(data) == 4: 58 | xyz, labels, sig, origin_len = data 59 | if len(sig.shape) == 2: sig = np.squeeze(sig) 60 | else: 61 | raise Exception('Return invalid data tuple') 62 | 63 | # random data augmentation by rotation 64 | if self.rotate_aug: 65 | rotate_rad = np.deg2rad(np.random.random() * 360) 66 | c, s = np.cos(rotate_rad), np.sin(rotate_rad) 67 | j = np.matrix([[c, s], [-s, c]]) 68 | xyz[:, :2] = np.dot(xyz[:, :2], j) 69 | 70 | # random data augmentation by flip x , y or x+y 71 | if self.flip_aug: 72 | flip_type = np.random.choice(4, 1) 73 | if flip_type == 1: 74 | xyz[:, 0] = -xyz[:, 0] 75 | elif flip_type == 2: 76 | xyz[:, 1] = -xyz[:, 1] 77 | elif flip_type == 3: 78 | xyz[:, :2] = -xyz[:, :2] 79 | 80 | 81 | if self.fixed_volume_space: 82 | max_bound = np.asarray(self.max_volume_space) 83 | min_bound = np.asarray(self.min_volume_space) 84 | else: 85 | max_bound = np.percentile(xyz, 100, axis=0) 86 | min_bound = np.percentile(xyz, 0, axis=0) 87 | 88 | ### Cut point cloud and segmentation label for valid range 89 | cut_point = 1 90 | if cut_point == 1: 91 | xyz0 = xyz 92 | for ci in range(3): 93 | xyz0[xyz[:, ci] < min_bound[ci], :] = 1000 94 | xyz0[xyz[:, ci] > max_bound[ci], :] = 1000 95 | valid_inds = xyz0[:, 0] != 1000 96 | xyz = xyz[valid_inds, :] 97 | sig = sig[valid_inds] 98 | 99 | # transpose centre coord for x axis 100 | trans_x = 1 101 | if trans_x: 102 | x_bias = (self.max_volume_space[0] - self.min_volume_space[0])/2 103 | min_bound[0] -= x_bias 104 | max_bound[0] -= x_bias 105 | xyz[:, 0] -= x_bias 106 | 107 | if len(data) == 4: 108 | origin_len = len(xyz) 109 | 110 | # get grid index 111 | crop_range = max_bound - min_bound 112 | cur_grid_size = self.grid_size 113 | 114 | intervals = crop_range / (cur_grid_size - 1) 115 | 116 | if (intervals == 0).any(): print("Zero interval!") 117 | 118 | grid_ind = (np.floor((np.clip(xyz, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 119 | 120 | # process voxel position 121 | dim_array = np.ones(len(self.grid_size) + 1, int) 122 | dim_array[0] = -1 123 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array) 124 | 125 | ## process labels 126 | # processed_label = np.ones(self.grid_size, dtype=np.uint8) * self.ignore_label 127 | # label_voxel_pair = np.concatenate([grid_ind, labels], axis=1) 128 | # label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2])), :] 129 | # processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair) 130 | 131 | processed_label = labels # voxel labels 132 | 133 | data_tuple = (voxel_position, processed_label) 134 | 135 | # center data on each voxel for PTnet 136 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 137 | return_xyz = xyz - voxel_centers 138 | return_xyz = np.concatenate((return_xyz, xyz), axis=1) 139 | 140 | if len(data) == 2: 141 | return_fea = return_xyz 142 | elif len(data) == 3 or len(data) == 4: 143 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) # 7:xyz_bias + xyz + intensity 144 | 145 | if self.return_test: 146 | data_tuple += (grid_ind, labels, return_fea, index) 147 | else: 148 | data_tuple += (grid_ind, labels, return_fea) 149 | 150 | if len(data) == 4: 151 | data_tuple += (origin_len,) 152 | 153 | return data_tuple 154 | 155 | 156 | # transformation between Cartesian coordinates and polar coordinates 157 | def cart2polar(input_xyz): 158 | rho = np.sqrt(input_xyz[:, 0] ** 2 + input_xyz[:, 1] ** 2) 159 | phi = np.arctan2(input_xyz[:, 1], input_xyz[:, 0]) 160 | return np.stack((rho, phi, input_xyz[:, 2]), axis=1) 161 | 162 | 163 | def polar2cat(input_xyz_polar): 164 | # print(input_xyz_polar.shape) 165 | x = input_xyz_polar[0] * np.cos(input_xyz_polar[1]) 166 | y = input_xyz_polar[0] * np.sin(input_xyz_polar[1]) 167 | return np.stack((x, y, input_xyz_polar[2]), axis=0) 168 | 169 | 170 | @register_dataset 171 | class cylinder_dataset(data.Dataset): 172 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False, 173 | fixed_volume_space=False, max_volume_space=[50, np.pi, 2], min_volume_space=[0, -np.pi, -4], 174 | scale_aug=False, 175 | transform_aug=False, trans_std=[0.1, 0.1, 0.1], 176 | min_rad=-np.pi / 4, max_rad=np.pi / 4, use_tta=False): 177 | self.point_cloud_dataset = in_dataset 178 | self.grid_size = np.asarray(grid_size) 179 | self.rotate_aug = rotate_aug 180 | self.flip_aug = flip_aug 181 | self.scale_aug = scale_aug 182 | self.ignore_label = ignore_label 183 | self.return_test = return_test 184 | self.fixed_volume_space = fixed_volume_space 185 | self.max_volume_space = max_volume_space 186 | self.min_volume_space = min_volume_space 187 | self.transform = transform_aug 188 | self.trans_std = trans_std 189 | 190 | self.noise_rotation = np.random.uniform(min_rad, max_rad) 191 | self.use_tta = use_tta 192 | 193 | def __len__(self): 194 | 'Denotes the total number of samples' 195 | return len(self.point_cloud_dataset) 196 | 197 | def rotation_points_single_angle(self, points, angle, axis=0): 198 | # points: [N, 3] 199 | rot_sin = np.sin(angle) 200 | rot_cos = np.cos(angle) 201 | if axis == 1: 202 | rot_mat_T = np.array( 203 | [[rot_cos, 0, -rot_sin], [0, 1, 0], [rot_sin, 0, rot_cos]], 204 | dtype=points.dtype) 205 | elif axis == 2 or axis == -1: 206 | rot_mat_T = np.array( 207 | [[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]], 208 | dtype=points.dtype) 209 | elif axis == 0: 210 | rot_mat_T = np.array( 211 | [[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]], 212 | dtype=points.dtype) 213 | else: 214 | raise ValueError("axis should in range") 215 | 216 | return points @ rot_mat_T 217 | 218 | def __getitem__(self, index): 219 | 'Generates one sample of data' 220 | data = self.point_cloud_dataset[index] 221 | if self.use_tta: 222 | data_total = [] 223 | voting = 4 224 | for idx in range(voting): 225 | data_single_ori = self.get_single_sample(data, index, idx) 226 | data_total.append(data_single_ori) 227 | data_total = tuple(data_total) 228 | return data_total 229 | else: 230 | data_single = self.get_single_sample(data, index) 231 | return data_single 232 | 233 | def get_single_sample(self, data, index, vote_idx=0): 234 | if len(data) == 2: 235 | xyz, labels = data 236 | elif len(data) == 3: 237 | xyz, labels, sig = data 238 | if len(sig.shape) == 2: sig = np.squeeze(sig) 239 | elif len(data) == 4: 240 | xyz, labels, sig, origin_len = data 241 | if len(sig.shape) == 2: sig = np.squeeze(sig) 242 | else: 243 | raise Exception('Return invalid data tuple') 244 | 245 | # random data augmentation by rotation 246 | if self.rotate_aug: 247 | rotate_rad = np.deg2rad(np.random.random() * 90) - np.pi / 4 248 | c, s = np.cos(rotate_rad), np.sin(rotate_rad) 249 | j = np.matrix([[c, s], [-s, c]]) 250 | xyz[:, :2] = np.dot(xyz[:, :2], j) 251 | 252 | # random data augmentation by flip x , y or x+y 253 | if self.flip_aug: 254 | if self.use_tta: 255 | flip_type = vote_idx 256 | else: 257 | flip_type = np.random.choice(4, 1) 258 | if flip_type == 1: 259 | xyz[:, 0] = -xyz[:, 0] 260 | elif flip_type == 2: 261 | xyz[:, 1] = -xyz[:, 1] 262 | elif flip_type == 3: 263 | xyz[:, :2] = -xyz[:, :2] 264 | if self.scale_aug: 265 | noise_scale = np.random.uniform(0.95, 1.05) 266 | xyz[:, 0] = noise_scale * xyz[:, 0] 267 | xyz[:, 1] = noise_scale * xyz[:, 1] 268 | # convert coordinate into polar coordinates 269 | 270 | if self.transform: 271 | noise_translate = np.array([np.random.normal(0, self.trans_std[0], 1), 272 | np.random.normal(0, self.trans_std[1], 1), 273 | np.random.normal(0, self.trans_std[2], 1)]).T 274 | 275 | xyz[:, 0:3] += noise_translate 276 | 277 | xyz_pol = cart2polar(xyz) 278 | 279 | max_bound_r = np.percentile(xyz_pol[:, 0], 100, axis=0) 280 | min_bound_r = np.percentile(xyz_pol[:, 0], 0, axis=0) 281 | max_bound = np.max(xyz_pol[:, 1:], axis=0) 282 | min_bound = np.min(xyz_pol[:, 1:], axis=0) 283 | max_bound = np.concatenate(([max_bound_r], max_bound)) 284 | min_bound = np.concatenate(([min_bound_r], min_bound)) 285 | if self.fixed_volume_space: 286 | max_bound = np.asarray(self.max_volume_space) 287 | min_bound = np.asarray(self.min_volume_space) 288 | # get grid index 289 | crop_range = max_bound - min_bound 290 | cur_grid_size = self.grid_size 291 | intervals = crop_range / (cur_grid_size - 1) 292 | 293 | if (intervals == 0).any(): print("Zero interval!") 294 | grid_ind = (np.floor((np.clip(xyz_pol, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 295 | 296 | voxel_position = np.zeros(self.grid_size, dtype=np.float32) 297 | dim_array = np.ones(len(self.grid_size) + 1, int) 298 | dim_array[0] = -1 299 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array) 300 | voxel_position = polar2cat(voxel_position) 301 | 302 | processed_label = np.ones(self.grid_size, dtype=np.uint8) * self.ignore_label 303 | label_voxel_pair = np.concatenate([grid_ind, labels], axis=1) 304 | label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2])), :] 305 | processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair) 306 | data_tuple = (voxel_position, processed_label) 307 | 308 | # center data on each voxel for PTnet 309 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 310 | return_xyz = xyz_pol - voxel_centers 311 | return_xyz = np.concatenate((return_xyz, xyz_pol, xyz[:, :2]), axis=1) 312 | 313 | if len(data) == 2: 314 | return_fea = return_xyz 315 | elif len(data) == 3 or len(data) == 4: 316 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) 317 | 318 | if self.return_test: 319 | data_tuple += (grid_ind, labels, return_fea, index) 320 | else: 321 | data_tuple += (grid_ind, labels, return_fea) 322 | 323 | if len(data) == 4: 324 | data_tuple += (origin_len,) 325 | return data_tuple 326 | 327 | 328 | @register_dataset 329 | class polar_dataset(data.Dataset): 330 | def __init__(self, in_dataset, grid_size, rotate_aug=False, flip_aug=False, ignore_label=255, return_test=False, 331 | fixed_volume_space=False, max_volume_space=[50, np.pi, 2], min_volume_space=[0, -np.pi, -4], 332 | scale_aug=False): 333 | self.point_cloud_dataset = in_dataset 334 | self.grid_size = np.asarray(grid_size) 335 | self.rotate_aug = rotate_aug 336 | self.flip_aug = flip_aug 337 | self.scale_aug = scale_aug 338 | self.ignore_label = ignore_label 339 | self.return_test = return_test 340 | self.fixed_volume_space = fixed_volume_space 341 | self.max_volume_space = max_volume_space 342 | self.min_volume_space = min_volume_space 343 | 344 | def __len__(self): 345 | 'Denotes the total number of samples' 346 | return len(self.point_cloud_dataset) 347 | 348 | def __getitem__(self, index): 349 | 'Generates one sample of data' 350 | data = self.point_cloud_dataset[index] 351 | if len(data) == 2: 352 | xyz, labels = data 353 | elif len(data) == 3: 354 | xyz, labels, sig = data 355 | if len(sig.shape) == 2: 356 | sig = np.squeeze(sig) 357 | else: 358 | raise Exception('Return invalid data tuple') 359 | 360 | # random data augmentation by rotation 361 | if self.rotate_aug: 362 | rotate_rad = np.deg2rad(np.random.random() * 45) - np.pi / 8 363 | c, s = np.cos(rotate_rad), np.sin(rotate_rad) 364 | j = np.matrix([[c, s], [-s, c]]) 365 | xyz[:, :2] = np.dot(xyz[:, :2], j) 366 | 367 | # random data augmentation by flip x , y or x+y 368 | if self.flip_aug: 369 | flip_type = np.random.choice(4, 1) 370 | if flip_type == 1: 371 | xyz[:, 0] = -xyz[:, 0] 372 | elif flip_type == 2: 373 | xyz[:, 1] = -xyz[:, 1] 374 | elif flip_type == 3: 375 | xyz[:, :2] = -xyz[:, :2] 376 | if self.scale_aug: 377 | noise_scale = np.random.uniform(0.95, 1.05) 378 | xyz[:, 0] = noise_scale * xyz[:, 0] 379 | xyz[:, 1] = noise_scale * xyz[:, 1] 380 | xyz_pol = cart2polar(xyz) 381 | 382 | max_bound_r = np.percentile(xyz_pol[:, 0], 100, axis=0) 383 | min_bound_r = np.percentile(xyz_pol[:, 0], 0, axis=0) 384 | max_bound = np.max(xyz_pol[:, 1:], axis=0) 385 | min_bound = np.min(xyz_pol[:, 1:], axis=0) 386 | max_bound = np.concatenate(([max_bound_r], max_bound)) 387 | min_bound = np.concatenate(([min_bound_r], min_bound)) 388 | if self.fixed_volume_space: 389 | max_bound = np.asarray(self.max_volume_space) 390 | min_bound = np.asarray(self.min_volume_space) 391 | # get grid index 392 | crop_range = max_bound - min_bound 393 | cur_grid_size = self.grid_size 394 | intervals = crop_range / (cur_grid_size - 1) 395 | 396 | if (intervals == 0).any(): print("Zero interval!") 397 | grid_ind = (np.floor((np.clip(xyz_pol, min_bound, max_bound) - min_bound) / intervals)).astype(np.int) 398 | 399 | voxel_position = np.zeros(self.grid_size, dtype=np.float32) 400 | dim_array = np.ones(len(self.grid_size) + 1, int) 401 | dim_array[0] = -1 402 | voxel_position = np.indices(self.grid_size) * intervals.reshape(dim_array) + min_bound.reshape(dim_array) 403 | voxel_position = polar2cat(voxel_position) 404 | 405 | processed_label = np.ones(self.grid_size, dtype=np.uint8) * self.ignore_label 406 | label_voxel_pair = np.concatenate([grid_ind, labels], axis=1) 407 | label_voxel_pair = label_voxel_pair[np.lexsort((grid_ind[:, 0], grid_ind[:, 1], grid_ind[:, 2])), :] 408 | processed_label = nb_process_label(np.copy(processed_label), label_voxel_pair) 409 | data_tuple = (voxel_position, processed_label) 410 | 411 | # center data on each voxel for PTnet 412 | voxel_centers = (grid_ind.astype(np.float32) + 0.5) * intervals + min_bound 413 | return_xyz = xyz_pol - voxel_centers 414 | return_xyz = np.concatenate((return_xyz, xyz_pol, xyz[:, :2]), axis=1) 415 | 416 | if len(data) == 2: 417 | return_fea = return_xyz 418 | elif len(data) == 3: 419 | return_fea = np.concatenate((return_xyz, sig[..., np.newaxis]), axis=1) 420 | 421 | if self.return_test: 422 | data_tuple += (grid_ind, labels, return_fea, index) 423 | else: 424 | data_tuple += (grid_ind, labels, return_fea) 425 | 426 | return data_tuple 427 | 428 | 429 | @nb.jit('u1[:,:,:](u1[:,:,:],i8[:,:])', nopython=True, cache=True, parallel=False) 430 | def nb_process_label(processed_label, sorted_label_voxel_pair): 431 | label_size = 256 432 | counter = np.zeros((label_size,), dtype=np.uint16) 433 | counter[sorted_label_voxel_pair[0, 3]] = 1 434 | cur_sear_ind = sorted_label_voxel_pair[0, :3] 435 | for i in range(1, sorted_label_voxel_pair.shape[0]): 436 | cur_ind = sorted_label_voxel_pair[i, :3] 437 | if not np.all(np.equal(cur_ind, cur_sear_ind)): 438 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter) 439 | counter = np.zeros((label_size,), dtype=np.uint16) 440 | cur_sear_ind = cur_ind 441 | counter[sorted_label_voxel_pair[i, 3]] += 1 442 | processed_label[cur_sear_ind[0], cur_sear_ind[1], cur_sear_ind[2]] = np.argmax(counter) 443 | return processed_label 444 | 445 | 446 | def collate_fn_BEV(data): 447 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 448 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 449 | grid_ind_stack = [d[2] for d in data] 450 | point_label = [d[3] for d in data] 451 | xyz = [d[4] for d in data] 452 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz 453 | 454 | '''def collate_fn_BEV(data): 455 | 456 | voxel_label = [] 457 | 458 | for da1 in data: 459 | for da2 in da1: 460 | voxel_label.append(da2[1]) 461 | 462 | voxel_label = np.stack(voxel_label).astype(np.int) 463 | 464 | grid_ind_stack = [] 465 | for da1 in data: 466 | for da2 in da1: 467 | grid_ind_stack.append(da2[2]) 468 | 469 | 470 | point_label = [] 471 | 472 | for da1 in data: 473 | for da2 in da1: 474 | point_label.append(da2[3]) 475 | 476 | xyz = [] 477 | 478 | for da1 in data: 479 | for da2 in da1: 480 | xyz.append(da2[4]) 481 | 482 | return xyz, torch.from_numpy(voxel_label), grid_ind_stack, point_label, xyz 483 | ''' 484 | 485 | '''def collate_fn_BEV_test_old(data): 486 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 487 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 488 | grid_ind_stack = [d[2] for d in data] 489 | point_label = [d[3] for d in data] 490 | xyz = [d[4] for d in data] 491 | index = [d[5] for d in data] 492 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index 493 | ''' 494 | 495 | def collate_fn_BEV_tta(data): 496 | 497 | voxel_label = [] 498 | 499 | for da1 in data: 500 | for da2 in da1: 501 | voxel_label.append(da2[1]) 502 | 503 | #voxel_label.astype(np.int) 504 | 505 | grid_ind_stack = [] 506 | for da1 in data: 507 | for da2 in da1: 508 | grid_ind_stack.append(da2[2]) 509 | 510 | 511 | 512 | point_label = [] 513 | 514 | for da1 in data: 515 | for da2 in da1: 516 | point_label.append(da2[3]) 517 | 518 | xyz = [] 519 | 520 | for da1 in data: 521 | for da2 in da1: 522 | xyz.append(da2[4]) 523 | index = [] 524 | for da1 in data: 525 | for da2 in da1: 526 | index.append(da2[5]) 527 | 528 | return xyz, xyz, grid_ind_stack, point_label, xyz, index 529 | 530 | 531 | def collate_fn_BEV_ms(data): 532 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 533 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 534 | grid_ind_stack = [d[2] for d in data] 535 | point_label = [d[3] for d in data] 536 | xyz = [d[4] for d in data] 537 | # origin_len = [d[5] for d in data] 538 | index = [d[5] for d in data] 539 | origin_len = [d[6] for d in data] 540 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index, origin_len 541 | 542 | 543 | def collate_fn_BEV_ms_tta(data): 544 | data2stack = np.stack([d[0] for d in data]).astype(np.float32) 545 | label2stack = np.stack([d[1] for d in data]).astype(np.int) 546 | grid_ind_stack = [d[2] for d in data] 547 | point_label = [d[3] for d in data] 548 | xyz = [d[4] for d in data] 549 | index = [d[5] for d in data] 550 | origin_len = [d[6] for d in data] 551 | return torch.from_numpy(data2stack), torch.from_numpy(label2stack), grid_ind_stack, point_label, xyz, index, origin_len 552 | 553 | 554 | # def collate_fn_BEV_ms_tta(data): 555 | # 556 | # voxel_label = [] 557 | # 558 | # for da1 in data: 559 | # for da2 in da1: 560 | # voxel_label.append(da2[1]) 561 | # 562 | # #voxel_label.astype(np.int) 563 | # 564 | # grid_ind_stack = [] 565 | # for da1 in data: 566 | # for da2 in da1: 567 | # grid_ind_stack.append(da2[2]) 568 | # 569 | # point_label = [] 570 | # 571 | # for da1 in data: 572 | # for da2 in da1: 573 | # point_label.append(da2[3]) 574 | # 575 | # xyz = [] 576 | # 577 | # for da1 in data: 578 | # for da2 in da1: 579 | # xyz.append(da2[4]) 580 | # 581 | # index = [] 582 | # 583 | # for da1 in data: 584 | # for da2 in da1: 585 | # index.append(da2[5]) 586 | # 587 | # origin_len = [] 588 | # 589 | # for da1 in data: 590 | # for da2 in da1: 591 | # origin_len.append(da2[6]) 592 | # 593 | # return xyz, xyz, grid_ind_stack, point_label, xyz, index, origin_len 594 | -------------------------------------------------------------------------------- /dataloader/pc_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: pc_dataset.py 4 | 5 | import os 6 | import numpy as np 7 | from torch.utils import data 8 | import yaml 9 | import pickle 10 | import pathlib 11 | 12 | REGISTERED_PC_DATASET_CLASSES = {} 13 | 14 | 15 | def register_dataset(cls, name=None): 16 | global REGISTERED_PC_DATASET_CLASSES 17 | if name is None: 18 | name = cls.__name__ 19 | assert name not in REGISTERED_PC_DATASET_CLASSES, f"exist class: {REGISTERED_PC_DATASET_CLASSES}" 20 | REGISTERED_PC_DATASET_CLASSES[name] = cls 21 | return cls 22 | 23 | 24 | def get_pc_model_class(name): 25 | global REGISTERED_PC_DATASET_CLASSES 26 | assert name in REGISTERED_PC_DATASET_CLASSES, f"available class: {REGISTERED_PC_DATASET_CLASSES}" 27 | return REGISTERED_PC_DATASET_CLASSES[name] 28 | 29 | 30 | @register_dataset 31 | class SemKITTI_sk(data.Dataset): 32 | def __init__(self, data_path, imageset='train', 33 | return_ref=False, label_mapping="semantic-kitti.yaml", nusc=None): 34 | self.return_ref = return_ref 35 | with open(label_mapping, 'r') as stream: 36 | semkittiyaml = yaml.safe_load(stream) 37 | self.learning_map = semkittiyaml['learning_map'] 38 | self.imageset = imageset 39 | if imageset == 'train': 40 | split = semkittiyaml['split']['train'] 41 | elif imageset == 'val': 42 | split = semkittiyaml['split']['valid'] 43 | elif imageset == 'test': 44 | split = semkittiyaml['split']['test'] 45 | else: 46 | raise Exception('Split must be train/val/test') 47 | 48 | self.im_idx = [] 49 | for i_folder in split: 50 | self.im_idx += absoluteFilePaths('/'.join([data_path, str(i_folder).zfill(2), 'velodyne'])) 51 | 52 | def __len__(self): 53 | 'Denotes the total number of samples' 54 | return len(self.im_idx) 55 | 56 | def __getitem__(self, index): 57 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) 58 | if self.imageset == 'test': 59 | annotated_data = np.expand_dims(np.zeros_like(raw_data[:, 0], dtype=int), axis=1) 60 | else: 61 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'labels')[:-3] + 'label', 62 | dtype=np.uint32).reshape((-1, 1)) 63 | annotated_data = annotated_data & 0xFFFF # delete high 16 digits binary 64 | annotated_data = np.vectorize(self.learning_map.__getitem__)(annotated_data) 65 | 66 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8)) 67 | if self.return_ref: 68 | data_tuple += (raw_data[:, 3],) 69 | return data_tuple 70 | 71 | 72 | def absoluteFilePaths(directory): 73 | for dirpath, _, filenames in os.walk(directory): 74 | filenames.sort() 75 | for f in filenames: 76 | yield os.path.abspath(os.path.join(dirpath, f)) 77 | 78 | 79 | def SemKITTI2train(label): 80 | if isinstance(label, list): 81 | return [SemKITTI2train_single(a) for a in label] 82 | else: 83 | return SemKITTI2train_single(label) 84 | 85 | 86 | def SemKITTI2train_single(label): 87 | remove_ind = label == 0 88 | label -= 1 89 | label[remove_ind] = 255 90 | return label 91 | 92 | 93 | def unpack(compressed): # from samantickitti api 94 | ''' given a bit encoded voxel grid, make a normal voxel grid out of it. ''' 95 | uncompressed = np.zeros(compressed.shape[0] * 8, dtype=np.uint8) 96 | uncompressed[::8] = compressed[:] >> 7 & 1 97 | uncompressed[1::8] = compressed[:] >> 6 & 1 98 | uncompressed[2::8] = compressed[:] >> 5 & 1 99 | uncompressed[3::8] = compressed[:] >> 4 & 1 100 | uncompressed[4::8] = compressed[:] >> 3 & 1 101 | uncompressed[5::8] = compressed[:] >> 2 & 1 102 | uncompressed[6::8] = compressed[:] >> 1 & 1 103 | uncompressed[7::8] = compressed[:] & 1 104 | 105 | return uncompressed 106 | 107 | def get_eval_mask(labels, invalid_voxels): # from samantickitti api 108 | """ 109 | Ignore labels set to 255 and invalid voxels (the ones never hit by a laser ray, probed using ray tracing) 110 | :param labels: input ground truth voxels 111 | :param invalid_voxels: voxels ignored during evaluation since the lie beyond the scene that was captured by the laser 112 | :return: boolean mask to subsample the voxels to evaluate 113 | """ 114 | masks = np.ones_like(labels, dtype=np.bool) 115 | masks[labels == 255] = False 116 | masks[invalid_voxels == 1] = False 117 | 118 | return masks 119 | 120 | 121 | from os.path import join 122 | @register_dataset 123 | class SemKITTI_sk_multiscan(data.Dataset): 124 | def __init__(self, data_path, imageset='train',return_ref=False, label_mapping="semantic-kitti-multiscan.yaml", nusc=None): 125 | self.return_ref = return_ref 126 | with open(label_mapping, 'r') as stream: 127 | semkittiyaml = yaml.safe_load(stream) 128 | ### remap completion label 129 | remapdict = semkittiyaml['learning_map'] 130 | # make lookup table for mapping 131 | maxkey = max(remapdict.keys()) 132 | remap_lut = np.zeros((maxkey + 100), dtype=np.int32) 133 | remap_lut[list(remapdict.keys())] = list(remapdict.values()) 134 | # in completion we have to distinguish empty and invalid voxels. 135 | # Important: For voxels 0 corresponds to "empty" and not "unlabeled". 136 | remap_lut[remap_lut == 0] = 255 # map 0 to 'invalid' 137 | remap_lut[0] = 0 # only 'empty' stays 'empty'. 138 | self.comletion_remap_lut = remap_lut 139 | 140 | self.learning_map = semkittiyaml['learning_map'] 141 | self.imageset = imageset 142 | self.data_path = data_path 143 | if imageset == 'train': 144 | split = semkittiyaml['split']['train'] 145 | elif imageset == 'val': 146 | split = semkittiyaml['split']['valid'] 147 | elif imageset == 'test': 148 | split = semkittiyaml['split']['test'] 149 | else: 150 | raise Exception('Split must be train/val/test') 151 | 152 | multiscan = 0 # additional frames are fused with target-frame. Hence, multiscan+1 point clouds in total 153 | print('multiscan: %d' %multiscan) 154 | self.multiscan = multiscan 155 | self.im_idx = [] 156 | 157 | self.calibrations = [] 158 | self.times = [] 159 | self.poses = [] 160 | 161 | self.load_calib_poses() 162 | for i_folder in split: 163 | # velodyne path corresponding to voxel path 164 | complete_path = os.path.join(data_path, str(i_folder).zfill(2), "voxels") 165 | files = list(pathlib.Path(complete_path).glob('*.bin')) 166 | for filename in files: 167 | self.im_idx.append(str(filename).replace('voxels', 'velodyne')) 168 | 169 | 170 | 171 | def __len__(self): 172 | 'Denotes the total number of samples' 173 | return len(self.im_idx) 174 | 175 | def load_calib_poses(self): 176 | """ 177 | load calib poses and times. 178 | """ 179 | 180 | ########### 181 | # Load data 182 | ########### 183 | 184 | self.calibrations = [] 185 | self.times = [] 186 | self.poses = [] 187 | 188 | for seq in range(0, 22): 189 | seq_folder = join(self.data_path, str(seq).zfill(2)) 190 | 191 | # Read Calib 192 | self.calibrations.append(self.parse_calibration(join(seq_folder, "calib.txt"))) 193 | 194 | # Read times 195 | self.times.append(np.loadtxt(join(seq_folder, 'times.txt'), dtype=np.float32)) 196 | 197 | # Read poses 198 | poses_f64 = self.parse_poses(join(seq_folder, 'poses.txt'), self.calibrations[-1]) 199 | self.poses.append([pose.astype(np.float32) for pose in poses_f64]) 200 | 201 | def parse_calibration(self, filename): 202 | """ read calibration file with given filename 203 | 204 | Returns 205 | ------- 206 | dict 207 | Calibration matrices as 4x4 numpy arrays. 208 | """ 209 | calib = {} 210 | 211 | calib_file = open(filename) 212 | for line in calib_file: 213 | key, content = line.strip().split(":") 214 | values = [float(v) for v in content.strip().split()] 215 | 216 | pose = np.zeros((4, 4)) 217 | pose[0, 0:4] = values[0:4] 218 | pose[1, 0:4] = values[4:8] 219 | pose[2, 0:4] = values[8:12] 220 | pose[3, 3] = 1.0 221 | 222 | calib[key] = pose 223 | 224 | calib_file.close() 225 | 226 | return calib 227 | 228 | def parse_poses(self, filename, calibration): 229 | """ read poses file with per-scan poses from given filename 230 | 231 | Returns 232 | ------- 233 | list 234 | list of poses as 4x4 numpy arrays. 235 | """ 236 | file = open(filename) 237 | 238 | poses = [] 239 | 240 | Tr = calibration["Tr"] 241 | Tr_inv = np.linalg.inv(Tr) 242 | 243 | for line in file: 244 | values = [float(v) for v in line.strip().split()] 245 | 246 | pose = np.zeros((4, 4)) 247 | pose[0, 0:4] = values[0:4] 248 | pose[1, 0:4] = values[4:8] 249 | pose[2, 0:4] = values[8:12] 250 | pose[3, 3] = 1.0 251 | 252 | poses.append(np.matmul(Tr_inv, np.matmul(pose, Tr))) 253 | 254 | return poses 255 | 256 | def fuse_multi_scan(self, points, pose0, pose): 257 | 258 | hpoints = np.hstack((points[:, :3], np.ones_like(points[:, :1]))) 259 | new_points = np.sum(np.expand_dims(hpoints, 2) * pose.T, axis=1) 260 | new_points = new_points[:, :3] 261 | new_coords = new_points - pose0[:3, 3] 262 | new_coords = np.sum(np.expand_dims(new_coords, 2) * pose0[:3, :3], axis=1) 263 | new_coords = np.hstack((new_coords, points[:, 3:])) 264 | 265 | return new_coords 266 | 267 | def __getitem__(self, index): 268 | raw_data = np.fromfile(self.im_idx[index], dtype=np.float32).reshape((-1, 4)) # point cloud 269 | origin_len = len(raw_data) 270 | voxel_label = 1 271 | 272 | number_idx = int(self.im_idx[index][-10:-4]) 273 | dir_idx = int(self.im_idx[index][-22:-20]) 274 | 275 | pose0 = self.poses[dir_idx][number_idx] 276 | 277 | for fuse_idx in range(self.multiscan): 278 | plus_idx = fuse_idx + 1 279 | data_idx = number_idx + plus_idx 280 | path = self.im_idx[index][:-10] 281 | newpath2 = path + str(data_idx).zfill(6) + self.im_idx[index][-4:] 282 | voxel_path = path.replace('velodyne', 'labels') 283 | files = list(pathlib.Path(voxel_path).glob('*.label')) 284 | 285 | if data_idx < len(files): 286 | pose = self.poses[dir_idx][data_idx] 287 | 288 | raw_data2 = np.fromfile(newpath2, dtype=np.float32).reshape((-1, 4)) 289 | 290 | if voxel_label == 0: 291 | if self.imageset == 'test': 292 | annotated_data2 = np.expand_dims(np.zeros_like(raw_data2[:, 0], dtype=int), axis=1) 293 | else: 294 | annotated_data2 = np.fromfile(newpath2.replace('velodyne', 'labels')[:-3] + 'label', 295 | dtype=np.int32).reshape((-1, 1)) 296 | annotated_data2 = annotated_data2 & 0xFFFF # delete high 16 digits binary 297 | 298 | raw_data2 = self.fuse_multi_scan(raw_data2, pose0, pose) 299 | 300 | if len(raw_data2) != 0: 301 | raw_data = np.concatenate((raw_data, raw_data2), 0) 302 | if voxel_label == 0: 303 | annotated_data = np.concatenate((annotated_data, annotated_data2), 0) 304 | 305 | if self.imageset == 'test': 306 | annotated_data = np.zeros([256, 256, 32], dtype=int).reshape((-1, 1)) 307 | else: 308 | annotated_data = np.fromfile(self.im_idx[index].replace('velodyne', 'voxels')[:-3] + 'label', 309 | dtype=np.uint16).reshape((-1, 1)) # voxel labels 310 | 311 | annotated_data = self.comletion_remap_lut[annotated_data] 312 | annotated_data = annotated_data.reshape((256, 256, 32)) 313 | 314 | data_tuple = (raw_data[:, :3], annotated_data.astype(np.uint8)) # xyz, voxel labels 315 | 316 | if self.return_ref: 317 | data_tuple += (raw_data[:, 3], origin_len) # origin_len is used to indicate the length of target-scan 318 | 319 | return data_tuple 320 | 321 | 322 | # load Semantic KITTI class info 323 | def get_SemKITTI_label_name(label_mapping): 324 | with open(label_mapping, 'r') as stream: 325 | semkittiyaml = yaml.safe_load(stream) 326 | SemKITTI_label_name = dict() 327 | for i in sorted(list(semkittiyaml['learning_map'].keys()))[::-1]: 328 | SemKITTI_label_name[semkittiyaml['learning_map'][i]] = semkittiyaml['labels'][i] 329 | 330 | return SemKITTI_label_name 331 | 332 | 333 | def get_nuScenes_label_name(label_mapping): 334 | with open(label_mapping, 'r') as stream: 335 | nuScenesyaml = yaml.safe_load(stream) 336 | nuScenes_label_name = dict() 337 | for i in sorted(list(nuScenesyaml['learning_map'].keys()))[::-1]: 338 | val_ = nuScenesyaml['learning_map'][i] 339 | nuScenes_label_name[val_] = nuScenesyaml['labels_16'][val_] 340 | 341 | return nuScenes_label_name 342 | -------------------------------------------------------------------------------- /imgs/pipline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCPNet/Codes-for-SCPNet/c0f55fa29e0b975454930f1237f5681327d53d09/imgs/pipline.png -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | -------------------------------------------------------------------------------- /network/conv_base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import List, Tuple 3 | 4 | 5 | class SharedMLP(nn.Sequential): 6 | 7 | def __init__( 8 | self, 9 | args: List[int], 10 | *, 11 | bn: bool = False, 12 | activation=nn.ReLU(inplace=True), 13 | preact: bool = False, 14 | first: bool = False, 15 | name: str = "", 16 | instance_norm: bool = False, ): 17 | super().__init__() 18 | 19 | for i in range(len(args) - 1): 20 | self.add_module( 21 | name + 'layer{}'.format(i), 22 | Conv2d( 23 | args[i], 24 | args[i + 1], 25 | bn=(not first or not preact or (i != 0)) and bn, 26 | activation=activation 27 | if (not first or not preact or (i != 0)) else None, 28 | preact=preact, 29 | instance_norm=instance_norm 30 | ) 31 | ) 32 | 33 | 34 | class _ConvBase(nn.Sequential): 35 | 36 | def __init__( 37 | self, 38 | in_size, 39 | out_size, 40 | kernel_size, 41 | stride, 42 | padding, 43 | activation, 44 | bn, 45 | init, 46 | conv=None, 47 | batch_norm=None, 48 | bias=True, 49 | preact=False, 50 | name="", 51 | instance_norm=False, 52 | instance_norm_func=None 53 | ): 54 | super().__init__() 55 | 56 | bias = bias and (not bn) 57 | conv_unit = conv( 58 | in_size, 59 | out_size, 60 | kernel_size=kernel_size, 61 | stride=stride, 62 | padding=padding, 63 | bias=bias 64 | ) 65 | init(conv_unit.weight) 66 | if bias: 67 | nn.init.constant_(conv_unit.bias, 0) 68 | 69 | if bn: 70 | if not preact: 71 | bn_unit = batch_norm(out_size) 72 | else: 73 | bn_unit = batch_norm(in_size) 74 | if instance_norm: 75 | if not preact: 76 | in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) 77 | else: 78 | in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) 79 | 80 | if preact: 81 | if bn: 82 | self.add_module(name + 'bn', bn_unit) 83 | 84 | if activation is not None: 85 | self.add_module(name + 'activation', activation) 86 | 87 | if not bn and instance_norm: 88 | self.add_module(name + 'in', in_unit) 89 | 90 | self.add_module(name + 'conv', conv_unit) 91 | 92 | if not preact: 93 | if bn: 94 | self.add_module(name + 'bn', bn_unit) 95 | 96 | if activation is not None: 97 | self.add_module(name + 'activation', activation) 98 | 99 | if not bn and instance_norm: 100 | self.add_module(name + 'in', in_unit) 101 | 102 | 103 | class _BNBase(nn.Sequential): 104 | 105 | def __init__(self, in_size, batch_norm=None, name=""): 106 | super().__init__() 107 | self.add_module(name + "bn", batch_norm(in_size)) 108 | 109 | nn.init.constant_(self[0].weight, 1.0) 110 | nn.init.constant_(self[0].bias, 0) 111 | 112 | 113 | class BatchNorm1d(_BNBase): 114 | 115 | def __init__(self, in_size: int, *, name: str = ""): 116 | super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) 117 | 118 | 119 | class BatchNorm2d(_BNBase): 120 | 121 | def __init__(self, in_size: int, name: str = ""): 122 | super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) 123 | 124 | 125 | class BatchNorm3d(_BNBase): 126 | 127 | def __init__(self, in_size: int, name: str = ""): 128 | super().__init__(in_size, batch_norm=nn.BatchNorm3d, name=name) 129 | 130 | 131 | class Conv1d(_ConvBase): 132 | 133 | def __init__( 134 | self, 135 | in_size: int, 136 | out_size: int, 137 | *, 138 | kernel_size: int = 1, 139 | stride: int = 1, 140 | padding: int = 0, 141 | activation=nn.ReLU(inplace=True), 142 | bn: bool = False, 143 | init=nn.init.kaiming_normal_, 144 | bias: bool = True, 145 | preact: bool = False, 146 | name: str = "", 147 | instance_norm=False 148 | ): 149 | super().__init__( 150 | in_size, 151 | out_size, 152 | kernel_size, 153 | stride, 154 | padding, 155 | activation, 156 | bn, 157 | init, 158 | conv=nn.Conv1d, 159 | batch_norm=BatchNorm1d, 160 | bias=bias, 161 | preact=preact, 162 | name=name, 163 | instance_norm=instance_norm, 164 | instance_norm_func=nn.InstanceNorm1d 165 | ) 166 | 167 | 168 | class Conv2d(_ConvBase): 169 | 170 | def __init__( 171 | self, 172 | in_size: int, 173 | out_size: int, 174 | *, 175 | kernel_size: Tuple[int, int] = (1, 1), 176 | stride: Tuple[int, int] = (1, 1), 177 | padding: Tuple[int, int] = (0, 0), 178 | activation=nn.ReLU(inplace=True), 179 | bn: bool = False, 180 | init=nn.init.kaiming_normal_, 181 | bias: bool = True, 182 | preact: bool = False, 183 | name: str = "", 184 | instance_norm=False 185 | ): 186 | super().__init__( 187 | in_size, 188 | out_size, 189 | kernel_size, 190 | stride, 191 | padding, 192 | activation, 193 | bn, 194 | init, 195 | conv=nn.Conv2d, 196 | batch_norm=BatchNorm2d, 197 | bias=bias, 198 | preact=preact, 199 | name=name, 200 | instance_norm=instance_norm, 201 | instance_norm_func=nn.InstanceNorm2d 202 | ) 203 | 204 | 205 | class Conv3d(_ConvBase): 206 | 207 | def __init__( 208 | self, 209 | in_size: int, 210 | out_size: int, 211 | *, 212 | kernel_size: Tuple[int, int, int] = (1, 1, 1), 213 | stride: Tuple[int, int, int] = (1, 1, 1), 214 | padding: Tuple[int, int, int] = (0, 0, 0), 215 | activation=nn.ReLU(inplace=True), 216 | bn: bool = False, 217 | init=nn.init.kaiming_normal_, 218 | bias: bool = True, 219 | preact: bool = False, 220 | name: str = "", 221 | instance_norm=False 222 | ): 223 | super().__init__( 224 | in_size, 225 | out_size, 226 | kernel_size, 227 | stride, 228 | padding, 229 | activation, 230 | bn, 231 | init, 232 | conv=nn.Conv3d, 233 | batch_norm=BatchNorm3d, 234 | bias=bias, 235 | preact=preact, 236 | name=name, 237 | instance_norm=instance_norm, 238 | instance_norm_func=nn.InstanceNorm3d 239 | ) 240 | 241 | 242 | class FC(nn.Sequential): 243 | 244 | def __init__( 245 | self, 246 | in_size: int, 247 | out_size: int, 248 | *, 249 | activation=nn.ReLU(inplace=True), 250 | bn: bool = False, 251 | init=None, 252 | preact: bool = False, 253 | name: str = "" 254 | ): 255 | super().__init__() 256 | 257 | fc = nn.Linear(in_size, out_size, bias=not bn) 258 | if init is not None: 259 | init(fc.weight) 260 | if not bn: 261 | nn.init.constant(fc.bias, 0) 262 | 263 | if preact: 264 | if bn: 265 | self.add_module(name + 'bn', BatchNorm1d(in_size)) 266 | 267 | if activation is not None: 268 | self.add_module(name + 'activation', activation) 269 | 270 | self.add_module(name + 'fc', fc) 271 | 272 | if not preact: 273 | if bn: 274 | self.add_module(name + 'bn', BatchNorm1d(out_size)) 275 | 276 | if activation is not None: 277 | self.add_module(name + 'activation', activation) 278 | 279 | 280 | -------------------------------------------------------------------------------- /network/cylinder_fea_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch_scatter 8 | 9 | 10 | class cylinder_fea(nn.Module): 11 | 12 | def __init__(self, grid_size, fea_dim=3, 13 | out_pt_fea_dim=64, max_pt_per_encode=64, fea_compre=None): 14 | super(cylinder_fea, self).__init__() 15 | 16 | self.PPmodel = nn.Sequential( 17 | nn.BatchNorm1d(fea_dim), 18 | 19 | nn.Linear(fea_dim, 64), 20 | nn.BatchNorm1d(64), 21 | nn.ReLU(), 22 | 23 | nn.Linear(64, 128), 24 | nn.BatchNorm1d(128), 25 | nn.ReLU(), 26 | 27 | nn.Linear(128, 256), 28 | nn.BatchNorm1d(256), 29 | nn.ReLU(), 30 | 31 | nn.Linear(256, out_pt_fea_dim) 32 | ) 33 | 34 | self.max_pt = max_pt_per_encode 35 | self.fea_compre = fea_compre 36 | self.grid_size = grid_size 37 | kernel_size = 3 38 | self.local_pool_op = torch.nn.MaxPool2d(kernel_size, stride=1, 39 | padding=(kernel_size - 1) // 2, 40 | dilation=1) 41 | self.pool_dim = out_pt_fea_dim 42 | 43 | # point feature compression 44 | if self.fea_compre is not None: 45 | self.fea_compression = nn.Sequential( 46 | nn.Linear(self.pool_dim, self.fea_compre), 47 | nn.ReLU()) 48 | self.pt_fea_dim = self.fea_compre 49 | else: 50 | self.pt_fea_dim = self.pool_dim 51 | 52 | def forward(self, pt_fea, xy_ind): 53 | cur_dev = pt_fea[0].get_device() 54 | 55 | ### concate everything 56 | cat_pt_ind = [] 57 | for i_batch in range(len(xy_ind)): 58 | cat_pt_ind.append(F.pad(xy_ind[i_batch], (1, 0), 'constant', value=i_batch)) 59 | 60 | cat_pt_fea = torch.cat(pt_fea, dim=0) 61 | cat_pt_ind = torch.cat(cat_pt_ind, dim=0) 62 | pt_num = cat_pt_ind.shape[0] 63 | 64 | ### shuffle the data 65 | shuffled_ind = torch.randperm(pt_num, device=cur_dev) 66 | cat_pt_fea = cat_pt_fea[shuffled_ind, :] 67 | cat_pt_ind = cat_pt_ind[shuffled_ind, :] 68 | 69 | ### unique xy grid index 70 | unq, unq_inv, unq_cnt = torch.unique(cat_pt_ind, return_inverse=True, return_counts=True, dim=0) 71 | unq = unq.type(torch.int64) 72 | 73 | ### process feature 74 | processed_cat_pt_fea = self.PPmodel(cat_pt_fea) 75 | pooled_data = torch_scatter.scatter_max(processed_cat_pt_fea, unq_inv, dim=0)[0] 76 | 77 | if self.fea_compre: 78 | processed_pooled_data = self.fea_compression(pooled_data) 79 | else: 80 | processed_pooled_data = pooled_data 81 | 82 | return unq, processed_pooled_data 83 | -------------------------------------------------------------------------------- /network/cylinder_spconv_3d.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: cylinder_spconv_3d.py 4 | 5 | from torch import nn 6 | import torch 7 | 8 | REGISTERED_MODELS_CLASSES = {} 9 | 10 | 11 | def register_model(cls, name=None): 12 | global REGISTERED_MODELS_CLASSES 13 | if name is None: 14 | name = cls.__name__ 15 | assert name not in REGISTERED_MODELS_CLASSES, f"exist class: {REGISTERED_MODELS_CLASSES}" 16 | REGISTERED_MODELS_CLASSES[name] = cls 17 | return cls 18 | 19 | 20 | def get_model_class(name): 21 | global REGISTERED_MODELS_CLASSES 22 | assert name in REGISTERED_MODELS_CLASSES, f"available class: {REGISTERED_MODELS_CLASSES}" 23 | return REGISTERED_MODELS_CLASSES[name] 24 | 25 | 26 | @register_model 27 | class cylinder_asym(nn.Module): 28 | def __init__(self, 29 | cylin_model, 30 | segmentator_spconv, 31 | sparse_shape, 32 | ): 33 | super().__init__() 34 | self.name = "cylinder_asym" 35 | 36 | self.cylinder_3d_generator = cylin_model 37 | 38 | self.cylinder_3d_spconv_seg = segmentator_spconv 39 | 40 | self.sparse_shape = sparse_shape 41 | 42 | def forward(self, train_pt_fea_ten, train_vox_ten, batch_size, val_grid=None, voting_num=4, use_tta=False): 43 | coords, features_3d = self.cylinder_3d_generator(train_pt_fea_ten, train_vox_ten) 44 | # train_pt_fea_ten: [batch_size, N1, 7] 45 | # train_vox_ten: [batch_size, N1, 3] 46 | 47 | if use_tta: 48 | batch_size *= voting_num 49 | 50 | spatial_features = self.cylinder_3d_spconv_seg(features_3d, coords, batch_size) # [batch_size, 20, 256, 256, 32] 51 | 52 | if use_tta: 53 | fused_predict = spatial_features[0, :] 54 | for idx in range(1, voting_num, 1): 55 | aug_predict = spatial_features[idx, :] 56 | aug_predict = torch.flip(aug_predict, dims=[2]) 57 | fused_predict += aug_predict 58 | return torch.unsqueeze(fused_predict, 0) 59 | else: 60 | return spatial_features 61 | -------------------------------------------------------------------------------- /network/segmentator_3d_asymm_spconv.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge, Xzy 3 | # @file: segmentator_3d_asymm_spconv.py 4 | 5 | import numpy as np 6 | import spconv 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1, indice_key=None): 12 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False, indice_key=indice_key) 14 | 15 | 16 | def conv1x3(in_planes, out_planes, stride=1, indice_key=None): 17 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride, 18 | padding=(0, 1, 1), bias=False, indice_key=indice_key) 19 | 20 | 21 | def conv1x1x3(in_planes, out_planes, stride=1, indice_key=None): 22 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, 23 | padding=(0, 0, 1), bias=False, indice_key=indice_key) 24 | 25 | 26 | def conv1x3x1(in_planes, out_planes, stride=1, indice_key=None): 27 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, 28 | padding=(0, 1, 0), bias=False, indice_key=indice_key) 29 | 30 | 31 | def conv3x1x1(in_planes, out_planes, stride=1, indice_key=None): 32 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, 33 | padding=(1, 0, 0), bias=False, indice_key=indice_key) 34 | 35 | 36 | def conv3x1(in_planes, out_planes, stride=1, indice_key=None): 37 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, 38 | padding=(1, 0, 1), bias=False, indice_key=indice_key) 39 | 40 | 41 | def conv1x1(in_planes, out_planes, stride=1, indice_key=None): 42 | return spconv.SubMConv3d(in_planes, out_planes, kernel_size=1, stride=stride, 43 | padding=1, bias=False, indice_key=indice_key) 44 | 45 | 46 | class ResContextBlock(nn.Module): 47 | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None): 48 | super(ResContextBlock, self).__init__() 49 | self.conv1 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef") 50 | self.bn0 = nn.BatchNorm1d(out_filters) 51 | self.act1 = nn.LeakyReLU() 52 | 53 | self.conv1_2 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") 54 | self.bn0_2 = nn.BatchNorm1d(out_filters) 55 | self.act1_2 = nn.LeakyReLU() 56 | 57 | self.conv2 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef") 58 | self.act2 = nn.LeakyReLU() 59 | self.bn1 = nn.BatchNorm1d(out_filters) 60 | 61 | self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") 62 | self.act3 = nn.LeakyReLU() 63 | self.bn2 = nn.BatchNorm1d(out_filters) 64 | 65 | self.weight_initialization() 66 | 67 | def weight_initialization(self): 68 | for m in self.modules(): 69 | if isinstance(m, nn.BatchNorm1d): 70 | nn.init.constant_(m.weight, 1) 71 | nn.init.constant_(m.bias, 0) 72 | 73 | def forward(self, x): 74 | shortcut = self.conv1(x) 75 | shortcut.features = self.act1(shortcut.features) 76 | shortcut.features = self.bn0(shortcut.features) 77 | 78 | shortcut = self.conv1_2(shortcut) 79 | shortcut.features = self.act1_2(shortcut.features) 80 | shortcut.features = self.bn0_2(shortcut.features) 81 | 82 | resA = self.conv2(x) 83 | resA.features = self.act2(resA.features) 84 | resA.features = self.bn1(resA.features) 85 | 86 | resA = self.conv3(resA) 87 | resA.features = self.act3(resA.features) 88 | resA.features = self.bn2(resA.features) 89 | resA.features = resA.features + shortcut.features 90 | 91 | return resA 92 | 93 | 94 | class ResBlock(nn.Module): 95 | def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), stride=1, 96 | pooling=True, drop_out=True, height_pooling=False, indice_key=None): 97 | super(ResBlock, self).__init__() 98 | self.pooling = pooling 99 | self.drop_out = drop_out 100 | 101 | self.conv1 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef") 102 | self.act1 = nn.LeakyReLU() 103 | self.bn0 = nn.BatchNorm1d(out_filters) 104 | 105 | self.conv1_2 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef") 106 | self.act1_2 = nn.LeakyReLU() 107 | self.bn0_2 = nn.BatchNorm1d(out_filters) 108 | 109 | self.conv2 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef") 110 | self.act2 = nn.LeakyReLU() 111 | self.bn1 = nn.BatchNorm1d(out_filters) 112 | 113 | self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef") 114 | self.act3 = nn.LeakyReLU() 115 | self.bn2 = nn.BatchNorm1d(out_filters) 116 | 117 | if pooling: 118 | if height_pooling: 119 | self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=2, 120 | padding=1, indice_key=indice_key, bias=False) 121 | else: 122 | self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1), 123 | padding=1, indice_key=indice_key, bias=False) 124 | self.weight_initialization() 125 | 126 | def weight_initialization(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.BatchNorm1d): 129 | nn.init.constant_(m.weight, 1) 130 | nn.init.constant_(m.bias, 0) 131 | 132 | def forward(self, x): 133 | shortcut = self.conv1(x) 134 | shortcut.features = self.act1(shortcut.features) 135 | shortcut.features = self.bn0(shortcut.features) 136 | 137 | shortcut = self.conv1_2(shortcut) 138 | shortcut.features = self.act1_2(shortcut.features) 139 | shortcut.features = self.bn0_2(shortcut.features) 140 | 141 | resA = self.conv2(x) 142 | resA.features = self.act2(resA.features) 143 | resA.features = self.bn1(resA.features) 144 | 145 | resA = self.conv3(resA) 146 | resA.features = self.act3(resA.features) 147 | resA.features = self.bn2(resA.features) 148 | 149 | resA.features = resA.features + shortcut.features 150 | 151 | if self.pooling: 152 | resB = self.pool(resA) 153 | return resB, resA 154 | else: 155 | return resA 156 | 157 | 158 | class UpBlock(nn.Module): 159 | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), indice_key=None, up_key=None): 160 | super(UpBlock, self).__init__() 161 | # self.drop_out = drop_out 162 | self.trans_dilao = conv3x3(in_filters, out_filters, indice_key=indice_key + "new_up") 163 | self.trans_act = nn.LeakyReLU() 164 | self.trans_bn = nn.BatchNorm1d(out_filters) 165 | 166 | self.conv1 = conv1x3(out_filters, out_filters, indice_key=indice_key) 167 | self.act1 = nn.LeakyReLU() 168 | self.bn1 = nn.BatchNorm1d(out_filters) 169 | 170 | self.conv2 = conv3x1(out_filters, out_filters, indice_key=indice_key) 171 | self.act2 = nn.LeakyReLU() 172 | self.bn2 = nn.BatchNorm1d(out_filters) 173 | 174 | self.conv3 = conv3x3(out_filters, out_filters, indice_key=indice_key) 175 | self.act3 = nn.LeakyReLU() 176 | self.bn3 = nn.BatchNorm1d(out_filters) 177 | # self.dropout3 = nn.Dropout3d(p=dropout_rate) 178 | 179 | self.up_subm = spconv.SparseInverseConv3d(out_filters, out_filters, kernel_size=3, indice_key=up_key, 180 | bias=False) 181 | 182 | self.weight_initialization() 183 | 184 | def weight_initialization(self): 185 | for m in self.modules(): 186 | if isinstance(m, nn.BatchNorm1d): 187 | nn.init.constant_(m.weight, 1) 188 | nn.init.constant_(m.bias, 0) 189 | 190 | def forward(self, x, skip): 191 | upA = self.trans_dilao(x) 192 | upA.features = self.trans_act(upA.features) 193 | upA.features = self.trans_bn(upA.features) 194 | 195 | ## upsample 196 | upA = self.up_subm(upA) 197 | 198 | upA.features = upA.features + skip.features 199 | 200 | upE = self.conv1(upA) 201 | upE.features = self.act1(upE.features) 202 | upE.features = self.bn1(upE.features) 203 | 204 | upE = self.conv2(upE) 205 | upE.features = self.act2(upE.features) 206 | upE.features = self.bn2(upE.features) 207 | 208 | upE = self.conv3(upE) 209 | upE.features = self.act3(upE.features) 210 | upE.features = self.bn3(upE.features) 211 | 212 | return upE 213 | 214 | 215 | class ReconBlock(nn.Module): 216 | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None): 217 | super(ReconBlock, self).__init__() 218 | self.conv1 = conv3x1x1(in_filters, out_filters, indice_key=indice_key + "bef") 219 | self.bn0 = nn.BatchNorm1d(out_filters) 220 | self.act1 = nn.Sigmoid() 221 | 222 | self.conv1_2 = conv1x3x1(in_filters, out_filters, indice_key=indice_key + "bef") 223 | self.bn0_2 = nn.BatchNorm1d(out_filters) 224 | self.act1_2 = nn.Sigmoid() 225 | 226 | self.conv1_3 = conv1x1x3(in_filters, out_filters, indice_key=indice_key + "bef") 227 | self.bn0_3 = nn.BatchNorm1d(out_filters) 228 | self.act1_3 = nn.Sigmoid() 229 | 230 | def forward(self, x): 231 | shortcut = self.conv1(x) 232 | shortcut.features = self.bn0(shortcut.features) 233 | shortcut.features = self.act1(shortcut.features) 234 | 235 | shortcut2 = self.conv1_2(x) 236 | shortcut2.features = self.bn0_2(shortcut2.features) 237 | shortcut2.features = self.act1_2(shortcut2.features) 238 | 239 | shortcut3 = self.conv1_3(x) 240 | shortcut3.features = self.bn0_3(shortcut3.features) 241 | shortcut3.features = self.act1_3(shortcut3.features) 242 | shortcut.features = shortcut.features + shortcut2.features + shortcut3.features 243 | 244 | shortcut.features = shortcut.features * x.features 245 | 246 | return shortcut 247 | 248 | 249 | def extract_nonzero_features(x): 250 | device = x.device 251 | nonzero_index = torch.sum(torch.abs(x), dim=1).nonzero() 252 | coords = nonzero_index.type(torch.int32).to(device) 253 | channels = int(x.shape[1]) 254 | features = x.permute(0, 2, 3, 4, 1).reshape(-1, channels) 255 | features = features[torch.sum(torch.abs(features), dim=1).nonzero(), :] 256 | features = features.squeeze(1).to(device) 257 | coords, _, _ = torch.unique(coords, return_inverse=True, return_counts=True, dim=0) 258 | return coords, features 259 | 260 | 261 | class Asymm_3d_spconv(nn.Module): 262 | def __init__(self, 263 | output_shape, 264 | use_norm=True, 265 | num_input_features=128, 266 | nclasses=20, n_height=32, strict=False, init_size=16): 267 | super(Asymm_3d_spconv, self).__init__() 268 | self.nclasses = nclasses 269 | self.nheight = n_height 270 | self.strict = False 271 | 272 | sparse_shape = np.array(output_shape) 273 | print(sparse_shape) 274 | self.sparse_shape = sparse_shape 275 | 276 | ### Completion sub-network 277 | mybias = False # False 278 | chs = [init_size, init_size*1, init_size*1, init_size*1] 279 | self.a_conv1 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 3, 1, padding=1, bias=mybias), nn.ReLU()) 280 | self.a_conv2 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 3, 1, padding=1, bias=mybias), nn.ReLU()) 281 | self.a_conv3 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 5, 1, padding=2, bias=mybias), nn.ReLU()) 282 | self.a_conv4 = nn.Sequential(nn.Conv3d(chs[1], chs[1], 7, 1, padding=3, bias=mybias), nn.ReLU()) 283 | self.a_conv5 = nn.Sequential(nn.Conv3d(chs[1]*3, chs[1], 3, 1, padding=1, bias=mybias), nn.ReLU()) 284 | self.a_conv6 = nn.Sequential(nn.Conv3d(chs[1]*3, chs[1], 5, 1, padding=2, bias=mybias), nn.ReLU()) 285 | self.a_conv7 = nn.Sequential(nn.Conv3d(chs[1]*3, chs[1], 7, 1, padding=3, bias=mybias), nn.ReLU()) 286 | self.ch_conv1 = nn.Sequential(nn.Conv3d(chs[1]*7, chs[0], kernel_size=1, stride=1, bias=mybias), nn.ReLU()) 287 | self.res_1 = nn.Sequential(nn.Conv3d(chs[0], chs[0], 3, 1, padding=1, bias=mybias), nn.ReLU()) 288 | self.res_2 = nn.Sequential(nn.Conv3d(chs[0], chs[0], 5, 1, padding=2, bias=mybias), nn.ReLU()) 289 | self.res_3 = nn.Sequential(nn.Conv3d(chs[0], chs[0], 7, 1, padding=3, bias=mybias), nn.ReLU()) 290 | 291 | ### Segmentation sub-network 292 | self.downCntx = ResContextBlock(num_input_features, init_size, indice_key="pre") 293 | self.resBlock2 = ResBlock(init_size, 2 * init_size, 0.2, height_pooling=True, indice_key="down2") 294 | self.resBlock3 = ResBlock(2 * init_size, 4 * init_size, 0.2, height_pooling=True, indice_key="down3") 295 | self.resBlock4 = ResBlock(4 * init_size, 8 * init_size, 0.2, pooling=True, height_pooling=False, 296 | indice_key="down4") 297 | self.resBlock5 = ResBlock(8 * init_size, 16 * init_size, 0.2, pooling=True, height_pooling=False, 298 | indice_key="down5") 299 | 300 | self.upBlock0 = UpBlock(16 * init_size, 16 * init_size, indice_key="up0", up_key="down5") 301 | self.upBlock1 = UpBlock(16 * init_size, 8 * init_size, indice_key="up1", up_key="down4") 302 | self.upBlock2 = UpBlock(8 * init_size, 4 * init_size, indice_key="up2", up_key="down3") 303 | self.upBlock3 = UpBlock(4 * init_size, 2 * init_size, indice_key="up3", up_key="down2") 304 | 305 | self.ReconNet = ReconBlock(2 * init_size, 2 * init_size, indice_key="recon") 306 | 307 | self.logits = spconv.SubMConv3d(4 * init_size, nclasses, indice_key="logit", kernel_size=3, stride=1, padding=1, 308 | bias=True) 309 | 310 | 311 | def forward(self, voxel_features, coors, batch_size): 312 | coors = coors.int() 313 | x_sparse = spconv.SparseConvTensor(voxel_features, coors, self.sparse_shape, batch_size) 314 | x = x_sparse 315 | 316 | debug = 0 317 | 318 | # Spase to dense 319 | x_dense = x_sparse.dense() 320 | 321 | ### Completion sub-network by dense convolution 322 | x1 = self.a_conv1(x_dense) 323 | x2 = self.a_conv2(x1) 324 | x3 = self.a_conv3(x1) 325 | x4 = self.a_conv4(x1) 326 | t1 = torch.cat((x2, x3, x4), 1) 327 | x5 = self.a_conv5(t1) 328 | x6 = self.a_conv6(t1) 329 | x7 = self.a_conv7(t1) 330 | x = torch.cat((x1, x2, x3, x4, x5, x6, x7), 1) 331 | y0 = self.ch_conv1(x) 332 | y1 = self.res_1(x_dense) 333 | y2 = self.res_2(x_dense) 334 | y3 = self.res_3(x_dense) 335 | x = x_dense + y0 + y1 + y2 + y3 336 | 337 | # Dense to sparse 338 | coord, features = extract_nonzero_features(x) 339 | x = spconv.SparseConvTensor(features, coord.int(), self.sparse_shape, batch_size) # voxel features 340 | 341 | ### Segmentation sub-network by sparse convolution 342 | x = self.downCntx(x) 343 | down1c, down1b = self.resBlock2(x) 344 | down2c, down2b = self.resBlock3(down1c) 345 | down3c, down3b = self.resBlock4(down2c) 346 | down4c, down4b = self.resBlock5(down3c) 347 | 348 | up4e = self.upBlock0(down4c, down4b) 349 | up3e = self.upBlock1(up4e, down3b) 350 | up2e = self.upBlock2(up3e, down2b) 351 | up1e = self.upBlock3(up2e, down1b) 352 | 353 | up0e = self.ReconNet(up1e) 354 | 355 | up0e.features = torch.cat((up0e.features, up1e.features), 1) 356 | 357 | logits = self.logits(up0e) 358 | y = logits.dense() 359 | 360 | if debug == 1: 361 | print(y.shape) 362 | assert 1==0 363 | 364 | return y 365 | 366 | @staticmethod 367 | def _joining(encoder_features, x, concat): 368 | if concat: 369 | return torch.cat((encoder_features, x), dim=1) 370 | else: 371 | return encoder_features + x -------------------------------------------------------------------------------- /test_scpnet_comp.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge, Xzy 3 | # @file: train_cylinder_asym.py 4 | 5 | 6 | import os 7 | import time 8 | import argparse 9 | import sys 10 | import numpy as np 11 | import torch 12 | from tqdm import tqdm 13 | 14 | # from utils.metric_util import per_class_iu, fast_hist_crop 15 | from dataloader.pc_dataset import get_SemKITTI_label_name 16 | from builder import data_builder, model_builder, loss_builder 17 | from config.config import load_config_data 18 | 19 | from utils.load_save_util import load_checkpoint 20 | 21 | import warnings 22 | 23 | warnings.filterwarnings("ignore") 24 | import yaml 25 | 26 | 27 | def train2SemKITTI(input_label): 28 | # delete 0 label (uses uint8 trick : 0 - 1 = 255 ) 29 | return input_label + 1 30 | 31 | def main(args): 32 | pytorch_device = torch.device('cuda:0') 33 | 34 | config_path = args.config_path 35 | 36 | configs = load_config_data(config_path) 37 | 38 | dataset_config = configs['dataset_params'] 39 | train_dataloader_config = configs['train_data_loader'] 40 | val_dataloader_config = configs['val_data_loader'] 41 | 42 | val_batch_size = val_dataloader_config['batch_size'] 43 | train_batch_size = train_dataloader_config['batch_size'] 44 | 45 | model_config = configs['model_params'] 46 | train_hypers = configs['train_params'] 47 | 48 | grid_size = model_config['output_shape'] 49 | num_class = model_config['num_class'] 50 | ignore_label = dataset_config['ignore_label'] 51 | 52 | model_load_path = train_hypers['model_load_path'] 53 | 54 | SemKITTI_label_name = get_SemKITTI_label_name(dataset_config["label_mapping"]) 55 | unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1 56 | unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1] 57 | 58 | my_model = model_builder.build(model_config) 59 | model_load_path += 'iou26.6891_epoch19.pth' 60 | if os.path.exists(model_load_path): 61 | print('Load model from: %s' % model_load_path) 62 | my_model = load_checkpoint(model_load_path, my_model) 63 | else: 64 | print('No existing model, training model from scratch...') 65 | 66 | my_model.to(pytorch_device) 67 | 68 | _, test_dataset_loader, test_pt_dataset = data_builder.build(dataset_config, 69 | train_dataloader_config, 70 | val_dataloader_config, 71 | grid_size=grid_size, 72 | use_tta=True, 73 | use_multiscan=True) 74 | 75 | # training 76 | dataset_name = val_dataloader_config["imageset"] 77 | output_path = 'out_scpnet/' + dataset_name 78 | 79 | if True: 80 | print('Generate predictions for test split') 81 | pbar = tqdm(total=len(test_dataset_loader)) 82 | time.sleep(10) 83 | ### learning map 84 | with open("config/label_mapping/semantic-kitti.yaml", 'r') as stream: 85 | semkittiyaml = yaml.safe_load(stream) 86 | # make lookup table for mapping 87 | learning_map_inv = semkittiyaml["learning_map_inv"] 88 | maxkey = max(learning_map_inv.keys()) 89 | # +100 hack making lut bigger just in case there are unknown labels 90 | remap_lut_First = np.zeros((maxkey + 100), dtype=np.int32) 91 | remap_lut_First[list(learning_map_inv.keys())] = list(learning_map_inv.values()) 92 | 93 | if True: 94 | if True: 95 | my_model.eval() 96 | with torch.no_grad(): 97 | for i_iter_test, (_, _, test_grid, _, test_pt_fea, test_index, origin_len) in enumerate( 98 | test_dataset_loader): 99 | 100 | test_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in 101 | test_pt_fea] 102 | test_grid_ten = [torch.from_numpy(i).to(pytorch_device) for i in test_grid] 103 | 104 | predict_labels = my_model(test_pt_fea_ten, test_grid_ten, val_batch_size, test_grid, use_tta=False) 105 | predict_labels = torch.argmax(predict_labels, dim=1) 106 | predict_labels = predict_labels.cpu().detach().numpy() 107 | if True: 108 | test_pred_label = np.squeeze(predict_labels) 109 | 110 | ### save prediction after remapping 111 | pred = test_pred_label 112 | pred = pred.astype(np.uint32) 113 | pred = pred.reshape((-1)) 114 | upper_half = pred >> 16 # get upper half for instances 115 | lower_half = pred & 0xFFFF # get lower half for semantics 116 | lower_half = remap_lut_First[lower_half] # do the remapping of semantics 117 | pred = (upper_half << 16) + lower_half # reconstruct full label 118 | pred = pred.astype(np.uint32) 119 | final_preds = pred.astype(np.uint16) 120 | 121 | save_dir = test_pt_dataset.im_idx[test_index[0]] 122 | _,dir2 = save_dir.split('/sequences/',1) 123 | new_save_dir = output_path + '/sequences/' +dir2.replace('velodyne', 'predictions')[:-3]+'label' 124 | if not os.path.exists(os.path.dirname(new_save_dir)): 125 | try: 126 | os.makedirs(os.path.dirname(new_save_dir)) 127 | except OSError as exc: 128 | if exc.errno != errno.EEXIST: 129 | raise 130 | final_preds.tofile(new_save_dir) 131 | 132 | pbar.update(1) 133 | del test_grid, test_pt_fea, test_grid_ten, test_index 134 | pbar.close() 135 | print('Predicted test labels are saved in %s. Need to be shifted to original label format before submitting to the Competition website.' % output_path) 136 | print('Remapping script can be found in semantic-kitti-api.') 137 | 138 | if __name__ == '__main__': 139 | # Training settings 140 | parser = argparse.ArgumentParser(description='') 141 | parser.add_argument('-y', '--config_path', default='config/semantickitti-multiscan.yaml') 142 | args = parser.parse_args() 143 | 144 | print(' '.join(sys.argv)) 145 | print(args) 146 | main(args) 147 | -------------------------------------------------------------------------------- /train_scpnet_comp.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge, Xzy 3 | # @file: train_cylinder_asym.py 4 | 5 | 6 | import os 7 | import time 8 | import argparse 9 | import sys 10 | import numpy as np 11 | import torch 12 | import torch.optim as optim 13 | from tqdm import tqdm 14 | 15 | # from utils.metric_util import per_class_iu, fast_hist_crop 16 | from dataloader.pc_dataset import get_SemKITTI_label_name, get_eval_mask, unpack 17 | from builder import data_builder, model_builder, loss_builder 18 | from config.config import load_config_data 19 | 20 | from utils.load_save_util import load_checkpoint 21 | 22 | import warnings 23 | from utils.np_ioueval import iouEval 24 | import yaml 25 | 26 | warnings.filterwarnings("ignore") 27 | 28 | 29 | def main(args): 30 | pytorch_device = torch.device('cuda:0') 31 | 32 | config_path = args.config_path 33 | 34 | configs = load_config_data(config_path) 35 | 36 | dataset_config = configs['dataset_params'] 37 | train_dataloader_config = configs['train_data_loader'] 38 | val_dataloader_config = configs['val_data_loader'] 39 | 40 | val_batch_size = val_dataloader_config['batch_size'] 41 | train_batch_size = train_dataloader_config['batch_size'] 42 | 43 | model_config = configs['model_params'] 44 | train_hypers = configs['train_params'] 45 | 46 | grid_size = model_config['output_shape'] 47 | num_class = model_config['num_class'] 48 | ignore_label = dataset_config['ignore_label'] 49 | 50 | model_load_path = train_hypers['model_load_path'] 51 | model_save_path = train_hypers['model_save_path'] 52 | 53 | SemKITTI_label_name = get_SemKITTI_label_name(dataset_config["label_mapping"]) 54 | unique_label = np.asarray(sorted(list(SemKITTI_label_name.keys())))[1:] - 1 55 | unique_label_str = [SemKITTI_label_name[x] for x in unique_label + 1] 56 | 57 | my_model = model_builder.build(model_config) 58 | model_load_path += '0.pth' 59 | model_save_path += '' 60 | if os.path.exists(model_load_path): 61 | print('Load model from: %s' % model_load_path) 62 | my_model = load_checkpoint(model_load_path, my_model) 63 | else: 64 | print('No existing model, training model from scratch...') 65 | 66 | if not os.path.exists(model_save_path): 67 | os.makedirs(model_save_path) 68 | print(model_save_path) 69 | 70 | my_model.to(pytorch_device) 71 | optimizer = optim.Adam(my_model.parameters(), lr=train_hypers["learning_rate"]) 72 | 73 | loss_func, lovasz_softmax = loss_builder.build(wce=True, lovasz=True, 74 | num_class=num_class, ignore_label=ignore_label) 75 | 76 | train_dataset_loader, val_dataset_loader, val_pt_dataset = data_builder.build(dataset_config, 77 | train_dataloader_config, 78 | val_dataloader_config, 79 | grid_size=grid_size, 80 | use_tta=False, 81 | use_multiscan=True) 82 | 83 | # training 84 | epoch = 0 85 | best_val_miou = 0 86 | my_model.train() 87 | global_iter = 0 88 | check_iter = train_hypers['eval_every_n_steps'] 89 | 90 | # learning map 91 | with open("config/label_mapping/semantic-kitti.yaml", 'r') as stream: 92 | semkittiyaml = yaml.safe_load(stream) 93 | class_strings = semkittiyaml["labels"] 94 | class_inv_remap = semkittiyaml["learning_map_inv"] 95 | 96 | while epoch < train_hypers['max_num_epochs']: 97 | loss_list = [] 98 | pbar = tqdm(total=len(train_dataset_loader)) 99 | time.sleep(10) 100 | # lr_scheduler.step(epoch) 101 | for i_iter, (_, train_vox_label, train_grid, _, train_pt_fea, train_index, origin_len) in enumerate(train_dataset_loader): 102 | 103 | if global_iter % check_iter == 0 and epoch > 0: 104 | my_model.eval() 105 | 106 | val_loss_list = [] 107 | val_method = 2 # 1-segmentation method, 2-completion method 108 | if val_method == 1: 109 | hist_list = [] 110 | else: 111 | evaluator = iouEval(num_class, []) 112 | with torch.no_grad(): 113 | for i_iter_val, (_, val_vox_label, val_grid, _, val_pt_fea, val_index, origin_len) in enumerate( 114 | val_dataset_loader): 115 | 116 | val_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in 117 | val_pt_fea] 118 | val_grid_ten = [torch.from_numpy(i).to(pytorch_device) for i in val_grid] 119 | 120 | for bat in range(val_batch_size): 121 | 122 | val_label_tensor = val_vox_label[bat,:].type(torch.LongTensor).to(pytorch_device) 123 | val_label_tensor = torch.unsqueeze(val_label_tensor, 0) 124 | predict_labels = my_model(val_pt_fea_ten, val_grid_ten, val_batch_size) 125 | loss = lovasz_softmax(torch.nn.functional.softmax(predict_labels).detach(), val_label_tensor, 126 | ignore=ignore_label) + loss_func(predict_labels.detach(), val_label_tensor) 127 | 128 | predict_labels = torch.argmax(predict_labels, dim=1) 129 | predict_labels = predict_labels.cpu().detach().numpy() 130 | predict_labels = np.squeeze(predict_labels) 131 | val_vox_label0 = val_vox_label[bat, :].cpu().detach().numpy() 132 | val_vox_label0 = np.squeeze(val_vox_label0) 133 | 134 | val_name = val_pt_dataset.im_idx[val_index[0]] 135 | 136 | invalid_name = val_name.replace('velodyne', 'voxels')[:-3]+'invalid' 137 | invalid_voxels = unpack(np.fromfile(invalid_name, dtype=np.uint8)) # voxel labels 138 | invalid_voxels = invalid_voxels.reshape((256, 256, 32)) 139 | masks = get_eval_mask(val_vox_label0, invalid_voxels) 140 | predict_labels = predict_labels[masks] 141 | val_vox_label0 = val_vox_label0[masks] 142 | 143 | evaluator.addBatch(predict_labels.astype(int), val_vox_label0.astype(int)) 144 | 145 | val_loss_list.append(loss.detach().cpu().numpy()) 146 | 147 | # my_model.train() 148 | print('Validation per class iou: ') 149 | _, class_jaccard = evaluator.getIoU() 150 | m_jaccard = class_jaccard[1:].mean() 151 | iou = class_jaccard 152 | val_miou = m_jaccard * 100 153 | ignore = [0] 154 | # print also classwise 155 | for i, jacc in enumerate(class_jaccard): 156 | if i not in ignore: 157 | print('IoU class {i:} [{class_str:}] = {jacc:.3f}'.format( 158 | i=i, class_str=class_strings[class_inv_remap[i]], jacc=jacc)) 159 | # compute remaining metrics. 160 | conf = evaluator.get_confusion() 161 | acc_completion = (np.sum(conf[1:, 1:])) / (np.sum(conf) - conf[0, 0]) 162 | print('Current val completion iou is %.3f' % acc_completion) 163 | 164 | del val_vox_label, val_grid, val_pt_fea, val_pt_fea_ten, val_grid_ten, val_label_tensor 165 | 166 | # save model if performance is improved 167 | if best_val_miou < val_miou: 168 | best_val_miou = val_miou 169 | # save model with best val miou for completion 170 | model_save_name = model_save_path + ('iou%.4f_epoch%d.pth' % (val_miou, epoch)) 171 | torch.save(my_model.state_dict(), model_save_name) 172 | 173 | print('Current val miou is %.3f while the best val miou is %.3f' % 174 | (val_miou, best_val_miou)) 175 | print('Current val loss is %.3f' % 176 | (np.mean(val_loss_list))) 177 | 178 | my_model.train() 179 | 180 | train_pt_fea_ten = [torch.from_numpy(i).type(torch.FloatTensor).to(pytorch_device) for i in train_pt_fea] 181 | train_vox_ten = [torch.from_numpy(i).to(pytorch_device) for i in train_grid] 182 | point_label_tensor = train_vox_label.type(torch.LongTensor).to(pytorch_device) 183 | 184 | # forward + backward + optimize 185 | outputs = my_model(train_pt_fea_ten, train_vox_ten, point_label_tensor.shape[0]) 186 | loss = lovasz_softmax(torch.nn.functional.softmax(outputs), point_label_tensor, ignore=ignore_label) + loss_func( 187 | outputs, point_label_tensor) 188 | loss.backward() 189 | optimizer.step() 190 | loss_list.append(loss.item()) 191 | 192 | if global_iter % 1000 == 0: 193 | if len(loss_list) > 0: 194 | print('epoch %d iter %5d, loss: %.3f\n' % 195 | (epoch, i_iter, np.mean(loss_list))) 196 | else: 197 | print('loss error') 198 | 199 | optimizer.zero_grad() 200 | pbar.update(1) 201 | global_iter += 1 202 | if global_iter % check_iter == 0: 203 | if len(loss_list) > 0: 204 | print('epoch %d iter %5d, loss: %.3f\n' % 205 | (epoch, i_iter, np.mean(loss_list))) 206 | else: 207 | print('loss error') 208 | pbar.close() 209 | epoch += 1 210 | 211 | 212 | if __name__ == '__main__': 213 | # Training settings 214 | parser = argparse.ArgumentParser(description='') 215 | parser.add_argument('-y', '--config_path', default='config/semantickitti-multiscan.yaml') 216 | args = parser.parse_args() 217 | 218 | print(' '.join(sys.argv)) 219 | print(args) 220 | main(args) 221 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: __init__.py.py 4 | -------------------------------------------------------------------------------- /utils/load_save_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: load_save_util.py 4 | 5 | import torch 6 | 7 | 8 | # def load_checkpoint_old2(model_load_path, model): 9 | def load_checkpoint(model_load_path, model): 10 | pre_weight = torch.load(model_load_path) 11 | my_model_dict = model.state_dict() 12 | part_load = {} 13 | match_size = 0 14 | nomatch_size = 0 15 | for k in pre_weight.keys(): 16 | value = pre_weight[k] 17 | # str3 = 'seg_head.sparseModel.1.weight' 18 | # if k.find(str3) > 0: 19 | # value = value[:, :, 0, :] 20 | if k in my_model_dict and my_model_dict[k].shape == value.shape: 21 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 22 | match_size += 1 23 | part_load[k] = value 24 | else: 25 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 26 | Lvalue = len(value.shape) 27 | assert 1 <= Lvalue <= 5 28 | if len(value.shape) == 1: 29 | c = value.shape[0] 30 | cc = my_model_dict[k].shape[0] - c #int(c*0.5) 31 | if 0 < cc <= c: 32 | value = torch.cat([value, value[:cc]], dim=0) 33 | elif cc > c: 34 | value = torch.cat([value, value, value[:(cc-c)]], dim=0) 35 | elif cc < 0: 36 | value = value[:-cc] 37 | else: 38 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 39 | cs = value.shape 40 | ccs = [0]*Lvalue 41 | j = -1 42 | for ci in cs: 43 | j += 1 44 | ccs[j] = (my_model_dict[k].shape[j] - ci) 45 | if ccs[j] != 0: 46 | for m in range(Lvalue): 47 | if m != j: 48 | ccs[m] = value.shape[m] 49 | # print(ccs) 50 | if ccs[j] > 0: 51 | if ccs[j] > ci: 52 | ccs[j] = ccs[j] - ci 53 | if Lvalue == 5: 54 | value = torch.cat([value, value[:ccs[0], :ccs[1], :ccs[2], :ccs[3], :ccs[4]]], dim=j) 55 | elif Lvalue == 4: 56 | # print(value[:ccs[0], :ccs[1], :ccs[2], :ccs[3]].shape) 57 | value = torch.cat([value, value[:ccs[0], :ccs[1], :ccs[2], :ccs[3]]], dim=j) 58 | elif Lvalue == 3: 59 | value = torch.cat([value, value[:ccs[0], :ccs[1], :ccs[2]]], dim=j) 60 | elif Lvalue == 2: 61 | value = torch.cat([value, value[:ccs[0], :ccs[1]]], dim=j) 62 | ccs[j] = value.shape[j] 63 | elif ccs[j] < 0: 64 | # ccs[j] = -ccs[j] 65 | ccs[j] = my_model_dict[k].shape[j] 66 | if j == 0: 67 | value = value[:ccs[0], :] 68 | elif j == 1: 69 | if Lvalue == 2: 70 | value = value[:, :ccs[1]] 71 | else: 72 | value = value[:, :ccs[1], :] 73 | elif j == 2: 74 | if Lvalue == 3: 75 | value = value[:, :, :ccs[2]] 76 | else: 77 | value = value[:, :, :ccs[2], :] 78 | elif j == 3 and j <= Lvalue-1: 79 | if Lvalue == 4: 80 | value = value[:, :, :, :ccs[3]] 81 | else: 82 | value = value[:, :, :, :ccs[3], :] 83 | elif j == 4 and j <= Lvalue-1: 84 | value = value[:, :, :, :, :ccs[4]] 85 | 86 | nomatch_size += 1 87 | if my_model_dict[k].shape == value.shape: 88 | part_load[k] = value 89 | # print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 90 | # assert my_model_dict[k].shape == value.shape 91 | 92 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 93 | 94 | my_model_dict.update(part_load) 95 | model.load_state_dict(my_model_dict) 96 | # model.load_state_dict(my_model_dict, strict=False) # True 97 | 98 | return model 99 | 100 | 101 | def load_checkpoint_old2(model_load_path, model): 102 | my_model_dict = model.state_dict() 103 | pre_weight = torch.load(model_load_path) 104 | 105 | part_load = {} 106 | match_size = 0 107 | nomatch_size = 0 108 | for k in pre_weight.keys(): 109 | value = pre_weight[k] 110 | if k in my_model_dict and my_model_dict[k].shape == value.shape: 111 | #print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 112 | match_size += 1 113 | part_load[k] = value 114 | else: 115 | assert len(value.shape) == 1 or len(value.shape) == 5 116 | if len(value.shape) == 1: 117 | c = value.shape[0] 118 | cc = my_model_dict[k].shape[0] - c #int(c*0.5) 119 | if cc <= c: 120 | value = torch.cat([value, value[:cc]], dim=0) 121 | else: 122 | value = torch.cat([value, value, value[:(cc-c)]], dim=0) 123 | else: 124 | _, _, _, c1, c2 = value.shape 125 | cc1 = my_model_dict[k].shape[3] - c1 #int(c1*0.5) 126 | cc2 = my_model_dict[k].shape[4] - c2 #int(c2*0.5) 127 | if cc1 > 0 and cc1 <= c1: 128 | value1 = torch.cat([value, value[:, :, :, :cc1, :]], dim=3) 129 | elif cc1 > c1: 130 | value1 = torch.cat([value, value, value[:, :, :, :(cc1-c1), :]], dim=3) 131 | else: 132 | value1 = value 133 | if cc2 > 0 and cc2 <= c2: 134 | value = torch.cat([value1, value1[:, :, :, :, :cc2]], dim=4) 135 | elif cc2 > c2: 136 | value = torch.cat([value1, value1, value1[:, :, :, :, :(cc2-c2)]], dim=4) 137 | else: 138 | value = value1 139 | nomatch_size += 1 140 | part_load[k] = value 141 | assert my_model_dict[k].shape == value.shape 142 | #print("model shape:{}, pre shape:{}".format(str(my_model_dict[k].shape), str(value.shape))) 143 | 144 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 145 | 146 | my_model_dict.update(part_load) 147 | # model.load_state_dict(my_model_dict) 148 | model.load_state_dict(my_model_dict, strict=False) # True 149 | 150 | return model 151 | 152 | def load_checkpoint_old(model_load_path, model): 153 | my_model_dict = model.state_dict() 154 | pre_weight = torch.load(model_load_path) 155 | 156 | part_load = {} 157 | match_size = 0 158 | nomatch_size = 0 159 | for k in pre_weight.keys(): 160 | value = pre_weight[k] 161 | if k in my_model_dict and my_model_dict[k].shape == value.shape: 162 | # print("loading ", k) 163 | match_size += 1 164 | part_load[k] = value 165 | else: 166 | nomatch_size += 1 167 | 168 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 169 | 170 | my_model_dict.update(part_load) 171 | model.load_state_dict(my_model_dict) 172 | 173 | return model 174 | 175 | def load_checkpoint_1b1(model_load_path, model): 176 | my_model_dict = model.state_dict() 177 | pre_weight = torch.load(model_load_path) 178 | 179 | part_load = {} 180 | match_size = 0 181 | nomatch_size = 0 182 | 183 | pre_weight_list = [*pre_weight] 184 | my_model_dict_list = [*my_model_dict] 185 | 186 | for idx in range(len(pre_weight_list)): 187 | key_ = pre_weight_list[idx] 188 | key_2 = my_model_dict_list[idx] 189 | value_ = pre_weight[key_] 190 | if my_model_dict[key_2].shape == pre_weight[key_].shape: 191 | # print("loading ", k) 192 | match_size += 1 193 | part_load[key_2] = value_ 194 | else: 195 | print(key_) 196 | print(key_2) 197 | nomatch_size += 1 198 | 199 | print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size)) 200 | 201 | my_model_dict.update(part_load) 202 | model.load_state_dict(my_model_dict) 203 | 204 | return model 205 | -------------------------------------------------------------------------------- /utils/log_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: log_util.py 4 | 5 | 6 | def save_to_log(logdir, logfile, message): 7 | f = open(logdir + '/' + logfile, "a") 8 | f.write(message + '\n') 9 | f.close() 10 | return -------------------------------------------------------------------------------- /utils/lovasz_losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | 4 | 5 | """ 6 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 7 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 8 | """ 9 | 10 | from __future__ import print_function, division 11 | 12 | import torch 13 | from torch.autograd import Variable 14 | import torch.nn.functional as F 15 | import numpy as np 16 | try: 17 | from itertools import ifilterfalse 18 | except ImportError: # py3k 19 | from itertools import filterfalse as ifilterfalse 20 | 21 | def lovasz_grad(gt_sorted): 22 | """ 23 | Computes gradient of the Lovasz extension w.r.t sorted errors 24 | See Alg. 1 in paper 25 | """ 26 | p = len(gt_sorted) 27 | gts = gt_sorted.sum() 28 | intersection = gts - gt_sorted.float().cumsum(0) 29 | union = gts + (1 - gt_sorted).float().cumsum(0) 30 | jaccard = 1. - intersection / union 31 | if p > 1: # cover 1-pixel case 32 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 33 | return jaccard 34 | 35 | 36 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 37 | """ 38 | IoU for foreground class 39 | binary: 1 foreground, 0 background 40 | """ 41 | if not per_image: 42 | preds, labels = (preds,), (labels,) 43 | ious = [] 44 | for pred, label in zip(preds, labels): 45 | intersection = ((label == 1) & (pred == 1)).sum() 46 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 47 | if not union: 48 | iou = EMPTY 49 | else: 50 | iou = float(intersection) / float(union) 51 | ious.append(iou) 52 | iou = mean(ious) # mean accross images if per_image 53 | return 100 * iou 54 | 55 | 56 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 57 | """ 58 | Array of IoU for each (non ignored) class 59 | """ 60 | if not per_image: 61 | preds, labels = (preds,), (labels,) 62 | ious = [] 63 | for pred, label in zip(preds, labels): 64 | iou = [] 65 | for i in range(C): 66 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 67 | intersection = ((label == i) & (pred == i)).sum() 68 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 69 | if not union: 70 | iou.append(EMPTY) 71 | else: 72 | iou.append(float(intersection) / float(union)) 73 | ious.append(iou) 74 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 75 | return 100 * np.array(ious) 76 | 77 | 78 | # --------------------------- BINARY LOSSES --------------------------- 79 | 80 | 81 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 82 | """ 83 | Binary Lovasz hinge loss 84 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 85 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 86 | per_image: compute the loss per image instead of per batch 87 | ignore: void class id 88 | """ 89 | if per_image: 90 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 91 | for log, lab in zip(logits, labels)) 92 | else: 93 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 94 | return loss 95 | 96 | 97 | def lovasz_hinge_flat(logits, labels): 98 | """ 99 | Binary Lovasz hinge loss 100 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 101 | labels: [P] Tensor, binary ground truth labels (0 or 1) 102 | ignore: label to ignore 103 | """ 104 | if len(labels) == 0: 105 | # only void pixels, the gradients should be 0 106 | return logits.sum() * 0. 107 | signs = 2. * labels.float() - 1. 108 | errors = (1. - logits * Variable(signs)) 109 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 110 | perm = perm.data 111 | gt_sorted = labels[perm] 112 | grad = lovasz_grad(gt_sorted) 113 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 114 | return loss 115 | 116 | 117 | def flatten_binary_scores(scores, labels, ignore=None): 118 | """ 119 | Flattens predictions in the batch (binary case) 120 | Remove labels equal to 'ignore' 121 | """ 122 | scores = scores.view(-1) 123 | labels = labels.view(-1) 124 | if ignore is None: 125 | return scores, labels 126 | valid = (labels != ignore) 127 | vscores = scores[valid] 128 | vlabels = labels[valid] 129 | return vscores, vlabels 130 | 131 | 132 | class StableBCELoss(torch.nn.modules.Module): 133 | def __init__(self): 134 | super(StableBCELoss, self).__init__() 135 | def forward(self, input, target): 136 | neg_abs = - input.abs() 137 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 138 | return loss.mean() 139 | 140 | 141 | def binary_xloss(logits, labels, ignore=None): 142 | """ 143 | Binary Cross entropy loss 144 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 145 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 146 | ignore: void class id 147 | """ 148 | logits, labels = flatten_binary_scores(logits, labels, ignore) 149 | loss = StableBCELoss()(logits, Variable(labels.float())) 150 | return loss 151 | 152 | 153 | # --------------------------- MULTICLASS LOSSES --------------------------- 154 | 155 | 156 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 157 | """ 158 | Multi-class Lovasz-Softmax loss 159 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 160 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 161 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 162 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 163 | per_image: compute the loss per image instead of per batch 164 | ignore: void class labels 165 | """ 166 | if per_image: 167 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 168 | for prob, lab in zip(probas, labels)) 169 | else: 170 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 171 | return loss 172 | 173 | 174 | def lovasz_softmax_flat(probas, labels, classes='present'): 175 | """ 176 | Multi-class Lovasz-Softmax loss 177 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 178 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 179 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 180 | """ 181 | if probas.numel() == 0: 182 | # only void pixels, the gradients should be 0 183 | return probas * 0. 184 | C = probas.size(1) 185 | losses = [] 186 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 187 | for c in class_to_sum: 188 | fg = (labels == c).float() # foreground for class c 189 | if (classes is 'present' and fg.sum() == 0): 190 | continue 191 | if C == 1: 192 | if len(classes) > 1: 193 | raise ValueError('Sigmoid output possible only with 1 class') 194 | class_pred = probas[:, 0] 195 | else: 196 | class_pred = probas[:, c] 197 | errors = (Variable(fg) - class_pred).abs() 198 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 199 | perm = perm.data 200 | fg_sorted = fg[perm] 201 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 202 | return mean(losses) 203 | 204 | 205 | def flatten_probas(probas, labels, ignore=None): 206 | """ 207 | Flattens predictions in the batch 208 | """ 209 | if probas.dim() == 3: 210 | # assumes output of a sigmoid layer 211 | B, H, W = probas.size() 212 | probas = probas.view(B, 1, H, W) 213 | elif probas.dim() == 5: 214 | #3D segmentation 215 | B, C, L, H, W = probas.size() 216 | probas = probas.contiguous().view(B, C, L, H*W) 217 | B, C, H, W = probas.size() 218 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 219 | labels = labels.view(-1) 220 | if ignore is None: 221 | return probas, labels 222 | valid = (labels != ignore) 223 | vprobas = probas[valid.nonzero().squeeze()] 224 | vlabels = labels[valid] 225 | return vprobas, vlabels 226 | 227 | def xloss(logits, labels, ignore=None): 228 | """ 229 | Cross entropy loss 230 | """ 231 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 232 | 233 | def jaccard_loss(probas, labels,ignore=None, smooth = 100, bk_class = None): 234 | """ 235 | Something wrong with this loss 236 | Multi-class Lovasz-Softmax loss 237 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 238 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 239 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 240 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 241 | per_image: compute the loss per image instead of per batch 242 | ignore: void class labels 243 | """ 244 | vprobas, vlabels = flatten_probas(probas, labels, ignore) 245 | 246 | 247 | true_1_hot = torch.eye(vprobas.shape[1])[vlabels] 248 | 249 | if bk_class: 250 | one_hot_assignment = torch.ones_like(vlabels) 251 | one_hot_assignment[vlabels == bk_class] = 0 252 | one_hot_assignment = one_hot_assignment.float().unsqueeze(1) 253 | true_1_hot = true_1_hot*one_hot_assignment 254 | 255 | true_1_hot = true_1_hot.to(vprobas.device) 256 | intersection = torch.sum(vprobas * true_1_hot) 257 | cardinality = torch.sum(vprobas + true_1_hot) 258 | loss = (intersection + smooth / (cardinality - intersection + smooth)).mean() 259 | return (1-loss)*smooth 260 | 261 | def hinge_jaccard_loss(probas, labels,ignore=None, classes = 'present', hinge = 0.1, smooth =100): 262 | """ 263 | Multi-class Hinge Jaccard loss 264 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 265 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 266 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 267 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 268 | ignore: void class labels 269 | """ 270 | vprobas, vlabels = flatten_probas(probas, labels, ignore) 271 | C = vprobas.size(1) 272 | losses = [] 273 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 274 | for c in class_to_sum: 275 | if c in vlabels: 276 | c_sample_ind = vlabels == c 277 | cprobas = vprobas[c_sample_ind,:] 278 | non_c_ind =np.array([a for a in class_to_sum if a != c]) 279 | class_pred = cprobas[:,c] 280 | max_non_class_pred = torch.max(cprobas[:,non_c_ind],dim = 1)[0] 281 | TP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) + smooth 282 | FN = torch.sum(torch.clamp(max_non_class_pred - class_pred, min = -hinge)+hinge) 283 | 284 | if (~c_sample_ind).sum() == 0: 285 | FP = 0 286 | else: 287 | nonc_probas = vprobas[~c_sample_ind,:] 288 | class_pred = nonc_probas[:,c] 289 | max_non_class_pred = torch.max(nonc_probas[:,non_c_ind],dim = 1)[0] 290 | FP = torch.sum(torch.clamp(class_pred - max_non_class_pred, max = hinge)+1.) 291 | 292 | losses.append(1 - TP/(TP+FP+FN)) 293 | 294 | if len(losses) == 0: return 0 295 | return mean(losses) 296 | 297 | # --------------------------- HELPER FUNCTIONS --------------------------- 298 | def isnan(x): 299 | return x != x 300 | 301 | 302 | def mean(l, ignore_nan=False, empty=0): 303 | """ 304 | nanmean compatible with generators. 305 | """ 306 | l = iter(l) 307 | if ignore_nan: 308 | l = ifilterfalse(isnan, l) 309 | try: 310 | n = 1 311 | acc = next(l) 312 | except StopIteration: 313 | if empty == 'raise': 314 | raise ValueError('Empty mean') 315 | return empty 316 | for n, v in enumerate(l, 2): 317 | acc += v 318 | if n == 1: 319 | return acc 320 | return acc / n 321 | -------------------------------------------------------------------------------- /utils/metric_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # author: Xinge 3 | # @file: metric_util.py 4 | 5 | import numpy as np 6 | 7 | 8 | def fast_hist(pred, label, n): 9 | k = (label >= 0) & (label < n) 10 | bin_count = np.bincount( 11 | n * label[k].astype(int) + pred[k], minlength=n ** 2) 12 | return bin_count[:n ** 2].reshape(n, n) 13 | 14 | 15 | def per_class_iu(hist): 16 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 17 | 18 | 19 | def fast_hist_crop(output, target, unique_label): 20 | hist = fast_hist(output.flatten(), target.flatten(), np.max(unique_label) + 2) 21 | hist = hist[unique_label + 1, :] 22 | hist = hist[:, unique_label + 1] 23 | return hist 24 | -------------------------------------------------------------------------------- /utils/np_ioueval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This file is covered by the LICENSE file in the root of this project. 3 | 4 | import sys 5 | import numpy as np 6 | 7 | 8 | class iouEval: 9 | def __init__(self, n_classes, ignore=None): 10 | # classes 11 | self.n_classes = n_classes 12 | 13 | # What to include and ignore from the means 14 | self.ignore = np.array(ignore, dtype=np.int64) 15 | self.include = np.array( 16 | [n for n in range(self.n_classes) if n not in self.ignore], dtype=np.int64) 17 | print("[IOU EVAL] IGNORE: ", self.ignore) 18 | print("[IOU EVAL] INCLUDE: ", self.include) 19 | 20 | # reset the class counters 21 | self.reset() 22 | 23 | def num_classes(self): 24 | return self.n_classes 25 | 26 | def reset(self): 27 | self.conf_matrix = np.zeros((self.n_classes, 28 | self.n_classes), 29 | dtype=np.int64) 30 | 31 | def addBatch(self, x, y): # x=preds, y=targets 32 | # sizes should be matching 33 | x_row = x.reshape(-1) # de-batchify 34 | y_row = y.reshape(-1) # de-batchify 35 | 36 | # check 37 | assert(x_row.shape == y_row.shape) 38 | 39 | # create indexes 40 | idxs = tuple(np.stack((x_row, y_row), axis=0)) 41 | 42 | # make confusion matrix (cols = gt, rows = pred) 43 | np.add.at(self.conf_matrix, idxs, 1) 44 | 45 | def getStats(self): 46 | # remove fp from confusion on the ignore classes cols 47 | conf = self.conf_matrix.copy() 48 | conf[:, self.ignore] = 0 49 | 50 | # get the clean stats 51 | tp = np.diag(conf) 52 | fp = conf.sum(axis=1) - tp 53 | fn = conf.sum(axis=0) - tp 54 | return tp, fp, fn 55 | 56 | def getIoU(self): 57 | tp, fp, fn = self.getStats() 58 | intersection = tp 59 | union = tp + fp + fn + 1e-15 60 | iou = intersection / union 61 | iou_mean = (intersection[self.include] / union[self.include]).mean() 62 | return iou_mean, iou # returns "iou mean", "iou per class" ALL CLASSES 63 | 64 | def getacc(self): 65 | tp, fp, fn = self.getStats() 66 | total_tp = tp.sum() 67 | total = tp[self.include].sum() + fp[self.include].sum() + 1e-15 68 | acc_mean = total_tp / total 69 | return acc_mean # returns "acc mean" 70 | 71 | def get_confusion(self): 72 | return self.conf_matrix.copy() 73 | 74 | 75 | 76 | if __name__ == "__main__": 77 | # mock problem 78 | nclasses = 2 79 | ignore = [] 80 | 81 | # test with 2 squares and a known IOU 82 | lbl = np.zeros((7, 7), dtype=np.int64) 83 | argmax = np.zeros((7, 7), dtype=np.int64) 84 | 85 | # put squares 86 | lbl[2:4, 2:4] = 1 87 | argmax[3:5, 3:5] = 1 88 | 89 | # make evaluator 90 | eval = iouEval(nclasses, ignore) 91 | 92 | # run 93 | eval.addBatch(argmax, lbl) 94 | m_iou, iou = eval.getIoU() 95 | print("IoU: ", m_iou) 96 | print("IoU class: ", iou) 97 | m_acc = eval.getacc() 98 | print("Acc: ", m_acc) 99 | --------------------------------------------------------------------------------