├── .gitignore ├── LICENSE ├── Readme.md ├── config ├── Readme.md ├── callbacks │ └── default.yaml ├── data_augmentation │ ├── crop.yaml │ ├── only_norm.yaml │ ├── only_tensor.yaml │ ├── randaugment_crop_flip.yaml │ ├── randaugment_crop_hflip.yaml │ ├── randaugment_flip.yaml │ ├── randaugment_hflip.yaml │ ├── randaugment_light_flip.yaml │ ├── randaugment_light_hflip.yaml │ ├── randaugment_nonorm_flip.yaml │ ├── randaugment_nonorm_hflip.yaml │ ├── randaugment_scale_crop_flip.yaml │ ├── randaugment_scale_crop_hflip.yaml │ ├── scale_crop_VOC2010.yaml │ ├── scale_crop_flip.yaml │ └── scale_crop_hflip.yaml ├── dataset │ ├── Cityscapes.yaml │ ├── Cityscapes_coarse.yaml │ ├── Cityscapes_fine_coarse.yaml │ ├── PennFudan.yaml │ ├── VOC2010_Context.yaml │ └── VOC2010_Context_60.yaml ├── environment │ ├── cluster.yaml │ └── local.yaml ├── experiment │ ├── Cityscapes.yaml │ ├── Cityscapes_coarse.yaml │ ├── Cityscapes_fine_coarse.yaml │ ├── PennFudan.yaml │ ├── VOC2010_Context.yaml │ ├── VOC2010_Context_60.yaml │ └── default.yaml ├── logger │ └── tensorboard.yaml ├── lr_scheduler │ ├── polynomial.yaml │ ├── polynomial_epoch.yaml │ ├── polynomial_epoch_warmup.yaml │ └── polynomial_warmup.yaml ├── metric │ ├── MAP.yaml │ ├── mean_Dice.yaml │ ├── mean_Dice_Class.yaml │ ├── mean_IoU.yaml │ └── mean_IoU_Class.yaml ├── model │ ├── DeepLabv3.yaml │ ├── FCN.yaml │ ├── Mask_RCNN.yaml │ ├── Mask_RCNN_RMI.yaml │ ├── UNet.yaml │ ├── hrnet.yaml │ ├── hrnet_ocr.yaml │ ├── hrnet_ocr_aspp.yaml │ └── hrnet_ocr_ms.yaml ├── optimizer │ ├── ADAMW.yaml │ ├── MADGRAD.yaml │ └── SGD.yaml ├── testing.yaml ├── trainer │ ├── InstSeg.yaml │ └── SemSeg.yaml └── training.yaml ├── datasets ├── Cityscapes │ ├── Cityscapes.py │ ├── Cityscapes_coarse.py │ ├── Cityscapes_fine_coarse.py │ ├── process_Cityscapes.py │ └── process_Cityscapes_coarse.py ├── DataModules.py ├── PennFudan │ └── PennFudan.py └── VOC2010_Context │ ├── VOC2010_Context.py │ └── process_VOC2010_Context.py ├── imgs ├── Data.png ├── Epochs_Batch_Size.png ├── Further.png ├── Logos │ ├── DKFZ_Logo.png │ ├── HI_Logo.png │ └── HI_Title.png ├── Lossfunctions.png ├── Mixed_Precision.png ├── Models_Basic.png ├── RMI_Loss.png ├── Time_Complexity.png └── VOC2010.png ├── models ├── DeepLabv3.py ├── FCN.py ├── Mask_RCNN.py ├── Mask_RCNN_RMI_loss.py ├── UNet.py ├── backbones │ └── hrnet_backbone.py ├── hrnet.py ├── hrnet_ocr.py ├── hrnet_ocr_aspp.py ├── hrnet_ocr_ms.py └── model_ensemble.py ├── pretrained └── .gitkeep ├── pyproject.toml ├── requirements.txt ├── src ├── augmentations.py ├── callbacks.py ├── loss │ ├── DC_CE_Loss.py │ ├── Dice_Loss.py │ ├── rmi.py │ └── rmi_utils.py ├── loss_function.py ├── lr_scheduler.py ├── metric.py ├── utils.py └── visualization.py ├── testing.py ├── tools ├── Readme.md ├── dataset_stats.py ├── lr_finder.py ├── predict.py ├── show_data.py └── show_prediction.py ├── trainers ├── Instance_Segmentation_Trainer.py └── Semantic_Segmentation_Trainer.py └── training.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | *.jpg 3 | *.sh 4 | *.ckpt 5 | *.pth 6 | *__pycache__* 7 | logs/ 8 | dataset_stats/ 9 | lightning_logs/ 10 | .idea/ 11 | stuff.py 12 | config/environment/local_2.yaml 13 | config/environment/cluster_2.yaml 14 | -------------------------------------------------------------------------------- /config/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | CALLBACKS: 3 | time_callback: 4 | _target_: src.callbacks.TimeCallback 5 | lr_monitor: 6 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 7 | logging_interval: 'step' 8 | tqdm_progressbar: 9 | #_target_: pytorch_lightning.callbacks.progress.TQDMProgressBar 10 | _target_: src.callbacks.customTQDMProgressBar 11 | refresh_rate: 1 12 | model_summary: 13 | _target_: pytorch_lightning.callbacks.ModelSummary 14 | max_depth: 2 -------------------------------------------------------------------------------- /config/data_augmentation/crop.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | TRAIN: 8 | - Compose: 9 | transforms: 10 | - RandomCrop: 11 | height: ${AUGMENTATIONS.crop_size[0]} 12 | width: ${AUGMENTATIONS.crop_size[1]} 13 | - Normalize: 14 | mean: ${AUGMENTATIONS.mean} 15 | std: ${AUGMENTATIONS.std} 16 | - ToTensorV2: 17 | VALIDATION: 18 | - Compose: 19 | transforms: 20 | - Normalize: 21 | mean: ${AUGMENTATIONS.mean} 22 | std: ${AUGMENTATIONS.std} 23 | - ToTensorV2: 24 | TEST: ${AUGMENTATIONS.VALIDATION} 25 | -------------------------------------------------------------------------------- /config/data_augmentation/only_norm.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | mean: [ 0.485, 0.456, 0.406 ] 4 | std: [ 0.229, 0.224, 0.225 ] 5 | TRAIN: 6 | - Compose: 7 | transforms: 8 | - Normalize: 9 | mean: ${AUGMENTATIONS.mean} 10 | std: ${AUGMENTATIONS.std} 11 | - ToTensorV2: 12 | VALIDATION: 13 | - Compose: 14 | transforms: 15 | - Normalize: 16 | mean: ${AUGMENTATIONS.mean} 17 | std: ${AUGMENTATIONS.std} 18 | - ToTensorV2: 19 | TEST: ${AUGMENTATIONS.VALIDATION} 20 | -------------------------------------------------------------------------------- /config/data_augmentation/only_tensor.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | TRAIN: 4 | - Compose: 5 | transforms: 6 | - ToTensorV2: 7 | VALIDATION: 8 | - Compose: 9 | transforms: 10 | - ToTensorV2: 11 | TEST: ${AUGMENTATIONS.VALIDATION} 12 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_crop_flip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | N: 3 8 | M: 3 9 | TRAIN: 10 | - Compose: 11 | transforms: 12 | - RandomCrop: 13 | height: ${AUGMENTATIONS.crop_size[0]} 14 | width: ${AUGMENTATIONS.crop_size[1]} 15 | - RandAugment: 16 | N: ${AUGMENTATIONS.N} 17 | M: ${AUGMENTATIONS.M} 18 | mode: 2 19 | p: 0.5 20 | - VerticalFlip: 21 | p: 0.25 22 | - HorizontalFlip: 23 | p: 0.25 24 | - Normalize: 25 | mean: ${AUGMENTATIONS.mean} 26 | std: ${AUGMENTATIONS.std} 27 | - ToTensorV2: 28 | VALIDATION: 29 | - Compose: 30 | transforms: 31 | - Normalize: 32 | mean: ${AUGMENTATIONS.mean} 33 | std: ${AUGMENTATIONS.std} 34 | - ToTensorV2: 35 | TEST: ${AUGMENTATIONS.VALIDATION} 36 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_crop_hflip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | N: 3 8 | M: 3 9 | TRAIN: 10 | - Compose: 11 | transforms: 12 | - RandomCrop: 13 | height: ${AUGMENTATIONS.crop_size[0]} 14 | width: ${AUGMENTATIONS.crop_size[1]} 15 | - RandAugment: 16 | N: ${AUGMENTATIONS.N} 17 | M: ${AUGMENTATIONS.M} 18 | mode: 2 19 | p: 0.5 20 | - HorizontalFlip: 21 | p: 0.5 22 | - Normalize: 23 | mean: ${AUGMENTATIONS.mean} 24 | std: ${AUGMENTATIONS.std} 25 | - ToTensorV2: 26 | VALIDATION: 27 | - Compose: 28 | transforms: 29 | - Normalize: 30 | mean: ${AUGMENTATIONS.mean} 31 | std: ${AUGMENTATIONS.std} 32 | - ToTensorV2: 33 | TEST: ${AUGMENTATIONS.VALIDATION} 34 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_flip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | N: 3 8 | M: 3 9 | TRAIN: 10 | - Compose: 11 | transforms: 12 | - RandAugment: 13 | N: 3 14 | M: 3 15 | mode: 2 16 | p: 0.5 17 | - VerticalFlip: 18 | p: 0.25 19 | - HorizontalFlip: 20 | p: 0.25 21 | - Normalize: 22 | mean: ${AUGMENTATIONS.mean} 23 | std: ${AUGMENTATIONS.std} 24 | - ToTensorV2: 25 | VALIDATION: 26 | - Compose: 27 | transforms: 28 | - Normalize: 29 | mean: ${AUGMENTATIONS.mean} 30 | std: ${AUGMENTATIONS.std} 31 | - ToTensorV2: 32 | TEST: ${AUGMENTATIONS.VALIDATION} 33 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_hflip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | N: 3 8 | M: 3 9 | TRAIN: 10 | - Compose: 11 | transforms: 12 | - RandAugment: 13 | N: 3 14 | M: 3 15 | mode: 2 16 | p: 0.5 17 | - HorizontalFlip: 18 | p: 0.5 19 | - Normalize: 20 | mean: ${AUGMENTATIONS.mean} 21 | std: ${AUGMENTATIONS.std} 22 | - ToTensorV2: 23 | VALIDATION: 24 | - Compose: 25 | transforms: 26 | - Normalize: 27 | mean: ${AUGMENTATIONS.mean} 28 | std: ${AUGMENTATIONS.std} 29 | - ToTensorV2: 30 | TEST: ${AUGMENTATIONS.VALIDATION} 31 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_light_flip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | N: 3 8 | M: 5 9 | TRAIN: 10 | - Compose: 11 | transforms: 12 | - RandAugment_light: 13 | N: ${AUGMENTATIONS.N} 14 | M: ${AUGMENTATIONS.M} 15 | mode: 2 16 | p: 0.5 17 | - VerticalFlip: 18 | p: 0.25 19 | - HorizontalFlip: 20 | p: 0.25 21 | - Normalize: 22 | mean: ${AUGMENTATIONS.mean} 23 | std: ${AUGMENTATIONS.std} 24 | - ToTensorV2: 25 | VALIDATION: 26 | - Compose: 27 | transforms: 28 | - Normalize: 29 | mean: ${AUGMENTATIONS.mean} 30 | std: ${AUGMENTATIONS.std} 31 | - ToTensorV2: 32 | TEST: ${AUGMENTATIONS.VALIDATION} 33 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_light_hflip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | N: 3 8 | M: 5 9 | TRAIN: 10 | - Compose: 11 | transforms: 12 | - RandAugment_light: 13 | N: ${AUGMENTATIONS.N} 14 | M: ${AUGMENTATIONS.M} 15 | mode: 2 16 | p: 0.5 17 | - HorizontalFlip: 18 | p: 0.5 19 | - Normalize: 20 | mean: ${AUGMENTATIONS.mean} 21 | std: ${AUGMENTATIONS.std} 22 | - ToTensorV2: 23 | VALIDATION: 24 | - Compose: 25 | transforms: 26 | - Normalize: 27 | mean: ${AUGMENTATIONS.mean} 28 | std: ${AUGMENTATIONS.std} 29 | - ToTensorV2: 30 | TEST: ${AUGMENTATIONS.VALIDATION} 31 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_nonorm_flip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | N: 3 4 | M: 3 5 | TRAIN: 6 | - Compose: 7 | transforms: 8 | - RandAugment: 9 | N: ${AUGMENTATIONS.N} 10 | M: ${AUGMENTATIONS.M} 11 | mode: 2 12 | p: 0.5 13 | - VerticalFlip: 14 | p: 0.25 15 | - HorizontalFlip: 16 | p: 0.25 17 | - ToTensorV2: 18 | VALIDATION: 19 | - Compose: 20 | transforms: 21 | - ToTensorV2: 22 | TEST: ${AUGMENTATIONS.VALIDATION} 23 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_nonorm_hflip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | N: 3 4 | M: 3 5 | TRAIN: 6 | - Compose: 7 | transforms: 8 | - RandAugment: 9 | N: ${AUGMENTATIONS.N} 10 | M: ${AUGMENTATIONS.M} 11 | mode: 2 12 | p: 0.5 13 | - HorizontalFlip: 14 | p: 0.5 15 | - ToTensorV2: 16 | VALIDATION: 17 | - Compose: 18 | transforms: 19 | - ToTensorV2: 20 | TEST: ${AUGMENTATIONS.VALIDATION} 21 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_scale_crop_flip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | N: 3 8 | M: 3 9 | TRAIN: 10 | - Compose: 11 | transforms: 12 | - RandomScale: 13 | scale_limit: ${AUGMENTATIONS.scale_limit} 14 | p: 1.0 15 | - RandomCrop: 16 | height: ${AUGMENTATIONS.crop_size[0]} 17 | width: ${AUGMENTATIONS.crop_size[1]} 18 | - RandAugment: 19 | N: ${AUGMENTATIONS.N} 20 | M: ${AUGMENTATIONS.M} 21 | mode: 2 22 | p: 0.5 23 | - VerticalFlip: 24 | p: 0.25 25 | - HorizontalFlip: 26 | p: 0.25 27 | - Normalize: 28 | mean: ${AUGMENTATIONS.mean} 29 | std: ${AUGMENTATIONS.std} 30 | - ToTensorV2: 31 | VALIDATION: 32 | - Compose: 33 | transforms: 34 | - Normalize: 35 | mean: ${AUGMENTATIONS.mean} 36 | std: ${AUGMENTATIONS.std} 37 | - ToTensorV2: 38 | TEST: ${AUGMENTATIONS.VALIDATION} 39 | -------------------------------------------------------------------------------- /config/data_augmentation/randaugment_scale_crop_hflip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | N: 3 8 | M: 3 9 | TRAIN: 10 | - Compose: 11 | transforms: 12 | - RandomScale: 13 | scale_limit: ${AUGMENTATIONS.scale_limit} 14 | p: 1.0 15 | - RandomCrop: 16 | height: ${AUGMENTATIONS.crop_size[0]} 17 | width: ${AUGMENTATIONS.crop_size[1]} 18 | - RandAugment: 19 | N: ${AUGMENTATIONS.N} 20 | M: ${AUGMENTATIONS.M} 21 | mode: 2 22 | p: 0.5 23 | - HorizontalFlip: 24 | p: 0.5 25 | - Normalize: 26 | mean: ${AUGMENTATIONS.mean} 27 | std: ${AUGMENTATIONS.std} 28 | - ToTensorV2: 29 | VALIDATION: 30 | - Compose: 31 | transforms: 32 | - Normalize: 33 | mean: ${AUGMENTATIONS.mean} 34 | std: ${AUGMENTATIONS.std} 35 | - ToTensorV2: 36 | TEST: ${AUGMENTATIONS.VALIDATION} 37 | -------------------------------------------------------------------------------- /config/data_augmentation/scale_crop_VOC2010.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5,1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | TRAIN: 8 | - Compose: 9 | transforms: 10 | - SmallestMaxSize: 11 | max_size: ${AUGMENTATIONS.crop_size[0]} 12 | - RandomScale: 13 | scale_limit: ${AUGMENTATIONS.scale_limit} 14 | p: 1.0 15 | - RGBShift: 16 | r_shift_limit: 10 17 | g_shift_limit: 10 18 | b_shift_limit: 10 19 | - PadIfNeeded: 20 | min_height: ${AUGMENTATIONS.crop_size[0]} 21 | min_width: ${AUGMENTATIONS.crop_size[1]} 22 | border_mode: 0 #"cv2.BORDER_CONSTANT" 23 | value: 0 24 | mask_value: ${DATASET.IGNORE_INDEX} 25 | - RandomCrop: 26 | height: ${AUGMENTATIONS.crop_size[0]} 27 | width: ${AUGMENTATIONS.crop_size[1]} 28 | - HorizontalFlip: 29 | p: 0.5 30 | - Normalize: 31 | mean: ${AUGMENTATIONS.mean} 32 | std: ${AUGMENTATIONS.std} 33 | - ToTensorV2: 34 | VALIDATION: 35 | - Compose: 36 | transforms: 37 | - LongestMaxSize: 38 | max_size: ${AUGMENTATIONS.crop_size[0]} 39 | - PadIfNeeded: 40 | min_height: ${AUGMENTATIONS.crop_size[0]} 41 | min_width: ${AUGMENTATIONS.crop_size[1]} 42 | border_mode: 0 #"cv2.BORDER_CONSTANT" 43 | value: 0 44 | mask_value: ${DATASET.IGNORE_INDEX} 45 | - Normalize: 46 | mean: ${AUGMENTATIONS.mean} 47 | std: ${AUGMENTATIONS.std} 48 | - ToTensorV2: 49 | TEST: 50 | - Compose: 51 | transforms: 52 | - Normalize: 53 | mean: ${AUGMENTATIONS.mean} 54 | std: ${AUGMENTATIONS.std} 55 | - ToTensorV2: 56 | -------------------------------------------------------------------------------- /config/data_augmentation/scale_crop_flip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | TRAIN: 8 | - Compose: 9 | transforms: 10 | - RandomScale: 11 | scale_limit: ${AUGMENTATIONS.scale_limit} 12 | p: 1.0 13 | - RandomCrop: 14 | height: ${AUGMENTATIONS.crop_size[0]} 15 | width: ${AUGMENTATIONS.crop_size[1]} 16 | - HorizontalFlip: 17 | p: 0.25 18 | - VerticalFlip: 19 | p: 0.25 20 | - Normalize: 21 | mean: ${AUGMENTATIONS.mean} 22 | std: ${AUGMENTATIONS.std} 23 | - ToTensorV2: 24 | VALIDATION: 25 | - Compose: 26 | transforms: 27 | - Normalize: 28 | mean: ${AUGMENTATIONS.mean} 29 | std: ${AUGMENTATIONS.std} 30 | - ToTensorV2: 31 | TEST: ${AUGMENTATIONS.VALIDATION} 32 | -------------------------------------------------------------------------------- /config/data_augmentation/scale_crop_hflip.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | AUGMENTATIONS: 3 | scale_limit: [-0.5, 1.0] 4 | crop_size: [512, 512] 5 | mean: [ 0.485, 0.456, 0.406 ] 6 | std: [ 0.229, 0.224, 0.225 ] 7 | TRAIN: 8 | - Compose: 9 | transforms: 10 | - RandomScale: 11 | scale_limit: ${AUGMENTATIONS.scale_limit} 12 | p: 1.0 13 | - RandomCrop: 14 | height: ${AUGMENTATIONS.crop_size[0]} 15 | width: ${AUGMENTATIONS.crop_size[1]} 16 | - HorizontalFlip: 17 | p: 0.5 18 | - Normalize: 19 | mean: ${AUGMENTATIONS.mean} 20 | std: ${AUGMENTATIONS.std} 21 | - ToTensorV2: 22 | VALIDATION: 23 | - Compose: 24 | transforms: 25 | - Normalize: 26 | mean: ${AUGMENTATIONS.mean} 27 | std: ${AUGMENTATIONS.std} 28 | - ToTensorV2: 29 | TEST: ${AUGMENTATIONS.VALIDATION} 30 | -------------------------------------------------------------------------------- /config/dataset/Cityscapes.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # define and configure the dataset class which is called by hydra 4 | dataset: 5 | _target_: datasets.Cityscapes.Cityscapes.Cityscapes_dataset 6 | root: ${paths.Cityscapes} 7 | 8 | # additional information about the dataset 9 | DATASET: 10 | ##NEEDED PARAMETERS 11 | NAME: "Cityscapes" 12 | NUM_CLASSES: 19 13 | IGNORE_INDEX: 255 14 | #INFORMATION ABOUT DS, NOT NEEDED 15 | SIZE: 16 | TRAIN: 2975 17 | VAL: 500 18 | TEST: 1525 19 | #NEEDED IF WEIGHTED LOSSFUNCTIONS ARE USED 20 | CLASS_WEIGHTS: [ 0.8373, 0.918, 0.866, 1.0345,1.0166, 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529, 1.0507 ] 21 | #OPTIONAL - JUST IF CLASS LOGGING RESULTS ARE WANTED 22 | CLASS_LABELS: 23 | - road 24 | - sidewalk 25 | - building 26 | - wall 27 | - fence 28 | - pole 29 | - traffic light 30 | - traffic sign 31 | - vegetation 32 | - terrain 33 | - sky 34 | - person 35 | - rider 36 | - car 37 | - truck 38 | - bus 39 | - train 40 | - motorcycle 41 | - bicycle 42 | -------------------------------------------------------------------------------- /config/dataset/Cityscapes_coarse.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # define and configure the dataset class which is called by hydra 4 | dataset: 5 | _target_: datasets.Cityscapes.Cityscapes_coarse.Cityscapes_coarse_dataset 6 | root: ${paths.Cityscapes} 7 | 8 | # additional information about the dataset 9 | DATASET: 10 | NAME: "Cityscapes_coarse" 11 | NUM_CLASSES: 19 12 | IGNORE_INDEX: 255 13 | #NEEDED IF WEIGHTED LOSSFUNCTIONS ARE USED 14 | CLASS_WEIGHTS: [ 0.8373, 0.918, 0.866, 1.0345,1.0166, 15 | 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 16 | 0.9539, 0.9843,1.1116, 0.9037, 1.0865, 17 | 1.0955, 1.0865, 1.1529, 1.0507 ] 18 | #INFORMATION ABOUT DS, NOT NEEDED 19 | SIZE: 20 | TRAIN: 19997 21 | VAL: 500 22 | #OPTIONAL - JUST IF CLASS LOGGING RESULTS ARE WANTED 23 | CLASS_LABELS: 24 | - road 25 | - sidewalk 26 | - building 27 | - wall 28 | - fence 29 | - pole 30 | - traffic light 31 | - traffic sign 32 | - vegetation 33 | - terrain 34 | - sky 35 | - person 36 | - rider 37 | - car 38 | - truck 39 | - bus 40 | - train 41 | - motorcycle 42 | - bicycle 43 | -------------------------------------------------------------------------------- /config/dataset/Cityscapes_fine_coarse.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # define and configure the dataset class which is called by hydra 4 | dataset: 5 | _target_: datasets.Cityscapes.Cityscape_fine_coarse.Cityscape_fine_coarse_dataset 6 | root: ${paths.Cityscapes} 7 | 8 | # additional information about the dataset 9 | DATASET: 10 | NAME: "Cityscapes_fine_coarse" 11 | NUM_CLASSES: 19 12 | IGNORE_INDEX: 255 13 | #NEEDED IF WEIGHTED LOSSFUNCTIONS ARE USED 14 | CLASS_WEIGHTS: [ 0.8373, 0.918, 0.866, 1.0345,1.0166, 15 | 0.9969, 0.9754, 1.0489, 0.8786, 1.0023, 16 | 0.9539, 0.9843,1.1116, 0.9037, 1.0865, 17 | 1.0955, 1.0865, 1.1529, 1.0507 ] 18 | #INFORMATION ABOUT DS, NOT NEEDED 19 | SIZE: 20 | TRAIN: 22972 21 | VAL: 500 22 | #OPTIONAL - JUST IF CLASS LOGGING RESULTS ARE WANTED 23 | CLASS_LABELS: 24 | - road 25 | - sidewalk 26 | - building 27 | - wall 28 | - fence 29 | - pole 30 | - traffic light 31 | - traffic sign 32 | - vegetation 33 | - terrain 34 | - sky 35 | - person 36 | - rider 37 | - car 38 | - truck 39 | - bus 40 | - train 41 | - motorcycle 42 | - bicycle 43 | -------------------------------------------------------------------------------- /config/dataset/PennFudan.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # define and configure the dataset class which is called by hydra 4 | dataset: 5 | _target_: datasets.PennFudan.PennFudan.PennFudanDataset 6 | root: ${paths.PennFudan} 7 | 8 | # additional information about the dataset 9 | DATASET: 10 | ##NEEDED PARAMETERS 11 | NAME: "PennFudan" 12 | NUM_CLASSES: 2 13 | 14 | -------------------------------------------------------------------------------- /config/dataset/VOC2010_Context.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # define and configure the dataset class which is called by hydra 4 | dataset: 5 | _target_: datasets.VOC2010_Context.VOC2010_Context.VOC2010_Context_dataset 6 | root: ${paths.VOC2010_Context} 7 | ignore_index: ${DATASET.IGNORE_INDEX} 8 | num_classes: ${DATASET.NUM_CLASSES} 9 | 10 | # additional information about the dataset 11 | DATASET: 12 | NAME: "VOC2010_Context" 13 | NUM_CLASSES: 59 14 | IGNORE_INDEX: 255 15 | #INFORMATION ABOUT DS, NOT NEEDED 16 | SIZE: 17 | TRAIN: 4998 18 | VAL: 5103 19 | TEST: 0 20 | #OPTIONAL - JUST IF CLASS LOGGING RESULTS ARE WANTED 21 | CLASS_LABELS: 22 | - aeroplane 23 | - bag 24 | - bed 25 | - bedclothes 26 | - bench 27 | - bicycle 28 | - bird 29 | - boat 30 | - book 31 | - bottle 32 | - building 33 | - bus 34 | - cabinet 35 | - car 36 | - cat 37 | - ceiling 38 | - chair 39 | - cloth 40 | - computer 41 | - cow 42 | - cup 43 | - curtain 44 | - dog 45 | - door 46 | - fence 47 | - floor 48 | - flower 49 | - food 50 | - grass 51 | - ground 52 | - horse 53 | - keyboard 54 | - light 55 | - motorbike 56 | - mountain 57 | - mouse 58 | - person 59 | - plate 60 | - platform 61 | - pottedplant 62 | - road 63 | - rock 64 | - sheep 65 | - shelves 66 | - sidewalk 67 | - sign 68 | - sky 69 | - snow 70 | - sofa 71 | - table 72 | - track 73 | - train 74 | - tree 75 | - truck 76 | - tvmonitor 77 | - wall 78 | - water 79 | - window 80 | - wood 81 | 82 | -------------------------------------------------------------------------------- /config/dataset/VOC2010_Context_60.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # define and configure the dataset class which is called by hydra 4 | dataset: 5 | _target_: datasets.VOC2010_Context.VOC2010_Context.VOC2010_Context_dataset 6 | root: ${paths.VOC2010_Context} 7 | ignore_index: ${DATASET.IGNORE_INDEX} 8 | num_classes: ${DATASET.NUM_CLASSES} 9 | 10 | # additional information about the dataset 11 | DATASET: 12 | NAME: "VOC2010_Context_60" 13 | NUM_CLASSES: 60 14 | IGNORE_INDEX: 255 15 | #INFORMATION ABOUT DS, NOT NEEDED 16 | SIZE: 17 | TRAIN: 4998 18 | VAL: 5103 19 | TEST: 0 20 | #OPTIONAL - JUST IF CLASS LOGGING RESULTS ARE WANTED 21 | CLASS_LABELS: 22 | - background 23 | - aeroplane 24 | - bag 25 | - bed 26 | - bedclothes 27 | - bench 28 | - bicycle 29 | - bird 30 | - boat 31 | - book 32 | - bottle 33 | - building 34 | - bus 35 | - cabinet 36 | - car 37 | - cat 38 | - ceiling 39 | - chair 40 | - cloth 41 | - computer 42 | - cow 43 | - cup 44 | - curtain 45 | - dog 46 | - door 47 | - fence 48 | - floor 49 | - flower 50 | - food 51 | - grass 52 | - ground 53 | - horse 54 | - keyboard 55 | - light 56 | - motorbike 57 | - mountain 58 | - mouse 59 | - person 60 | - plate 61 | - platform 62 | - pottedplant 63 | - road 64 | - rock 65 | - sheep 66 | - shelves 67 | - sidewalk 68 | - sign 69 | - sky 70 | - snow 71 | - sofa 72 | - table 73 | - track 74 | - train 75 | - tree 76 | - truck 77 | - tvmonitor 78 | - wall 79 | - water 80 | - window 81 | - wood 82 | 83 | -------------------------------------------------------------------------------- /config/environment/cluster.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | CALLBACKS: 3 | tqdm_progressbar: ~ 4 | pl_trainer: 5 | enable_progress_bar: False 6 | enable_checkpointing: True 7 | 8 | LOGDIR: /dkfz/cluster/gpu/.../logs 9 | 10 | paths: 11 | Cityscapes: 12 | LABELS: /dkfz/cluster/gpu/.../cityscapes 13 | IMAGES: /dkfz/cluster/gpu/data/.../cityscapes 14 | VOC2010_Context: /dkfz/cluster/gpu/data/.../VOC2010_Context -------------------------------------------------------------------------------- /config/environment/local.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | #Path to Data of different Datasets 4 | paths: 5 | Cityscapes: /home/.../Datasets/cityscapes 6 | VOC2010_Context: /home/.../VOC2010_Context 7 | PennFudan: /home/.../Datasets/PennFudanPed 8 | -------------------------------------------------------------------------------- /config/experiment/Cityscapes.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Define which Dataset and which augmentation pipeline to use 4 | defaults: 5 | - override /data_augmentation: scale_crop_hflip 6 | - override /dataset: Cityscapes 7 | 8 | # Configure the augmentation pipeline 9 | AUGMENTATIONS: 10 | scale_limit: [-0.5, 1.0] 11 | crop_size: [512, 1024] 12 | mean: [ 0.485, 0.456, 0.406 ] 13 | std: [ 0.229, 0.224, 0.225 ] 14 | 15 | #Hyperparameters for Cityscapes Dataset 16 | batch_size: 6 # batch size per gpu for training 17 | val_batch_size: ${batch_size} # batch size per gpu for validation 18 | epochs: 400 # number of eposchs 19 | lr: 0.01 # learning rate for training (0.01445439770745928) 20 | momentum: 0.9 # momentum for optimizer 21 | weight_decay: 0.0005 # wd for optimizer 22 | lossfunction: [ "wCE", "wCE", "wCE", "wCE"] # list of lossfunctions 23 | lossweight: [1.0, 0.4, 0.05, 0.05] # corresponding weights for each loss function 24 | -------------------------------------------------------------------------------- /config/experiment/Cityscapes_coarse.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Define which Dataset and which augmentation pipeline to use 4 | defaults: 5 | - override /data_augmentation: scale_crop_hflip 6 | - override /dataset: Cityscapes_coarse 7 | 8 | # Configure the augmentation pipeline 9 | AUGMENTATIONS: 10 | scale_limit: [-0.5, 1.0] 11 | crop_size: [512, 1024] 12 | mean: [ 0.485, 0.456, 0.406 ] 13 | std: [ 0.229, 0.224, 0.225 ] 14 | 15 | #Hyperparameters for Cityscapes Dataset 16 | batch_size: 6 # batch size per gpu for training 17 | val_batch_size: ${batch_size} # batch size per gpu for validation 18 | epochs: 400 # number of eposchs 19 | lr: 0.01 # learning rate for training (0.01445439770745928) 20 | momentum: 0.9 # momentum for optimizer 21 | weight_decay: 0.0005 # wd for optimizer 22 | lossfunction: [ "wCE", "wCE", "wCE", "wCE"] # list of lossfunctions 23 | lossweight: [1.0, 0.4, 0.05, 0.05] # corresponding weights for each loss function -------------------------------------------------------------------------------- /config/experiment/Cityscapes_fine_coarse.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Define which Dataset and which augmentation pipeline to use 4 | defaults: 5 | - override /data_augmentation: scale_crop_hflip 6 | - override /dataset: Cityscapes_fine_coarse 7 | 8 | # Configure the augmentation pipeline 9 | AUGMENTATIONS: 10 | scale_limit: [-0.5, 1.0] 11 | crop_size: [512, 1024] 12 | mean: [ 0.485, 0.456, 0.406 ] 13 | std: [ 0.229, 0.224, 0.225 ] 14 | 15 | #Hyperparameters for Cityscapes Dataset 16 | batch_size: 6 # batch size per gpu for training 17 | val_batch_size: ${batch_size} # batch size per gpu for validation 18 | epochs: 400 # number of eposchs 19 | lr: 0.01 # learning rate for training (0.01445439770745928) 20 | momentum: 0.9 # momentum for optimizer 21 | weight_decay: 0.0005 # wd for optimizer 22 | lossfunction: [ "wCE", "wCE", "wCE", "wCE"] # list of lossfunctions 23 | lossweight: [1.0, 0.4, 0.05, 0.05] # corresponding weights for each loss function -------------------------------------------------------------------------------- /config/experiment/PennFudan.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | defaults: 3 | - override /data_augmentation: only_tensor 4 | - override /metric: MAP 5 | - override /trainer: InstSeg 6 | - override /dataset: PennFudan 7 | - override /model: Mask_RCNN 8 | 9 | AUGMENTATIONS: 10 | scale_limit: [-0.5, 1.0] 11 | crop_size: [512, 1024] 12 | mean: [ 0.485, 0.456, 0.406 ] 13 | std: [ 0.229, 0.224, 0.225 ] 14 | 15 | model: 16 | version: v1 17 | disable_transforms: False 18 | 19 | #Hyperparameters for the Dataset 20 | epochs: 200 # number of epochs 21 | lr: 0.005 # learning rate for training (0.01445439770745928) 22 | momentum: 0.9 # momentum for optimizer 23 | weight_decay: 0.0005 # wd for optimizer 24 | 25 | # how many example predictions should be logged 26 | num_example_predictions: 3 27 | -------------------------------------------------------------------------------- /config/experiment/VOC2010_Context.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # Define which Dataset and which augmentation pipeline to use 4 | defaults: 5 | - override /dataset: VOC2010_Context 6 | - override /data_augmentation: scale_crop_VOC2010 7 | 8 | # Configure the augmentation pipeline 9 | AUGMENTATIONS: 10 | scale_limit: [-0.5, 1.0] 11 | crop_size: [520, 520] 12 | mean: [ 0.485, 0.456, 0.406 ] 13 | std: [ 0.229, 0.224, 0.225 ] 14 | 15 | # Define the testing behaviour 16 | TESTING: 17 | SCALES: [ 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0 ] 18 | FLIP: False 19 | BINARY_FLIP: True 20 | OVERRIDES: 21 | - val_batch_size=1 22 | 23 | #Hyperparameters for VOC2010 Context Dataset Dataset 24 | batch_size: 8 # batch size per gpu for training 25 | val_batch_size: ${batch_size} # batch size per gpu for validation 26 | epochs: 200 # number of eposchs 27 | lr: 0.004 # learning rate for training 28 | momentum: 0.9 # momentum for optimizer 29 | weight_decay: 0.0001 # wd for optimizer 30 | lossfunction: [ "CE", "CE", "CE", "CE"] # list of lossfunctions 31 | lossweight: [1.0, 0.4, 0.05, 0.05] # corresponding weights for each loss function 32 | -------------------------------------------------------------------------------- /config/experiment/VOC2010_Context_60.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # Define which Dataset and which augmentation pipeline to use 4 | defaults: 5 | - override /dataset: VOC2010_Context 6 | - override /data_augmentation: scale_crop_VOC2010_60 7 | 8 | # Configure the augmentation pipeline 9 | AUGMENTATIONS: 10 | scale_limit: [-0.5, 1.0] 11 | crop_size: [520, 520] 12 | mean: [ 0.485, 0.456, 0.406 ] 13 | std: [ 0.229, 0.224, 0.225 ] 14 | 15 | # Define the testing behaviour 16 | TESTING: 17 | SCALES: [ 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0 ] 18 | FLIP: False 19 | BINARY_FLIP: True 20 | OVERRIDES: 21 | - val_batch_size=1 22 | 23 | #Hyperparameters for VOC2010 Context Dataset Dataset 24 | batch_size: 8 # batch size per gpu for training 25 | val_batch_size: ${batch_size} # batch size per gpu for validation 26 | epochs: 200 # number of eposchs 27 | lr: 0.004 # learning rate for training 28 | momentum: 0.9 # momentum for optimizer 29 | weight_decay: 0.0001 # wd for optimizer 30 | lossfunction: [ "CE", "CE", "CE", "CE"] # list of lossfunctions 31 | lossweight: [1.0, 0.4, 0.05, 0.05] # corresponding weights for each loss function 32 | -------------------------------------------------------------------------------- /config/experiment/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Basic experiment for training. For some experiments (Cityscapes, VOC2010_Context they are 4 | # overwritten in config/experiment/*.yaml, so edit the parameters there if using these datasets 5 | num_workers: 10 # number of workers for dataloader 6 | batch_size: 6 # batch size per gpu for training 7 | val_batch_size: ${batch_size} # batch size per gpu for validation 8 | epochs: 400 # number of eposchs 9 | lr: 0.01 # learning rate for training (0.01445439770745928) 10 | momentum: 0.9 # momentum for optimizer 11 | weight_decay: 0.0005 # wd for optimizer 12 | lossfunction: [ "CE", "CE", "CE", "CE"] # list of lossfunctions 13 | lossweight: [1.0, 0.4, 0.05, 0.05] # corresponding weights for each loss function 14 | #seed: 1234 # seed everything 15 | pl_trainer: # parameters for the pytorch lightning trainers 16 | accelerator: 'gpu' # train on GPU 17 | devices: -1 # using all available GPUs 18 | max_epochs: ${epochs} # parsing the number of epochs which is defined as a hyperparameter 19 | precision: 16 # using Mixed Precision 20 | benchmark: True # using benchmark for faster training 21 | deterministic: False # deterministic does not work because of not deterministc CrossEntropyLoss Error - use "warn" instead 22 | enable_checkpointing: True # Enable/Disable Checkpointing 23 | # Some usefull pl parameters for debugging 24 | #limit_train_batches: 0.1 25 | #limit_val_batches: 0.5 26 | #limit_test_batches: 0.25 27 | #accumulate_grad_batches: 2 -------------------------------------------------------------------------------- /config/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | _target_ : pytorch_lightning.loggers.tensorboard.TensorBoardLogger 2 | save_dir: ${OUTPUT_DIR} 3 | name: "" 4 | version: "" 5 | default_hp_metric: False -------------------------------------------------------------------------------- /config/lr_scheduler/polynomial.yaml: -------------------------------------------------------------------------------- 1 | interval: step # when the scheduler should be called, at each step of each epoch 2 | frequency: 1 # how often should it be called 3 | monitor: metric_to_track # parameter for pytorch lightning to log the lr 4 | scheduler: #defining the scheduler class 5 | #_target_: src.lr_scheduler.polynomial_LR_scheduler_stepwise # path to the scheduler 6 | _target_: torch.optim.lr_scheduler.PolynomialLR # path to the scheduler 7 | power: 0.9 # arguments for the scheduler -------------------------------------------------------------------------------- /config/lr_scheduler/polynomial_epoch.yaml: -------------------------------------------------------------------------------- 1 | interval: epoch # when the scheduler should be called, at each step of each epoch 2 | frequency: 1 # how often should it be called 3 | monitor: metric_to_track # parameter for pytorch lightning to log the lr 4 | scheduler: #defining the scheduler class 5 | _target_: torch.optim.lr_scheduler.PolynomialLR # path to the scheduler 6 | max_epochs: ${epochs} 7 | power: 0.9 # arguments for the scheduler -------------------------------------------------------------------------------- /config/lr_scheduler/polynomial_epoch_warmup.yaml: -------------------------------------------------------------------------------- 1 | interval: epoch # when the scheduler should be called, at each step of each epoch 2 | frequency: 1 # how often should it be called 3 | monitor: metric_to_track # parameter for pytorch lightning to log the lr 4 | scheduler: #defining the scheduler class 5 | _target_: src.lr_scheduler.PolynomialLR_Warmstart # path to the scheduler 6 | power: 0.9 # arguments for the scheduler 7 | warmstart_iters: 0.0125 # number of warmup steps in % -------------------------------------------------------------------------------- /config/lr_scheduler/polynomial_warmup.yaml: -------------------------------------------------------------------------------- 1 | interval: step # when the scheduler should be called, at each step of each epoch 2 | frequency: 1 # how often should it be called 3 | monitor: metric_to_track # parameter for pytorch lightning to log the lr 4 | scheduler: #defining the scheduler class 5 | _target_: src.lr_scheduler.PolynomialLR_Warmstart # path to the scheduler 6 | power: 0.9 # arguments for the scheduler 7 | warmstart_iters: 0.0125 # number of warmup steps in % -------------------------------------------------------------------------------- /config/metric/MAP.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | METRIC: 3 | #NAME: map # which metric to optimize - mean Dice over the classes 4 | NAME: map # which metric to optimize - mean Dice over the classes 5 | train_metric: False # If also a train metric is wanted (in addition to a validation metric) 6 | call_global: True # If True metric is updated in each step and computed once at the end of the epoch 7 | call_stepwise: False # If True metric is computed in each step (usually one batch) and averaged over all steps - exclusively with call_per_img 8 | call_per_img: False # If True metric is computed for each image and averaged over all images - exclusively with call_stepwise 9 | METRICS: 10 | MAP: # Define the Metric 11 | _target_: torchmetrics.detection.mean_ap.MeanAveragePrecision # Metric Classs 12 | iou_type: segm 13 | -------------------------------------------------------------------------------- /config/metric/mean_Dice.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | METRIC: 3 | NAME: meanDice # which metric to optimize - mean Dice over the classes 4 | train_metric: False # If also a train metric is wanted (in addition to a validation metric) 5 | call_global: True # If True metric is updated in each step and computed once at the end of the epoch 6 | call_stepwise: False # If True metric is computed in each step (usually one batch) and averaged over all steps - exclusively with call_per_img 7 | call_per_img: False # If True metric is computed for each image and averaged over all images - exclusively with call_stepwise 8 | METRICS: 9 | meanDice: # Define the Metric 10 | _target_: src.metric.Dice # Metric Class 11 | num_classes: ${DATASET.NUM_CLASSES} # Number if Classes in the Dataset 12 | labels: ${DATASET.CLASS_LABELS} # Class Labels in the Dataset 13 | per_class: False # Return the mean Metric with or without the Metric for each Class -------------------------------------------------------------------------------- /config/metric/mean_Dice_Class.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | METRIC: 3 | NAME: meanDice # which metric to optimize - mean Dice over the classes 4 | train_metric: False # If also a train metric is wanted (in addition to a validation metric) 5 | call_global: True # If True metric is updated in each step and computed once at the end of the epoch 6 | call_stepwise: False # If True metric is computed in each step (usually one batch) and averaged over all steps - exclusively with call_per_img 7 | call_per_img: False # If True metric is computed for each image and averaged over all images - exclusively with call_stepwise 8 | METRICS: 9 | meanDice: # Define the Metric 10 | _target_: src.metric.Dice # Metric Class 11 | num_classes: ${DATASET.NUM_CLASSES} # Number if Classes in the Dataset 12 | labels: ${DATASET.CLASS_LABELS} # Class Labels in the Dataset 13 | per_class: True # Return the mean Metric with or without the Metric for each Class -------------------------------------------------------------------------------- /config/metric/mean_IoU.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | METRIC: 3 | NAME: meanIoU # which metric to optimize - mean IoU over the classes 4 | train_metric: False # If also a train metric is wanted (in addition to a validation metric) 5 | call_global: True # If True metric is updated in each step and computed once at the end of the epoch 6 | call_stepwise: False # If True metric is computed in each step (usually one batch) and averaged over all steps - exclusively with call_per_img 7 | call_per_img: False # If True metric is computed for each image and averaged over all images - exclusively with call_stepwise 8 | METRICS: 9 | meanIoU: # Define the Metric 10 | _target_: src.metric.IoU # Metric Class 11 | num_classes: ${DATASET.NUM_CLASSES} # Number of Classes in the Dataset 12 | labels: ${DATASET.CLASS_LABELS} # Class Labels in the Dataset 13 | per_class: False # Return the mean Metric with or without the Metric for each Class 14 | 15 | 16 | -------------------------------------------------------------------------------- /config/metric/mean_IoU_Class.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | METRIC: 3 | NAME: meanIoU # which metric to optimize - mean IoU over the classes 4 | train_metric: False # If also a train metric is wanted (in addition to a validation metric) 5 | call_global: True # If True metric is updated in each step and computed once at the end of the epoch 6 | call_stepwise: False # If True metric is computed in each step (usually one batch) and averaged over all steps - exclusively with call_per_img 7 | call_per_img: False # If True metric is computed for each image and averaged over all images - exclusively with call_stepwise 8 | METRICS: 9 | meanIoU: # Define the Metric 10 | _target_: src.metric.IoU # Metric Class 11 | num_classes: ${DATASET.NUM_CLASSES} # Number if Classes in the Dataset 12 | labels: ${DATASET.CLASS_LABELS} # Class Labels in the Dataset 13 | per_class: True # Return the mean Metric with or without the Metric for each Class -------------------------------------------------------------------------------- /config/model/DeepLabv3.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | model: 3 | _target_: models.DeepLabv3.get_seg_model #torchvision.models.segmentation.deeplabv3_resnet101 4 | num_classes: ${DATASET.NUM_CLASSES} 5 | pretrained: True 6 | aux_loss: True 7 | backbone: resnet101 #resnet50 or resnet101 8 | 9 | MODEL: 10 | NAME: DeepLabv3 -------------------------------------------------------------------------------- /config/model/FCN.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | model: 3 | _target_: models.FCN.get_seg_model #torchvision.models.segmentation.fcn_resnet101 4 | num_classes: ${DATASET.NUM_CLASSES} 5 | pretrained: True 6 | aux_loss: True 7 | backbone: resnet101 #resnet50 or resnet101 8 | 9 | MODEL: 10 | NAME: FCN -------------------------------------------------------------------------------- /config/model/Mask_RCNN.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | model: 3 | _target_: models.Mask_RCNN.get_model_50 4 | num_classes: ${DATASET.NUM_CLASSES} 5 | version: v2 6 | disable_transforms: True 7 | MODEL: 8 | NAME: Mask_RCNN -------------------------------------------------------------------------------- /config/model/Mask_RCNN_RMI.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | model: 3 | _target_: models.Mask_RCNN_RMI_loss.get_model_50 4 | num_classes: ${DATASET.NUM_CLASSES} 5 | version: v2 6 | disable_transforms: True 7 | MODEL: 8 | NAME: Mask_RCNN_RMI -------------------------------------------------------------------------------- /config/model/UNet.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | model: 3 | _target_: models.UNet.UNet 4 | n_channels: 3 5 | n_classes: ${DATASET.NUM_CLASSES} 6 | 7 | MODEL: 8 | NAME: UNet -------------------------------------------------------------------------------- /config/model/hrnet.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | model: 3 | _target_: models.hrnet.get_seg_model 4 | cfg: 5 | MODEL: ${MODEL} 6 | DATASET: 7 | NUM_CLASSES: ${DATASET.NUM_CLASSES} 8 | 9 | MODEL: 10 | NAME: hrnet 11 | PRETRAINED: true 12 | pretrained_on: ImageNet 13 | PRETRAINED_WEIGHTS: ${ORG_CWD}/pretrained/${MODEL.available_weights.${MODEL.pretrained_on}} 14 | available_weights: 15 | ImageNet: hrnetv2_w48_imagenet_pretrained.pth 16 | Paddle: HRNet_W48_C_ssld_pretrained.pth 17 | Mapillary: mapillary_ocrnet.HRNet_Mscale_fast-rattlesnake.pth 18 | ALIGN_CORNERS: False 19 | INPUT_CHANNELS: 3 20 | EXTRA: 21 | FINAL_CONV_KERNEL: 1 22 | STAGE1: 23 | NUM_MODULES: 1 24 | NUM_BRANCHES: 1 25 | BLOCK: BOTTLENECK 26 | NUM_BLOCKS: 27 | - 4 28 | NUM_CHANNELS: 29 | - 64 30 | FUSE_METHOD: SUM 31 | STAGE2: 32 | NUM_MODULES: 1 33 | NUM_BRANCHES: 2 34 | BLOCK: BASIC 35 | NUM_BLOCKS: 36 | - 4 37 | - 4 38 | NUM_CHANNELS: 39 | - 48 40 | - 96 41 | FUSE_METHOD: SUM 42 | STAGE3: 43 | NUM_MODULES: 4 44 | NUM_BRANCHES: 3 45 | BLOCK: BASIC 46 | NUM_BLOCKS: 47 | - 4 48 | - 4 49 | - 4 50 | NUM_CHANNELS: 51 | - 48 52 | - 96 53 | - 192 54 | FUSE_METHOD: SUM 55 | STAGE4: 56 | NUM_MODULES: 3 57 | NUM_BRANCHES: 4 58 | BLOCK: BASIC 59 | NUM_BLOCKS: 60 | - 4 61 | - 4 62 | - 4 63 | - 4 64 | NUM_CHANNELS: 65 | - 48 66 | - 96 67 | - 192 68 | - 384 69 | FUSE_METHOD: SUM 70 | -------------------------------------------------------------------------------- /config/model/hrnet_ocr.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | model: 3 | _target_: models.hrnet_ocr.get_seg_model 4 | cfg: 5 | MODEL: ${MODEL} 6 | DATASET: 7 | NUM_CLASSES: ${DATASET.NUM_CLASSES} 8 | 9 | MODEL: 10 | NAME: hrnet_ocr 11 | PRETRAINED: true 12 | pretrained_on: ImageNet 13 | PRETRAINED_WEIGHTS: ${ORG_CWD}/pretrained/${MODEL.available_weights.${MODEL.pretrained_on}} 14 | available_weights: 15 | ImageNet: hrnetv2_w48_imagenet_pretrained.pth 16 | Paddle: HRNet_W48_C_ssld_pretrained.pth 17 | Mapillary: mapillary_ocrnet.HRNet_Mscale_fast-rattlesnake.pth 18 | ALIGN_CORNERS: True 19 | INPUT_CHANNELS: 3 20 | OCR: 21 | MID_CHANNELS: 512 22 | KEY_CHANNELS: 256 23 | DROPOUT: 0.05 24 | SCALE: 1 25 | EXTRA: 26 | FINAL_CONV_KERNEL: 1 27 | STAGE1: 28 | NUM_MODULES: 1 29 | NUM_RANCHES: 1 30 | BLOCK: BOTTLENECK 31 | NUM_BLOCKS: 32 | - 4 33 | NUM_CHANNELS: 34 | - 64 35 | FUSE_METHOD: SUM 36 | STAGE2: 37 | NUM_MODULES: 1 38 | NUM_BRANCHES: 2 39 | BLOCK: BASIC 40 | NUM_BLOCKS: 41 | - 4 42 | - 4 43 | NUM_CHANNELS: 44 | - 48 45 | - 96 46 | FUSE_METHOD: SUM 47 | STAGE3: 48 | NUM_MODULES: 4 49 | NUM_BRANCHES: 3 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 4 53 | - 4 54 | - 4 55 | NUM_CHANNELS: 56 | - 48 57 | - 96 58 | - 192 59 | FUSE_METHOD: SUM 60 | STAGE4: 61 | NUM_MODULES: 3 62 | NUM_BRANCHES: 4 63 | BLOCK: BASIC 64 | NUM_BLOCKS: 65 | - 4 66 | - 4 67 | - 4 68 | - 4 69 | NUM_CHANNELS: 70 | - 48 71 | - 96 72 | - 192 73 | - 384 74 | FUSE_METHOD: SUM 75 | 76 | -------------------------------------------------------------------------------- /config/model/hrnet_ocr_aspp.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | model: 3 | _target_: models.hrnet_ocr_aspp.get_seg_model 4 | cfg: 5 | MODEL: ${MODEL} 6 | DATASET: 7 | NUM_CLASSES: ${DATASET.NUM_CLASSES} 8 | 9 | MODEL: 10 | NAME: hrnet_ocr_aspp 11 | PRETRAINED: true 12 | pretrained_on: ImageNet 13 | PRETRAINED_WEIGHTS: ${ORG_CWD}/pretrained/${MODEL.available_weights.${MODEL.pretrained_on}} 14 | available_weights: 15 | ImageNet: hrnetv2_w48_imagenet_pretrained.pth 16 | Paddle: HRNet_W48_C_ssld_pretrained.pth 17 | Mapillary: mapillary_ocrnet.HRNet_Mscale_fast-rattlesnake.pth 18 | ALIGN_CORNERS: True 19 | ASPP_BOT_CH: 256 20 | INPUT_CHANNELS: 3 21 | OCR: 22 | MID_CHANNELS: 512 23 | KEY_CHANNELS: 256 24 | DROPOUT: 0.05 25 | SCALE: 1 26 | EXTRA: 27 | FINAL_CONV_KERNEL: 1 28 | STAGE1: 29 | NUM_MODULES: 1 30 | NUM_RANCHES: 1 31 | BLOCK: BOTTLENECK 32 | NUM_BLOCKS: 33 | - 4 34 | NUM_CHANNELS: 35 | - 64 36 | FUSE_METHOD: SUM 37 | STAGE2: 38 | NUM_MODULES: 1 39 | NUM_BRANCHES: 2 40 | BLOCK: BASIC 41 | NUM_BLOCKS: 42 | - 4 43 | - 4 44 | NUM_CHANNELS: 45 | - 48 46 | - 96 47 | FUSE_METHOD: SUM 48 | STAGE3: 49 | NUM_MODULES: 4 50 | NUM_BRANCHES: 3 51 | BLOCK: BASIC 52 | NUM_BLOCKS: 53 | - 4 54 | - 4 55 | - 4 56 | NUM_CHANNELS: 57 | - 48 58 | - 96 59 | - 192 60 | FUSE_METHOD: SUM 61 | STAGE4: 62 | NUM_MODULES: 3 63 | NUM_BRANCHES: 4 64 | BLOCK: BASIC 65 | NUM_BLOCKS: 66 | - 4 67 | - 4 68 | - 4 69 | - 4 70 | NUM_CHANNELS: 71 | - 48 72 | - 96 73 | - 192 74 | - 384 75 | FUSE_METHOD: SUM 76 | 77 | -------------------------------------------------------------------------------- /config/model/hrnet_ocr_ms.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | model: 3 | _target_: models.hrnet_ocr_ms.get_seg_model 4 | cfg: 5 | MODEL: ${MODEL} 6 | DATASET: 7 | NUM_CLASSES: ${DATASET.NUM_CLASSES} 8 | 9 | TESTING: 10 | OVERRIDES: 11 | - MODEL.MSCALE_INFERENCE=True 12 | 13 | 14 | MODEL: 15 | NAME: hrnet_ocr_ms 16 | PRETRAINED: true 17 | pretrained_on: ImageNet 18 | PRETRAINED_WEIGHTS: ${ORG_CWD}/pretrained/${MODEL.available_weights.${MODEL.pretrained_on}} 19 | available_weights: 20 | ImageNet: hrnetv2_w48_imagenet_pretrained.pth 21 | Paddle: HRNet_W48_C_ssld_pretrained.pth 22 | Mapillary: mapillary_ocrnet.HRNet_Mscale_fast-rattlesnake.pth 23 | MSCALE_INFERENCE: False 24 | ALIGN_CORNERS: True 25 | MSCALE_LO_SCALE: 0.5 26 | SEGATTN_BOT_CH: 256 27 | N_SCALES: [ 0.5,1.0,2.0 ] 28 | MSCALE_DROPOUT: False 29 | MSCALE_INNER_3x3: True 30 | INPUT_CHANNELS: 3 31 | OCR: 32 | MID_CHANNELS: 512 33 | KEY_CHANNELS: 256 34 | DROPOUT: 0.05 35 | SCALE: 1 36 | EXTRA: 37 | FINAL_CONV_KERNEL: 1 38 | STAGE1: 39 | NUM_MODULES: 1 40 | NUM_RANCHES: 1 41 | BLOCK: BOTTLENECK 42 | NUM_BLOCKS: 43 | - 4 44 | NUM_CHANNELS: 45 | - 64 46 | FUSE_METHOD: SUM 47 | STAGE2: 48 | NUM_MODULES: 1 49 | NUM_BRANCHES: 2 50 | BLOCK: BASIC 51 | NUM_BLOCKS: 52 | - 4 53 | - 4 54 | NUM_CHANNELS: 55 | - 48 56 | - 96 57 | FUSE_METHOD: SUM 58 | STAGE3: 59 | NUM_MODULES: 4 60 | NUM_BRANCHES: 3 61 | BLOCK: BASIC 62 | NUM_BLOCKS: 63 | - 4 64 | - 4 65 | - 4 66 | NUM_CHANNELS: 67 | - 48 68 | - 96 69 | - 192 70 | FUSE_METHOD: SUM 71 | STAGE4: 72 | NUM_MODULES: 3 73 | NUM_BRANCHES: 4 74 | BLOCK: BASIC 75 | NUM_BLOCKS: 76 | - 4 77 | - 4 78 | - 4 79 | - 4 80 | NUM_CHANNELS: 81 | - 48 82 | - 96 83 | - 192 84 | - 384 85 | FUSE_METHOD: SUM 86 | 87 | -------------------------------------------------------------------------------- /config/optimizer/ADAMW.yaml: -------------------------------------------------------------------------------- 1 | # optimizer 2 | _target_: torch.optim.AdamW 3 | lr: ${lr} -------------------------------------------------------------------------------- /config/optimizer/MADGRAD.yaml: -------------------------------------------------------------------------------- 1 | # optimizer 2 | _target_: madgrad.MADGRAD 3 | lr: ${lr} 4 | momentum: ${momentum} 5 | weight_decay: 0 #${weight_decay} -------------------------------------------------------------------------------- /config/optimizer/SGD.yaml: -------------------------------------------------------------------------------- 1 | # optimizer 2 | _target_: torch.optim.SGD 3 | lr: ${lr} 4 | momentum: ${momentum} 5 | weight_decay: ${weight_decay} -------------------------------------------------------------------------------- /config/testing.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # Default List 4 | # Use training.yaml and only change/override a few parts 5 | defaults: 6 | - training 7 | - _self_ 8 | 9 | # (Required) Path to the Checkpoint which should be tested/validated 10 | ckpt_dir: ??? 11 | 12 | # Define Test Time Augmentation 13 | TESTING: 14 | SCALES: [1] 15 | FLIP: False 16 | BINARY_FLIP: False 17 | 18 | # Customizations of Hydra, change hydra/run/dir to ckpt_dir 19 | hydra: 20 | output_subdir: testing/hydra 21 | run: 22 | dir: ${ckpt_dir} 23 | sweep: 24 | dir: multi_run_${hydra.run.dir} 25 | subdir: ${hydra.job.num} 26 | job_logging: 27 | handlers: 28 | file: 29 | #filename: testing/${hydra.job.name}.log 30 | filename: ${hydra.job.name}.log -------------------------------------------------------------------------------- /config/trainer/InstSeg.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | trainermodule: 3 | _target_: trainers.Instance_Segmentation_Trainer.InstModel 4 | datamodule: 5 | instance_seg: True -------------------------------------------------------------------------------- /config/trainer/SemSeg.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | trainermodule: 3 | _target_: trainers.Semantic_Segmentation_Trainer.SegModel -------------------------------------------------------------------------------- /config/training.yaml: -------------------------------------------------------------------------------- 1 | #@package _global_ 2 | 3 | # Default List 4 | # defines which modules from which parameter group are used by default and in which order they are composed 5 | defaults: 6 | - _self_ 7 | - trainer: SemSeg # Which Trainer to use 8 | - metric: mean_IoU # Metric configuration 9 | - model: hrnet # Model 10 | - dataset: Cityscapes # Dataset 11 | - data_augmentation: only_norm # Data Augmentation 12 | - optimizer: SGD # Optimizer 13 | - lr_scheduler: polynomial # Learning rate scheduler 14 | - callbacks: default # Callbacks 15 | - logger: tensorboard # Logger 16 | - experiment/default # Load Default Hyperparameters 17 | - optional experiment: 18 | - environment: local # Environment 19 | 20 | - override hydra/hydra_logging: colorlog # Using colorlog plugin of hydra for logging 21 | - override hydra/job_logging: colorlog 22 | 23 | # Logging Related Parameters 24 | ORG_CWD: ${hydra:runtime.cwd} # Saving the original working dir 25 | OUTPUT_DIR: ${hydra:runtime.output_dir} 26 | name: "run" # Possibility for naming the experiment 27 | LOGDIR: logs/ # Default logging directory 28 | num_example_predictions: 2 # Save some example predictions during validation/testing 29 | # Just for visualization/inspection, set to 0 if not wanted 30 | 31 | # Defining the datamodule 32 | datamodule: 33 | _target_: datasets.DataModules.BaseDataModule # Base Data Module 34 | num_workers: ${num_workers} # Parsing all needed experiment 35 | batch_size: ${batch_size} 36 | val_batch_size: ${val_batch_size} 37 | augmentations: ${AUGMENTATIONS} # Parsing the Data augmentation, defined in the augmentation config 38 | dataset: ${dataset} # Parsing the Dataset defined in the dataset config 39 | 40 | # Defining the saving behavior. Only used when checkpointing is enabled in the pl_trainer 41 | ModelCheckpoint: 42 | _target_: src.callbacks.customModelCheckpoint # Custom checkpoint Callback with a few modifications 43 | monitor: "metric/${METRIC.NAME}" # Name of the metric during logging 44 | mode: "max" # min or max: should be metric me maximised or minimized 45 | filename: 'best_epoch_{epoch}__${METRIC.NAME}_{metric/${METRIC.NAME}:.4f}' 46 | auto_insert_metric_name: False # Needs to be false for better naming of checkpoint 47 | save_last: True # If the last checkpoint should be saved too 48 | 49 | # Customizations of Hydra 50 | hydra: 51 | output_subdir: hydra 52 | run: 53 | dir: ${LOGDIR}/${DATASET.NAME}/${MODEL.NAME}/${name}__${path_formatter:${hydra.job.override_dirname}}/${now:%Y-%m-%d_%H-%M-%S} 54 | # Example Dir: /.../Semantic_Segmentation/logs/Cityscapes/hrnet/baseline__lossfunction_CE/2022-02-14_15-42-43/ 55 | # using path_formatter with is defined in training.py to resolve problems which may occur with characters like [,],",or ',' in paths 56 | sweep: 57 | dir: multi_run_${hydra.run.dir} 58 | subdir: ${hydra.job.num} 59 | job: 60 | chdir: True 61 | config: 62 | override_dirname: 63 | kv_sep: "_" # do not use "=" to prevent problems when parsing paths 64 | item_sep: "__" 65 | exclude_keys: # excluding some key from ${hydra.job.override_dirname} 66 | - model # already used in the path 67 | - dataset # already used in the path 68 | - environment # no needed information for the experiments 69 | - finetune_from # to long 70 | - continue_from # to long 71 | - name # already used in the path 72 | - LOGDIR # no needed information for the experiments 73 | - pl_trainer.enable_checkpointing -------------------------------------------------------------------------------- /datasets/Cityscapes/Cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from collections import namedtuple 4 | 5 | import torch 6 | import torchvision.utils 7 | 8 | import cv2 9 | import albumentations as A 10 | from albumentations.pytorch import ToTensorV2 11 | from src.utils import get_logger 12 | 13 | log = get_logger(__name__) 14 | 15 | # some parts are taken from here: 16 | # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 17 | 18 | CityscapesClass = namedtuple( 19 | "CityscapesClass", 20 | [ 21 | "name", 22 | "id", 23 | "train_id", 24 | "category", 25 | "category_id", 26 | "has_instances", 27 | "ignore_in_eval", 28 | "color", 29 | ], 30 | ) 31 | # 34 classes 32 | classes_34 = [ 33 | CityscapesClass("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0)), 34 | CityscapesClass("ego vehicle", 1, 255, "void", 0, False, True, (0, 0, 0)), 35 | CityscapesClass("rectification border", 2, 255, "void", 0, False, True, (0, 0, 0)), 36 | CityscapesClass("out of roi", 3, 255, "void", 0, False, True, (0, 0, 0)), 37 | CityscapesClass("static", 4, 255, "void", 0, False, True, (0, 0, 0)), 38 | CityscapesClass("dynamic", 5, 255, "void", 0, False, True, (111, 74, 0)), 39 | CityscapesClass("ground", 6, 255, "void", 0, False, True, (81, 0, 81)), 40 | CityscapesClass("road", 7, 0, "flat", 1, False, False, (128, 64, 128)), 41 | CityscapesClass("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232)), 42 | CityscapesClass("parking", 9, 255, "flat", 1, False, True, (250, 170, 160)), 43 | CityscapesClass("rail track", 10, 255, "flat", 1, False, True, (230, 150, 140)), 44 | CityscapesClass("building", 11, 2, "construction", 2, False, False, (70, 70, 70)), 45 | CityscapesClass("wall", 12, 3, "construction", 2, False, False, (102, 102, 156)), 46 | CityscapesClass("fence", 13, 4, "construction", 2, False, False, (190, 153, 153)), 47 | CityscapesClass("guard rail", 14, 255, "construction", 2, False, True, (180, 165, 180)), 48 | CityscapesClass("bridge", 15, 255, "construction", 2, False, True, (150, 100, 100)), 49 | CityscapesClass("tunnel", 16, 255, "construction", 2, False, True, (150, 120, 90)), 50 | CityscapesClass("pole", 17, 5, "object", 3, False, False, (153, 153, 153)), 51 | CityscapesClass("polegroup", 18, 255, "object", 3, False, True, (153, 153, 153)), 52 | CityscapesClass("traffic light", 19, 6, "object", 3, False, False, (250, 170, 30)), 53 | CityscapesClass("traffic sign", 20, 7, "object", 3, False, False, (220, 220, 0)), 54 | CityscapesClass("vegetation", 21, 8, "nature", 4, False, False, (107, 142, 35)), 55 | CityscapesClass("terrain", 22, 9, "nature", 4, False, False, (152, 251, 152)), 56 | CityscapesClass("sky", 23, 10, "sky", 5, False, False, (70, 130, 180)), 57 | CityscapesClass("person", 24, 11, "human", 6, True, False, (220, 20, 60)), 58 | CityscapesClass("rider", 25, 12, "human", 6, True, False, (255, 0, 0)), 59 | CityscapesClass("car", 26, 13, "vehicle", 7, True, False, (0, 0, 142)), 60 | CityscapesClass("truck", 27, 14, "vehicle", 7, True, False, (0, 0, 70)), 61 | CityscapesClass("bus", 28, 15, "vehicle", 7, True, False, (0, 60, 100)), 62 | CityscapesClass("caravan", 29, 255, "vehicle", 7, True, True, (0, 0, 90)), 63 | CityscapesClass("trailer", 30, 255, "vehicle", 7, True, True, (0, 0, 110)), 64 | CityscapesClass("train", 31, 16, "vehicle", 7, True, False, (0, 80, 100)), 65 | CityscapesClass("motorcycle", 32, 17, "vehicle", 7, True, False, (0, 0, 230)), 66 | CityscapesClass("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32)), 67 | CityscapesClass("license plate", -1, -1, "vehicle", 7, False, True, (0, 0, 142)), 68 | ] 69 | 70 | # 19 classes 71 | classes_19 = [ 72 | CityscapesClass("road", 0, 0, "flat", 1, False, False, (128, 64, 128)), 73 | CityscapesClass("sidewalk", 1, 1, "flat", 1, False, False, (244, 35, 232)), 74 | CityscapesClass("building", 2, 2, "construction", 2, False, False, (70, 70, 70)), 75 | CityscapesClass("wall", 3, 3, "construction", 2, False, False, (102, 102, 156)), 76 | CityscapesClass("fence", 4, 4, "construction", 2, False, False, (190, 153, 153)), 77 | CityscapesClass("pole", 5, 5, "object", 3, False, False, (153, 153, 153)), 78 | CityscapesClass("traffic light", 6, 6, "object", 3, False, False, (250, 170, 30)), 79 | CityscapesClass("traffic sign", 7, 7, "object", 3, False, False, (220, 220, 0)), 80 | CityscapesClass("vegetation", 8, 8, "nature", 4, False, False, (107, 142, 35)), 81 | CityscapesClass("terrain", 9, 9, "nature", 4, False, False, (152, 251, 152)), 82 | CityscapesClass("sky", 10, 10, "sky", 5, False, False, (70, 130, 180)), 83 | CityscapesClass("person", 11, 11, "human", 6, True, False, (220, 20, 60)), 84 | CityscapesClass("rider", 12, 12, "human", 6, True, False, (255, 0, 0)), 85 | CityscapesClass("car", 13, 13, "vehicle", 7, True, False, (0, 0, 142)), 86 | CityscapesClass("truck", 14, 14, "vehicle", 7, True, False, (0, 0, 70)), 87 | CityscapesClass("bus", 15, 15, "vehicle", 7, True, False, (0, 60, 100)), 88 | CityscapesClass("train", 16, 16, "vehicle", 7, True, False, (0, 80, 100)), 89 | CityscapesClass("motorcycle", 17, 17, "vehicle", 7, True, False, (0, 0, 230)), 90 | CityscapesClass("bicycle", 18, 18, "vehicle", 7, True, False, (119, 11, 32)), 91 | ] 92 | 93 | # mapping from 34 class setting to 19 classes 94 | ignore_label = 255 95 | label_mapping = { 96 | -1: ignore_label, 97 | 0: ignore_label, 98 | 1: ignore_label, 99 | 2: ignore_label, 100 | 3: ignore_label, 101 | 4: ignore_label, 102 | 5: ignore_label, 103 | 6: ignore_label, 104 | 7: 0, 105 | 8: 1, 106 | 9: ignore_label, 107 | 10: ignore_label, 108 | 11: 2, 109 | 12: 3, 110 | 13: 4, 111 | 14: ignore_label, 112 | 15: ignore_label, 113 | 16: ignore_label, 114 | 17: 5, 115 | 18: ignore_label, 116 | 19: 6, 117 | 20: 7, 118 | 21: 8, 119 | 22: 9, 120 | 23: 10, 121 | 24: 11, 122 | 25: 12, 123 | 26: 13, 124 | 27: 14, 125 | 28: 15, 126 | 29: ignore_label, 127 | 30: ignore_label, 128 | 31: 16, 129 | 32: 17, 130 | 33: 18, 131 | } 132 | 133 | 134 | # cityscapes dataset class 135 | class Cityscapes_dataset(torch.utils.data.Dataset): 136 | def __init__(self, root, split="train", transforms=None): 137 | # providing the possibility to have data and labels at different locations 138 | if isinstance(root, str): 139 | root_imgs = root 140 | root_labels = root 141 | else: 142 | root_imgs = root.IMAGES 143 | root_labels = root.LABELS 144 | 145 | # no test dataset for cityscapes so return the validation set instead 146 | if split == "test": 147 | split = "val" 148 | 149 | # building the paths 150 | imgs_path = os.path.join( 151 | root_imgs, 152 | "leftImg8bit_trainvaltest", 153 | "leftImg8bit", 154 | split, 155 | "*", 156 | "*_leftImg8bit.png", 157 | ) 158 | masks_path = os.path.join( 159 | root_labels, 160 | "gtFine_trainvaltest", 161 | "gtFine", 162 | split, 163 | "*", 164 | "*_gt*_labelIds_19classes.png", 165 | ) 166 | # elif num_classes==34: 167 | # masks_path=os.path.join( root_labels,"gtFine_trainvaltest" , "gtFine" , split , "*" , "*_gt*_labelIds.png" ) 168 | 169 | # save all paths in lists 170 | self.imgs = list(sorted(glob.glob(imgs_path))) 171 | self.masks = list(sorted(glob.glob(masks_path))) 172 | 173 | self.transforms = transforms 174 | log.info( 175 | "Dataset: Cityscape %s - %s images - %s masks", 176 | split, 177 | len(self.imgs), 178 | len(self.masks), 179 | ) 180 | 181 | def __getitem__(self, idx): 182 | # read image (opencv read images in bgr) and mask 183 | img = cv2.imread(self.imgs[idx]) 184 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 185 | 186 | mask = cv2.imread(self.masks[idx], -1) 187 | 188 | # apply albumentations transforms 189 | transformed = self.transforms(image=img, mask=mask) 190 | img = transformed["image"] 191 | mask = transformed["mask"] 192 | 193 | return img, mask # .long() 194 | 195 | def __len__(self): 196 | return len(self.imgs) 197 | 198 | 199 | if __name__ == "__main__": 200 | # define some transforms 201 | transforms = A.Compose( 202 | [ 203 | # A.RandomCrop(width=768, height=768), 204 | A.RandomScale(scale_limit=(-0.5, 1), always_apply=True, p=1.0), 205 | # A.PadIfNeeded(min_height=768,min_width=768), 206 | # A.Resize(p=1.0,width=1024, height=512), 207 | A.RandomCrop(width=1024, height=512, always_apply=True, p=1.0), 208 | # A.ColorJitter(brightness=9,contrast=0,saturation=0,hue=0), 209 | A.RGBShift(p=1, r_shift_limit=10, g_shift_limit=10, b_shift_limit=10), 210 | A.GaussianBlur(), 211 | A.HorizontalFlip(p=0.5), 212 | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], always_apply=True), 213 | ToTensorV2(), 214 | ] 215 | ) 216 | print(transforms) 217 | # print(transforms) 218 | # A.save(transforms,"config/transform_test.yaml",data_format='yaml') 219 | 220 | # load a dataset 221 | cityscapesPath = "/home/l727r/Desktop/Datasets/cityscapes" 222 | Cityscape_train = Cityscapes_dataset(cityscapesPath, "train", transforms=transforms) 223 | 224 | # load some data and visualie it 225 | img, mask = Cityscape_train[100] 226 | -------------------------------------------------------------------------------- /datasets/Cityscapes/Cityscapes_coarse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import torch 5 | import albumentations as A 6 | from albumentations.pytorch import ToTensorV2 7 | from datasets.Cityscapes.Cityscapes import Cityscapes_dataset, classes_19 8 | from src.utils import get_logger 9 | 10 | log = get_logger(__name__) 11 | 12 | 13 | # dataset class for the coarse cityscapes dataset 14 | # subclass of the cityscapes dataset and just adopt the init 15 | class Cityscapes_coarse_dataset(Cityscapes_dataset): 16 | def __init__(self, root, split="train", transforms=None): 17 | # providing the possibility to have data and labels at different locations 18 | if isinstance(root, str): 19 | root_imgs = root 20 | root_labels = root 21 | else: 22 | root_imgs = root.IMAGES 23 | root_labels = root.LABELS 24 | 25 | # no test dataset for cityscapes so return the validation set instead 26 | if split == "test": 27 | split = "val" 28 | 29 | # building the paths 30 | if split == "train": 31 | imgs_path = os.path.join( 32 | root_imgs, 33 | "leftImg8bit_trainextra", 34 | "leftImg8bit", 35 | "train_extra", 36 | "*", 37 | "*_leftImg8bit.png", 38 | ) 39 | masks_path = os.path.join( 40 | root_labels, 41 | "gtCoarse", 42 | "gtCoarse", 43 | "train_extra", 44 | "*", 45 | "*_gt*_labelIds_19classes.png", 46 | ) 47 | elif split == "val": 48 | imgs_path = os.path.join( 49 | root_imgs, 50 | "leftImg8bit_trainvaltest", 51 | "leftImg8bit", 52 | split, 53 | "*", 54 | "*_leftImg8bit.png", 55 | ) 56 | masks_path = os.path.join( 57 | root_labels, 58 | "gtFine_trainvaltest", 59 | "gtFine", 60 | split, 61 | "*", 62 | "*_gt*_labelIds_19classes.png", 63 | ) 64 | 65 | # save all path in lists 66 | self.imgs = list(sorted(glob.glob(imgs_path))) 67 | self.masks = list(sorted(glob.glob(masks_path))) 68 | 69 | # this image is corrupt, so exclude it 70 | troisdorf = ( 71 | root_imgs 72 | + "/leftImg8bit_trainextra/leftImg8bit/train_extra/troisdorf/troisdorf_000000_000073_leftImg8bit.png" 73 | ) 74 | if troisdorf in self.imgs: 75 | self.imgs.remove(troisdorf) 76 | 77 | self.transforms = transforms 78 | log.info( 79 | "Dataset: Cityscape %s - %s images - %s masks", 80 | split, 81 | len(self.imgs), 82 | len(self.masks), 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | transforms = A.Compose( 88 | [ 89 | # A.RandomCrop(width=768, height=768), 90 | # A.RandomScale(scale_limit=(-0.5,1),always_apply=True,p=1.0), 91 | # A.Resize(p=1.0,width=1024, height=512), 92 | # A.RandomCrop(width=1024, height=512,always_apply=True,p=1.0), 93 | # A.HorizontalFlip(p=0.5), 94 | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], always_apply=True), 95 | ToTensorV2(), 96 | ] 97 | ) 98 | 99 | cityscapesPath = "/home/l727r/Desktop/Cityscape" 100 | Cityscape_train = Cityscapes_coarse_dataset(cityscapesPath, "train", transforms=transforms) 101 | 102 | img, mask = Cityscape_train[2000] 103 | -------------------------------------------------------------------------------- /datasets/Cityscapes/Cityscapes_fine_coarse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | 5 | import albumentations as A 6 | from albumentations.pytorch import ToTensorV2 7 | 8 | import torch 9 | 10 | from datasets.Cityscapes.Cityscapes import Cityscapes_dataset, classes_19, classes_34 11 | from src.utils import get_logger 12 | 13 | log = get_logger(__name__) 14 | 15 | 16 | # dataset class for using fine and coarse cityscapes data 17 | # subclass of the cityscapes dataset and just adopt the init 18 | # coarse_portion: defines the amount of coarse data which should be included, between 0 (none) and 1 (all) 19 | class Cityscape_fine_coarse_dataset(Cityscapes_dataset): 20 | def __init__(self, root, split="train", transforms=None, coarse_portion=1.0): 21 | # providing the possibility to have data and labels at different locations 22 | if isinstance(root, str): 23 | root_imgs = root 24 | root_labels = root 25 | else: 26 | root_imgs = root.IMAGES 27 | root_labels = root.LABELS 28 | 29 | # no test dataset for cityscapes so return the validation set instead 30 | if split == "test": 31 | split = "val" 32 | 33 | if split == "train": 34 | # building the paths for fine and coarse images 35 | imgs_path_fine = os.path.join( 36 | root_imgs, 37 | "leftImg8bit_trainvaltest", 38 | "leftImg8bit", 39 | split, 40 | "*", 41 | "*_leftImg8bit.png", 42 | ) 43 | imgs_path_coarse = os.path.join( 44 | root_imgs, 45 | "leftImg8bit_trainextra", 46 | "leftImg8bit", 47 | "train_extra", 48 | "*", 49 | "*_leftImg8bit.png", 50 | ) 51 | 52 | # building the paths for fine and coarse masks 53 | masks_path_fine = os.path.join( 54 | root_labels, 55 | "gtFine_trainvaltest", 56 | "gtFine", 57 | split, 58 | "*", 59 | "*_gt*_labelIds_19classes.png", 60 | ) 61 | masks_path_coarse = os.path.join( 62 | root_labels, 63 | "gtCoarse", 64 | "gtCoarse", 65 | "train_extra", 66 | "*", 67 | "*_gt*_labelIds_19classes.png", 68 | ) 69 | # elif num_classes==34: 70 | # masks_path_fine = os.path.join(root_labels, "gtFine_trainvaltest", "gtFine", split, "*","*_gt*_labelIds.png") 71 | # masks_path_coarse = os.path.join(root_labels, "gtCoarse", "gtCoarse", "train_extra", "*","*_gt*_labelIds.png") 72 | 73 | # save all path in lists 74 | imgs_fine = list(sorted(glob.glob(imgs_path_fine))) 75 | imgs_coarse = list(sorted(glob.glob(imgs_path_coarse))) 76 | 77 | # this image is corrupt, so exclude it 78 | troisdorf = ( 79 | root_imgs 80 | + "/leftImg8bit_trainextra/leftImg8bit/train_extra/troisdorf/troisdorf_000000_000073_leftImg8bit.png" 81 | ) 82 | if troisdorf in imgs_coarse: 83 | imgs_coarse.remove(troisdorf) 84 | masks_fine = list(sorted(glob.glob(masks_path_fine))) 85 | masks_coarse = list(sorted(glob.glob(masks_path_coarse))) 86 | 87 | # randomly select coarse_portion of the coarse data 88 | coarse_portion = max(coarse_portion, 0) 89 | indices = random.sample(range(len(imgs_coarse)), int(len(imgs_coarse) * coarse_portion)) 90 | indices.sort() 91 | imgs_coarse = [imgs_coarse[index] for index in indices] 92 | masks_coarse = [masks_coarse[index] for index in indices] 93 | 94 | # join file and selected coarse data 95 | self.masks = masks_fine + masks_coarse 96 | self.imgs = imgs_fine + imgs_coarse 97 | 98 | log.info( 99 | "Dataset: Cityscape %s (Coarse+Fine) | Total: %s images - %s masks | Fine: %s" 100 | " images - %s masks | Fine: %s images - %s masks", 101 | split, 102 | len(self.imgs), 103 | len(self.masks), 104 | len(imgs_fine), 105 | len(masks_fine), 106 | len(imgs_coarse), 107 | len(masks_coarse), 108 | ) 109 | 110 | elif split == "val": 111 | # for validation only the fine annotated dat is used 112 | # building the paths 113 | imgs_path = os.path.join( 114 | root_imgs, 115 | "leftImg8bit_trainvaltest", 116 | "leftImg8bit", 117 | split, 118 | "*", 119 | "*_leftImg8bit.png", 120 | ) 121 | masks_path = os.path.join( 122 | root_labels, 123 | "gtFine_trainvaltest", 124 | "gtFine", 125 | split, 126 | "*", 127 | "*_gt*_labelIds_19classes.png", 128 | ) 129 | 130 | # save all path in lists 131 | self.imgs = list(sorted(glob.glob(imgs_path))) 132 | self.masks = list(sorted(glob.glob(masks_path))) 133 | 134 | log.info( 135 | "Dataset: Cityscape %s (Coarse+Fine) | Total: %s images - %s masks", 136 | split, 137 | len(self.imgs), 138 | len(self.masks), 139 | ) 140 | 141 | self.transforms = transforms 142 | 143 | 144 | if __name__ == "__main__": 145 | transforms = A.Compose( 146 | [ 147 | # A.RandomCrop(width=768, height=768), 148 | # A.RandomScale(scale_limit=(-0.5,1),always_apply=True,p=1.0), 149 | # A.Resize(p=1.0,width=1024, height=512), 150 | # A.RandomCrop(width=1024, height=512,always_apply=True,p=1.0), 151 | # A.HorizontalFlip(p=0.5), 152 | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], always_apply=True), 153 | ToTensorV2(), 154 | ] 155 | ) 156 | 157 | cityscapesPath = "/home/l727r/Desktop/Cityscape" 158 | Cityscape_train = Cityscape_fine_coarse_dataset( 159 | cityscapesPath, "train", transforms=transforms, coarse_portion=-0.2 160 | ) 161 | # for i in range(0,50): 162 | img, mask = Cityscape_train[2000] 163 | -------------------------------------------------------------------------------- /datasets/Cityscapes/process_Cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | ignore_label = 255 8 | label_mapping = { 9 | -1: ignore_label, 10 | 0: ignore_label, 11 | 1: ignore_label, 12 | 2: ignore_label, 13 | 3: ignore_label, 14 | 4: ignore_label, 15 | 5: ignore_label, 16 | 6: ignore_label, 17 | 7: 0, 18 | 8: 1, 19 | 9: ignore_label, 20 | 10: ignore_label, 21 | 11: 2, 22 | 12: 3, 23 | 13: 4, 24 | 14: ignore_label, 25 | 15: ignore_label, 26 | 16: ignore_label, 27 | 17: 5, 28 | 18: ignore_label, 29 | 19: 6, 30 | 20: 7, 31 | 21: 8, 32 | 22: 9, 33 | 23: 10, 34 | 24: 11, 35 | 25: 12, 36 | 26: 13, 37 | 27: 14, 38 | 28: 15, 39 | 29: ignore_label, 40 | 30: ignore_label, 41 | 31: 16, 42 | 32: 17, 43 | 33: 18, 44 | } 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("data_path", type=str) 49 | args = parser.parse_args() 50 | 51 | root = args.data_path 52 | 53 | splits = ["train", "val"] 54 | 55 | for split in splits: 56 | path = os.path.join( 57 | root, "gtFine_trainvaltest", "gtFine", split, "*", "*gtFine_labelIds.png" 58 | ) 59 | files = glob.glob(path) 60 | for file in tqdm(files): 61 | outfile = file.split(".png")[0] + "_19classes.png" 62 | mask_34 = cv2.imread(file, -1) 63 | mask_19 = mask_34.copy() 64 | for k, v in label_mapping.items(): 65 | mask_19[mask_34 == k] = v 66 | cv2.imwrite(outfile, mask_19) 67 | -------------------------------------------------------------------------------- /datasets/Cityscapes/process_Cityscapes_coarse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | ignore_label = 255 8 | label_mapping = { 9 | -1: ignore_label, 10 | 0: ignore_label, 11 | 1: ignore_label, 12 | 2: ignore_label, 13 | 3: ignore_label, 14 | 4: ignore_label, 15 | 5: ignore_label, 16 | 6: ignore_label, 17 | 7: 0, 18 | 8: 1, 19 | 9: ignore_label, 20 | 10: ignore_label, 21 | 11: 2, 22 | 12: 3, 23 | 13: 4, 24 | 14: ignore_label, 25 | 15: ignore_label, 26 | 16: ignore_label, 27 | 17: 5, 28 | 18: ignore_label, 29 | 19: 6, 30 | 20: 7, 31 | 21: 8, 32 | 22: 9, 33 | 23: 10, 34 | 24: 11, 35 | 25: 12, 36 | 26: 13, 37 | 27: 14, 38 | 28: 15, 39 | 29: ignore_label, 40 | 30: ignore_label, 41 | 31: 16, 42 | 32: 17, 43 | 33: 18, 44 | } 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("data_path", type=str) 49 | args = parser.parse_args() 50 | 51 | root = args.data_path 52 | 53 | splits = ["train_extra"] 54 | 55 | for split in splits: 56 | path = os.path.join(root, "gtCoarse", "gtCoarse", split, "*", "*gtCoarse_labelIds.png") 57 | files = glob.glob(path) 58 | for file in tqdm(files): 59 | outfile = file.split(".png")[0] + "_19classes.png" 60 | mask_34 = cv2.imread(file, -1) 61 | mask_19 = mask_34.copy() 62 | for k, v in label_mapping.items(): 63 | mask_19[mask_34 == k] = v 64 | cv2.imwrite(outfile, mask_19) 65 | -------------------------------------------------------------------------------- /datasets/PennFudan/PennFudan.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import albumentations as A 4 | import numpy as np 5 | import torch 6 | import cv2 7 | 8 | import os 9 | import numpy as np 10 | import torch 11 | import cv2 12 | 13 | 14 | class PennFudanDataset(torch.utils.data.Dataset): 15 | def __init__(self, root, split="train", transforms=None): 16 | # https://torchtutorialstaging.z5.web.core.windows.net/intermediate/torchvision_tutorial.html#writing-a-custom-dataset-for-pennfudan 17 | self.root = root 18 | self.transforms = transforms 19 | # load all image files, sorting them to 20 | # ensure that they are aligned 21 | self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages")))) 22 | self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks")))) 23 | if split == "train": 24 | self.imgs = self.imgs[:-50] 25 | self.masks = self.masks[:-50] 26 | else: 27 | self.imgs = self.imgs[-50:] 28 | self.masks = self.masks[-50:] 29 | 30 | def __getitem__(self, idx): 31 | # load images and masks 32 | img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) 33 | mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) 34 | 35 | img = cv2.imread(img_path) 36 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 37 | mask = cv2.imread(mask_path, -1) 38 | 39 | # instances are encoded as different colors 40 | obj_ids = np.unique(mask) 41 | # first id is the background, so remove it 42 | obj_ids = obj_ids[1:] 43 | masks = mask == obj_ids[:, None, None] 44 | 45 | masks = masks.transpose((1, 2, 0)) 46 | if self.transforms is not None: 47 | transformed = self.transforms(image=img, mask=masks) 48 | img = transformed["image"] / 255 49 | masks = transformed["mask"].permute(2, 0, 1) 50 | 51 | # get bounding box coordinates for each mask 52 | num_objs = len(obj_ids) 53 | boxes = [] 54 | for i in range(num_objs): 55 | pos = np.where(masks[i]) 56 | xmin = np.min(pos[1]) 57 | xmax = np.max(pos[1]) 58 | ymin = np.min(pos[0]) 59 | ymax = np.max(pos[0]) 60 | boxes.append([xmin, ymin, xmax, ymax]) 61 | 62 | # convert everything into a torch.Tensor 63 | boxes = torch.as_tensor(boxes, dtype=torch.float32) 64 | # there is only one class 65 | labels = torch.ones((num_objs,), dtype=torch.int64) 66 | masks = torch.as_tensor(masks, dtype=torch.uint8) 67 | 68 | image_id = torch.tensor([idx]) 69 | area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) 70 | # suppose all instances are not crowd 71 | iscrowd = torch.zeros((num_objs,), dtype=torch.int64) 72 | 73 | target = {} 74 | target["boxes"] = boxes 75 | target["labels"] = labels 76 | target["masks"] = masks 77 | target["image_id"] = image_id 78 | target["area"] = area 79 | target["iscrowd"] = iscrowd 80 | 81 | return img, target 82 | 83 | def __len__(self): 84 | return len(self.imgs) 85 | -------------------------------------------------------------------------------- /datasets/VOC2010_Context/VOC2010_Context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import torch 5 | 6 | import cv2 7 | import albumentations as A 8 | from albumentations.pytorch import ToTensorV2 9 | import numpy as np 10 | 11 | from src.utils import get_logger 12 | 13 | log = get_logger(__name__) 14 | 15 | ### NAME OF ALL CLASSES ### 16 | CLASSES = ( 17 | "background", 18 | "aeroplane", 19 | "bag", 20 | "bed", 21 | "bedclothes", 22 | "bench", 23 | "bicycle", 24 | "bird", 25 | "boat", 26 | "book", 27 | "bottle", 28 | "building", 29 | "bus", 30 | "cabinet", 31 | "car", 32 | "cat", 33 | "ceiling", 34 | "chair", 35 | "cloth", 36 | "computer", 37 | "cow", 38 | "cup", 39 | "curtain", 40 | "dog", 41 | "door", 42 | "fence", 43 | "floor", 44 | "flower", 45 | "food", 46 | "grass", 47 | "ground", 48 | "horse", 49 | "keyboard", 50 | "light", 51 | "motorbike", 52 | "mountain", 53 | "mouse", 54 | "person", 55 | "plate", 56 | "platform", 57 | "pottedplant", 58 | "road", 59 | "rock", 60 | "sheep", 61 | "shelves", 62 | "sidewalk", 63 | "sign", 64 | "sky", 65 | "snow", 66 | "sofa", 67 | "table", 68 | "track", 69 | "train", 70 | "tree", 71 | "truck", 72 | "tvmonitor", 73 | "wall", 74 | "water", 75 | "window", 76 | "wood", 77 | ) 78 | 79 | ### COLORMAPPING FOR EACH CLASS ### 80 | PALETTE = [ 81 | [120, 120, 120], 82 | [180, 120, 120], 83 | [6, 230, 230], 84 | [80, 50, 50], 85 | [4, 200, 3], 86 | [120, 120, 80], 87 | [140, 140, 140], 88 | [204, 5, 255], 89 | [230, 230, 230], 90 | [4, 250, 7], 91 | [224, 5, 255], 92 | [235, 255, 7], 93 | [150, 5, 61], 94 | [120, 120, 70], 95 | [8, 255, 51], 96 | [255, 6, 82], 97 | [143, 255, 140], 98 | [204, 255, 4], 99 | [255, 51, 7], 100 | [204, 70, 3], 101 | [0, 102, 200], 102 | [61, 230, 250], 103 | [255, 6, 51], 104 | [11, 102, 255], 105 | [255, 7, 71], 106 | [255, 9, 224], 107 | [9, 7, 230], 108 | [220, 220, 220], 109 | [255, 9, 92], 110 | [112, 9, 255], 111 | [8, 255, 214], 112 | [7, 255, 224], 113 | [255, 184, 6], 114 | [10, 255, 71], 115 | [255, 41, 10], 116 | [7, 255, 255], 117 | [224, 255, 8], 118 | [102, 8, 255], 119 | [255, 61, 6], 120 | [255, 194, 7], 121 | [255, 122, 8], 122 | [0, 255, 20], 123 | [255, 8, 41], 124 | [255, 5, 153], 125 | [6, 51, 255], 126 | [235, 12, 255], 127 | [160, 150, 20], 128 | [0, 163, 255], 129 | [140, 140, 140], 130 | [250, 10, 15], 131 | [20, 255, 0], 132 | [31, 255, 0], 133 | [255, 31, 0], 134 | [255, 224, 0], 135 | [153, 255, 0], 136 | [0, 0, 255], 137 | [255, 71, 0], 138 | [0, 235, 255], 139 | [0, 173, 255], 140 | [31, 0, 255], 141 | ] 142 | 143 | 144 | class VOC2010_Context_dataset(torch.utils.data.Dataset): 145 | def __init__(self, root, split="train", num_classes=60, ignore_index=255, transforms=None): 146 | # providing the possibility to have data and labels at different locations 147 | if isinstance(root, str): 148 | root_imgs = root 149 | root_labels = root 150 | else: 151 | root_imgs = root.IMAGES 152 | root_labels = root.LABELS 153 | 154 | self.split = split 155 | self.ignore_index = ignore_index 156 | self.num_classes = num_classes 157 | if split == "test": 158 | split = "val" 159 | imgs_path = os.path.join(root_imgs, "Images", split, "*.jpg") 160 | 161 | masks_path = os.path.join(root_labels, "Annotations", split, "*.png") 162 | 163 | self.imgs = list(sorted(glob.glob(imgs_path))) 164 | self.masks = list(sorted(glob.glob(masks_path))) 165 | 166 | self.transforms = transforms 167 | log.info( 168 | "Dataset: VOC2010_Context %s - %s images - %s masks", 169 | split, 170 | len(self.imgs), 171 | len(self.masks), 172 | ) 173 | 174 | def reduce_num_classes(self, mask): 175 | # exclude background class 176 | mask = mask - 1 177 | mask[mask == -1] = self.ignore_index 178 | return mask 179 | 180 | def __getitem__(self, idx): 181 | # read image (opencv read images in bgr) and mask 182 | img = cv2.imread(self.imgs[idx]) 183 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 184 | 185 | mask = cv2.imread(self.masks[idx], -1) 186 | 187 | # reduce the number of classes is specified in the config - yes by default 188 | if self.num_classes == 59: 189 | mask = self.reduce_num_classes(mask) 190 | 191 | # apply albumentations transforms 192 | transformed = self.transforms(image=img, mask=mask) 193 | img = transformed["image"] 194 | mask = transformed["mask"] 195 | 196 | return img, mask.long() 197 | 198 | def __len__(self): 199 | return len(self.imgs) 200 | 201 | 202 | if __name__ == "__main__": 203 | transforms = A.Compose( 204 | [ 205 | # A.RandomCrop(width=768, height=768), 206 | A.SmallestMaxSize(max_size=520), 207 | # A.RandomScale(scale_limit=(-0.5, 1), always_apply=True, p=1.0), 208 | # A.RGBShift(p=1,r_shift_limit=10,g_shift_limit=10,b_shift_limit=10), 209 | A.RandomScale(scale_limit=(-0.5, 1), always_apply=True, p=1.0), 210 | A.PadIfNeeded(min_height=520, min_width=520, border_mode=0, value=0, mask_value=255), 211 | # A.Resize(p=1.0,width=480, height=480), 212 | # A.RandomCrop(width=520, height=520,always_apply=True,p=1.0), 213 | # A.GaussianBlur(p=1), 214 | # A.ColorJitter(brightness=9,contrast=0,saturation=0,hue=0), 215 | # A.RGBShift(p=1,r_shift_limit=10,g_shift_limit=10,b_shift_limit=10), 216 | A.GaussianBlur(p=1), 217 | A.HorizontalFlip(p=0.5), 218 | A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], always_apply=True), 219 | ToTensorV2(), 220 | ] 221 | ) 222 | print(transforms) 223 | Path = "/home/l727r/Desktop/Datasets/VOC2010_Context" 224 | VOC2010_train = VOC2010_Context_dataset(Path, "train", transforms=transforms) 225 | 226 | img, mask = VOC2010_train[465] 227 | -------------------------------------------------------------------------------- /datasets/VOC2010_Context/process_VOC2010_Context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import scipy.io 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import shutil 8 | import argparse 9 | 10 | CLASSES = ( 11 | "background", 12 | "aeroplane", 13 | "bag", 14 | "bed", 15 | "bedclothes", 16 | "bench", 17 | "bicycle", 18 | "bird", 19 | "boat", 20 | "book", 21 | "bottle", 22 | "building", 23 | "bus", 24 | "cabinet", 25 | "car", 26 | "cat", 27 | "ceiling", 28 | "chair", 29 | "cloth", 30 | "computer", 31 | "cow", 32 | "cup", 33 | "curtain", 34 | "dog", 35 | "door", 36 | "fence", 37 | "floor", 38 | "flower", 39 | "food", 40 | "grass", 41 | "ground", 42 | "horse", 43 | "keyboard", 44 | "light", 45 | "motorbike", 46 | "mountain", 47 | "mouse", 48 | "person", 49 | "plate", 50 | "platform", 51 | "pottedplant", 52 | "road", 53 | "rock", 54 | "sheep", 55 | "shelves", 56 | "sidewalk", 57 | "sign", 58 | "sky", 59 | "snow", 60 | "sofa", 61 | "table", 62 | "track", 63 | "train", 64 | "tree", 65 | "truck", 66 | "tvmonitor", 67 | "wall", 68 | "water", 69 | "window", 70 | "wood", 71 | ) 72 | 73 | mapping = np.array( 74 | [ 75 | 0, 76 | 2, 77 | 9, 78 | 18, 79 | 19, 80 | 22, 81 | 23, 82 | 25, 83 | 31, 84 | 33, 85 | 34, 86 | 44, 87 | 45, 88 | 46, 89 | 59, 90 | 65, 91 | 68, 92 | 72, 93 | 80, 94 | 85, 95 | 98, 96 | 104, 97 | 105, 98 | 113, 99 | 115, 100 | 144, 101 | 158, 102 | 159, 103 | 162, 104 | 187, 105 | 189, 106 | 207, 107 | 220, 108 | 232, 109 | 258, 110 | 259, 111 | 260, 112 | 284, 113 | 295, 114 | 296, 115 | 308, 116 | 324, 117 | 326, 118 | 347, 119 | 349, 120 | 354, 121 | 355, 122 | 360, 123 | 366, 124 | 368, 125 | 397, 126 | 415, 127 | 416, 128 | 420, 129 | 424, 130 | 427, 131 | 440, 132 | 445, 133 | 454, 134 | 458, 135 | ] 136 | ) 137 | 138 | 139 | if __name__ == "__main__": 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("data_path", type=str) 142 | args = parser.parse_args() 143 | 144 | root_path = args.data_path 145 | 146 | outdir = os.path.join(root_path, "VOC2010_Context") 147 | if not os.path.exists(outdir): 148 | os.makedirs(os.path.join(outdir, "Annotations", "train")) 149 | os.makedirs(os.path.join(outdir, "Annotations", "val")) 150 | os.makedirs(os.path.join(outdir, "Images", "train")) 151 | os.makedirs(os.path.join(outdir, "Images", "val")) 152 | 153 | files = glob.glob(os.path.join(root_path, "trainval", "trainval", "*.mat")) 154 | 155 | with open( 156 | os.path.join( 157 | root_path, 158 | "VOCtrainval_03-May-2010", 159 | "VOCdevkit", 160 | "VOC2010", 161 | "ImageSets", 162 | "Main", 163 | "train.txt", 164 | ) 165 | ) as file: 166 | train = file.readlines() 167 | train = [line.rstrip() for line in train] 168 | with open( 169 | os.path.join( 170 | root_path, 171 | "VOCtrainval_03-May-2010", 172 | "VOCdevkit", 173 | "VOC2010", 174 | "ImageSets", 175 | "Main", 176 | "val.txt", 177 | ) 178 | ) as file: 179 | val = file.readlines() 180 | val = [line.rstrip() for line in val] 181 | 182 | print(len(train), len(val)) 183 | print(len(files)) 184 | 185 | print("## Covert Annoation Data ##") 186 | for file in tqdm(files): 187 | id = file.split("/")[-1].split(".")[0] 188 | 189 | label_file = os.path.join(root_path, "trainval", "trainval", id + ".mat") 190 | 191 | mat = scipy.io.loadmat(label_file)["LabelMap"] 192 | img = np.zeros(mat.shape) 193 | 194 | # values=np.unique(mat) 195 | for value in np.unique(mat): 196 | map = np.where(mapping == value)[0] 197 | if map.size > 0: 198 | img[mat == value] = map 199 | 200 | if id in train: 201 | outfile = os.path.join(outdir, "Annotations", "train", id + ".png") 202 | img_pil = Image.fromarray(np.uint8(img)) 203 | img_pil.save(outfile) 204 | 205 | if id in val: 206 | outfile = os.path.join(outdir, "Annotations", "val", id + ".png") 207 | img_pil = Image.fromarray(np.uint8(img)) 208 | img_pil.save(outfile) 209 | 210 | print("## Copy Image Data ##") 211 | for file in tqdm(files): 212 | id = file.split("/")[-1].split(".")[0] 213 | # print(os.path.join(root_path,file)) 214 | 215 | img_file = os.path.join( 216 | root_path, 217 | "VOCtrainval_03-May-2010", 218 | "VOCdevkit", 219 | "VOC2010", 220 | "JPEGImages", 221 | id + ".jpg", 222 | ) 223 | 224 | if id in train: 225 | outfile = os.path.join(outdir, "Images", "train", id + ".jpg") 226 | 227 | shutil.copy(img_file, outfile) 228 | 229 | if id in val: 230 | outfile = os.path.join(outdir, "Images", "val", id + ".jpg") 231 | 232 | shutil.copy(img_file, outfile) 233 | -------------------------------------------------------------------------------- /imgs/Data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Data.png -------------------------------------------------------------------------------- /imgs/Epochs_Batch_Size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Epochs_Batch_Size.png -------------------------------------------------------------------------------- /imgs/Further.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Further.png -------------------------------------------------------------------------------- /imgs/Logos/DKFZ_Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Logos/DKFZ_Logo.png -------------------------------------------------------------------------------- /imgs/Logos/HI_Logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Logos/HI_Logo.png -------------------------------------------------------------------------------- /imgs/Logos/HI_Title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Logos/HI_Title.png -------------------------------------------------------------------------------- /imgs/Lossfunctions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Lossfunctions.png -------------------------------------------------------------------------------- /imgs/Mixed_Precision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Mixed_Precision.png -------------------------------------------------------------------------------- /imgs/Models_Basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Models_Basic.png -------------------------------------------------------------------------------- /imgs/RMI_Loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/RMI_Loss.png -------------------------------------------------------------------------------- /imgs/Time_Complexity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/Time_Complexity.png -------------------------------------------------------------------------------- /imgs/VOC2010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/imgs/VOC2010.png -------------------------------------------------------------------------------- /models/DeepLabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.segmentation import deeplabv3_resnet101 3 | from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3 4 | from torchvision.models.segmentation.fcn import FCNHead 5 | 6 | 7 | def get_seg_model( 8 | num_classes: int, pretrained: bool, aux_loss: bool = None, backbone: str = "resnet101", **kwargs 9 | ) -> DeepLabV3: 10 | """ 11 | Initialize a DeepLab Model with a resnet backbone 12 | First init the basic torchvision model 13 | to enable the use of pretrained weights with different number of classes, the last layer of the 14 | classifier + aux_classifier is adapted after the basic initialization 15 | 16 | Parameters 17 | ---------- 18 | num_classes: int 19 | pretrained: bool 20 | aux_loss: aux_loss 21 | backbone: str 22 | resnet50 or resnet101 23 | kwargs 24 | 25 | Returns 26 | ------- 27 | DeepLabV3 : 28 | """ 29 | 30 | # load the deeplab model with the corresponding backbone 31 | if backbone == "resnet101": 32 | model = deeplabv3_resnet101(pretrained=pretrained, aux_loss=aux_loss, **kwargs) 33 | elif backbone == "resnet50": 34 | model = deeplabv3_resnet101(pretrained=pretrained, aux_loss=aux_loss, **kwargs) 35 | 36 | # to enable pretrained weights the last layer in the classifier head is adopted to match to the 37 | # number of classes after initializing of the model with pretrained weights 38 | in_channels = model.classifier[4].in_channels 39 | kernel_size = model.classifier[4].kernel_size 40 | stride = model.classifier[4].stride 41 | model.classifier[4] = torch.nn.Conv2d(in_channels, num_classes, kernel_size, stride) 42 | 43 | # the same is done for the aux_classifier if exists 44 | if hasattr(model, "aux_classifier"): 45 | in_channels = model.aux_classifier[4].in_channels 46 | kernel_size = model.aux_classifier[4].kernel_size 47 | stride = model.aux_classifier[4].stride 48 | model.aux_classifier[4] = torch.nn.Conv2d(in_channels, num_classes, kernel_size, stride) 49 | 50 | # For exchanging the complete Head 51 | # in_features = model.classifier[0].convs[0][0].in_channels 52 | # model.classifier = DeepLabHead(in_features, num_classes) 53 | # in_features_aux = model.aux_classifier[0].in_channels 54 | # model.aux_classifier = FCNHead(in_features_aux, num_classes) 55 | 56 | return model 57 | -------------------------------------------------------------------------------- /models/FCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.segmentation import fcn_resnet101, fcn_resnet50 3 | from torchvision.models.segmentation.fcn import FCNHead, FCN 4 | 5 | 6 | def get_seg_model( 7 | num_classes: int, pretrained: bool, aux_loss: bool = None, backbone: str = "resnet101", **kwargs 8 | ) -> FCN: 9 | """ 10 | Initialize a FCN Model with a resnet backbone 11 | First init the basic torchvision model 12 | to enable the use of pretrained weights with different number of classes, the last layer of the 13 | classifier + aux_classifier is adapted after the basic initialization 14 | 15 | Parameters 16 | ---------- 17 | num_classes: int 18 | pretrained: bool 19 | aux_loss: aux_loss 20 | backbone: str 21 | resnet50 or resnet101 22 | kwargs 23 | 24 | Returns 25 | ------- 26 | FCN : 27 | """ 28 | 29 | # load the fcn model with the corresponding backbone 30 | if backbone == "resnet101": 31 | model = fcn_resnet101(pretrained=pretrained, aux_loss=aux_loss, **kwargs) 32 | elif backbone == "resnet50": 33 | model = fcn_resnet50(pretrained=pretrained, aux_loss=aux_loss, **kwargs) 34 | 35 | # to enable pretrained weights the last layer in the classifier head is adopted to match to the 36 | # number of classes after initializing of the model with pretrained weights 37 | in_channels = model.classifier[4].in_channels 38 | kernel_size = model.classifier[4].kernel_size 39 | stride = model.classifier[4].stride 40 | model.classifier[4] = torch.nn.Conv2d(in_channels, num_classes, kernel_size, stride) 41 | 42 | # the same is done for the aux_classifier if exists 43 | if hasattr(model, "aux_classifier"): 44 | in_channels = model.aux_classifier[4].in_channels 45 | kernel_size = model.aux_classifier[4].kernel_size 46 | stride = model.aux_classifier[4].stride 47 | model.aux_classifier[4] = torch.nn.Conv2d(in_channels, num_classes, kernel_size, stride) 48 | 49 | return model 50 | -------------------------------------------------------------------------------- /models/Mask_RCNN.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 3 | from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor 4 | from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights 5 | from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights 6 | from torchvision.models.detection.transform import GeneralizedRCNNTransform 7 | 8 | 9 | # Disable the resize and normalize transform since we take care of this in the data augmentation pipeline 10 | class GeneralizedRCNNTransform_no_transform(GeneralizedRCNNTransform): 11 | def resize(self, image, target=None): 12 | return image, target 13 | 14 | def normalize(self, image): 15 | return image 16 | 17 | 18 | def get_model_50( 19 | num_classes, pretrained=True, version="v1", disable_transforms=False, *args, **kwargs 20 | ): 21 | # load an instance segmentation model pre-trained on COCO 22 | if pretrained: 23 | if version == "v1": 24 | model = models.detection.maskrcnn_resnet50_fpn( 25 | weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT, box_detections_per_img=250 26 | ) 27 | elif version == "v2": 28 | model = models.detection.maskrcnn_resnet50_fpn_v2( 29 | weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT, box_detections_per_img=250 30 | ) 31 | else: 32 | if version == "v1": 33 | model = models.detection.maskrcnn_resnet50_fpn(box_detections_per_img=250) 34 | elif version == "v2": 35 | model = models.detection.maskrcnn_resnet50_fpn_v2(box_detections_per_img=250) 36 | 37 | # get the number of input features for the classifier 38 | in_features = model.roi_heads.box_predictor.cls_score.in_features 39 | # replace the pre-trained head with a new one 40 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 41 | 42 | # now get the number of input features for the mask classifier 43 | in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels 44 | hidden_layer = 256 45 | # and replace the mask predictor with a new one 46 | model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes) 47 | 48 | if disable_transforms: 49 | model.transform = GeneralizedRCNNTransform_no_transform( 50 | min_size=None, max_size=None, image_mean=None, image_std=None 51 | ) 52 | 53 | return model 54 | -------------------------------------------------------------------------------- /models/Mask_RCNN_RMI_loss.py: -------------------------------------------------------------------------------- 1 | from torchvision.models.detection import roi_heads 2 | from torchvision.models.detection.roi_heads import project_masks_on_boxes 3 | 4 | from torchvision import models 5 | from torchvision.models.detection.faster_rcnn import FastRCNNPredictor 6 | from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor 7 | from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights 8 | from torchvision.models.detection import MaskRCNN_ResNet50_FPN_V2_Weights 9 | from torchvision.models.detection.transform import GeneralizedRCNNTransform 10 | import torch 11 | from src.loss.rmi import RMILoss 12 | 13 | 14 | class RMI_loss_dummy: 15 | def __init__(self, weights=None, ignore_index=255): 16 | if weights: 17 | weights = torch.FloatTensor(weights).cuda() 18 | self.lossfunction = RMILoss(num_classes=2, class_weights=weights, ignore_index=ignore_index) 19 | 20 | def __call__(self, mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): 21 | discretization_size = mask_logits.shape[-1] 22 | mask_targets = [ 23 | project_masks_on_boxes(m, p, i, discretization_size) 24 | for m, p, i in zip(gt_masks, proposals, mask_matched_idxs) 25 | ] 26 | mask_targets = torch.cat(mask_targets, dim=0).type(torch.int64) 27 | 28 | loss = self.lossfunction(mask_logits, mask_targets) 29 | return loss 30 | 31 | 32 | roi_heads.maskrcnn_loss = RMI_loss_dummy() 33 | 34 | 35 | # Disable the resize and normalize transform since we take care of this in the data augmentation pipeline 36 | class GeneralizedRCNNTransform_no_transform(GeneralizedRCNNTransform): 37 | def resize(self, image, target=None): 38 | return image, target 39 | 40 | def normalize(self, image): 41 | return image 42 | 43 | 44 | def get_model_50( 45 | num_classes, pretrained=True, version="v1", disable_transforms=False, *args, **kwargs 46 | ): 47 | # load an instance segmentation model pre-trained on COCO 48 | if pretrained: 49 | if version == "v1": 50 | model = models.detection.maskrcnn_resnet50_fpn( 51 | weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT, box_detections_per_img=250 52 | ) 53 | elif version == "v2": 54 | model = models.detection.maskrcnn_resnet50_fpn_v2( 55 | weights=MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT, box_detections_per_img=250 56 | ) 57 | else: 58 | if version == "v1": 59 | model = models.detection.maskrcnn_resnet50_fpn(box_detections_per_img=250) 60 | elif version == "v2": 61 | model = models.detection.maskrcnn_resnet50_fpn_v2(box_detections_per_img=250) 62 | 63 | # get the number of input features for the classifier 64 | in_features = model.roi_heads.box_predictor.cls_score.in_features 65 | # replace the pre-trained head with a new one 66 | model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 67 | 68 | # now get the number of input features for the mask classifier 69 | in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels 70 | hidden_layer = 256 71 | # and replace the mask predictor with a new one 72 | model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes) 73 | 74 | if disable_transforms: 75 | model.transform = GeneralizedRCNNTransform_no_transform( 76 | min_size=None, max_size=None, image_mean=None, image_std=None 77 | ) 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /models/UNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ------------------------------------------------------------------------------ 3 | Code slightly adapted and mainly from: 4 | https://github.com/milesial/Pytorch-UNet 5 | - model https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_model.py 6 | - model src https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py 7 | ------------------------------------------------------------------------------ 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class DoubleConv(nn.Module): 16 | """(convolution => [BN] => ReLU) * 2""" 17 | 18 | def __init__(self, in_channels, out_channels, mid_channels=None): 19 | super().__init__() 20 | if not mid_channels: 21 | mid_channels = out_channels 22 | self.double_conv = nn.Sequential( 23 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 24 | nn.BatchNorm2d(mid_channels), 25 | nn.ReLU(inplace=True), 26 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 27 | nn.BatchNorm2d(out_channels), 28 | nn.ReLU(inplace=True), 29 | ) 30 | 31 | def forward(self, x): 32 | return self.double_conv(x) 33 | 34 | 35 | class Down(nn.Module): 36 | """Downscaling with maxpool then double conv""" 37 | 38 | def __init__(self, in_channels, out_channels): 39 | super().__init__() 40 | self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) 41 | 42 | def forward(self, x): 43 | return self.maxpool_conv(x) 44 | 45 | 46 | class Up(nn.Module): 47 | """Upscaling then double conv""" 48 | 49 | def __init__(self, in_channels, out_channels, bilinear=True): 50 | super().__init__() 51 | 52 | # if bilinear, use the normal convolutions to reduce the number of channels 53 | if bilinear: 54 | self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 55 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 56 | else: 57 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 58 | self.conv = DoubleConv(in_channels, out_channels) 59 | 60 | def forward(self, x1, x2): 61 | x1 = self.up(x1) 62 | # input is CHW 63 | diffY = x2.size()[2] - x1.size()[2] 64 | diffX = x2.size()[3] - x1.size()[3] 65 | 66 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) 67 | # if you have padding issues, see 68 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 69 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 70 | x = torch.cat([x2, x1], dim=1) 71 | return self.conv(x) 72 | 73 | 74 | class OutConv(nn.Module): 75 | def __init__(self, in_channels, out_channels): 76 | super(OutConv, self).__init__() 77 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 78 | 79 | def forward(self, x): 80 | return self.conv(x) 81 | 82 | 83 | class UNet(nn.Module): 84 | def __init__(self, n_channels, n_classes, bilinear=False): 85 | super(UNet, self).__init__() 86 | self.n_channels = n_channels 87 | self.n_classes = n_classes 88 | self.bilinear = bilinear 89 | 90 | self.inc = DoubleConv(n_channels, 64) 91 | self.down1 = Down(64, 128) 92 | self.down2 = Down(128, 256) 93 | self.down3 = Down(256, 512) 94 | factor = 2 if bilinear else 1 95 | self.down4 = Down(512, 1024 // factor) 96 | self.up1 = Up(1024, 512 // factor, bilinear) 97 | self.up2 = Up(512, 256 // factor, bilinear) 98 | self.up3 = Up(256, 128 // factor, bilinear) 99 | self.up4 = Up(128, 64, bilinear) 100 | self.outc = OutConv(64, n_classes) 101 | 102 | def forward(self, x): 103 | x1 = self.inc(x) 104 | x2 = self.down1(x1) 105 | x3 = self.down2(x2) 106 | x4 = self.down3(x3) 107 | x5 = self.down4(x4) 108 | x = self.up1(x5, x4) 109 | x = self.up2(x, x3) 110 | x = self.up3(x, x2) 111 | x = self.up4(x, x1) 112 | logits = self.outc(x) 113 | return logits 114 | -------------------------------------------------------------------------------- /models/model_ensemble.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | import torch.nn as nn 4 | from omegaconf import OmegaConf, DictConfig 5 | import hydra 6 | import os 7 | import torch 8 | from src.utils import get_logger 9 | 10 | log = get_logger(__name__) 11 | 12 | 13 | class Ensemble(nn.Module): 14 | def __init__(self, ckpts): 15 | super(Ensemble, self).__init__() 16 | models = [] 17 | for ckpt in ckpts: 18 | 19 | # Init Model 20 | model_ckpt = OmegaConf.load(os.path.join(ckpt, "hparams.yaml")).model 21 | if hasattr(model_ckpt.cfg.MODEL, "PRETRAINED"): 22 | model_ckpt.cfg.MODEL.PRETRAINED = False 23 | model = hydra.utils.instantiate(model_ckpt) 24 | 25 | # Load State Dict 26 | ckpt_file = glob.glob(os.path.join(ckpt, "checkpoints", "best_*.ckpt"))[0] 27 | state_dict_ckpt = torch.load(ckpt_file, map_location={"cuda:0": "cpu"}) 28 | if "state_dict" in state_dict_ckpt.keys(): 29 | state_dict_ckpt = state_dict_ckpt["state_dict"] 30 | 31 | state_dict_ckpt = { 32 | k.replace("model.", "").replace("module.", ""): v 33 | for k, v in state_dict_ckpt.items() 34 | } 35 | model.load_state_dict(state_dict_ckpt) 36 | model.eval().cuda() 37 | models.append(model) 38 | 39 | log.info("{} loaded from ckpt {}".format(model_ckpt.cfg.MODEL.NAME, ckpt_file)) 40 | 41 | self.models = models 42 | 43 | def forward(self, x): 44 | out = None 45 | for m in self.models: 46 | if out is None: 47 | out = m(x)["out"] 48 | else: 49 | out += m(x)["out"] 50 | 51 | out_avg = out / len(self.models) 52 | 53 | return out_avg 54 | -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIC-DKFZ/semantic_segmentation/2bf503d94ee16e8910e7a28f553eecb7b6c28877/pretrained/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | preview = true -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Pytorch + Pytorch-Lightning 2 | # Adapt the cuda versions for torch and torchvision to your needs 3 | --find-links https://download.pytorch.org/whl/cu116/torch_stable.html 4 | torch>=1.13.1+cu116 5 | torchvision>=0.14.1+cu116 6 | torchmetrics==0.11.3 7 | pycocotools==2.0.5 8 | pytorch-lightning==2.0 9 | tensorboard>=2.12.0 10 | 11 | # Madgrad optimizer 12 | madgrad==1.2 13 | 14 | # Hydra 15 | hydra-core==1.3.2 16 | hydra-colorlog==1.2.0 17 | omegaconf==2.3.0 18 | 19 | # Packages for data handling, augmentation and visualization 20 | numpy>=1.23.2 21 | opencv-python>=4.6.0.66 22 | albumentations>=1.3.0 23 | pillow>=9.2.0 24 | matplotlib>=3.5.3 25 | 26 | # Code Formatter 27 | black>=22.12.0 28 | 29 | # Needed for loading .mat files for processing VOC2010_Context Dataset 30 | scipy>=1.9.0 31 | 32 | -------------------------------------------------------------------------------- /src/augmentations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import albumentations as A 3 | 4 | def RandAugment(N: int, M: int, p: float = 0.5, bins: int = 10, mode: int = 0) -> object: 5 | 6 | """ 7 | https://towardsdatascience.com/augmentation-methods-using-albumentations-and-pytorch-35cd135382f8 8 | # https://openreview.net/pdf?id=JrBfXaoxbA2 9 | # https://arxiv.org/pdf/1805.09501.pdf 10 | https://arxiv.org/pdf/1909.13719.pdf, 11 | Parameters 12 | ---------- 13 | N: int 14 | Number of Operations 15 | M: int 16 | Magnitude, intensity of the operations 17 | p: float, optional 18 | Probability for each operation to be applied 19 | bins: int, optional 20 | Number of Bins, max value for M 21 | mode: int, optional 22 | Border Mode for geometric transforms 23 | BORDER_CONSTANT = 0, 24 | BORDER_REPLICATE = 1, 25 | BORDER_REFLECT = 2, 26 | BORDER_WRAP = 3, 27 | BORDER_REFLECT_101 = 4, 28 | BORDER_TRANSPARENT = 5, 29 | BORDER_REFLECT101 = BORDER_REFLECT_101, 30 | BORDER_DEFAULT = BORDER_REFLECT_101, 31 | BORDER_ISOLATED = 16, 32 | """ 33 | """ 34 | Color Transformations 35 | """ 36 | # Contrast: 1=original, 0=gray, Range [max(0, 1 - contrast), 1 + contrast] 37 | contrast = np.linspace(0.0, 0.9, bins) 38 | # Brightness: 1=original, 0=black, Range [max(0, 1 - brightness), 1 + brightness] 39 | brightness = np.linspace(0.0, 0.9, bins) 40 | # Color: 1=original, 0=black-white image, Range [max(0, 1 - color), 1 + color] 41 | color = np.linspace(0.0, 0.9, bins) 42 | # Solariize: 255=original, 0=inverted, Range [solarize,255] 43 | solarize = np.linspace(255, 0, bins) 44 | # Posterize: Range [8,4] 45 | posterize = (8 - (np.arange(bins) / ((bins - 1) / 4)).round()).astype(int) 46 | # Sharpen 47 | sharpen = np.linspace(0.0, 2.0, bins) 48 | 49 | col_transforms = [ 50 | # Identity 51 | A.NoOp(p=p), 52 | # Contrast - equivalent to F.adjust_contrast(Image.fromarray(img), val) 53 | A.ColorJitter(brightness=0, contrast=contrast[M], saturation=0, hue=0, p=p), 54 | # Brightness - equivalent to F.adjust_brightness(Image.fromarray(img), val) 55 | A.ColorJitter(brightness=brightness[M], contrast=0, saturation=0, hue=0, p=p), 56 | # Color (~Saturation) - equivalent to F.adjust_saturation(Image.fromarray(img), val) 57 | A.ColorJitter(brightness=0, contrast=0, saturation=color[M], hue=0, p=p), 58 | # Solarize - eqilvalent to F.solarize(Image.fromarray(img), int(val)) 59 | A.Solarize(threshold=(solarize[M], 255), p=p), 60 | # Posterize - F.posterize(Image.fromarray(img), int(val)) 61 | A.Posterize(num_bits=int(posterize[M]), p=p), 62 | # Equalize - equivalent to F.equalize(Image.fromarray(img)) 63 | A.Equalize(mode="pil", p=p), 64 | # Invert - equalize to F.invert(Image.fromarray(img)) 65 | A.InvertImg(p=p), 66 | # Sharpen - no equivalent sharpen in albumentations compared to F.adjust_sharpness(Image.fromarray(img.copy()), val) 67 | # Replaced by Unsharpen mask + blure 68 | # A.UnsharpMask(sigma_limit=sharpen[M], alpha=(0.2, np.clip(sharpen, 0.2, 1.0)), p=p), 69 | A.UnsharpMask(sigma_limit=sharpen[M], p=p), 70 | A.Blur(blur_limit=3, p=p), 71 | ] 72 | """ 73 | Geometric Transformations 74 | """ 75 | # Shear X: 0=no shear, ~17=max degree, Range [-shear_x,shear_x] 76 | shear_x = np.linspace(0, np.degrees(np.arctan(0.3)), bins) 77 | # Shear Y: 0=no shear, ~17=max degree, Range [-shear_y,shear_y] 78 | shear_y = np.linspace(0, np.degrees(np.arctan(0.3)), bins) 79 | # Translate X: 0=no translation 0.2=max translation in %, Range [-translate_x,translate_x] 80 | translate_x = np.linspace(0, 0.2, bins) 81 | # Translate Y: 0=no translation 0.2=max translation in %, Range [-translate_y,translate_y] 82 | translate_y = np.linspace(0, 0.2, bins) 83 | # Rotate: 0=no rotation, 30=rotationabout 30 degree, Range [-rotation,rotation] 84 | rotation = np.linspace(0, 30, bins) 85 | 86 | geo_transforms = [ 87 | # Shear X - equivalen to F.affine(Image.fromarray(img.copy()),angle=0.0,translate=[0, 0],scale=1.0,shear=[math.degrees(math.atan(val)),0.0],interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,center=[0, 0],) 88 | A.Affine(shear={"x": (-shear_x[M], shear_x[M]), "y": 0}, p=p, mode=mode), 89 | # Shear Y - equivalen to F.affine(Image.fromarray(img.copy()),angle=0.0,translate=[0, 0],scale=1.0,shear=[0.0, math.degrees(math.atan(val))],interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,center=[0, 0],) 90 | A.Affine(shear={"x": 0, "y": (-shear_y[M], shear_y[M])}, p=p, mode=mode), 91 | # Translate X - F.affine(Image.fromarray(img.copy()),angle=0.0,translate=[150, 0],scale=1.0,shear=[0.0, 0.0],interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,center=[0, 0],) 92 | A.Affine( 93 | translate_percent={"x": (-translate_x[M], translate_x[M]), "y": 0}, p=p, mode=mode 94 | ), 95 | # Translate Y - F.affine(Image.fromarray(img.copy()),angle=0.0,translate=[0,150],scale=1.0,shear=[0.0, 0.0],interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,center=[0, 0],) 96 | A.Affine( 97 | translate_percent={"x": 0, "y": (-translate_y[M], translate_y[M])}, p=p, mode=mode 98 | ), 99 | # Rotate - equivalent to F.rotate(Image.fromarray(img.copy()),val,interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,) 100 | A.Affine(rotate=(-rotation[M], rotation[M]), p=p, mode=mode), 101 | ] 102 | 103 | """ 104 | Return RandAugment pipeline 105 | """ 106 | transforms = A.SomeOf(col_transforms + geo_transforms, n=N) 107 | return transforms 108 | 109 | 110 | def RandAugment_light(N: int, M: int, p: float = 0.5, bins: int = 10, mode: int = 0) -> object: 111 | 112 | """ 113 | subset of color augmentations of RandAugment 114 | https://towardsdatascience.com/augmentation-methods-using-albumentations-and-pytorch-35cd135382f8 115 | # https://openreview.net/pdf?id=JrBfXaoxbA2 116 | # https://arxiv.org/pdf/1805.09501.pdf 117 | Parameters 118 | ---------- 119 | N: int 120 | Number of Operations 121 | M: int 122 | Magnitude, intensity of the operations 123 | p: float, optional 124 | Probability for each operation to be applied 125 | bins: int, optional 126 | Number of Bins, max value for M 127 | mode: int, optional 128 | Border Mode for geometric transforms 129 | BORDER_CONSTANT = 0, 130 | BORDER_REPLICATE = 1, 131 | BORDER_REFLECT = 2, 132 | BORDER_WRAP = 3, 133 | BORDER_REFLECT_101 = 4, 134 | BORDER_TRANSPARENT = 5, 135 | BORDER_REFLECT101 = BORDER_REFLECT_101, 136 | BORDER_DEFAULT = BORDER_REFLECT_101, 137 | BORDER_ISOLATED = 16, 138 | """ 139 | """ 140 | Color Transformations 141 | """ 142 | # Contrast: 1=original, 0=gray, Range [max(0, 1 - contrast), 1 + contrast] 143 | contrast = np.linspace(0.0, 0.9, bins) 144 | # Brightness: 1=original, 0=black, Range [max(0, 1 - brightness), 1 + brightness] 145 | brightness = np.linspace(0.0, 0.9, bins) 146 | # Color: 1=original, 0=black-white image, Range [max(0, 1 - color), 1 + color] 147 | color = np.linspace(0.0, 0.9, bins) 148 | # Solariize: 255=original, 0=inverted, Range [solarize,255] 149 | solarize = np.linspace(255, 0, bins) 150 | # Posterize: Range [8,4] 151 | posterize = (8 - (np.arange(bins) / ((bins - 1) / 4)).round()).astype(int) 152 | # Sharpen 153 | sharpen = np.linspace(0.0, 2.0, bins) 154 | 155 | col_transforms = [ 156 | # Identity 157 | A.NoOp(p=p), 158 | # Contrast - equivalent to F.adjust_contrast(Image.fromarray(img), val) 159 | A.ColorJitter(brightness=0, contrast=contrast[M], saturation=0, hue=0, p=p), 160 | # Brightness - equivalent to F.adjust_brightness(Image.fromarray(img), val) 161 | A.ColorJitter(brightness=brightness[M], contrast=0, saturation=0, hue=0, p=p), 162 | # Color (~Saturation) - equivalent to F.adjust_saturation(Image.fromarray(img), val) 163 | A.ColorJitter(brightness=0, contrast=0, saturation=color[M], hue=0, p=p), 164 | # Solarize - eqilvalent to F.solarize(Image.fromarray(img), int(val)) 165 | # A.Solarize(threshold=(solarize[M], 255), p=p), 166 | # Posterize - F.posterize(Image.fromarray(img), int(val)) 167 | # A.Posterize(num_bits=int(posterize[M]), p=p), 168 | # Equalize - equivalent to F.equalize(Image.fromarray(img)) 169 | # A.Equalize(mode="pil", p=p), 170 | # Invert - equalize to F.invert(Image.fromarray(img)) 171 | # A.InvertImg(p=p), 172 | # Sharpen - no equivalent sharpen in albumentations compared to F.adjust_sharpness(Image.fromarray(img.copy()), val) 173 | # Replaced by Unsharpen mask + blure 174 | # A.UnsharpMask(sigma_limit=sharpen[M], alpha=(0.2, np.clip(sharpen, 0.2, 1.0)), p=p), 175 | A.UnsharpMask(sigma_limit=sharpen[M], p=p), 176 | A.Blur(blur_limit=3, p=p), 177 | ] 178 | """ 179 | Geometric Transformations 180 | """ 181 | # Shear X: 0=no shear, ~17=max degree, Range [-shear_x,shear_x] 182 | shear_x = np.linspace(0, np.degrees(np.arctan(0.3)), bins) 183 | # Shear Y: 0=no shear, ~17=max degree, Range [-shear_y,shear_y] 184 | shear_y = np.linspace(0, np.degrees(np.arctan(0.3)), bins) 185 | # Translate X: 0=no translation 0.2=max translation in %, Range [-translate_x,translate_x] 186 | translate_x = np.linspace(0, 0.2, bins) 187 | # Translate Y: 0=no translation 0.2=max translation in %, Range [-translate_y,translate_y] 188 | translate_y = np.linspace(0, 0.2, bins) 189 | # Rotate: 0=no rotation, 30=rotationabout 30 degree, Range [-rotation,rotation] 190 | rotation = np.linspace(0, 30, bins) 191 | 192 | geo_transforms = [ 193 | # Shear X - equivalen to F.affine(Image.fromarray(img.copy()),angle=0.0,translate=[0, 0],scale=1.0,shear=[math.degrees(math.atan(val)),0.0],interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,center=[0, 0],) 194 | A.Affine(shear={"x": (-shear_x[M], shear_x[M]), "y": 0}, p=p, mode=mode), 195 | # Shear Y - equivalen to F.affine(Image.fromarray(img.copy()),angle=0.0,translate=[0, 0],scale=1.0,shear=[0.0, math.degrees(math.atan(val))],interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,center=[0, 0],) 196 | A.Affine(shear={"x": 0, "y": (-shear_y[M], shear_y[M])}, p=p, mode=mode), 197 | # Translate X - F.affine(Image.fromarray(img.copy()),angle=0.0,translate=[150, 0],scale=1.0,shear=[0.0, 0.0],interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,center=[0, 0],) 198 | A.Affine( 199 | translate_percent={"x": (-translate_x[M], translate_x[M]), "y": 0}, p=p, mode=mode 200 | ), 201 | # Translate Y - F.affine(Image.fromarray(img.copy()),angle=0.0,translate=[0,150],scale=1.0,shear=[0.0, 0.0],interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,center=[0, 0],) 202 | A.Affine( 203 | translate_percent={"x": 0, "y": (-translate_y[M], translate_y[M])}, p=p, mode=mode 204 | ), 205 | # Rotate - equivalent to F.rotate(Image.fromarray(img.copy()),val,interpolation=torchvision.transforms.InterpolationMode.NEAREST,fill=None,) 206 | A.Affine(rotate=(-rotation[M], rotation[M]), p=p, mode=mode), 207 | ] 208 | 209 | """ 210 | Return RandAugment pipeline 211 | """ 212 | transforms = A.SomeOf(col_transforms + geo_transforms, n=N) 213 | return transforms 214 | -------------------------------------------------------------------------------- /src/callbacks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | from pytorch_lightning.callbacks import Callback, ModelCheckpoint 5 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 6 | from pytorch_lightning.callbacks.progress.tqdm_progress import _update_n 7 | from pytorch_lightning.callbacks.progress.tqdm_progress import convert_inf, Tqdm 8 | import sys 9 | from typing import Any, Dict, Optional, Union 10 | 11 | from pytorch_lightning.utilities.types import STEP_OUTPUT 12 | 13 | 14 | class customModelCheckpoint(ModelCheckpoint): 15 | """ 16 | Small modification on the ModelCheckpoint from pytorch lightning for renaming the last epoch 17 | """ 18 | 19 | def __init__(self, **kwargs): 20 | super( 21 | customModelCheckpoint, 22 | self, 23 | ).__init__(**kwargs) 24 | self.CHECKPOINT_NAME_LAST = "last_epoch_{epoch}" 25 | 26 | 27 | class customTQDMProgressBar(TQDMProgressBar): 28 | """ 29 | Small modification on the TQDMProgressBar class from pytorch lightning to get rid of the 30 | "v_num" entry and the printing bug during validation (linebreak + print in every step) 31 | https://stackoverflow.com/questions/59455268/how-to-disable-progress-bar-in-pytorch-lightning/66731318#66731318 32 | https://github.com/PyTorchLightning/pytorch-lightning/issues/765 33 | this is another solution to use the terminal as output console for Pycharm 34 | https://stackoverflow.com/questions/59455268/how-to-disable-progress-bar-in-pytorch-lightning 35 | """ 36 | 37 | def __init__(self, *args, **kwargs): 38 | super().__init__(*args, **kwargs) 39 | self.status = "None" 40 | 41 | def init_validation_tqdm(self): 42 | ### disable validation tqdm instead only use train_progress_bar### 43 | bar = tqdm( 44 | disable=True, 45 | ) 46 | return bar 47 | 48 | def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module) -> None: 49 | # reset progress, set status and update metric 50 | self.train_progress_bar.reset( 51 | convert_inf(self.total_train_batches + self.total_val_batches) 52 | ) 53 | self.train_progress_bar.initial = 0 54 | self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") 55 | self.status = "Training" 56 | self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) 57 | 58 | def on_validation_batch_end( 59 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0 60 | ): 61 | n = self.total_train_batches + batch_idx + 1 62 | if self._should_update(n, self.train_progress_bar.total): 63 | _update_n(self.train_progress_bar, n) 64 | 65 | def get_metrics(self, trainer, model): 66 | # don't show the version number 67 | items = super().get_metrics(trainer, model) 68 | items.pop("v_num", None) 69 | items["status"] = self.status 70 | 71 | return items 72 | 73 | def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): 74 | if not trainer.sanity_checking: 75 | self.status = "Validation" 76 | self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) 77 | 78 | def on_validation_epoch_end( 79 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 80 | ) -> None: 81 | self.status = "Done" 82 | self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) 83 | self.train_progress_bar.refresh() 84 | print("") 85 | 86 | 87 | class TimeCallback(Callback): 88 | """ 89 | Callback for measuring the time during train, validation and testing 90 | """ 91 | 92 | def __init__(self): 93 | self.t_train_start = torch.cuda.Event(enable_timing=True) 94 | self.t_val_start = torch.cuda.Event(enable_timing=True) 95 | self.t_test_start = torch.cuda.Event(enable_timing=True) 96 | self.end = torch.cuda.Event(enable_timing=True) 97 | self.t_train = [] 98 | self.t_val = [] 99 | 100 | def on_train_epoch_start(self, *args, **kwargs): 101 | self.t_train_start.record() 102 | 103 | def on_validation_epoch_start(self, trainer, *args, **kwargs): 104 | if not trainer.sanity_checking: 105 | self.end.record() 106 | torch.cuda.synchronize() 107 | 108 | train_time = self.t_train_start.elapsed_time(self.end) / 1000 109 | self.t_train.append(train_time) 110 | 111 | self.log( 112 | "Time/train_time", 113 | train_time, 114 | logger=True, 115 | sync_dist=True if trainer.num_devices > 1 else False, 116 | prog_bar=True, 117 | ) 118 | self.log( 119 | "Time/mTrainTime", 120 | np.mean(self.t_train), 121 | logger=True, 122 | sync_dist=True if trainer.num_devices > 1 else False, 123 | ) 124 | 125 | if not trainer.sanity_checking: 126 | self.t_val_start.record() 127 | 128 | def on_validation_epoch_end(self, trainer, *args, **kwargs): 129 | if not trainer.sanity_checking: 130 | self.end.record() 131 | torch.cuda.synchronize() 132 | 133 | val_time = self.t_val_start.elapsed_time(self.end) / 1000 134 | self.t_val.append(val_time) 135 | 136 | self.log( 137 | "Time/validation_time", 138 | val_time, 139 | logger=True, 140 | sync_dist=True if trainer.num_devices > 1 else False, 141 | prog_bar=True, 142 | ) 143 | self.log( 144 | "Time/mValTime", 145 | np.mean(self.t_val), 146 | logger=True, 147 | sync_dist=True if trainer.num_devices > 1 else False, 148 | ) 149 | 150 | def on_test_epoch_start(self, trainer, *args, **kwargs): 151 | self.t_test_start.record() 152 | 153 | def on_test_epoch_end(self, trainer, *args, **kwargs): 154 | self.end.record() 155 | torch.cuda.synchronize() 156 | 157 | test_time = self.t_test_start.elapsed_time(self.end) / 1000 158 | 159 | self.log( 160 | "Time/test_time", 161 | test_time, 162 | logger=True, 163 | sync_dist=True if trainer.num_devices > 1 else False, 164 | ) 165 | -------------------------------------------------------------------------------- /src/loss/DC_CE_Loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | ------------------------------------------------------------------------------ 3 | Code slightly adapted and mainly from: 4 | https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunet/training/loss_functions 5 | ------------------------------------------------------------------------------ 6 | """ 7 | 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | # from nnunet.utilities.nd_softmax import softmax_helper 13 | import numpy as np 14 | import torch.nn.functional as F 15 | 16 | softmax_helper = lambda x: F.softmax(x, 1) 17 | 18 | 19 | def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): 20 | """ 21 | net_output must be (b, c, x, y(, z))) 22 | gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) 23 | if mask is provided it must have shape (b, 1, x, y(, z))) 24 | :param net_output: 25 | :param gt: 26 | :param axes: can be (, ) = no summation 27 | :param mask: mask must be 1 for valid pixels and 0 for invalid pixels 28 | :param square: if True then fp, tp and fn will be squared before summation 29 | :return: 30 | """ 31 | if axes is None: 32 | axes = tuple(range(2, len(net_output.size()))) 33 | 34 | shp_x = net_output.shape 35 | shp_y = gt.shape 36 | 37 | with torch.no_grad(): 38 | if len(shp_x) != len(shp_y): 39 | gt = gt.view((shp_y[0], 1, *shp_y[1:])) 40 | 41 | if all([i == j for i, j in zip(net_output.shape, gt.shape)]): 42 | # if this is the case then gt is probably already a one hot encoding 43 | y_onehot = gt 44 | else: 45 | gt = gt.long() 46 | y_onehot = torch.zeros(shp_x) 47 | if net_output.device.type == "cuda": 48 | y_onehot = y_onehot.cuda(net_output.device.index) 49 | y_onehot.scatter_(1, gt, 1) 50 | 51 | tp = net_output * y_onehot 52 | fp = net_output * (1 - y_onehot) 53 | fn = (1 - net_output) * y_onehot 54 | tn = (1 - net_output) * (1 - y_onehot) 55 | 56 | if mask is not None: 57 | tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) 58 | fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) 59 | fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) 60 | tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) 61 | 62 | if square: 63 | tp = tp**2 64 | fp = fp**2 65 | fn = fn**2 66 | tn = tn**2 67 | 68 | if len(axes) > 0: 69 | tp = sum_tensor(tp, axes, keepdim=False) 70 | fp = sum_tensor(fp, axes, keepdim=False) 71 | fn = sum_tensor(fn, axes, keepdim=False) 72 | tn = sum_tensor(tn, axes, keepdim=False) 73 | 74 | return tp, fp, fn, tn 75 | 76 | 77 | def sum_tensor(inp, axes, keepdim=False): 78 | axes = np.unique(axes).astype(int) 79 | if keepdim: 80 | for ax in axes: 81 | inp = inp.sum(int(ax), keepdim=True) 82 | else: 83 | for ax in sorted(axes, reverse=True): 84 | inp = inp.sum(int(ax)) 85 | return inp 86 | 87 | 88 | class SoftDiceLoss(nn.Module): 89 | def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.0): 90 | """ """ 91 | super(SoftDiceLoss, self).__init__() 92 | 93 | self.do_bg = do_bg 94 | self.batch_dice = batch_dice 95 | self.apply_nonlin = apply_nonlin 96 | self.smooth = smooth 97 | 98 | def forward(self, x, y, loss_mask=None): 99 | shp_x = x.shape 100 | 101 | if self.batch_dice: 102 | axes = [0] + list(range(2, len(shp_x))) 103 | else: 104 | axes = list(range(2, len(shp_x))) 105 | 106 | if self.apply_nonlin is not None: 107 | x = self.apply_nonlin(x) 108 | 109 | tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False) 110 | 111 | nominator = 2 * tp + self.smooth 112 | denominator = 2 * tp + fp + fn + self.smooth 113 | 114 | dc = nominator / (denominator + 1e-8) 115 | 116 | if not self.do_bg: 117 | if self.batch_dice: 118 | dc = dc[1:] 119 | else: 120 | dc = dc[:, 1:] 121 | dc = dc.mean() 122 | 123 | return -dc 124 | 125 | 126 | class RobustCrossEntropyLoss(nn.CrossEntropyLoss): 127 | """ 128 | this is just a compatibility layer because my target tensor is float and has an extra dimension 129 | """ 130 | 131 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 132 | if len(target.shape) == len(input.shape): 133 | assert target.shape[1] == 1 134 | target = target[:, 0] 135 | return super().forward(input, target.long()) 136 | 137 | 138 | class DC_and_CE_loss(nn.Module): 139 | def __init__( 140 | self, 141 | soft_dice_kwargs, 142 | ce_kwargs, 143 | aggregate="sum", 144 | square_dice=False, 145 | weight_ce=1, 146 | weight_dice=1, 147 | log_dice=False, 148 | ignore_label=None, 149 | ): 150 | """ 151 | CAREFUL. Weights for CE and Dice do not need to sum to one. You can set whatever you want. 152 | :param soft_dice_kwargs: 153 | :param ce_kwargs: 154 | :param aggregate: 155 | :param square_dice: 156 | :param weight_ce: 157 | :param weight_dice: 158 | """ 159 | super(DC_and_CE_loss, self).__init__() 160 | if ignore_label is not None: 161 | assert not square_dice, "not implemented" 162 | ce_kwargs["reduction"] = "none" 163 | self.log_dice = log_dice 164 | self.weight_dice = weight_dice 165 | self.weight_ce = weight_ce 166 | self.aggregate = aggregate 167 | self.ce = RobustCrossEntropyLoss(**ce_kwargs) 168 | 169 | self.ignore_label = ignore_label 170 | 171 | if not square_dice: 172 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) 173 | else: 174 | self.dc = SoftDiceLossSquared(apply_nonlin=softmax_helper, **soft_dice_kwargs) 175 | 176 | def forward(self, net_output, target): 177 | """ 178 | target must be b, c, x, y(, z) with c=1 179 | :param net_output: 180 | :param target: 181 | :return: 182 | """ 183 | if self.ignore_label is not None: 184 | assert target.shape[1] == 1, "not implemented for one hot encoding" 185 | mask = target != self.ignore_label 186 | target[~mask] = 0 187 | mask = mask.float() 188 | else: 189 | mask = None 190 | 191 | dc_loss = self.dc(net_output, target, loss_mask=mask) if self.weight_dice != 0 else 0 192 | if self.log_dice: 193 | dc_loss = -torch.log(-dc_loss) 194 | 195 | ce_loss = self.ce(net_output, target[:, 0].long()) if self.weight_ce != 0 else 0 196 | if self.ignore_label is not None: 197 | ce_loss *= mask[:, 0] 198 | ce_loss = ce_loss.sum() / mask.sum() 199 | 200 | if self.aggregate == "sum": 201 | result = self.weight_ce * ce_loss + self.weight_dice * dc_loss 202 | else: 203 | raise NotImplementedError("nah son") # reserved for other stuff (later) 204 | return result 205 | 206 | 207 | class DC_and_topk_loss(nn.Module): 208 | def __init__( 209 | self, 210 | soft_dice_kwargs, 211 | ce_kwargs, 212 | aggregate="sum", 213 | square_dice=False, 214 | ignore_label=None, 215 | ): 216 | super(DC_and_topk_loss, self).__init__() 217 | self.aggregate = aggregate 218 | self.ce = TopKLoss(**ce_kwargs) 219 | if not square_dice: 220 | self.dc = SoftDiceLoss(apply_nonlin=softmax_helper, **soft_dice_kwargs) 221 | else: 222 | self.dc = SoftDiceLossSquared(apply_nonlin=softmax_helper, **soft_dice_kwargs) 223 | self.ignore_label = ignore_label 224 | 225 | def forward(self, net_output, target): 226 | # dc_loss = self.dc(net_output, target) 227 | if self.ignore_label is not None: 228 | assert target.shape[1] == 1, "not implemented for one hot encoding" 229 | mask = target != self.ignore_label 230 | target[~mask] = 0 231 | mask = mask.float() 232 | else: 233 | mask = None 234 | dc_loss = self.dc(net_output, target, loss_mask=mask) 235 | ce_loss = self.ce(net_output, target) 236 | if self.aggregate == "sum": 237 | result = ce_loss + dc_loss 238 | else: 239 | raise NotImplementedError("nah son") # reserved for other stuff (later?) 240 | return result 241 | 242 | 243 | class TopKLoss(RobustCrossEntropyLoss): 244 | """ 245 | Network has to have NO LINEARITY! 246 | """ 247 | 248 | def __init__(self, weight=None, ignore_index=-100, k=10): 249 | self.k = k 250 | super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False) 251 | 252 | def forward(self, inp, target): 253 | target = target[:, 0].long() 254 | res = super(TopKLoss, self).forward(inp, target) 255 | num_voxels = np.prod(res.shape, dtype=np.int64) 256 | res, _ = torch.topk(res.view((-1,)), int(num_voxels * self.k / 100), sorted=False) 257 | return res.mean() 258 | -------------------------------------------------------------------------------- /src/loss/Dice_Loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | ------------------------------------------------------------------------------ 3 | Code slightly adapted and mainly from: 4 | https://github.com/BloodAxe/pytorch-toolbelt/blob/ae796bb09cce0698258875ed56d82cde131431e4/pytorch_toolbelt/losses/functional.py#L181def 5 | ------------------------------------------------------------------------------ 6 | """ 7 | 8 | from torch.nn.modules.loss import _Loss 9 | from typing import List 10 | from torch import Tensor 11 | import torch.nn.functional as F 12 | import torch 13 | 14 | 15 | def soft_dice_score(output, target, smooth=0.0, eps=1e-7, dims=None): 16 | """ 17 | :param output: 18 | :param target: 19 | :param smooth: 20 | :param eps: 21 | :return: 22 | Shape: 23 | - Input: :math:`(N, NC, *)` where :math:`*` means any number 24 | of additional dimensions 25 | - Target: :math:`(N, NC, *)`, same shape as the input 26 | - Output: scalar. 27 | """ 28 | assert output.size() == target.size() 29 | if dims is not None: 30 | intersection = torch.sum(output * target, dim=dims) 31 | cardinality = torch.sum(output + target, dim=dims) 32 | else: 33 | intersection = torch.sum(output * target) 34 | cardinality = torch.sum(output + target) 35 | dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) 36 | return dice_score 37 | 38 | 39 | BINARY_MODE = "binary" 40 | MULTICLASS_MODE = "multiclass" 41 | MULTILABEL_MODE = "multilabel" 42 | 43 | 44 | class DiceLoss(_Loss): 45 | # https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/dice.py 46 | """ 47 | Implementation of Dice loss for image segmentation task. 48 | It supports binary, multiclass and multilabel cases 49 | """ 50 | 51 | def __init__( 52 | self, 53 | mode: str, 54 | classes: List[int] = None, 55 | log_loss=False, 56 | from_logits=True, 57 | smooth: float = 0.0, 58 | ignore_index=None, 59 | eps=1e-7, 60 | ): 61 | """ 62 | :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} 63 | :param classes: Optional list of classes that contribute in loss computation; 64 | By default, all channels are included. 65 | :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` 66 | :param from_logits: If True assumes input is raw logits 67 | :param smooth: 68 | :param ignore_index: Label that indicates ignored pixels (does not contribute to loss) 69 | :param eps: Small epsilon for numerical stability 70 | """ 71 | assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} 72 | super(DiceLoss, self).__init__() 73 | self.mode = mode 74 | if classes is not None: 75 | assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" 76 | classes = to_tensor(classes, dtype=torch.long) 77 | 78 | self.classes = classes 79 | self.from_logits = from_logits 80 | self.smooth = smooth 81 | self.eps = eps 82 | self.ignore_index = ignore_index 83 | self.log_loss = log_loss 84 | 85 | def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: 86 | """ 87 | :param y_pred: NxCxHxW 88 | :param y_true: NxHxW 89 | :return: scalar 90 | """ 91 | assert y_true.size(0) == y_pred.size(0) 92 | 93 | if self.from_logits: 94 | # Apply activations to get [0..1] class probabilities 95 | # Using Log-Exp as this gives more numerically stable result and does not cause vanishing gradient on 96 | # extreme values 0 and 1 97 | if self.mode == MULTICLASS_MODE: 98 | y_pred = y_pred.log_softmax(dim=1).exp() 99 | else: 100 | y_pred = F.logsigmoid(y_pred).exp() 101 | 102 | bs = y_true.size(0) 103 | num_classes = y_pred.size(1) 104 | dims = (0, 2) 105 | 106 | if self.mode == BINARY_MODE: 107 | y_true = y_true.view(bs, 1, -1) 108 | y_pred = y_pred.view(bs, 1, -1) 109 | 110 | if self.ignore_index is not None: 111 | mask = y_true != self.ignore_index 112 | y_pred = y_pred * mask 113 | y_true = y_true * mask 114 | 115 | if self.mode == MULTICLASS_MODE: 116 | y_true = y_true.view(bs, -1) 117 | y_pred = y_pred.view(bs, num_classes, -1) 118 | 119 | if self.ignore_index is not None: 120 | mask = y_true != self.ignore_index 121 | y_pred = y_pred * mask.unsqueeze(1) 122 | 123 | y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C 124 | y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W 125 | else: 126 | y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C 127 | y_true = y_true.permute(0, 2, 1) # H, C, H*W 128 | 129 | if self.mode == MULTILABEL_MODE: 130 | y_true = y_true.view(bs, num_classes, -1) 131 | y_pred = y_pred.view(bs, num_classes, -1) 132 | 133 | if self.ignore_index is not None: 134 | mask = y_true != self.ignore_index 135 | y_pred = y_pred * mask 136 | y_true = y_true * mask 137 | 138 | scores = soft_dice_score( 139 | y_pred, y_true.type_as(y_pred), smooth=self.smooth, eps=self.eps, dims=dims 140 | ) 141 | 142 | if self.log_loss: 143 | # Log Cosh Dice Loss 144 | # loss = -torch.log(scores.clamp_min(self.eps)) 145 | loss = 1.0 - scores 146 | loss = torch.log(torch.cosh(loss)) 147 | else: 148 | # Dice Loss 149 | loss = 1.0 - scores 150 | 151 | # Dice loss is undefined for non-empty classes 152 | # So we zero contribution of channel that does not have true pixels 153 | # NOTE: A better workaround would be to use loss term `mean(y_pred)` 154 | # for this case, however it will be a modified jaccard loss 155 | 156 | mask = y_true.sum(dims) > 0 157 | loss *= mask.to(loss.dtype) 158 | 159 | if self.classes is not None: 160 | loss = loss[self.classes] 161 | 162 | return loss.mean() 163 | -------------------------------------------------------------------------------- /src/loss/rmi_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | ------------------------------------------------------------------------------ 3 | Code slightly adapted and mainly from: 4 | https://github.com/ZJULearning/RMI/tree/master/losses/rmi 5 | ------------------------------------------------------------------------------ 6 | """ 7 | from __future__ import print_function 8 | from __future__ import division 9 | from __future__ import absolute_import 10 | 11 | # import os 12 | # import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | 17 | __all__ = ["map_get_pairs", "log_det_by_cholesky"] 18 | 19 | 20 | def map_get_pairs(labels_4D, probs_4D, radius=3, is_combine=True): 21 | """get map pairs 22 | Args: 23 | labels_4D : labels, shape [N, C, H, W] 24 | probs_4D : probabilities, shape [N, C, H, W] 25 | radius : the square radius 26 | Return: 27 | tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)] 28 | """ 29 | # pad to ensure the following slice operation is valid 30 | # pad_beg = int(radius // 2) 31 | # pad_end = radius - pad_beg 32 | 33 | # the original height and width 34 | label_shape = labels_4D.size() 35 | h, w = label_shape[2], label_shape[3] 36 | new_h, new_w = h - (radius - 1), w - (radius - 1) 37 | # https://pytorch.org/docs/stable/nn.html?highlight=f%20pad#torch.nn.functional.pad 38 | # padding = (pad_beg, pad_end, pad_beg, pad_end) 39 | # labels_4D, probs_4D = F.pad(labels_4D, padding), F.pad(probs_4D, padding) 40 | 41 | # get the neighbors 42 | la_ns = [] 43 | pr_ns = [] 44 | # for x in range(0, radius, 1): 45 | for y in range(0, radius, 1): 46 | for x in range(0, radius, 1): 47 | la_now = labels_4D[:, :, y : y + new_h, x : x + new_w] 48 | pr_now = probs_4D[:, :, y : y + new_h, x : x + new_w] 49 | la_ns.append(la_now) 50 | pr_ns.append(pr_now) 51 | 52 | if is_combine: 53 | # for calculating RMI 54 | pair_ns = la_ns + pr_ns 55 | p_vectors = torch.stack(pair_ns, dim=2) 56 | return p_vectors 57 | else: 58 | # for other purpose 59 | la_vectors = torch.stack(la_ns, dim=2) 60 | pr_vectors = torch.stack(pr_ns, dim=2) 61 | return la_vectors, pr_vectors 62 | 63 | 64 | def map_get_pairs_region(labels_4D, probs_4D, radius=3, is_combine=0, num_classeses=21): 65 | """get map pairs 66 | Args: 67 | labels_4D : labels, shape [N, C, H, W]. 68 | probs_4D : probabilities, shape [N, C, H, W]. 69 | radius : The side length of the square region. 70 | Return: 71 | A tensor with shape [N, C, radiu * radius, H // radius, W // raidius] 72 | """ 73 | kernel = torch.zeros([num_classeses, 1, radius, radius]).type_as(probs_4D) 74 | padding = radius // 2 75 | # get the neighbours 76 | la_ns = [] 77 | pr_ns = [] 78 | for y in range(0, radius, 1): 79 | for x in range(0, radius, 1): 80 | kernel_now = kernel.clone() 81 | kernel_now[:, :, y, x] = 1.0 82 | la_now = F.conv2d( 83 | labels_4D, 84 | kernel_now, 85 | stride=radius, 86 | padding=padding, 87 | groups=num_classeses, 88 | ) 89 | pr_now = F.conv2d( 90 | probs_4D, 91 | kernel_now, 92 | stride=radius, 93 | padding=padding, 94 | groups=num_classeses, 95 | ) 96 | la_ns.append(la_now) 97 | pr_ns.append(pr_now) 98 | 99 | if is_combine: 100 | # for calculating RMI 101 | pair_ns = la_ns + pr_ns 102 | p_vectors = torch.stack(pair_ns, dim=2) 103 | return p_vectors 104 | else: 105 | # for other purpose 106 | la_vectors = torch.stack(la_ns, dim=2) 107 | pr_vectors = torch.stack(pr_ns, dim=2) 108 | return la_vectors, pr_vectors 109 | return 110 | 111 | 112 | def log_det_by_cholesky(matrix): 113 | """ 114 | Args: 115 | matrix: matrix must be a positive define matrix. 116 | shape [N, C, D, D]. 117 | Ref: 118 | https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/linalg/linalg_impl.py 119 | """ 120 | # This uses the property that the log det(A) = 2 * sum(log(real(diag(C)))) 121 | # where C is the cholesky decomposition of A. 122 | # chol = torch.cholesky(matrix) 123 | chol = torch.linalg.cholesky(matrix, upper=False) 124 | # return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-6), dim=-1) 125 | return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-8), dim=-1) 126 | 127 | 128 | def batch_cholesky_inverse(matrix): 129 | """ 130 | Args: matrix, 4-D tensor, [N, C, M, M]. 131 | matrix must be a symmetric positive define matrix. 132 | """ 133 | # chol_low = torch.cholesky(matrix, upper=False) 134 | chol_low = torch.linalg.cholesky(matrix, upper=False) 135 | chol_low_inv = batch_low_tri_inv(chol_low) 136 | return torch.matmul(chol_low_inv.transpose(-2, -1), chol_low_inv) 137 | 138 | 139 | def batch_low_tri_inv(L): 140 | """ 141 | Batched inverse of lower triangular matrices 142 | Args: 143 | L : a lower triangular matrix 144 | Ref: 145 | https://www.pugetsystems.com/labs/hpc/PyTorch-for-Scientific-Computing 146 | """ 147 | n = L.shape[-1] 148 | invL = torch.zeros_like(L) 149 | for j in range(0, n): 150 | invL[..., j, j] = 1.0 / L[..., j, j] 151 | for i in range(j + 1, n): 152 | S = 0.0 153 | for k in range(0, i + 1): 154 | S = S - L[..., i, k] * invL[..., k, j].clone() 155 | invL[..., i, j] = S / L[..., i, i] 156 | return invL 157 | 158 | 159 | def log_det_by_cholesky_test(): 160 | """ 161 | test for function log_det_by_cholesky() 162 | """ 163 | a = torch.randn(1, 4, 4) 164 | a = torch.matmul(a, a.transpose(2, 1)) 165 | print(a) 166 | res_1 = torch.logdet(torch.squeeze(a)) 167 | res_2 = log_det_by_cholesky(a) 168 | print(res_1, res_2) 169 | 170 | 171 | def batch_inv_test(): 172 | """ 173 | test for function batch_cholesky_inverse() 174 | """ 175 | a = torch.randn(1, 1, 4, 4) 176 | a = torch.matmul(a, a.transpose(-2, -1)) 177 | print(a) 178 | res_1 = torch.inverse(a) 179 | res_2 = batch_cholesky_inverse(a) 180 | print(res_1, "\n", res_2) 181 | 182 | 183 | def mean_var_test(): 184 | x = torch.randn(3, 4) 185 | y = torch.randn(3, 4) 186 | 187 | x_mean = x.mean(dim=1, keepdim=True) 188 | x_sum = x.sum(dim=1, keepdim=True) / 2.0 189 | y_mean = y.mean(dim=1, keepdim=True) 190 | y_sum = y.sum(dim=1, keepdim=True) / 2.0 191 | 192 | x_var_1 = torch.matmul(x - x_mean, (x - x_mean).t()) 193 | x_var_2 = torch.matmul(x, x.t()) - torch.matmul(x_sum, x_sum.t()) 194 | xy_cov = torch.matmul(x - x_mean, (y - y_mean).t()) 195 | xy_cov_1 = torch.matmul(x, y.t()) - x_sum.matmul(y_sum.t()) 196 | 197 | print(x_var_1) 198 | print(x_var_2) 199 | 200 | print(xy_cov, "\n", xy_cov_1) 201 | 202 | 203 | if __name__ == "__main__": 204 | batch_inv_test() 205 | -------------------------------------------------------------------------------- /src/loss_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | from omegaconf import DictConfig 4 | from src.loss.rmi import RMILoss 5 | from src.loss.Dice_Loss import DiceLoss 6 | from src.loss.DC_CE_Loss import DC_and_CE_loss, TopKLoss, DC_and_topk_loss 7 | from src.utils import has_not_empty_attr 8 | 9 | 10 | def get_loss_function_from_cfg(name_lf: str, cfg: DictConfig) -> list: 11 | """ 12 | Instantiate the Lossfunction identified by name_lf 13 | Therefore get the needed parameters from the config 14 | 15 | Parameters 16 | ---------- 17 | name_lf: str 18 | string identifier of the wanted loss function 19 | cfg : DictConfig 20 | complete config 21 | 22 | Returns 23 | ------- 24 | Lossfunction 25 | """ 26 | num_classes = cfg.DATASET.NUM_CLASSES 27 | ignore_index = ( 28 | cfg.DATASET.IGNORE_INDEX if has_not_empty_attr(cfg.DATASET, "IGNORE_INDEX") else -100 29 | ) 30 | if name_lf == "CE": 31 | loss_function = torch.nn.CrossEntropyLoss( 32 | ignore_index=ignore_index 33 | ) # , label_smoothing=0.1) 34 | elif name_lf == "wCE": 35 | weights = torch.FloatTensor(cfg.DATASET.CLASS_WEIGHTS).cuda() 36 | loss_function = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, weight=weights) 37 | elif name_lf == "RMI": 38 | loss_function = RMILoss(num_classes=num_classes, ignore_index=ignore_index) 39 | 40 | elif name_lf == "wRMI": 41 | weights = torch.FloatTensor(cfg.DATASET.CLASS_WEIGHTS).cuda() 42 | loss_function = RMILoss( 43 | num_classes=num_classes, ignore_index=ignore_index, class_weights=weights 44 | ) 45 | 46 | elif name_lf == "DC": 47 | loss_function = DiceLoss(mode="multiclass", ignore_index=ignore_index) 48 | 49 | elif name_lf == "DC_CE": 50 | DC_and_CE = DC_and_CE_loss( 51 | {"batch_dice": True, "smooth": 0, "do_bg": False}, 52 | {"ignore_index": ignore_index}, 53 | ignore_label=ignore_index, 54 | ) 55 | loss_function = lambda pred, gt: DC_and_CE(pred.clone(), gt[:, None].clone()) 56 | elif name_lf == "TOPK": 57 | TopK = TopKLoss(ignore_index=ignore_index) 58 | loss_function = lambda pred, gt: TopK(pred.clone(), gt[:, None]) 59 | elif name_lf == "DC_TOPK": 60 | DC_TopK = DC_and_topk_loss( 61 | {"batch_dice": True, "smooth": 0, "do_bg": False}, 62 | {"ignore_index": ignore_index}, 63 | ignore_label=ignore_index, 64 | ) 65 | loss_function = lambda pred, gt: DC_TopK(pred.clone(), gt[:, None]) 66 | else: 67 | raise NotImplementedError("No Lossfunction found for {}".format(name_lf)) 68 | return loss_function 69 | 70 | -------------------------------------------------------------------------------- /src/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | def polynomial_LR_scheduler_stepwise(optimizer, max_steps, exponent=0.9, **kwargs): 7 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 8 | optimizer, lambda step: (1 - step / max_steps) ** exponent 9 | ) 10 | 11 | return lr_scheduler 12 | 13 | 14 | def polynomial_LR_scheduler_epochwise(optimizer, max_epochs, exponent=0.9, **kwargs): 15 | lr_scheduler = torch.optim.lr_scheduler.LambdaLR( 16 | optimizer, lambda epoch: (1 - epoch / (max_epochs)) ** exponent 17 | ) 18 | 19 | return lr_scheduler 20 | 21 | 22 | class PolynomialLR_Warmstart(_LRScheduler): 23 | def __init__( 24 | self, optimizer, warmstart_iters=1, total_iters=5, power=1.0, last_epoch=-1, verbose=False 25 | ): 26 | if isinstance(warmstart_iters, int): 27 | self.warmstart = warmstart_iters 28 | if isinstance(warmstart_iters, float): 29 | self.warmstart = int(total_iters * warmstart_iters) 30 | self.warmstart = max(self.warmstart, 1) 31 | 32 | self.total_iters = total_iters - self.warmstart 33 | 34 | self.power = power 35 | super().__init__(optimizer, last_epoch, verbose) 36 | 37 | def get_lr(self): 38 | if not self._get_lr_called_within_step: 39 | warnings.warn( 40 | "To get the last learning rate computed by the scheduler, " 41 | "please use `get_last_lr()`.", 42 | UserWarning, 43 | ) 44 | # Warmstart 45 | if self.last_epoch < self.warmstart: 46 | # addrates = [(lr / (self.warmstart + 1)) for lr in self.base_lrs] 47 | # updated_lr = [ 48 | # addrates[i] * (self.last_epoch + 1) 49 | # for i, group in enumerate(self.optimizer.param_groups) 50 | # ] 51 | addrates = [(lr / (self.warmstart)) for lr in self.base_lrs] 52 | updated_lr = [ 53 | addrates[i] * (self.last_epoch + 1) 54 | for i, group in enumerate(self.optimizer.param_groups) 55 | ] 56 | return updated_lr 57 | 58 | if self.last_epoch == 0 or self.last_epoch > self.total_iters: 59 | return [group["lr"] for group in self.optimizer.param_groups] 60 | 61 | decay_factor = ( 62 | (1.0 - self.last_epoch / self.total_iters) 63 | / (1.0 - (self.last_epoch - 1) / self.total_iters) 64 | ) ** self.power 65 | return [group["lr"] * decay_factor for group in self.optimizer.param_groups] 66 | 67 | def _get_closed_form_lr(self): 68 | return [ 69 | ( 70 | base_lr 71 | * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power 72 | ) 73 | for base_lr in self.base_lrs 74 | ] 75 | 76 | 77 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import hydra 3 | from omegaconf import DictConfig, OmegaConf 4 | import os 5 | from typing import Any 6 | 7 | import torch 8 | from pytorch_lightning.utilities import rank_zero_only 9 | import pytorch_lightning as pl 10 | 11 | 12 | def get_logger(name: str = __name__) -> logging.Logger: 13 | """ 14 | Initializes multi-GPU-friendly python command line logger 15 | Taken from: 16 | https://github.com/ashleve/lightning-hydra-template/blob/main/src/utils/__init__.py 17 | 18 | Parameters 19 | ---------- 20 | name: str 21 | 22 | Returns 23 | ------- 24 | logging.Logger : 25 | """ 26 | logger = logging.getLogger(name) 27 | 28 | # this ensures all logging levels get marked with the rank zero decorator 29 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 30 | for level in ( 31 | "debug", 32 | "info", 33 | "warning", 34 | "error", 35 | "exception", 36 | "fatal", 37 | "critical", 38 | ): 39 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 40 | return logger 41 | 42 | 43 | @rank_zero_only 44 | def log_hyperparameters( 45 | config: DictConfig, 46 | model: pl.LightningModule, 47 | trainer: pl.Trainer, 48 | ) -> None: 49 | """ 50 | Controls which config parts are saved by Lightning loggers, additionally update hparams.yaml 51 | Taken and adopted from: 52 | https://github.com/ashleve/lightning-hydra-template/blob/main/src/utils/__init__.py 53 | 54 | Parameters 55 | ---------- 56 | config : DictConfig 57 | model : pl.LightningModule 58 | trainer: pl.Trainer 59 | """ 60 | hparams = {} 61 | 62 | # choose which parts of hydra config will be saved to loggers 63 | hparams["model"] = config.MODEL.NAME 64 | hparams["dataset"] = config.DATASET.NAME 65 | hparams["metric"] = model.metric_name 66 | 67 | avail_GPUS = torch.cuda.device_count() 68 | selected_GPUS = config.pl_trainer.devices 69 | hparams["num_gpus"] = int(num_gpus(avail_GPUS, selected_GPUS)) 70 | 71 | hparams["lossfunction"] = config.lossfunction 72 | hparams["optimizer"] = "" 73 | hparams["lr_scheduler"] = "" 74 | 75 | if hydra.core.hydra_config.HydraConfig.initialized(): 76 | cfg = hydra.core.hydra_config.HydraConfig.get() 77 | if has_not_empty_attr(cfg.runtime.choices, "optimizer"): 78 | hparams["optimizer"] = cfg.runtime.choices.optimizer 79 | if has_not_empty_attr(cfg.runtime.choices, "lr_scheduler"): 80 | hparams["lr_scheduler"] = cfg.runtime.choices.lr_scheduler 81 | 82 | hparams["lr"] = config.lr 83 | hparams["epochs"] = config.epochs 84 | hparams["batch_size"] = config.batch_size 85 | hparams["precision"] = trainer.precision 86 | 87 | # save number of model parameters 88 | hparams["Parameter"] = sum(p.numel() for p in model.parameters()) 89 | hparams["trainable Parameter"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 90 | metric = { 91 | "metric/best_" + model.metric_name: torch.nan, 92 | "Time/mTrainTime": torch.nan, 93 | "Time/mValTime": torch.nan, 94 | } 95 | # print(hparams, metric) 96 | # send hparams to all loggers 97 | trainer.logger.log_hyperparams(hparams, metric) 98 | 99 | # save resolved config in hparams.yaml 100 | OmegaConf.save( 101 | config=config, 102 | resolve=True, 103 | f=os.path.join(trainer.logger.log_dir, "hparams.yaml"), 104 | ) 105 | 106 | 107 | def num_gpus(avail_GPUS: int, selected_GPUS: Any) -> int: 108 | """ 109 | Translating the num_gpus of pytorch lightning trainers into a raw number of used gpus 110 | Needed since lightning enables to pass gpu as int, list or string 111 | 112 | Parameters 113 | ---------- 114 | avail_GPUS : int 115 | how many gpus are available 116 | selected_GPUS : Any 117 | num_gpus input argument for the pytorch lightning trainers 118 | 119 | Returns 120 | ------- 121 | int : 122 | the number of used gpus 123 | """ 124 | if selected_GPUS in [-1, "-1"]: 125 | num_gpus = avail_GPUS 126 | elif selected_GPUS in [0, "0", None]: 127 | num_gpus = 0 128 | elif isinstance(selected_GPUS, int): 129 | num_gpus = selected_GPUS 130 | elif isinstance(selected_GPUS, list): 131 | num_gpus = len(selected_GPUS) 132 | elif isinstance(selected_GPUS, str): 133 | num_gpus = len(selected_GPUS.split(",")) 134 | return num_gpus 135 | 136 | 137 | def first_from_dict(dictionary): 138 | return list(dictionary.values())[0] 139 | 140 | 141 | def has_true_attr(obj: Any, attr: str) -> bool: 142 | """ 143 | return True if obj contains attr and attr is true, else returns False 144 | 145 | Parameters 146 | ---------- 147 | obj : Any 148 | attr : str 149 | 150 | Returns 151 | ------- 152 | bool : 153 | """ 154 | if hasattr(obj, attr): 155 | if obj[attr]: 156 | return True 157 | return False 158 | 159 | 160 | def has_not_empty_attr(obj: Any, attr: str) -> bool: 161 | """ 162 | return True if obj contains attr and attr is not empty, else returns False 163 | 164 | Parameters 165 | ---------- 166 | obj : Any 167 | attr : str 168 | 169 | Returns 170 | ------- 171 | bool : 172 | """ 173 | if hasattr(obj, attr): 174 | if obj[attr] != None: 175 | return True 176 | return False 177 | 178 | -------------------------------------------------------------------------------- /src/visualization.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | import numpy as np 4 | import cv2 5 | 6 | from pytorch_lightning import LightningModule 7 | from torch.utils.data import Dataset 8 | 9 | from src.utils import has_not_empty_attr, get_logger 10 | 11 | log = get_logger(__name__) 12 | 13 | 14 | def convert_torch_to(img_torch, output_type): 15 | if output_type == "numpy": 16 | return np.array(img_torch) 17 | elif output_type == "PIL": 18 | return Image.fromarray(np.array(img_torch)) 19 | elif output_type == "torch": 20 | return img_torch 21 | 22 | 23 | def convert_numpy_to(img_np, output_type): 24 | if output_type == "numpy": 25 | return img_np 26 | elif output_type == "PIL": 27 | return Image.fromarray(img_np) 28 | elif output_type == "torch": 29 | return torch.tensor(img_np) 30 | 31 | 32 | def show_img(img: torch.Tensor, mean: list = None, std: list = None, output_type: str = "numpy"): 33 | 34 | img_np = np.array(img) 35 | if len(img_np.shape) == 2: 36 | img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB) 37 | else: 38 | img_np = np.moveaxis(img_np, 0, -1) 39 | 40 | if mean is not None and std is not None: 41 | img_np = ((img_np * std) + mean) * 255 42 | elif np.min(img_np) >= 0 and np.max(img_np) <= 1: 43 | img_np = img_np * 255 44 | img_np = img_np.astype(np.uint8) 45 | 46 | return convert_numpy_to(img_np, output_type) 47 | 48 | 49 | def show_mask_sem_seg(mask: torch.Tensor, cmap: list, output_type: str = "numpy"): 50 | mask_np = np.array(mask) 51 | w, h = mask_np.shape 52 | fig = np.zeros((w, h, 3), dtype=np.uint8) 53 | for class_id in np.unique(mask_np): 54 | x, y = np.where(mask_np == class_id) 55 | if class_id >= len(cmap): 56 | fig[x, y] = [0, 0, 0] 57 | else: 58 | fig[x, y, :] = cmap[class_id] 59 | fig = fig.astype(np.uint8) 60 | return convert_numpy_to(fig, output_type) 61 | 62 | 63 | def show_mask_inst_seg(target, img_shape, output_type: str = "numpy", alpha=0.5): 64 | masks = target["masks"].squeeze(1) 65 | boxes = target["boxes"] # .detach().cpu() 66 | fig = np.ones((*img_shape, 3), dtype=np.uint8) * 255 67 | for mask, box in zip(masks, boxes): 68 | 69 | color = np.random.randint(0, 255, 3) 70 | 71 | # If also the bounding box should be shown 72 | # x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3]) 73 | # cv2.rectangle(img, (x1, y1), (x2, y2), [int(color[0]), int(color[1]), int(color[2])]) 74 | 75 | x, y = np.where(mask != 0) 76 | fig[x, y] = fig[x, y] * alpha + color * (1 - alpha) 77 | fig = fig.astype(np.uint8) 78 | return convert_numpy_to(fig, output_type) 79 | 80 | 81 | def show_prediction_sem_seg(): 82 | pass 83 | 84 | 85 | def show_prediction_inst_seg(pred, img_shape, output_type="numpy", alpha=0.5): 86 | # pred = [{k: v.detach().cpu() for k, v in t.items()} for t in pred] 87 | # pred = list(p.detach().cpu() for p in pred) 88 | pred = pred[0] 89 | masks = pred["masks"].squeeze(1) 90 | boxes = pred["boxes"] 91 | scores = pred["scores"] 92 | 93 | masks = [mask for mask, score in zip(masks, scores) if score >= 0.5] 94 | boxes = [box for box, score in zip(boxes, scores) if score >= 0.5] 95 | 96 | fig = np.ones((*img_shape, 3), dtype=np.uint8) * 255 97 | for mask, box in zip(masks, boxes): 98 | 99 | color = np.random.randint(0, 255, 3) 100 | 101 | # If also the bounding box should be shown 102 | # x1, y1, x2, y2 = int(box[0]), int(box[1]), int(box[2]), int(box[3]) 103 | # cv2.rectangle(img, (x1, y1), (x2, y2), [int(color[0]), int(color[1]), int(color[2])]) 104 | 105 | x, y = np.where(mask >= 0.5) 106 | fig[x, y] = fig[x, y] * alpha + color * (1 - alpha) 107 | fig = fig.astype(np.uint8) 108 | return convert_numpy_to(fig, output_type) 109 | 110 | 111 | class Visualizer: 112 | def __init__( 113 | self, 114 | dataset: Dataset, 115 | cmap: np.ndarray, 116 | model: LightningModule = None, 117 | mean: list = None, 118 | std: list = None, 119 | segmentation: str = "semantic", 120 | axis: int = 1, 121 | ) -> None: 122 | """ 123 | Visualizing a Dataset 124 | If a model if Given also the prediction is of the model on the dataset is shown 125 | 126 | Parameters 127 | ---------- 128 | dataset: Dataset 129 | dataset which should be visualized 130 | cmap: np.ndarray 131 | colormap to color the singel classes, list of RGB values 132 | model: LightningModule, optional 133 | if given the model is used to generate predictions for images in dataset 134 | mean: list, optional 135 | if given the normalization is inverted during visualization --> nicer image 136 | std: list, optional 137 | if given the normalization is inverted during visualization --> nicer image 138 | """ 139 | self.model = model 140 | self.dataset = dataset 141 | self.cmap = cmap 142 | self.mean = mean 143 | self.std = std 144 | self.segmentation = segmentation 145 | self.axis = axis 146 | 147 | def color_mask(self, mask: torch.Tensor, img_shape) -> np.ndarray: 148 | """ 149 | Color encode mask with color ids into RGB 150 | 151 | Parameters 152 | ---------- 153 | mask_np 154 | array of shape [w,h], with class ids for each pixel 155 | Returns 156 | ------- 157 | np.ndarray : 158 | array of shape [w,h,3] with color encoding of each class (in RGB format) 159 | """ 160 | if self.segmentation == "semantic": 161 | return show_mask_sem_seg(mask, self.cmap, "numpy") 162 | elif self.segmentation == "instance": 163 | return show_mask_inst_seg(mask, img_shape, "numpy") 164 | 165 | def viz_correctness(self, pred: torch.Tensor, mask: torch.Tensor) -> np.ndarray: 166 | """ 167 | visualizing the correctness of the prediction (where pred is qual to mask) 168 | 169 | Parameters 170 | ---------- 171 | pred : torch.Tensor 172 | mask: torch.Tensor 173 | 174 | Returns 175 | ------- 176 | np.ndarray : 177 | """ 178 | cor = np.zeros(self.mask_np.shape, dtype=np.uint8) 179 | # where prediction and gt are equal 180 | x, y = np.where(pred == mask) 181 | # pixel which dont belong to a class (ignore index) 182 | x_ign, y_ign = np.where(mask > len(self.cmap)) 183 | 184 | cor[:, :] = [255, 0, 0] # Red for not equal pixel 185 | cor[x, y] = [0, 255, 0] # Green for equal pixel 186 | cor[x_ign, y_ign] = [0, 0, 0] # Black for ignored pixel 187 | return cor 188 | 189 | def update_window(self, *arg, **kwargs) -> None: 190 | """ 191 | Update the opencv Window when another image should be displayed (another img_id) 192 | Load Image and Mask and transform them into the correct format (opencv conform) 193 | (Optional) if a model is given also predict the image and colorize prediction 194 | """ 195 | img_id = cv2.getTrackbarPos("Image ID", "Window") 196 | 197 | # Load Image and Mask, transform image and colorize the mask 198 | img, mask = self.dataset[img_id] 199 | 200 | self.img_np = show_img(img, self.mean, self.std, "numpy") 201 | if self.segmentation == "semantic": 202 | self.mask_np = show_mask_sem_seg(mask, self.cmap, "numpy") 203 | elif self.segmentation == "instance": 204 | self.mask_np = show_mask_inst_seg(mask, img.shape[-2:], "numpy") 205 | 206 | # Predict the Image and colorize the prediction 207 | if self.model is not None: 208 | if self.segmentation == "semantic": 209 | pred = self.model(img.unsqueeze(0).cuda()) 210 | pred = torch.argmax(list(pred.values())[0].squeeze(), dim=0).detach().cpu() 211 | self.pred = self.color_mask(np.array(pred), img_shape=img.shape[-2:]) 212 | 213 | # Show Correctness of prediction 214 | self.cor = self.viz_correctness(pred, mask) 215 | elif self.segmentation == "instance": 216 | pred = self.model(img.unsqueeze(0).cuda())[0] 217 | pred = [{k: v.detach().cpu() for k, v in pred.items()}] 218 | self.pred = show_prediction_inst_seg(pred, img_shape=img.shape[-2:]) 219 | 220 | # update the the channel and alpha parameter and show the window 221 | self.update_channel_and_alpha() 222 | 223 | def update_channel_and_alpha(self, *arg, **kwargs) -> None: 224 | """ 225 | Select the correct Channel 226 | if -1 the channels 0:3 are used 227 | otherwise a single channel is used in grayscale on 3 channels 228 | """ 229 | if hasattr(self, "img_np"): 230 | channel_id = cv2.getTrackbarPos("Channel", "Window") 231 | 232 | # Select the correct Channel, if -1 use the channels 0:3 otherwise use a single one 233 | # and transform to RGB 234 | if channel_id == -1: 235 | self.img_np_chan = self.img_np[:, :, 0:3] 236 | else: 237 | self.img_np_chan = self.img_np[:, :, channel_id] 238 | self.img_np_chan = cv2.cvtColor(self.img_np_chan, cv2.COLOR_GRAY2RGB) 239 | 240 | # Update Alpha and udate Window 241 | self.update_alpha() 242 | 243 | def update_alpha(self, *arg, **kwargs) -> None: 244 | """ 245 | Display the image blended with the mask or prediction on the left, on the right the gt mask 246 | Alpha defines the weight of the blending 247 | Afterwards update the opencv image 248 | """ 249 | if hasattr(self, "img_np_chan"): 250 | alpha = cv2.getTrackbarPos("alpha", "Window") / 100 251 | 252 | # Blend the image with prediction 253 | if hasattr(self, "pred"): 254 | self.img_np_fig = cv2.addWeighted(self.img_np_chan, 1 - alpha, self.pred, alpha, 0.0) 255 | self.img_np_fig = self.update_corrects(self.img_np_fig) 256 | else: 257 | self.img_np_fig = cv2.addWeighted(self.img_np_chan, 1 - alpha, self.mask_np, alpha, 0.0) 258 | bg_map = np.all(self.mask_np == [255, 255, 255], axis=2) 259 | self.img_np_fig[bg_map] = self.img_np_fig[bg_map] 260 | # concat blended image and mask 261 | fig = np.concatenate((self.img_np_fig, self.mask_np), self.axis) 262 | # transform from RGB to BGR to match the cv2 order 263 | self.fig = cv2.cvtColor(fig, cv2.COLOR_RGB2BGR) 264 | # show image 265 | cv2.imshow("Window", self.fig) 266 | 267 | def update_corrects(self, img) -> None: 268 | """ 269 | Display the image blended with the mask or prediction on the left, on the right the gt mask 270 | Alpha defines the weight of the blending 271 | Afterwards update the opencv image 272 | """ 273 | alpha_cor = cv2.getTrackbarPos("correctness", "Window") / 100 274 | if alpha_cor > 0: 275 | 276 | img = cv2.addWeighted(img, 1 - alpha_cor, self.cor, alpha_cor, 0.0) 277 | return img 278 | -------------------------------------------------------------------------------- /testing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO) 4 | 5 | import os 6 | import glob 7 | import hydra 8 | from omegaconf import OmegaConf, DictConfig 9 | 10 | from pytorch_lightning import Trainer 11 | from pytorch_lightning import loggers as pl_loggers 12 | 13 | from trainers.Semantic_Segmentation_Trainer import SegModel 14 | from src.utils import has_not_empty_attr, log_hyperparameters, get_logger 15 | 16 | 17 | log = get_logger(__name__) 18 | 19 | 20 | @hydra.main(config_path="config", config_name="testing", version_base="1.3") 21 | def testing(cfg: DictConfig) -> None: 22 | """ 23 | Running the Testing/Validation 24 | Using the ckpt_dir as Working Directory 25 | Load the hydra overrides from the ckpt_dir 26 | Compose config from config/testing.yaml with overwrites from the checkpoint and the overwrites 27 | from commandline 28 | (Optional) include Overrides defined in the config (TRAINING.OVERRIDES) 29 | Load Model, Datamodule, Logger and Trainer 30 | Run testing 31 | 32 | Parameters 33 | ---------- 34 | cfg : DictConfig 35 | cfg given by hydra - build from config/testing.yaml + commandline arguments 36 | """ 37 | # Save overrides from the commandline for the current run 38 | overrides_cl = hydra.core.hydra_config.HydraConfig.get().overrides.task 39 | # Load overrides from the experiment in the checkpoint dir 40 | overrides_ckpt = OmegaConf.load(os.path.join("hydra", "overrides.yaml")) 41 | 42 | # Compose config by override with overrides_ckpt, afterwards override with overrides_cl 43 | cfg = hydra.compose(config_name="testing", overrides=overrides_ckpt + overrides_cl) 44 | 45 | # Get the TESTING.OVVERRIDES to check if additional parameters should be changed 46 | if has_not_empty_attr(cfg, "TESTING"): 47 | if has_not_empty_attr(cfg.TESTING, "OVERRIDES"): 48 | overrides_test = cfg.TESTING.OVERRIDES 49 | # Compose config again with including the new overrides 50 | cfg = hydra.compose( 51 | config_name="testing", 52 | overrides=overrides_ckpt + overrides_test + overrides_cl, 53 | ) 54 | 55 | # Load the best checkpoint and load the model 56 | log.info("Working Directory: %s", os.getcwd()) 57 | ckpt_file = glob.glob(os.path.join("checkpoints", "best_*"))[0] 58 | log.info("Checkpoint Directory: %s", ckpt_file) 59 | 60 | model = SegModel.load_from_checkpoint(ckpt_file, model_config=cfg, strict=False) 61 | 62 | # Load the datamodule 63 | dataModule = hydra.utils.instantiate(cfg.datamodule, _recursive_=False) 64 | 65 | # Instantiate callbacks 66 | callbacks = [] 67 | for _, cb_conf in cfg.CALLBACKS.items(): 68 | if cb_conf is not None: 69 | cb = hydra.utils.instantiate(cb_conf) 70 | callbacks.append(cb) 71 | 72 | tb_logger = pl_loggers.TensorBoardLogger( 73 | save_dir="testing", name="", version="", default_hp_metric=False 74 | ) 75 | 76 | # Parsing the pl_trainer args and instantiate the trainers 77 | trainer_args = getattr(cfg, "pl_trainer") if has_not_empty_attr(cfg, "pl_trainer") else {} 78 | trainer = Trainer(callbacks=callbacks, logger=tb_logger, **trainer_args) 79 | 80 | # Log experiment 81 | log_hyperparameters(cfg, model, trainer) 82 | 83 | # Run testing/validation 84 | trainer.test(model, dataModule) 85 | 86 | 87 | if __name__ == "__main__": 88 | testing() 89 | -------------------------------------------------------------------------------- /tools/Readme.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | 6 | # Tools 7 | 8 | The ``tools/`` folder contains some useful tools for developing and experimenting. 9 | It is not guaranteed that these tools will work for all kind of use-cases, datasets and datatypes but even then 10 | they can be used as a starting point and can be adapted with a few changes. 11 | These scripts support the general Hydra override syntax as described in [config](../config), together with some additional arguments (these can be used via the argparse syntax with a `--` prefix). 12 | For each tool use the ``-h`` flag (e.g. ``python tools/show_data.py -h``) to see all additional options. 13 | 14 | ### show_data.py 15 | Load and Visualize the pytorch dataset which is defined in the dataset config. Can be used for data inspection and to view different data augmentation pipelines. 16 | - dataset: Name of the dataset config (see [here](#selecting-a-dataset)) 17 | - --augmentation: Which augmentations to use: None (by default), train, val or test. If A.Normalization(std=...,mean=...) is part of the augmentations, this will be undone during visualization to get a better interpretable image 18 | - --split: which split to use: train, val or test Dataset (train by default) 19 | - --segmentation: type of segmentation - semantic or instance, depending on the dataset 20 | ````shell 21 | pyhton tools/show_data.py dataset= 22 | pyhton tools/show_data.py dataset=Cityscapes --split=val --augmentation=val 23 | ```` 24 | 25 | ### show_prediction.py 26 | Show the predictions of a trained model. Basically has the same syntax 27 | as the [validation/testing](#run-validationtesting), but visualizes the result instead of calculating 28 | metrics. The Test Dataset is used for predicting by default together with the train data-augmentations. For a 29 | nicer appearance the normalization operation is undone during visualization (not for prediction). 30 | (Note, depending on the size of the input, the inference time of the model and the available hardware, 31 | there might be a delay when sliding through the images) 32 | - ckpt_dir: path to the checkpoint which should be used 33 | - --augmentation: Which augmentations to use: train, val or test (by default). If A.Normalization(std=...,mean=...) is part of the augmentations, this will be undone during visualization to get a better interpretable image 34 | - --split: which split to use: train, val or test Dataset (train by default) 35 | - --segmentation: type of segmentation - semantic or instance, depending on the dataset 36 | 37 | ````shell 38 | python tools/show_prediction.py ckpt_dir= 39 | python tools/show_prediction.py ckpt_dir=ckpt_dir="/../Semantic_Segmentation/logs/VOC2010_Context/hrnet/baseline_/2022-02-15_13-51-42" --split=test --augmentation=test 40 | ```` 41 | 42 | ### dataset_stats.py 43 | Getting some basic stats and visualizations about the dataset like: mean and std for each channel, appearances and ratio of classes and potential class weights. 44 | Also, the color encoding of classes is visualized (can be useful for the show_data or show_prediction scripts) 45 | The output will be saved in *dataset_stats/dataset_name/*. If no output is wanted use the supress_output flag 46 | - dataset: Name of the dataset config (see [here](#selecting-a-dataset)) 47 | - --name: Prefix for naming the results 48 | - --split: which split to use: train, val or test 49 | - --img_only: Flag for only analyse Image Data 50 | - --mask_only: Flag for only analyse Mask Data 51 | - --supress_output: Flag to supress creation of an output directory and output files 52 | ````shell 53 | python tools/dataset_stats.py dataset= 54 | pyhton tools/dataset_stats.py dataset=Cityscapes 55 | ```` 56 | 57 | ### lr_finder.py 58 | Implementation to use pytorch lightning's [Learning Rate Finder](https://pytorch-lightning.readthedocs.io/en/1.4.0/advanced/lr_finder.html) 59 | to get some guidance when choosing an optimal initial lr (Should be used with caution, especially if random augmentations are used). 60 | - --num_training_samples: number of batches which are used from the lr finder (100 by default) 61 | ````shell 62 | python tools/lr_finder.py 63 | pyhton tools/lr_finder.py dataset=Cityscapes model=hrnet --num_training=300 64 | ```` 65 | 66 | # Acknowledgements 67 | 68 |

69 |      70 | 71 |

72 | 73 | This Repository is developed and maintained by the Applied Computer Vision Lab (ACVL) 74 | of [Helmholtz Imaging](https://www.helmholtz-imaging.de/). -------------------------------------------------------------------------------- /tools/lr_finder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | import sys 5 | 6 | import matplotlib.pyplot as plt 7 | 8 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | logging.basicConfig(level=logging.INFO) 10 | 11 | from pytorch_lightning import loggers as pl_loggers 12 | from pytorch_lightning import Trainer 13 | import hydra 14 | import torch 15 | 16 | from trainers.Semantic_Segmentation_Trainer import SegModel 17 | 18 | from src.utils import has_true_attr, has_not_empty_attr, get_logger, num_gpus 19 | 20 | 21 | log = get_logger(__name__) 22 | 23 | 24 | def find_lr(overrides_cl: list, num_training_samples: int) -> None: 25 | """ 26 | Implementation for using Pytorch Lightning learning rate finder 27 | 28 | Parameters 29 | ---------- 30 | overrides_cl : list of strings 31 | arguments from the commandline to overwrite the config 32 | num_training_samples: int 33 | how many samples to use for lr finding 34 | """ 35 | # initialize hydra 36 | hydra.initialize(config_path="../config", version_base="1.1") 37 | 38 | overrides_cl.append("ORG_CWD=./") 39 | cfg = hydra.compose(config_name="training", overrides=overrides_cl) 40 | if os.getcwd().endswith("tools"): 41 | cfg.ORG_CWD="../" 42 | else: 43 | cfg.ORG_CWD = "/" 44 | callbacks = [] 45 | for _, cb_conf in cfg.CALLBACKS.items(): 46 | if cb_conf is not None: 47 | cb = hydra.utils.instantiate(cb_conf) 48 | callbacks.append(cb) 49 | # Adding a Checkpoint Callback if checkpointing is enabled 50 | if has_true_attr(cfg.pl_trainer, "enable_checkpointing"): 51 | cfg.pl_trainer.enable_checkpointing = False 52 | 53 | # Using tensorboard logger 54 | tb_logger = pl_loggers.TensorBoardLogger( 55 | save_dir=".", name="", version="", default_hp_metric=False 56 | ) 57 | 58 | # Logging information about gpu setup 59 | avail_GPUS = torch.cuda.device_count() 60 | selected_GPUS = cfg.pl_trainer.devices 61 | number_gpus = num_gpus(avail_GPUS, selected_GPUS) 62 | 63 | log.info("Available GPUs: %s - %s", avail_GPUS, torch.cuda.get_device_name()) 64 | log.info("Number of used GPUs: %s Selected GPUs: %s", number_gpus, cfg.pl_trainer.devices) 65 | log.info("CUDA version: %s", torch._C._cuda_getCompiledVersion()) 66 | 67 | # Defining the datamodule 68 | dataModule = hydra.utils.instantiate(cfg.datamodule, _recursive_=False) 69 | 70 | # Defining model and load checkpoint if wanted 71 | # cfg.finetune_from should be the path to a .ckpt file 72 | # if has_not_empty_attr(cfg, "finetune_from"): 73 | # log.info("finetune from: %s", cfg.finetune_from) 74 | # model = SegModel.load_from_checkpoint(cfg.finetune_from, strict=False, config=cfg) 75 | # else: 76 | # model = SegModel(config=cfg) 77 | if hasattr(cfg, "num_example_predictions"): 78 | cfg.num_example_predictions = 0 79 | if has_not_empty_attr(cfg, "finetune_from"): 80 | log.info("finetune from: %s", cfg.finetune_from) 81 | cfg.trainermodule._target_ += ".load_from_checkpoint" 82 | model = hydra.utils.call( 83 | cfg.trainermodule, cfg.finetune_from, strict=False, model_config=cfg, _recursive_=False 84 | ) 85 | # model = SegModel.load_from_checkpoint(cfg.finetune_from, strict=False, config=cfg) 86 | else: 87 | # model = SegModel(config=cfg) 88 | model = hydra.utils.instantiate(cfg.trainermodule, cfg, _recursive_=False) 89 | # Initializing trainers 90 | trainer_args = getattr(cfg, "pl_trainer") if has_not_empty_attr(cfg, "pl_trainer") else {} 91 | 92 | # ddp=DDPPlugin(find_unused_parameters=False) if number_gpus > 1 else None 93 | trainer = Trainer( 94 | callbacks=callbacks, 95 | logger=tb_logger, 96 | strategy="ddp" if number_gpus > 1 else None, 97 | sync_batchnorm=True if number_gpus > 1 else False, 98 | auto_lr_find="config.lr", 99 | **trainer_args, 100 | ) 101 | lr_finder = trainer.tuner.lr_find( 102 | model=model, 103 | datamodule=dataModule, 104 | num_training=num_training_samples, 105 | ) 106 | print("lr suggestion: ", lr_finder.suggestion()) 107 | 108 | fig = lr_finder.plot(suggest=True) 109 | plt.show() 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument( 115 | "--num_training_samples", 116 | type=int, 117 | default=100, 118 | help="how many samples to use for lr finding", 119 | ) 120 | args, overrides = parser.parse_known_args() 121 | num_training_samples = args.num_training_samples 122 | find_lr(overrides, num_training_samples) 123 | -------------------------------------------------------------------------------- /tools/predict.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | 10 | import hydra 11 | import argparse 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2, 40).__str__() 17 | import cv2 18 | 19 | cv2.setNumThreads(0) 20 | 21 | from src.utils import get_logger 22 | 23 | import albumentations as A 24 | from albumentations.pytorch import ToTensorV2 25 | 26 | log = get_logger(__name__) 27 | 28 | 29 | def predict_img( 30 | image, 31 | model, 32 | test_time_augmentation=True, 33 | save_probabilities=False, 34 | ): 35 | # Run inference 36 | with torch.no_grad(): 37 | # for patch, patch_indices in tqdm(loader, disable=no_tqdm): 38 | image = image.cuda() 39 | image = image.unsqueeze(0) 40 | patch_prediction = model(image) 41 | 42 | if test_time_augmentation: 43 | patch_prediction += torch.flip(model(torch.flip(image.clone(), [2])), [2]) 44 | patch_prediction += torch.flip(model(torch.flip(image.clone(), [3])), [3]) 45 | patch_prediction += torch.flip(model(torch.flip(image.clone(), [2, 3])), [2, 3]) 46 | patch_prediction /= 4 47 | 48 | patch_prediction = patch_prediction.cpu().numpy() 49 | patch_prediction_argmax = patch_prediction.argmax(1).squeeze(0) 50 | if not save_probabilities: 51 | return patch_prediction_argmax 52 | else: 53 | patch_prediction_softmax = np.array( 54 | F.softmax(torch.tensor(patch_prediction.squeeze(0)), -3, _stacklevel=5) 55 | ) 56 | 57 | return patch_prediction_argmax, patch_prediction_softmax 58 | 59 | 60 | def predict(input_dir, output_dir, overrides, use_tta, save_probabilities=False): 61 | hydra.initialize(config_path="../config", version_base="1.1") 62 | cfg = hydra.compose(config_name="baseline", overrides=overrides) 63 | model = hydra.utils.instantiate(cfg.model) 64 | model.eval().to("cuda") 65 | 66 | os.makedirs(output_dir, exist_ok=True) 67 | 68 | img_files = glob.glob(os.path.join(input_dir, "*.png")) 69 | 70 | log.info("{} files found".format(len(img_files))) 71 | for img_file in img_files: 72 | log.info("process: {}".format(img_file)) 73 | file_name = img_file.rsplit("/", 1)[1].rsplit(".", 1)[0] 74 | output_file = os.path.join(output_dir, file_name + ".png") 75 | if os.path.exists(output_file): 76 | continue 77 | 78 | image = cv2.imread(img_file) 79 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 80 | 81 | transform = A.Compose( 82 | [A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2()] 83 | ) 84 | image = transform(image=image)["image"] 85 | # image = image.transpose(2, 0, 1) 86 | if not save_probabilities: 87 | prediction = predict_img( 88 | image, 89 | model, 90 | test_time_augmentation=use_tta, 91 | save_probabilities=save_probabilities, 92 | ) 93 | cv2.imwrite(os.path.join(output_dir, file_name + ".png"), np.array(prediction)) 94 | elif save_probabilities: 95 | prediction, sm = predict_img( 96 | image, 97 | model, 98 | test_time_augmentation=use_tta, 99 | save_probabilities=save_probabilities, 100 | ) 101 | cv2.imwrite(os.path.join(output_dir, file_name + ".png"), np.array(prediction)) 102 | np.savez( 103 | os.path.join( 104 | output_dir, 105 | file_name + ".npz", 106 | ), 107 | probabilities=sm, 108 | ) 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument( 114 | "-i", 115 | "--input", 116 | help="Input", 117 | default="/home/l727r/Desktop/Datasets/Diadem_example/imgs", 118 | ) 119 | parser.add_argument( 120 | "-o", "--output", help="Output", default="/home/l727r/Desktop/Datasets/Diadem_example/preds" 121 | ) 122 | parser.add_argument( 123 | "--no_tta", 124 | action="store_true", 125 | help="No TQDM", 126 | ) 127 | parser.add_argument( 128 | "--save_probabilities", 129 | action="store_true", 130 | help="Store Softmax probabilities", 131 | ) 132 | 133 | args, overrides = parser.parse_known_args() 134 | predict(args.input, args.output, overrides, not args.no_tta, args.save_probabilities) 135 | -------------------------------------------------------------------------------- /tools/show_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | import hydra 10 | from omegaconf import OmegaConf 11 | from datasets.DataModules import get_augmentations_from_config 12 | import albumentations as A 13 | from albumentations.pytorch import ToTensorV2 14 | import cv2 15 | import numpy as np 16 | from matplotlib import cm 17 | 18 | from src.utils import get_logger 19 | from src.visualization import Visualizer 20 | 21 | log = get_logger(__name__) 22 | 23 | 24 | def show_data( 25 | overrides_cl: list, augmentation: str, split: str, segmentation: str, axis: int 26 | ) -> None: 27 | """ 28 | Visualizing a Dataset 29 | initializing the dataset defined in the config 30 | display img + mask using opencv 31 | 32 | Parameters 33 | ---------- 34 | overrides_cl : list 35 | arguments from commandline to overwrite hydra config 36 | """ 37 | # Init and Compose Hydra to get the config 38 | hydra.initialize(config_path="../config", version_base="1.3") 39 | cfg = hydra.compose(config_name="training", overrides=overrides_cl) 40 | 41 | # Define Colormap and basic Transforms and instantiate the dataset 42 | color_map = "viridis" 43 | cmap = np.array(cm.get_cmap(color_map, cfg.DATASET.NUM_CLASSES).colors * 255, dtype=np.uint8)[ 44 | :, 0:3 45 | ] 46 | 47 | OmegaConf.set_struct(cfg, False) 48 | if augmentation is None: 49 | transforms = A.Compose([ToTensorV2()]) 50 | elif augmentation == "train": 51 | transforms = get_augmentations_from_config(cfg.AUGMENTATIONS.TRAIN)[0] 52 | elif augmentation == "val": 53 | transforms = get_augmentations_from_config(cfg.AUGMENTATIONS.VALIDATION)[0] 54 | elif augmentation == "test": 55 | transforms = get_augmentations_from_config(cfg.AUGMENTATIONS.TEST)[0] 56 | 57 | dataset = hydra.utils.instantiate(cfg.dataset, split=split, transforms=transforms) 58 | 59 | # check if data is normalized, if yes redo this during visualization of the image 60 | mean = None 61 | std = None 62 | for t in transforms.transforms: # .transforms: 63 | if isinstance(t, A.Normalize): 64 | mean = t.mean 65 | std = t.std 66 | break 67 | 68 | visualizer = Visualizer(dataset, cmap, mean=mean, std=std, segmentation=segmentation, axis=axis) 69 | 70 | # Create the cv2 Window 71 | cv2.namedWindow("Window", cv2.WINDOW_NORMAL) 72 | cv2.resizeWindow("Window", 1200, 1200) 73 | 74 | # Create Trackbar for Image Id and alpha value 75 | cv2.createTrackbar("Image ID", "Window", 0, len(dataset) - 1, visualizer.update_window) 76 | cv2.createTrackbar("alpha", "Window", 0, 100, visualizer.update_alpha) 77 | 78 | # look at the first image to get the number of channels 79 | img, _ = dataset[0] 80 | if len(img.shape) == 2: 81 | channels = 2 82 | else: 83 | channels = img.shape[0] 84 | # Create the Trackbar for the Channel 85 | cv2.createTrackbar("Channel", "Window", -1, channels - 1, visualizer.update_channel_and_alpha) 86 | cv2.setTrackbarMin("Channel", "Window", -1) 87 | cv2.setTrackbarPos("Channel", "Window", -1) 88 | 89 | # show the first image in window and start loop 90 | visualizer.update_window() 91 | print("press q to quit") 92 | while True: 93 | k = cv2.waitKey() 94 | if k == 113: 95 | break 96 | elif k == 115: 97 | 98 | img_id = cv2.getTrackbarPos("Image ID", "Window") 99 | file_name = f"{cfg.DATASET.NAME}__ID{img_id}" 100 | os.makedirs("dataset_visualizations", exist_ok=True) 101 | 102 | print(f"Save {file_name}") 103 | 104 | img = cv2.cvtColor(visualizer.img_np_fig, cv2.COLOR_RGB2BGR) 105 | mask = cv2.cvtColor(visualizer.mask_np, cv2.COLOR_RGB2BGR) 106 | 107 | cv2.imwrite(os.path.join("dataset_visualizations", file_name + "__image.png"), img) 108 | cv2.imwrite(os.path.join("dataset_visualizations", file_name + "__mask.png"), mask) 109 | 110 | cv2.destroyAllWindows() 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument( 116 | "--augmentation", 117 | type=str, 118 | default=None, 119 | help="Which augmentations to use: None (by default), train, val or test", 120 | ) 121 | parser.add_argument( 122 | "--split", 123 | type=str, 124 | default="train", 125 | help="which split to use: train (by default), val or test", 126 | ) 127 | parser.add_argument( 128 | "--segmentation", 129 | type=str, 130 | default="semantic", 131 | help="semantic or instance, depending on the dataset", 132 | ) 133 | parser.add_argument( 134 | "--axis", 135 | type=int, 136 | default=1, 137 | help="1 for displaying images side by side, 0 for displaying images on top of each other", 138 | ) 139 | args, overrides = parser.parse_known_args() 140 | augmentation = args.augmentation 141 | split = args.split 142 | segmentation = args.segmentation 143 | axis = args.axis 144 | 145 | show_data(overrides, augmentation, split, segmentation, axis) 146 | -------------------------------------------------------------------------------- /tools/show_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import glob 4 | import logging 5 | import sys 6 | 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | logging.basicConfig(level=logging.INFO) 9 | 10 | import hydra 11 | from omegaconf import OmegaConf 12 | import torch 13 | import numpy as np 14 | import albumentations as A 15 | from albumentations.pytorch import ToTensorV2 16 | import cv2 17 | from matplotlib import cm 18 | 19 | from trainers.Semantic_Segmentation_Trainer import SegModel 20 | from trainers.Instance_Segmentation_Trainer import InstModel 21 | 22 | from src.utils import has_not_empty_attr, get_logger 23 | from datasets.DataModules import get_augmentations_from_config 24 | 25 | log = get_logger(__name__) 26 | 27 | from src.visualization import Visualizer 28 | 29 | 30 | def show_prediction( 31 | overrides_cl: list, augmentation: str, split: str, segmentation: str, axis: int 32 | ) -> None: 33 | """ 34 | Show Model Predictions 35 | Load Model and Dataset from the checkpoint(ckpt_dir) 36 | Show predictions of the model for the images in a small GUI 37 | 38 | Parameters 39 | ---------- 40 | overrides_cl : list of strings 41 | arguments from the commandline to overwrite the config 42 | """ 43 | # initialize hydra 44 | hydra.initialize(config_path="../config", version_base="1.1") 45 | 46 | # change working dir to checkpoint dir 47 | if os.getcwd().endswith("tools"): 48 | ORG_CWD = os.path.join(os.getcwd(),"..") 49 | else: 50 | ORG_CWD = os.getcwd() 51 | 52 | ckpt_dir = None 53 | for override in overrides_cl: 54 | if override.startswith("ckpt_dir"): 55 | ckpt_dir = override.split("=", 1)[1] 56 | break 57 | if ckpt_dir is None: 58 | log.error( 59 | "ckpt_dir has to be in th config. Run python show_prediction.py ckpt_dir=" 60 | ) 61 | quit() 62 | os.chdir(ckpt_dir) 63 | 64 | # load overrides from the experiment in the checkpoint dir 65 | overrides_ckpt = OmegaConf.load(os.path.join("hydra", "overrides.yaml")) 66 | # compose config by override with overrides_ckpt, afterwards override with overrides_cl 67 | cfg = hydra.compose(config_name="testing", overrides=overrides_ckpt + overrides_cl) 68 | 69 | # Get the TESTING.OVERRIDES to check if additional parameters should be changed 70 | if has_not_empty_attr(cfg, "TESTING"): 71 | if has_not_empty_attr(cfg.TESTING, "OVERRIDES"): 72 | overrides_test = cfg.TESTING.OVERRIDES 73 | # Compose config again with including the new overrides 74 | cfg = hydra.compose( 75 | config_name="testing", 76 | overrides=overrides_ckpt + overrides_test + overrides_cl, 77 | ) 78 | 79 | # load the best checkpoint and load the model 80 | cfg.ORG_CWD = ORG_CWD 81 | ckpt_file = glob.glob(os.path.join("checkpoints", "best_*"))[0] 82 | #if hasattr(cfg.MODEL, "PRETRAINED"): 83 | # cfg.MODEL.PRETRAINED = False 84 | if segmentation == "semantic": 85 | model = SegModel.load_from_checkpoint(ckpt_file, model_config=cfg, strict=False).cuda() 86 | elif segmentation == "instance": 87 | model = InstModel.load_from_checkpoint(ckpt_file, model_config=cfg, strict=False).cuda() 88 | # model = SegModel.load_from_checkpoint(ckpt_file, config=cfg).cuda() 89 | # print(cfg) 90 | # print(cfg.model) 91 | 92 | # model=hydra.utils.instantiate(cfg.model).cuda() 93 | OmegaConf.set_struct(cfg, False) 94 | if augmentation == "train": 95 | transforms = get_augmentations_from_config(cfg.AUGMENTATIONS.TRAIN)[0] 96 | elif augmentation == "val": 97 | transforms = get_augmentations_from_config(cfg.AUGMENTATIONS.VALIDATION)[0] 98 | elif augmentation == "test": 99 | transforms = get_augmentations_from_config(cfg.AUGMENTATIONS.TEST)[0] 100 | else: 101 | transforms = A.Compose([ToTensorV2()]) 102 | 103 | # instantiate dataset 104 | dataset = hydra.utils.instantiate(cfg.dataset, split=split, transforms=transforms) 105 | 106 | # check if data is normalized, if yes redo this during visualization of the image 107 | mean = None 108 | std = None 109 | for t in transforms.transforms: # .transforms: 110 | if isinstance(t, A.Normalize): 111 | mean = t.mean 112 | std = t.std 113 | break 114 | 115 | # define colormap 116 | color_map = "viridis" 117 | cmap = np.array(cm.get_cmap(color_map, cfg.DATASET.NUM_CLASSES).colors * 255, dtype=np.uint8)[ 118 | :, 0:3 119 | ] 120 | 121 | # init visualizer 122 | visualizer = Visualizer( 123 | dataset, cmap, model, mean=mean, std=std, segmentation=segmentation, axis=axis 124 | ) 125 | 126 | # create window 127 | cv2.namedWindow("Window", cv2.WINDOW_NORMAL) 128 | cv2.resizeWindow("Window", 1200, 1200) 129 | 130 | # Create Trackbar for Image Id and alpha value 131 | cv2.createTrackbar("Image ID", "Window", 0, len(dataset) - 1, visualizer.update_window) 132 | 133 | cv2.createTrackbar("alpha", "Window", 50, 100, visualizer.update_alpha) 134 | if segmentation == "semantic": 135 | cv2.createTrackbar("correctness", "Window", 0, 100, visualizer.update_alpha) 136 | 137 | # look at the first image to get the number of channels 138 | img, _ = dataset[0] 139 | if len(img.shape) == 2: 140 | channels = 2 141 | else: 142 | channels = img.shape[0] 143 | 144 | # Create the Trackbar for the Channel Parameter 145 | cv2.createTrackbar("Channel", "Window", -1, channels - 1, visualizer.update_channel_and_alpha) 146 | cv2.setTrackbarMin("Channel", "Window", -1) 147 | cv2.setTrackbarPos("Channel", "Window", -1) 148 | 149 | # show the first image in window and start loop 150 | model.eval() 151 | with torch.no_grad(): 152 | visualizer.update_window() 153 | print("press q to quit") 154 | while True: 155 | k = cv2.waitKey() 156 | if k == 113: 157 | break 158 | cv2.destroyAllWindows() 159 | 160 | 161 | if __name__ == "__main__": 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument( 164 | "--augmentation", 165 | type=str, 166 | default="test", 167 | help="Which augmentations to use: train, val or test (by default)", 168 | ) 169 | parser.add_argument( 170 | "--split", 171 | type=str, 172 | default="test", 173 | help="which split to use: train, val or test (by default)", 174 | ) 175 | parser.add_argument( 176 | "--segmentation", 177 | type=str, 178 | default="semantic", 179 | help="semantic or instance, depending on the dataset", 180 | ) 181 | parser.add_argument( 182 | "--axis", 183 | type=int, 184 | default=1, 185 | help="1 for displaying images side by side, 0 for displaying images on top of each other", 186 | ) 187 | args, overrides = parser.parse_known_args() 188 | augmentation = args.augmentation 189 | split = args.split 190 | segmentation = args.segmentation 191 | axis = args.axis 192 | 193 | show_prediction(overrides, augmentation, split, segmentation, axis) 194 | -------------------------------------------------------------------------------- /trainers/Instance_Segmentation_Trainer.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from src.metric import MetricModule 7 | from src.utils import get_logger 8 | from trainers.Semantic_Segmentation_Trainer import SegModel 9 | from src.visualization import show_prediction_inst_seg, show_mask_inst_seg, show_img 10 | 11 | log = get_logger(__name__) 12 | 13 | 14 | class InstModel(SegModel): 15 | def __init__(self, model_config: DictConfig) -> None: 16 | """ 17 | __init__ the LightningModule 18 | instantiate the model and the metric(s) 19 | 20 | Parameters 21 | ---------- 22 | config : omegaconf.DictConfig 23 | """ 24 | super().__init__(model_config) 25 | if self.train_metric: 26 | log.info( 27 | "Training Metric for Instance Segmentation is not supported and is set to False" 28 | ) 29 | self.train_metric = False 30 | 31 | def forward(self, x: torch.Tensor, gt=None) -> dict: 32 | """ 33 | forward the input to the model 34 | 35 | Parameters 36 | ---------- 37 | x : torch.Tensor 38 | input to predict 39 | gt : dict of {str:torch.Tensor} 40 | target, only used for training to compute the loss 41 | 42 | Returns 43 | ------- 44 | if training 45 | torch.Tensor: training loss 46 | else: 47 | dict of {str:torch.Tensor} : 48 | prediction of the model containing masks, boxes, scores and labels 49 | """ 50 | if self.training: 51 | x = self.model(x, gt) 52 | else: 53 | x = self.model(x) 54 | 55 | return x 56 | 57 | def training_step(self, batch: list, batch_idx: int) -> torch.Tensor: 58 | """ 59 | Forward the image through the model and compute the loss 60 | (optional) update the metric stepwise of global (defined by metric_call parameter) 61 | 62 | Parameters 63 | ---------- 64 | batch : list of torch.Tensor 65 | contains img (shape==[batch_size,num_classes,w,h]) and mask (shape==[batch_size,w,h]) 66 | batch_idx : int 67 | index of the batch 68 | 69 | Returns 70 | ------- 71 | torch.Tensor : 72 | training loss 73 | """ 74 | # predict batch 75 | x, y_gt = batch 76 | loss_dict = self(x, y_gt) 77 | 78 | loss = sum(l for l in loss_dict.values()) 79 | 80 | # compute and log loss 81 | self.log( 82 | "Loss/training_loss", 83 | loss, 84 | on_step=True, 85 | on_epoch=True, 86 | logger=True, 87 | sync_dist=True if self.trainer.num_devices > 1 else False, 88 | ) 89 | 90 | return loss 91 | 92 | def validation_step(self, batch: list, batch_idx: int) -> torch.Tensor: 93 | """ 94 | Forward the image through the model and compute the loss 95 | update the metric stepwise of global (defined by metric_call parameter) 96 | 97 | Parameters 98 | ---------- 99 | batch : list of dicts 100 | batch_idx : int 101 | index of the batch 102 | 103 | Returns 104 | ------- 105 | """ 106 | 107 | # predict batch 108 | x, y_gt = batch 109 | y_pred = self(x) 110 | 111 | # update validation metric 112 | # self.update_metric(y_pred, y_gt, self.metric, prefix="metric/") 113 | self.update_metric(y_pred, y_gt, self.metric, prefix="metric/") 114 | 115 | # log some example predictions to tensorboard 116 | # ensure that exactly self.num_example_predictions examples are taken 117 | if self.global_rank == 0 and not self.trainer.sanity_checking: 118 | self.log_batch_prediction(x, y_pred, y_gt, batch_idx) 119 | 120 | def on_test_start(self) -> None: 121 | pass 122 | 123 | def test_step(self, batch: list, batch_idx: int) -> torch.Tensor: 124 | """ 125 | copy of validation step 126 | """ 127 | # predict batch 128 | x, y_gt = batch 129 | y_pred = self(x) 130 | 131 | # update validation metric 132 | self.update_metric(y_pred, y_gt, self.metric, prefix="metric_test/") 133 | 134 | # log some example predictions to tensorboard 135 | if self.global_rank == 0 and not self.trainer.sanity_checking: 136 | self.log_batch_prediction(x, y_pred, y_gt, batch_idx) 137 | 138 | def update_metric( 139 | self, y_pred: torch.Tensor, y_gt: torch.Tensor, metric: MetricModule, prefix: str = "" 140 | ): 141 | 142 | for y in y_pred: 143 | y["masks"] = y["masks"].squeeze(1) 144 | for i in range(0, len(y["masks"])): 145 | x = torch.where(y["masks"][i] >= 0.5, 1, 0) 146 | y["masks"][i] = x 147 | y["masks"] = y["masks"].type(torch.uint8) 148 | 149 | if self.metric_call_stepwise: 150 | # Log the metric result for each step 151 | metric_step = metric(y_pred, y_gt) 152 | # exclude nan since pl uses torch.mean for reduction, this way torch.nanmean is simulated 153 | metric_step = {k: v for k, v in metric_step.items() if not torch.isnan(v)} 154 | self.log_dict_epoch( 155 | metric_step, 156 | prefix=prefix, 157 | postfix="_stepwise", 158 | on_step=False, 159 | on_epoch=True, 160 | ) 161 | elif self.metric_call_per_img: 162 | # If metric should be called per img, iterate through the batch to compute and log the 163 | # metric for each img separately 164 | for yi_pred, yi_gt in zip(y_pred, y_gt): 165 | metric_step = metric(yi_pred.unsqueeze(0), yi_gt.unsqueeze(0)) 166 | # exclude nan since pl uses torch.mean for reduction, this way torch.nanmean is simulated 167 | metric_step = {k: v for k, v in metric_step.items() if not torch.isnan(v)} 168 | self.log_dict_epoch( 169 | metric_step, 170 | prefix=prefix, 171 | postfix="_per_img", 172 | on_step=False, 173 | on_epoch=True, 174 | ) 175 | elif self.metric_call_global: 176 | # Just update the metric 177 | metric.update(y_pred, y_gt) 178 | 179 | def get_loss(self, y_pred: dict, y_gt: torch.Tensor) -> torch.Tensor: 180 | pass 181 | 182 | def log_batch_prediction(self, imgs: list, preds: list, gts: list, batch_idx: int) -> None: 183 | """ 184 | logging example prediction and gt to tensorboard 185 | 186 | Parameters 187 | ---------- 188 | imgs: [torch.Tensor] 189 | pred : [dict] 190 | gt : [dict] 191 | batch_idx: int 192 | idx of the current batch, needed for naming of the predictions 193 | """ 194 | # Check if the current batch has to be logged, if yes how many images 195 | val_batch_size = self.trainer.datamodule.val_batch_size 196 | diff_to_show = self.num_example_predictions - (batch_idx * val_batch_size) 197 | if diff_to_show > 0: 198 | current_batche_size = len(imgs) 199 | # log the desired number of images 200 | for i in range(min(current_batche_size, diff_to_show)): 201 | img = imgs[i].detach().cpu() 202 | 203 | pred = preds[i] 204 | pred = [{k: v.detach().cpu() for k, v in pred.items()}] 205 | 206 | gt = gts[i] 207 | gt = [{k: v.detach().cpu() for k, v in gt.items()}] 208 | 209 | # colormap class labels and transform image 210 | pred = show_prediction_inst_seg(pred, img.shape[-2:], output_type="torch") 211 | gt = show_mask_inst_seg(gt[0], img.shape[-2:], output_type="torch") 212 | 213 | # Overlay mask with images, disabled since tensorboard logfiles get to large 214 | # img = show_img(img, mean=self.viz_mean, std=self.viz_std, output_type="torch") 215 | # alpha = 0.5 216 | # gt = (img * alpha + gt * (1 - alpha)).type(torch.uint8) 217 | # pred = (img * alpha + pred * (1 - alpha)).type(torch.uint8) 218 | 219 | # concat pred and gt for better visualization 220 | axis = 0 if gt.shape[1] > 2 * gt.shape[0] else 1 221 | fig = torch.cat((pred, gt), axis) 222 | 223 | # resize fig for not getting to large tensorboard-files 224 | w, h, c = fig.shape 225 | max_size = 1024 226 | if max(w, h) > max_size: 227 | s = max_size / max(w, h) 228 | 229 | fig = fig.permute(2, 0, 1).unsqueeze(0).float() 230 | fig = F.interpolate(fig, size=(int(w * s), int(h * s)), mode="nearest") 231 | fig = fig.squeeze(0).permute(1, 2, 0).to(torch.uint8) 232 | fig = fig.to(torch.uint8) 233 | # Log Figure to tensorboard 234 | self.trainer.logger.experiment.add_image( 235 | "Example_Prediction/prediction_gt__sample_" 236 | + str(batch_idx * val_batch_size + i), 237 | fig, 238 | self.current_epoch, 239 | dataformats="HWC", 240 | ) 241 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO) 4 | 5 | import os 6 | import hydra 7 | import torch 8 | from pytorch_lightning import loggers as pl_loggers 9 | from pytorch_lightning import Trainer, seed_everything 10 | from pytorch_lightning.strategies.ddp import DDPStrategy 11 | 12 | from src.utils import ( 13 | has_true_attr, 14 | has_not_empty_attr, 15 | get_logger, 16 | num_gpus, 17 | log_hyperparameters, 18 | ) 19 | from omegaconf import DictConfig, OmegaConf 20 | 21 | # from trainers.Semantic_Segmentation_Trainer import SegModel 22 | # from trainers.Instance_Segmentation_Trainer import InstSegModel as SegModel 23 | 24 | log = get_logger(__name__) 25 | 26 | 27 | # OmegaConf resolver for preventing problems in the output path 28 | # Removing all characters which can cause problems or are not wanted in a directory name 29 | OmegaConf.register_new_resolver( 30 | "path_formatter", 31 | lambda s: s.replace("[", "") 32 | .replace("]", "") 33 | .replace("}", "") 34 | .replace("{", "") 35 | .replace(")", "") 36 | .replace("(", "") 37 | .replace(",", "_") 38 | .replace("=", "_") 39 | .replace("/", ".") 40 | .replace("+", ""), 41 | ) 42 | 43 | 44 | @hydra.main(config_path="config", config_name="training", version_base="1.3") 45 | def training_loop(cfg: DictConfig): 46 | """ 47 | Running Training 48 | import Callbacks and initialize Logger 49 | Load Model, Datamodule and Trainer 50 | Train the model 51 | 52 | Parameters 53 | ---------- 54 | cfg : 55 | cfg given by hydra - build from config/training.yaml + commandline argumentss 56 | """ 57 | # for k, v in logging.Logger.manager.loggerDict.items(): 58 | # if not isinstance(v, logging.PlaceHolder): 59 | # print(k, v.handlers) 60 | log.info("Output Directory: %s", os.getcwd()) 61 | # Seeding if given by config 62 | if has_not_empty_attr(cfg, "seed"): 63 | seed_everything(cfg.seed, workers=True) 64 | 65 | # Importing callbacks using hydra 66 | callbacks = [] 67 | for _, cb_conf in cfg.CALLBACKS.items(): 68 | if cb_conf is not None: 69 | cb = hydra.utils.instantiate(cb_conf) 70 | callbacks.append(cb) 71 | # Adding a Checkpoint Callback if checkpointing is enabled 72 | if has_true_attr(cfg.pl_trainer, "enable_checkpointing"): 73 | callbacks.append(hydra.utils.instantiate(cfg.ModelCheckpoint)) 74 | 75 | # Using tensorboard logger 76 | # tb_logger = pl_loggers.TensorBoardLogger( 77 | # save_dir=".", name="", version="", default_hp_metric=False 78 | # ) 79 | tb_logger = hydra.utils.instantiate(cfg.logger) 80 | 81 | # Logging information about gpu setup 82 | avail_GPUS = torch.cuda.device_count() 83 | selected_GPUS = cfg.pl_trainer.devices 84 | number_gpus = num_gpus(avail_GPUS, selected_GPUS) 85 | log.info("Available GPUs: %s - %s", avail_GPUS, torch.cuda.get_device_name()) 86 | log.info( 87 | "Number of used GPUs: %s Selected GPUs: %s", 88 | number_gpus, 89 | cfg.pl_trainer.devices, 90 | ) 91 | log.info( 92 | "CUDA version: {} Pytorch version: {}".format( 93 | torch._C._cuda_getCompiledVersion(), torch.__version__ 94 | ) 95 | ) 96 | 97 | # Defining the datamodule 98 | dataModule = hydra.utils.instantiate(cfg.datamodule, _recursive_=False) 99 | 100 | # Defining model and load checkpoint if wanted 101 | # cfg.finetune_from should be the path to a .ckpt file 102 | if has_not_empty_attr(cfg, "finetune_from"): 103 | log.info("finetune from: %s", cfg.finetune_from) 104 | cfg.trainermodule._target_ += ".load_from_checkpoint" 105 | model = hydra.utils.call( 106 | cfg.trainermodule, cfg.finetune_from, strict=False, model_config=cfg, _recursive_=False 107 | ) 108 | # model = SegModel.load_from_checkpoint(cfg.finetune_from, strict=False, config=cfg) 109 | else: 110 | # model = SegModel(config=cfg) 111 | model = hydra.utils.instantiate(cfg.trainermodule, cfg, _recursive_=False) 112 | 113 | # Initializing trainers 114 | trainer_args = getattr(cfg, "pl_trainer") if has_not_empty_attr(cfg, "pl_trainer") else {} 115 | ddp = DDPStrategy(find_unused_parameters=False) # if number_gpus > 1 else None 116 | trainer = Trainer( 117 | callbacks=callbacks, 118 | logger=tb_logger, 119 | strategy=ddp if number_gpus > 1 else "auto", 120 | # strategy="ddp_find_unused_parameters_false" if number_gpus > 1 else None, 121 | sync_batchnorm=True if number_gpus > 1 else False, 122 | **trainer_args 123 | ) 124 | 125 | # Log experiment, if-statement is needed to catch fast_dev_run 126 | if not has_true_attr(cfg.pl_trainer, "fast_dev_run"): 127 | log_hyperparameters(cfg, model, trainer) 128 | 129 | # Start training 130 | trainer.fit( 131 | model, 132 | dataModule, 133 | ckpt_path=cfg.continue_from if hasattr(cfg, "continue_from") else None, 134 | ) 135 | 136 | 137 | if __name__ == "__main__": 138 | training_loop() 139 | --------------------------------------------------------------------------------