├── utils ├── __init__.py ├── __pycache__ │ ├── config.cpython-38.pyc │ ├── logger.cpython-38.pyc │ ├── losses.cpython-38.pyc │ ├── pseudo.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── dataset.cpython-38.pyc │ ├── metrics.cpython-38.pyc │ ├── sampler.cpython-38.pyc │ ├── callbacks.cpython-38.pyc │ ├── collation.cpython-38.pyc │ ├── voxelizer.cpython-38.pyc │ ├── augmentations.cpython-38.pyc │ ├── online_logger.cpython-38.pyc │ └── dataset_online.cpython-38.pyc ├── _weights │ └── synthetichdl32e_full.npy ├── config.py ├── _resources │ ├── synthetic.yaml │ ├── synlidar.yaml │ ├── nuscenes.yaml │ └── semantic-kitti.yaml ├── metrics.py ├── sampler.py ├── online_logger.py ├── augmentations.py ├── callbacks.py ├── voxelizer.py ├── collation.py ├── losses.py └── logger.py ├── pic └── fig_framework1.png ├── __pycache__ └── tent.cpython-38.pyc ├── models ├── __pycache__ │ ├── lrf.cpython-38.pyc │ ├── common.cpython-38.pyc │ ├── network.cpython-38.pyc │ ├── resnet.cpython-38.pyc │ ├── resunet.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── minkunet.cpython-38.pyc │ ├── minkunet_ssl.cpython-38.pyc │ ├── minkunet_nobn.cpython-38.pyc │ └── residual_block.cpython-38.pyc ├── __init__.py ├── common.py ├── residual_block.py ├── lrf.py ├── network.py ├── resunet.py ├── minkunet_nobn.py ├── minkunet.py ├── resnet.py └── minkunet_ssl.py ├── pipelines ├── __pycache__ │ ├── trainer.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ ├── base_pipeline.cpython-38.pyc │ ├── trainer_lighting.cpython-38.pyc │ ├── adaptation_online_single.cpython-38.pyc │ ├── adaptation_online_single_gpg.cpython-38.pyc │ ├── adaptation_online_single_tent.cpython-38.pyc │ └── adaptation_online_single_test.cpython-38.pyc ├── __init__.py ├── base_pipeline.py ├── trainer_lighting.py └── trainer.py ├── train.sh ├── configs ├── source │ ├── synlidar_source.yaml │ ├── synth4dkitti_source.yaml │ └── synth4dnusc_source.yaml └── adaptation │ ├── synth4d2nusc_adaptation_model_features.yaml │ ├── synth4d2nusc_adaptation_model_features_64to32.yaml │ ├── synth4d2nusc_adaptation_64to32_synth.yaml │ ├── synth4d2nusc_adaptation_64to32_syn4d.yaml │ ├── synlidar2kitti_adaptation_model_features.yaml │ ├── synth4d2nusc_adaptation_64to32_syn4d_test.yaml │ ├── synth4d2kitti_adaptation_model_features.yaml │ └── synlidar2kitti_adaptation_32to64.yaml ├── README.md ├── train_lighting.py └── adapt_online.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pic/fig_framework1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/pic/fig_framework1.png -------------------------------------------------------------------------------- /__pycache__/tent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/__pycache__/tent.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/lrf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/lrf.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pseudo.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/pseudo.cpython-38.pyc -------------------------------------------------------------------------------- /utils/_weights/synthetichdl32e_full.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/_weights/synthetichdl32e_full.npy -------------------------------------------------------------------------------- /models/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/network.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/network.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resunet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/resunet.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/minkunet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/minkunet.cpython-38.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/pipelines/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/callbacks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/callbacks.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/collation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/collation.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/voxelizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/voxelizer.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/minkunet_ssl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/minkunet_ssl.cpython-38.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/pipelines/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/augmentations.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/augmentations.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/online_logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/online_logger.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/minkunet_nobn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/minkunet_nobn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/residual_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/models/__pycache__/residual_block.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_online.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/utils/__pycache__/dataset_online.cpython-38.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/base_pipeline.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/pipelines/__pycache__/base_pipeline.cpython-38.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/trainer_lighting.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/pipelines/__pycache__/trainer_lighting.cpython-38.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/adaptation_online_single.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/pipelines/__pycache__/adaptation_online_single.cpython-38.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/adaptation_online_single_gpg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/pipelines/__pycache__/adaptation_online_single_gpg.cpython-38.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/adaptation_online_single_tent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/pipelines/__pycache__/adaptation_online_single_tent.cpython-38.pyc -------------------------------------------------------------------------------- /pipelines/__pycache__/adaptation_online_single_test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tpzou/HGL/HEAD/pipelines/__pycache__/adaptation_online_single_test.cpython-38.pyc -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.minkunet import MinkUNet34C, MinkUNet18 2 | from models.minkunet_ssl import MinkUNet18_SSL, MinkUNet18_HEADS, MinkUNet18_MCMC 3 | from models.resunet import ResUNetBN2C 4 | from models.minkunet_nobn import MinkUNet18NOBN 5 | 6 | __all__ = ['MinkUNet34C', 'MinkUNet18', 'MinkUNet18_SSL', 'MinkUNet18_MCMC', 'ResUNetBN2C', 7 | 'MinkUNet18NOBN'] 8 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | 3 | 4 | def get_norm(norm_type, num_feats, bn_momentum=0.05, D=-1): 5 | if norm_type == 'BN': 6 | return ME.MinkowskiBatchNorm(num_feats, momentum=bn_momentum) 7 | elif norm_type == 'IN': 8 | return ME.MinkowskiInstanceNorm(num_feats, dimension=D) 9 | else: 10 | raise ValueError(f'Type {norm_type}, not defined') -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | gpuid=${1:-0} 2 | export CUDA_VISIBLE_DEVICES=$gpuid 3 | cd /home/XXXXXXX/A0_TTA-Point/HGL 4 | 5 | note="normal" 6 | CUBLAS_WORKSPACE_CONFIG=:4096:8 python adapt_online.py --config_file configs/adaptation/synlidar2kitti_adaptation_model_features.yaml --note $note --use_prototype --use_pseudo_new --pseudo_th 0.7 --pseudo_knn 10 --score_weight --loss_use_score_weight --loss_method_num 1 --loss_eps 0.3 --use_all_pseudo 7 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | def get_config(name): 5 | stream = open(name, 'r') 6 | config_dict = yaml.safe_load(stream) 7 | return Config(config_dict) 8 | 9 | 10 | class Config: 11 | def __init__(self, in_dict: dict): 12 | assert isinstance(in_dict, dict) 13 | for key, val in in_dict.items(): 14 | if isinstance(val, (list, tuple)): 15 | setattr(self, key, [Config(x) if isinstance(x, dict) else x for x in val]) 16 | else: 17 | setattr(self, key, Config(val) if isinstance(val, dict) else val) 18 | -------------------------------------------------------------------------------- /pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_pipeline import BasePipeline 2 | from .trainer import OneDomainTrainer 3 | from .trainer_lighting import PLTOneDomainTrainer 4 | from .adaptation_online_single import OnlineTrainer 5 | from .adaptation_online_single_test import OnlineTrainer_test 6 | from .adaptation_online_single_tent import OnlineTrainer_tent 7 | from .adaptation_online_single_gpg import OnlineTrainer_gpg 8 | 9 | __all__ = ['BasePipeline', 'OneDomainTrainer', 10 | 'PLTOneDomainTrainer', 11 | 'OneDomainAdaptation', 'OnlineTrainer', 'OnlineTrainer_test', "OnlineTrainer_tent", "OnlineTrainer_gpg"] 12 | -------------------------------------------------------------------------------- /pipelines/base_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as plt 3 | import MinkowskiEngine as ME 4 | import open3d as o3d 5 | 6 | 7 | class BasePipeline(object): 8 | 9 | def __init__(self, 10 | model=None, 11 | loss=None, 12 | optimizer=None, 13 | scheduler=None): 14 | 15 | self.model = model 16 | self.loss = loss 17 | self.optimizer = optimizer 18 | self.scheduler = scheduler 19 | 20 | def train(self): 21 | raise NotImplementedError 22 | 23 | def single_gpu_train(self): 24 | raise NotImplementedError 25 | 26 | def validate(self): 27 | raise NotImplementedError 28 | 29 | def inference(self): 30 | raise NotImplementedError 31 | 32 | -------------------------------------------------------------------------------- /configs/source/synlidar_source.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | 6 | dataset: 7 | version: 'full' 8 | name: 'SynLiDAR' 9 | dataset_path: 'data/SynLiDAR/' 10 | target_path: 'data/SemanticKITTI/data/sequences/' 11 | voxel_size: 0.05 12 | num_pts: 50000 13 | ignore_label: -1 14 | validate_target: false 15 | augment_data: true 16 | mapping_path: '_resources/synthetic.yaml' 17 | 18 | 19 | pipeline: 20 | epochs: 100 21 | gpus: [0] 22 | precision: 32 23 | loss: 'SoftDICELoss' 24 | seed: 1234 25 | save_dir: 'experiments/source/synlidar' 26 | 27 | dataloader: 28 | batch_size: 4 29 | num_workers: 16 30 | 31 | optimizer: 32 | name: 'Adam' 33 | lr: 0.01 34 | scheduler: true 35 | 36 | scheduler: 37 | scheduler_name: 'ExponentialLR' 38 | 39 | lightning: 40 | check_val_every_n_epoch: 3 41 | clear_cache_int: 1 42 | resume_checkpoint: null 43 | val_check_interval: 1.0 44 | num_sanity_val_steps: 0 45 | 46 | 47 | wandb: 48 | run_name: 'SOURCE_SYNLIDAR_SEG-PT1' 49 | project_name: 'amazing-project' 50 | entity_name: 'name' 51 | offline: false 52 | 53 | -------------------------------------------------------------------------------- /configs/source/synth4dkitti_source.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | 6 | dataset: 7 | version: 'full' 8 | name: 'SyntheticKITTI' 9 | dataset_path: 'data/Synth4D' 10 | target_path: 'data/SemanticKITTI/data/sequences/' 11 | voxel_size: 0.05 12 | num_pts: 50000 13 | ignore_label: -1 14 | validate_target: false 15 | augment_data: true 16 | mapping_path: '_resources/synthetic.yaml' 17 | 18 | 19 | pipeline: 20 | epochs: 100 21 | gpus: [0] 22 | precision: 32 23 | loss: 'SoftDICELoss' 24 | seed: 1234 25 | save_dir: 'experiments/source/synthkitti' 26 | 27 | dataloader: 28 | batch_size: 20 29 | num_workers: 16 30 | 31 | optimizer: 32 | name: 'Adam' 33 | lr: 0.01 34 | scheduler: true 35 | 36 | scheduler: 37 | scheduler_name: 'ExponentialLR' 38 | 39 | lightning: 40 | check_val_every_n_epoch: 5 41 | clear_cache_int: 1 42 | resume_checkpoint: null 43 | val_check_interval: 1.0 44 | num_sanity_val_steps: 0 45 | 46 | 47 | wandb: 48 | run_name: 'SOURCE_SYNTHKITTI_SEG-PT1' 49 | project_name: 'amazing-project' 50 | entity_name: 'name' 51 | offline: false 52 | 53 | -------------------------------------------------------------------------------- /configs/source/synth4dnusc_source.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | 6 | dataset: 7 | version: 'full' 8 | name: 'SyntheticNuScenes' 9 | dataset_path: '/DATA3/XXXX/TTA_Point/' 10 | target_path: '/DATA2/nuScenes/v1.0-trainval' 11 | voxel_size: 0.05 12 | num_pts: 50000 13 | ignore_label: -1 14 | validate_target: false 15 | augment_data: true 16 | mapping_path: '_resources/synthetic.yaml' 17 | 18 | 19 | pipeline: 20 | epochs: 100 21 | gpus: [0] 22 | precision: 32 23 | loss: 'SoftDICELoss' 24 | seed: 1234 25 | save_dir: 'experiments/source/synthnusc' 26 | 27 | dataloader: 28 | batch_size: 16 29 | num_workers: 16 30 | 31 | optimizer: 32 | name: 'Adam' 33 | lr: 0.01 34 | scheduler: true 35 | 36 | scheduler: 37 | scheduler_name: 'ExponentialLR' 38 | 39 | lightning: 40 | check_val_every_n_epoch: 10 41 | clear_cache_int: 1 42 | resume_checkpoint: null 43 | val_check_interval: 1.0 44 | num_sanity_val_steps: 0 45 | 46 | 47 | wandb: 48 | run_name: 'SOURCE_SYNTHNUSC_SEG-PT1' 49 | project_name: 'amazing-project' 50 | entity_name: 'name' 51 | offline: false 52 | 53 | -------------------------------------------------------------------------------- /utils/_resources/synthetic.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : 'noise' 4 | 1 : 'building' 5 | 2 : 'fences' 6 | 3 : 'other' 7 | 4 : 'pedestrian' 8 | 5 : 'pole' 9 | 6 : 'roadlines' 10 | 7 : 'road' 11 | 8 : 'sidewalk' 12 | 9 : 'vegeation' 13 | 10 : 'vehicle' 14 | 11 : 'wall' 15 | 12 : 'trafficsing' 16 | 13 : 'sky' 17 | 14 : 'ground' 18 | 15 : 'bridge' 19 | 16 : 'railtrack' 20 | 17 : 'guardrail' 21 | 18 : 'trafficlight' 22 | 19 : 'static' 23 | 20 : 'dynamic' 24 | 21 : 'water' 25 | 22 : 'terrain' 26 | 27 | # content: # as a ratio with the total number of points 28 | # classes that are indistinguishable from single scan or inconsistent in 29 | # ground truth are mapped to their closest equivalent 30 | learning_map: 31 | 0 : -1 32 | 1 : 5 33 | 2 : 5 34 | 3 : -1 35 | 4 : 1 36 | 5 : 5 37 | 6 : 2 38 | 7 : 2 39 | 8 : 3 40 | 9 : 6 41 | 10 : 0 42 | 11 : 5 43 | 12 : 5 44 | 13 : -1 45 | 14 : -1 46 | 15 : 5 47 | 16 : 5 48 | 17 : 5 49 | 18 : -1 50 | 19 : -1 51 | 20 : -1 52 | 21 : -1 53 | 22 : 4 54 | 55 | 56 | learning_map_inv: # inverse of previous map 57 | -1 : 0 58 | 0 : 1 59 | 1 : 4 60 | 2 : 7 61 | 3 : 8 62 | 4 : 22 63 | 5 : 1 64 | 6 : 9 65 | 66 | learning_ignore: # Ignore classes 67 | -1: True # "unlabeled", and others ignored 68 | 0: False 69 | 1: False 70 | 2: False 71 | 3: False 72 | 4: False 73 | 5: False 74 | 6: False 75 | 76 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sklearn.metrics as metrics 3 | import numpy as np 4 | import pdb 5 | 6 | 7 | def filtered_accuracy(preds, gt): 8 | 9 | valid_idx = torch.logical_not(gt == 0) 10 | gt = gt[valid_idx] - 1 11 | preds = preds[valid_idx] 12 | 13 | acc = metrics.accuracy_score(gt.numpy(), preds.numpy()) 14 | 15 | return acc 16 | 17 | 18 | def confusion_matrix(scores, labels, num_classes=7): 19 | r""" 20 | Compute the confusion matrix of one batch 21 | Parameters 22 | ---------- 23 | scores: torch.FloatTensor, shape (B?, C, N) 24 | raw scores for each class 25 | labels: torch.LongTensor, shape (B?, N) 26 | ground truth labels 27 | Returns 28 | ------- 29 | confusion matrix of this batch 30 | """ 31 | 32 | predictions = scores.data 33 | labels = labels.data 34 | 35 | conf_m = torch.zeros((num_classes, num_classes), dtype=torch.int32) 36 | 37 | for label in range(num_classes): 38 | for pred in range(num_classes): 39 | conf_m[label][pred] = torch.sum( 40 | torch.logical_and(labels == label, predictions == pred)) 41 | return conf_m 42 | 43 | 44 | def iou_from_confusion(conf_m): 45 | per_class_iou = [conf_m[i][i] / (conf_m[i].sum() + conf_m[:, i].sum() - conf_m[i][i]) for i in range(conf_m.shape[0])] 46 | return per_class_iou, torch.hstack(per_class_iou).mean() 47 | -------------------------------------------------------------------------------- /configs/adaptation/synth4d2nusc_adaptation_model_features.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | drop_prob: 0.5 6 | 7 | dataset: 8 | version: 'full' 9 | name: 'nuScenes' 10 | dataset_path: '/DATA2/nuScenes/Data' 11 | split_path: null 12 | target_path: '' 13 | voxel_size: 0.05 14 | num_pts: 50000 15 | ignore_label: -1 16 | validate_target: false 17 | augment_data: false 18 | max_time_window: 1 19 | oracle_pts: 0 20 | mapping_path: '_resources/nuscenes.yaml' 21 | 22 | pipeline: 23 | epochs: 1 24 | gpu: 0 25 | precision: 32 26 | loss: 'SoftDICELoss' 27 | ssl_loss: 'CosineSimilarity' 28 | eps: 0.25 29 | ssl_beta: 0.5 30 | segmentation_beta: 1.0 31 | seed: 1234 32 | save_dir: 'experiments/HGL/synth4dnusc/' 33 | source_model: 'pretrained_models/source/synth4dnusc/epoch=99-step=12499.ckpt' 34 | student_model: null 35 | topk_matches: -1 36 | random_time_window: false 37 | freeze_list: null 38 | topk_pseudo: null 39 | th_pseudo: 0. 40 | delayed_freeze_list: null 41 | delayed_freeze_frames: null 42 | is_double: true 43 | is_pseudo: true 44 | use_mcmc: true 45 | sub_epoch: 1 46 | num_mc_iterations: 5 47 | top_class: 0 48 | propagate: true 49 | top_p: 0.1 50 | propagation_size: 5 51 | metric: 'mcmc_cbst' 52 | use_matches: false 53 | propagation_method: 'model_features' 54 | 55 | dataloader: 56 | stream_batch_size: 1 57 | adaptation_batch_size: 1 58 | num_workers: 10 59 | 60 | optimizer: 61 | name: 'Adam' 62 | lr: 0.001 63 | scheduler: false 64 | 65 | scheduler: 66 | scheduler_name: 'ExponentialLR' 67 | 68 | trainer: 69 | save_checkpoint_every: 200 70 | clear_cache_int: 1 71 | num_sanity_val_steps: 0 72 | 73 | 74 | wandb: 75 | run_name: 'HGL-Synth2NUSCENES' 76 | project_name: 'amazing-project' 77 | entity_name: 'name' 78 | offline: false 79 | 80 | -------------------------------------------------------------------------------- /configs/adaptation/synth4d2nusc_adaptation_model_features_64to32.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | drop_prob: 0.5 6 | 7 | dataset: 8 | version: 'full' 9 | name: 'nuScenes' 10 | dataset_path: '/DATA2/nuScenes/Data' 11 | split_path: null 12 | target_path: '' 13 | voxel_size: 0.05 14 | num_pts: 50000 15 | ignore_label: -1 16 | validate_target: false 17 | augment_data: false 18 | max_time_window: 1 19 | oracle_pts: 0 20 | mapping_path: '_resources/nuscenes.yaml' 21 | 22 | pipeline: 23 | epochs: 1 24 | gpu: 0 25 | precision: 32 26 | loss: 'SoftDICELoss' 27 | ssl_loss: 'CosineSimilarity' 28 | eps: 0.25 29 | ssl_beta: 0.5 30 | segmentation_beta: 1.0 31 | seed: 1234 32 | save_dir: 'experiments/HGL/synth4dnusc/' 33 | source_model: 'pretrained_models/source/synth4dkitti/epoch=99-step=12499.ckpt' 34 | student_model: null 35 | topk_matches: -1 36 | random_time_window: false 37 | freeze_list: null 38 | topk_pseudo: null 39 | th_pseudo: 0. 40 | delayed_freeze_list: null 41 | delayed_freeze_frames: null 42 | is_double: true 43 | is_pseudo: true 44 | use_mcmc: true 45 | sub_epoch: 1 46 | num_mc_iterations: 5 47 | top_class: 0 48 | propagate: true 49 | top_p: 0.1 50 | propagation_size: 5 51 | metric: 'mcmc_cbst' 52 | use_matches: false 53 | propagation_method: 'model_features' 54 | 55 | dataloader: 56 | stream_batch_size: 1 57 | adaptation_batch_size: 1 58 | num_workers: 10 59 | 60 | optimizer: 61 | name: 'Adam' 62 | lr: 0.001 63 | scheduler: false 64 | 65 | scheduler: 66 | scheduler_name: 'ExponentialLR' 67 | 68 | trainer: 69 | save_checkpoint_every: 200 70 | clear_cache_int: 1 71 | num_sanity_val_steps: 0 72 | 73 | 74 | wandb: 75 | run_name: 'HGL-Synth2NUSCENES' 76 | project_name: 'amazing-project' 77 | entity_name: 'name' 78 | offline: false 79 | 80 | -------------------------------------------------------------------------------- /configs/adaptation/synth4d2nusc_adaptation_64to32_synth.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | drop_prob: 0.5 6 | 7 | dataset: 8 | version: 'full' 9 | name: 'nuScenes' 10 | dataset_path: '/DATA2/nuScenes/Data' 11 | split_path: null 12 | target_path: '' 13 | voxel_size: 0.05 14 | num_pts: 50000 15 | ignore_label: -1 16 | validate_target: false 17 | augment_data: false 18 | max_time_window: 1 19 | oracle_pts: 0 20 | mapping_path: '_resources/nuscenes.yaml' 21 | 22 | pipeline: 23 | epochs: 1 24 | gpu: 0 25 | precision: 32 26 | loss: 'SoftDICELoss' 27 | ssl_loss: 'CosineSimilarity' 28 | eps: 0.25 29 | ssl_beta: 0.5 30 | segmentation_beta: 1.0 31 | seed: 1234 32 | save_dir: 'experiments/HGL/synth4dnusc/A0_64to32_synth/' 33 | source_model: 'pretrained_models/source/synlidar/epoch=77-step=73631.ckpt' 34 | student_model: null 35 | topk_matches: -1 36 | random_time_window: false 37 | freeze_list: null 38 | topk_pseudo: null 39 | th_pseudo: 0. 40 | delayed_freeze_list: null 41 | delayed_freeze_frames: null 42 | is_double: true 43 | is_pseudo: true 44 | use_mcmc: true 45 | sub_epoch: 1 46 | num_mc_iterations: 5 47 | top_class: 0 48 | propagate: true 49 | top_p: 0.1 50 | propagation_size: 5 51 | metric: 'mcmc_cbst' 52 | use_matches: false 53 | propagation_method: 'geometric_features' 54 | 55 | dataloader: 56 | stream_batch_size: 1 57 | adaptation_batch_size: 1 58 | num_workers: 10 59 | 60 | optimizer: 61 | name: 'Adam' 62 | lr: 0.001 63 | scheduler: false 64 | 65 | scheduler: 66 | scheduler_name: 'ExponentialLR' 67 | 68 | trainer: 69 | save_checkpoint_every: 200 70 | clear_cache_int: 1 71 | num_sanity_val_steps: 0 72 | 73 | 74 | wandb: 75 | run_name: 'HGL-Synth2NUSCENES' 76 | project_name: 'amazing-project' 77 | entity_name: 'name' 78 | offline: false 79 | 80 | -------------------------------------------------------------------------------- /configs/adaptation/synth4d2nusc_adaptation_64to32_syn4d.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | drop_prob: 0.5 6 | 7 | dataset: 8 | version: 'full' 9 | name: 'nuScenes' 10 | dataset_path: '/DATA2/nuScenes/Data' 11 | split_path: null 12 | target_path: '' 13 | voxel_size: 0.05 14 | num_pts: 50000 15 | ignore_label: -1 16 | validate_target: false 17 | augment_data: false 18 | max_time_window: 1 19 | oracle_pts: 0 20 | mapping_path: '_resources/nuscenes.yaml' 21 | 22 | pipeline: 23 | epochs: 1 24 | gpu: 0 25 | precision: 32 26 | loss: 'SoftDICELoss' 27 | ssl_loss: 'CosineSimilarity' 28 | eps: 0.25 29 | ssl_beta: 0.5 30 | segmentation_beta: 1.0 31 | seed: 1234 32 | save_dir: 'experiments/HGL/synth4dnusc/A0_64to32_syn4d/' 33 | source_model: 'pretrained_models/source/synth4dkitti/epoch=99-step=12499.ckpt' 34 | student_model: null 35 | topk_matches: -1 36 | random_time_window: false 37 | freeze_list: null 38 | topk_pseudo: null 39 | th_pseudo: 0. 40 | delayed_freeze_list: null 41 | delayed_freeze_frames: null 42 | is_double: true 43 | is_pseudo: true 44 | use_mcmc: true 45 | sub_epoch: 1 46 | num_mc_iterations: 5 47 | top_class: 0 48 | propagate: true 49 | top_p: 0.1 50 | propagation_size: 5 51 | metric: 'mcmc_cbst' 52 | use_matches: false 53 | propagation_method: 'geometric_features' 54 | 55 | dataloader: 56 | stream_batch_size: 1 57 | adaptation_batch_size: 1 58 | num_workers: 10 59 | 60 | optimizer: 61 | name: 'Adam' 62 | lr: 0.001 63 | scheduler: false 64 | 65 | scheduler: 66 | scheduler_name: 'ExponentialLR' 67 | 68 | trainer: 69 | save_checkpoint_every: 200 70 | clear_cache_int: 1 71 | num_sanity_val_steps: 0 72 | 73 | 74 | wandb: 75 | run_name: 'HGL-Synth2NUSCENES' 76 | project_name: 'amazing-project' 77 | entity_name: 'name' 78 | offline: false 79 | 80 | -------------------------------------------------------------------------------- /configs/adaptation/synlidar2kitti_adaptation_model_features.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | drop_prob: 0.5 6 | 7 | dataset: 8 | version: 'full' 9 | name: 'SemanticKITTI' 10 | dataset_path: '/DATA2/kitti/odometry/sequences' 11 | split_path: null 12 | target_path: '' 13 | voxel_size: 0.05 14 | num_pts: 50000 15 | ignore_label: -1 16 | validate_target: false 17 | augment_data: false 18 | max_time_window: 5 19 | oracle_pts: 0 20 | mapping_path: '_resources/semantic-kitti.yaml' 21 | 22 | pipeline: 23 | epochs: 1 24 | gpu: 0 25 | precision: 32 26 | loss: 'SoftDICELoss' 27 | ssl_loss: 'CosineSimilarity' 28 | eps: 0.25 29 | ssl_beta: 0.5 30 | segmentation_beta: 1.0 31 | seed: 1234 32 | save_dir: 'experiments/HGL/synlidar2kitti' 33 | source_model: 'pretrained_models/source/synlidar/epoch=77-step=73631.ckpt' 34 | student_model: null 35 | topk_matches: -1 36 | random_time_window: false 37 | freeze_list: null 38 | topk_pseudo: null 39 | th_pseudo: 0. 40 | delayed_freeze_list: null 41 | delayed_freeze_frames: null 42 | is_double: true 43 | is_pseudo: true 44 | use_mcmc: true 45 | sub_epoch: 1 46 | num_mc_iterations: 5 47 | top_class: 0 48 | propagate: true 49 | top_p: 0.01 50 | propagation_size: 10 51 | metric: 'mcmc_cbst' 52 | use_matches: false 53 | propagation_method: 'model_features' 54 | 55 | dataloader: 56 | stream_batch_size: 1 57 | adaptation_batch_size: 1 58 | num_workers: 10 59 | 60 | optimizer: 61 | name: 'Adam' 62 | lr: 0.001 63 | scheduler: false 64 | 65 | scheduler: 66 | scheduler_name: 'ExponentialLR' 67 | 68 | trainer: 69 | save_checkpoint_every: 200 70 | clear_cache_int: 1 71 | num_sanity_val_steps: 0 72 | 73 | 74 | wandb: 75 | run_name: 'HGL-Synlidar2KITTI' 76 | project_name: 'amazing-project' 77 | entity_name: 'name' 78 | offline: false 79 | 80 | -------------------------------------------------------------------------------- /configs/adaptation/synth4d2nusc_adaptation_64to32_syn4d_test.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | drop_prob: 0.5 6 | 7 | dataset: 8 | version: 'full' 9 | name: 'nuScenes' 10 | dataset_path: '/DATA2/nuScenes/Data' 11 | split_path: null 12 | target_path: '' 13 | voxel_size: 0.05 14 | num_pts: 50000 15 | ignore_label: -1 16 | validate_target: false 17 | augment_data: false 18 | max_time_window: 1 19 | oracle_pts: 0 20 | mapping_path: '_resources/nuscenes.yaml' 21 | 22 | pipeline: 23 | epochs: 1 24 | gpu: 0 25 | precision: 32 26 | loss: 'SoftDICELoss' 27 | ssl_loss: 'CosineSimilarity' 28 | eps: 0.25 29 | ssl_beta: 0.5 30 | segmentation_beta: 1.0 31 | seed: 1234 32 | save_dir: 'experiments/HGL/synth4dnusc/A0_64to32_syn4d/' 33 | source_model: 'pretrained_models/source/synth4dkitti/epoch=99-step=12499.ckpt' 34 | student_model: null 35 | topk_matches: -1 36 | random_time_window: false 37 | freeze_list: null 38 | topk_pseudo: null 39 | th_pseudo: 0. 40 | delayed_freeze_list: null 41 | delayed_freeze_frames: null 42 | is_double: true 43 | is_pseudo: true 44 | use_mcmc: true 45 | sub_epoch: 1 46 | num_mc_iterations: 5 47 | top_class: 0 48 | propagate: true 49 | top_p: 0.1 50 | propagation_size: 5 51 | metric: 'mcmc_cbst' 52 | use_matches: false 53 | propagation_method: 'geometric_features' 54 | 55 | dataloader: 56 | stream_batch_size: 1 57 | adaptation_batch_size: 1 58 | num_workers: 10 59 | 60 | optimizer: 61 | name: 'Adam' 62 | lr: 0.001 63 | scheduler: false 64 | 65 | scheduler: 66 | scheduler_name: 'ExponentialLR' 67 | 68 | trainer: 69 | save_checkpoint_every: 200 70 | clear_cache_int: 1 71 | num_sanity_val_steps: 0 72 | 73 | 74 | wandb: 75 | run_name: 'HGL-Synth2NUSCENES' 76 | project_name: 'amazing-project' 77 | entity_name: 'name' 78 | offline: false 79 | 80 | -------------------------------------------------------------------------------- /configs/adaptation/synth4d2kitti_adaptation_model_features.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | drop_prob: 0.5 6 | 7 | dataset: 8 | version: 'full' 9 | name: 'SemanticKITTI' 10 | dataset_path: '/DATA2/kitti/odometry/sequences' 11 | split_path: null 12 | target_path: '' 13 | voxel_size: 0.05 14 | num_pts: 50000 15 | ignore_label: -1 16 | validate_target: false 17 | augment_data: false 18 | max_time_window: 5 19 | oracle_pts: 0 20 | mapping_path: '_resources/semantic-kitti.yaml' 21 | 22 | 23 | pipeline: 24 | epochs: 1 25 | gpu: 0 26 | precision: 32 27 | loss: 'SoftDICELoss' 28 | ssl_loss: 'CosineSimilarity' 29 | eps: 0.25 30 | ssl_beta: 0.5 31 | segmentation_beta: 1.0 32 | seed: 1234 33 | save_dir: 'experiments/HGL/synth4d2kitti' 34 | source_model: 'pretrained_models/source/synth4dkitti/epoch=99-step=12499.ckpt' 35 | student_model: null 36 | topk_matches: -1 37 | random_time_window: false 38 | freeze_list: null 39 | topk_pseudo: null 40 | th_pseudo: 0. 41 | delayed_freeze_list: null 42 | delayed_freeze_frames: null 43 | is_double: true 44 | is_pseudo: true 45 | use_mcmc: true 46 | sub_epoch: 1 47 | num_mc_iterations: 5 48 | top_class: 0 49 | propagate: false 50 | top_p: 0.01 51 | propagation_size: 10 52 | metric: 'mcmc_cbst' 53 | use_matches: false 54 | propagation_method: 'model_features' 55 | 56 | dataloader: 57 | stream_batch_size: 1 58 | adaptation_batch_size: 1 59 | num_workers: 10 60 | 61 | optimizer: 62 | name: 'Adam' 63 | lr: 0.001 64 | scheduler: false 65 | 66 | scheduler: 67 | scheduler_name: 'ExponentialLR' 68 | 69 | trainer: 70 | save_checkpoint_every: 200 71 | clear_cache_int: 1 72 | num_sanity_val_steps: 0 73 | 74 | 75 | wandb: 76 | run_name: 'HGL-Synth2NUSCENES' 77 | project_name: 'amazing-project' 78 | entity_name: 'name' 79 | offline: false 80 | 81 | -------------------------------------------------------------------------------- /configs/adaptation/synlidar2kitti_adaptation_32to64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'MinkUNet18' 3 | in_feat_size: 1 4 | out_classes: 7 5 | drop_prob: 0.5 6 | 7 | dataset: 8 | version: 'full' 9 | name: 'SemanticKITTI' 10 | dataset_path: '/DATA2/kitti/odometry/sequences' 11 | split_path: null 12 | target_path: '' 13 | voxel_size: 0.05 14 | num_pts: 50000 15 | ignore_label: -1 16 | validate_target: false 17 | augment_data: false 18 | max_time_window: 5 19 | oracle_pts: 0 20 | mapping_path: '_resources/semantic-kitti.yaml' 21 | 22 | pipeline: 23 | epochs: 1 24 | gpu: 0 25 | precision: 32 26 | loss: 'SoftDICELoss' 27 | ssl_loss: 'CosineSimilarity' 28 | eps: 0.25 29 | ssl_beta: 0.5 30 | segmentation_beta: 1.0 31 | seed: 1234 32 | save_dir: 'experiments/HGL/synlidar2kitti' 33 | source_model: 'pretrained_models/source/synth4dnusc/epoch=99-step=12499.ckpt' 34 | student_model: null 35 | topk_matches: -1 36 | random_time_window: false 37 | freeze_list: null 38 | topk_pseudo: null 39 | th_pseudo: 0. 40 | delayed_freeze_list: null 41 | delayed_freeze_frames: null 42 | is_double: true 43 | is_pseudo: true 44 | use_mcmc: true 45 | sub_epoch: 1 46 | num_mc_iterations: 5 47 | top_class: 0 48 | propagate: true 49 | top_p: 0.01 50 | propagation_size: 10 51 | metric: 'mcmc_cbst' 52 | use_matches: false 53 | propagation_method: 'geometric_features' 54 | # propagation_method: 'model_features' 55 | 56 | dataloader: 57 | stream_batch_size: 1 58 | adaptation_batch_size: 1 59 | num_workers: 10 60 | 61 | optimizer: 62 | name: 'Adam' 63 | lr: 0.001 64 | scheduler: false 65 | 66 | scheduler: 67 | scheduler_name: 'ExponentialLR' 68 | 69 | trainer: 70 | save_checkpoint_every: 200 71 | clear_cache_int: 1 72 | num_sanity_val_steps: 0 73 | 74 | 75 | wandb: 76 | run_name: 'HGL-Synlidar2KITTI' 77 | project_name: 'amazing-project' 78 | entity_name: 'name' 79 | offline: false 80 | 81 | -------------------------------------------------------------------------------- /models/residual_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.common import get_norm 4 | 5 | import MinkowskiEngine as ME 6 | import MinkowskiEngine.MinkowskiFunctional as MEF 7 | 8 | 9 | class BasicBlockBase(nn.Module): 10 | expansion = 1 11 | NORM_TYPE = 'BN' 12 | 13 | def __init__(self, 14 | inplanes, 15 | planes, 16 | stride=1, 17 | dilation=1, 18 | downsample=None, 19 | bn_momentum=0.1, 20 | D=3): 21 | super(BasicBlockBase, self).__init__() 22 | 23 | self.conv1 = ME.MinkowskiConvolution( 24 | inplanes, planes, kernel_size=3, stride=stride, dimension=D) 25 | self.norm1 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) 26 | self.conv2 = ME.MinkowskiConvolution( 27 | planes, 28 | planes, 29 | kernel_size=3, 30 | stride=1, 31 | dilation=dilation, 32 | bias=False, 33 | dimension=D) 34 | self.norm2 = get_norm(self.NORM_TYPE, planes, bn_momentum=bn_momentum, D=D) 35 | self.downsample = downsample 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.norm1(out) 42 | out = MEF.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.norm2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = MEF.relu(out) 52 | 53 | return out 54 | 55 | 56 | class BasicBlockBN(BasicBlockBase): 57 | NORM_TYPE = 'BN' 58 | 59 | 60 | class BasicBlockIN(BasicBlockBase): 61 | NORM_TYPE = 'IN' 62 | 63 | 64 | def get_block(norm_type, 65 | inplanes, 66 | planes, 67 | stride=1, 68 | dilation=1, 69 | downsample=None, 70 | bn_momentum=0.1, 71 | D=3): 72 | if norm_type == 'BN': 73 | return BasicBlockBN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) 74 | elif norm_type == 'IN': 75 | return BasicBlockIN(inplanes, planes, stride, dilation, downsample, bn_momentum, D) 76 | else: 77 | raise ValueError(f'Type {norm_type}, not defined') -------------------------------------------------------------------------------- /utils/_resources/synlidar.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1: "car" 5 | 2: "pick-up" 6 | 3: "truck" 7 | 4: "bus" 8 | 5: "bicycle" 9 | 6: "motorcycle" 10 | 7: "other-vehicle" 11 | 8: "road" 12 | 9: "sidewalk" 13 | 10: "parking" 14 | 11: "other-ground" 15 | 12: "female" 16 | 13: "male" 17 | 14: "kid" 18 | 15: "crowd" # multiple person that are very close 19 | 16: "bicyclist" 20 | 17: "motorcyclist" 21 | 18: "building" 22 | 19: "other-structure" 23 | 20: "vegetation" 24 | 21: "trunk" 25 | 22: "terrain" 26 | 23: "traffic-sign" 27 | 24: "pole" 28 | 25: "traffic-cone" 29 | 26: "fence" 30 | 27: "garbage-can" 31 | 28: "electric-box" 32 | 29: "table" 33 | 30: "chair" 34 | 31: "bench" 35 | 32: "other-object" 36 | 37 | learning_map: 38 | 0: -1 # "unlabeled" 39 | 1: 0 # "car" 40 | 2: 0 # "pick-up" 41 | 3: -1 # "truck" 42 | 4: -1 # "bus" 43 | 5: -1 # "bicycle" 44 | 6: -1 # "motorcycle" 45 | 7: -1 # "other-vehicle" 46 | 8: 2 # "road" 47 | 9: 3 # "sidewalk" 48 | 10: 2 # "parking" 49 | 11: -1 # "other-ground" 50 | 12: 1 # "female" 51 | 13: 1 # "male" 52 | 14: 1 # "kid" 53 | 15: 1 # "crowd" 54 | 16: -1 # "bicyclist" 55 | 17: -1 # "motorcyclist" 56 | 18: 5 # "building" 57 | 19: -1 # "other-structure" 58 | 20: 6 # "vegetation" 59 | 21: 6 # "trunk" 60 | 22: 4 # "terrain" 61 | 23: 5 # "traffic-sign" 62 | 24: 5 # "pole" 63 | 25: -1 # "traffic-cone" 64 | 26: 5 # "fence" 65 | 27: -1 # "garbage-can" 66 | 28: -1 # "electric-box" 67 | 29: -1 # "table" 68 | 30: -1 # "chair" 69 | 31: -1 # "bench" 70 | 32: -1 # "other-object" 71 | 72 | learning_map_inv: # inverse of previous map 73 | -1: 0 # "unlabeled", and others ignored 74 | 0: 1 # "vehicle" 75 | 1: 12 # "person" 76 | 2: 8 # "road" 77 | 3: 9 # "sidewalk" 78 | 4: 22 # "terrain" 79 | 5: 18 # "manmade" 80 | 6: 10 # "vegetation" 81 | 82 | learning_ignore: # Ignore classes 83 | -1: True # "unlabeled", and others ignored 84 | 0: False # "vehicle" 85 | 1: False # "pedestrian" 86 | 2: False # "road" 87 | 3: False # "sidewalk" 88 | 4: False # "terrain" 89 | 5: False # "manmade" 90 | 6: False # "vegetation" 91 | -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.utils.data.sampler import Sampler 4 | from typing import Iterator, Optional, Sequence, List, TypeVar, Generic, Sized 5 | 6 | 7 | class SequentialSampler(Sampler[int]): 8 | r"""Samples elements sequentially, always in the same order. 9 | Args: 10 | data_source (Dataset): dataset to sample from 11 | """ 12 | data_source: Sized 13 | 14 | def __init__(self, data_source: Sized, is_adapt=False, max_time_wdw=None, adapt_batchsize=None) -> None: 15 | self.data_source = data_source 16 | self.is_adapt = is_adapt 17 | self.adapt_batchsize = adapt_batchsize 18 | self.max_time_wdw = max_time_wdw 19 | 20 | def __iter__(self) -> Iterator[int]: 21 | if not self.is_adapt: 22 | return iter(range(len(self.data_source))) 23 | else: 24 | if self.max_time_wdw is None: 25 | return iter(range(self.adapt_batchsize-1, len(self.data_source))) 26 | else: 27 | return iter(range(self.max_time_wdw, len(self.data_source))) 28 | 29 | def __len__(self) -> int: 30 | return len(self.data_source) 31 | 32 | 33 | class BatchSampler(Sampler[List[int]]): 34 | 35 | def __init__(self, sampler: Sampler[int], batch_size: int) -> None: 36 | # Since collections.abc.Iterable does not check for `__getitem__`, which 37 | # is one way for an object to be an iterable, we don't do an `isinstance` 38 | # check here. 39 | if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ 40 | batch_size <= 0: 41 | raise ValueError("batch_size should be a positive integer value, " 42 | "but got batch_size={}".format(batch_size)) 43 | 44 | self.sampler = sampler 45 | self.batch_size = batch_size 46 | 47 | def __iter__(self) -> Iterator[List[int]]: 48 | # batch = [] 49 | # for idx in self.sampler: 50 | # batch.append(idx) 51 | # if len(batch) == self.batch_size: 52 | # yield batch 53 | # batch = [] 54 | # if len(batch) > 0 and not self.drop_first: 55 | # yield batch 56 | batch = [] 57 | for idx in self.sampler: 58 | idx += 1 59 | batch.extend([b for b in range(idx-self.batch_size, idx)]) 60 | if len(batch) == self.batch_size: 61 | print('--> BATCHED', batch) 62 | yield batch 63 | batch = [] 64 | 65 | def __len__(self) -> int: 66 | # Can only be called if self.sampler has __len__ implemented 67 | # We cannot enforce this condition, so we turn off typechecking for the 68 | # implementation below. 69 | # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 70 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size -------------------------------------------------------------------------------- /utils/online_logger.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import csv 3 | import os 4 | import torch 5 | 6 | 7 | class OnlineWandbLogger(object): 8 | 9 | def __init__(self, 10 | project, 11 | entity, 12 | name, 13 | offline=True, 14 | config=None): 15 | 16 | super().__init__() 17 | 18 | self.project = project 19 | self.entity = entity 20 | self.name = name 21 | self.offline = offline 22 | 23 | if self.offline: 24 | os.environ['WANDB_MODE'] = 'offline' 25 | 26 | self.run = wandb.init(project=project, 27 | name=name, 28 | config=config) 29 | 30 | 31 | 32 | self.sequence = None 33 | 34 | def set_sequence(self, sequence): 35 | self.sequence = sequence 36 | 37 | def log(self, results_dict, step=None): 38 | mapped_dict = {os.path.join(str(self.sequence), k): v for k, v in results_dict.items()} 39 | self.run.log(mapped_dict, step=step) 40 | 41 | 42 | class OnlineCSVLogger(object): 43 | 44 | def __init__(self, 45 | save_dir, 46 | version='logs'): 47 | 48 | super().__init__() 49 | 50 | self.save_dir = save_dir 51 | self.version = version 52 | 53 | os.mkdir(os.path.join(self.save_dir, self.version)) 54 | 55 | self.metrics = [] 56 | 57 | self.metrics_file_path = os.path.join(self.save_dir, self.version) 58 | self.sequence = None 59 | 60 | def set_sequence(self, sequence): 61 | self.sequence = sequence 62 | 63 | def log(self, results_dict, step=None): 64 | self.log_metrics(results_dict, step) 65 | self.save() 66 | 67 | def log_metrics(self, metrics_dict, step=None) -> None: 68 | """Record metrics""" 69 | 70 | def _handle_value(value): 71 | if isinstance(value, torch.Tensor): 72 | return value.item() 73 | return value 74 | 75 | if step is None: 76 | step = len(self.metrics) 77 | 78 | metrics = {k: _handle_value(v) for k, v in metrics_dict.items()} 79 | metrics["step"] = step 80 | self.metrics.append(metrics) 81 | 82 | def save(self) -> None: 83 | """Save recorded hparams and metrics into files""" 84 | 85 | if not self.metrics: 86 | return 87 | 88 | last_m = {} 89 | for m in self.metrics: 90 | last_m.update(m) 91 | metrics_keys = list(last_m.keys()) 92 | 93 | log_path = os.path.join(self.metrics_file_path, str(self.sequence)+'.csv') 94 | 95 | with open(log_path, "w", newline="") as f: 96 | self.writer = csv.DictWriter(f, fieldnames=metrics_keys) 97 | self.writer.writeheader() 98 | self.writer.writerows(self.metrics) 99 | 100 | 101 | -------------------------------------------------------------------------------- /utils/_resources/nuscenes.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : 'noise' 4 | 1 : 'animal' 5 | 2 : 'human.pedestrian.adult' 6 | 3 : 'human.pedestrian.child' 7 | 4 : 'human.pedestrian.construction_worker' 8 | 5 : 'human.pedestrian.personal_mobility' 9 | 6 : 'human.pedestrian.police_officer' 10 | 7 : 'human.pedestrian.stroller' 11 | 8 : 'human.pedestrian.wheelchair' 12 | 9 : 'movable_object.barrier' 13 | 10 : 'movable_object.debris' 14 | 11 : 'movable_object.pushable_pullable' 15 | 12 : 'movable_object.trafficcone' 16 | 13 : 'static_object.bicycle_rack' 17 | 14 : 'vehicle.bicycle' 18 | 15 : 'vehicle.bus.bendy' 19 | 16 : 'vehicle.bus.rigid' 20 | 17 : 'vehicle.car' 21 | 18 : 'vehicle.construction' 22 | 19 : 'vehicle.emergency.ambulance' 23 | 20 : 'vehicle.emergency.police' 24 | 21 : 'vehicle.motorcycle' 25 | 22 : 'vehicle.trailer' 26 | 23 : 'vehicle.truck' 27 | 24 : 'flat.driveable_surface' 28 | 25 : 'flat.other' 29 | 26 : 'flat.sidewalk' 30 | 27 : 'flat.terrain' 31 | 28 : 'static.manmade' 32 | 29 : 'static.other' 33 | 30 : 'static.vegetation' 34 | 31 : 'vehicle.ego' 35 | 36 | content: # as a ratio with the total number of points $ mini nusc 37 | 0 : 0.001862411429 38 | 1 : 0.000004865757637 39 | 2 : 0.001948534888 40 | 3 : 0.000008724027852 41 | 4 : 0.0001259973709 42 | 5 : 0.000007881894868 43 | 6 : 0.000008275854075 44 | 7 : 0.000007959602418 45 | 8 : 0.00001099471475 46 | 9 : 0.008407871974 47 | 10 : 0.00006041400582 48 | 11 : 0.0006493468772 49 | 12 : 0.0006652480105 50 | 13 : 0.0001473967651 51 | 14 : 0.0001277213942 52 | 15 : 0.0003229950459 53 | 16 : 0.003837756326 54 | 17 : 0.03443006399 55 | 18 : 0.001368388391 56 | 19 : 0.000002004131929 57 | 20 : 0.000005384410354 58 | 21 : 0.0003861803198 59 | 22 : 0.004434309958 60 | 23 : 0.01431389697 61 | 24 : 0.2863965057 62 | 25 : 0.00773390355 63 | 26 : 0.06342875245 64 | 27 : 0.06351212452 65 | 28 : 0.1609974505 66 | 29 : 0.00007383572614 67 | 30 : 0.1107615163 68 | 31 : 0.3045689783 69 | 70 | # classes that are indistinguishable from single scan or inconsistent in 71 | # ground truth are mapped to their closest equivalent 72 | learning_map: 73 | 0 : -1 74 | 1 : -1 75 | 2 : 1 76 | 3 : 1 77 | 4 : 1 78 | 5 : 1 79 | 6 : 1 80 | 7 : 1 81 | 8 : 1 82 | 9 : -1 83 | 10 : -1 84 | 11 : -1 85 | 12 : -1 86 | 13 : -1 87 | 14 : -1 88 | 15 : -1 89 | 16 : -1 90 | 17 : 0 91 | 18 : -1 92 | 19 : -1 93 | 20 : 0 94 | 21 : -1 95 | 22 : -1 96 | 23 : -1 97 | 24 : 2 98 | 25 : -1 99 | 26 : 3 100 | 27 : 4 101 | 28 : 5 102 | 29 : -1 103 | 30 : 6 104 | 31 : -1 105 | 106 | 107 | learning_map_inv: # inverse of previous map 108 | -1: 0 109 | 0 : 17 110 | 1 : 2 111 | 2 : 24 112 | 3 : 26 113 | 4 : 27 114 | 5 : 28 115 | 6 : 30 116 | 117 | learning_ignore: # Ignore classes 118 | -1: True # "unlabeled", and others ignored 119 | 0: False 120 | 1: False 121 | 2: False 122 | 3: False 123 | 4: False 124 | 5: False 125 | 6: False 126 | 127 | -------------------------------------------------------------------------------- /models/lrf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | 4 | 5 | class lrf(): 6 | ''' 7 | This is our re-implementation (+adaptation) of the LRF computed in: 8 | Z. Gojcic, C. Zhou, J. Wegner, and W. Andreas, 9 | “The perfect match: 3Dpoint cloud matching with smoothed densities,” 10 | CVPR, 2019 11 | ''' 12 | def __init__(self, pcd, pcd_tree, lrf_kernel, patch_size, viz=False): 13 | 14 | self.pcd = pcd 15 | self.pcd_tree = pcd_tree 16 | self.do_viz = viz 17 | self.patch_kernel = lrf_kernel 18 | self.patch_size = patch_size 19 | 20 | def get(self, pt): 21 | 22 | _, patch_idx, _ = self.pcd_tree.search_radius_vector_3d(pt, self.patch_kernel) 23 | 24 | ptnn = np.asarray(self.pcd.points)[patch_idx[1:], :].T 25 | ptall = np.asarray(self.pcd.points)[patch_idx, :].T 26 | 27 | # eq. 3 28 | ptnn_cov = 1 / len(ptnn) * np.dot((ptnn - pt[:, np.newaxis]), (ptnn - pt[:, np.newaxis]).T) 29 | 30 | if len(patch_idx) < self.patch_kernel / 2: 31 | _, patch_idx, _ = self.pcd_tree.search_knn_vector_3d(pt, self.patch_kernel) 32 | 33 | # The normalized (unit “length”) eigenvectors, s.t. the column v[:,i] is the eigenvector corresponding to the eigenvalue w[i]. 34 | a, v = np.linalg.eig(ptnn_cov) 35 | smallest_eigevalue_idx = np.argmin(a) 36 | np_hat = v[:, smallest_eigevalue_idx] 37 | 38 | # eq. 4 39 | zp = np_hat if np.sum(np.dot(np_hat, pt[:, np.newaxis] - ptnn)) > 0 else - np_hat 40 | 41 | v = (ptnn - pt[:, np.newaxis]) - (np.dot((ptnn - pt[:, np.newaxis]).T, zp[:, np.newaxis]) * zp).T 42 | alpha = (self.patch_kernel - np.linalg.norm(pt[:, np.newaxis] - ptnn, axis=0)) ** 2 43 | beta = np.dot((ptnn - pt[:, np.newaxis]).T, zp[:, np.newaxis]).squeeze() ** 2 44 | 45 | # e.q. 5 46 | xp = 1 / np.linalg.norm(np.dot(v, (alpha * beta)[:, np.newaxis])) * np.dot(v, (alpha * beta)[:, np.newaxis]) 47 | xp = xp.squeeze() 48 | 49 | yp = np.cross(xp, zp) 50 | 51 | lRg = np.asarray([xp, yp, zp]).T 52 | 53 | # rotate w.r.t local frame and centre in zero using the chosen point 54 | ptall = (lRg.T @ (ptall - pt[:, np.newaxis])).T 55 | 56 | # this is our normalisation 57 | ptall /= self.patch_kernel 58 | 59 | T = np.zeros((4, 4)) 60 | T[-1, -1] = 1 61 | T[:3, :3] = lRg 62 | T[:3, -1] = pt 63 | 64 | # visualise patch and local reference frame 65 | if self.do_viz: 66 | self.pcd.paint_uniform_color([.3, .3, .3]) 67 | self.pcd.estimate_normals() 68 | np.asarray(self.pcd.colors)[patch_idx[1:]] = [0, 1, 0] 69 | local_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=0.2) 70 | local_frame.transform(T) 71 | o3d.visualization.draw_geometries([self.pcd, local_frame]) 72 | 73 | # to make sure that there are at least self.patch_size points, pad with zeros if not 74 | if ptall.shape[0] < self.patch_size: 75 | ptall = np.concatenate((ptall, np.zeros((self.patch_size - ptall.shape[0], 3)))) 76 | 77 | inds = np.random.choice(ptall.shape[0], self.patch_size, replace=False) 78 | 79 | return ptall[inds], pt, T -------------------------------------------------------------------------------- /models/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class STN3d(nn.Module): 6 | 7 | def __init__(self): 8 | super(STN3d, self).__init__() 9 | 10 | self.conv1 = nn.Sequential(nn.Conv1d(3, 256, 1), 11 | nn.BatchNorm1d(256), 12 | nn.ReLU()) 13 | 14 | self.conv2 = nn.Sequential(nn.Conv1d(256, 512, 1), 15 | nn.BatchNorm1d(512), 16 | nn.ReLU()) 17 | 18 | self.conv3 = nn.Sequential(nn.Conv1d(512, 1024, 1), 19 | nn.BatchNorm1d(1024)) 20 | 21 | self.fc1 = nn.Sequential(nn.Linear(1024, 512), 22 | nn.BatchNorm1d(512), 23 | nn.ReLU()) 24 | 25 | self.fc2 = nn.Sequential(nn.Linear(512, 256), 26 | nn.BatchNorm1d(256), 27 | nn.ReLU()) 28 | 29 | self.fc3 = nn.Sequential(nn.Linear(256, 9)) 30 | 31 | def forward(self, x): 32 | 33 | batchsize = x.size()[0] 34 | 35 | x = self.conv1(x) 36 | x = self.conv2(x) 37 | x = self.conv3(x) 38 | 39 | x = torch.max(x, 2, keepdim=True)[0] 40 | x = x.view(-1, 1024) 41 | 42 | x = self.fc1(x) 43 | x = self.fc2(x) 44 | x = self.fc3(x) 45 | 46 | iden = torch.Tensor([1, 0, 0, 0, 1, 0, 0, 0, 1]).view(1, 9).repeat(batchsize, 1) 47 | if x.is_cuda: 48 | iden = iden.cuda() 49 | x = x + iden 50 | 51 | x = x.view(-1, 3, 3) 52 | 53 | return x 54 | 55 | 56 | class PointNetFeature(nn.Module): 57 | 58 | def __init__(self, dim=32, l2norm=True, tnet=True): 59 | super(PointNetFeature, self).__init__() 60 | 61 | self.l2norm = l2norm 62 | self.tnet = tnet 63 | 64 | self.stn3d = STN3d() 65 | 66 | self.conv1 = nn.Sequential(nn.Conv1d(3, 256, 1), 67 | nn.BatchNorm1d(256), 68 | nn.ReLU()) 69 | 70 | self.conv2 = nn.Sequential(nn.Conv1d(256, 512, 1), 71 | nn.BatchNorm1d(512), 72 | nn.ReLU()) 73 | 74 | self.conv3 = nn.Sequential(nn.Conv1d(512, 1024, 1), 75 | nn.BatchNorm1d(1024)) 76 | 77 | self.fc1 = nn.Sequential(nn.Linear(1024, 512), 78 | nn.BatchNorm1d(512), 79 | nn.ReLU()) 80 | 81 | self.fc2 = nn.Sequential(nn.Linear(512, 256), 82 | nn.Dropout(p=0.3), 83 | nn.BatchNorm1d(256), 84 | nn.ReLU()) 85 | 86 | self.fc3 = nn.Sequential(nn.Linear(256, dim)) 87 | 88 | def _forward(self, x): 89 | 90 | if self.tnet: 91 | trans = self.stn3d(x) 92 | xtrans = torch.bmm(trans, x) 93 | else: 94 | xtrans = x 95 | 96 | x = self.conv1(xtrans) 97 | x = self.conv2(x) 98 | x = self.conv3(x) 99 | 100 | mx, amx = torch.max(x, 2, keepdim=True) 101 | x = mx.view(-1, 1024) 102 | 103 | x = self.fc1(x) 104 | x = self.fc2(x) 105 | x = self.fc3(x) 106 | 107 | if self.l2norm: 108 | if self.tnet: 109 | return F.normalize(x, p=2, dim=1), xtrans, trans, mx, amx 110 | else: 111 | return F.normalize(x, p=2, dim=1), mx, amx 112 | else: 113 | return x, xtrans, trans, mx, amx 114 | 115 | def forward(self, xa, xp=torch.Tensor([]), trans=False): 116 | 117 | if xp.nelement() == 0: 118 | if trans or not self.tnet: 119 | out, mx, amx = self._forward(xa) 120 | return out, mx, amx 121 | else: 122 | out, _, _, mx, amx = self._forward(xa) 123 | return out, mx, amx 124 | else: 125 | if self.tnet: 126 | out1a, out1b, out1c, _, _ = self._forward(xa) 127 | out2a, out2b, out2c, _, _ = self._forward(xp) 128 | return out1a, out1b, out1c, out2a, out2b, out2c 129 | 130 | else: 131 | return self._forward(xa), self._forward(xp) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **HGL: Hierarchical Geometry Learning for Test-time Adaptation in 3D Point Cloud Segmentation [ECCV2024 Oral]** 2 | 3 | The official implementation of our work "HGL: Hierarchical Geometry Learning for Test-time Adaptation in 3D Point Cloud Segmentation". 4 | 5 | ![image](https://github.com/tpzou/HGL/blob/master/pic/fig_framework1.png) 6 | 7 | ## Introduction 8 | 3D point cloud segmentation has received significant interest for its growing applications. However, the generalization ability of models suffers in dynamic scenarios due to the distribution shift between test and training data. To promote robustness and adaptability across diverse scenarios, test-time adaptation (TTA) has recently been introduced. Nevertheless, most existing TTA methods are developed for images, and limited approaches applicable to point clouds ignore the inherent hierarchical geometric structures in point cloud streams, i.e., local (point-level), global (object-level), and temporal (frame-level) structures. In this paper, we delve into TTA in 3D point cloud segmentation and propose a novel Hierarchical Geometry Learning (HGL) framework. HGL comprises three complementary modules from local, global to temporal learning in a bottom-up manner. Technically, we first construct a local geometry learning module for pseudo-label generation. Next, we build prototypes from the global geometry perspective for pseudo-label fine-tuning. Furthermore, we introduce a temporal consistency regularization module to mitigate negative transfer. Extensive experiments on four datasets demonstrate the effectiveness and superiority of our HGL. Remarkably, on the SynLiDAR to SemanticKITTI task, HGL achieves an overall mIoU of 46.91\%, improving GIPSO by 3.0\% and significantly reducing the required adaptation time by 80\%. 9 | 10 | ### Environment 11 | - [MinkowskiEnginge](https://github.com/NVIDIA/MinkowskiEngine). 12 | - [open3d 0.13.0](http://www.open3d.org) 13 | - [KNN-CUDA](https://github.com/unlimblue/KNN_CUDA) 14 | - [pytorch-lighting 1.4.1](https://www.pytorchlightning.ai) 15 | - [wandb](https://docs.wandb.ai/quickstart) 16 | - [nuscenes-devkit](https://github.com/nutonomy/nuscenes-devkit) 17 | - tqdm 18 | - pickle 19 | 20 | ## NOTE !!!: 21 | - __This code is the initial version and I haven't had enough time to trim and optimize it. I will update it after the CVPR2025 deadline.__. The reason for publishing it in advance is to facilitate further exploration of the 3D TTA problem by other researchers. The code's mainly based on the [GIPSO](https://github.com/saltoricristiano/gipso-sfouda) framework and our main change points are [use_pseudo_new(LGL)](https://github.com/tpzou/HGL/blob/91523a12301c38cc8f436fd5a07ac0ee866d0685/pipelines/adaptation_online_single.py#L566), [use_prototype(GFG)](https://github.com/tpzou/HGL/blob/91523a12301c38cc8f436fd5a07ac0ee866d0685/pipelines/adaptation_online_single.py#L709C26-L709C39), [score_weight(TGR)](https://github.com/tpzou/HGL/blob/91523a12301c38cc8f436fd5a07ac0ee866d0685/pipelines/adaptation_online_single.py#L840C34-L840C46) and [SoftDICELoss](https://github.com/tpzou/HGL/blob/91523a12301c38cc8f436fd5a07ac0ee866d0685/utils/losses.py#L120C7-L120C19). 22 | - I will be working on optimizing the code, in the meantime feel free to contact me if you have any questions! 23 | 24 | ## Source training 25 | 26 | To train the source model on SynLiDAR 27 | ``` 28 | python train_lighting.py --config_file configs/source/synlidar_source.yaml 29 | ``` 30 | For Synth4D ``--config_file configs/source/synth4dkitti_source.yaml``. 31 | 32 | For nuScenes ``--config_file configs/source/synth4dnusc_source.yaml`` 33 | 34 | ## Pretrained models 35 | 36 | We use the pretrained models on Synth4D-KITTI, Synth4D-nuScenes and SynLIDAR provided by [GIPSO](https://github.com/saltoricristiano/gipso-sfouda). You can find the models [here](https://drive.google.com/file/d/1gT6KN1pYWj800qX54jAjWl5VGrHs8Owc/view?usp=sharing). 37 | For the model performance please refer to the main paper. 38 | 39 | After downloading the pretrained models decompress them in ```/pretrained_models```. 40 | 41 | 42 | ## Adaptation to target 43 | 44 | To adapt the source model SynLiDAR to the target domain SemanticKITTI 45 | 46 | ``` 47 | sh train.sh 48 | ``` 49 | If you want to save point cloud for future visualization you will need to add ``--save_predictions`` and they will be saved in ```pipeline.save_dir```. 50 | 51 | ## Thanks 52 | We thanks the open source projects [Minkowski-Engine](https://github.com/NVIDIA/MinkowskiEngine) and [GIPSO](https://github.com/saltoricristiano/gipso-sfouda). 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /utils/_resources/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | 38 | content: # as a ratio with the total number of points 39 | 0: 0.018889854628292943 40 | 1: 0.0002937197336781505 41 | 10: 0.040818519255974316 42 | 11: 0.00016609538710764618 43 | 13: 2.7879693665067774e-05 44 | 15: 0.00039838616015114444 45 | 16: 0.0 46 | 18: 0.0020633612104619787 47 | 20: 0.0016218197275284021 48 | 30: 0.00017698551338515307 49 | 31: 1.1065903904919655e-08 50 | 32: 5.532951952459828e-09 51 | 40: 0.1987493871255525 52 | 44: 0.014717169549888214 53 | 48: 0.14392298360372 54 | 49: 0.0039048553037472045 55 | 50: 0.1326861944777486 56 | 51: 0.0723592229456223 57 | 52: 0.002395131480328884 58 | 60: 4.7084144280367186e-05 59 | 70: 0.26681502148037506 60 | 71: 0.006035012012626033 61 | 72: 0.07814222006271769 62 | 80: 0.002855498193863172 63 | 81: 0.0006155958086189918 64 | 99: 0.009923127583046915 65 | 252: 0.001789309418528068 66 | 253: 0.00012709999297008662 67 | 254: 0.00016059776092534436 68 | 255: 3.745553104802113e-05 69 | 256: 0.0 70 | 257: 0.00011351574470342043 71 | 258: 0.00010157861367183268 72 | 259: 4.3840131989471124e-05 73 | # classes that are indistinguishable from single scan or inconsistent in 74 | # ground truth are mapped to their closest equivalent 75 | learning_map: 76 | 0 : -1 # "unlabeled" 77 | 1 : -1 # "outlier" mapped to "unlabeled" --------------------------mapped 78 | 10: 0 # "car" 79 | 11: -1 # "bicycle" 80 | 13: -1 # "bus" mapped to "other-vehicle" --------------------------mapped 81 | 15: -1 # "motorcycle" 82 | 16: -1 # "on-rails" mapped to "other-vehicle" ---------------------mapped 83 | 18: -1 # "truck" 84 | 20: -1 # "other-vehicle" 85 | 30: 1 # "person" 86 | 31: -1 # "bicyclist" 87 | 32: -1 # "motorcyclist" 88 | 40: 2 # "road" 89 | 44: 2 # "parking" 90 | 48: 3 # "sidewalk" 91 | 49: -1 # "other-ground" 92 | 50: 5 # "building" 93 | 51: 5 # "fence" 94 | 52: -1 # "other-structure" mapped to "unlabeled" ------------------mapped 95 | 60: 2 # "lane-marking" to "road" ---------------------------------mapped 96 | 70: 6 # "vegetation" 97 | 71: 6 # "trunk" 98 | 72: 4 # "terrain" 99 | 80: 5 # "pole" 100 | 81: 5 # "traffic-sign" 101 | 99: -1 # "other-object" to "unlabeled" ----------------------------mapped 102 | 252: 0 # "moving-car" to "car" ------------------------------------mapped 103 | 253: -1 # "moving-bicyclist" to "bicyclist" ------------------------mapped 104 | 254: 1 # "moving-person" to "person" ------------------------------mapped 105 | 255: -1 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 106 | 256: -1 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 107 | 257: -1 # "moving-bus" mapped to "other-vehicle" -------------------mapped 108 | 258: -1 # "moving-truck" to "truck" --------------------------------mapped 109 | 259: -1 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 110 | learning_map_inv: # inverse of previous map 111 | -1: 0 # "unlabeled", and others ignored 112 | 0: 10 # "vehicle" 113 | 1: 30 # "person" 114 | 2: 40 # "road" 115 | 3: 48 # "sidewalk" 116 | 4: 72 # "terrain" 117 | 5: 50 # "manmade" 118 | 6: 70 # "vegetation" 119 | 120 | learning_ignore: # Ignore classes 121 | -1: True # "unlabeled", and others ignored 122 | 0: False # "vehicle" 123 | 1: False # "pedestrian" 124 | 2: False # "road" 125 | 3: False # "sidewalk" 126 | 4: False # "terrain" 127 | 5: False # "manmade" 128 | 6: False # "vegetation" 129 | -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import logging 4 | import numpy as np 5 | import scipy 6 | import scipy.ndimage 7 | import scipy.interpolate 8 | import torch 9 | 10 | import MinkowskiEngine as ME 11 | 12 | 13 | # A sparse tensor consists of coordinates and associated features. 14 | # You must apply augmentation to both. 15 | 16 | ############################## 17 | # Coordinate transformations 18 | ############################## 19 | class RandomDropout(object): 20 | 21 | def __init__(self, dropout_ratio=0.2, dropout_application_ratio=0.5): 22 | """ 23 | upright_axis: axis index among x,y,z, i.e. 2 for z 24 | """ 25 | self.dropout_ratio = dropout_ratio 26 | self.dropout_application_ratio = dropout_application_ratio 27 | 28 | def __call__(self, coords, feats, labels): 29 | if random.random() < self.dropout_ratio: 30 | N = len(coords) 31 | inds = np.random.choice(N, int(N * (1 - self.dropout_ratio)), replace=False) 32 | return coords[inds], feats[inds], labels[inds] 33 | return coords, feats, labels 34 | 35 | 36 | class RandomHorizontalFlip(object): 37 | 38 | def __init__(self, upright_axis, is_temporal): 39 | """ 40 | upright_axis: axis index among x,y,z, i.e. 2 for z 41 | """ 42 | self.is_temporal = is_temporal 43 | self.D = 4 if is_temporal else 3 44 | self.upright_axis = {'x': 0, 'y': 1, 'z': 2}[upright_axis.lower()] 45 | # Use the rest of axes for flipping. 46 | self.horz_axes = set(range(self.D)) - set([self.upright_axis]) 47 | 48 | def __call__(self, coords, feats, labels): 49 | if random.random() < 0.95: 50 | for curr_ax in self.horz_axes: 51 | if random.random() < 0.5: 52 | coord_max = np.max(coords[:, curr_ax]) 53 | coords[:, curr_ax] = coord_max - coords[:, curr_ax] 54 | return coords, feats, labels 55 | 56 | 57 | class ElasticDistortion: 58 | 59 | def __init__(self, distortion_params): 60 | self.distortion_params = distortion_params 61 | 62 | def elastic_distortion(self, coords, feats, labels, granularity, magnitude): 63 | """Apply elastic distortion on sparse coordinate space. 64 | pointcloud: numpy array of (number of points, at least 3 spatial dims) 65 | granularity: size of the noise grid (in same scale[m/cm] as the voxel grid) 66 | magnitude: noise multiplier 67 | """ 68 | blurx = np.ones((3, 1, 1, 1)).astype('float32') / 3 69 | blury = np.ones((1, 3, 1, 1)).astype('float32') / 3 70 | blurz = np.ones((1, 1, 3, 1)).astype('float32') / 3 71 | coords_min = coords.min(0) 72 | 73 | # Create Gaussian noise tensor of the size given by granularity. 74 | noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3 75 | noise = np.random.randn(*noise_dim, 3).astype(np.float32) 76 | 77 | # Smoothing. 78 | for _ in range(2): 79 | noise = scipy.ndimage.filters.convolve(noise, blurx, mode='constant', cval=0) 80 | noise = scipy.ndimage.filters.convolve(noise, blury, mode='constant', cval=0) 81 | noise = scipy.ndimage.filters.convolve(noise, blurz, mode='constant', cval=0) 82 | 83 | # Trilinear interpolate noise filters for each spatial dimensions. 84 | ax = [ 85 | np.linspace(d_min, d_max, d) 86 | for d_min, d_max, d in zip(coords_min - granularity, coords_min + granularity * 87 | (noise_dim - 2), noise_dim) 88 | ] 89 | interp = scipy.interpolate.RegularGridInterpolator(ax, noise, bounds_error=0, fill_value=0) 90 | coords += interp(coords) * magnitude 91 | return coords, feats, labels 92 | 93 | def __call__(self, coords, feats, labels): 94 | if self.distortion_params is not None: 95 | if random.random() < 0.95: 96 | for granularity, magnitude in self.distortion_params: 97 | coords, feats, labels = self.elastic_distortion(coords, feats, labels, granularity, 98 | magnitude) 99 | return coords, feats, labels 100 | 101 | 102 | class Compose(object): 103 | """Composes several transforms together.""" 104 | 105 | def __init__(self, transforms): 106 | self.transforms = transforms 107 | 108 | def __call__(self, *args): 109 | for t in self.transforms: 110 | args = t(*args) 111 | return args 112 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.callbacks import Callback 4 | from pytorch_lightning.utilities import rank_zero_only 5 | 6 | import os 7 | from argparse import ArgumentParser, Namespace 8 | from pathlib import Path 9 | from typing import Optional, Union 10 | 11 | 12 | class SourceCheckpoint(Callback): 13 | @rank_zero_only 14 | def on_save_checkpoint(self, trainer, pl_module, checkpoint): 15 | 16 | checkpoint_filename = ( 17 | "-".join( 18 | ["source", 19 | pl_module.hparams.training_dataset.name, 20 | str(trainer.current_epoch)] 21 | ) 22 | + ".pth" 23 | ) 24 | os.makedirs(os.path.join(trainer.weights_save_path, 'source_checkpoints'), exist_ok=True) 25 | checkpoint_path = os.path.join(trainer.weights_save_path, 'source_checkpoints', checkpoint_filename) 26 | torch.save(pl_module.model.state_dict(), checkpoint_path) 27 | 28 | 29 | # credits to https://github.com/vturrisi/solo-learn/blob/main/solo/utils/checkpointer.py 30 | class SOLOCheckpointer(Callback): 31 | def __init__( 32 | self, 33 | args: Namespace, 34 | logdir: Union[str, Path] = Path("trained_models"), 35 | frequency: int = 1, 36 | keep_previous_checkpoints: bool = False, 37 | ): 38 | """Custom checkpointer callback that stores checkpoints in an easier to access way. 39 | Args: 40 | args (Namespace): namespace object containing at least an attribute name. 41 | logdir (Union[str, Path], optional): base directory to store checkpoints. 42 | Defaults to "trained_models". 43 | frequency (int, optional): number of epochs between each checkpoint. Defaults to 1. 44 | keep_previous_checkpoints (bool, optional): whether to keep previous checkpoints or not. 45 | Defaults to False. 46 | """ 47 | 48 | super().__init__() 49 | 50 | self.args = args 51 | self.logdir = Path(logdir) 52 | self.frequency = frequency 53 | self.keep_previous_checkpoints = keep_previous_checkpoints 54 | 55 | @staticmethod 56 | def add_checkpointer_args(parent_parser: ArgumentParser): 57 | """Adds user-required arguments to a parser. 58 | Args: 59 | parent_parser (ArgumentParser): parser to add new args to. 60 | """ 61 | 62 | parser = parent_parser.add_argument_group("checkpointer") 63 | parser.add_argument("--checkpoint_dir", default=Path("trained_models"), type=Path) 64 | parser.add_argument("--checkpoint_frequency", default=1, type=int) 65 | return parent_parser 66 | 67 | def initial_setup(self, trainer: pl.Trainer): 68 | """Creates the directories and does the initial setup needed. 69 | Args: 70 | trainer (pl.Trainer): pytorch lightning trainer object. 71 | """ 72 | 73 | if trainer.logger is None: 74 | version = None 75 | else: 76 | version = str(trainer.logger.version) 77 | if version is not None: 78 | self.path = self.logdir / version 79 | self.ckpt_placeholder = f"{self.args.name}-{version}" + "-ep={}.ckpt" 80 | else: 81 | self.path = self.logdir 82 | self.ckpt_placeholder = f"{self.args.name}" + "-ep={}.ckpt" 83 | self.last_ckpt: Optional[str] = None 84 | 85 | # create logging dirs 86 | if trainer.is_global_zero: 87 | os.makedirs(self.path, exist_ok=True) 88 | 89 | def save_args(self, trainer: pl.Trainer): 90 | """Stores arguments into a json file. 91 | Args: 92 | trainer (pl.Trainer): pytorch lightning trainer object. 93 | """ 94 | 95 | if trainer.is_global_zero: 96 | args = vars(self.args) 97 | json_path = self.path / "args.json" 98 | json.dump(args, open(json_path, "w"), default=lambda o: "") 99 | 100 | def save(self, trainer: pl.Trainer): 101 | """Saves current checkpoint. 102 | Args: 103 | trainer (pl.Trainer): pytorch lightning trainer object. 104 | """ 105 | 106 | if trainer.is_global_zero and not trainer.running_sanity_check: 107 | epoch = trainer.current_epoch # type: ignore 108 | ckpt = self.path / self.ckpt_placeholder.format(epoch) 109 | trainer.save_checkpoint(ckpt) 110 | 111 | if self.last_ckpt and self.last_ckpt != ckpt and not self.keep_previous_checkpoints: 112 | os.remove(self.last_ckpt) 113 | self.last_ckpt = ckpt 114 | 115 | def on_train_start(self, trainer: pl.Trainer, _): 116 | """Executes initial setup and saves arguments. 117 | Args: 118 | trainer (pl.Trainer): pytorch lightning trainer object. 119 | """ 120 | 121 | self.initial_setup(trainer) 122 | self.save_args(trainer) 123 | 124 | def on_validation_end(self, trainer: pl.Trainer, _): 125 | """Tries to save current checkpoint at the end of each validation epoch. 126 | Args: 127 | trainer (pl.Trainer): pytorch lightning trainer object. 128 | """ 129 | 130 | epoch = trainer.current_epoch # type: ignore 131 | if epoch % self.frequency == 0: 132 | self.save(trainer) -------------------------------------------------------------------------------- /train_lighting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | from pytorch_lightning.loggers import WandbLogger 11 | from utils.logger import CSVLogger 12 | import MinkowskiEngine as ME 13 | 14 | import models 15 | from utils.dataset import get_dataset 16 | from utils.config import get_config 17 | from utils.collation import CollateFN 18 | from utils.callbacks import SourceCheckpoint 19 | from pipelines import PLTOneDomainTrainer 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--config_file", 23 | default="configs/source/synth4dkitti_source.yaml", 24 | type=str, 25 | help="Path to config file") 26 | 27 | # AUG_DICT = {'RandomDropout': [0.2, 0.5]} 28 | AUG_DICT = None 29 | 30 | 31 | def train(config): 32 | 33 | def get_dataloader(dataset, shuffle=False, pin_memory=True): 34 | return DataLoader(dataset, 35 | batch_size=config.pipeline.dataloader.batch_size, 36 | collate_fn=CollateFN(), 37 | shuffle=shuffle, 38 | num_workers=config.pipeline.dataloader.num_workers, 39 | pin_memory=pin_memory) 40 | try: 41 | mapping_path = config.dataset.mapping_path 42 | except AttributeError('--> Setting default class mapping path!'): 43 | mapping_path = None 44 | 45 | training_dataset, validation_dataset, target_dataset = get_dataset(dataset_name=config.dataset.name, 46 | dataset_path=config.dataset.dataset_path, 47 | voxel_size=config.dataset.voxel_size, 48 | augment_data=config.dataset.augment_data, 49 | aug_parameters=AUG_DICT, 50 | version=config.dataset.version, 51 | sub_num=config.dataset.num_pts, 52 | get_target=config.dataset.validate_target, 53 | target_dataset_path=config.dataset.target_path, 54 | num_classes=config.model.out_classes, 55 | ignore_label=config.dataset.ignore_label, 56 | mapping_path=mapping_path) 57 | 58 | training_dataloader = get_dataloader(training_dataset, shuffle=True) 59 | validation_dataloader = get_dataloader(validation_dataset, shuffle=False) 60 | 61 | if target_dataset is not None: 62 | target_dataloader = get_dataloader(target_dataset, shuffle=False) 63 | validation_dataloader = [validation_dataloader, target_dataloader] 64 | else: 65 | validation_dataloader = [validation_dataloader] 66 | 67 | # coords = [N, [x, y, z]], feats=[N, f] -> f [i] ----- [x, y, z, i] 68 | # model = MinkUNet34C(1, 8) 69 | Model = getattr(models, config.model.name) 70 | model = Model(config.model.in_feat_size, config.model.out_classes) 71 | 72 | model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model) 73 | 74 | pl_module = PLTOneDomainTrainer(training_dataset=training_dataset, 75 | validation_dataset=validation_dataset, 76 | model=model, 77 | criterion=config.pipeline.loss, 78 | optimizer_name=config.pipeline.optimizer.name, 79 | batch_size=config.pipeline.dataloader.batch_size, 80 | val_batch_size=config.pipeline.dataloader.batch_size, 81 | lr=config.pipeline.optimizer.lr, 82 | num_classes=config.model.out_classes, 83 | train_num_workers=config.pipeline.dataloader.num_workers, 84 | val_num_workers=config.pipeline.dataloader.num_workers, 85 | clear_cache_int=config.pipeline.lightning.clear_cache_int, 86 | scheduler_name=config.pipeline.scheduler.scheduler_name) 87 | 88 | run_time = time.strftime("%Y_%m_%d_%H:%M", time.gmtime()) 89 | if config.pipeline.wandb.run_name is not None: 90 | run_name = run_time + '_' + config.pipeline.wandb.run_name 91 | else: 92 | run_name = run_time 93 | 94 | save_dir = os.path.join(config.pipeline.save_dir, run_name) 95 | 96 | wandb_logger = WandbLogger(project=config.pipeline.wandb.project_name, 97 | name=run_name, 98 | offline=config.pipeline.wandb.offline) 99 | csv_logger = CSVLogger(save_dir=save_dir, 100 | name=run_name, 101 | version='logs') 102 | 103 | loggers = [wandb_logger, csv_logger] 104 | 105 | checkpoint_callback = [ModelCheckpoint(dirpath=os.path.join(save_dir, 'checkpoints'), save_top_k=-1), 106 | SourceCheckpoint()] 107 | 108 | trainer = Trainer(max_epochs=config.pipeline.epochs, 109 | gpus=config.pipeline.gpus, 110 | accelerator="ddp", 111 | default_root_dir=config.pipeline.save_dir, 112 | weights_save_path=save_dir, 113 | precision=config.pipeline.precision, 114 | logger=loggers, 115 | check_val_every_n_epoch=config.pipeline.lightning.check_val_every_n_epoch, 116 | val_check_interval=1.0, 117 | num_sanity_val_steps=0, 118 | resume_from_checkpoint=config.pipeline.lightning.resume_checkpoint, 119 | callbacks=checkpoint_callback) 120 | 121 | trainer.fit(pl_module, 122 | train_dataloaders=training_dataloader, 123 | val_dataloaders=validation_dataloader) 124 | 125 | 126 | if __name__ == '__main__': 127 | args = parser.parse_args() 128 | 129 | config = get_config(args.config_file) 130 | 131 | # fix random seed 132 | os.environ['PYTHONHASHSEED'] = str(config.pipeline.seed) 133 | np.random.seed(config.pipeline.seed) 134 | torch.manual_seed(config.pipeline.seed) 135 | torch.cuda.manual_seed(config.pipeline.seed) 136 | torch.backends.cudnn.benchmark = True 137 | 138 | train(config) 139 | -------------------------------------------------------------------------------- /utils/voxelizer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import open3d as o3d 3 | import numpy as np 4 | import MinkowskiEngine as ME 5 | from scipy.linalg import expm, norm 6 | import os 7 | 8 | # Rotation matrix along axis with angle theta 9 | def M(axis, theta): 10 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)) 11 | 12 | 13 | class Voxelizer: 14 | 15 | def __init__(self, 16 | voxel_size=0.05, 17 | clip_bound=None, 18 | use_augmentation=False, 19 | scale_augmentation_bound=None, 20 | rotation_augmentation_bound=None, 21 | translation_augmentation_ratio_bound=None, 22 | ignore_label=255): 23 | """ 24 | Args: 25 | voxel_size: side length of a voxel 26 | clip_bound: boundary of the voxelizer. Points outside the bound will be deleted 27 | expects either None or an array like ((-100, 100), (-100, 100), (-100, 100)). 28 | scale_augmentation_bound: None or (0.9, 1.1) 29 | rotation_augmentation_bound: None or ((np.pi / 6, np.pi / 6), None, None) for 3 axis. 30 | Use random order of x, y, z to prevent bias. 31 | translation_augmentation_bound: ((-5, 5), (0, 0), (-10, 10)) 32 | ignore_label: label assigned for ignore (not a training label). 33 | """ 34 | self.voxel_size = voxel_size 35 | self.clip_bound = clip_bound 36 | if ignore_label is not None: 37 | self.ignore_label = ignore_label 38 | else: 39 | self.ignore_label = -100 40 | # Augmentation 41 | self.use_augmentation = use_augmentation 42 | self.scale_augmentation_bound = scale_augmentation_bound 43 | self.rotation_augmentation_bound = rotation_augmentation_bound 44 | self.translation_augmentation_ratio_bound = translation_augmentation_ratio_bound 45 | 46 | def get_transformation_matrix(self): 47 | voxelization_matrix, rotation_matrix = np.eye(4), np.eye(4) 48 | 49 | # Transform pointcloud coordinate to voxel coordinate. 50 | # 1. Random rotation 51 | rot_mat = np.eye(3) 52 | if self.use_augmentation and self.rotation_augmentation_bound is not None: 53 | if isinstance(self.rotation_augmentation_bound, collections.Iterable): 54 | rot_mats = [] 55 | for axis_ind, rot_bound in enumerate(self.rotation_augmentation_bound): 56 | theta = 0 57 | axis = np.zeros(3) 58 | axis[axis_ind] = 1 59 | if rot_bound is not None: 60 | theta = np.random.uniform(*rot_bound) 61 | rot_mats.append(M(axis, theta)) 62 | # Use random order 63 | np.random.shuffle(rot_mats) 64 | rot_mat = rot_mats[0] @ rot_mats[1] @ rot_mats[2] 65 | else: 66 | raise ValueError() 67 | rotation_matrix[:3, :3] = rot_mat 68 | # 2. Scale and translate to the voxel space. 69 | scale = 1 70 | if self.use_augmentation and self.scale_augmentation_bound is not None: 71 | scale *= np.random.uniform(*self.scale_augmentation_bound) 72 | np.fill_diagonal(voxelization_matrix[:3, :3], scale) 73 | 74 | # 3. Translate 75 | if self.use_augmentation and self.translation_augmentation_ratio_bound is not None: 76 | tr = [np.random.uniform(*t) for t in self.translation_augmentation_ratio_bound] 77 | rotation_matrix[:3, 3] = tr 78 | # Get final transformation matrix. 79 | return voxelization_matrix, rotation_matrix 80 | 81 | def clip(self, coords, center=None, trans_aug_ratio=None): 82 | bound_min = np.min(coords, 0).astype(float) 83 | bound_max = np.max(coords, 0).astype(float) 84 | bound_size = bound_max - bound_min 85 | if center is None: 86 | center = bound_min + bound_size * 0.5 87 | if trans_aug_ratio is not None: 88 | trans = np.multiply(trans_aug_ratio, bound_size) 89 | center += trans 90 | lim = self.clip_bound 91 | 92 | if isinstance(self.clip_bound, (int, float)): 93 | if bound_size.max() < self.clip_bound: 94 | return None 95 | else: 96 | clip_inds = ((coords[:, 0] >= (-lim + center[0])) & \ 97 | (coords[:, 0] < (lim + center[0])) & \ 98 | (coords[:, 1] >= (-lim + center[1])) & \ 99 | (coords[:, 1] < (lim + center[1])) & \ 100 | (coords[:, 2] >= (-lim + center[2])) & \ 101 | (coords[:, 2] < (lim + center[2]))) 102 | return clip_inds 103 | 104 | # Clip points outside the limit 105 | clip_inds = ((coords[:, 0] >= (lim[0][0] + center[0])) & \ 106 | (coords[:, 0] < (lim[0][1] + center[0])) & \ 107 | (coords[:, 1] >= (lim[1][0] + center[1])) & \ 108 | (coords[:, 1] < (lim[1][1] + center[1])) & \ 109 | (coords[:, 2] >= (lim[2][0] + center[2])) & \ 110 | (coords[:, 2] < (lim[2][1] + center[2]))) 111 | return clip_inds 112 | 113 | def voxelize(self, coords, feats, labels, center=None): 114 | 115 | assert coords.shape[1] == 3 and coords.shape[0] == feats.shape[0] and coords.shape[0] 116 | # if self.clip_bound is not None: 117 | # trans_aug_ratio = np.zeros(3) 118 | # if self.use_augmentation and self.translation_augmentation_ratio_bound is not None: 119 | # for axis_ind, trans_ratio_bound in enumerate(self.translation_augmentation_ratio_bound): 120 | # trans_aug_ratio[axis_ind] = np.random.uniform(*trans_ratio_bound) 121 | # 122 | # clip_inds = self.clip(coords, center, trans_aug_ratio) 123 | # if clip_inds is not None: 124 | # coords, feats = coords[clip_inds], feats[clip_inds] 125 | # if labels is not None: 126 | # labels = labels[clip_inds] 127 | 128 | M_v, M_r = self.get_transformation_matrix() 129 | rigid_transformation = M_v 130 | # Apply transformations 131 | if self.use_augmentation: 132 | # Get rotation and scale 133 | rigid_transformation = M_r @ rigid_transformation 134 | 135 | homo_coords = np.hstack((coords, np.ones((coords.shape[0], 1), dtype=coords.dtype))) 136 | # coords = np.floor(homo_coords @ rigid_transformation.T[:, :3]) 137 | coords = homo_coords @ rigid_transformation.T[:, :3] 138 | 139 | # key = self.hash(coords_aug) # floor happens by astype(np.uint64) 140 | coords, feats, labels = ME.utils.sparse_quantize(coords, 141 | feats, 142 | labels=labels, 143 | ignore_label=self.ignore_label, 144 | quantization_size=self.voxel_size) 145 | return coords, feats, labels -------------------------------------------------------------------------------- /models/resunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | import MinkowskiEngine.MinkowskiFunctional as MEF 4 | from models.common import get_norm 5 | 6 | from models.residual_block import get_block 7 | 8 | 9 | class ResUNet2(ME.MinkowskiNetwork): 10 | NORM_TYPE = None 11 | BLOCK_NORM_TYPE = 'BN' 12 | CHANNELS = [None, 32, 64, 128, 256] 13 | TR_CHANNELS = [None, 32, 64, 64, 128] 14 | 15 | # To use the model, must call initialize_coords before forward pass. 16 | # Once data is processed, call clear to reset the model before calling initialize_coords 17 | def __init__(self, 18 | in_channels=3, 19 | out_channels=32, 20 | bn_momentum=0.1, 21 | normalize_feature=None, 22 | conv1_kernel_size=None, 23 | D=3): 24 | ME.MinkowskiNetwork.__init__(self, D) 25 | NORM_TYPE = self.NORM_TYPE 26 | BLOCK_NORM_TYPE = self.BLOCK_NORM_TYPE 27 | CHANNELS = self.CHANNELS 28 | TR_CHANNELS = self.TR_CHANNELS 29 | self.normalize_feature = normalize_feature 30 | self.conv1 = ME.MinkowskiConvolution( 31 | in_channels=in_channels, 32 | out_channels=CHANNELS[1], 33 | kernel_size=conv1_kernel_size, 34 | stride=1, 35 | dilation=1, 36 | bias=False, 37 | dimension=D) 38 | self.norm1 = get_norm(NORM_TYPE, CHANNELS[1], bn_momentum=bn_momentum, D=D) 39 | 40 | self.block1 = get_block( 41 | BLOCK_NORM_TYPE, CHANNELS[1], CHANNELS[1], bn_momentum=bn_momentum, D=D) 42 | 43 | self.conv2 = ME.MinkowskiConvolution( 44 | in_channels=CHANNELS[1], 45 | out_channels=CHANNELS[2], 46 | kernel_size=3, 47 | stride=2, 48 | dilation=1, 49 | bias=False, 50 | dimension=D) 51 | self.norm2 = get_norm(NORM_TYPE, CHANNELS[2], bn_momentum=bn_momentum, D=D) 52 | 53 | self.block2 = get_block( 54 | BLOCK_NORM_TYPE, CHANNELS[2], CHANNELS[2], bn_momentum=bn_momentum, D=D) 55 | 56 | self.conv3 = ME.MinkowskiConvolution( 57 | in_channels=CHANNELS[2], 58 | out_channels=CHANNELS[3], 59 | kernel_size=3, 60 | stride=2, 61 | dilation=1, 62 | bias=False, 63 | dimension=D) 64 | self.norm3 = get_norm(NORM_TYPE, CHANNELS[3], bn_momentum=bn_momentum, D=D) 65 | 66 | self.block3 = get_block( 67 | BLOCK_NORM_TYPE, CHANNELS[3], CHANNELS[3], bn_momentum=bn_momentum, D=D) 68 | 69 | self.conv4 = ME.MinkowskiConvolution( 70 | in_channels=CHANNELS[3], 71 | out_channels=CHANNELS[4], 72 | kernel_size=3, 73 | stride=2, 74 | dilation=1, 75 | bias=False, 76 | dimension=D) 77 | self.norm4 = get_norm(NORM_TYPE, CHANNELS[4], bn_momentum=bn_momentum, D=D) 78 | 79 | self.block4 = get_block( 80 | BLOCK_NORM_TYPE, CHANNELS[4], CHANNELS[4], bn_momentum=bn_momentum, D=D) 81 | 82 | self.conv4_tr = ME.MinkowskiConvolutionTranspose( 83 | in_channels=CHANNELS[4], 84 | out_channels=TR_CHANNELS[4], 85 | kernel_size=3, 86 | stride=2, 87 | dilation=1, 88 | bias=False, 89 | dimension=D) 90 | self.norm4_tr = get_norm(NORM_TYPE, TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 91 | 92 | self.block4_tr = get_block( 93 | BLOCK_NORM_TYPE, TR_CHANNELS[4], TR_CHANNELS[4], bn_momentum=bn_momentum, D=D) 94 | 95 | self.conv3_tr = ME.MinkowskiConvolutionTranspose( 96 | in_channels=CHANNELS[3] + TR_CHANNELS[4], 97 | out_channels=TR_CHANNELS[3], 98 | kernel_size=3, 99 | stride=2, 100 | dilation=1, 101 | bias=False, 102 | dimension=D) 103 | self.norm3_tr = get_norm(NORM_TYPE, TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 104 | 105 | self.block3_tr = get_block( 106 | BLOCK_NORM_TYPE, TR_CHANNELS[3], TR_CHANNELS[3], bn_momentum=bn_momentum, D=D) 107 | 108 | self.conv2_tr = ME.MinkowskiConvolutionTranspose( 109 | in_channels=CHANNELS[2] + TR_CHANNELS[3], 110 | out_channels=TR_CHANNELS[2], 111 | kernel_size=3, 112 | stride=2, 113 | dilation=1, 114 | bias=False, 115 | dimension=D) 116 | self.norm2_tr = get_norm(NORM_TYPE, TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 117 | 118 | self.block2_tr = get_block( 119 | BLOCK_NORM_TYPE, TR_CHANNELS[2], TR_CHANNELS[2], bn_momentum=bn_momentum, D=D) 120 | 121 | self.conv1_tr = ME.MinkowskiConvolution( 122 | in_channels=CHANNELS[1] + TR_CHANNELS[2], 123 | out_channels=TR_CHANNELS[1], 124 | kernel_size=1, 125 | stride=1, 126 | dilation=1, 127 | bias=False, 128 | dimension=D) 129 | 130 | # self.block1_tr = BasicBlockBN(TR_CHANNELS[1], TR_CHANNELS[1], bn_momentum=bn_momentum, D=D) 131 | 132 | self.final = ME.MinkowskiConvolution( 133 | in_channels=TR_CHANNELS[1], 134 | out_channels=out_channels, 135 | kernel_size=1, 136 | stride=1, 137 | dilation=1, 138 | bias=True, 139 | dimension=D) 140 | 141 | def forward(self, x): 142 | out_s1 = self.conv1(x) 143 | out_s1 = self.norm1(out_s1) 144 | out_s1 = self.block1(out_s1) 145 | out = MEF.relu(out_s1) 146 | 147 | out_s2 = self.conv2(out) 148 | out_s2 = self.norm2(out_s2) 149 | out_s2 = self.block2(out_s2) 150 | out = MEF.relu(out_s2) 151 | 152 | out_s4 = self.conv3(out) 153 | out_s4 = self.norm3(out_s4) 154 | out_s4 = self.block3(out_s4) 155 | out = MEF.relu(out_s4) 156 | 157 | out_s8 = self.conv4(out) 158 | out_s8 = self.norm4(out_s8) 159 | out_s8 = self.block4(out_s8) 160 | out = MEF.relu(out_s8) 161 | 162 | out = self.conv4_tr(out) 163 | out = self.norm4_tr(out) 164 | out = self.block4_tr(out) 165 | out_s4_tr = MEF.relu(out) 166 | 167 | out = ME.cat(out_s4_tr, out_s4) 168 | 169 | out = self.conv3_tr(out) 170 | out = self.norm3_tr(out) 171 | out = self.block3_tr(out) 172 | out_s2_tr = MEF.relu(out) 173 | 174 | out = ME.cat(out_s2_tr, out_s2) 175 | 176 | out = self.conv2_tr(out) 177 | out = self.norm2_tr(out) 178 | out = self.block2_tr(out) 179 | out_s1_tr = MEF.relu(out) 180 | 181 | out = ME.cat(out_s1_tr, out_s1) 182 | out = self.conv1_tr(out) 183 | out = MEF.relu(out) 184 | out = self.final(out) 185 | 186 | if self.normalize_feature: 187 | return ME.SparseTensor( 188 | out.F / torch.norm(out.F, p=2, dim=1, keepdim=True), 189 | coordinate_map_key=out.coordinate_map_key, 190 | coordinate_manager=out.coordinate_manager) 191 | else: 192 | return out 193 | 194 | 195 | class ResUNetBN2(ResUNet2): 196 | NORM_TYPE = 'BN' 197 | 198 | 199 | class ResUNetBN2B(ResUNet2): 200 | NORM_TYPE = 'BN' 201 | CHANNELS = [None, 32, 64, 128, 256] 202 | TR_CHANNELS = [None, 64, 64, 64, 64] 203 | 204 | 205 | class ResUNetBN2C(ResUNet2): 206 | NORM_TYPE = 'BN' 207 | CHANNELS = [None, 32, 64, 128, 256] 208 | TR_CHANNELS = [None, 64, 64, 64, 128] 209 | 210 | 211 | class ResUNetBN2D(ResUNet2): 212 | NORM_TYPE = 'BN' 213 | CHANNELS = [None, 32, 64, 128, 256] 214 | TR_CHANNELS = [None, 64, 64, 128, 128] 215 | 216 | 217 | class ResUNetBN2E(ResUNet2): 218 | NORM_TYPE = 'BN' 219 | CHANNELS = [None, 128, 128, 128, 256] 220 | TR_CHANNELS = [None, 64, 128, 128, 128] 221 | 222 | 223 | class ResUNetIN2(ResUNet2): 224 | NORM_TYPE = 'BN' 225 | BLOCK_NORM_TYPE = 'IN' 226 | 227 | 228 | class ResUNetIN2B(ResUNetBN2B): 229 | NORM_TYPE = 'BN' 230 | BLOCK_NORM_TYPE = 'IN' 231 | 232 | 233 | class ResUNetIN2C(ResUNetBN2C): 234 | NORM_TYPE = 'BN' 235 | BLOCK_NORM_TYPE = 'IN' 236 | 237 | 238 | class ResUNetIN2D(ResUNetBN2D): 239 | NORM_TYPE = 'BN' 240 | BLOCK_NORM_TYPE = 'IN' 241 | 242 | 243 | class ResUNetIN2E(ResUNetBN2E): 244 | NORM_TYPE = 'BN' 245 | BLOCK_NORM_TYPE = 'IN' -------------------------------------------------------------------------------- /utils/collation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | 4 | 5 | # def minkowski_collate_fn(list_data): 6 | # r""" 7 | # Collation function for MinkowskiEngine.SparseTensor that creates batched 8 | # coordinates given a list of dictionaries. 9 | # """ 10 | # coordinates_batch, features_batch, labels_batch = ME.utils.batch_sparse_collate( 11 | # [d["coordinates"] for d in list_data], 12 | # [d["features"] for d in list_data], 13 | # [d["labels"] for d in list_data]) 14 | # return { 15 | # "coordinates": coordinates_batch, 16 | # "features": features_batch, 17 | # "labels": labels_batch, 18 | # } 19 | 20 | class CollateFN: 21 | def __init__(self, device=None): 22 | self.device = device 23 | 24 | def __call__(self, list_data): 25 | r""" 26 | Collation function for MinkowskiEngine.SparseTensor that creates batched 27 | coordinates given a list of dictionaries. 28 | """ 29 | # batched_coords = [d['coordinates'] for d in list_data] 30 | # batched_coords = ME.utils.batched_coordinates(batched_coords, dtype=torch.float32) 31 | 32 | list_data = [(torch.from_numpy(d["coordinates"]).to(self.device), d["features"].to(self.device), d["labels"]) for d in list_data] 33 | 34 | coordinates_batch, features_batch, labels_batch = ME.utils.SparseCollation(dtype=torch.float32, 35 | device=self.device)(list_data) 36 | return {"coordinates": coordinates_batch, 37 | "features": features_batch, 38 | "labels": labels_batch} 39 | 40 | 41 | class CollateMixed: 42 | def __init__(self, device=None): 43 | self.device = device 44 | 45 | def __call__(self, list_data): 46 | r""" 47 | Collation function for MinkowskiEngine.SparseTensor that creates batched 48 | coordinates given a list of dictionaries. 49 | """ 50 | actual_matches_list = [d["matches"] for d in list_data] 51 | next_matches_list = [d["next_matches"] for d in list_data] 52 | 53 | matches_list = [None] * (len(actual_matches_list)+len(next_matches_list)) 54 | matches_list[::2] = actual_matches_list 55 | matches_list[1::2] = next_matches_list 56 | 57 | matches = torch.cat(matches_list) 58 | 59 | match_idx0 = torch.where(matches == 0)[0] 60 | match_idx1 = torch.where(matches == 1)[0] 61 | 62 | next_list_data = [(d["next_coordinates"].to(self.device), d["next_features"].to(self.device), d["next_labels"].to(self.device)) for d in list_data] 63 | actual_list_data = [(d["coordinates"].to(self.device), d["features"].to(self.device), d["labels"].to(self.device)) for d in list_data] 64 | 65 | batch_list = [None] * (len(next_list_data)+len(actual_list_data)) 66 | batch_list[::2] = actual_list_data 67 | batch_list[1::2] = actual_list_data 68 | 69 | coordinates_batch, features_batch, labels_batch = ME.utils.SparseCollation(dtype=torch.float32)(batch_list) 70 | 71 | return {"coordinates": coordinates_batch, 72 | "features": features_batch, 73 | "labels": labels_batch, 74 | "matches": matches, 75 | "fwd_match": match_idx0, 76 | "bck_match": match_idx1} 77 | 78 | 79 | class CollateSeparated: 80 | def __init__(self, device=None): 81 | self.device = device 82 | 83 | def __call__(self, list_data): 84 | r""" 85 | Collation function for MinkowskiEngine.SparseTensor that creates batched 86 | coordinates given a list of dictionaries. 87 | """ 88 | 89 | matches_list0 = [] 90 | matches_list1 = [] 91 | num_data = len(list_data) 92 | 93 | list_data0 = [] 94 | list_data1 = [] 95 | 96 | list_num_pts0 = [] 97 | list_num_pts1 = [] 98 | 99 | list_all = [] 100 | list_selected = [] 101 | 102 | list_global0 = [] 103 | list_global1 = [] 104 | 105 | start_pts0 = 0 106 | start_pts1 = 0 107 | 108 | for d in range(num_data): 109 | 110 | # shift for minkowski forward and append 111 | matches_list0.append(list_data[d]["matches0"] + start_pts0) 112 | matches_list1.append(list_data[d]["matches1"] + start_pts1) 113 | 114 | start_pts0 += list_data[d]["num_pts0"] 115 | start_pts1 += list_data[d]["num_pts1"] 116 | 117 | list_num_pts0.append(list_data[d]["num_pts0"]) 118 | list_num_pts1.append(list_data[d]["num_pts1"]) 119 | 120 | list_data0.append((list_data[d]["coordinates"].to(self.device), list_data[d]["features"].to(self.device), list_data[d]["labels"].to(self.device))) 121 | list_data1.append((list_data[d]["next_coordinates"].to(self.device), list_data[d]["next_features"].to(self.device), list_data[d]["next_labels"].to(self.device))) 122 | 123 | list_all.append(list_data[d]["coordinates_all"].to(self.device)) 124 | 125 | list_selected.append(list_data[d]["sampled_idx"].to(self.device)) 126 | 127 | list_global0.append(list_data[d]["global_pts"].to(self.device)) 128 | list_global1.append(list_data[d]["global_next_pts"].to(self.device)) 129 | 130 | 131 | # concatenate 132 | matches_list0 = torch.cat(matches_list0) 133 | matches_list1 = torch.cat(matches_list1) 134 | 135 | # concatenate 136 | list_global0 = torch.cat(list_global0) 137 | list_global1 = torch.cat(list_global1) 138 | 139 | # sparse collation for t0 140 | coordinates_batch0, features_batch0, labels_batch0 = ME.utils.SparseCollation(dtype=torch.float32, 141 | device=self.device)(list_data0) 142 | # sparse collation for t1 143 | coordinates_batch1, features_batch1, labels_batch1 = ME.utils.SparseCollation(dtype=torch.float32, 144 | device=self.device)(list_data1) 145 | 146 | return {"coordinates0": coordinates_batch0, 147 | "features0": features_batch0, 148 | "labels0": labels_batch0, 149 | "coordinates1": coordinates_batch1, 150 | "features1": features_batch1, 151 | "labels1": labels_batch1, 152 | "matches0": matches_list0, 153 | "matches1": matches_list1, 154 | "num_pts0": list_num_pts0, 155 | "num_pts1": list_num_pts1, 156 | "coordinates_all": list_all, 157 | "sampled_idx": list_selected, 158 | "global_pts0": list_global0, 159 | "global_pts1": list_global1} 160 | 161 | 162 | class CollateStream: 163 | def __init__(self, device=None): 164 | self.device = device 165 | 166 | def __call__(self, list_data): 167 | r""" 168 | Collation function for MinkowskiEngine.SparseTensor that creates batched 169 | coordinates given a list of dictionaries. 170 | """ 171 | # batched_coords = [d['coordinates'] for d in list_data] 172 | # batched_coords = ME.utils.batched_coordinates(batched_coords, dtype=torch.float32) 173 | 174 | batch_data = [] 175 | batch_global = [] 176 | 177 | for d in list_data: 178 | batch_data.append((d["coordinates"], d["features"], d["labels"])) 179 | batch_global.append(d['global_points']) 180 | 181 | coordinates_batch, features_batch, labels_batch = ME.utils.SparseCollation(dtype=torch.float32, 182 | device=self.device)(batch_data) 183 | return {"coordinates": coordinates_batch, 184 | "features": features_batch, 185 | "labels": labels_batch, 186 | "global_points": batch_global} 187 | -------------------------------------------------------------------------------- /pipelines/trainer_lighting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import MinkowskiEngine as ME 5 | from utils.losses import CELoss, SoftCELoss, DICELoss, SoftDICELoss 6 | import pytorch_lightning as pl 7 | from sklearn.metrics import jaccard_score 8 | 9 | 10 | class PLTOneDomainTrainer(pl.core.LightningModule): 11 | r""" 12 | Segmentation Module for MinkowskiEngine for training on one domain. 13 | """ 14 | 15 | def __init__(self, 16 | model, 17 | training_dataset, 18 | validation_dataset, 19 | optimizer_name='SGD', 20 | criterion='CELoss', 21 | lr=1e-3, 22 | batch_size=12, 23 | weight_decay=1e-5, 24 | momentum=0.9, 25 | val_batch_size=6, 26 | train_num_workers=10, 27 | val_num_workers=10, 28 | num_classes=7, 29 | clear_cache_int=2, 30 | scheduler_name=None): 31 | 32 | super().__init__() 33 | for name, value in vars().items(): 34 | if name != "self": 35 | setattr(self, name, value) 36 | 37 | if criterion == 'CELoss': 38 | self.criterion = CELoss(ignore_label=self.training_dataset.ignore_label, 39 | weight=None) 40 | 41 | elif criterion == 'WCELoss': 42 | self.criterion = CELoss(ignore_label=self.training_dataset.ignore_label, 43 | weight=self.training_dataset.weights) 44 | 45 | elif criterion == 'SoftCELoss': 46 | self.criterion = SoftCELoss(ignore_label=self.training_dataset.ignore_label) 47 | 48 | elif criterion == 'DICELoss': 49 | self.criterion = DICELoss(ignore_label=self.training_dataset.ignore_label) 50 | elif criterion == 'SoftDICELoss': 51 | self.criterion = SoftDICELoss(ignore_label=self.training_dataset.ignore_label) 52 | else: 53 | raise NotImplementedError 54 | 55 | self.ignore_label = self.training_dataset.ignore_label 56 | 57 | self.save_hyperparameters(ignore='model') 58 | 59 | def training_step(self, batch, batch_idx): 60 | stensor = ME.SparseTensor(coordinates=batch["coordinates"].int(), features=batch["features"]) 61 | # Must clear cache at regular interval 62 | if self.global_step % self.clear_cache_int == 0: 63 | torch.cuda.empty_cache() 64 | 65 | out = self.model(stensor).F 66 | labels = batch['labels'].long() 67 | 68 | loss, per_class_loss = self.criterion(out, labels, return_class=True) 69 | 70 | _, preds = out.max(1) 71 | 72 | iou_tmp = jaccard_score(preds.detach().cpu().numpy(), labels.cpu().numpy(), average=None, 73 | labels=np.arange(0, self.num_classes), 74 | zero_division=0.) 75 | 76 | present_labels, class_occurs = np.unique(labels.cpu().numpy(), return_counts=True) 77 | present_labels = present_labels[present_labels != self.ignore_label] 78 | present_names = self.training_dataset.class2names[present_labels].tolist() 79 | present_names = [os.path.join('training', p + '_iou') for p in present_names] 80 | results_dict = dict(zip(present_names, iou_tmp.tolist())) 81 | 82 | present_names = [os.path.join('training', p + '_loss') for p in present_names] 83 | results_dict.update(dict(zip(present_names, per_class_loss.tolist()))) 84 | 85 | results_dict['training/loss'] = loss 86 | results_dict['training/iou'] = np.mean(iou_tmp[present_labels]) 87 | results_dict['training/lr'] = self.trainer.optimizers[0].param_groups[0]["lr"] 88 | results_dict['training/epoch'] = self.current_epoch 89 | 90 | for k, v in results_dict.items(): 91 | self.log( 92 | name=k, 93 | value=v, 94 | logger=True, 95 | on_step=False, 96 | on_epoch=True, 97 | sync_dist=True, 98 | rank_zero_only=True 99 | ) 100 | return loss 101 | 102 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 103 | phase = ['validation', 'target'] 104 | phase = phase[dataloader_idx] 105 | stensor = ME.SparseTensor(coordinates=batch["coordinates"].int(), features=batch["features"]) 106 | # Must clear cache at regular interval 107 | if self.global_step % self.clear_cache_int == 0: 108 | torch.cuda.empty_cache() 109 | 110 | out = self.model(stensor).F 111 | labels = batch['labels'].long() 112 | 113 | loss = self.criterion(out, labels) 114 | _, preds = out.max(1) 115 | 116 | iou_tmp = jaccard_score(preds.detach().cpu().numpy(), labels.cpu().numpy(), average=None, 117 | labels=np.arange(0, self.num_classes), 118 | zero_division=0.) 119 | 120 | present_labels, class_occurs = np.unique(labels.cpu().numpy(), return_counts=True) 121 | present_labels = present_labels[present_labels != self.ignore_label] 122 | present_names = self.training_dataset.class2names[present_labels].tolist() 123 | present_names = [os.path.join(phase, p + '_iou') for p in present_names] 124 | results_dict = dict(zip(present_names, iou_tmp.tolist())) 125 | 126 | results_dict[f'{phase}/loss'] = loss 127 | results_dict[f'{phase}/iou'] = np.mean(iou_tmp[present_labels]) 128 | 129 | for k, v in results_dict.items(): 130 | self.log( 131 | name=k, 132 | value=v, 133 | logger=True, 134 | on_step=False, 135 | on_epoch=True, 136 | sync_dist=True, 137 | add_dataloader_idx=False 138 | ) 139 | return results_dict 140 | 141 | def configure_optimizers(self): 142 | if self.scheduler_name is None: 143 | if self.optimizer_name == 'SGD': 144 | optimizer = torch.optim.SGD(self.model.parameters(), 145 | lr=self.lr, 146 | momentum=self.momentum, 147 | weight_decay=self.weight_decay) 148 | elif self.optimizer_name == 'Adam': 149 | optimizer = torch.optim.Adam(self.model.parameters(), 150 | lr=self.lr, 151 | weight_decay=self.weight_decay) 152 | else: 153 | raise NotImplementedError 154 | 155 | return optimizer 156 | else: 157 | if self.optimizer_name == 'SGD': 158 | optimizer = torch.optim.SGD(self.model.parameters(), 159 | lr=self.lr, 160 | momentum=self.momentum, 161 | weight_decay=self.weight_decay) 162 | elif self.optimizer_name == 'Adam': 163 | optimizer = torch.optim.Adam(self.model.parameters(), 164 | lr=self.lr, 165 | weight_decay=self.weight_decay) 166 | else: 167 | raise NotImplementedError 168 | 169 | if self.scheduler_name == 'CosineAnnealingLR': 170 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) 171 | elif self.scheduler_name == 'ExponentialLR': 172 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 173 | elif self.scheduler_name == 'CyclicLR': 174 | scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=self.lr/10000, max_lr=self.lr, 175 | step_size_up=5, mode="triangular2") 176 | 177 | else: 178 | raise NotImplementedError 179 | 180 | return [optimizer], [scheduler] 181 | 182 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from typing import Optional 6 | 7 | 8 | class CELoss(nn.Module): 9 | def __init__(self, ignore_label=None, weight=None): 10 | super().__init__() 11 | if weight is not None: 12 | weight = torch.from_numpy(weight).float() 13 | print(f'----->Using weighted CE Loss weights: {weight}') 14 | 15 | self.loss = nn.CrossEntropyLoss(ignore_index=ignore_label, weight=weight) 16 | self.ignored_label = ignore_label 17 | 18 | def forward(self, preds, gt): 19 | 20 | loss = self.loss(preds, gt) 21 | return loss 22 | 23 | 24 | class SoftCELoss(nn.Module): 25 | def __init__(self, ignore_label=None): 26 | super().__init__() 27 | 28 | self.ignore_label = ignore_label 29 | 30 | @staticmethod 31 | def soft_ce(preds, gt): 32 | log_probs = F.log_softmax(preds, dim=1) 33 | loss = -(gt * log_probs).sum() / preds.shape[0] 34 | return loss 35 | 36 | def forward(self, preds, gt): 37 | bs, num_pts, num_classes = preds.shape 38 | 39 | preds = preds.view(-1, num_classes) 40 | 41 | gt = gt.view(-1) 42 | if self.ignore_label is not None: 43 | valid_idx = torch.logical_not(self.ignore_label == gt) 44 | preds = preds[valid_idx] 45 | gt = gt[valid_idx] 46 | 47 | return self.soft_ce(preds, gt) 48 | 49 | 50 | class DICELoss(nn.Module): 51 | 52 | def __init__(self, ignore_label=None, powerize=True, use_tmask=True): 53 | super(DICELoss, self).__init__() 54 | 55 | if ignore_label is not None: 56 | self.ignore_label = torch.tensor(ignore_label) 57 | else: 58 | self.ignore_label = ignore_label 59 | 60 | self.powerize = powerize 61 | self.use_tmask = use_tmask 62 | 63 | def forward(self, output, target): 64 | input_device = output.device 65 | # temporal solution to avoid nan 66 | output = output.cpu() 67 | target = target.cpu() 68 | 69 | if self.ignore_label is not None: 70 | valid_idx = torch.logical_not(target == self.ignore_label) 71 | target = target[valid_idx] 72 | output = output[valid_idx, :] 73 | 74 | target = F.one_hot(target, num_classes=output.shape[1]) 75 | output = F.softmax(output, dim=-1) 76 | 77 | intersection = (output * target).sum(dim=0) 78 | if self.powerize: 79 | union = (output.pow(2).sum(dim=0) + target.sum(dim=0)) + 1e-12 80 | else: 81 | union = (output.sum(dim=0) + target.sum(dim=0)) + 1e-12 82 | if self.use_tmask: 83 | tmask = (target.sum(dim=0) > 0).int() 84 | else: 85 | tmask = torch.ones(target.shape[1]).int() 86 | 87 | iou = (tmask * 2 * intersection / union).sum(dim=0) / (tmask.sum(dim=0) + 1e-12) 88 | 89 | dice_loss = 1 - iou.mean() 90 | 91 | return dice_loss.to(input_device) 92 | 93 | 94 | def get_soft(t_vector, eps=0.25): 95 | 96 | max_val = 1 - eps 97 | min_val = eps / (t_vector.shape[-1] - 1) 98 | 99 | t_soft = torch.empty(t_vector.shape) 100 | t_soft[t_vector == 0] = min_val 101 | t_soft[t_vector == 1] = max_val 102 | 103 | return t_soft 104 | 105 | def get_soft_new(t_vector, score, eps=0.25): 106 | eps = eps * (1 - score) 107 | max_val = 1 - eps 108 | min_val = eps / (t_vector.shape[-1] - 1) 109 | 110 | min_val = min_val.unsqueeze(1).expand_as(t_vector) 111 | max_val = max_val.unsqueeze(1).expand_as(t_vector) 112 | 113 | t_soft = torch.empty(t_vector.shape).cuda() 114 | t_soft[t_vector == 0] = min_val[t_vector == 0] 115 | t_soft[t_vector == 1] = max_val[t_vector == 1] 116 | 117 | return t_soft 118 | 119 | 120 | class SoftDICELoss(nn.Module): 121 | 122 | def __init__(self, ignore_label=None, powerize=True, use_tmask=True, 123 | neg_range=False, eps=0.): 124 | super(SoftDICELoss, self).__init__() 125 | 126 | if ignore_label is not None: 127 | self.ignore_label = torch.tensor(ignore_label) 128 | else: 129 | self.ignore_label = ignore_label 130 | self.powerize = powerize 131 | self.use_tmask = use_tmask 132 | self.neg_range = neg_range 133 | self.eps = eps 134 | 135 | def forward(self, output, target, return_class=False, score=None, loss_method_num=0): 136 | input_device = output.device 137 | # temporal solution to avoid nan 138 | output = output 139 | target = target 140 | 141 | if self.ignore_label is not None: 142 | valid_idx = torch.logical_not(target == self.ignore_label) 143 | target = target[valid_idx] 144 | output = output[valid_idx, :] 145 | 146 | if score is not None: 147 | score = score.squeeze() 148 | score = score[valid_idx] 149 | 150 | target_onehot = F.one_hot(target, num_classes=output.shape[1]) 151 | 152 | if score is not None and (loss_method_num==1 or loss_method_num==3): 153 | target_soft = get_soft_new(target_onehot, score, eps=self.eps).cuda() 154 | else: 155 | target_soft = get_soft(target_onehot, eps=self.eps).cuda() 156 | 157 | # target_soft = get_soft(target_onehot, eps=self.eps).cuda() 158 | 159 | output = F.softmax(output, dim=-1) 160 | 161 | if score is not None and (loss_method_num==2 or loss_method_num==3): 162 | intersection = (output * target_soft * score.unsqueeze(1).expand_as(output)).sum(dim=0) 163 | else: 164 | intersection = (output * target_soft).sum(dim=0) 165 | 166 | # intersection = (output * target_soft).sum(dim=0) 167 | 168 | if self.powerize: 169 | union = (output.pow(2).sum(dim=0) + target_soft.sum(dim=0)) + 1e-12 170 | else: 171 | union = (output.sum(dim=0) + target_soft.sum(dim=0)) + 1e-12 172 | if self.use_tmask: 173 | tmask = (target_onehot.sum(dim=0) > 0).int() 174 | else: 175 | tmask = torch.ones(target_onehot.shape[1]).int() 176 | 177 | iou = (tmask * 2 * intersection / union).sum(dim=0) / (tmask.sum(dim=0) + 1e-12) 178 | iou_class = tmask * 2 * intersection / union 179 | 180 | if self.neg_range: 181 | dice_loss = -iou.mean() 182 | dice_class = -iou_class 183 | else: 184 | dice_loss = 1 - iou.mean() 185 | dice_class = 1 - iou_class 186 | if return_class: 187 | return dice_loss.to(input_device), dice_class 188 | else: 189 | return dice_loss.to(input_device) 190 | 191 | 192 | class HLoss(nn.Module): 193 | def __init__(self): 194 | super(HLoss, self).__init__() 195 | 196 | def forward(self, x): 197 | b = F.softmax(x, dim=-1) * F.log_softmax(x, dim=-1) 198 | b = -1.0 * b.sum(dim=-1) 199 | return b 200 | 201 | 202 | class SCELoss(torch.nn.Module): 203 | def __init__(self, alpha, beta, num_classes=10, reduction='mean', ignore_label=None): 204 | super(SCELoss, self).__init__() 205 | self.device = 'cpu' 206 | self.alpha = alpha 207 | self.beta = beta 208 | self.num_classes = num_classes 209 | self.reduction = reduction 210 | self.ignore_label = ignore_label 211 | self.cross_entropy = torch.nn.CrossEntropyLoss(reduction=reduction) 212 | 213 | def forward(self, pred, labels): 214 | 215 | pred = pred.cpu() 216 | labels = labels.cpu() 217 | if self.ignore_label is not None: 218 | valid_idx = torch.logical_not(labels == self.ignore_label) 219 | pred = pred[valid_idx] 220 | labels = labels[valid_idx] 221 | 222 | # CCE 223 | ce = self.cross_entropy(pred, labels) 224 | 225 | # RCE 226 | pred = F.softmax(pred, dim=-1) 227 | pred = torch.clamp(pred, min=1e-4, max=1.0) 228 | label_one_hot = torch.nn.functional.one_hot(labels, self.num_classes).float() 229 | label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0) 230 | rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1)) 231 | 232 | if self.reduction == 'mean': 233 | rce = rce.mean() 234 | # Loss 235 | loss = self.alpha * ce + self.beta * rce 236 | return loss 237 | -------------------------------------------------------------------------------- /models/minkunet_nobn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | import torch 25 | import torch.nn as nn 26 | from torch.optim import SGD 27 | 28 | import MinkowskiEngine as ME 29 | 30 | from models.resnet import ResNetBase 31 | 32 | 33 | class BasicBlockNOBN(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, 37 | inplanes, 38 | planes, 39 | stride=1, 40 | dilation=1, 41 | downsample=None, 42 | bn_momentum=0.1, 43 | dimension=-1): 44 | super(BasicBlockNOBN, self).__init__() 45 | assert dimension > 0 46 | 47 | self.conv1 = ME.MinkowskiConvolution( 48 | inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, dimension=dimension) 49 | self.conv2 = ME.MinkowskiConvolution( 50 | planes, planes, kernel_size=3, stride=1, dilation=dilation, dimension=dimension) 51 | self.relu = ME.MinkowskiReLU(inplace=True) 52 | self.downsample = downsample 53 | 54 | def forward(self, x): 55 | residual = x 56 | out = self.conv1(x) 57 | out = self.relu(out) 58 | out = self.conv2(out) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | class MinkUNetBaseNoBN(ResNetBase): 69 | BLOCK = None 70 | PLANES = None 71 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) 72 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 73 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 74 | INIT_DIM = 32 75 | OUT_TENSOR_STRIDE = 1 76 | 77 | # To use the model, must call initialize_coords before forward pass. 78 | # Once data is processed, call clear to reset the model before calling 79 | # initialize_coords 80 | def __init__(self, in_channels, out_channels, D=3): 81 | ResNetBase.__init__(self, in_channels, out_channels, D) 82 | 83 | def network_initialization(self, in_channels, out_channels, D): 84 | # Output of the first conv concated to conv6 85 | self.inplanes = self.INIT_DIM 86 | self.conv0p1s1 = ME.MinkowskiConvolution( 87 | in_channels, self.inplanes, kernel_size=5, dimension=D) 88 | 89 | self.conv1p1s2 = ME.MinkowskiConvolution( 90 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 91 | 92 | self.block1 = self._make_layer_nobn(self.BLOCK, self.PLANES[0], self.LAYERS[0]) 93 | 94 | self.conv2p2s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 95 | 96 | self.block2 = self._make_layer_nobn(self.BLOCK, self.PLANES[1], self.LAYERS[1]) 97 | 98 | self.conv3p4s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 99 | 100 | self.block3 = self._make_layer_nobn(self.BLOCK, self.PLANES[2], self.LAYERS[2]) 101 | 102 | self.conv4p8s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 103 | 104 | self.block4 = self._make_layer_nobn(self.BLOCK, self.PLANES[3], self.LAYERS[3]) 105 | 106 | self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose(self.inplanes, self.PLANES[4], kernel_size=2, 107 | stride=2, dimension=D) 108 | 109 | self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion 110 | self.block5 = self._make_layer_nobn(self.BLOCK, self.PLANES[4], self.LAYERS[4]) 111 | self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose(self.inplanes, self.PLANES[5], 112 | kernel_size=2, stride=2, dimension=D) 113 | 114 | self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion 115 | self.block6 = self._make_layer_nobn(self.BLOCK, self.PLANES[5],self.LAYERS[5]) 116 | self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose(self.inplanes, self.PLANES[6], 117 | kernel_size=2, stride=2, dimension=D) 118 | 119 | self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion 120 | self.block7 = self._make_layer_nobn(self.BLOCK, self.PLANES[6], self.LAYERS[6]) 121 | self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose(self.inplanes, self.PLANES[7], 122 | kernel_size=2, stride=2, dimension=D) 123 | 124 | self.inplanes = self.PLANES[7] + self.INIT_DIM 125 | 126 | self.block8 = self._make_layer_nobn(self.BLOCK, self.PLANES[7], self.LAYERS[7]) 127 | 128 | self.final = ME.MinkowskiConvolution(self.PLANES[7] * self.BLOCK.expansion, 129 | out_channels, 130 | kernel_size=1, 131 | bias=True, 132 | dimension=D) 133 | 134 | self.relu = ME.MinkowskiReLU(inplace=True) 135 | 136 | def forward(self, x, is_seg=True): 137 | out = self.conv0p1s1(x) 138 | out_p1 = self.relu(out) 139 | 140 | out = self.conv1p1s2(out_p1) 141 | out = self.relu(out) 142 | out_b1p2 = self.block1(out) 143 | 144 | out = self.conv2p2s2(out_b1p2) 145 | out = self.relu(out) 146 | out_b2p4 = self.block2(out) 147 | 148 | out = self.conv3p4s2(out_b2p4) 149 | out = self.relu(out) 150 | out_b3p8 = self.block3(out) 151 | 152 | # tensor_stride=16 153 | out = self.conv4p8s2(out_b3p8) 154 | out = self.relu(out) 155 | out_bottle = self.block4(out) 156 | 157 | # tensor_stride=8 158 | out = self.convtr4p16s2(out_bottle) 159 | out = self.relu(out) 160 | 161 | out = ME.cat(out, out_b3p8) 162 | out = self.block5(out) 163 | 164 | # tensor_stride=4 165 | out = self.convtr5p8s2(out) 166 | out = self.relu(out) 167 | 168 | out = ME.cat(out, out_b2p4) 169 | out = self.block6(out) 170 | 171 | # tensor_stride=2 172 | out = self.convtr6p4s2(out) 173 | out = self.relu(out) 174 | 175 | out = ME.cat(out, out_b1p2) 176 | out = self.block7(out) 177 | 178 | # tensor_stride=1 179 | out = self.convtr7p2s2(out) 180 | out = self.relu(out) 181 | 182 | out = ME.cat(out, out_p1) 183 | out = self.block8(out) 184 | 185 | if is_seg: 186 | return self.final(out) 187 | else: 188 | return out, out_bottle 189 | 190 | 191 | class MinkUNet14NOBN(MinkUNetBaseNoBN): 192 | BLOCK = BasicBlockNOBN 193 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 194 | 195 | 196 | class MinkUNet18NOBN(MinkUNetBaseNoBN): 197 | BLOCK = BasicBlockNOBN 198 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 199 | 200 | 201 | class MinkUNet34NOBN(MinkUNetBaseNoBN): 202 | BLOCK = BasicBlockNOBN 203 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 204 | 205 | 206 | if __name__ == '__main__': 207 | 208 | # loss and network 209 | criterion = nn.CrossEntropyLoss() 210 | net = MinkUNet18NOBN(in_channels=3, out_channels=7, D=3) 211 | print(net) -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | CSV logger 16 | ---------- 17 | CSV logger for basic experiment logging that does not require opening ports 18 | """ 19 | import csv 20 | import logging 21 | import os 22 | from argparse import Namespace 23 | from typing import Any, Dict, Optional, Union 24 | 25 | import torch 26 | 27 | from pytorch_lightning.core.saving import save_hparams_to_yaml 28 | from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment 29 | from pytorch_lightning.utilities import rank_zero_warn 30 | from pytorch_lightning.utilities.distributed import rank_zero_only 31 | 32 | log = logging.getLogger(__name__) 33 | 34 | 35 | class ExperimentWriter: 36 | r""" 37 | Experiment writer for CSVLogger. 38 | Currently supports to log hyperparameters and metrics in YAML and CSV 39 | format, respectively. 40 | Args: 41 | log_dir: Directory for the experiment logs 42 | """ 43 | 44 | NAME_HPARAMS_FILE = "hparams.yaml" 45 | NAME_METRICS_FILE = "metrics.csv" 46 | 47 | def __init__(self, log_dir: str) -> None: 48 | self.hparams = {} 49 | self.metrics = [] 50 | 51 | self.log_dir = log_dir 52 | if os.path.exists(self.log_dir) and os.listdir(self.log_dir): 53 | rank_zero_warn( 54 | f"Experiment logs directory {self.log_dir} exists and is not empty." 55 | " Previous log files in this directory will be deleted when the new ones are saved!" 56 | ) 57 | os.makedirs(self.log_dir, exist_ok=True) 58 | 59 | self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) 60 | 61 | def log_hparams(self, params: Dict[str, Any]) -> None: 62 | """Record hparams""" 63 | self.hparams.update(params) 64 | 65 | def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None: 66 | """Record metrics""" 67 | 68 | def _handle_value(value): 69 | if isinstance(value, torch.Tensor): 70 | return value.item() 71 | return value 72 | 73 | if step is None: 74 | step = len(self.metrics) 75 | 76 | metrics = {k: _handle_value(v) for k, v in metrics_dict.items()} 77 | metrics["step"] = step 78 | self.metrics.append(metrics) 79 | 80 | def save(self) -> None: 81 | """Save recorded hparams and metrics into files""" 82 | # hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) 83 | # save_hparams_to_yaml(hparams_file, self.hparams) 84 | 85 | if not self.metrics: 86 | return 87 | 88 | last_m = {} 89 | for m in self.metrics: 90 | last_m.update(m) 91 | metrics_keys = list(last_m.keys()) 92 | 93 | with open(self.metrics_file_path, "w", newline="") as f: 94 | self.writer = csv.DictWriter(f, fieldnames=metrics_keys) 95 | self.writer.writeheader() 96 | self.writer.writerows(self.metrics) 97 | 98 | 99 | class CSVLogger(LightningLoggerBase): 100 | r""" 101 | Log to local file system in yaml and CSV format. 102 | Logs are saved to ``os.path.join(save_dir, name, version)``. 103 | Example: 104 | 105 | Args: 106 | save_dir: Save directory 107 | name: Experiment name. Defaults to ``'default'``. 108 | version: Experiment version. If version is not specified the logger inspects the save 109 | directory for existing versions, then automatically assigns the next available version. 110 | prefix: A string to put at the beginning of metric keys. 111 | """ 112 | 113 | LOGGER_JOIN_CHAR = "-" 114 | 115 | def __init__( 116 | self, 117 | save_dir: str, 118 | name: Optional[str] = "default", 119 | version: Optional[Union[int, str]] = None, 120 | prefix: str = "", 121 | ): 122 | super().__init__() 123 | self._save_dir = save_dir 124 | self._name = name or "" 125 | self._version = version 126 | self._prefix = prefix 127 | self._experiment = None 128 | 129 | @property 130 | def root_dir(self) -> str: 131 | """ 132 | Parent directory for all checkpoint subdirectories. 133 | If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used 134 | and the checkpoint will be saved in "save_dir/version_dir" 135 | """ 136 | if not self.name: 137 | return self.save_dir 138 | return os.path.join(self.save_dir, self.name) 139 | 140 | @property 141 | def log_dir(self) -> str: 142 | """ 143 | The log directory for this run. By default, it is named 144 | ``'version_${self.version}'`` but it can be overridden by passing a string value 145 | for the constructor's version parameter instead of ``None`` or an int. 146 | """ 147 | # create a pseudo standard path ala test-tube 148 | version = self.version if isinstance(self.version, str) else f"version_{self.version}" 149 | log_dir = os.path.join(self.root_dir, version) 150 | return log_dir 151 | 152 | @property 153 | def save_dir(self) -> Optional[str]: 154 | """ 155 | The current directory where logs are saved. 156 | Returns: 157 | The path to current directory where logs are saved. 158 | """ 159 | return self._save_dir 160 | 161 | @property 162 | @rank_zero_experiment 163 | def experiment(self) -> ExperimentWriter: 164 | r""" 165 | Actual ExperimentWriter object. To use ExperimentWriter features in your 166 | :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. 167 | Example:: 168 | self.logger.experiment.some_experiment_writer_function() 169 | """ 170 | if self._experiment: 171 | return self._experiment 172 | 173 | os.makedirs(self.root_dir, exist_ok=True) 174 | self._experiment = ExperimentWriter(log_dir=self.log_dir) 175 | return self._experiment 176 | 177 | @rank_zero_only 178 | def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: 179 | params = self._convert_params(params) 180 | self.experiment.log_hparams(params) 181 | 182 | @rank_zero_only 183 | def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: 184 | metrics = self._add_prefix(metrics) 185 | self.experiment.log_metrics(metrics, step) 186 | 187 | @rank_zero_only 188 | def save(self) -> None: 189 | super().save() 190 | self.experiment.save() 191 | 192 | @rank_zero_only 193 | def finalize(self, status: str) -> None: 194 | self.save() 195 | 196 | @property 197 | def name(self) -> str: 198 | """ 199 | Gets the name of the experiment. 200 | Returns: 201 | The name of the experiment. 202 | """ 203 | return self._name 204 | 205 | @property 206 | def version(self) -> int: 207 | """ 208 | Gets the version of the experiment. 209 | Returns: 210 | The version of the experiment if it is specified, else the next version. 211 | """ 212 | if self._version is None: 213 | self._version = self._get_next_version() 214 | return self._version 215 | 216 | def _get_next_version(self): 217 | root_dir = os.path.join(self._save_dir, self.name) 218 | 219 | if not os.path.isdir(root_dir): 220 | log.warning("Missing logger folder: %s", root_dir) 221 | return 0 222 | 223 | existing_versions = [] 224 | for d in os.listdir(root_dir): 225 | if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): 226 | existing_versions.append(int(d.split("_")[1])) 227 | 228 | if len(existing_versions) == 0: 229 | return 0 230 | 231 | return max(existing_versions) + 1 232 | -------------------------------------------------------------------------------- /pipelines/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wandb 3 | import time 4 | import logging 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import MinkowskiEngine as ME 8 | import numpy as np 9 | 10 | from pipelines.base_pipeline import BasePipeline 11 | from utils.metrics import filtered_accuracy, confusion_matrix, iou_from_confusion 12 | 13 | try: 14 | import pytorch_lightning as plt 15 | except ImportError: 16 | raise ImportError( 17 | "Please install requirements with `pip install open3d pytorch_lightning`." 18 | ) 19 | 20 | 21 | class OneDomainTrainer(BasePipeline): 22 | 23 | def __init__(self, 24 | model=None, 25 | training_dataset=None, 26 | validation_dataset=None, 27 | loss=None, 28 | optimizer=None, 29 | scheduler=None, 30 | save_dir=None): 31 | 32 | # init super 33 | super().__init__(model=model, 34 | loss=loss, 35 | optimizer=optimizer, 36 | scheduler=scheduler) 37 | 38 | # datasets 39 | self.training_dataset = training_dataset 40 | self.validation_dataset = validation_dataset 41 | 42 | self.training_loader = None 43 | self.validation_loader = None 44 | 45 | # dirs 46 | self.save_dir = save_dir 47 | # logs 48 | self.use_wandb = None 49 | 50 | # device 51 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 52 | 53 | # for saving 54 | self.best_acc = 0. 55 | 56 | def single_gpu_train(self, 57 | epochs=100, 58 | lr=1e-3, 59 | batch_size=4, 60 | use_wandb=False, 61 | run_name=None, 62 | save_every=10, 63 | collation=None): 64 | 65 | # init logs 66 | run_time = time.strftime("%Y_%m_%d_%H:%M", time.gmtime()) 67 | if run_name is not None: 68 | run_name = run_time + '_' + run_name 69 | else: 70 | run_name = run_time 71 | 72 | self.save_dir = os.path.join(self.save_dir, run_name) 73 | self.use_wandb = use_wandb 74 | 75 | os.makedirs(os.path.join(self.save_dir, 'weights'), exist_ok=True) 76 | 77 | # logging 78 | log_path = os.path.join(self.save_dir, 'logs') 79 | os.makedirs(log_path, exist_ok=True) 80 | log_file = os.path.join(log_path, 'train.log') 81 | 82 | logging.basicConfig(filename=log_file, level=logging.INFO, force=True) 83 | logging.info(f'RUN_NAME {run_name}') 84 | logging.info(f'Logging in this file {log_file}') 85 | 86 | quick_configs = {'run_name': run_name, 87 | 'save_dir': self.save_dir, 88 | 'loss': str(type(self.loss)), 89 | 'epochs': epochs, 90 | 'lr': lr, 91 | 'lr_decay': self.scheduler.gamma, 92 | 'batchsize': batch_size} 93 | 94 | logging.info(f'CONFIGS: {quick_configs}') 95 | 96 | # use or not wandb 97 | if use_wandb: 98 | 99 | wandb.init(project="cvpr2022-online-seg", entity="unitn-mhug-csalto", 100 | name=run_name) 101 | 102 | 103 | logging.info('WANDB enabled') 104 | 105 | # init dataloaders 106 | self.training_loader = DataLoader(self.training_dataset, 107 | batch_size=batch_size, 108 | collate_fn=collation, 109 | shuffle=True) 110 | 111 | self.validation_loader = DataLoader(self.validation_dataset, 112 | batch_size=batch_size, 113 | collate_fn=collation, 114 | shuffle=False) 115 | 116 | logging.info(f'Training started at {time.strftime("%H:%M:%S", time.gmtime())}') 117 | 118 | for epoch in range(epochs): 119 | 120 | logging.info(f'=======> Epoch {epoch}') 121 | 122 | start_time = time.time() 123 | 124 | train_loss, train_acc, train_iou = self.train() 125 | 126 | ep_time = time.time() - start_time 127 | logging.info(f'=======> Epoch {epoch} ended') 128 | logging.info(f'=======> Training loss {train_loss}') 129 | logging.info(f'=======> Training acc {train_acc}') 130 | logging.info(f'=======> Training IoU {train_iou}') 131 | logging.info(f'=======> Time {ep_time}') 132 | 133 | if self.use_wandb: 134 | wandb.log({'Training loss': train_loss, 135 | 'Training accuracy': train_acc, 136 | 'Training iou': train_iou}) 137 | 138 | if epoch % save_every == 0: 139 | 140 | val_loss, val_acc, val_iou = self.validate() 141 | 142 | logging.info(f'******* Validation ******') 143 | logging.info(f'=======> Validation IoU {val_iou}') 144 | logging.info(f'=======> Validation acc {val_acc}') 145 | logging.info(f'=======> Validation loss {val_loss}') 146 | logging.info('**************************') 147 | 148 | if self.use_wandb: 149 | wandb.log({'Validation accuracy': val_acc, 150 | 'Validation loss': val_loss, 151 | 'Validation IoU': val_iou}) 152 | 153 | self.save_model(epoch, val_acc) 154 | 155 | self.scheduler.step() 156 | 157 | def train(self): 158 | self.model.train() 159 | training_losses = [] 160 | training_labels = [] 161 | training_preds = [] 162 | training_confusions = [] 163 | 164 | for t_idx, train_data in enumerate(self.training_loader): 165 | loss, train_pred = self.train_step(train_data) 166 | 167 | training_losses.append(loss.cpu().detach().numpy()) 168 | training_labels.append(train_data["labels"].cpu()) 169 | training_preds.append(train_pred.cpu()) 170 | conf_m = torch.from_numpy(confusion_matrix(train_pred.cpu(), (train_data["labels"].cpu()))).unsqueeze(0) 171 | training_confusions.append(conf_m) 172 | 173 | training_loss_mean = np.mean(training_losses) 174 | training_preds = torch.cat(training_preds).view(-1) 175 | training_labels = torch.cat(training_labels).view(-1) 176 | 177 | training_acc = filtered_accuracy(training_preds, training_labels) 178 | train_iou_per_class, training_iou = iou_from_confusion(torch.cat(training_confusions)) 179 | 180 | return training_loss_mean, training_acc, training_iou 181 | 182 | def train_step(self, batch): 183 | self.optimizer.zero_grad() 184 | stensor = ME.SparseTensor(coordinates=batch['coordinates'], features=batch['features']) 185 | out = self.model(stensor).F 186 | 187 | loss = self.loss(out, batch['labels'].long()) 188 | loss.backward() 189 | self.optimizer.step() 190 | 191 | _, preds = out.max(1) 192 | 193 | return loss, preds 194 | 195 | def validate(self): 196 | self.model.eval() 197 | validation_labels = [] 198 | validation_preds = [] 199 | validation_loss = [] 200 | validation_confusions = [] 201 | 202 | with torch.no_grad(): 203 | 204 | for v_idx, val_data in enumerate(self.validation_loader): 205 | val_loss, pred = self.validation_step(val_data) 206 | 207 | validation_labels.append(val_data["labels"].cpu()) 208 | validation_preds.append(pred.cpu()) 209 | conf_m = torch.from_numpy(confusion_matrix(pred.cpu(), val_data["labels"].cpu())).unsqueeze(0) 210 | validation_confusions.append(conf_m) 211 | 212 | validation_loss.append(val_loss) 213 | 214 | validation_preds = torch.cat(validation_preds).view(-1) 215 | validation_labels = torch.cat(validation_labels).view(-1) 216 | val_acc = filtered_accuracy(validation_preds, validation_labels) 217 | val_loss = np.mean(validation_loss) 218 | val_iou_per_class, val_iou = iou_from_confusion(torch.cat(validation_confusions)) 219 | 220 | return val_loss, val_acc, val_iou 221 | 222 | def validation_step(self, batch): 223 | stensor = ME.SparseTensor(coordinates=batch['coordinates'], features=batch['features']) 224 | 225 | out = self.model(stensor).F 226 | 227 | loss = self.loss(out, batch['labels'].long()) 228 | _, preds = out.max(1) 229 | # torch.cuda.empty_cache() 230 | 231 | return loss, preds 232 | 233 | def save_model(self, epoch, acc): 234 | 235 | torch.save({"state_dict": self.model.state_dict(), 236 | "optimizer": self.optimizer.state_dict(), 237 | "scheduler": self.scheduler.state_dict()}, 238 | os.path.join(self.save_dir, 'weights', f'checkpoint_{epoch}.pth')) 239 | 240 | if acc > self.best_acc: 241 | 242 | torch.save({"state_dict": self.model.state_dict(), 243 | "optimizer": self.optimizer.state_dict(), 244 | "scheduler": self.scheduler.state_dict()}, 245 | os.path.join(self.save_dir, 'weights', f'best.pth')) 246 | 247 | self.best_acc = acc 248 | -------------------------------------------------------------------------------- /models/minkunet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | import torch 25 | import torch.nn as nn 26 | from torch.optim import SGD 27 | 28 | import MinkowskiEngine as ME 29 | 30 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 31 | 32 | from models.resnet import ResNetBase 33 | 34 | 35 | class MinkUNetBase(ResNetBase): 36 | BLOCK = None 37 | PLANES = None 38 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) 39 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 40 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 41 | INIT_DIM = 32 42 | OUT_TENSOR_STRIDE = 1 43 | 44 | # To use the model, must call initialize_coords before forward pass. 45 | # Once data is processed, call clear to reset the model before calling 46 | # initialize_coords 47 | def __init__(self, in_channels, out_channels, D=3): 48 | ResNetBase.__init__(self, in_channels, out_channels, D) 49 | 50 | def network_initialization(self, in_channels, out_channels, D): 51 | # Output of the first conv concated to conv6 52 | self.inplanes = self.INIT_DIM 53 | self.conv0p1s1 = ME.MinkowskiConvolution( 54 | in_channels, self.inplanes, kernel_size=5, dimension=D) 55 | 56 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 57 | 58 | self.conv1p1s2 = ME.MinkowskiConvolution( 59 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 60 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 61 | 62 | self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], 63 | self.LAYERS[0]) 64 | 65 | self.conv2p2s2 = ME.MinkowskiConvolution( 66 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 67 | self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) 68 | 69 | self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], 70 | self.LAYERS[1]) 71 | 72 | self.conv3p4s2 = ME.MinkowskiConvolution( 73 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 74 | 75 | self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) 76 | self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], 77 | self.LAYERS[2]) 78 | 79 | self.conv4p8s2 = ME.MinkowskiConvolution( 80 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 81 | self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) 82 | self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], 83 | self.LAYERS[3]) 84 | 85 | self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( 86 | self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D) 87 | self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) 88 | 89 | self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion 90 | self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], 91 | self.LAYERS[4]) 92 | self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( 93 | self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D) 94 | self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) 95 | 96 | self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion 97 | self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], 98 | self.LAYERS[5]) 99 | self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( 100 | self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D) 101 | self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) 102 | 103 | self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion 104 | self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], 105 | self.LAYERS[6]) 106 | self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( 107 | self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D) 108 | self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) 109 | 110 | self.inplanes = self.PLANES[7] + self.INIT_DIM 111 | self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], 112 | self.LAYERS[7]) 113 | 114 | self.final = ME.MinkowskiConvolution( 115 | self.PLANES[7] * self.BLOCK.expansion, 116 | out_channels, 117 | kernel_size=1, 118 | bias=True, 119 | dimension=D) 120 | self.relu = ME.MinkowskiReLU(inplace=True) 121 | 122 | def forward(self, x, is_seg=True): 123 | out = self.conv0p1s1(x) 124 | out = self.bn0(out) 125 | out_p1 = self.relu(out) 126 | 127 | out = self.conv1p1s2(out_p1) 128 | out = self.bn1(out) 129 | out = self.relu(out) 130 | out_b1p2 = self.block1(out) 131 | 132 | out = self.conv2p2s2(out_b1p2) 133 | out = self.bn2(out) 134 | out = self.relu(out) 135 | out_b2p4 = self.block2(out) 136 | 137 | out = self.conv3p4s2(out_b2p4) 138 | out = self.bn3(out) 139 | out = self.relu(out) 140 | out_b3p8 = self.block3(out) 141 | 142 | # tensor_stride=16 143 | out = self.conv4p8s2(out_b3p8) 144 | out = self.bn4(out) 145 | out = self.relu(out) 146 | out_bottle = self.block4(out) 147 | 148 | # tensor_stride=8 149 | out = self.convtr4p16s2(out_bottle) 150 | out = self.bntr4(out) 151 | out = self.relu(out) 152 | 153 | out = ME.cat(out, out_b3p8) 154 | out = self.block5(out) 155 | 156 | # tensor_stride=4 157 | out = self.convtr5p8s2(out) 158 | out = self.bntr5(out) 159 | out = self.relu(out) 160 | 161 | out = ME.cat(out, out_b2p4) 162 | out = self.block6(out) 163 | 164 | # tensor_stride=2 165 | out = self.convtr6p4s2(out) 166 | out = self.bntr6(out) 167 | out = self.relu(out) 168 | 169 | out = ME.cat(out, out_b1p2) 170 | out = self.block7(out) 171 | 172 | # tensor_stride=1 173 | out = self.convtr7p2s2(out) 174 | out = self.bntr7(out) 175 | out = self.relu(out) 176 | 177 | out = ME.cat(out, out_p1) 178 | out = self.block8(out) 179 | 180 | if is_seg: 181 | return self.final(out) 182 | else: 183 | return out, out_bottle 184 | 185 | 186 | class MinkUNet14(MinkUNetBase): 187 | BLOCK = BasicBlock 188 | LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) 189 | 190 | 191 | class MinkUNet18(MinkUNetBase): 192 | BLOCK = BasicBlock 193 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 194 | 195 | 196 | class MinkUNet34(MinkUNetBase): 197 | BLOCK = BasicBlock 198 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 199 | 200 | 201 | class MinkUNet50(MinkUNetBase): 202 | BLOCK = Bottleneck 203 | LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) 204 | 205 | 206 | class MinkUNet101(MinkUNetBase): 207 | BLOCK = Bottleneck 208 | LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) 209 | 210 | 211 | class MinkUNet14A(MinkUNet14): 212 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 213 | 214 | 215 | class MinkUNet14B(MinkUNet14): 216 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 217 | 218 | 219 | class MinkUNet14C(MinkUNet14): 220 | PLANES = (32, 64, 128, 256, 192, 192, 128, 128) 221 | 222 | 223 | class MinkUNet14D(MinkUNet14): 224 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 225 | 226 | 227 | class MinkUNet18A(MinkUNet18): 228 | PLANES = (32, 64, 128, 256, 128, 128, 96, 96) 229 | 230 | 231 | class MinkUNet18B(MinkUNet18): 232 | PLANES = (32, 64, 128, 256, 128, 128, 128, 128) 233 | 234 | 235 | class MinkUNet18D(MinkUNet18): 236 | PLANES = (32, 64, 128, 256, 384, 384, 384, 384) 237 | 238 | 239 | class MinkUNet34A(MinkUNet34): 240 | PLANES = (32, 64, 128, 256, 256, 128, 64, 64) 241 | 242 | 243 | class MinkUNet34B(MinkUNet34): 244 | PLANES = (32, 64, 128, 256, 256, 128, 64, 32) 245 | 246 | 247 | class MinkUNet34C(MinkUNet34): 248 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 249 | 250 | 251 | if __name__ == '__main__': 252 | from tests.python.common import data_loader 253 | # loss and network 254 | criterion = nn.CrossEntropyLoss() 255 | net = MinkUNet14A(in_channels=3, out_channels=5, D=2) 256 | print(net) 257 | 258 | # a data loader must return a tuple of coords, features, and labels. 259 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 260 | 261 | net = net.to(device) 262 | optimizer = SGD(net.parameters(), lr=1e-2) 263 | 264 | for i in range(10): 265 | optimizer.zero_grad() 266 | 267 | # Get new data 268 | coords, feat, label = data_loader(is_classification=False) 269 | input = ME.SparseTensor(feat, coordinates=coords, device=device) 270 | label = label.to(device) 271 | 272 | # Forward 273 | output = net(input) 274 | 275 | # Loss 276 | loss = criterion(output.F, label) 277 | print('Iteration: ', i, ', Loss: ', loss.item()) 278 | 279 | # Gradient 280 | loss.backward() 281 | optimizer.step() 282 | 283 | # Saving and loading a network 284 | torch.save(net.state_dict(), 'test.pth') 285 | net.load_state_dict(torch.load('test.pth')) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | import os 25 | from urllib.request import urlretrieve 26 | import numpy as np 27 | 28 | import torch 29 | import torch.nn as nn 30 | from torch.optim import SGD 31 | 32 | try: 33 | import open3d as o3d 34 | except ImportError: 35 | raise ImportError("Please install open3d with `pip install open3d`.") 36 | 37 | import MinkowskiEngine as ME 38 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 39 | 40 | # 41 | # if not os.path.isfile("1.ply"): 42 | # print('Downloading an example pointcloud...') 43 | # urlretrieve("https://bit.ly/3c2iLhg", "1.ply") 44 | # 45 | 46 | def load_file(file_name): 47 | pcd = o3d.io.read_point_cloud(file_name) 48 | coords = np.array(pcd.points) 49 | colors = np.array(pcd.colors) 50 | return coords, colors, pcd 51 | 52 | 53 | class ResNetBase(nn.Module): 54 | BLOCK = None 55 | LAYERS = () 56 | INIT_DIM = 64 57 | PLANES = (64, 128, 256, 512) 58 | 59 | def __init__(self, in_channels, out_channels, D=3): 60 | nn.Module.__init__(self) 61 | self.D = D 62 | assert self.BLOCK is not None 63 | 64 | self.network_initialization(in_channels, out_channels, D) 65 | self.weight_initialization() 66 | 67 | def network_initialization(self, in_channels, out_channels, D): 68 | 69 | self.inplanes = self.INIT_DIM 70 | self.conv1 = nn.Sequential( 71 | ME.MinkowskiConvolution( 72 | in_channels, self.inplanes, kernel_size=3, stride=2, dimension=D 73 | ), 74 | ME.MinkowskiInstanceNorm(self.inplanes), 75 | ME.MinkowskiReLU(inplace=True), 76 | ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=D), 77 | ) 78 | 79 | self.layer1 = self._make_layer( 80 | self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2 81 | ) 82 | self.layer2 = self._make_layer( 83 | self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2 84 | ) 85 | self.layer3 = self._make_layer( 86 | self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2 87 | ) 88 | self.layer4 = self._make_layer( 89 | self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2 90 | ) 91 | 92 | self.conv5 = nn.Sequential( 93 | ME.MinkowskiDropout(), 94 | ME.MinkowskiConvolution( 95 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D 96 | ), 97 | ME.MinkowskiInstanceNorm(self.inplanes), 98 | ME.MinkowskiGELU(), 99 | ) 100 | 101 | self.glob_pool = ME.MinkowskiGlobalMaxPooling() 102 | 103 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 104 | 105 | def weight_initialization(self): 106 | for m in self.modules(): 107 | if isinstance(m, ME.MinkowskiConvolution): 108 | ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") 109 | 110 | if isinstance(m, ME.MinkowskiBatchNorm): 111 | nn.init.constant_(m.bn.weight, 1) 112 | nn.init.constant_(m.bn.bias, 0) 113 | 114 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): 115 | downsample = None 116 | if stride != 1 or self.inplanes != planes * block.expansion: 117 | downsample = nn.Sequential( 118 | ME.MinkowskiConvolution( 119 | self.inplanes, 120 | planes * block.expansion, 121 | kernel_size=1, 122 | stride=stride, 123 | dimension=self.D, 124 | ), 125 | ME.MinkowskiBatchNorm(planes * block.expansion), 126 | ) 127 | layers = [] 128 | layers.append( 129 | block( 130 | self.inplanes, 131 | planes, 132 | stride=stride, 133 | dilation=dilation, 134 | downsample=downsample, 135 | dimension=self.D, 136 | ) 137 | ) 138 | self.inplanes = planes * block.expansion 139 | for i in range(1, blocks): 140 | layers.append( 141 | block( 142 | self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D 143 | ) 144 | ) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def _make_layer_nobn(self, block, planes, blocks, stride=1, dilation=1): 149 | downsample = None 150 | if stride != 1 or self.inplanes != planes * block.expansion: 151 | downsample = nn.Sequential( 152 | ME.MinkowskiConvolution( 153 | self.inplanes, 154 | planes * block.expansion, 155 | kernel_size=1, 156 | stride=stride, 157 | dimension=self.D, 158 | ), 159 | ) 160 | layers = [] 161 | layers.append( 162 | block( 163 | self.inplanes, 164 | planes, 165 | stride=stride, 166 | dilation=dilation, 167 | downsample=downsample, 168 | dimension=self.D, 169 | ) 170 | ) 171 | self.inplanes = planes * block.expansion 172 | for i in range(1, blocks): 173 | layers.append( 174 | block( 175 | self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D 176 | ) 177 | ) 178 | 179 | return nn.Sequential(*layers) 180 | 181 | def forward(self, x: ME.SparseTensor): 182 | x = self.conv1(x) 183 | x = self.layer1(x) 184 | x = self.layer2(x) 185 | x = self.layer3(x) 186 | x = self.layer4(x) 187 | x = self.conv5(x) 188 | x = self.glob_pool(x) 189 | return self.final(x) 190 | 191 | 192 | class ResNet14(ResNetBase): 193 | BLOCK = BasicBlock 194 | LAYERS = (1, 1, 1, 1) 195 | 196 | 197 | class ResNet18(ResNetBase): 198 | BLOCK = BasicBlock 199 | LAYERS = (2, 2, 2, 2) 200 | 201 | 202 | class ResNet34(ResNetBase): 203 | BLOCK = BasicBlock 204 | LAYERS = (3, 4, 6, 3) 205 | 206 | 207 | class ResNet50(ResNetBase): 208 | BLOCK = Bottleneck 209 | LAYERS = (3, 4, 6, 3) 210 | 211 | 212 | class ResNet101(ResNetBase): 213 | BLOCK = Bottleneck 214 | LAYERS = (3, 4, 23, 3) 215 | 216 | 217 | class ResFieldNetBase(ResNetBase): 218 | def network_initialization(self, in_channels, out_channels, D): 219 | field_ch = 32 220 | field_ch2 = 64 221 | self.field_network = nn.Sequential( 222 | ME.MinkowskiSinusoidal(in_channels, field_ch), 223 | ME.MinkowskiBatchNorm(field_ch), 224 | ME.MinkowskiReLU(inplace=True), 225 | ME.MinkowskiLinear(field_ch, field_ch), 226 | ME.MinkowskiBatchNorm(field_ch), 227 | ME.MinkowskiReLU(inplace=True), 228 | ME.MinkowskiToSparseTensor(), 229 | ) 230 | self.field_network2 = nn.Sequential( 231 | ME.MinkowskiSinusoidal(field_ch + in_channels, field_ch2), 232 | ME.MinkowskiBatchNorm(field_ch2), 233 | ME.MinkowskiReLU(inplace=True), 234 | ME.MinkowskiLinear(field_ch2, field_ch2), 235 | ME.MinkowskiBatchNorm(field_ch2), 236 | ME.MinkowskiReLU(inplace=True), 237 | ME.MinkowskiToSparseTensor(), 238 | ) 239 | 240 | ResNetBase.network_initialization(self, field_ch2, out_channels, D) 241 | 242 | def forward(self, x: ME.TensorField): 243 | otensor = self.field_network(x) 244 | otensor2 = self.field_network2(otensor.cat_slice(x)) 245 | return ResNetBase.forward(self, otensor2) 246 | 247 | 248 | class ResFieldNet14(ResFieldNetBase): 249 | BLOCK = BasicBlock 250 | LAYERS = (1, 1, 1, 1) 251 | 252 | 253 | class ResFieldNet18(ResFieldNetBase): 254 | BLOCK = BasicBlock 255 | LAYERS = (2, 2, 2, 2) 256 | 257 | 258 | class ResFieldNet34(ResFieldNetBase): 259 | BLOCK = BasicBlock 260 | LAYERS = (3, 4, 6, 3) 261 | 262 | 263 | class ResFieldNet50(ResFieldNetBase): 264 | BLOCK = Bottleneck 265 | LAYERS = (3, 4, 6, 3) 266 | 267 | 268 | class ResFieldNet101(ResFieldNetBase): 269 | BLOCK = Bottleneck 270 | LAYERS = (3, 4, 23, 3) 271 | 272 | 273 | if __name__ == "__main__": 274 | # loss and network 275 | voxel_size = 0.02 276 | N_labels = 10 277 | 278 | criterion = nn.CrossEntropyLoss() 279 | net = ResNet14(in_channels=3, out_channels=N_labels, D=3) 280 | print(net) 281 | 282 | # a data loader must return a tuple of coords, features, and labels. 283 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 284 | 285 | net = net.to(device) 286 | optimizer = SGD(net.parameters(), lr=1e-2) 287 | 288 | coords, colors, pcd = load_file("1.ply") 289 | coords = torch.from_numpy(coords) 290 | # Get new data 291 | coordinates = ME.utils.batched_coordinates( 292 | [coords / voxel_size, coords / 2 / voxel_size, coords / 4 / voxel_size], 293 | dtype=torch.float32, 294 | ) 295 | features = torch.rand((len(coordinates), 3), device=device) 296 | for i in range(10): 297 | optimizer.zero_grad() 298 | 299 | input = ME.SparseTensor(features, coordinates, device=device) 300 | dummy_label = torch.randint(0, N_labels, (3,), device=device) 301 | 302 | # Forward 303 | output = net(input) 304 | 305 | # Loss 306 | loss = criterion(output.F, dummy_label) 307 | print("Iteration: ", i, ", Loss: ", loss.item()) 308 | 309 | # Gradient 310 | loss.backward() 311 | optimizer.step() 312 | 313 | # Saving and loading a network 314 | torch.save(net.state_dict(), "test.pth") 315 | net.load_state_dict(torch.load("test.pth")) -------------------------------------------------------------------------------- /adapt_online.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | import random 6 | import torch 7 | 8 | import models 9 | from models import MinkUNet18_HEADS, MinkUNet18_MCMC 10 | from utils.config import get_config 11 | from utils.collation import CollateSeparated, CollateFN 12 | from utils.dataset_online import get_online_dataset 13 | from utils.online_logger import OnlineWandbLogger, OnlineCSVLogger 14 | from pipelines import OnlineTrainer 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--config_file", 19 | default="configs/deva/nuscenes_sequence.yaml", 20 | type=str, 21 | help="Path to config file") 22 | parser.add_argument("--split_size", 23 | default=4071, 24 | type=int, 25 | help="Num frames per sub sequence (SemanticKITTI only)") 26 | parser.add_argument("--drop_prob", 27 | default=None, 28 | type=float, 29 | help="Dropout prob MCMC") 30 | parser.add_argument("--save_predictions", 31 | default=False, 32 | action='store_true') 33 | parser.add_argument("--note", 34 | default=None, 35 | type=str) 36 | 37 | 38 | parser.add_argument("--use_pseudo_new", 39 | default=False, 40 | action='store_true') 41 | parser.add_argument("--use_prototype", 42 | default=False, 43 | action='store_true') 44 | parser.add_argument("--use_all_pseudo", 45 | default=False, 46 | action='store_true') 47 | parser.add_argument("--score_weight", 48 | default=False, 49 | action='store_true') 50 | parser.add_argument("--loss_use_score_weight", 51 | default=False, 52 | action='store_true') 53 | 54 | 55 | parser.add_argument("--without_pre_eval_synlidar2kitti", 56 | default=False, 57 | action='store_true') 58 | parser.add_argument("--without_pre_eval_synth4d2kitti", 59 | default=False, 60 | action='store_true') 61 | parser.add_argument("--without_pre_eval_synth4dnusc", 62 | default=False, 63 | action='store_true') 64 | 65 | 66 | parser.add_argument("--kitti_sim", 67 | default=False, 68 | action='store_true') 69 | parser.add_argument("--only_certainty", 70 | default=False, 71 | action='store_true') 72 | parser.add_argument("--only_purity", 73 | default=False, 74 | action='store_true') 75 | parser.add_argument("--without_reload", 76 | default=False, 77 | action='store_true') 78 | parser.add_argument("--save_gem_predictions", 79 | default=False, 80 | action='store_true') 81 | parser.add_argument("--sample_pos", 82 | default=False, 83 | action='store_true') 84 | parser.add_argument("--coord_weight", 85 | default=False, 86 | action='store_true') 87 | parser.add_argument("--use_hard_label", 88 | default=False, 89 | action='store_true') 90 | parser.add_argument("--BMD_prototype", 91 | default=False, 92 | action='store_true') 93 | parser.add_argument("--only_use_BMD_prototype", 94 | default=False, 95 | action='store_true') 96 | parser.add_argument("--score_weight_new", 97 | default=False, 98 | action='store_true') 99 | parser.add_argument("--use_ema", 100 | default=False, 101 | action='store_true') 102 | parser.add_argument("--use_pre_label", 103 | default=False, 104 | action='store_true') 105 | parser.add_argument("--without_ssl_loss", 106 | default=False, 107 | action='store_true') 108 | parser.add_argument("--only_use_prototype", 109 | default=False, 110 | action='store_true') 111 | 112 | 113 | parser.add_argument("--lr", 114 | default=0.0, 115 | type=float) 116 | parser.add_argument("--ssl_beta", 117 | default=1.0, 118 | type=float) 119 | parser.add_argument("--pseudo_th", 120 | default=0.5, 121 | type=float) 122 | parser.add_argument("--loss_eps", 123 | default=0.25, 124 | type=float) 125 | parser.add_argument("--segmentation_beta", 126 | default=1.0, 127 | type=float) 128 | parser.add_argument("--max_time_window", 129 | default=0, 130 | type=int) 131 | parser.add_argument("--loss_method_num", 132 | default=0, 133 | type=int) 134 | parser.add_argument("--pre_label_num", 135 | default=2, 136 | type=int) 137 | parser.add_argument("--pre_label_knn", 138 | default=1, 139 | type=int) 140 | parser.add_argument("--pseudo_knn", 141 | default=5, 142 | type=int) 143 | parser.add_argument("--seed", 144 | default=1234, 145 | type=int) 146 | 147 | AUG_DICT = None 148 | 149 | 150 | def get_mini_config(main_c): 151 | return dict(time_window=main_c.dataset.max_time_window, 152 | mcmc_it=main_c.pipeline.num_mc_iterations, 153 | metric=main_c.pipeline.metric, 154 | cbst_p=main_c.pipeline.top_p, 155 | th_pseudo=main_c.pipeline.th_pseudo, 156 | top_class=main_c.pipeline.top_class, 157 | propagation_size=main_c.pipeline.propagation_size, 158 | drop_prob=main_c.model.drop_prob) 159 | 160 | 161 | def train(config, split_size=4071, save_preds=False, args=None): 162 | 163 | mapping_path = config.dataset.mapping_path 164 | 165 | 166 | if args.max_time_window != 0: 167 | config.dataset.max_time_window = args.max_time_window 168 | 169 | eval_dataset = get_online_dataset(dataset_name=config.dataset.name, 170 | dataset_path=config.dataset.dataset_path, 171 | voxel_size=config.dataset.voxel_size, 172 | augment_data=config.dataset.augment_data, 173 | max_time_wdw=config.dataset.max_time_window, 174 | version=config.dataset.version, 175 | sub_num=config.dataset.num_pts, 176 | ignore_label=config.dataset.ignore_label, 177 | split_size=split_size, 178 | mapping_path=mapping_path, 179 | num_classes=config.model.out_classes, 180 | args=args) 181 | 182 | adapt_dataset = get_online_dataset(dataset_name=config.dataset.name, 183 | dataset_path=config.dataset.dataset_path, 184 | voxel_size=config.dataset.voxel_size, 185 | augment_data=config.dataset.augment_data, 186 | max_time_wdw=config.dataset.max_time_window, 187 | version=config.dataset.version, 188 | sub_num=config.dataset.num_pts, 189 | ignore_label=config.dataset.ignore_label, 190 | split_size=split_size, 191 | mapping_path=mapping_path, 192 | num_classes=config.model.out_classes, 193 | args=args) 194 | 195 | Model = getattr(models, config.model.name) 196 | model = Model(config.model.in_feat_size, config.model.out_classes) 197 | 198 | if config.model.name == 'MinkUNet18': 199 | model = MinkUNet18_HEADS(model) 200 | 201 | if config.pipeline.is_double: 202 | source_model = Model(config.model.in_feat_size, config.model.out_classes) 203 | if config.pipeline.use_mcmc: 204 | if args.drop_prob is not None: 205 | config.model.drop_prob = args.drop_prob 206 | 207 | source_model = MinkUNet18_MCMC(source_model, p_drop=config.model.drop_prob) 208 | else: 209 | source_model = None 210 | 211 | if config.pipeline.delayed_freeze_list is not None: 212 | delayed_list = dict(zip(config.pipeline.delayed_freeze_list, config.pipeline.delayed_freeze_frames)) 213 | else: 214 | delayed_list = None 215 | 216 | 217 | run_time = time.strftime("%Y_%m_%d_%H:%M", time.gmtime()) 218 | if config.pipeline.wandb.run_name is not None: 219 | run_name = run_time + '_' + config.pipeline.wandb.run_name 220 | else: 221 | run_name = run_time 222 | 223 | mini_configs = get_mini_config(config) 224 | 225 | if args.note is not None: 226 | run_name += f'_{args.note}' 227 | else: 228 | for k, v in mini_configs.items(): 229 | run_name += f'_{str(k)}:{str(v)}' 230 | 231 | save_dir = os.path.join(config.pipeline.save_dir, run_name) 232 | args.save_dir = save_dir 233 | # save_dir += "_normal_test" 234 | os.makedirs(save_dir, exist_ok=True) 235 | 236 | wandb_logger = OnlineWandbLogger(project=config.pipeline.wandb.project_name, 237 | entity=config.pipeline.wandb.entity_name, 238 | name=run_name, 239 | offline=config.pipeline.wandb.offline, 240 | config=mini_configs) 241 | 242 | csv_logger = OnlineCSVLogger(save_dir=save_dir, 243 | version='logs') 244 | 245 | loggers = [wandb_logger, csv_logger] 246 | 247 | if args.lr != 0.0: 248 | config.pipeline.optimizer.lr = args.lr 249 | trainer = OnlineTrainer( 250 | eval_dataset=eval_dataset, 251 | adapt_dataset=adapt_dataset, 252 | model=model, 253 | num_classes=config.model.out_classes, 254 | source_model=source_model, 255 | criterion=config.pipeline.loss, 256 | epsilon=config.pipeline.eps, 257 | ssl_criterion=config.pipeline.ssl_loss, 258 | ssl_beta=config.pipeline.ssl_beta, 259 | seg_beta=config.pipeline.segmentation_beta, 260 | optimizer_name=config.pipeline.optimizer.name, 261 | adaptation_batch_size=config.pipeline.dataloader.adaptation_batch_size, 262 | stream_batch_size=config.pipeline.dataloader.stream_batch_size, 263 | lr=config.pipeline.optimizer.lr, 264 | clear_cache_int=config.pipeline.trainer.clear_cache_int, 265 | scheduler_name=config.pipeline.scheduler.scheduler_name, 266 | train_num_workers=config.pipeline.dataloader.num_workers, 267 | val_num_workers=config.pipeline.dataloader.num_workers, 268 | use_random_wdw=config.pipeline.random_time_window, 269 | freeze_list=config.pipeline.freeze_list, 270 | delayed_freeze_list=delayed_list, 271 | num_mc_iterations=config.pipeline.num_mc_iterations, 272 | 273 | collate_fn_eval=CollateFN(), 274 | collate_fn_adapt=CollateSeparated(), 275 | device=config.pipeline.gpu, 276 | default_root_dir=config.pipeline.save_dir, 277 | weights_save_path=os.path.join(save_dir, 'checkpoints'), 278 | loggers=loggers, 279 | save_checkpoint_every=config.pipeline.trainer.save_checkpoint_every, 280 | source_checkpoint=config.pipeline.source_model, 281 | student_checkpoint=config.pipeline.student_model, 282 | is_double=config.pipeline.is_double, 283 | is_pseudo=config.pipeline.is_pseudo, 284 | use_mcmc=config.pipeline.use_mcmc, 285 | sub_epochs=config.pipeline.sub_epoch, 286 | save_predictions=save_preds, 287 | args=args,) 288 | 289 | trainer.adapt_double() 290 | 291 | def set_random_seed(seed=0): 292 | 293 | random.seed(seed) 294 | np.random.seed(seed) 295 | torch.manual_seed(seed) 296 | torch.cuda.manual_seed(seed) 297 | torch.cuda.manual_seed_all(seed) 298 | 299 | os.environ['PYTHONHASHSEED'] = str(seed) 300 | torch.backends.cudnn.benchmark = False 301 | torch.backends.cudnn.deterministic = True 302 | 303 | if __name__ == '__main__': 304 | args = parser.parse_args() 305 | 306 | config = get_config(args.config_file) 307 | 308 | set_random_seed(args.seed) 309 | train(config, split_size=args.split_size, save_preds=args.save_predictions, args=args) 310 | -------------------------------------------------------------------------------- /models/minkunet_ssl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | 25 | import torch.nn as nn 26 | 27 | import MinkowskiEngine as ME 28 | 29 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 30 | 31 | from models.resnet import ResNetBase 32 | 33 | 34 | class MinkUNetBase(ResNetBase): 35 | BLOCK = None 36 | PLANES = None 37 | DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) 38 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 39 | PLANES = (32, 64, 128, 256, 256, 128, 96, 96) 40 | INIT_DIM = 32 41 | OUT_TENSOR_STRIDE = 1 42 | 43 | # To use the model, must call initialize_coords before forward pass. 44 | # Once data is processed, call clear to reset the model before calling 45 | # initialize_coords 46 | def __init__(self, in_channels, out_channels, D=3): 47 | ResNetBase.__init__(self, in_channels, out_channels, D) 48 | 49 | def network_initialization(self, in_channels, out_channels, D): 50 | # Output of the first conv concated to conv6 51 | self.inplanes = self.INIT_DIM 52 | self.conv0p1s1 = ME.MinkowskiConvolution( 53 | in_channels, self.inplanes, kernel_size=5, dimension=D) 54 | 55 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 56 | 57 | self.conv1p1s2 = ME.MinkowskiConvolution( 58 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 59 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 60 | 61 | self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], 62 | self.LAYERS[0]) 63 | 64 | self.conv2p2s2 = ME.MinkowskiConvolution( 65 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 66 | self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) 67 | 68 | self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], 69 | self.LAYERS[1]) 70 | 71 | self.conv3p4s2 = ME.MinkowskiConvolution( 72 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 73 | 74 | self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) 75 | self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], 76 | self.LAYERS[2]) 77 | 78 | self.conv4p8s2 = ME.MinkowskiConvolution( 79 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) 80 | self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) 81 | self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], 82 | self.LAYERS[3]) 83 | 84 | self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( 85 | self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D) 86 | self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) 87 | 88 | self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.expansion 89 | self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], 90 | self.LAYERS[4]) 91 | self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( 92 | self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D) 93 | self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) 94 | 95 | self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.expansion 96 | self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], 97 | self.LAYERS[5]) 98 | self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( 99 | self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D) 100 | self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) 101 | 102 | self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.expansion 103 | self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], 104 | self.LAYERS[6]) 105 | self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( 106 | self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D) 107 | self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) 108 | 109 | self.inplanes = self.PLANES[7] + self.INIT_DIM 110 | self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], 111 | self.LAYERS[7]) 112 | 113 | self.final = ME.MinkowskiConvolution( 114 | self.PLANES[7] * self.BLOCK.expansion, 115 | out_channels, 116 | kernel_size=1, 117 | bias=True, 118 | dimension=D) 119 | self.relu = ME.MinkowskiReLU(inplace=True) 120 | 121 | # define self-sup heads 122 | self.proj_dim = self.PLANES[7] * self.BLOCK.expansion 123 | 124 | # projector head 125 | self.encoder = nn.Sequential( 126 | ME.MinkowskiConvolution( 127 | self.PLANES[7] * self.BLOCK.expansion, 128 | self.proj_dim, 129 | kernel_size=1, 130 | bias=True, 131 | dimension=D), 132 | ME.MinkowskiBatchNorm(self.proj_dim), 133 | self.relu, 134 | ME.MinkowskiConvolution( 135 | self.proj_dim, 136 | self.proj_dim, 137 | kernel_size=1, 138 | bias=True, 139 | dimension=D), 140 | ME.MinkowskiBatchNorm(self.proj_dim)) 141 | 142 | # predictor head 143 | self.predictor = nn.Sequential( 144 | ME.MinkowskiConvolution( 145 | self.proj_dim, 146 | self.proj_dim, 147 | kernel_size=1, 148 | bias=True, 149 | dimension=D), 150 | ME.MinkowskiBatchNorm(self.proj_dim), 151 | self.relu, 152 | ME.MinkowskiConvolution( 153 | self.proj_dim, 154 | self.proj_dim, 155 | kernel_size=1, 156 | bias=True, 157 | dimension=D)) 158 | 159 | def _forward(self, x): 160 | out = self.conv0p1s1(x) 161 | out = self.bn0(out) 162 | out_p1 = self.relu(out) 163 | 164 | out = self.conv1p1s2(out_p1) 165 | out = self.bn1(out) 166 | out = self.relu(out) 167 | out_b1p2 = self.block1(out) 168 | 169 | out = self.conv2p2s2(out_b1p2) 170 | out = self.bn2(out) 171 | out = self.relu(out) 172 | out_b2p4 = self.block2(out) 173 | 174 | out = self.conv3p4s2(out_b2p4) 175 | out = self.bn3(out) 176 | out = self.relu(out) 177 | out_b3p8 = self.block3(out) 178 | 179 | # tensor_stride=16 180 | out = self.conv4p8s2(out_b3p8) 181 | out = self.bn4(out) 182 | out = self.relu(out) 183 | out_bottle = self.block4(out) 184 | 185 | # tensor_stride=8 186 | out = self.convtr4p16s2(out_bottle) 187 | out = self.bntr4(out) 188 | out = self.relu(out) 189 | 190 | out = ME.cat(out, out_b3p8) 191 | out = self.block5(out) 192 | 193 | # tensor_stride=4 194 | out = self.convtr5p8s2(out) 195 | out = self.bntr5(out) 196 | out = self.relu(out) 197 | 198 | out = ME.cat(out, out_b2p4) 199 | out = self.block6(out) 200 | 201 | # tensor_stride=2 202 | out = self.convtr6p4s2(out) 203 | out = self.bntr6(out) 204 | out = self.relu(out) 205 | 206 | out = ME.cat(out, out_b1p2) 207 | out = self.block7(out) 208 | 209 | # tensor_stride=1 210 | out = self.convtr7p2s2(out) 211 | out = self.bntr7(out) 212 | out = self.relu(out) 213 | 214 | out = ME.cat(out, out_p1) 215 | out = self.block8(out) 216 | 217 | return out, out_bottle 218 | 219 | def _forward_heads(self, x): 220 | out_seg = self.final(x) 221 | out_en = self.encoder(x) 222 | out_pred = self.predictor(out_en) 223 | return out_seg, out_en, out_pred 224 | 225 | def forward(self, x, is_train=True): 226 | if is_train: 227 | x0, x1 = x 228 | 229 | # future: t0 -> t1 230 | out_backbone0, out_bottle0 = self._forward(x0) 231 | 232 | # past: t1 -> t0 233 | out_backbone1, out_bottle1 = self._forward(x1) 234 | 235 | # future pred 236 | out_seg0, out_en0, out_pred0 = self._forward_heads(out_backbone0) 237 | # past pred 238 | out_seg1, out_en1, out_pred1 = self._forward_heads(out_backbone1) 239 | 240 | return out_seg0.F, out_en0.F, out_pred0.F, out_backbone0.F, out_bottle0, \ 241 | out_seg1.F, out_en1.F, out_pred1.F, out_backbone1.F, out_bottle1 242 | else: 243 | # forward in backbone 244 | out_backbone, out_bottle = self._forward(x) 245 | 246 | # forward in final 247 | out_seg = self.final(out_backbone) 248 | 249 | return out_seg.F, out_backbone.F, out_bottle 250 | 251 | 252 | class MinkUNet18_SSL(MinkUNetBase): 253 | BLOCK = BasicBlock 254 | LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) 255 | 256 | 257 | class MinkUNet18_HEADS(nn.Module): 258 | 259 | def __init__(self, seg_model): 260 | super().__init__() 261 | self.seg_model = seg_model 262 | 263 | # define self-sup heads 264 | self.proj_dim = self.seg_model.PLANES[7] * self.seg_model.BLOCK.expansion 265 | 266 | self.relu = ME.MinkowskiReLU(inplace=True) 267 | 268 | # projector head 269 | self.encoder = nn.Sequential( 270 | ME.MinkowskiConvolution( 271 | self.seg_model.PLANES[7] * self.seg_model.BLOCK.expansion, 272 | self.proj_dim, 273 | kernel_size=1, 274 | bias=True, 275 | dimension=self.seg_model.D), 276 | ME.MinkowskiBatchNorm(self.proj_dim), 277 | self.relu, 278 | ME.MinkowskiConvolution( 279 | self.proj_dim, 280 | self.proj_dim, 281 | kernel_size=1, 282 | bias=True, 283 | dimension=self.seg_model.D), 284 | ME.MinkowskiBatchNorm(self.proj_dim)) 285 | 286 | # self.encoder = nn.Sequential( 287 | # ME.MinkowskiConvolution( 288 | # self.seg_model.PLANES[7] * self.seg_model.BLOCK.expansion, 289 | # self.proj_dim, 290 | # kernel_size=1, 291 | # bias=True, 292 | # dimension=self.seg_model.D), 293 | # ME.MinkowskiBatchNorm(self.proj_dim), 294 | # self.relu, 295 | # ME.MinkowskiConvolution( 296 | # self.proj_dim, 297 | # self.seg_model.PLANES[7] * self.seg_model.BLOCK.expansion, 298 | # kernel_size=1, 299 | # bias=True, 300 | # dimension=self.seg_model.D)) 301 | 302 | # predictor head 303 | self.predictor = nn.Sequential( 304 | ME.MinkowskiConvolution( 305 | self.proj_dim, 306 | self.proj_dim, 307 | kernel_size=1, 308 | bias=True, 309 | dimension=self.seg_model.D), 310 | ME.MinkowskiBatchNorm(self.proj_dim), 311 | self.relu, 312 | ME.MinkowskiConvolution( 313 | self.proj_dim, 314 | self.proj_dim, 315 | kernel_size=1, 316 | bias=True, 317 | dimension=self.seg_model.D)) 318 | 319 | def _forward_heads(self, x): 320 | out_seg = self.seg_model.final(x) 321 | out_en = self.encoder(x) 322 | out_pred = self.predictor(out_en) 323 | return out_seg, out_en, out_pred 324 | 325 | def forward(self, x, is_train=True): 326 | if is_train: 327 | x0, x1 = x 328 | 329 | # future: t0 -> t1 330 | out_backbone0, out_bottle0 = self.seg_model(x0, is_seg=False) 331 | 332 | # past: t1 -> t0 333 | out_backbone1, out_bottle1 = self.seg_model(x1, is_seg=False) 334 | 335 | # future pred 336 | out_seg0, out_en0, out_pred0 = self._forward_heads(out_backbone0) 337 | # past pred 338 | out_seg1, out_en1, out_pred1 = self._forward_heads(out_backbone1) 339 | 340 | return out_seg0.F, out_en0.F, out_pred0.F, out_backbone0.F, out_bottle0,\ 341 | out_seg1.F, out_en1.F, out_pred1.F, out_backbone1.F, out_bottle1 342 | else: 343 | # forward in backbone 344 | out_backbone, out_bottle= self.seg_model(x, is_seg=False) 345 | 346 | # forward in final 347 | out_seg = self.seg_model.final(out_backbone) 348 | 349 | return out_seg.F, out_backbone.F, out_bottle 350 | 351 | 352 | class MinkUNet18_BYOL(nn.Module): 353 | 354 | def __init__(self, seg_model): 355 | super().__init__() 356 | self.seg_model = seg_model 357 | 358 | # define self-sup heads 359 | self.proj_dim = self.seg_model.PLANES[7] * self.seg_model.BLOCK.expansion 360 | 361 | self.relu = ME.MinkowskiReLU(inplace=True) 362 | 363 | # projector head 364 | self.encoder = nn.Sequential( 365 | ME.MinkowskiConvolution( 366 | self.seg_model.PLANES[7] * self.seg_model.BLOCK.expansion, 367 | self.proj_dim, 368 | kernel_size=1, 369 | bias=True, 370 | dimension=self.seg_model.D), 371 | ME.MinkowskiBatchNorm(self.proj_dim), 372 | self.relu, 373 | ME.MinkowskiConvolution( 374 | self.proj_dim, 375 | self.proj_dim, 376 | kernel_size=1, 377 | bias=True, 378 | dimension=self.seg_model.D), 379 | ME.MinkowskiBatchNorm(self.proj_dim)) 380 | 381 | # self.encoder = nn.Sequential( 382 | # ME.MinkowskiConvolution( 383 | # self.seg_model.PLANES[7] * self.seg_model.BLOCK.expansion, 384 | # self.proj_dim, 385 | # kernel_size=1, 386 | # bias=True, 387 | # dimension=self.seg_model.D), 388 | # ME.MinkowskiBatchNorm(self.proj_dim), 389 | # self.relu, 390 | # ME.MinkowskiConvolution( 391 | # self.proj_dim, 392 | # self.seg_model.PLANES[7] * self.seg_model.BLOCK.expansion, 393 | # kernel_size=1, 394 | # bias=True, 395 | # dimension=self.seg_model.D)) 396 | 397 | # predictor head 398 | self.predictor = nn.Sequential( 399 | ME.MinkowskiConvolution( 400 | self.proj_dim, 401 | self.proj_dim, 402 | kernel_size=1, 403 | bias=True, 404 | dimension=self.seg_model.D), 405 | ME.MinkowskiBatchNorm(self.proj_dim), 406 | self.relu, 407 | ME.MinkowskiConvolution( 408 | self.proj_dim, 409 | self.proj_dim, 410 | kernel_size=1, 411 | bias=True, 412 | dimension=self.seg_model.D)) 413 | 414 | def _forward_heads(self, x): 415 | out_seg = self.seg_model.final(x) 416 | out_en = self.encoder(x) 417 | out_pred = self.predictor(out_en) 418 | return out_seg, out_en, out_pred 419 | 420 | def forward(self, x, is_train=True, momentum=True): 421 | if is_train: 422 | if momentum: 423 | x0, _ = x 424 | else: 425 | _, x0 = x 426 | 427 | out_backbone0, out_bottle0 = self.seg_model(x0, is_seg=False) 428 | 429 | out_seg0, out_en0, out_pred0 = self._forward_heads(out_backbone0) 430 | 431 | return out_seg0.F, out_en0.F, out_pred0.F, out_backbone0.F, out_bottle0 432 | else: 433 | # forward in backbone 434 | out_backbone, out_bottle = self.seg_model(x, is_seg=False) 435 | 436 | # forward in final 437 | out_seg = self.seg_model.final(out_backbone) 438 | 439 | return out_seg.F, out_backbone.F, out_bottle 440 | 441 | 442 | class MinkUNet18_MCMC(nn.Module): 443 | 444 | def __init__(self, seg_model, p_drop=0.5): 445 | super().__init__() 446 | self.seg_model = seg_model 447 | 448 | self.dropout = ME.MinkowskiDropout(p=p_drop) 449 | 450 | def forward(self, x, is_train=True): 451 | # forward in backbone 452 | out_backbone, out_bottle = self.seg_model(x, is_seg=False) 453 | out_backbone = self.dropout(out_backbone) 454 | # forward in final 455 | out_seg = self.seg_model.final(out_backbone) 456 | 457 | return out_seg.F, out_backbone.F, out_bottle 458 | 459 | --------------------------------------------------------------------------------