├── README.md ├── configs ├── ADE20K │ ├── hrnetocr_contrastive_ADE20K.json │ └── upnswin_contrastive_ADE20K.json ├── CITYSCAPES │ └── hrnet_contrastive_CTS.json ├── __init__.py └── path_info.json ├── data ├── data.csv └── data.pkl ├── datasets ├── ADE20K.py ├── CaDIS.py ├── Cityscapes.py ├── Dataset_from_df.py ├── PascalC.py └── __init__.py ├── env_dgx.yml ├── losses ├── DenseContrastiveLossV2.py ├── DenseContrastiveLossV2_ms.py ├── LossWrapper.py ├── LovaszSoftmax.py ├── TwoScaleLoss.py └── __init__.py ├── main.py ├── managers ├── BaseManager.py ├── DeepLabv3_Manager.py ├── HRNet_Manager.py ├── LoggingManager.py ├── OCRNet_Manager.py └── __init__.py ├── misc └── figs │ └── fig1-01-01.png ├── models ├── DeepLabv3.py ├── HRNet.py ├── OCR.py ├── Projector.py ├── Swin.py ├── TTAWrapperSlide.py ├── TTA_wrapper.py ├── TTA_wrapper_CTS.py ├── TTA_wrapper_PC.py ├── Transformers.py ├── UPerNet.py ├── __init__.py └── hrnet_config.py └── utils ├── __init__.py ├── checkpoint_utils.py ├── config_parsers.py ├── datasets_info ├── ADE20K.py ├── CADIS.py ├── CITYSCAPES.py ├── PASCALC.py └── __init__.py ├── defaults.py ├── df_from_data.py ├── distributed.py ├── logger.py ├── lr_functions.py ├── metrics.py ├── np_transforms.py ├── optimizer_utils.py ├── pointrend_utils.py ├── repeat_factor_sampling.py ├── semi_utis.py ├── torch_transforms.py ├── torch_utils.py ├── transforms.py ├── tsne_visualization.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # ECCV2022 Multi-scale and Cross-scale Contrastive Learning for Semantic Segmentation 2 | Implementation of "Multi-scale and Cross-scale Contrastive Learning for Semantic Segmentation", to appear at ECCV 2022 3 | 4 | arxiv link : https://arxiv.org/abs/2203.13409 5 | 6 | 7 | ![fig](misc/figs/fig1-01-01.png) 8 | 9 | > [**Multi-scale and Cross-scale Contrastive Learning for Semantic Segmentation**](https://arxiv.org/abs/2203.13409), 10 | > [Theodoros Pissas](https://rvim.online/author/theodoros-pissas/), [Claudio S. Ravasio](https://rvim.online/author/claudio-ravasio/), [Lyndon Da Cruz](), [Christos Bergeles](https://rvim.online/author/christos-bergeles/)
11 | > 12 | > *arXiv technical report ([arXiv 2203.13409](https://arxiv.org/abs/2203.13409))* 13 | > 14 | > *ECCV 2022 ([proceedings]())* 15 | 16 | ## Log 17 | - 20/07 loss code public 18 | - Coming soon: pretrained checkpoints and more configs for more models 19 | 20 | 21 | ## Data and requirements 22 | 1) Download datasets 23 | 2) Modify paths as per your setup in configs/paths_info.json to add path to a folder and a log_path and a data_path (see example in paths_info.json) 24 |
a) data_path should be the root directory of the datasets 25 |
b) log_path is where you want each run to generate a directory containing logs/checkpoints to be stored 26 | 3) Create conda environment with pytorch 1.7 and CUDA 10.0 27 | ```bash 28 | conda env create -f env_dgx.yml 29 | conda activate semseg 30 | ``` 31 | 32 | ## Train 33 | To train a model we specify most settings using json configuration files, found in ```configs```. 34 | For each model on each dataset uses its own config. We also specify a few settings from commandline (see main.py) 35 | and also can override config settings from the commandline (see main.py) 36 | Here we show commands to start training on 4 GPUs and with the settings used in the paper. 37 | 38 | Training with ResNet or HRNet backbones requires imagenet initialization which is handled by torchvision or downloaded from a url respectively. 39 | To train with Swin backbones we use the provided imagenet checkpoints from their official implementation https://github.com/microsoft/Swin-Transformer/. 40 | These must be downloaded in a directory called pytorch_checkpoints structured as follows: 41 | 42 | ``` 43 | pytorch_checkpoints/swin_imagenet/swin_tiny_patch4_window7_224.pth 44 | /swin_small_patch4_window7_224.pth 45 | /swin_base_patch4_window7_224.pth 46 | /swin_large_patch4_window7_224_22k.pth 47 | ``` 48 | Example commands to start training (d = cuda device ids, p = multigpu training bs = batch size, w = workers per gpu ): 49 | - For HRNet on Cityscapes: 50 | ```bash 51 | python main.py -d 0,1,2,3 -p -u theo -c configs/CITYSCAPES/hrnet_contrastive_CTS.json -bs 12 -w 3 52 | ``` 53 | - For UPerNet SwinT on ADE20K: 54 | ```bash 55 | python main.py -d 0,1,2,3 -p -u theo -c configs/upnswin_contrastive_ADE20K.json -bs 16 -w 4 56 | ``` 57 | 58 | [//]: # (## Run a pretrained model) 59 | 60 | [//]: # (- Example of how to run inference with pretrained model:) 61 | 62 | [//]: # ( ```bash) 63 | 64 | [//]: # ( python main.py -d 0 -u theo -c configs/ADE20K/upnswin_contrastive_ADE20K.json -bs 1 -w 4 -m inference -cpt 20220303_230257_e1__upn_alignFalse_projFpn_swinT_sbn_DCms_cs_epochs127_bs16 -so) 65 | 66 | [//]: # ( ```) 67 | 68 | ## Licensing and copyright 69 | 70 | Please see the LICENSE file for details. 71 | 72 | ## Acknowledgements 73 | 74 | This project utilizes [timm] and the official implementation of [swin] Transformer. 75 | We thank the authors of those projects for open-sourcing their code and model weights. 76 | 77 | [timm]: https://github.com/rwightman/pytorch-image-models 78 | 79 | [swin]: https://github.com/microsoft/Swin-Transformer/ 80 | 81 | ## Citation 82 | If you found the paper or code useful please cite the following: 83 | 84 | ``` 85 | @misc{https://doi.org/10.48550/arxiv.2203.13409, 86 | doi = {10.48550/ARXIV.2203.13409}, 87 | url = {https://arxiv.org/abs/2203.13409}, 88 | author = {Pissas, Theodoros and Ravasio, Claudio S. and Da Cruz, Lyndon and Bergeles, Christos}, 89 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, 90 | title = {Multi-scale and Cross-scale Contrastive Learning for Semantic Segmentation}, 91 | publisher = {arXiv}, 92 | year = {2022}, 93 | copyright = {Creative Commons Attribution Non Commercial Share Alike 4.0 International} 94 | } 95 | ``` 96 | -------------------------------------------------------------------------------- /configs/ADE20K/hrnetocr_contrastive_ADE20K.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hrnOCR", 3 | "mode": "training", 4 | "manager": "OCRNet", 5 | "graph": { 6 | "model": "OCRNet", 7 | "backbone": "hrnet48", 8 | "sync_bn":true, 9 | "out_stride": 4, 10 | "pretrained": true, 11 | "align_corners": true, 12 | "ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": true, "before_context": true} 13 | }, 14 | 15 | "load_last": true, 16 | "tta":true, 17 | "tta_scales": [0.5, 0.75, 1.25, 1.5, 1.75], 18 | "load_checkpoint_": "no", 19 | "run_final_val": false, 20 | 21 | "data": { 22 | "num_workers":8, 23 | "dataset": "ADE20K", 24 | "use_relabeled": false, 25 | "blacklist": false, 26 | "experiment": 1, 27 | "split": ["train", "val"], 28 | "transforms": ["flip", "random_scale", "RandomCropImgLbl", "colorjitter", "torchvision_normalise"], 29 | "transform_values": {"crop_shape": [512, 512], "crop_class_max_ratio": 0.75, 30 | "scale_range": [0.5, 2]}, 31 | "transforms_val": ["torchvision_normalise"], 32 | "transform_values_val": {"min_side_length": 512, 33 | "crop_class_max_ratio": 0.75, 34 | "fit_stride_val": 32}, 35 | "batch_size": 16 36 | }, 37 | 38 | "loss": { 39 | "name": "LossWrapper", 40 | "label_scaling_mode": "nn", 41 | "dominant_mode": "all", 42 | "temperature": 0.1, 43 | "cross_scale_contrast": true, 44 | "weights": [1, 0.7, 0.4, 0.1], 45 | "scales": 4, 46 | "interm": {"name": "CrossEntropyLoss", "args": [], "weight": 0.4}, 47 | "final": {"name": "CrossEntropyLoss", "args": [], "weight": 1.0}, 48 | "losses": {"TwoScaleLoss": 1, "DenseContrastiveLossV2_ms": 0.1}, 49 | 50 | "losses___": {"TwoScaleLoss": 1}, 51 | 52 | "min_views_per_class": 5, 53 | "max_views_per_class": 2500, 54 | "max_features_total": 10000 55 | }, 56 | 57 | "train": { 58 | "learning_rate": 0.02, 59 | "lr_fct": "polynomial", 60 | "optim": "SGD", 61 | "lr_batchwise": true, 62 | "epochs": 120, 63 | "momentum": 0.9, 64 | "weight_decay": 0.0001 65 | }, 66 | "max_valid_imgs": 2, 67 | "valid_freq": 10, 68 | "log_every_n_epochs": 25, 69 | "cuda": true, 70 | "gpu_device": 0, 71 | "parallel": false, 72 | "seed": 100 73 | } -------------------------------------------------------------------------------- /configs/ADE20K/upnswin_contrastive_ADE20K.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "upn", 3 | "mode": "training", 4 | "manager": "OCRNet", 5 | "graph": { 6 | "model": "UPerNet", 7 | "backbone": "swinT", 8 | "sync_bn":true, 9 | "out_stride": 32, 10 | "pretrained": false, 11 | "align_corners": false, 12 | "aux_head":{"in_index": 3, "dropout_rate": 0.1}, 13 | "dropout_rate" : 0.1, 14 | "ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": true, "position":"fpn"} 15 | }, 16 | 17 | "load_last": false, 18 | "tta":false, 19 | "tta_scales": [0.5, 0.75, 1.25, 1.5, 1.75], 20 | 21 | "data": { 22 | "num_workers":0, 23 | "dataset": "ADE20K", 24 | "use_relabeled": false, 25 | "blacklist": false, 26 | "experiment": 1, 27 | "split": ["train", "val"], 28 | "transforms": ["flip", "random_scale", "RandomCropImgLbl", "colorjitter", "torchvision_normalise"], 29 | "transform_values": {"crop_shape": [512, 512], "crop_class_max_ratio": 0.75, 30 | "scale_range": [0.5, 2]}, 31 | "transforms_val": ["resize_val", "torchvision_normalise"], 32 | "transform_values_val": {"min_side_length": 512, 33 | "crop_class_max_ratio": 0.75, 34 | "fit_stride_val": 32}, 35 | "batch_size": 16 36 | }, 37 | 38 | "loss": { 39 | "name": "LossWrapper", 40 | "label_scaling_mode": "nn", 41 | "dominant_mode": "all", 42 | "temperature": 0.1, 43 | "cross_scale_contrast": true, 44 | "weights": [1.0, 0.7, 0.4, 0.1], 45 | "scales": 4, 46 | "interm": {"name": "CrossEntropyLoss", "args": [], "weight": 0.4}, 47 | "final": {"name": "CrossEntropyLoss", "args": [], "weight": 1.0}, 48 | "losses": {"TwoScaleLoss": 1.0, "DenseContrastiveLossV2_ms": 0.1}, 49 | "losses__": {"TwoScaleLoss": 1.0}, 50 | "min_views_per_class": 5, 51 | "max_views_per_class": 2500, 52 | "max_features_total": 10000 53 | }, 54 | 55 | "train": { 56 | "lr_batchwise": true, 57 | "learning_rate": 0.00006, 58 | "lr_fct": "linear-warmup-polynomial", 59 | "lr_params": {"power": 1.0, 60 | "warmup_iters": 1500, 61 | "warmup_rate": 1e-6 , 62 | "min_lr": 0.0}, 63 | "optim": "AdamW", 64 | "epochs": 127, 65 | "epochs_bs12": 648, 66 | "momentum": 0.9, 67 | "betas": [0.9, 0.999], 68 | "weight_decay": 0.01, 69 | "opt_keys":{"absolute_pos_embed": {"wd_mult":0.0}, 70 | "norm": {"wd_mult":0.0}, 71 | "relative_position_bias_table":{"wd_mult": 0.0}} 72 | }, 73 | "max_valid_imgs": 2, 74 | "valid_freq": 10, 75 | "log_every_n_epochs": 10, 76 | "cuda": true, 77 | "gpu_device": 0, 78 | "parallel": false, 79 | "seed": 100 80 | } -------------------------------------------------------------------------------- /configs/CITYSCAPES/hrnet_contrastive_CTS.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "hrn", 3 | "mode": "training", 4 | "manager": "HRNet", 5 | "graph": { 6 | "model": "HRNet", 7 | "backbone": "hrnet48", 8 | "sync_bn":true, 9 | "out_stride": 4, 10 | "pretrained": false, 11 | "align_corners": true, 12 | "ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": true, "before_context": true} 13 | }, 14 | 15 | "load_last": true, 16 | "tta":true, 17 | "tta_scales": [0.75, 1.25, 1.5, 1.75, 2], 18 | "run_final_val": false, 19 | 20 | "data": { 21 | "num_workers":8, 22 | "dataset": "CITYSCAPES", 23 | "use_relabeled": false, 24 | "blacklist": false, 25 | "experiment": 1, 26 | "split": ["train", "val"], 27 | "transforms": ["flip", "random_scale", "RandomCropImgLbl", "colorjitter", "torchvision_normalise"], 28 | "transform_values": {"crop_shape": [512, 1024], "crop_class_max_ratio": 0.75, 29 | "scale_range": [0.5, 2]}, 30 | "transforms_val": ["torchvision_normalise"], 31 | "transform_values_val": {}, 32 | "batch_size": 12 33 | }, 34 | 35 | "loss": { 36 | "name": "LossWrapper", 37 | "label_scaling_mode": "nn", 38 | "dominant_mode": "all", 39 | "temperature": 0.1, 40 | "cross_scale_contrast": true, 41 | "weights": [1, 0.7, 0.4, 0.1], 42 | "scales": 4, 43 | "losses": {"CrossEntropyLoss": 1,"DenseContrastiveLossV2_ms": 0.1}, 44 | "losses___": {"CrossEntropyLoss": 1}, 45 | "min_views_per_class": 5, 46 | "max_views_per_class": 2500, 47 | "max_features_total": 10000 48 | }, 49 | "train": { 50 | "learning_rate": 0.01, 51 | "lr_fct": "polynomial", 52 | "optim": "SGD", 53 | "lr_batchwise": true, 54 | "epochs": 484, 55 | "momentum": 0.9, 56 | "wd": 0.0005 57 | }, 58 | "valid_batch_size": 1, 59 | "max_valid_imgs":2, 60 | "valid_freq": 100, 61 | "log_every_n_epochs": 100, 62 | "cuda": true, 63 | "gpu_device": 0, 64 | "parallel": false, 65 | "seed": 0 66 | } -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RViMLab/ECCV2022-multi-scale-and-cross-scale-contrastive-segmentation/c511fbcde6ac53b72d663225bdf6dded022ca1ce/configs/__init__.py -------------------------------------------------------------------------------- /configs/path_info.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "theo_CTS": [ 4 | "D:\\datasets\\CITYSCAPES", 5 | "D:\\datasets\\CITYSCAPES\\logs" 6 | ], 7 | 8 | "theo_ADE20K": [ 9 | "D:\\datasets\\ADEChallengeData2016", 10 | "D:\\datasets\\ADEChallengeData2016\\logs" 11 | ] 12 | } 13 | -------------------------------------------------------------------------------- /data/data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RViMLab/ECCV2022-multi-scale-and-cross-scale-contrastive-segmentation/c511fbcde6ac53b72d663225bdf6dded022ca1ce/data/data.pkl -------------------------------------------------------------------------------- /datasets/ADE20K.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import namedtuple 4 | from typing import Union 5 | from torch.utils.data.dataset import Dataset 6 | from torchvision.transforms import Compose, ToPILImage 7 | from PIL import Image, ImageFile 8 | from utils import DATASETS_INFO, remap_mask, printlog 9 | import numpy as np 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | import pathlib 12 | 13 | from utils import DATASETS_INFO, remap_mask, printlog, mask_to_colormap, get_remapped_colormap 14 | 15 | 16 | 17 | class ADE20K(Dataset): 18 | 19 | def __init__(self, root, transforms_dict, split:Union[str,list]='train', debug=False): 20 | """ 21 | 22 | :param root: path to cityscapes dir (i.e where directories "leftImg8bit" and "gtFine" are located) 23 | :param transforms_dict: see dataset_from_df.py 24 | :param split: any of "train", "test", "val" 25 | :param mode: if "fine" then loads finely annotated images else Coarsely uses coarsely annotated 26 | :param target_type: currently only expects the default: 'semantic' (todo: test other target_types if needed) 27 | """ 28 | 29 | 30 | super(ADE20K, self).__init__() 31 | self.root = root 32 | self.common_transforms = Compose(transforms_dict['common']) 33 | self.img_transforms = Compose(transforms_dict['img']) 34 | self.lbl_transforms = Compose(transforms_dict['lbl']) 35 | # assert(mode in ("fine", "coarse")) 36 | valid_splits = ["train", "test", "val", ['train', 'val']] 37 | assert (split in valid_splits), f'split {split} is not in valid_modes {valid_splits}' 38 | # self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' 39 | self.split = split # "train", "test", "val" 40 | self.debug = debug 41 | # self.target_type = target_type 42 | self.images = [] 43 | self.targets = [] 44 | # this can only take the following values so hardcoded 45 | self.dataset = 'ADE20K' 46 | self.experiment = 1 47 | self.img_suffix = '.jpg' 48 | self.target_suffix = '.png' 49 | 50 | 51 | if self.split == ['train', 'val']: 52 | # for training on train + val 53 | printlog('train set is train+val splits') 54 | for i, s in enumerate(self.split): 55 | self.images_dir = os.path.join(self.root, 'ADEChallengeData2016', 'images', s) 56 | self.targets_dir = os.path.join(self.root, 'ADEChallengeData2016', 'annotations', s) 57 | for image_filename in os.listdir(self.images_dir): 58 | img_path = os.path.join(self.images_dir, image_filename) 59 | target_path = os.path.join(self.targets_dir, image_filename.split(self.img_suffix)[-2] + self.target_suffix) 60 | self.images.append(img_path) 61 | self.targets.append(target_path) 62 | assert (pathlib.Path(self.images[-1]).exists() and pathlib.Path(self.targets[-1]).exists()) 63 | assert(pathlib.Path(self.images[-1]).stem == pathlib.Path(self.targets[-1]).stem) 64 | 65 | elif self.split == 'test': 66 | self.images_dir = os.path.join(self.root, 'ADEChallengeData2016', 'images', split) 67 | self.targets_dir = os.path.join(self.root,'ADEChallengeData2016', 'annotations', 'train') # dummy 68 | targets_dummy = os.listdir(self.targets_dir) 69 | for n, image_filename in enumerate(os.listdir(self.images_dir)): 70 | img_path = os.path.join(self.images_dir, image_filename) 71 | target_path = os.path.join(self.targets_dir, targets_dummy[n]) 72 | self.images.append(img_path) 73 | self.targets.append(target_path) 74 | assert (pathlib.Path(self.images[-1]).exists()) # and pathlib.Path(self.targets[-1]).exists()) 75 | # assert(pathlib.Path(self.images[-1]).stem == pathlib.Path(self.targets[-1]).stem) 76 | else: 77 | self.images_dir = os.path.join(self.root, 'ADEChallengeData2016', 'images', split) 78 | self.targets_dir = os.path.join(self.root, 'ADEChallengeData2016', 'annotations', split) 79 | for image_filename in os.listdir(self.images_dir): 80 | img_path = os.path.join(self.images_dir, image_filename) 81 | target_path = os.path.join(self.targets_dir, image_filename.split(self.img_suffix)[-2] + self.target_suffix) 82 | self.images.append(img_path) 83 | self.targets.append(target_path) 84 | assert (pathlib.Path(self.images[-1]).exists() and pathlib.Path(self.targets[-1]).exists()) 85 | assert(pathlib.Path(self.images[-1]).stem == pathlib.Path(self.targets[-1]).stem) 86 | printlog(f'ade20k data all found split = {self.split}, images {len(self.images)}, targets {len(self.targets)}') 87 | 88 | self.return_filename = False 89 | 90 | def __getitem__(self, index, ): 91 | """ 92 | Args: 93 | index (int): Index 94 | Returns: 95 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more 96 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. 97 | """ 98 | 99 | image = Image.open(self.images[index]).convert('RGB') 100 | metadata = {'index': index} 101 | 102 | if self.split == 'test': 103 | target = remap_mask(np.ones(shape=np.array(image).shape[0:2], dtype=np.int32), 104 | DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], to_network=True).astype('int32') 105 | 106 | else: 107 | target = Image.open(self.targets[index]) 108 | target = remap_mask(np.array(target), 109 | DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], to_network=True).astype('int32') 110 | # print(index, ': ', np.unique(target)) 111 | target = Image.fromarray(target) 112 | # if 14 in np.unique(target).tolist(): # tracl 113 | # target.show() 114 | # target.close() 115 | # return 0 , 0 , 0 116 | 117 | image, target, metadata = self.common_transforms((image, target, metadata)) 118 | img_tensor = self.img_transforms(image) 119 | lbl_tensor = self.lbl_transforms(target).squeeze() 120 | 121 | if self.return_filename: 122 | metadata.update({'img_filename': self.images[index], 123 | 'target_filename': self.targets[index]}) 124 | 125 | if self.debug: 126 | # ToPILImage()(img_tensor).show() 127 | # ToPILImage()(lbl_tensor).show() 128 | # debug_lbl = mask_to_colormap(to_numpy(lbl_tensor), 129 | # get_remapped_colormap( 130 | # DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], 131 | # self.dataset), 132 | # from_network=True, experiment=self.experiment, 133 | # dataset=self.dataset)[..., ::-1] 134 | # 135 | # 136 | # 137 | # fn = metadata['target_filename'].split('\\')[-1] 138 | # p = pathlib.Path(r'C:\Users\Theodoros Pissas\Documents\tresorit\ADEChallengeData2016\ADEChallengeData2016\visuals\val/') 139 | # p1 = pathlib.Path(f'{fn}') 140 | # # ToPILImage()(lbl_tensor).save(f"{str(p/p1)}") 141 | # 142 | # cv2.imwrite(f"{str(p/p1)}", debug_lbl) 143 | print(f'\nafter aug index : {np.unique(lbl_tensor)} lbl {lbl_tensor.shape} image {img_tensor.shape} fname:{self.images[index]}') 144 | return img_tensor, lbl_tensor, metadata 145 | 146 | def __len__(self): 147 | return len(self.images) 148 | 149 | if __name__ == '__main__': 150 | import pathlib 151 | import torch 152 | from utils import parse_transform_lists 153 | import json 154 | import cv2 155 | from torch.nn import functional as F 156 | from utils import Pad, RandomResize, RandomCropImgLbl, Resize, FlipNP, to_numpy, pil_plot_tensor, to_comb_image 157 | from torchvision.transforms import ToTensor 158 | import PIL.Image as Image 159 | 160 | data_path = 'C:\\Users\\Theodoros Pissas\\Documents\\tresorit\\ADEChallengeData2016\\' 161 | d = {"dataset":'ADE20K', "experiment":1} 162 | path_to_config = '../configs/ADE20K/upnswin_contrastive_ADE20K.json' 163 | with open(path_to_config, 'r') as f: 164 | config = json.load(f) 165 | 166 | transforms_list = config['data']['transforms'] 167 | transforms_values = config['data']['transform_values'] 168 | if 'torchvision_normalise' in transforms_list: 169 | del transforms_list[-1] 170 | 171 | transforms_dict = parse_transform_lists(transforms_list, transforms_values, **d) 172 | transforms_list_val = config['data']['transforms_val'] 173 | transforms_values_val = config['data']['transform_values_val'] 174 | 175 | if 'torchvision_normalise' in transforms_list_val: 176 | del transforms_list_val[-1] 177 | transforms_dict_val = parse_transform_lists(transforms_list_val, transforms_values_val, **d) 178 | del transforms_list_val[0] 179 | train_set = ADE20K(root=data_path, 180 | debug=True, 181 | split=['train', 'val'], 182 | transforms_dict=transforms_dict) 183 | valid_set = ADE20K(root=data_path, 184 | debug=True, 185 | split='test', 186 | transforms_dict=transforms_dict_val) 187 | 188 | issues = [] 189 | valid_set.return_filename = True 190 | train_set.return_filename = True 191 | hs=[] 192 | ws = [] 193 | for ret in valid_set: 194 | hs.append(ret[0].shape[1]) 195 | ws.append(ret[0].shape[2]) 196 | present_classes = torch.unique(ret[1]) 197 | print(ret[-1]) 198 | # elif 15 in present_classes: 199 | # issues.append([ret[-1], present_classes]) 200 | # print('bus found !!!! ') 201 | # print(present_classes) 202 | # pil_plot_tensor(ret[0], is_rgb=True) 203 | # pil_plot_tensor(ret[1], is_rgb=False) 204 | 205 | # a = 1 206 | # print(max(hs)) 207 | # print(max(ws)) -------------------------------------------------------------------------------- /datasets/CaDIS.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from utils import DATASETS_INFO, printlog 3 | import pathlib 4 | 5 | 6 | def get_cadis_dataframes(config: dict): 7 | # Make dataframes for the training and the validation set 8 | assert 'data' in config 9 | dataset = config['data']['dataset'] 10 | assert dataset == 'CADIS', f'dataset must be CADIS instead got {dataset}' 11 | df = pd.read_csv('data/data.csv') # todo this should be moved to data dir 12 | 13 | if 'random_split' in config['data']: 14 | print("***Legacy mode: random split of all data used, instead of split of videos!***") 15 | train = df.sample(frac=config['data']['random_split'][0]).copy() 16 | valid = df.drop(train.index).copy() 17 | split_of_rest = config['data']['random_split'][1] / (1 - config['data']['random_split'][0]) 18 | valid = valid.sample(frac=split_of_rest) 19 | else: 20 | splits = DATASETS_INFO[dataset].DATA_SPLITS[int(config['data']['split'])] 21 | if len(splits) == 3: 22 | printlog("using train-val-test split") 23 | train_videos, valid_videos, test_videos = splits 24 | if config['mode'] == 'infer': 25 | printlog(f"CADIS with mode {config['mode']}") 26 | printlog(f"going to use test_videos as vadilation set") 27 | valid_videos = test_videos 28 | elif len(splits) == 2: 29 | printlog("using train-merged[valtest] split") 30 | train_videos, valid_videos = splits 31 | else: 32 | raise ValueError('splits must be a list of length 2 or 3') 33 | train = df.loc[df['vid_num'].isin(train_videos)].copy() 34 | valid = df.loc[(df['vid_num'].isin(valid_videos)) & (df['propagated'] == 0)].copy() # No prop lbl in valid 35 | info_string = "Dataframes created. Number of records training / validation: {:06d} / {:06d}\n" \ 36 | " Actual data split training / validation: {:.3f} / {:.3f}" \ 37 | .format(len(train.index), len(valid.index), len(train.index) / len(df), len(valid.index) / len(df)) 38 | 39 | # Replace incorrectly annotated frames if flag set 40 | if config['data']['use_relabeled']: 41 | train_idx_list = train[train['relabeled'] == 1].index 42 | for idx in train_idx_list: 43 | train.loc[idx, 'blacklisted'] = 0 # So the frames don't get removed after 44 | lbl_path = pathlib.Path(train.loc[idx, 'lbl_path']).name 45 | train.loc[idx, 'lbl_path'] = 'relabeled/' + str(lbl_path) 46 | valid_idx_list = valid[valid['relabeled'] == 1].index 47 | for idx in valid_idx_list: 48 | valid.loc[idx, 'blacklisted'] = 0 # So the frames don't get removed after 49 | lbl_path = pathlib.Path(valid.loc[idx, 'lbl_path']).name 50 | valid.loc[idx, 'lbl_path'] = 'relabeled/' + str(lbl_path) 51 | info_string += "\n Relabeled train recs: {}\n" \ 52 | " Relabeled valid recs: {}" \ 53 | .format(len(train_idx_list), len(valid_idx_list)) 54 | 55 | # Remove incorrectly annotated frames if flag set 56 | if config['data']['blacklist']: 57 | train = train.drop(train[train['blacklisted'] == 1].index) 58 | valid = valid.drop(valid[valid['blacklisted'] == 1].index) 59 | t_len, v_len = len(train.index), len(valid.index) 60 | info_string += "\n After blacklisting: Number of records train / valid: {:06d} / {:06d}\n" \ 61 | " Relative data split train / valid: {:.3f} / {:.3f}" \ 62 | .format(t_len, v_len, t_len / (t_len + v_len), v_len / (t_len + v_len)) 63 | train = train.reset_index() 64 | valid = valid.reset_index() 65 | 66 | printlog(f" dataset {dataset}") 67 | printlog(info_string) 68 | return train, valid 69 | -------------------------------------------------------------------------------- /datasets/Dataset_from_df.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | import cv2 4 | from torch.utils.data.dataset import Dataset 5 | from torchvision.transforms import Compose, ToPILImage 6 | import torch 7 | from utils import DATASETS_INFO, remap_mask 8 | import numpy as np 9 | 10 | 11 | class DatasetFromDF(Dataset): 12 | def __init__(self, dataframe, experiment, transforms_dict, data_path=None, labels_remaped=False, 13 | return_pseudo_property=False, dataset='CADIS', debug=False): 14 | self.df = dataframe 15 | self.experiment = experiment 16 | self.dataset = dataset 17 | self.common_transforms = Compose(transforms_dict['common']) 18 | self.img_transforms = Compose(transforms_dict['img']) 19 | self.lbl_transforms = Compose(transforms_dict['lbl']) 20 | self.labels_are_remapped = labels_remaped # used when reading pseudo labeled data 21 | self.return_pseudo_property = return_pseudo_property # used to return whether the datapoint is pseudo labelled 22 | self.preloaded = False if data_path is not None else True 23 | if self.preloaded: # Data preloaded, need to assert that 'image' and 'label' exist in the dataframe 24 | assert 'image' in self.df and 'label' in self.df, "For preloaded data, the dataframe passed to the " \ 25 | "PyTorch dataset needs to contain the columns 'image' " \ 26 | "and 'label'" 27 | else: # Standard case: data not preloaded, needs base path to get images / labels from 28 | assert 'img_path' in self.df and 'lbl_path' in self.df, "The dataframe passed to the PyTorch dataset needs"\ 29 | " to contain the columns 'img_path' and 'lbl_path'" 30 | self.data_path = data_path 31 | self.debug = debug 32 | 33 | def __getitem__(self, item): 34 | if self.preloaded: 35 | img = self.df.iloc[item].loc['image'] 36 | lbl = self.df.iloc[item].loc['label'] 37 | else: 38 | # img = cv2.imread(str(pathlib.Path(self.data_path) / self.df.iloc[item].loc['img_path']))[..., ::-1] 39 | img = cv2.imread( 40 | os.path.join( 41 | self.data_path, 42 | os.path.join(*self.df.iloc[item].loc['img_path'].split('\\'))))[..., ::-1] 43 | img = img - np.zeros_like(img) # deals with negative stride error 44 | # lbl = cv2.imread(str(pathlib.Path(self.data_path) / self.df.iloc[item].loc['lbl_path']), 0) 45 | lbl = cv2.imread( 46 | os.path.join( 47 | self.data_path, 48 | os.path.join(*self.df.iloc[item].loc['lbl_path'].split('\\'))), 0) 49 | lbl = lbl - np.zeros_like(lbl) 50 | 51 | if self.labels_are_remapped: 52 | # if labels are pseudo they are already remapped to experiment label set 53 | lbl = lbl.astype('int32') 54 | else: 55 | lbl = remap_mask(lbl, DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], to_network=True).astype('int32') 56 | 57 | # Note: .astype('i') is VERY important. If left in uint8, ToTensor() will normalise the segmentation classes! 58 | 59 | # Here (and before Compose(lbl_transforms) we'd need to set the random seed and pray, following this idea: 60 | # https://github.com/pytorch/vision/issues/9#issuecomment-304224800 61 | # Big yikes. Big potential problem source, see here: https://github.com/pytorch/pytorch/issues/7068 62 | # If that doesn't work, the whole transforms structure needs to be changed into all-custom functions that will 63 | # transform both img and lbl at the same time, with one random shift / flip / whatever being applied to both 64 | metadata = {'index': item, 'filename': self.df.iloc[item].loc['img_path'], 65 | 'target_filename': str(pathlib.Path(self.df.iloc[item].loc['img_path']).stem)} 66 | 67 | if self.dataset == 'RETOUCH': 68 | subject_id = pathlib.Path(metadata['filename']).parent.stem 69 | slice_id = pathlib.Path(self.df.iloc[item].loc['lbl_path']).stem 70 | metadata['subject_id'] = subject_id 71 | metadata['target_filename'] = f"{subject_id}_{slice_id}" 72 | 73 | img, lbl, metadata = self.common_transforms((img, lbl, metadata)) 74 | img_tensor = self.img_transforms(img) 75 | lbl_tensor = self.lbl_transforms(lbl).squeeze() 76 | if self.return_pseudo_property: 77 | # pseudo_tensor = torch.from_numpy(np.asarray(self.df.iloc[item].loc['pseudo'])) 78 | metadata.update({'pseudo': self.df.iloc[item].loc['pseudo']}) 79 | 80 | if self.debug: 81 | ToPILImage()(img_tensor).show() 82 | ToPILImage()(lbl_tensor).show() 83 | print(f'\nafter aug index : {np.unique(lbl_tensor)} lbl {lbl_tensor.shape} image {img_tensor.shape}') 84 | 85 | return img_tensor, lbl_tensor, metadata 86 | 87 | def __len__(self): 88 | return len(self.df) 89 | -------------------------------------------------------------------------------- /datasets/PascalC.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import torch 4 | from collections import namedtuple 5 | from torch.utils.data.dataset import Dataset 6 | from torchvision.transforms import Compose, ToPILImage 7 | from PIL import Image, ImageFile 8 | from utils import DATASETS_INFO, remap_mask, printlog, mask_to_colormap, get_remapped_colormap 9 | import numpy as np 10 | # import cv2 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | import pathlib 13 | 14 | 15 | class PascalC(Dataset): 16 | def __init__(self, root, transforms_dict, split='train', mode='fine', target_type='semantic', debug=False): 17 | """ 18 | :param root: path to pascal dir (i.e where directories "leftImg8bit" and "gtFine" are located) 19 | :param transforms_dict: see dataset_from_df.py 20 | :param split: "train" or "val" 21 | :param mode: if "fine" then loads finely annotated images else Coarsely uses coarsely annotated 22 | :param target_type: currently only expects the default: 'semantic' (todo: test other target_types if needed) 23 | """ 24 | self.debug = debug 25 | super(PascalC, self).__init__() 26 | self.root = root 27 | self.common_transforms = Compose(transforms_dict['common']) 28 | self.img_transforms = Compose(transforms_dict['img']) 29 | self.lbl_transforms = Compose(transforms_dict['lbl']) 30 | valid_modes = ["train", "val"] 31 | assert (split in valid_modes), f'split {split} is not in valid_modes {valid_modes}' 32 | # self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' 33 | self.split = split # "train", "test", "val" 34 | 35 | # self.target_type = target_type 36 | self.images = [] 37 | self.targets = [] 38 | # this can only take the following values so hardcoded 39 | self.dataset = 'PASCALC' 40 | self.experiment = 1 41 | 42 | # for training on train + val 43 | self.images_dir = [] 44 | self.targets_dir = [] 45 | self.images_dir = pathlib.Path(os.path.join(self.root, self.split, 'image')) 46 | self.targets_dir = pathlib.Path(os.path.join(self.root, self.split, 'label')) 47 | 48 | for img_path, target_path in zip(sorted(self.images_dir.glob('*.jpg')), sorted(self.targets_dir.glob('*.png'))): 49 | self.images.append(img_path) 50 | self.targets.append(target_path) 51 | assert(pathlib.Path(self.images[-1]).exists() and pathlib.Path(self.targets[-1]).exists()) 52 | assert(pathlib.Path(self.images[-1]).stem == pathlib.Path(self.targets[-1]).stem) 53 | printlog(f'{self.dataset} data all found split is [ {self.split} ]') 54 | 55 | self.return_filename = False 56 | 57 | def __getitem__(self, index): 58 | """ 59 | Args: 60 | index (int): Index 61 | Returns: 62 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more 63 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation. 64 | """ 65 | 66 | image = Image.open(self.images[index]).convert('RGB') 67 | target = Image.open(self.targets[index]) 68 | 69 | # if self.debug: 70 | # image.show() 71 | # target.show() 72 | 73 | 74 | target = remap_mask(np.array(target), 75 | DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], to_network=True).astype('int32') 76 | 77 | target = Image.fromarray(target) 78 | 79 | # print(index, ': ', np.unique(target), ' ', [class_int_to_name[c] for c in np.unique(target) if not (c==59) ]) 80 | # class_int_to_name = DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] 81 | 82 | metadata = {'index': index} 83 | image, target, metadata = self.common_transforms((image, target, metadata)) 84 | img_tensor = self.img_transforms(image) 85 | lbl_tensor = self.lbl_transforms(target).squeeze() 86 | 87 | if self.return_filename: 88 | metadata.update({'img_filename': str(self.images[index]), 89 | 'target_filename': str(self.targets[index])}) 90 | if self.debug: 91 | ToPILImage()(img_tensor).show() 92 | ToPILImage()(lbl_tensor).show() 93 | # 94 | # debug_lbl = mask_to_colormap(to_numpy(lbl_tensor), 95 | # get_remapped_colormap( 96 | # DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], 97 | # self.dataset), 98 | # from_network=True, experiment=self.experiment, 99 | # dataset=self.dataset)[..., ::-1] 100 | # 101 | # 102 | # 103 | # fn = metadata['target_filename'].split('\\')[-1] 104 | # p = pathlib.Path(r'C:\Users\Theodoros Pissas\Documents\tresorit\PASCALC\visuals\val/') 105 | # p1 = pathlib.Path(f'{fn}') 106 | # # ToPILImage()(lbl_tensor).save(f"{str(p/p1)}") 107 | # 108 | # cv2.imwrite(f"{str(p/p1)}", debug_lbl) 109 | # 110 | print(f'\nafter aug index, : {np.unique(lbl_tensor)} lbl {lbl_tensor.shape} image {img_tensor.shape} fname:{self.images[index]}') 111 | 112 | return img_tensor, lbl_tensor, metadata 113 | 114 | def __len__(self): 115 | return len(self.images) 116 | 117 | def extra_repr(self): 118 | lines = ["Split: {split}"] 119 | return '\n'.join(lines).format(**self.__dict__) 120 | 121 | 122 | if __name__ == '__main__': 123 | import pathlib 124 | from torch.nn import functional as F 125 | from utils import Pad, RandomResize, RandomCropImgLbl, Resize, FlipNP, to_numpy, pil_plot_tensor 126 | data_path = r'C:\Users\Theodoros Pissas\Documents\tresorit\PASCALC/' 127 | from torchvision.transforms import ToTensor 128 | import PIL.Image as Image 129 | d = {"dataset":'PASCALC', "experiment":1} 130 | 131 | # augs= [ 132 | # FlipNP(probability=(0, 1.0)), 133 | # RandomResize(**d, 134 | # scale_range=[0.5, 2], 135 | # aspect_range=[0.9, 1.1], 136 | # target_size=[520, 520], 137 | # probability=1.0), 138 | # RandomCropImgLbl(**d, 139 | # shape=[512, 512], 140 | # crop_class_max_ratio=0.75), 141 | # ] 142 | # 143 | # augs_val= [Resize(**d, min_side_length=512, fit_stride=32, return_original_labels=True)] 144 | # 145 | # 146 | # train_set = PascalC(root=data_path, debug=True, 147 | # split='train', 148 | # transforms_dict={'common': augs_val, 149 | # 'img': [(ToTensor())], 150 | # 'lbl': [(ToTensor())]}) 151 | 152 | from utils import parse_transform_lists 153 | import json 154 | path_to_config = '../configs/PASCALC/hrnet_contrastive_PC.json' 155 | with open(path_to_config, 'r') as f: 156 | config = json.load(f) 157 | 158 | transforms_list = config['data']['transforms'] 159 | if 'torchvision_normalise' in transforms_list: 160 | del transforms_list[-1] 161 | transforms_values = config['data']['transform_values'] 162 | transforms_dict = parse_transform_lists(transforms_list, transforms_values, dataset='PASCALC', experiment=1) 163 | 164 | transforms_list_val = config['data']['transforms_val'] 165 | if 'torchvision_normalise' in transforms_list: 166 | del transforms_list[-1] 167 | 168 | transforms_values_val = config['data']['transform_values_val'] 169 | transforms_dict_val = parse_transform_lists(transforms_list_val, transforms_values_val, dataset='PASCALC', experiment=1) 170 | 171 | # transforms_values_val = {} 172 | # transforms_dict_val = parse_transform_lists({}, transforms_values_val, dataset='PASCALC', experiment=1) 173 | 174 | 175 | train_set = PascalC(root=data_path, 176 | debug=True, 177 | split='train', 178 | transforms_dict=transforms_dict) 179 | 180 | valid_set = PascalC(root=data_path, 181 | debug=True, 182 | split='val', 183 | transforms_dict=transforms_dict_val) 184 | valid_set.return_filename = True 185 | 186 | issues = [] 187 | train_set.return_filename = True 188 | hs=[] 189 | ws = [] 190 | for ret in valid_set: 191 | # print(ret[0].shape) 192 | # img = ToPILImage()(ret[0]).show() 193 | # lbl = ToPILImage()(ret[1]).show() 194 | 195 | hs.append(ret[0].shape[1]) 196 | ws.append(ret[0].shape[2]) 197 | print(ret[-1]) 198 | print('*'*10) 199 | # meta = ret[-1] 200 | # lbl = meta['original_labels'].unsqueeze(0) 201 | # resized = ret[1].unsqueeze(0).unsqueeze(0).long() 202 | # pad_w, pad_h, stride = meta["pw_ph_stride"] 203 | # if pad_h > 0 or pad_w > 0: 204 | # 205 | # un_padded = resized[:, :, 0:resized.size(2) - pad_h, 0:resized.size(3) - pad_w] 206 | # pil_plot_tensor(un_padded) 207 | # un_resized = F.interpolate(un_padded.float(), size=lbl.size()[-2:], mode='nearest') 208 | # print(torch.sum(un_resized- lbl)) 209 | present_classes = torch.unique(ret[1]) 210 | if len(present_classes) == 1 and 59 in present_classes: 211 | issues.append([ret[-1], present_classes]) 212 | print('issue found !!!! ') 213 | print(present_classes, ret[-1]) 214 | print('issue found !!!! ') 215 | a = 1 216 | print(max(hs)) 217 | print(max(ws)) -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .Dataset_from_df import DatasetFromDF 2 | from .Cityscapes import Cityscapes 3 | from .PascalC import PascalC 4 | from .ADE20K import ADE20K 5 | from .CaDIS import get_cadis_dataframes 6 | -------------------------------------------------------------------------------- /env_dgx.yml: -------------------------------------------------------------------------------- 1 | name: semseg 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8 7 | - pytorch 8 | - cudatoolkit=10.1 9 | - tensorboard 10 | - torchvision=0.9.1 11 | - h5py 12 | - matplotlib 13 | - numpy 14 | - scipy 15 | - pandas 16 | - pillow 17 | - pip 18 | - future 19 | - pip: 20 | - opencv-python 21 | - tqdm 22 | - timm 23 | - tsne-torch 24 | -------------------------------------------------------------------------------- /losses/DenseContrastiveLossV2_ms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import DATASETS_INFO, is_distributed, concat_all_gather, get_rank, to_numpy, printlog 4 | from torch.nn.functional import one_hot 5 | import torch.distributed 6 | from losses.DenseContrastiveLossV2 import DenseContrastiveLossV2 as DCV2 7 | 8 | def has_inf_or_nan(x): 9 | return torch.isinf(x).max().item(), torch.isnan(x).max().item() 10 | 11 | 12 | class DenseContrastiveLossV2_ms(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | self.parallel = is_distributed() 16 | self.experiment = config['experiment'] 17 | self.dataset = config['dataset'] 18 | self.num_all_classes = len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) 19 | self.num_real_classes = self.num_all_classes - 1 if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] else self.num_all_classes 20 | self.ignore_class = (len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) - 1) if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] else -1 21 | self.scales = config['scales'] if 'scales' in config else 2 22 | self.weights = config['weights'] if 'weights' in config else [1.0] * self.scales 23 | assert(self.scales == len(self.weights)), f'given dc loss number of scales [{self.scales}] not equal len of weights {self.weights}' 24 | self.losses = [] 25 | self.eps = torch.tensor(1e-10) 26 | self.meta = {} 27 | self.cross_scale_contrast = config['cross_scale_contrast'] if 'cross_scale_contrast' in config else False 28 | self.cross_scale_temperature = config['temperature'] if 'cross_scale_temperature' not in config else 0.1 29 | self.detach_cs_deepest = config['detach_deepest'] if 'detach_deepest' in config else False 30 | self.w_high_low = config['w_high_low'] if 'w_high_low' in config else 1.0 31 | self.w_high_mid = config['w_high_mid'] if 'w_high_mid' in config else 1.0 32 | self.ms_losses = [] 33 | self.cs_losses = [] 34 | printlog(f'defining dcv2 ms loss with number of scales {self.scales} and weights {self.weights}') 35 | printlog(f'using cross scale contrast {self.cross_scale_contrast}') 36 | for class_name in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]: 37 | self.meta[class_name] = (0.0, 0.0) # pos-neg per class 38 | for s in range(self.scales): 39 | printlog(f'defining dcv2 loss at scale {s}') 40 | setattr(self, f'DCV2_scale{s}', DCV2(config)) 41 | if self.cross_scale_contrast: 42 | printlog(f'using cross-scale contrast with detach_cs_deepest set to {self.detach_cs_deepest}, w_high_low: {self.w_high_low}, w_high_mid: {self.w_high_mid}') 43 | 44 | def forward(self, label: torch.Tensor, features: list, **kwargs): 45 | self.cs_losses = [] 46 | self.ms_losses = [] 47 | flag_error = False 48 | loss = torch.tensor(0.0, dtype=torch.float, device=features[0].device) 49 | feats_ms = [] 50 | labels_ms = [] 51 | for s in range(self.scales): 52 | if self.cross_scale_contrast: 53 | loss_s, feats_s, labels_s, flag_error = getattr(self, f'DCV2_scale{s}')(label, features[s]) 54 | loss+= self.weights[s] * loss_s 55 | feats_ms.append(feats_s) 56 | labels_ms.append(labels_s) 57 | else: 58 | loss_s=getattr(self, f'DCV2_scale{s}')(label, features[s]) 59 | loss += self.weights[s] * loss_s 60 | self.ms_losses.append(loss_s.detach()) 61 | 62 | if self.cross_scale_contrast and not flag_error: 63 | assert len(feats_ms) > 1 64 | assert len(labels_ms) > 1 65 | # highest res to lowest res contrast 66 | if self.detach_cs_deepest: 67 | loss_cross_scale = self.contrastive_loss(feats_ms[0], labels_ms[0], feats_ms[-1].detach(), labels_ms[-1]) 68 | else: 69 | loss_cross_scale = self.contrastive_loss(feats_ms[0], labels_ms[0], feats_ms[-1], labels_ms[-1]) 70 | self.cs_losses.append(loss_cross_scale.detach()) 71 | 72 | loss += self.w_high_low * loss_cross_scale 73 | 74 | if len(feats_ms)>2: # hrnet : 4 , s4-s16 , s4-s32 dlv3 : 3 layer1(s4)-layer4(s8), layer1(s4)-layer3(s8) 75 | if self.detach_cs_deepest: 76 | loss_cross_scale2 = self.contrastive_loss(feats_ms[0], labels_ms[0], feats_ms[-2].detach(), labels_ms[-2]) 77 | else: 78 | loss_cross_scale2 = self.contrastive_loss(feats_ms[0], labels_ms[0], feats_ms[-2], labels_ms[-2]) 79 | loss += self.w_high_mid * loss_cross_scale2 80 | self.cs_losses.append(loss_cross_scale2.detach()) 81 | 82 | return loss 83 | 84 | def contrastive_loss(self, feats1, labels1, feats2, labels2): 85 | """ 86 | :param feats: T-C-V 87 | T: classes in batch (with repetition), which can be thought of as the number of anchors 88 | C: feature space dimensionality 89 | V: views per class (i.e samples from each class), 90 | which can be thought of as the number of views per anchor 91 | :param labels: T 92 | :return: loss 93 | """ 94 | # prepare feats 95 | feats1 = torch.nn.functional.normalize(feats1, p=2, dim=1) # L2 normalization 96 | feats1 = feats1.transpose(dim0=1, dim1=2) # feats are T-V-C 97 | num_anchors, views_per_anchor, c = feats1.shape # get T, V, C 98 | feats_flat1 = feats1.contiguous().view(-1, c) # feats_flat is T*V-C 99 | 100 | labels1 = labels1.contiguous().view(-1, 1) # labels are T-1 101 | labels1 = labels1.repeat(1, views_per_anchor) # labels are T-V 102 | labels1 = labels1.view(-1, 1) # labels are T*V-1 103 | 104 | feats2 = torch.nn.functional.normalize(feats2, p=2, dim=1) # L2 normalization 105 | feats2 = feats2.transpose(dim0=1, dim1=2) # feats are T-V-C 106 | num_anchors, views_per_anchor, c = feats2.shape # get T, V, C 107 | feats_flat2 = feats2.contiguous().view(-1, c) # feats_flat is T*V-C 108 | 109 | labels2 = labels2.contiguous().view(-1, 1) # labels are T-1 110 | labels2 = labels2.repeat(1, views_per_anchor) # labels are T-V 111 | labels2 = labels2.view(-1, 1) # labels are T*V-1 112 | 113 | pos_mask, neg_mask = self.get_masks2(labels1, labels2) 114 | dot_product = torch.div(torch.matmul(feats_flat1, torch.transpose(feats_flat2, 0, 1)), self.cross_scale_temperature) 115 | loss2 = self.InfoNce_loss(pos_mask, neg_mask, dot_product) 116 | return loss2 117 | 118 | @staticmethod 119 | def get_masks2(lbl1, lbl2): 120 | """ 121 | takes flattened labels and identifies pos/neg of each anchor 122 | :param labels: T*V-1 123 | :param num_anchors: T 124 | :param views_per_anchor: V 125 | :return: mask, pos_maks, 126 | """ 127 | # extract mask indicating same class samples 128 | pos_mask = torch.eq(lbl1, torch.transpose(lbl2, 0, 1)).float() # mask T-T # indicator of positives 129 | neg_mask = (1 - pos_mask) # indicator of negatives 130 | return pos_mask, neg_mask 131 | 132 | def InfoNce_loss(self, pos, neg, dot): 133 | """ 134 | :param pos: V*T-V*T 135 | :param neg: V*T-V*T 136 | :param dot: V*T-V*T 137 | :return: 138 | """ 139 | # logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 140 | logits = dot # - logits_max.detach() 141 | 142 | neg_logits = torch.exp(logits) * neg 143 | neg_logits = neg_logits.sum(1, keepdim=True) 144 | 145 | exp_logits = torch.exp(logits) 146 | log_prob = logits - torch.log(exp_logits + neg_logits) 147 | 148 | pos_sums = pos.sum(1) 149 | ones = torch.ones(size=pos_sums.size()) 150 | norm = torch.where(pos_sums > 0, pos_sums, ones.to(pos.device)) 151 | 152 | mean_log_prob_pos = (pos * log_prob).sum(1) / norm # normalize by positives 153 | 154 | loss = - mean_log_prob_pos 155 | 156 | loss = loss.mean() 157 | # print('loss.mean() ', has_inf_or_nan(loss)) 158 | # print('loss {}'.format(loss)) 159 | if has_inf_or_nan(loss)[0] or has_inf_or_nan(loss)[1]: 160 | print('\n inf found in loss with positives {} and Negatives {}'.format(pos.sum(1), neg.sum(1))) 161 | return loss 162 | 163 | def get_meta(self): 164 | meta = {} 165 | meta['queue_fillings']= to_numpy(self.queue_fillings) 166 | meta['scales']= int(self.scales) 167 | return meta 168 | -------------------------------------------------------------------------------- /losses/LossWrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | # noinspection PyUnresolvedReferences 4 | from losses import * 5 | from utils import DATASETS_INFO 6 | from typing import Union 7 | 8 | 9 | class LossWrapper(nn.Module): 10 | def __init__(self, config: dict): 11 | super().__init__() 12 | self.config = config 13 | self.loss_weightings = config['losses'] 14 | self.device = config['device'] 15 | self.dataset = config['dataset'] 16 | self.experiment = config['experiment'] 17 | self.total_loss = None 18 | self.loss_classes, self.loss_vals = {}, {} 19 | self.info_string = '' 20 | self.ignore_class = (len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) - 1) \ 21 | if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] else -1 22 | for loss_class in self.loss_weightings: 23 | if loss_class == 'CrossEntropyLoss': 24 | class_weights = None 25 | if self.dataset == 'CITYSCAPES': 26 | class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 27 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 28 | 1.0865, 1.1529, 1.0507]).cuda() 29 | print(f'using class_weights {class_weights}') 30 | loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_class, weight=class_weights) 31 | # loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_class) 32 | else: 33 | loss_fct = globals()[loss_class](config) 34 | self.loss_classes.update({loss_class: loss_fct}) 35 | self.loss_vals.update({loss_class: 0}) 36 | self.info_string += loss_class + ', ' 37 | self.info_string = self.info_string[:-2] 38 | self.dc_off = True if 'dc_off_at_epoch' in self.config else False 39 | 40 | def forward(self, 41 | prediction: torch.Tensor, 42 | labels: torch.Tensor, 43 | loss_list: list = None, 44 | deep_features: Union[torch.Tensor,list] = None, 45 | interm_prediction: torch.Tensor = None, 46 | epoch: int = None, 47 | skip_mem_update: bool =False) -> torch.Tensor: 48 | self.total_loss = torch.tensor(0.0, dtype=torch.float, device=self.device) 49 | # Compile list of losses to be evaluated. If no specific 'loss_list' is passed 50 | loss_list = list(self.loss_weightings.keys()) if loss_list is None else loss_list 51 | for loss_class in self.loss_weightings: # Go through all the losses 52 | if loss_class in loss_list: # Check if this loss should be calculated 53 | if 'DenseContrastive' in loss_class: 54 | assert deep_features is not None, f'for loss_class {loss_class}, deep_features must be tensor (B,H,W,C) ' \ 55 | f'instead got {deep_features}' 56 | if loss_class == 'LovaszSoftmax': 57 | if self.dc_off and epoch is not None and epoch < self.config['dc_off_at_epoch']: 58 | loss = torch.tensor(0.0, dtype=torch.float, device=self.device) 59 | else: 60 | loss = self.loss_classes[loss_class](prediction, labels) 61 | elif loss_class == 'DenseContrastiveLoss': 62 | if self.dc_off and epoch is not None and epoch >= self.config['dc_off_at_epoch']: 63 | loss = torch.tensor(0.0, dtype=torch.float, device=self.device) 64 | else: 65 | loss = self.loss_classes[loss_class](labels, deep_features) 66 | elif loss_class == 'TwoScaleLoss': 67 | loss = self.loss_classes[loss_class](interm_prediction, prediction, labels.long()) 68 | elif loss_class == 'DenseContrastiveLossV2': 69 | loss = self.loss_classes[loss_class](labels, deep_features) 70 | elif loss_class == 'DenseContrastiveLossV2_ms': 71 | loss = self.loss_classes[loss_class](labels, deep_features) 72 | elif loss_class == 'DenseContrastiveLossV3': 73 | loss = self.loss_classes[loss_class](labels, deep_features, epoch, skip_mem_update) 74 | elif loss_class == 'DenseContrastiveLossV3_ms': 75 | loss = self.loss_classes[loss_class](labels, deep_features, epoch, skip_mem_update) 76 | # self.meta['queue'] = self.loss_classes[loss_class].queue_ptr.clone().numpy() 77 | elif loss_class == 'DenseContrastiveCenters': 78 | loss = self.loss_classes[loss_class](labels, deep_features, epoch, skip_mem_update) 79 | elif loss_class == 'OhemCrossEntropy': 80 | loss = self.loss_classes[loss_class](prediction, labels) 81 | elif loss_class == 'CrossEntropyLoss': 82 | loss = self.loss_classes[loss_class](prediction, labels) 83 | else: 84 | print("Error: Loss class '{}' not recognised!".format(loss_class)) 85 | loss = torch.tensor(0.0, dtype=torch.float, device=self.device) 86 | else: 87 | loss = torch.tensor(0.0, dtype=torch.float, device=self.device) 88 | 89 | # Calculate weighted loss 90 | loss *= self.loss_weightings[loss_class] 91 | self.loss_vals[loss_class] = loss.detach() 92 | 93 | # logging each scale seperately if ms/cs loss 94 | if loss_class == 'DenseContrastiveLossV2_ms': 95 | 96 | if hasattr(self.loss_classes[loss_class], 'ms_losses'): 97 | for scale, loss_val_ms in enumerate(self.loss_classes[loss_class].ms_losses): 98 | self.loss_vals.update({f'{loss_class}_ms{scale}':loss_val_ms}) 99 | if self.loss_classes[loss_class].cross_scale_contrast and hasattr(self.loss_classes[loss_class], 'cs_losses'): 100 | for cscale, loss_val_cs in enumerate(self.loss_classes[loss_class].cs_losses): 101 | self.loss_vals.update({f'{loss_class}_cs{cscale}':loss_val_cs}) 102 | self.total_loss += loss 103 | return self.total_loss 104 | -------------------------------------------------------------------------------- /losses/LovaszSoftmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import softmax 4 | from utils import DATASETS_INFO 5 | from itertools import filterfalse 6 | 7 | 8 | class LovaszSoftmax(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | self.eps = torch.as_tensor(1e-10) 12 | self.experiment = config['experiment'] 13 | self.dataset = config['dataset'] 14 | 15 | ignore_index_in_loss = len( 16 | DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) - 1 \ 17 | if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] else None 18 | # self.num_classes = len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) 19 | self.per_image = False if 'per_image' not in config else config['per_image'] 20 | self.classes_to_ignore = ignore_index_in_loss if 'classes_to_ignore' not in config else config['classes_to_ignore'] 21 | self.classes_to_consider = 'present' if 'classes_to_consider' not in config else config['classes_to_consider'] 22 | # classes_to_consider: 'all' for all, 'present' for classes present in labels, or a list of classes to average 23 | 24 | def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 25 | """Multi-class Lovasz-Softmax loss. Adapted from github.com/bermanmaxim/LovaszSoftmax 26 | 27 | :param prediction: NCHW tensor, raw logits from the network 28 | :param target: NHW tensor, ground truth labels 29 | :return: Lovász-Softmax loss 30 | """ 31 | p = softmax(prediction, dim=1) 32 | if self.per_image: 33 | loss = mean(self.lovasz_softmax_flat(*self.flatten_probabilities(p.unsqueeze(0), t.unsqueeze(0))) 34 | for p, t in zip(p, target)) 35 | else: 36 | loss = self.lovasz_softmax_flat(*self.flatten_probabilities(p, target)) 37 | return loss 38 | 39 | def lovasz_softmax_flat(self, prob: torch.Tensor, lbl: torch.Tensor) -> torch.Tensor: 40 | """Multi-class Lovasz-Softmax loss. Adapted from github.com/bermanmaxim/LovaszSoftmax 41 | 42 | :param prob: class probabilities at each prediction (between 0 and 1) 43 | :param lbl: ground truth labels (between 0 and C - 1) 44 | :return: Lovász-Softmax loss 45 | """ 46 | if prob.numel() == 0: 47 | # only void pixels, the gradients should be 0 48 | return prob * 0. 49 | c = prob.size(1) 50 | losses = [] 51 | class_to_sum = list(range(c)) if self.classes_to_consider in ['all', 'present'] else self.classes_to_consider 52 | 53 | if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] and c in class_to_sum: 54 | class_to_sum.remove(c) # remove ignore class which is denoted in lbl by values = c 55 | 56 | for c in class_to_sum: 57 | fg = (lbl == c).float() # foreground for class c 58 | if self.classes_to_consider is 'present' and fg.sum() == 0: 59 | continue 60 | class_pred = prob[:, c] 61 | errors = (fg - class_pred).abs() 62 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 63 | perm = perm.detach() 64 | fg_sorted = fg[perm] 65 | losses.append(torch.dot(errors_sorted, lovasz_grad(fg_sorted))) 66 | return mean(losses) 67 | 68 | def flatten_probabilities(self, prob: torch.Tensor, lbl: torch.Tensor): 69 | """ 70 | Flattens predictions in the batch 71 | """ 72 | if prob.dim() == 3: 73 | # assumes output of a sigmoid layer 74 | n, h, w = prob.size() 75 | prob = prob.view(n, 1, h, w) 76 | _, c, _, _ = prob.size() 77 | prob = prob.permute(0, 2, 3, 1).contiguous().view(-1, c) # B * H * W, C = P, C 78 | lbl = lbl.view(-1) 79 | if self.classes_to_ignore is None: 80 | return prob, lbl 81 | else: 82 | valid = (lbl != self.classes_to_ignore) 83 | vprobas = prob[valid.nonzero().squeeze()] 84 | vlabels = lbl[valid] 85 | return vprobas, vlabels 86 | 87 | 88 | def lovasz_grad(gt_sorted): 89 | """ 90 | Computes gradient of the Lovasz extension w.r.t sorted errors 91 | See Alg. 1 in paper 92 | """ 93 | p = len(gt_sorted) 94 | gts = gt_sorted.sum() 95 | intersection = gts - gt_sorted.float().cumsum(0) 96 | union = gts + (1 - gt_sorted).float().cumsum(0) 97 | jaccard = 1. - intersection / union 98 | if p > 1: # cover 1-pixel case 99 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 100 | return jaccard 101 | 102 | 103 | def isnan(x): 104 | return x != x 105 | 106 | 107 | def mean(ip: torch.Tensor, ignore_nan: bool = False, empty=0): 108 | """ 109 | nanmean compatible with generators. 110 | """ 111 | ip = iter(ip) 112 | if ignore_nan: 113 | ip = filterfalse(isnan, ip) 114 | try: 115 | n = 1 116 | acc = next(ip) 117 | except StopIteration: 118 | if empty == 'raise': 119 | raise ValueError('Empty mean') 120 | return empty 121 | for n, v in enumerate(ip, 2): 122 | acc += v 123 | if n == 1: 124 | return acc 125 | return acc / n 126 | -------------------------------------------------------------------------------- /losses/TwoScaleLoss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | from utils import DATASETS_INFO 4 | from losses import LovaszSoftmax 5 | from torch.nn import CrossEntropyLoss 6 | import torch 7 | 8 | 9 | class TwoScaleLoss(nn.Module): 10 | def __init__(self, config): 11 | """ 12 | Loads two losses one from an intermediate output and one from the final output 13 | for now it assumes the two losses are the same CE-CE or Lovasz-Lovasz etc. 14 | the weights of the two losses may vary (by default 0.4 for interm and 1.0 final) 15 | :param config: 16 | """ 17 | super(TwoScaleLoss, self).__init__() 18 | interm_loss_class = globals()[config['interm']['name']] 19 | final_loss_class = globals()[config['final']['name']] 20 | self.w_interm = config['interm']['weight'] if 'weight' in config['interm'] else 0.4 21 | self.w_final = config['final']['weight'] if 'weight' in config['final'] else 1.0 22 | self.ignore_label = -100 # if experiment is not given assume nothing is ignored 23 | self.dataset = config['dataset'] 24 | self.experiment = config['experiment'] 25 | if 'experiment' in config: 26 | self.ignore_label = len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) - 1 \ 27 | if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1].keys() \ 28 | else len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) 29 | 30 | # pass experiment id to constructors of the two losses 31 | config['interm'].update({"experiment": config['experiment'], "dataset": self.dataset}) 32 | config['final'].update({"experiment": config['experiment'], "dataset": self.dataset}) 33 | 34 | if config['interm']['name'] == 'CrossEntropyLoss' and config['final']['name'] == 'CrossEntropyLoss': 35 | class_weights = None 36 | if self.dataset == 'CITYSCAPES': 37 | class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 38 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 39 | 1.0865, 1.1529, 1.0507]).cuda() 40 | print(f'using class weights {class_weights}') 41 | self.loss_interm = interm_loss_class(*config['interm']['args'], 42 | ignore_index=self.ignore_label, weight=class_weights) 43 | self.loss_final = final_loss_class(*config['final']['args'], 44 | ignore_index=self.ignore_label, weight=class_weights) 45 | 46 | # all other losses expect a config 47 | elif config['interm']['name'] == config['final']['name']: 48 | self.loss_interm = interm_loss_class(config['interm']) 49 | self.loss_final = final_loss_class(config['final']) 50 | else: 51 | raise NotImplementedError('different losses for interm {}' 52 | ' and final {}'.format(config['interm'], config['final'])) 53 | 54 | print("intermediate loss {} with weight {}".format(interm_loss_class, self.w_interm)) 55 | print("final loss {} with weight {}".format(final_loss_class, self.w_final)) 56 | 57 | def forward(self, logits_interm, logits_final, target): 58 | # upsample intermediate if not already upsampled 59 | ph, pw = logits_interm.size(2), logits_interm.size(3) 60 | h, w = target.size(1), target.size(2) 61 | # todo add align_corners from outside -- 62 | # this was ignored until now as upsampling was happening in model.forward() 63 | # if ph != h or pw != w: 64 | # logits_interm = F.upsample(input=logits_interm, size=(h, w), mode='bilinear') 65 | loss_final = self.loss_final(logits_final, target) 66 | loss_interm = self.loss_interm(logits_interm, target) 67 | loss = loss_final * self.w_final + loss_interm * self.w_interm 68 | return loss 69 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .LovaszSoftmax import LovaszSoftmax 2 | from .DenseContrastiveLossV2 import DenseContrastiveLossV2 3 | from .DenseContrastiveLossV2_ms import DenseContrastiveLossV2_ms 4 | from .TwoScaleLoss import TwoScaleLoss 5 | from .LossWrapper import LossWrapper 6 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | # noinspection PyUnresolvedReferences 3 | from managers import * 4 | from utils import parse_config 5 | 6 | 7 | def str2bool(s:str): 8 | assert type(s), f'input argument must be str instead {s}' 9 | if s in ['True', 'true']: 10 | return True 11 | elif s in ['False', 'false']: 12 | return False 13 | else: 14 | raise ValueError(f'string {s} ') 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | 20 | parser.add_argument('-c', '--config', type=str, default='configs/FCN_train_config.json', 21 | help='Set path to configuration files, e.g. ' 22 | 'python main.py --config configs/FCN_train_config.json.') 23 | 24 | parser.add_argument('-u', '--user', type=str, default='c', 25 | help='Select user to set correct data / logging paths for your system, e.g. ' 26 | 'python main.py --user theo') 27 | 28 | parser.add_argument('-d', '--device', nargs="+", type=int, default=-1, 29 | help='Select GPU device to run the experiment one.g. --device 3') 30 | 31 | parser.add_argument('-s', '--dataset', type=str, default=-1, required=False, 32 | help='Select dataset to run the experiment one.g. --device 3') 33 | 34 | parser.add_argument('-p', '--parallel', action='store_true', 35 | help='whether to use distributed training') 36 | 37 | parser.add_argument('-debug', '--debugging', action='store_true', 38 | help='sets manager into debugging mode e.x --> cts is run with val/val split') 39 | 40 | parser.add_argument('-cdnb', '--cudnn_benchmark', type=str, default=None, required=False, 41 | help='if added in args then uses cudnn benchmark set to True ' 42 | 'else uses config ' 43 | 'else sets it to True by default') 44 | 45 | parser.add_argument('-cdne', '--cudnn_enabled', type=str, default=None, required=False, 46 | help='if added in args then uses cudnn enabled set to True ' 47 | 'else uses config ' 48 | 'else sets it to True by default') 49 | 50 | parser.add_argument('-vf', '--valid_freq', type=int, default=None, required=False, 51 | help='sets how often to run validation') 52 | 53 | parser.add_argument('-w', '--workers', type=int, default=None, required=False, 54 | help='workers for dataloader per gpu process') 55 | 56 | parser.add_argument('-ec', '--empty_cache', action='store_true', 57 | help='whether to empty cache (per gpu process) after each forward step to avoid OOM --' 58 | ' this is useful in DCV2_ms or DCV3/ms') 59 | 60 | parser.add_argument('-m', '--mode', type=str, default=None, required=False, 61 | help='mode setting e.x training, inference (see BaseManager for others)') 62 | 63 | parser.add_argument('-cpt', '--checkpoint', type=str, default=None, required=False, 64 | help='path to checkpoint folder') 65 | 66 | parser.add_argument('-bs', '--batch_size', type=int, default=None, required=False, 67 | help='batch size -- the number given is then divided by n_gpus if ddp') 68 | 69 | parser.add_argument('-ep', '--epochs', type=int, default=None, required=False, 70 | help='training epochs -- overrides config') 71 | 72 | parser.add_argument('-so', '--save_outputs', action='store_true', 73 | help='whether to save outputs for submission cts') 74 | 75 | parser.add_argument('-rfv', '--run_final_val', action='store_true', 76 | help='whether to run validation with special settings' 77 | ' at the end of training (ex using tta or sliding window inference)') 78 | 79 | parser.add_argument('-tta', '--tta', action='store_true', 80 | help='whether to tta_val at the end of training') 81 | 82 | parser.add_argument('-tsnes', '--tsne_scale', type=int, default=None, required=False, 83 | help=' stride of feats on which to apply tsne must be [4,8,16,32]') 84 | 85 | # loss args for convenience 86 | parser.add_argument('--loss', '-l', choices=[None,'ce', 'ms', 'ms_cs'], default=None, required=False, 87 | help=f'choose loss overriding config (refer to config for other options except {"[ce, ms, ms_cs]"}') 88 | 89 | args = parser.parse_args() 90 | config = parse_config(args.config, args.user, args.device, args.dataset, args.parallel) 91 | manager_class = globals()[config['manager'] + 'Manager'] 92 | print(f'requested device ids: {config["gpu_device"]}') 93 | print('parsing cmdline args') 94 | # override config 95 | config['parallel'] = args.parallel 96 | config['tsne_scale'] = args.tsne_scale 97 | if args.loss: 98 | print(f'overriding loss type in config requested [{args.loss}]') 99 | if 'ms' in args.loss: 100 | config['loss'].update({"losses": {"CrossEntropyLoss": 1, "DenseContrastiveLossV2_ms": 0.1}}) 101 | config['loss'].update({"cross_scale_contrast": False}) 102 | if config['graph']['model'] == 'UPerNet': 103 | config['graph'].update({"ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": True, "position":"backbone"}}) 104 | else: 105 | config['graph'].update({"ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": True}}) 106 | 107 | if 'cs' in args.loss: 108 | config['loss'].update({"cross_scale_contrast": True}) 109 | 110 | if args.loss == 'ce': 111 | config['loss'].update({"losses": {"CrossEntropyLoss": 1}}) 112 | if 'ms_projector' in config['graph']: 113 | del config['graph']['ms_projector'] 114 | 115 | if args.save_outputs: 116 | config['save_outputs'] = True 117 | if args.run_final_val: 118 | config['run_final_val'] = True 119 | print('going to run tta val at the end of training') 120 | if args.empty_cache: 121 | config['empty_cache'] = True 122 | print('emptying cache') 123 | if args.batch_size is not None: 124 | config['data']['batch_size'] = args.batch_size 125 | print(f'bsize {args.batch_size}') 126 | if args.epochs is not None: 127 | config['train']['epochs'] = args.epochs 128 | print(f'epochs : {args.epochs}') 129 | if args.tta: 130 | config['tta'] = True 131 | print(f'tta set to {config["tta"]}') 132 | if args.debugging: 133 | config['debugging'] = True 134 | if args.valid_freq is not None: 135 | config['valid_freq'] = args.valid_freq 136 | if args.workers is not None: 137 | config['data']['num_workers'] = args.workers 138 | print(f'workers {args.workers}') 139 | if args.mode is not None: 140 | config['mode'] = args.mode 141 | print(f'mode {args.mode}') 142 | if args.checkpoint is not None: 143 | config['load_checkpoint'] = args.checkpoint 144 | print(f'load_checkpoint set to {args.mode}') 145 | 146 | if args.cudnn_benchmark is not None: 147 | config['cudnn_benchmark'] = str2bool(args.cudnn_benchmark) 148 | if args.cudnn_enabled is not None: 149 | config['cudnn_enabled'] = str2bool(args.cudnn_enabled) 150 | 151 | manager = manager_class(config) 152 | 153 | if config['mode'] == 'training' and not manager.parallel: 154 | manager.train() 155 | elif config['mode'] == 'inference': 156 | manager.infer() 157 | elif config['mode'] == 'demo_tsne': 158 | manager.demo_tsne() 159 | elif config['mode'] == 'submission_inference': 160 | manager.submission_infer() 161 | -------------------------------------------------------------------------------- /managers/HRNet_Manager.py: -------------------------------------------------------------------------------- 1 | from managers.BaseManager import BaseManager 2 | from utils import to_comb_image, t_get_confusion_matrix, t_normalise_confusion_matrix, t_get_pixel_accuracy, \ 3 | get_matrix_fig, to_numpy, t_get_mean_iou, DATASETS_INFO, printlog 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import numpy as np 8 | import datetime 9 | from models import HRNet 10 | from losses import LossWrapper 11 | from torch.utils.tensorboard.writer import SummaryWriter 12 | from tqdm import tqdm 13 | 14 | 15 | 16 | class HRNetManager(BaseManager): 17 | 18 | def forward_step(self, img, lbl, **kwrargs): 19 | ret = dict() 20 | 21 | skip_mem_update = False 22 | if 'skip_mem_update' in kwrargs: 23 | skip_mem_update = kwrargs['skip_mem_update'] 24 | 25 | if isinstance(self.loss, LossWrapper): 26 | if self.return_features: 27 | output, proj_features = self.model(img.float()) 28 | loss = self.loss(output, lbl.long(), deep_features=proj_features, epoch=self.epoch, skip_mem_update=skip_mem_update) 29 | else: 30 | output = self.model(img.float()) 31 | proj_features = None 32 | loss = self.loss(output, lbl.long(), epoch=self.epoch) 33 | 34 | # get individual loss terms values for logging 35 | if 'individual_losses' in kwrargs: 36 | individual_losses = kwrargs['individual_losses'] 37 | for key in self.loss.loss_vals: 38 | individual_losses[key] += self.loss.loss_vals[key] 39 | ret['individual_losses'] = individual_losses 40 | 41 | else: 42 | # not using the LossWrapper module 43 | output = self.model(img.float()) 44 | proj_features = None 45 | loss = self.loss(output, lbl.long()) 46 | 47 | ret['output'] = output 48 | ret['interm_output'] = None 49 | ret['feats'] = proj_features 50 | ret['loss'] = loss 51 | 52 | if self.empty_cache: 53 | torch.cuda.empty_cache() 54 | return ret 55 | 56 | def post_process_output(self, img, output, lbl, metadata, skip_label=False): 57 | if metadata and self.dataset in ['PASCALC', 'ADE20K']: 58 | if "pw_ph_stride" in metadata: 59 | # undo padding due to fit_stride resizing 60 | pad_w, pad_h, stride = metadata["pw_ph_stride"] 61 | if pad_h > 0 or pad_w > 0: 62 | output = output[:, :, 0:output.size(2) - pad_h, 0:output.size(3) - pad_w] 63 | lbl = lbl[:, 0:output.size(2) - pad_h, 0:output.size(3) - pad_w] 64 | img = img[:, :, 0:output.size(2) - pad_h, 0:output.size(3) - pad_w] 65 | 66 | if "sh_sw_in_out" in metadata: 67 | if hasattr(self.model, 'module'): 68 | align_corners = self.model.module.align_corners 69 | else: 70 | align_corners = self.model.align_corners 71 | # undo resizing 72 | starting_size = metadata["sh_sw_in_out"][-2] 73 | # starting size is w,h due to fucking PIL 74 | output = F.interpolate(input=output, size=starting_size[-2:][::-1], 75 | mode='bilinear', align_corners=align_corners) 76 | img = F.interpolate(input=img, size=starting_size[-2:][::-1], 77 | mode='bilinear', align_corners=align_corners) 78 | lbl = metadata["original_labels"].squeeze(0).long().cuda() 79 | 80 | return img, output, lbl 81 | 82 | def train_one_epoch(self): 83 | """Train the model for one epoch""" 84 | if self.rank == 0 and self.epoch == 0 and self.parallel: 85 | printlog('worker rank {} : CREATING train_writer'.format(self.rank)) 86 | self.train_writer = SummaryWriter(log_dir = self.log_dir / 'train') 87 | 88 | self.model.train() 89 | a = datetime.datetime.now() 90 | running_confusion_matrix = 0 91 | for batch_num, batch in enumerate(self.data_loaders[self.train_schedule[self.epoch]]): 92 | if len(batch) == 2: 93 | img, lbl = batch 94 | else: 95 | img, lbl, metadata = batch 96 | # if self.debugging: 97 | # continue 98 | b = (datetime.datetime.now() - a).total_seconds() * 1000 99 | a = datetime.datetime.now() 100 | img, lbl = img.to(self.device, non_blocking=True), lbl.to(self.device, non_blocking=True) 101 | self.optimiser.zero_grad() 102 | # forward 103 | ret = self.forward_step(img, lbl) 104 | loss = ret['loss'] 105 | output = ret['output'] 106 | # backward 107 | loss.backward() 108 | self.optimiser.step() 109 | # lr scheduler 110 | if self.scheduler is not None and self.config['train']['lr_batchwise']: 111 | self.scheduler.step() 112 | 113 | if batch_num == 2 and self.debugging: 114 | break 115 | 116 | # logging 117 | confusion_matrix = t_get_confusion_matrix(output, lbl, self.dataset) 118 | running_confusion_matrix += confusion_matrix 119 | pa, pac = t_get_pixel_accuracy(confusion_matrix) 120 | mious = t_get_mean_iou(confusion_matrix, self.config['data']['experiment'], 121 | self.dataset, categories=True, calculate_mean=False, rare=True) 122 | self.train_logging(batch_num, output, img, lbl, mious, loss, pa, pac, b) 123 | 124 | if 'DenseContrastiveLoss' in self.loss.loss_classes: 125 | col_confusion_matrix = t_normalise_confusion_matrix(running_confusion_matrix, mode='col') 126 | self.train_writer.add_figure('train_confusion_matrix/col_normalised', 127 | get_matrix_fig(to_numpy(col_confusion_matrix), 128 | self.config['data']['experiment'], 129 | self.dataset), self.global_step - 1) 130 | self.loss.loss_classes['DenseContrastiveLoss'].update_confusion_matrix(col_confusion_matrix) 131 | 132 | meta = {} 133 | if 'DenseContrastiveLossV3' in self.loss.loss_classes: 134 | meta = self.loss.loss_classes['DenseContrastiveLossV3'].get_meta() 135 | elif 'DenseContrastiveCenters' in self.loss.loss_classes: 136 | meta = self.loss.loss_classes['DenseContrastiveCenters'].get_meta() 137 | 138 | if 'queue_fillings' in meta: 139 | # self.num_real_classes, dtype=torch.long 140 | self.config['queue_fillings'] = meta['queue_fillings'] 141 | self.write_info_json() 142 | 143 | if self.scheduler is not None and not self.config['train']['lr_batchwise']: 144 | self.scheduler.step() 145 | self.train_writer.add_scalar('parameters/learning_rate', self.scheduler.get_lr()[0], self.global_step) \ 146 | if self.rank == 0 else None 147 | 148 | def validate(self): 149 | """Validate the model on the validation data""" 150 | if self.rank == 0: 151 | # only process with rank 0 runs validation step 152 | if self.epoch == 0 and self.parallel: 153 | printlog(f'\n creating valid_writer ... for process rank {self.rank}') 154 | self.valid_writer = SummaryWriter(log_dir= self.log_dir / 'valid') 155 | else: 156 | return 157 | 158 | if not self.parallel: 159 | torch.backends.cudnn.benchmark = False 160 | 161 | self.model.eval() 162 | valid_loss = 0 163 | confusion_matrix = None 164 | individual_losses = dict() 165 | if isinstance(self.loss, LossWrapper): 166 | for key in self.loss.loss_vals: 167 | individual_losses[key] = 0 168 | if 'DenseContrastiveLossV3' in self.loss.loss_classes: # make loss run the non ddp version for validation 169 | self.loss.loss_classes['DenseContrastiveLossV3'].parallel = False 170 | 171 | with torch.no_grad(): 172 | for rec_num, batch in enumerate(tqdm(self.data_loaders['valid_loader'])): 173 | if len(batch) == 2: 174 | img, lbl = batch 175 | metadata = None 176 | else: 177 | img, lbl, metadata = batch 178 | img, lbl = img.to(self.device, non_blocking=True), lbl.to(self.device, non_blocking=True) 179 | 180 | # forward 181 | ret = self.forward_step(img, lbl, individual_losses=individual_losses, skip_mem_update=True) 182 | loss = ret['loss'] 183 | output = ret['output'] 184 | valid_loss += loss 185 | img, output, lbl = self.post_process_output(img, output, lbl, metadata) 186 | 187 | # logging 188 | confusion_matrix = t_get_confusion_matrix(output, lbl, self.dataset, confusion_matrix) 189 | if rec_num in np.round(np.linspace(0, len(self.data_loaders['valid_loader']) - 1, self.max_valid_imgs)): 190 | lbl_pred = torch.argmax(nn.Softmax2d()(output), dim=1) 191 | self.valid_writer.add_image( 192 | 'valid_images/record_{:02d}'.format(rec_num), 193 | to_comb_image(self.un_norm(img)[0], lbl[0], lbl_pred[0], self.config['data']['experiment'], self.dataset), 194 | self.global_step, dataformats='HWC') 195 | individual_losses= ret['individual_losses'] if 'individual_losses' in ret else individual_losses 196 | if self.debugging and rec_num == 2: 197 | break 198 | valid_loss /= len(self.data_loaders['valid_loader']) 199 | pa, pac = t_get_pixel_accuracy(confusion_matrix) 200 | mious = t_get_mean_iou(confusion_matrix, self.config['data']['experiment'], self.dataset, True, rare=True) 201 | # logging + checkpoint 202 | self.valid_logging(valid_loss, confusion_matrix, individual_losses, mious, pa, pac) 203 | 204 | if not self.parallel: 205 | torch.backends.cudnn.benchmark = True 206 | 207 | if isinstance(self.loss, LossWrapper): 208 | if 'DenseContrastiveLossV3' in self.loss.loss_classes: # reset 209 | self.loss.loss_classes['DenseContrastiveLossV3'].parallel = self.parallel 210 | -------------------------------------------------------------------------------- /managers/__init__.py: -------------------------------------------------------------------------------- 1 | # from .colorization_manager import ColorizationManager 2 | from .DeepLabv3_Manager import DeepLabv3Manager 3 | from .OCRNet_Manager import OCRNetManager 4 | from .HRNet_Manager import HRNetManager 5 | 6 | -------------------------------------------------------------------------------- /misc/figs/fig1-01-01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RViMLab/ECCV2022-multi-scale-and-cross-scale-contrastive-segmentation/c511fbcde6ac53b72d663225bdf6dded022ca1ce/misc/figs/fig1-01-01.png -------------------------------------------------------------------------------- /models/Projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Union 4 | from models.Transformers import SelfAttention 5 | from utils import printlog 6 | 7 | class Projector(nn.Module): 8 | 9 | def __init__(self, config): 10 | """ module that maps encoder features to a d-dimensional space 11 | if can be a single conv-linear (optionally) preceded by an fcn with conv-relu layers 12 | """ 13 | super().__init__() 14 | self.d = config['d'] if 'd' in config else 128 # projection dim 15 | self.c_in = config['c_in'] # input features channels (usually == output channels of resnet backbone) 16 | assert isinstance(self.c_in, list) or isinstance(self.c_in, int) 17 | # config['mlp'] list of [k,c] for Conv-Relu layers, if empty only applies Conv(c_in, d, k=1) 18 | self.mlp = config['mlp'] if 'mlp' in config else [] 19 | self.use_bn = config['use_bn'] if 'use_bn' in config else False 20 | self.transformer = config['trans'] if 'trans' in config else False 21 | self.heads = config['heads'] if 'heads' in config else 1 22 | 23 | if isinstance(self.c_in, list): 24 | self.is_ms = True 25 | self._create_ms_mlp() 26 | else: 27 | self.is_ms = False # whether the projector is multiscale 28 | self._create_mlp(self.c_in) 29 | 30 | def _create_ms_mlp(self): 31 | printlog('** creating ms projector **') 32 | for feat_id, c_in in enumerate(self.c_in): 33 | printlog(f'* scale {feat_id} feats: {c_in}') 34 | self._create_mlp(c_in, feat_id) 35 | 36 | def _create_mlp(self, c_in:int, feat_id:Union[list,int]=''): 37 | # sanity checks 38 | assert(isinstance(self.mlp, list)), 'config["mlp"] must be [[k_1, c_1, s_1], ., [k_n, c_n, s_n]] or [] ' \ 39 | 'k_i is kernel (k_i x k_i) c_i is channels and s_i is stride' 40 | first_layer_has_cout_equal_to_cin = False 41 | if len(self.mlp): 42 | for layer in self.mlp: 43 | assert(isinstance(layer, list)), f'elements of layer definition list must be lists instead got {layer}' 44 | assert(len(layer) == 3 and layer[2] in [1, 2]), 'must provide list of lists of 3 elements each' \ 45 | '[kernel, channels, stride] instead {}'.format(layer[2]) 46 | if layer[1]>0: 47 | assert(layer[0] < layer[1]), 'kernel size is first element of list, got {} {}'.format(layer[0], layer[1]) 48 | 49 | self.convs = [] 50 | c_prev = c_in 51 | if len(self.mlp): 52 | for layer_id, (k, c_out, s) in enumerate(self.mlp): 53 | if layer_id == 0 and c_out==-1: 54 | c_out = c_prev 55 | printlog('Projector creating conv layer, k_{}/c_{}/s_{}'.format(k, c_out, s)) 56 | # if use_bn --> do not use bias 57 | # p = (k + (k - 1) * (d - 1) - s + 1) // 2 58 | p = (k - s + 1) // 2 59 | self.convs.append(nn.Conv2d(c_prev, c_out, kernel_size=k, stride=s, 60 | padding=p, bias=not self.use_bn)) 61 | self.convs.append(nn.ReLU(inplace=True)) 62 | if self.use_bn: 63 | self.convs.append(nn.BatchNorm2d(c_out, momentum=0.0003)) 64 | c_prev = c_out 65 | if self.transformer: 66 | sa = SelfAttention(dim = c_prev, heads=self.heads) 67 | printlog('Projector creating transformer layer, heads_{}/c_{}'.format(self.heads, c_prev)) 68 | self.convs.append(sa) 69 | 70 | printlog('Projector creating linear layer, k_{}/c_{}/s_{}'.format(1, self.d, 1)) 71 | self.convs.append(nn.Conv2d(c_prev, self.d, kernel_size=1, stride=1)) 72 | setattr(self, f'project{feat_id}', nn.Sequential(*self.convs)) 73 | 74 | def forward(self, x:Union[list, torch.tensor]): 75 | # # x are features of shape NCHW 76 | # x = x / torch.norm(x, dim=1, keepdim=True) # Normalise along C: features vectors now lie on unit hypersphere 77 | if self.is_ms: 78 | outs = [] 79 | assert(isinstance(x, list) or isinstance(x, tuple)), f'if multiscale projector is used a list is expected as input instead got {type(x)}' 80 | for feat_id, x_i in enumerate(x): 81 | x_i = getattr(self, f'project{feat_id}')(x_i) 82 | outs.append(x_i) 83 | return outs 84 | else: 85 | if isinstance(x, list): 86 | if len(x)==1: 87 | x = x[0] 88 | else: 89 | raise ValueError(f'x is {type(x)}, of length {len(x)}') 90 | x = self.project(x) 91 | return x 92 | 93 | 94 | 95 | 96 | if __name__ == '__main__': 97 | # example 98 | feats1 = torch.rand(size=(2, 1024, 60, 120)).float() 99 | feats0 = torch.rand(size=(2, 2048, 60, 120)).float() 100 | 101 | # proj = Projector({'mlp': [[1,-1, 1], [1, 256, 1]], 'c_in': [512,512,1024,1024], 'd': 128, 'use_bn': True}) 102 | 103 | proj = Projector({'mlp': [[1,-1, 1], [1, 256, 1]], 'c_in': 2048, 'd': 128, "trans": True, "heads":1, 'use_bn': True}) 104 | 105 | # projected_feats = proj(([feats0]*2 )+([feats1]*2)) 106 | p = proj(feats0) 107 | # p_sa = SelfAttention(dim=p.shape[1])(p) 108 | print(p.shape) 109 | 110 | # print(projected_feats.shape) 111 | 112 | # for v, par in proj.named_parameters(): 113 | # if par.requires_grad: 114 | # print(v, par.data.shape, par.requires_grad) 115 | # d = proj.state_dict() 116 | # print(d) 117 | -------------------------------------------------------------------------------- /models/TTAWrapperSlide.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from typing import Union 4 | import datetime 5 | import cv2 6 | import numpy as np 7 | from utils import printlog, to_numpy, to_comb_image, un_normalise 8 | from models import TTAWrapper 9 | 10 | 11 | class TTAWrapperSlide(TTAWrapper): 12 | def __init__(self, 13 | model, 14 | scale_list, 15 | flip=True, 16 | strides:Union[tuple, None]=None, 17 | crop_size:Union[tuple, None]=None, 18 | debug=False): 19 | 20 | super().__init__(model, scale_list, flip) 21 | # self.num_classes = 19 22 | self.num_classes = 150 23 | self.crop_size = crop_size if crop_size else [512,512] 24 | self.strides = strides if strides else self.crop_size # defaults to no-overlapping sliding window 25 | # self.base_size = 2048 26 | self.base_size = 512 27 | self.debug = debug 28 | img_scale = (2048, 512) 29 | self.image_flips = [] 30 | self.image_scales = [] 31 | if self.flip: 32 | flips = [True, False] 33 | else: 34 | flips = [False] 35 | for s in self.scales: 36 | for f in flips: 37 | self.image_scales.append(((int(img_scale[0]*s), int(img_scale[1]*s)),s)) 38 | self.image_flips.append(f) 39 | printlog(f'Sliding window : strides : {self.strides} crop_size {self.crop_size} image_scales: {self.image_scales}') 40 | 41 | def inference(self, image, flip=False, scale=1.0, id_=1): 42 | # image BCHW 43 | assert image.device.type == 'cuda' 44 | size = image.size() 45 | pred = self.model(image) 46 | # done internally in model 47 | # pred = F.interpolate( 48 | # input=pred, size=size[-2:], 49 | # mode='bilinear', align_corners=self.model.align_corners 50 | # ) 51 | if flip: 52 | flip_img = to_numpy(image)[:, :, :, ::-1] 53 | flip_output = self.model(torch.from_numpy(flip_img.copy()).cuda()) 54 | # flip_output = F.interpolate( 55 | # input=flip_output, size=size[-2:], 56 | # mode='bilinear', align_corners=self.model.align_corners 57 | # ) 58 | 59 | flip_pred = to_numpy(flip_output).copy() 60 | flip_pred = torch.from_numpy(flip_pred[:, :, :, ::-1].copy()).cuda() 61 | pred += flip_pred 62 | pred = pred * 0.5 63 | if self.debug: 64 | to_comb_image(un_normalise(image[0]), torch.argmax(pred[0], 0), None, 1, 'ADE20K', save=f'pred_scale_{scale}_{id_}.png') 65 | return pred.exp() 66 | 67 | def multi_scale_aug(self, image, new_shape): 68 | new_h, new_w = new_shape 69 | image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 70 | return image 71 | 72 | def slide_infer(self, new_img, preds, count, scale, flip): 73 | stride_h = int(self.strides[0]) 74 | stride_w = int(self.strides[1]) 75 | new_h, new_w = new_img.shape[:-1] 76 | rows = int(np.ceil((new_h - self.crop_size[0]) / stride_h)) + 1 77 | cols = int(np.ceil((new_w - self.crop_size[1]) / stride_w)) + 1 78 | # preds = torch.zeros([1, self.num_classes, new_h, new_w]).cuda() 79 | # count = torch.zeros([1, 1, new_h, new_w]).cuda() 80 | id_ = 1 81 | for r in range(rows): 82 | for c in range(cols): 83 | h0 = r * stride_h 84 | w0 = c * stride_w 85 | h1 = min(h0 + self.crop_size[0], new_h) 86 | w1 = min(w0 + self.crop_size[1], new_w) # x2 87 | h0 = max(int(h1 - self.crop_size[0]), 0) # y1 88 | w0 = max(int(w1 - self.crop_size[1]), 0) # x1 89 | crop_img = new_img[h0:h1, w0:w1, :] 90 | crop_img = crop_img.transpose((2, 0, 1)) 91 | crop_img = np.expand_dims(crop_img, axis=0) 92 | crop_img = torch.from_numpy(crop_img) 93 | pred = self.inference(crop_img.cuda(), flip=flip, scale=scale, id_=id_) 94 | id_ += 1 95 | # print(w0, preds.shape[3] - w1, int(h0), preds.shape[2] - h1) 96 | preds[:, :, h0:h1, w0:w1] += pred[:, :, 0:h1 - h0, 0:w1 - w0] 97 | count[:, :, h0:h1, w0:w1] += 1 98 | # preds = preds / count 99 | # preds = preds[:, :, :height, :width] 100 | 101 | return preds, count 102 | 103 | def forward(self, x): 104 | a = datetime.datetime.now() 105 | if isinstance(x, tuple): 106 | x = x[0] 107 | assert isinstance(x, torch.tensor), f'x input must be a tensor instead got {type(x)}' 108 | batch, _, ori_height, ori_width = x.size() 109 | assert batch == 1, "only supporting batchsize 1." 110 | # x is BCHW 111 | image = to_numpy(x)[0].transpose((1, 2, 0)).copy() 112 | # x is HWC 113 | stride_h = int(self.strides[0] * 1.0) 114 | stride_w = int(self.strides[1] * 1.0) 115 | 116 | final_pred = torch.zeros([1, self.num_classes, ori_height, ori_width]).cuda() 117 | for flip, (shape, scale) in zip(self.image_flips, self.image_scales) : 118 | new_img = self.multi_scale_aug(image, new_shape=shape) 119 | height, width = new_img.shape[:-1] 120 | # if scale < 1.0: 121 | # new_img = new_img.transpose((2, 0, 1)) 122 | # new_img = np.expand_dims(new_img, axis=0) 123 | # new_img = torch.from_numpy(new_img) 124 | # preds = self.inference(new_img.cuda(), flip=True, scale=scale, id_=1) 125 | # preds = preds[:, :, 0:height, 0:width] 126 | # else: 127 | new_h, new_w = new_img.shape[:-1] 128 | preds = torch.zeros([1, self.num_classes, new_h, new_w]).cuda() 129 | count = torch.zeros([1, 1, new_h, new_w]).cuda() 130 | 131 | preds, count = self.slide_infer(new_img, preds, count, scale, flip) 132 | 133 | preds = preds / count 134 | preds = preds[:, :, :height, :width] 135 | 136 | preds = F.interpolate( 137 | preds, (ori_height, ori_width), 138 | mode='bilinear', align_corners=self.model.align_corners 139 | ) 140 | 141 | final_pred += preds 142 | if self.debug: 143 | to_comb_image(un_normalise(x[0]), torch.argmax(final_pred[0], 0), None, 1, 'ADE20K', save=f'final.png') 144 | 145 | b = (datetime.datetime.now() - a).total_seconds() * 1000 146 | print(f'\r time:{b}') 147 | return final_pred 148 | 149 | 150 | if __name__ == '__main__': 151 | import pickle 152 | from torchvision.transforms import Normalize, ToTensor, Compose, RandomCrop 153 | # from models.SegFormer import SegFormer 154 | from models.UPerNet import UPerNet 155 | import cv2 156 | from utils import to_numpy, to_comb_image, un_normalise, check_module_prefix 157 | 158 | file = open('..\\ade20k_img.pkl', 'rb') 159 | img = pickle.load(file) 160 | file.close() 161 | 162 | # path_to_chkpt = '..\\logging\\ADE20K\\20220326_185031_e1__upn_ConvNextT_sbn_DCms_cs_epochs127_bs16\\chkpts\\chkpt_epoch_126.pt' 163 | map_location = 'cuda:0' 164 | # checkpoint = torch.load(str(path_to_chkpt), map_location) 165 | 166 | config = dict() 167 | config.update({'backbone': 'ConvNextT', 'out_stride': 32, 'pretrained': False, 'dataset':'ADE20K', 168 | 'pretrained_res':224, 'pretrained_dataset':'22k' , 'align_corners':False}) 169 | model = UPerNet(config, 1) 170 | 171 | # ret = model.load_state_dict(check_module_prefix(checkpoint['model_state_dict'], model), strict=False) 172 | # print(ret) 173 | T = Compose([ 174 | ToTensor(), 175 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 176 | # RandomCrop(size=[512, 512]) 177 | ]) 178 | 179 | with torch.no_grad(): 180 | 181 | tta_model = TTAWrapperSlide(model, scale_list=[0.5, 1], crop_size=(512, 512), 182 | strides=(326, 326), debug=True) # [0.75, 1.25, 1.5, 1.75, 2, 1.0] 183 | tta_model.cuda() 184 | tta_model.eval() 185 | x = T(img) 186 | # x = x.cuda().float() 187 | y = tta_model.forward(x.unsqueeze(0).float()) 188 | c = 1 -------------------------------------------------------------------------------- /models/TTA_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parallel import DistributedDataParallel as ddp 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import datetime 6 | import cv2 7 | from utils import printlog 8 | 9 | 10 | class TTAWrapper(nn.Module): 11 | """ 12 | hard-coding common scaling and flipping protocol for simplicity 13 | """ 14 | def __init__(self, model, scale_list=None, flip=True): 15 | super().__init__() 16 | self.scales = scale_list# 1.5, 1.75, 2] # 1.5, 1.75, 2] 17 | self.flip = flip 18 | if 1.0 not in self.scales: 19 | self.scales.append(1.0) 20 | if isinstance(model, ddp): 21 | self.model = ddp.module 22 | else: 23 | self.model = model 24 | 25 | self.align_corners = self.model.align_corners if hasattr(self.model, 'align_corners') else True 26 | 27 | printlog(f'*** TTA wrapper with flip : [{flip}] --- scales : {self.scales} -- align_corners:{self.align_corners}') 28 | 29 | def maybe_resize(self, x, scale, in_shape): 30 | """ 31 | 32 | :param x: B,C,H,W 33 | :param scale: if s in R+ resizes the image to s*in_shape, 34 | if s=1 then return x, 35 | if s=-1 then resize image to in_shape 36 | :param in_shape: 37 | :return: 38 | """ 39 | scaled_shape = [int(scale * in_shape[0]), int(scale * in_shape[1])] 40 | if scale != 1.0 and scale > 0: 41 | x = F.interpolate(x, size=scaled_shape, mode='bilinear', align_corners=self.align_corners) 42 | elif scale == -1: 43 | x = F.interpolate(x, size=in_shape, mode='bilinear', align_corners=self.align_corners) 44 | else: 45 | x = x.clone() 46 | return x 47 | 48 | def maybe_flip(self, x, f): 49 | if f == 0: 50 | x_f = torch.flip(x, dims=[3]) # clones 51 | else: 52 | x_f = x.clone() 53 | return x_f 54 | 55 | def forward(self, x, **kwargs): 56 | if isinstance(x, tuple): 57 | x = x[0] 58 | assert isinstance(x, torch.tensor), f'x input must be a tensor instead got {type(x)}' 59 | 60 | a = datetime.datetime.now() 61 | assert len(x.shape)==4, 'input must be B,C,H,W' 62 | flag_first = True # flag for the first iteration of the nested loop] 63 | in_shape=x.shape[2:4] 64 | out_shape = [1, self.model.num_classes] + list(in_shape) 65 | y_merged = torch.zeros(size=out_shape).cuda() 66 | for f in range(2): 67 | x_f = self.maybe_flip(x, f) # flip 68 | for s in self.scales: 69 | x_f_s = self.maybe_resize(x_f, s, in_shape) # resize 70 | y = self.model(x_f_s) # forward 71 | y = self.maybe_flip(y, f) # unflip 72 | y_merged += self.maybe_resize(y, -1, in_shape) # un-resize 73 | 74 | b = (datetime.datetime.now() - a).total_seconds() * 1000 75 | # print('time taken for tta {:.5f}'.format(b)) 76 | y_merged = y_merged/(2*len(self.scales)) 77 | # cv2.imshow('final', to_comb_image(un_normalise(x[0]), torch.argmax(y_merged[0], 0), None, 1, 'CITYSCAPES')) 78 | return y_merged 79 | 80 | 81 | 82 | 83 | 84 | if __name__ == '__main__': 85 | import pickle 86 | from torchvision.transforms import Normalize, ToTensor, Compose, RandomCrop 87 | from models.HRNet import HRNet 88 | from models.UPerNet import UPerNet 89 | import cv2 90 | from utils import to_numpy, to_comb_image, un_normalise, check_module_prefix 91 | 92 | file = open('..\\img_cts.pkl', 'rb') 93 | img = pickle.load(file) 94 | file.close() 95 | 96 | path_to_chkpt = '..\\logging\\ADE20K\\20220326_185031_e1__upn_ConvNextT_sbn_DCms_cs_epochs127_bs16\\chkpts\\chkpt_epoch_126.pt' 97 | map_location = 'cuda:0' 98 | checkpoint = torch.load(str(path_to_chkpt), map_location) 99 | 100 | config = dict() 101 | config.update({'backbone': 'ConvNextT', 'out_stride': 32, 'pretrained': True, 'dataset':'ADE20K'}) 102 | model = UPerNet(config, 1) 103 | 104 | ret = model.load_state_dict(check_module_prefix(checkpoint['model_state_dict'], model), strict=False) 105 | print(ret) 106 | T = Compose([ 107 | ToTensor(), 108 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 109 | RandomCrop(size=[512, 512]) 110 | ]) 111 | 112 | with torch.no_grad(): 113 | tta_model = TTAWrapper(model, scale_list=[0.5, 1.5]) 114 | tta_model.cuda() 115 | tta_model.eval() 116 | x = T(img) 117 | x = x.cuda().float() 118 | y = tta_model.forward(x.unsqueeze(0)) 119 | c = 1 120 | 121 | -------------------------------------------------------------------------------- /models/TTA_wrapper_CTS.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from typing import Union 4 | import datetime 5 | import cv2 6 | import numpy as np 7 | from utils import printlog, to_numpy, to_comb_image, un_normalise 8 | from models import TTAWrapper 9 | 10 | 11 | class TTAWrapperCTS(TTAWrapper): 12 | def __init__(self, 13 | model, 14 | scale_list, 15 | flip=True, 16 | strides:Union[tuple, None]=None, 17 | crop_size:Union[tuple, None]=None, 18 | debug=False): 19 | 20 | super().__init__(model, scale_list, flip) 21 | self.num_classes = 19 22 | self.crop_size = crop_size if crop_size else [512,1024] 23 | self.strides = strides if strides else self.crop_size # defaults to no-overlapping sliding window 24 | self.base_size = 2048 25 | self.debug = debug 26 | 27 | printlog(f'Sliding window : strides : {self.strides} crop_size {self.crop_size}') 28 | 29 | def inference(self, image, flip=False, scale=1.0, id_=1): 30 | # image BCHW 31 | assert image.device.type == 'cuda' 32 | size = image.size() 33 | pred = self.model(image) 34 | # done internally in model 35 | # pred = F.interpolate( 36 | # input=pred, size=size[-2:], 37 | # mode='bilinear', align_corners=self.model.align_corners 38 | # ) 39 | if flip: 40 | flip_img = to_numpy(image)[:, :, :, ::-1] 41 | flip_output = self.model(torch.from_numpy(flip_img.copy()).cuda()) 42 | # flip_output = F.interpolate( 43 | # input=flip_output, size=size[-2:], 44 | # mode='bilinear', align_corners=self.model.align_corners 45 | # ) 46 | 47 | flip_pred = to_numpy(flip_output).copy() 48 | flip_pred = torch.from_numpy(flip_pred[:, :, :, ::-1].copy()).cuda() 49 | pred += flip_pred 50 | pred = pred * 0.5 51 | if self.debug: 52 | to_comb_image(un_normalise(image[0]), torch.argmax(pred[0], 0), None, 1, 'ADE20K', save=f'pred_scale_{scale}_{id_}.png') 53 | return pred.exp() 54 | 55 | def multi_scale_aug(self, image, label=None, 56 | rand_scale=1, rand_crop=True): 57 | 58 | long_size = int(self.base_size * rand_scale + 0.5) 59 | h, w = image.shape[:2] 60 | if h > w: 61 | new_h = long_size 62 | new_w = int(w * long_size / h + 0.5) 63 | else: 64 | new_w = long_size 65 | new_h = int(h * long_size / w + 0.5) 66 | 67 | image = cv2.resize(image, (new_w, new_h), 68 | interpolation=cv2.INTER_LINEAR) 69 | if label is not None: 70 | label = cv2.resize(label, (new_w, new_h), 71 | interpolation=cv2.INTER_NEAREST) 72 | else: 73 | return image 74 | 75 | if rand_crop: 76 | image, label = self.rand_crop(image, label) 77 | 78 | return image, label 79 | 80 | def forward(self, x): 81 | a = datetime.datetime.now() 82 | if isinstance(x, tuple): 83 | x = x[0] 84 | assert isinstance(x, torch.tensor), f'x input must be a tensor instead got {type(x)}' 85 | batch, _, ori_height, ori_width = x.size() 86 | assert batch == 1, "only supporting batchsize 1." 87 | # x is BCHW 88 | image = to_numpy(x)[0].transpose((1, 2, 0)).copy() 89 | # x is HWC 90 | stride_h = int(self.strides[0] * 1.0) 91 | stride_w = int(self.strides[1] * 1.0) 92 | 93 | final_pred = torch.zeros([1, self.num_classes, 94 | ori_height, ori_width]).cuda() 95 | 96 | for scale in self.scales: 97 | new_img = self.multi_scale_aug(image=image, 98 | rand_scale=scale, 99 | rand_crop=False) 100 | # cv2.imshow(f'scale {scale}', new_img) 101 | height, width = new_img.shape[:-1] 102 | 103 | if scale < 1.0: 104 | new_img = new_img.transpose((2, 0, 1)) 105 | new_img = np.expand_dims(new_img, axis=0) 106 | new_img = torch.from_numpy(new_img) 107 | preds = self.inference(new_img.cuda(), flip=True, scale=scale, id_=1) 108 | preds = preds[:, :, 0:height, 0:width] 109 | else: 110 | new_h, new_w = new_img.shape[:-1] 111 | rows = int(np.ceil(1.0 * (new_h - self.crop_size[0]) / stride_h)) + 1 112 | cols = int(np.ceil(1.0 * (new_w - self.crop_size[1]) / stride_w)) + 1 113 | preds = torch.zeros([1, self.num_classes, new_h, new_w]).cuda() 114 | count = torch.zeros([1, 1, new_h, new_w]).cuda() 115 | id_ = 1 116 | for r in range(rows): 117 | for c in range(cols): 118 | h0 = r * stride_h 119 | w0 = c * stride_w 120 | h1 = min(h0 + self.crop_size[0], new_h) 121 | w1 = min(w0 + self.crop_size[1], new_w) 122 | h0 = max(int(h1 - self.crop_size[0]), 0) 123 | w0 = max(int(w1 - self.crop_size[1]), 0) 124 | crop_img = new_img[h0:h1, w0:w1, :] 125 | crop_img = crop_img.transpose((2, 0, 1)) 126 | crop_img = np.expand_dims(crop_img, axis=0) 127 | crop_img = torch.from_numpy(crop_img) 128 | pred = self.inference(crop_img.cuda(), flip=self.flip, scale=scale, id_= id_) 129 | id_ += 1 130 | 131 | preds[:, :, h0:h1, w0:w1] += pred[:, :, 0:h1 - h0, 0:w1 - w0] 132 | count[:, :, h0:h1, w0:w1] += 1 133 | preds = preds / count 134 | preds = preds[:, :, :height, :width] 135 | 136 | preds = F.interpolate( 137 | preds, (ori_height, ori_width), 138 | mode='bilinear', align_corners=self.model.align_corners 139 | ) 140 | 141 | final_pred += preds 142 | if self.debug: 143 | to_comb_image(un_normalise(x[0]), torch.argmax(final_pred[0], 0), None, 1, 'CITYSCAPES', save=f'final.png') 144 | 145 | b = (datetime.datetime.now() - a).total_seconds() * 1000 146 | print(f'\r time:{b}') 147 | return final_pred 148 | 149 | 150 | if __name__ == '__main__': 151 | import pickle 152 | from torchvision.transforms import Normalize, ToTensor, Compose, RandomCrop 153 | # from models.SegFormer import SegFormer 154 | from models.UPerNet import UPerNet 155 | import cv2 156 | from utils import to_numpy, to_comb_image, un_normalise, check_module_prefix 157 | 158 | file = open('..\\ade20k_img.pkl', 'rb') 159 | img = pickle.load(file) 160 | file.close() 161 | 162 | path_to_chkpt = '..\\logging\\ADE20K\\20220326_185031_e1__upn_ConvNextT_sbn_DCms_cs_epochs127_bs16\\chkpts\\chkpt_epoch_126.pt' 163 | map_location = 'cuda:0' 164 | checkpoint = torch.load(str(path_to_chkpt), map_location) 165 | 166 | config = dict() 167 | 168 | config.update({'backbone': 'ConvNextT', 'out_stride': 32, 'pretrained': False, 'dataset':'ADE20K', 169 | 'pretrained_res':224, 'pretrained_dataset':'22k' , 'align_corners':False}) 170 | 171 | model = UPerNet(config, 1) 172 | 173 | ret = model.load_state_dict(check_module_prefix(checkpoint['model_state_dict'], model), strict=False) 174 | print(ret) 175 | T = Compose([ 176 | ToTensor(), 177 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 178 | # RandomCrop(size=[512, 512]) 179 | ]) 180 | 181 | with torch.no_grad(): 182 | 183 | tta_model = TTAWrapperCTS(model, scale_list=[0.5], crop_size=(512, 512), strides=(341, 341), debug=True) # [0.75, 1.25, 1.5, 1.75, 2, 1.0] 184 | tta_model.cuda() 185 | tta_model.eval() 186 | x = T(img) 187 | # x = x.cuda().float() 188 | y = tta_model.forward(x.unsqueeze(0).float()) 189 | c = 1 190 | -------------------------------------------------------------------------------- /models/TTA_wrapper_PC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parallel import DistributedDataParallel as ddp 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import datetime 6 | import cv2 7 | import numpy as np 8 | from utils import printlog, to_numpy, to_comb_image 9 | from models import TTAWrapper 10 | 11 | 12 | class TTAWrapperPC(TTAWrapper): 13 | def __init__(self, model, scale_list): 14 | super().__init__(model, scale_list) 15 | self.num_classes = 59 16 | self.crop_size = [512,512] 17 | self.base_size = 520 18 | 19 | def inference(self, image, flip=False, scale=1.0, id_=1): 20 | # image BCHW 21 | assert image.device.type == 'cuda' 22 | size = image.size() 23 | pred = self.model(image) 24 | # done internally in model 25 | # pred = F.interpolate( 26 | # input=pred, size=size[-2:], 27 | # mode='bilinear', align_corners=self.model.align_corners 28 | # ) 29 | if flip: 30 | flip_img = to_numpy(image)[:, :, :, ::-1] 31 | flip_output = self.model(torch.from_numpy(flip_img.copy()).cuda()) 32 | # flip_output = F.interpolate( 33 | # input=flip_output, size=size[-2:], 34 | # mode='bilinear', align_corners=self.model.align_corners 35 | # ) 36 | 37 | flip_pred = to_numpy(flip_output).copy() 38 | flip_pred = torch.from_numpy(flip_pred[:, :, :, ::-1].copy()).cuda() 39 | pred += flip_pred 40 | pred = pred * 0.5 41 | 42 | # to_comb_image(un_normalise(image[0]), torch.argmax(pred[0], 0), None, 1, 'PASCALC', save=f'pred_{self.ind}_scale_{scale}_{id_}.png') 43 | return pred.exp() 44 | 45 | def multi_scale_aug(self, image, label=None, 46 | rand_scale=1, rand_crop=True): 47 | 48 | long_size = int(self.base_size * rand_scale + 0.5) 49 | h, w = image.shape[:2] 50 | if h > w: 51 | new_h = long_size 52 | new_w = int(w * long_size / h + 0.5) 53 | else: 54 | new_w = long_size 55 | new_h = int(h * long_size / w + 0.5) 56 | 57 | image = cv2.resize(image, (new_w, new_h), 58 | interpolation=cv2.INTER_LINEAR) 59 | if label is not None: 60 | label = cv2.resize(label, (new_w, new_h), 61 | interpolation=cv2.INTER_NEAREST) 62 | else: 63 | return image 64 | 65 | if rand_crop: 66 | image, label = self.rand_crop(image, label) 67 | 68 | return image, label 69 | 70 | def pad_image(self, image, h, w, size, padvalue): 71 | pad_image = image.copy() 72 | pad_h = max(size[0] - h, 0) 73 | pad_w = max(size[1] - w, 0) 74 | if pad_h > 0 or pad_w > 0: 75 | pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0, 76 | pad_w, cv2.BORDER_CONSTANT, 77 | value=padvalue) 78 | 79 | return pad_image 80 | 81 | def forward(self, x, ind=0): 82 | self.ind = ind 83 | a = datetime.datetime.now() 84 | if isinstance(x, tuple): 85 | x = x[0] 86 | assert isinstance(x, torch.tensor), f'x input must be a tensor instead got {type(x)}' 87 | batch, _, ori_height, ori_width = x.size() 88 | assert batch == 1, "only supporting batchsize 1." 89 | # x is BCHW 90 | image = to_numpy(x)[0].transpose((1, 2, 0)).copy( ) 91 | # x is HWC 92 | stride_h = int(self.crop_size[0] * 2.0/3.0) 93 | stride_w = int(self.crop_size[1] * 2.0/3.0) 94 | final_pred = torch.zeros([1, self.num_classes, 95 | ori_height, ori_width]).cuda() 96 | 97 | mean = [0.485, 0.456, 0.406] 98 | std = [0.229, 0.224, 0.225] 99 | padvalue = -1.0 * np.array(mean)/np.array(std) 100 | 101 | for scale in self.scales: 102 | new_img = self.multi_scale_aug(image=image, 103 | rand_scale=scale, 104 | rand_crop=False) 105 | # cv2.imshow(f'scale {scale}', new_img) 106 | height, width = new_img.shape[:-1] 107 | 108 | if max(height, width) <= np.min(self.crop_size): 109 | new_img = self.pad_image(new_img, height, width, 110 | self.crop_size, padvalue) 111 | new_img = new_img.transpose((2, 0, 1)) 112 | new_img = np.expand_dims(new_img, axis=0) 113 | new_img = torch.from_numpy(new_img) 114 | preds = self.inference(new_img.cuda(), flip=True, scale=scale, id_=1) 115 | preds = preds[:, :, 0:height, 0:width] 116 | else: 117 | 118 | if height < self.crop_size[0] or width < self.crop_size[1]: 119 | new_img = self.pad_image(new_img, height, width, 120 | self.crop_size, padvalue) 121 | 122 | new_h, new_w = new_img.shape[:-1] 123 | rows = int(np.ceil(1.0 * (new_h - self.crop_size[0]) / stride_h)) + 1 124 | cols = int(np.ceil(1.0 * (new_w - self.crop_size[1]) / stride_w)) + 1 125 | preds = torch.zeros([1, self.num_classes, new_h, new_w]).cuda() 126 | count = torch.zeros([1, 1, new_h, new_w]).cuda() 127 | id_ = 1 128 | for r in range(rows): 129 | for c in range(cols): 130 | h0 = r * stride_h 131 | w0 = c * stride_w 132 | h1 = min(h0 + self.crop_size[0], new_h) 133 | w1 = min(w0 + self.crop_size[1], new_w) 134 | # h0 = max(int(h1 - self.crop_size[0]), 0) 135 | # w0 = max(int(w1 - self.crop_size[1]), 0) 136 | crop_img = new_img[h0:h1, w0:w1, :] 137 | 138 | if h1 == new_h or w1 == new_w: 139 | crop_img = self.pad_image(crop_img, 140 | h1-h0, 141 | w1-w0, 142 | self.crop_size, 143 | padvalue) 144 | 145 | crop_img = crop_img.transpose((2, 0, 1)) 146 | crop_img = np.expand_dims(crop_img, axis=0) 147 | crop_img = torch.from_numpy(crop_img) 148 | pred = self.inference(crop_img.cuda(), flip=True, scale=scale, id_= id_) 149 | id_ += 1 150 | 151 | preds[:, :, h0:h1, w0:w1] += pred[:, :, 0:h1 - h0, 0:w1 - w0] 152 | count[:, :, h0:h1, w0:w1] += 1 153 | preds = preds / count 154 | preds = preds[:, :, :height, :width] 155 | 156 | preds = F.interpolate( 157 | preds, (ori_height, ori_width), 158 | mode='bilinear', align_corners=self.model.align_corners 159 | ) 160 | 161 | final_pred += preds 162 | 163 | # final_pred = F.interpolate( 164 | # final_pred, ori_shape, 165 | # mode='bilinear', align_corners=self.model.align_corners 166 | # ) 167 | 168 | # to_comb_image(un_normalise(x[0]), torch.argmax(final_pred[0], 0), None, 1, 'PASCALC', save=f'final_{self.ind}.png') 169 | b = (datetime.datetime.now() - a).total_seconds() * 1000 170 | print(f'\r time:{b}') 171 | return final_pred 172 | 173 | 174 | if __name__ == '__main__': 175 | import pickle 176 | from torchvision.transforms import Normalize, ToTensor, Compose, RandomCrop 177 | from models.HRNet import HRNet 178 | import cv2 179 | from utils import to_numpy, to_comb_image, un_normalise, check_module_prefix 180 | from datasets import PascalC 181 | data_path = r'C:\Users\Theodoros Pissas\Documents\tresorit\PASCALC/' 182 | from torchvision.transforms import ToTensor 183 | import PIL.Image as Image 184 | d = {"dataset":'PASCALC', "experiment":1} 185 | 186 | # file = open('..\\img_cts.pkl', 'rb') 187 | # img = pickle.load(file) 188 | # file.close() 189 | 190 | from utils import parse_transform_lists 191 | import json 192 | path_to_config = '../configs/dlv3_contrastive_PC.json' 193 | with open(path_to_config, 'r') as f: 194 | config = json.load(f) 195 | 196 | transforms_list_val = config['data']['transforms_val'] 197 | # if 'torchvision_normalise' in transforms_list_val: 198 | # del transforms_list_val[-1] 199 | transforms_values_val = config['data']['transform_values_val'] 200 | transforms_dict_val = parse_transform_lists(transforms_list_val, transforms_values_val, dataset='PASCALC', experiment=1) 201 | valid_set = PascalC(root=data_path, 202 | debug=False, 203 | split='val', 204 | transforms_dict=transforms_dict_val) 205 | 206 | issues = [] 207 | valid_set.return_filename = True 208 | 209 | # if i ==5: 210 | # break 211 | path_to_chkpt = '..\\logging/PASCALC/20211216_072315_e1__hrn_200epochs_hr48_sbn_DCms_cs/chkpts/chkpt_epoch_199.pt' 212 | # path_to_chkpt = '..\\logging/PASCALC/20211215_213857_e1__hrn_200epochs_hr48_sbn_CE/chkpts/chkpt_epoch_199.pt' 213 | map_location = 'cuda:0' 214 | checkpoint = torch.load(str(path_to_chkpt), map_location) 215 | torch.manual_seed(0) 216 | config = dict() 217 | config.update({'backbone': 'hrnet48', 'out_stride': 4, 'pretrained': True, 'dataset':'PASCALC'}) 218 | model = HRNet(config, 1) 219 | msg = model.load_state_dict(check_module_prefix(checkpoint['model_state_dict'], model), strict=False) 220 | print(msg) 221 | 222 | for i, ret in enumerate(valid_set): 223 | img = ret[0] 224 | ori_shape = ret[2]['original_labels'].shape[-2:] 225 | with torch.no_grad(): 226 | tta_model = TTAWrapperPC(model, scale_list=[0.75, 0.5, 1.5]) 227 | tta_model.cuda() 228 | tta_model.eval() 229 | x = img 230 | y = tta_model.forward(x.unsqueeze(0).float(), ori_shape, i) 231 | c = 1 -------------------------------------------------------------------------------- /models/Transformers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # from lib.models.backbones.vit.helper import IntermediateSequential 3 | import torch 4 | 5 | class SelfAttention(nn.Module): 6 | def __init__( 7 | self, dim, heads=1, qkv_bias=False, qk_scale=None, dropout_rate=0.0 8 | ): 9 | super().__init__() 10 | self.num_heads = heads 11 | head_dim = dim // heads 12 | self.scale = qk_scale or head_dim ** -0.5 13 | 14 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 15 | self.dropout_rate = 0.0 16 | self.attn_drop = nn.Dropout(dropout_rate) 17 | self.proj = nn.Linear(dim, dim) 18 | self.proj_drop = nn.Dropout(dropout_rate) 19 | 20 | def forward(self, x, unflatten_output=True): 21 | H,W,was_flattened = -1,-1, False 22 | if len(x.shape)==4: 23 | was_flattened=unflatten_output 24 | B,C,H,W = x.shape 25 | x = x.permute(0,1,2,3).view(B,-1,C) # B,C,H,W --> B,H,W,C --> B,HW,C 26 | 27 | B, N, C = x.shape 28 | qkv = ( 29 | self.qkv(x) 30 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 31 | .permute(2, 0, 3, 1, 4) 32 | ) 33 | q, k, v = ( 34 | qkv[0], 35 | qkv[1], 36 | qkv[2], 37 | ) # make torchscript happy (cannot use tensor as tuple) 38 | 39 | attn = (q @ k.transpose(-2, -1)) * self.scale 40 | attn = attn.softmax(dim=-1) 41 | if self.dropout_rate > 0.0: 42 | attn = self.attn_drop(attn) 43 | 44 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 45 | x = self.proj(x) 46 | if self.dropout_rate > 0.0: 47 | x = self.proj_drop(x) 48 | if was_flattened: 49 | return x.view(B,H,W,C).permute(0,3,1,2) # B,HW,C --> B,H,W,C --> B,C,H,W 50 | return x 51 | 52 | 53 | class Residual(nn.Module): 54 | def __init__(self, fn): 55 | super().__init__() 56 | self.fn = fn 57 | 58 | def forward(self, x): 59 | return self.fn(x) + x 60 | 61 | 62 | class PreNorm(nn.Module): 63 | def __init__(self, dim, fn): 64 | super().__init__() 65 | self.norm = nn.LayerNorm(dim) 66 | self.fn = fn 67 | 68 | def forward(self, x): 69 | return self.fn(self.norm(x)) 70 | 71 | 72 | class PreNormDrop(nn.Module): 73 | def __init__(self, dim, dropout_rate, fn): 74 | super().__init__() 75 | self.norm = nn.LayerNorm(dim) 76 | self.dropout = nn.Dropout(p=dropout_rate) 77 | self.fn = fn 78 | 79 | def forward(self, x): 80 | return self.dropout(self.fn(self.norm(x))) 81 | 82 | 83 | class FeedForward(nn.Module): 84 | def __init__(self, dim, hidden_dim, dropout_rate): 85 | super().__init__() 86 | self.net = nn.Sequential( 87 | nn.Linear(dim, hidden_dim), 88 | nn.GELU(), 89 | nn.Dropout(p=dropout_rate), 90 | nn.Linear(hidden_dim, dim), 91 | nn.Dropout(p=dropout_rate), 92 | ) 93 | 94 | def forward(self, x): 95 | return self.net(x) 96 | 97 | 98 | class TransformerModel(nn.Module): 99 | def __init__( 100 | self, 101 | dim, 102 | depth, 103 | heads, 104 | mlp_dim, 105 | dropout_rate=0.1, 106 | attn_dropout_rate=0.1, 107 | ): 108 | super().__init__() 109 | layers = [] 110 | for _ in range(depth): 111 | layers.extend( 112 | [ 113 | Residual( 114 | PreNormDrop( 115 | dim, 116 | dropout_rate, 117 | SelfAttention( 118 | dim, heads=heads, dropout_rate=attn_dropout_rate 119 | ), 120 | ) 121 | ), 122 | Residual( 123 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)) 124 | ), 125 | ] 126 | ) 127 | self.net = torch.nn.Sequential(*layers) 128 | 129 | def forward(self, x): 130 | return self.net(x) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .UPerNet import UPerNet 2 | from .DeepLabv3 import DeepLabv3 3 | from .OCR import OCRNet 4 | from .Projector import Projector 5 | from .HRNet import hrnet48, hrnet32, hrnet18, HRNet 6 | from .TTA_wrapper import TTAWrapper 7 | from .TTA_wrapper_CTS import TTAWrapperCTS 8 | from .TTA_wrapper_PC import TTAWrapperPC 9 | from .TTAWrapperSlide import TTAWrapperSlide 10 | from .Transformers import SelfAttention 11 | from .Swin import SwinTransformer 12 | 13 | -------------------------------------------------------------------------------- /models/hrnet_config.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Create by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn), Rainbowsecret (yuyua@microsoft.com) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | from utils import EasyDict as CN 13 | 14 | # configs for HRNet48 15 | HRNET_48 = CN() 16 | HRNET_48.FINAL_CONV_KERNEL = 1 17 | 18 | HRNET_48.STAGE1 = CN() 19 | HRNET_48.STAGE1.NUM_MODULES = 1 20 | HRNET_48.STAGE1.NUM_BRANCHES = 1 21 | HRNET_48.STAGE1.NUM_BLOCKS = [4] 22 | HRNET_48.STAGE1.NUM_CHANNELS = [64] 23 | HRNET_48.STAGE1.BLOCK = 'BOTTLENECK' 24 | HRNET_48.STAGE1.FUSE_METHOD = 'SUM' 25 | 26 | HRNET_48.STAGE2 = CN() 27 | HRNET_48.STAGE2.NUM_MODULES = 1 28 | HRNET_48.STAGE2.NUM_BRANCHES = 2 29 | HRNET_48.STAGE2.NUM_BLOCKS = [4, 4] 30 | HRNET_48.STAGE2.NUM_CHANNELS = [48, 96] 31 | HRNET_48.STAGE2.BLOCK = 'BASIC' 32 | HRNET_48.STAGE2.FUSE_METHOD = 'SUM' 33 | 34 | HRNET_48.STAGE3 = CN() 35 | HRNET_48.STAGE3.NUM_MODULES = 4 36 | HRNET_48.STAGE3.NUM_BRANCHES = 3 37 | HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4] 38 | HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192] 39 | HRNET_48.STAGE3.BLOCK = 'BASIC' 40 | HRNET_48.STAGE3.FUSE_METHOD = 'SUM' 41 | 42 | HRNET_48.STAGE4 = CN() 43 | HRNET_48.STAGE4.NUM_MODULES = 3 44 | HRNET_48.STAGE4.NUM_BRANCHES = 4 45 | HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 46 | HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384] 47 | HRNET_48.STAGE4.BLOCK = 'BASIC' 48 | HRNET_48.STAGE4.FUSE_METHOD = 'SUM' 49 | 50 | 51 | # configs for HRNet32 52 | HRNET_32 = CN() 53 | HRNET_32.FINAL_CONV_KERNEL = 1 54 | 55 | HRNET_32.STAGE1 = CN() 56 | HRNET_32.STAGE1.NUM_MODULES = 1 57 | HRNET_32.STAGE1.NUM_BRANCHES = 1 58 | HRNET_32.STAGE1.NUM_BLOCKS = [4] 59 | HRNET_32.STAGE1.NUM_CHANNELS = [64] 60 | HRNET_32.STAGE1.BLOCK = 'BOTTLENECK' 61 | HRNET_32.STAGE1.FUSE_METHOD = 'SUM' 62 | 63 | HRNET_32.STAGE2 = CN() 64 | HRNET_32.STAGE2.NUM_MODULES = 1 65 | HRNET_32.STAGE2.NUM_BRANCHES = 2 66 | HRNET_32.STAGE2.NUM_BLOCKS = [4, 4] 67 | HRNET_32.STAGE2.NUM_CHANNELS = [32, 64] 68 | HRNET_32.STAGE2.BLOCK = 'BASIC' 69 | HRNET_32.STAGE2.FUSE_METHOD = 'SUM' 70 | 71 | HRNET_32.STAGE3 = CN() 72 | HRNET_32.STAGE3.NUM_MODULES = 4 73 | HRNET_32.STAGE3.NUM_BRANCHES = 3 74 | HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4] 75 | HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128] 76 | HRNET_32.STAGE3.BLOCK = 'BASIC' 77 | HRNET_32.STAGE3.FUSE_METHOD = 'SUM' 78 | 79 | HRNET_32.STAGE4 = CN() 80 | HRNET_32.STAGE4.NUM_MODULES = 3 81 | HRNET_32.STAGE4.NUM_BRANCHES = 4 82 | HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 83 | HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 84 | HRNET_32.STAGE4.BLOCK = 'BASIC' 85 | HRNET_32.STAGE4.FUSE_METHOD = 'SUM' 86 | 87 | 88 | # configs for HRNet18 89 | HRNET_18 = CN() 90 | HRNET_18.FINAL_CONV_KERNEL = 1 91 | 92 | HRNET_18.STAGE1 = CN() 93 | HRNET_18.STAGE1.NUM_MODULES = 1 94 | HRNET_18.STAGE1.NUM_BRANCHES = 1 95 | HRNET_18.STAGE1.NUM_BLOCKS = [4] 96 | HRNET_18.STAGE1.NUM_CHANNELS = [64] 97 | HRNET_18.STAGE1.BLOCK = 'BOTTLENECK' 98 | HRNET_18.STAGE1.FUSE_METHOD = 'SUM' 99 | 100 | HRNET_18.STAGE2 = CN() 101 | HRNET_18.STAGE2.NUM_MODULES = 1 102 | HRNET_18.STAGE2.NUM_BRANCHES = 2 103 | HRNET_18.STAGE2.NUM_BLOCKS = [4, 4] 104 | HRNET_18.STAGE2.NUM_CHANNELS = [18, 36] 105 | HRNET_18.STAGE2.BLOCK = 'BASIC' 106 | HRNET_18.STAGE2.FUSE_METHOD = 'SUM' 107 | 108 | HRNET_18.STAGE3 = CN() 109 | HRNET_18.STAGE3.NUM_MODULES = 4 110 | HRNET_18.STAGE3.NUM_BRANCHES = 3 111 | HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4] 112 | HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72] 113 | HRNET_18.STAGE3.BLOCK = 'BASIC' 114 | HRNET_18.STAGE3.FUSE_METHOD = 'SUM' 115 | 116 | HRNET_18.STAGE4 = CN() 117 | HRNET_18.STAGE4.NUM_MODULES = 3 118 | HRNET_18.STAGE4.NUM_BRANCHES = 4 119 | HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 120 | HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144] 121 | HRNET_18.STAGE4.BLOCK = 'BASIC' 122 | HRNET_18.STAGE4.FUSE_METHOD = 'SUM' 123 | 124 | 125 | MODEL_CONFIGS = { 126 | 'hrnet18': HRNET_18, 127 | 'hrnet32': HRNET_32, 128 | 'hrnet48': HRNET_48, 129 | } -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import * 2 | from .transforms import * 3 | from .np_transforms import * 4 | from .utils import * 5 | from .torch_utils import * 6 | from .repeat_factor_sampling import RepeatFactorSampler 7 | from .lr_functions import * 8 | from .distributed import * 9 | from .checkpoint_utils import * 10 | from .logger import * 11 | from .config_parsers import * 12 | from .optimizer_utils import * 13 | from .tsne_visualization import * 14 | -------------------------------------------------------------------------------- /utils/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def _check_model_param_prefix(state_dict, prefix:str): 5 | # check if parameters of model state dict contain a give prefix 6 | found_prefix_model = False 7 | for param_name in state_dict: 8 | if not param_name.startswith(prefix): 9 | found_prefix_model = False 10 | if found_prefix_model: 11 | raise Warning('module prefix found in some of the model params but not others ' 12 | '-- this will cause bugs!! -- check before proceeding') 13 | break 14 | else: 15 | found_prefix_model = True 16 | return found_prefix_model 17 | 18 | 19 | def check_module_prefix(chkpt_state_dict, model:nn.Module): 20 | found_prefix_model = _check_model_param_prefix(model.state_dict(), prefix='module.') 21 | found_prefix_chkpt = _check_model_param_prefix(chkpt_state_dict, prefix='module.') 22 | 23 | # remove prefix from chkpt_state_dict keys 24 | # if that prefix is not found in model variable names 25 | if ~found_prefix_model and found_prefix_chkpt: 26 | for k in list(chkpt_state_dict.keys()): 27 | # retain only encoder_q up to before the embedding layer 28 | if k.startswith('module.'): 29 | # remove prefix 30 | chkpt_state_dict[k[len("module."):]] = chkpt_state_dict[k] 31 | # delete renamed or unused k 32 | del chkpt_state_dict[k] 33 | 34 | return chkpt_state_dict 35 | -------------------------------------------------------------------------------- /utils/datasets_info/CADIS.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | class EasyDict(dict): 5 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 6 | 7 | def __getattr__(self, name: str) -> Any: 8 | try: 9 | return self[name] 10 | except KeyError: 11 | raise AttributeError(name) 12 | 13 | def __setattr__(self, name: str, value: Any) -> None: 14 | self[name] = value 15 | 16 | def __delattr__(self, name: str) -> None: 17 | del self[name] 18 | 19 | 20 | DATA_SPLITS = [ # Pre-defined splits of the videos, to be used generally 21 | [[1], [5]], # Split 0: debugging 22 | [[1, 3, 4, 6, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 21, 23, 24, 25], [5, 7, 16, 2, 12, 22]], # Split 1 23 | [list(range(1, 26)), [5, 7, 16, 2, 12, 22]], # Split 2 (all data) 24 | [[1, 8, 9, 10, 14, 15, 21, 23, 24], [5, 7, 16, 2, 12, 22]], # Split 3: "50% of data" (1729 frames, 49.3%) 25 | [[10, 14, 21, 24], [5, 7, 16, 2, 12, 22]], # Split 4: "25% of data" (834 frames, 23.8%) 26 | [[1, 3, 4, 6, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 21, 23, 24, 25], [5, 7, 16], [2, 12, 22]] # train-val-test 27 | ] 28 | 29 | categories_exp0 = { 30 | 'anatomies': [], 31 | 'instruments': [], 32 | 'others': [] 33 | } 34 | categories_exp1 = { 35 | 'anatomies': [0, 4, 5, 6], 36 | 'instruments': [7], 37 | 'others': [1, 2, 3], 38 | 'rare': [2] 39 | } 40 | categories_exp2 = { 41 | 'anatomies': [0, 4, 5, 6], 42 | 'instruments': [7, 8, 9, 10, 11, 12, 13, 14, 15, 16], 43 | 'others': [1, 2, 3], 44 | 'rare': [16, 10, 9, 12, 14] # picked with freq_thresh 0.2 and s.t rf > 1.5 45 | } 46 | categories_exp3 = { 47 | 'anatomies': [0, 4, 5, 6], 48 | 'instruments': [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], 49 | 'others': [1, 2, 3], 50 | 'rare': [24, 20, 21, 22, 18, 23, 19, 16, 12, 11, 14] # picked with freq_thresh 0.2 and s.t rf > 1.5 51 | } 52 | 53 | class_remapping_exp0 = { 54 | 0: [0], 55 | 1: [1], 56 | 2: [2], 57 | 3: [3], 58 | 4: [4], 59 | 5: [5], 60 | 6: [6], 61 | 7: [7], 62 | 8: [8], 63 | 9: [9], 64 | 10: [10], 65 | 11: [11], 66 | 12: [12], 67 | 13: [13], 68 | 14: [14], 69 | 15: [15], 70 | 16: [16], 71 | 17: [17], 72 | 18: [18], 73 | 19: [19], 74 | 20: [20], 75 | 21: [21], 76 | 22: [22], 77 | 23: [23], 78 | 24: [24], 79 | 25: [25], 80 | 26: [26], 81 | 27: [27], 82 | 28: [28], 83 | 29: [29], 84 | 30: [30], 85 | 31: [31], 86 | 32: [32], 87 | 33: [33], 88 | 34: [34], 89 | 35: [35] 90 | } 91 | classes_exp0 = { 92 | 0: 'Pupil', 93 | 1: 'Surgical Tape', 94 | 2: 'Hand', 95 | 3: 'Eye Retractors', 96 | 4: 'Iris', 97 | 5: 'Skin', 98 | 6: 'Cornea', 99 | 7: 'Hydrodissection Cannula', 100 | 8: 'Viscoelastic Cannula', 101 | 9: 'Capsulorhexis Cystotome', 102 | 10: 'Rycroft Cannula', 103 | 11: 'Bonn Forceps', 104 | 12: 'Primary Knife', 105 | 13: 'Phacoemulsifier Handpiece', 106 | 14: 'Lens Injector', 107 | 15: 'I/A Handpiece', 108 | 16: 'Secondary Knife', 109 | 17: 'Micromanipulator', 110 | 18: 'I/A Handpiece Handle', 111 | 19: 'Capsulorhexis Forceps', 112 | 20: 'Rycroft Cannula Handle', 113 | 21: 'Phacoemulsifier Handpiece Handle', 114 | 22: 'Capsulorhexis Cystotome Handle', 115 | 23: 'Secondary Knife Handle', 116 | 24: 'Lens Injector Handle', 117 | 25: 'Suture Needle', 118 | 26: 'Needle Holder', 119 | 27: 'Charleux Cannula', 120 | 28: 'Primary Knife Handle', 121 | 29: 'Vitrectomy Handpiece', 122 | 30: 'Mendez Ring', 123 | 31: 'Marker', 124 | 32: 'Hydrodissection Cannula Handle', 125 | 33: 'Troutman Forceps', 126 | 34: 'Cotton', 127 | 35: 'Iris Hooks' 128 | } 129 | 130 | class_remapping_exp1 = { 131 | 0: [0], 132 | 1: [1], 133 | 2: [2], 134 | 3: [3], 135 | 4: [4], 136 | 5: [5], 137 | 6: [6], 138 | 7: [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 139 | 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], 140 | } 141 | classes_exp1 = { 142 | 0: "Pupil", 143 | 1: "Surgical Tape", 144 | 2: "Hand", 145 | 3: "Eye Retractors", 146 | 4: "Iris", 147 | 5: "Skin", 148 | 6: "Cornea", 149 | 7: "Instrument", 150 | } 151 | 152 | class_remapping_exp2 = { 153 | 0: [0], 154 | 1: [1], 155 | 2: [2], 156 | 3: [3], 157 | 4: [4], 158 | 5: [5], 159 | 6: [6], 160 | 7: [7, 8, 10, 27, 20, 32], 161 | 8: [9, 22], 162 | 9: [11, 33], 163 | 10: [12, 28], 164 | 11: [13, 21], 165 | 12: [14, 24], 166 | 13: [15, 18], 167 | 14: [16, 23], 168 | 15: [17], 169 | 16: [19], 170 | 255: [25, 26, 29, 30, 31, 34, 35], 171 | } 172 | classes_exp2 = { 173 | 0: "Pupil", 174 | 1: "Surgical Tape", 175 | 2: "Hand", 176 | 3: "Eye Retractors", 177 | 4: "Iris", 178 | 5: "Skin", 179 | 6: "Cornea", 180 | 7: "Cannula", 181 | 8: "Cap. Cystotome", 182 | 9: "Tissue Forceps", 183 | 10: "Primary Knife", 184 | 11: "Ph. Handpiece", 185 | 12: "Lens Injector", 186 | 13: "I/A Handpiece", 187 | 14: "Secondary Knife", 188 | 15: "Micromanipulator", 189 | 16: "Cap. Forceps", 190 | 255: "Ignore", 191 | } 192 | 193 | class_remapping_exp3 = { 194 | 0: [0], 195 | 1: [1], 196 | 2: [2], 197 | 3: [3], 198 | 4: [4], 199 | 5: [5], 200 | 6: [6], 201 | 7: [7], 202 | 8: [8], 203 | 9: [9], 204 | 10: [10], 205 | 11: [11], 206 | 12: [12], 207 | 13: [13], 208 | 14: [14], 209 | 15: [15], 210 | 16: [16], 211 | 17: [17], 212 | 18: [18], 213 | 19: [19], 214 | 20: [20], 215 | 21: [21], 216 | 22: [22], 217 | 23: [23], 218 | 24: [24], 219 | 255: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], 220 | } 221 | classes_exp3 = { 222 | 0: "Pupil", 223 | 1: "Surgical Tape", 224 | 2: "Hand", 225 | 3: "Eye Retractors", 226 | 4: "Iris", 227 | 5: "Skin", 228 | 6: "Cornea", 229 | 7: "Hydro. Cannula", 230 | 8: "Visc. Cannula", 231 | 9: "Cap. Cystotome", 232 | 10: "Rycroft Cannula", 233 | 11: "Bonn Forceps", 234 | 12: "Primary Knife", 235 | 13: "Ph. Handpiece", 236 | 14: "Lens Injector", 237 | 15: "I/A Handpiece", 238 | 16: "Secondary Knife", 239 | 17: "Micromanipulator", 240 | 18: "I/A Handpiece Handle", 241 | 19: "Cap. Forceps", 242 | 20: "R. Cannula Handle", 243 | 21: "Ph. Handpiece Handle", 244 | 22: "Cap. Cystotome Handle", 245 | 23: "Sec. Knife Handle", 246 | 24: "Lens Injector Handle", 247 | 255: "Ignore", 248 | } 249 | 250 | CLASS_INFO = [ 251 | [class_remapping_exp0, classes_exp0, categories_exp0], # Original classes 252 | [class_remapping_exp1, classes_exp1, categories_exp1], 253 | [class_remapping_exp2, classes_exp2, categories_exp2], 254 | [class_remapping_exp3, classes_exp3, categories_exp3] 255 | ] 256 | 257 | CLASS_NAMES = [[CLASS_INFO[0][1][key] for key in sorted(CLASS_INFO[0][1].keys())], 258 | [CLASS_INFO[1][1][key] for key in sorted(CLASS_INFO[1][1].keys())], 259 | [CLASS_INFO[2][1][key] for key in sorted(CLASS_INFO[2][1].keys())], 260 | [CLASS_INFO[3][1][key] for key in sorted(CLASS_INFO[3][1].keys())]] 261 | 262 | OVERSAMPLING_PRESETS = { 263 | 'default': [ 264 | [3, 5, 7], # Experiment 1 265 | [7, 8, 15, 16], # Experiment 2 266 | [19, 20, 22, 24] # Experiment 3 267 | ], 268 | 'rare': [ # Same classes as 'rare' category for mIoU metric 269 | [2], # Experiment 1 270 | [16, 10, 9, 12, 14], # Experiment 2 271 | [24, 20, 21, 22, 18, 23, 19, 16, 12, 11, 14] # Experiment 3 272 | ] 273 | } 274 | 275 | CLASS_FREQUENCIES = [ 276 | 1.68024535e-01, 277 | 5.93061223e-02, 278 | 7.38987570e-03, 279 | 5.72173439e-03, 280 | 1.12288211e-01, 281 | 1.33608027e-01, 282 | 4.89257831e-01, 283 | 1.26300163e-03, 284 | 8.96526043e-04, 285 | 9.28408858e-04, 286 | 6.47719387e-04, 287 | 2.61340734e-03, 288 | 1.40455685e-03, 289 | 1.84766048e-03, 290 | 3.25327478e-03, 291 | 3.60986861e-03, 292 | 1.06050077e-03, 293 | 1.97264561e-03, 294 | 5.32642854e-04, 295 | 7.07037962e-04, 296 | 3.66272768e-04, 297 | 4.75095501e-04, 298 | 1.73250919e-04, 299 | 5.49602466e-04, 300 | 2.91966965e-04, 301 | 1.06066764e-05, 302 | 1.54437472e-04, 303 | 4.16546878e-05, 304 | 2.96828324e-06, 305 | 1.02785378e-04, 306 | 4.38665256e-04, 307 | 4.91079867e-04, 308 | 1.13576281e-05, 309 | 1.83788200e-04, 310 | 1.37330396e-04, 311 | 2.35550169e-04 312 | ] 313 | CLASS_SUMS = [ 314 | 406775301, 315 | 143575852, 316 | 17890357, 317 | 13851907, 318 | 271841675, 319 | 323455413, 320 | 1184457982, 321 | 3057636, 322 | 2170425, 323 | 2247611, 324 | 1568082, 325 | 6326871, 326 | 3400331, 327 | 4473053, 328 | 7875944, 329 | 8739232, 330 | 2567396, 331 | 4775633, 332 | 1289490, 333 | 1711688, 334 | 886720, 335 | 1150172, 336 | 419428, 337 | 1330548, 338 | 706831, 339 | 25678, 340 | 373882, 341 | 100843, 342 | 7186, 343 | 248836, 344 | 1061977, 345 | 1188869, 346 | 27496, 347 | 444938, 348 | 332467, 349 | 570250 350 | ] 351 | 352 | CADIS_INFO = EasyDict(CLASS_INFO=CLASS_INFO, 353 | CLASS_NAMES=CLASS_NAMES, 354 | DATA_SPLITS=DATA_SPLITS, 355 | OVERSAMPLING_PRESETS=OVERSAMPLING_PRESETS, 356 | CLASS_FREQUENCIES=CLASS_FREQUENCIES, 357 | CLASS_SUMS=CLASS_SUMS) 358 | -------------------------------------------------------------------------------- /utils/datasets_info/CITYSCAPES.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | class EasyDict(dict): 5 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 6 | 7 | def __getattr__(self, name: str) -> Any: 8 | try: 9 | return self[name] 10 | except KeyError: 11 | raise AttributeError(name) 12 | 13 | def __setattr__(self, name: str, value: Any) -> None: 14 | self[name] = value 15 | 16 | def __delattr__(self, name: str) -> None: 17 | del self[name] 18 | 19 | 20 | categories_exp0 = { 21 | 'void': [0, 1, 2, 3, 4, 5, 6], 22 | 'flat': [7, 8, 9, 10], 23 | 'construction': [11, 12, 13, 14, 15, 16], 24 | 'object': [17, 18, 19, 20], 25 | 'nature': [21, 22], 26 | 'sky': [23], 27 | 'human': [24, 25], 28 | 'vehicle': [26, 27, 28, 29, 30, 31, 32, 33] 29 | } 30 | 31 | categories_exp1 = { 32 | 'flat': [0, 1], 33 | 'construction': [2, 3, 4], 34 | 'object': [5, 6, 7], 35 | 'nature': [8, 9], 36 | 'sky': [10], 37 | 'human': [11, 12], 38 | 'vehicle': [13, 14, 15, 16, 17, 18] 39 | } 40 | 41 | class_remapping_exp0 = { 42 | 0: [0], 43 | 1: [1], 44 | 2: [2], 45 | 3: [3], 46 | 4: [4], 47 | 5: [5], 48 | 6: [6], 49 | 7: [7], 50 | 8: [8], 51 | 9: [9], 52 | 10: [10], 53 | 11: [11], 54 | 12: [12], 55 | 13: [13], 56 | 14: [14], 57 | 15: [15], 58 | 16: [16], 59 | 17: [17], 60 | 18: [18], 61 | 19: [19], 62 | 20: [20], 63 | 21: [21], 64 | 22: [22], 65 | 23: [23], 66 | 24: [24], 67 | 25: [25], 68 | 26: [26], 69 | 27: [27], 70 | 28: [28], 71 | 29: [29], 72 | 30: [30], 73 | 31: [31], 74 | 32: [32], 75 | 33: [33], 76 | -1: [-1] 77 | } 78 | classes_exp0 = { 79 | 0: 'unlabeled', 80 | 1: 'ego vehicle', 81 | 2: 'rectification border', 82 | 3: 'out of roi', 83 | 4: 'static', 84 | 5: 'dynamic', 85 | 6: 'ground', 86 | 7: 'road', 87 | 8: 'sidewalk', 88 | 9: 'parking', 89 | 10: 'rail track', 90 | 11: 'building', 91 | 12: 'wall', 92 | 13: 'fence', 93 | 14: 'guard rail', 94 | 15: 'bridge', 95 | 16: 'tunnel', 96 | 17: 'pole', 97 | 18: 'polegroup', 98 | 19: 'traffic light', 99 | 20: 'traffic sign', 100 | 21: 'vegetation', 101 | 22: 'terrain', 102 | 23: 'sky', 103 | 24: 'person', 104 | 25: 'rider', 105 | 26: 'car', 106 | 27: 'truck', 107 | 28: 'bus', 108 | 29: 'caravan', 109 | 30: 'trailer', 110 | 31: 'train', 111 | 32: 'motorcycle', 112 | 33: 'bicycle', 113 | 34: 'Cotton', 114 | 35: 'Iris Hooks', 115 | -1: 'license plate' 116 | } 117 | 118 | class_remapping_exp1 = { 119 | 0: [7], 120 | 1: [8], 121 | 2: [11], 122 | 3: [12], 123 | 4: [13], 124 | 5: [17], 125 | 6: [19], 126 | 7: [20], 127 | 8: [21], 128 | 9: [22], 129 | 10: [23], 130 | 11: [24], 131 | 12: [25], 132 | 13: [26], 133 | 14: [27], 134 | 15: [28], 135 | 16: [31], 136 | 17: [32], 137 | 18: [33], 138 | 255: [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 139 | } 140 | 141 | 142 | classes_exp1 = { 143 | 0: 'road', 144 | 1: 'sidewalk', 145 | 2: 'building', 146 | 3: 'wall', 147 | 4: 'fence', 148 | 5: 'pole', 149 | 6: 'traffic light', 150 | 7: 'traffic sign', 151 | 8: 'vegetation', 152 | 9: 'terrain', 153 | 10: 'sky', 154 | 11: 'person', 155 | 12: 'rider', 156 | 13: 'car', 157 | 14: 'truck', 158 | 15: 'bus', 159 | 16: 'train', 160 | 17: 'motorcycle', 161 | 18: 'bicycle', 162 | 255: 'Ignore' 163 | } 164 | 165 | 166 | CLASS_INFO = [ 167 | [class_remapping_exp0, classes_exp0, categories_exp0], # Original classes 168 | [class_remapping_exp1, classes_exp1, categories_exp1] 169 | ] 170 | 171 | CLASS_NAMES = [[CLASS_INFO[0][1][key] for key in sorted(CLASS_INFO[0][1].keys())], 172 | [CLASS_INFO[1][1][key] for key in sorted(CLASS_INFO[1][1].keys())]] 173 | 174 | CITYSCAPES_INFO = EasyDict(CLASS_INFO=CLASS_INFO, CLASS_NAMES=CLASS_NAMES) 175 | 176 | if __name__ == '__main__': 177 | # all info is in a class attribute of the Cityscapes class 178 | from torchvision.datasets.cityscapes import Cityscapes 179 | CTS_info = Cityscapes.classes 180 | categories_exp1 = {} 181 | colormap = {} 182 | ingored_colormap = {} 183 | class_remap_exp1 = {} 184 | categ_exp0 = {} 185 | categ_exp1 = {} 186 | for cl in CTS_info: 187 | ############################################ 188 | classes_exp0[cl.id] = cl.name 189 | colormap[cl.id] = cl.color 190 | # ingored_colormap[cl.train_id] = cl.color 191 | ############################################ 192 | if cl.train_id in class_remap_exp1: 193 | class_remap_exp1[cl.train_id] += [cl.id] 194 | else: 195 | # -1 mapped to 255 which is used as the ignored class 196 | class_remap_exp1[cl.train_id] = [cl.id] 197 | classes_exp1[cl.train_id] = cl.name 198 | 199 | if cl.category not in categ_exp0: 200 | categ_exp0[cl.category] = [cl.id] 201 | else: 202 | categ_exp0[cl.category] += [cl.id] 203 | 204 | if cl.category not in categ_exp1: 205 | categ_exp1[cl.category] = [cl.train_id] 206 | else: 207 | categ_exp1[cl.category] += [cl.train_id] 208 | 209 | class_remap_exp1.pop(-1) # remove -1 from dictionary 210 | class_remap_exp1[255] += [-1] # and place it in the ignore class 211 | 212 | classes_exp1[255] = 'Ignore' 213 | classes_exp1.pop(-1) # remove -1 from dictionary 214 | 215 | a = 1 216 | -------------------------------------------------------------------------------- /utils/datasets_info/PASCALC.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | class EasyDict(dict): 5 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 6 | 7 | def __getattr__(self, name: str) -> Any: 8 | try: 9 | return self[name] 10 | except KeyError: 11 | raise AttributeError(name) 12 | 13 | def __setattr__(self, name: str, value: Any) -> None: 14 | self[name] = value 15 | 16 | def __delattr__(self, name: str) -> None: 17 | del self[name] 18 | 19 | 20 | categories_exp0 = { 21 | 'flat': [1, 2], 22 | } 23 | 24 | categories_exp1 = { 25 | 'flat': [1, 2], 26 | } 27 | 28 | class_remapping_exp0 = { 29 | 0:[255], 30 | 1: [1], 31 | 2: [2], 32 | 3: [3], 33 | 4: [4], 34 | 5: [5], 35 | 6: [6], 36 | 7: [7], 37 | 8: [8], 38 | 9: [9], 39 | 10: [10], 40 | 11: [11], 41 | 12: [12], 42 | 13: [13], 43 | 14: [14], 44 | 15: [15], 45 | 16: [16], 46 | 17: [17], 47 | 18: [18], 48 | 19: [19], 49 | 20: [20], 50 | 21: [21], 51 | 22: [22], 52 | 23: [23], 53 | 24: [24], 54 | 25: [25], 55 | 26: [26], 56 | 27: [27], 57 | 28: [28], 58 | 29: [29], 59 | 30: [30], 60 | 31: [31], 61 | 32: [32], 62 | 33: [33], 63 | 34: [34], 64 | 35: [35], 65 | 36: [36], 66 | 37: [37], 67 | 38: [38], 68 | 39: [39], 69 | 40: [40], 70 | 41: [41], 71 | 42: [42], 72 | 43: [43], 73 | 44: [44], 74 | 45: [45], 75 | 46: [46], 76 | 47: [47], 77 | 48: [48], 78 | 49: [49], 79 | 50: [50], 80 | 51: [51], 81 | 52: [52], 82 | 53: [53], 83 | 54: [54], 84 | 55: [55], 85 | 56: [56], 86 | 57: [57], 87 | 58: [58], 88 | 59: [59] 89 | } 90 | # 91 | # CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench', 92 | # 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus', 93 | # 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth', 94 | # 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence', 95 | # 'floor', 'flower', 'food', 'grass', 'ground', 'horse', 96 | # 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person', 97 | # 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep', 98 | # 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table', 99 | # 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water', 100 | # 'window', 'wood') 101 | # 102 | # PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 103 | # [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 104 | # [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 105 | # [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 106 | # [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 107 | # [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 108 | # [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 109 | # [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 110 | # [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 111 | # [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 112 | # [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 113 | # [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 114 | # [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 115 | # [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 116 | # [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] 117 | 118 | classes_exp0 = { 119 | 0: "background", 120 | 1: "aeroplane", 121 | 2: "bag", 122 | 3: "bed", 123 | 4: "bedclothes", 124 | 5: "bench", 125 | 6: "bicycle", 126 | 7: "bird", 127 | 8: "boat", 128 | 9: "book", 129 | 10: "bottle", 130 | 11: "building", 131 | 12: "bus", 132 | 13: "cabinet", 133 | 14: "car", 134 | 15: "cat", 135 | 16: "ceiling", 136 | 17: "chair", 137 | 18: "cloth", 138 | 19: "computer", 139 | 20: "cow", 140 | 21: "cup", 141 | 22: "curtain", 142 | 23: "dog", 143 | 24: "door", 144 | 25: "fence", 145 | 26: "floor", 146 | 27: "flower", 147 | 28: "food", 148 | 29: "grass", 149 | 30: "ground", 150 | 31: "horse", 151 | 32: "keyboard", 152 | 33: "light", 153 | 34: "motorbike", 154 | 35: "mountain", 155 | 36: "mouse", 156 | 37: "person", 157 | 38: "plate", 158 | 39: "platform", 159 | 40: "pottedplant", 160 | 41: "road", 161 | 42: "rock", 162 | 43: "sheep", 163 | 44: "shelves", 164 | 45: "sidewalk", 165 | 46: "sign", 166 | 47: "sky", 167 | 48: "snow", 168 | 49: "sofa", 169 | 50: "table", 170 | 51: "track", 171 | 52: "train", 172 | 53: "tree", 173 | 54: "truck", 174 | 55: "tvmonitor", 175 | 56: "wall", 176 | 57: "water", 177 | 58: "window", 178 | 59: "wood" 179 | } 180 | 181 | 182 | class_remapping_exp1 = { 183 | 255: [0], 184 | 0: [1], 185 | 1: [2], 186 | 2: [3], 187 | 3: [4], 188 | 4: [5], 189 | 5: [6], 190 | 6: [7], 191 | 7: [8], 192 | 8: [9], 193 | 9: [10], 194 | 10: [11], 195 | 11: [12], 196 | 12: [13], 197 | 13: [14], 198 | 14: [15], 199 | 15: [16], 200 | 16: [17], 201 | 17: [18], 202 | 18: [19], 203 | 19: [20], 204 | 20: [21], 205 | 21: [22], 206 | 22: [23], 207 | 23: [24], 208 | 24: [25], 209 | 25: [26], 210 | 26: [27], 211 | 27: [28], 212 | 28: [29], 213 | 29: [30], 214 | 30: [31], 215 | 31: [32], 216 | 32: [33], 217 | 33: [34], 218 | 34: [35], 219 | 35: [36], 220 | 36: [37], 221 | 37: [38], 222 | 38: [39], 223 | 39: [40], 224 | 40: [41], 225 | 41: [42], 226 | 42: [43], 227 | 43: [44], 228 | 44: [45], 229 | 45: [46], 230 | 46: [47], 231 | 47: [48], 232 | 48: [49], 233 | 49: [50], 234 | 50: [51], 235 | 51: [52], 236 | 52: [53], 237 | 53: [54], 238 | 54: [55], 239 | 55: [56], 240 | 56: [57], 241 | 57: [58], 242 | 58: [59] 243 | } 244 | 245 | 246 | classes_exp1 = { 247 | 255: "background", 248 | 0: "aeroplane", 249 | 1: "bag", 250 | 2: "bed", 251 | 3: "bedclothes", 252 | 4: "bench", 253 | 5: "bicycle", 254 | 6: "bird", 255 | 7: "boat", 256 | 8: "book", 257 | 9: "bottle", 258 | 10: "building", 259 | 11: "bus", 260 | 12: "cabinet", 261 | 13: "car", 262 | 14: "cat", 263 | 15: "ceiling", 264 | 16: "chair", 265 | 17: "cloth", 266 | 18: "computer", 267 | 19: "cow", 268 | 20: "cup", 269 | 21: "curtain", 270 | 22: "dog", 271 | 23: "door", 272 | 24: "fence", 273 | 25: "floor", 274 | 26: "flower", 275 | 27: "food", 276 | 28: "grass", 277 | 29: "ground", 278 | 30: "horse", 279 | 31: "keyboard", 280 | 32: "light", 281 | 33: "motorbike", 282 | 34: "mountain", 283 | 35: "mouse", 284 | 36: "person", 285 | 37: "plate", 286 | 38: "platform", 287 | 39: "pottedplant", 288 | 40: "road", 289 | 41: "rock", 290 | 42: "sheep", 291 | 43: "shelves", 292 | 44: "sidewalk", 293 | 45: "sign", 294 | 46: "sky", 295 | 47: "snow", 296 | 48: "sofa", 297 | 49: "table", 298 | 50: "track", 299 | 51: "train", 300 | 52: "tree", 301 | 53: "truck", 302 | 54: "tvmonitor", 303 | 55: "wall", 304 | 56: "water", 305 | 57: "window", 306 | 58: "wood" 307 | } 308 | 309 | 310 | CLASS_INFO = [ 311 | [class_remapping_exp0, classes_exp0, categories_exp0], # Original classes 312 | [class_remapping_exp1, classes_exp1, categories_exp1] 313 | ] 314 | 315 | CLASS_NAMES = [[CLASS_INFO[0][1][key] for key in sorted(CLASS_INFO[0][1].keys())], 316 | [CLASS_INFO[1][1][key] for key in sorted(CLASS_INFO[1][1].keys())]] 317 | 318 | PASCALC_INFO = EasyDict(CLASS_INFO=CLASS_INFO, CLASS_NAMES=CLASS_NAMES) 319 | 320 | 321 | def label_sanity_check(root=None): 322 | import cv2 323 | import warnings 324 | import pathlib 325 | import numpy as np 326 | warning = 0 327 | warning_msg = [] 328 | if root == None: 329 | root = pathlib.Path(r"C:\Users\Theodoros Pissas\Documents\tresorit\PASCALC\val\label/") 330 | for path_to_label in root.glob('**/*.PNG'): 331 | i = cv2.imread(str(path_to_label)) 332 | labels_present = np.unique(i) 333 | print(f'{path_to_label.stem} : {labels_present}') 334 | if max(labels_present) > 59: 335 | warnings.warn(f'invalid label found {labels_present}') 336 | warning += 1 337 | warning_msg.append(f'invalid label found {labels_present}') 338 | return warning_msg, warning 339 | 340 | def class_dict_from_txt(): 341 | d = dict() 342 | content = open('pascal.txt').read() 343 | print('{') 344 | for i in content.split('\n'): 345 | key = i.split(':')[0] 346 | val = i.split(':')[-1] 347 | # print(key, val) 348 | d[int(key)] = val 349 | val = val.replace(" ", "") 350 | print(f'{key}:"{val}",') 351 | print('}') 352 | if __name__ == '__main__': 353 | # label_sanity_check() 354 | # class_dict_from_txt() 355 | 356 | # for i in classes_exp0: 357 | # # for remapping 358 | # # print(f'{i-1}:{[i]},') 359 | from utils import get_pascalc_colormap 360 | # A = PALETTE 361 | # print(f'{i - 1}:"{classes_exp0[i]}",') 362 | for i, c in enumerate(classes_exp0): 363 | print(f'{i-1}:"{c}",') -------------------------------------------------------------------------------- /utils/datasets_info/__init__.py: -------------------------------------------------------------------------------- 1 | from .CITYSCAPES import * 2 | from .CADIS import * 3 | from .PASCALC import * 4 | from .ADE20K import * 5 | -------------------------------------------------------------------------------- /utils/defaults.py: -------------------------------------------------------------------------------- 1 | from .datasets_info import CITYSCAPES_INFO, CADIS_INFO, PASCALC_INFO, ADE20K_INFO 2 | import numpy as np 3 | from typing import Any 4 | 5 | 6 | class EasyDict(dict): 7 | """Convenience class that behaves like a dict but allows access with the attribute syntax.""" 8 | 9 | def __getattr__(self, name: str) -> Any: 10 | try: 11 | return self[name] 12 | except KeyError: 13 | raise AttributeError(name) 14 | 15 | def __setattr__(self, name: str, value: Any) -> None: 16 | self[name] = value 17 | 18 | def __delattr__(self, name: str) -> None: 19 | del self[name] 20 | 21 | 22 | DATASETS_INFO = EasyDict(CADIS=CADIS_INFO, CITYSCAPES=CITYSCAPES_INFO, PASCALC=PASCALC_INFO, ADE20K=ADE20K_INFO) 23 | 24 | 25 | def get_cityscapes_colormap(): 26 | """ 27 | Returns cityscapes colormap as in paper 28 | :return: ndarray of rgb colors 29 | """ 30 | return np.asarray( 31 | [(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81), (128, 64, 128), 32 | (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153), 33 | (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), 34 | (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), 35 | (0, 0, 70), (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)] 36 | ) 37 | 38 | 39 | def get_cadis_colormap(): 40 | """ 41 | Returns cadis colormap as in paper 42 | :return: ndarray of rgb colors 43 | """ 44 | return np.asarray( 45 | [ 46 | [0, 137, 255], 47 | [255, 165, 0], 48 | [255, 156, 201], 49 | [99, 0, 255], 50 | [255, 0, 0], 51 | [255, 0, 165], 52 | [255, 255, 255], 53 | [141, 141, 141], 54 | [255, 218, 0], 55 | [173, 156, 255], 56 | [73, 73, 73], 57 | [250, 213, 255], 58 | [255, 156, 156], 59 | [99, 255, 0], 60 | [157, 225, 255], 61 | [255, 89, 124], 62 | [173, 255, 156], 63 | [255, 60, 0], 64 | [40, 0, 255], 65 | [170, 124, 0], 66 | [188, 255, 0], 67 | [0, 207, 255], 68 | [0, 255, 207], 69 | [188, 0, 255], 70 | [243, 0, 255], 71 | [0, 203, 108], 72 | [252, 255, 0], 73 | [93, 182, 177], 74 | [0, 81, 203], 75 | [211, 183, 120], 76 | [231, 203, 0], 77 | [0, 124, 255], 78 | [10, 91, 44], 79 | [2, 0, 60], 80 | [0, 144, 2], 81 | [133, 59, 59], 82 | ] 83 | ) 84 | 85 | 86 | def get_pascalc_colormap(): 87 | cmap = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 88 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 89 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 90 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 91 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 92 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 93 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 94 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 95 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 96 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 97 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 98 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 99 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 100 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 101 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] 102 | return cmap 103 | 104 | 105 | def get_ade20k_colormap(): 106 | # 151 VALUES , CMAP[0] IS IGNORED 107 | cmap = [[0,0,0], [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 108 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 109 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 110 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 111 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 112 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 113 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 114 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 115 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 116 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 117 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 118 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 119 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 120 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 121 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], 122 | [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 123 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], 124 | [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 125 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], 126 | [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 127 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], 128 | [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 129 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], 130 | [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 131 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], 132 | [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 133 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], 134 | [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 135 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], 136 | [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 137 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], 138 | [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 139 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], 140 | [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 141 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], 142 | [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 143 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], 144 | [102, 255, 0], [92, 0, 255]] 145 | return cmap 146 | 147 | 148 | def get_iacl_colormap(): 149 | cmap = [[0, 0, 127], 150 | [0, 0, 254], 151 | [0, 96, 256], 152 | [0, 212, 255], 153 | [76, 255, 170], 154 | [170, 255, 76], 155 | [255, 229, 0], 156 | [255, 122, 0], 157 | [254, 18, 0]] 158 | return cmap 159 | 160 | def get_retouch_colormap(): 161 | cmap = [[0, 0, 0], 162 | [0, 0, 254], 163 | [0, 96, 256], 164 | [0, 212, 255], 165 | [76, 255, 170], 166 | [170, 255, 76], 167 | [255, 229, 0], 168 | [255, 122, 0], 169 | [254, 18, 0]] 170 | return cmap 171 | 172 | 173 | 174 | DEFAULT_VALUES = { 175 | 'sliding_miou_kernel': 7, # Make sure this is odd! 176 | 'sliding_miou_stride': 4, 177 | } 178 | 179 | DEFAULT_CONFIG_DICT = { 180 | 'mode': 'training', 181 | 'debugging': False, 182 | 'log_every_n_epochs': 100, 183 | 'max_valid_imgs': 10, 184 | 'cuda': True, 185 | 'gpu_device': 0, 186 | 'parallel': False, 187 | 'parallel_gpu_devices': [], 188 | 'seed': 0, 189 | 'tta': False 190 | } 191 | 192 | DEFAULT_CONFIG_NESTED_DICT = { 193 | 'data': { 194 | 'transforms': ['pad'], 195 | 'transform_values': { 196 | 'crop_size': 0.5, 197 | 'crop_mode': 'random', 198 | 'crop_shape': [512, 1024] 199 | }, 200 | 'split': 1, 201 | 'batch_size': 10, 202 | 'num_workers': 0, 203 | 'preload': False, 204 | 'blacklist': True, 205 | 'use_propagated': False, 206 | 'propagated_video_blacklist': False, 207 | 'propagated_quart_blacklist': False, 208 | 'use_relabeled': False, 209 | 'weighted_random': [0, 0], 210 | 'weighted_random_mode': 'v1', 211 | 'oversampling': [0, 0], 212 | 'oversampling_frac': 0.2, 213 | 'oversampling_preset': 'default', 214 | 'adaptive_batching': [0, 0], 215 | 'adaptive_sel_size': 10, 216 | 'adaptive_iou_update': 1, 217 | "repeat_factor": [0, 0], 218 | "repeat_factor_freq_thresh": 0.15, 219 | # loaders for two-step pseudo training 220 | # only loads labelled data with RF 221 | "lab_repeat_factor": [0, 0], 222 | # only loads unlabelled data 223 | "ulab_default": [0, 0], 224 | # loads lab and ulab mixed -- default choice for pseudo training 225 | "mixed_default": [0, 0], 226 | # loads lab with RF and ulab mixed 227 | "mixed_repeat_factor": [0, 0] 228 | }, 229 | 'train': { 230 | 'epochs': 50, 231 | 'lr_fct': 'exponential', 232 | 'lr_batchwise': False, 233 | 'lr_restarts': [], 234 | 'lr_restart_vals': 1, 235 | 'lr_params': None, 236 | }, 237 | 'loss': { 238 | 'temperature': 0.1, 239 | 'dominant_mode': 'all', 240 | 'label_scaling_mode': 'avg_pool', 241 | 'dc_weightings': { 242 | 'outer_freq': False, 243 | 'outer_entropy': False, 244 | 'outer_confusionmatrix': False, 245 | 'inner_crossentropy': False, 246 | 'inner_idealcrossentropy': False, 247 | 'neg_confusionmatrix': False, 248 | 'neg_negativity': False 249 | }, 250 | } 251 | } 252 | -------------------------------------------------------------------------------- /utils/df_from_data.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import pandas as pd 3 | 4 | import argparse 5 | 6 | # Set path to data, e.g. python df_from_data.py --path 7 | parser = argparse.ArgumentParser() 8 | path = "C:\\Users\\Theodoros Pissas\\Documents\\tresorit\\CaDIS\\segmentation" 9 | parser.add_argument('-p', '--path', type=str, default=path, 10 | help='Set path to data, e.g. python df_from_data.py --path ') 11 | args = parser.parse_args() 12 | 13 | record_list = [] 14 | data_path = pathlib.Path(args.path) 15 | subfolders = [[f, f.name] for f in data_path.iterdir() if f.is_dir()] 16 | for folder_path, folder_name in subfolders: 17 | for image in (folder_path / 'Images').iterdir(): 18 | record_list.append([ 19 | int(folder_name[-2:]), # Video number: 'Video01' --> 1 20 | str(pathlib.PurePosixPath(pathlib.Path(folder_name) / 'Images' / image.name)), # Relative path to the image 21 | str(pathlib.PurePosixPath(pathlib.Path(folder_name) / 'Labels' / image.name)), # Relative path ot the label 22 | ]) 23 | df = pd.DataFrame(data=record_list, columns=['vid_num', 'img_path', 'lbl_path']) 24 | df = df.sort_values(by=['vid_num', 'img_path']).reset_index(drop=True) 25 | df.to_pickle('../data/data.pkl') 26 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | """ utils from openseg codebase """ 5 | def is_distributed(): 6 | return dist.is_initialized() 7 | 8 | 9 | def get_world_size(): 10 | if not dist.is_initialized(): 11 | return 1 12 | return dist.get_world_size() 13 | 14 | 15 | def get_rank(): 16 | if not dist.is_initialized(): 17 | return 0 18 | return dist.get_rank() 19 | 20 | 21 | def all_reduce_numpy(array): 22 | tensor = torch.from_numpy(array).cuda() 23 | dist.all_reduce(tensor) 24 | return tensor.cpu().numpy() 25 | 26 | def reduce_tensor(inp): 27 | """ 28 | Reduce the loss from all processes so that 29 | process with rank 0 has the averaged results. 30 | """ 31 | world_size = dist.get_world_size() 32 | if world_size < 2: 33 | return inp 34 | with torch.no_grad(): 35 | reduced_inp = inp 36 | torch.distributed.reduce(reduced_inp, dst=0) 37 | return reduced_inp / world_size 38 | 39 | 40 | def barrier(): 41 | """Synchronizes all processes. 42 | 43 | This collective blocks processes until the whole group enters this 44 | function. 45 | """ 46 | if dist.is_initialized(): 47 | dist.barrier() # processes in global group wait here until all processes reach this point 48 | return 49 | 50 | @torch.no_grad() 51 | def concat_all_gather(tensor, concat_dim=0): 52 | """ from moco 53 | Performs all_gather operation on the provided tensors. 54 | *** Warning ***: torch.distributed.all_gather has no gradient. 55 | """ 56 | tensors_gather = [torch.ones_like(tensor) 57 | for _ in range(torch.distributed.get_world_size())] 58 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 59 | output = torch.cat(tensors_gather, dim=concat_dim) 60 | return output 61 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Donny You(youansheng@gmail.com) 4 | # Logging tool implemented with the python Package logging. 5 | 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import argparse 12 | import logging 13 | import os 14 | import sys 15 | 16 | 17 | DEFAULT_LOGFILE_LEVEL = 'debug' 18 | DEFAULT_STDOUT_LEVEL = 'info' 19 | DEFAULT_LOG_FILE = './default.log' 20 | DEFAULT_LOG_FORMAT = '%(asctime)s %(levelname)-7s %(message)s' 21 | 22 | LOG_LEVEL_DICT = { 23 | 'debug': logging.DEBUG, 24 | 'info': logging.INFO, 25 | 'warning': logging.WARNING, 26 | 'error': logging.ERROR, 27 | 'critical': logging.CRITICAL 28 | } 29 | 30 | 31 | class Logger(object): 32 | """ 33 | Args: 34 | Log level: CRITICAL>ERROR>WARNING>INFO>DEBUG. 35 | Log file: The file that stores the logging info. 36 | rewrite: Clear the log file. 37 | log format: The format of log messages. 38 | stdout level: The log level to print on the screen. 39 | """ 40 | logfile_level = None 41 | log_file = None 42 | log_format = None 43 | rewrite = None 44 | stdout_level = None 45 | logger = None 46 | 47 | _caches = {} 48 | 49 | @staticmethod 50 | def init(logfile_level=DEFAULT_LOGFILE_LEVEL, 51 | log_file=DEFAULT_LOG_FILE, 52 | log_format=DEFAULT_LOG_FORMAT, 53 | rewrite=False, 54 | stdout_level=None): 55 | Logger.logfile_level = logfile_level 56 | Logger.log_file = log_file 57 | Logger.log_format = log_format 58 | Logger.rewrite = rewrite 59 | Logger.stdout_level = stdout_level 60 | 61 | Logger.logger = logging.getLogger() 62 | Logger.logger.handlers = [] 63 | fmt = logging.Formatter(Logger.log_format) 64 | 65 | if Logger.logfile_level is not None: 66 | filemode = 'w' 67 | if not Logger.rewrite: 68 | filemode = 'a' 69 | 70 | dir_name = os.path.dirname(os.path.abspath(Logger.log_file)) 71 | if not os.path.exists(dir_name): 72 | os.makedirs(dir_name) 73 | 74 | if Logger.logfile_level not in LOG_LEVEL_DICT: 75 | print('Invalid logging level: {}'.format(Logger.logfile_level)) 76 | Logger.logfile_level = DEFAULT_LOGFILE_LEVEL 77 | 78 | Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.logfile_level]) 79 | 80 | fh = logging.FileHandler(Logger.log_file, mode=filemode) 81 | fh.setFormatter(fmt) 82 | fh.setLevel(LOG_LEVEL_DICT[Logger.logfile_level]) 83 | 84 | Logger.logger.addHandler(fh) 85 | 86 | if stdout_level is not None: 87 | if Logger.logfile_level is None: 88 | Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.stdout_level]) 89 | 90 | console = logging.StreamHandler() 91 | if Logger.stdout_level not in LOG_LEVEL_DICT: 92 | print('Invalid logging level: {}'.format(Logger.stdout_level)) 93 | return 94 | 95 | console.setLevel(LOG_LEVEL_DICT[Logger.stdout_level]) 96 | console.setFormatter(fmt) 97 | Logger.logger.addHandler(console) 98 | 99 | @staticmethod 100 | def set_log_file(file_path): 101 | Logger.log_file = file_path 102 | Logger.init(log_file=file_path) 103 | 104 | @staticmethod 105 | def set_logfile_level(log_level): 106 | if log_level not in LOG_LEVEL_DICT: 107 | print('Invalid logging level: {}'.format(log_level)) 108 | return 109 | 110 | Logger.init(logfile_level=log_level) 111 | 112 | @staticmethod 113 | def clear_log_file(): 114 | Logger.rewrite = True 115 | Logger.init(rewrite=True) 116 | 117 | @staticmethod 118 | def check_logger(): 119 | if Logger.logger is None: 120 | Logger.init(logfile_level=None, stdout_level=DEFAULT_STDOUT_LEVEL) 121 | 122 | @staticmethod 123 | def set_stdout_level(log_level): 124 | if log_level not in LOG_LEVEL_DICT: 125 | print('Invalid logging level: {}'.format(log_level)) 126 | return 127 | 128 | Logger.init(stdout_level=log_level) 129 | 130 | @staticmethod 131 | def debug(message): 132 | Logger.check_logger() 133 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 134 | lineno = sys._getframe().f_back.f_lineno 135 | prefix = '[{}, {}]'.format(filename,lineno) 136 | Logger.logger.debug('{} {}'.format(prefix, message)) 137 | 138 | @staticmethod 139 | def info(message): 140 | Logger.check_logger() 141 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 142 | lineno = sys._getframe().f_back.f_lineno 143 | prefix = '[{}, {}]'.format(filename,lineno) 144 | Logger.logger.info('{} {}'.format(prefix, message)) 145 | 146 | @staticmethod 147 | def info_once(message): 148 | Logger.check_logger() 149 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 150 | lineno = sys._getframe().f_back.f_lineno 151 | prefix = '[{}, {}]'.format(filename, lineno) 152 | 153 | if Logger._caches.get((prefix, message)) is not None: 154 | return 155 | 156 | Logger.logger.info('{} {}'.format(prefix, message)) 157 | Logger._caches[(prefix, message)] = True 158 | 159 | @staticmethod 160 | def warn(message): 161 | Logger.check_logger() 162 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 163 | lineno = sys._getframe().f_back.f_lineno 164 | prefix = '[{}, {}]'.format(filename,lineno) 165 | Logger.logger.warn('{} {}'.format(prefix, message)) 166 | 167 | @staticmethod 168 | def error(message): 169 | Logger.check_logger() 170 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 171 | lineno = sys._getframe().f_back.f_lineno 172 | prefix = '[{}, {}]'.format(filename,lineno) 173 | Logger.logger.error('{} {}'.format(prefix, message)) 174 | 175 | @staticmethod 176 | def critical(message): 177 | Logger.check_logger() 178 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 179 | lineno = sys._getframe().f_back.f_lineno 180 | prefix = '[{}, {}]'.format(filename,lineno) 181 | Logger.logger.critical('{} {}'.format(prefix, message)) 182 | 183 | 184 | def printlog(message:str, save_to_log=True, **kwargs): 185 | """prints a message in console and logs without printing in log file""" 186 | print(message, **kwargs) 187 | if save_to_log: 188 | Logger.info(message) 189 | 190 | 191 | if __name__ == "__main__": 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument('--logfile_level', default="debug", type=str, 194 | dest='logfile_level', help='To set the log level to files.') 195 | parser.add_argument('--stdout_level', default=None, type=str, 196 | dest='stdout_level', help='To set the level to print to screen.') 197 | parser.add_argument('--log_file', default="./default.log", type=str, 198 | dest='log_file', help='The path of log files.') 199 | parser.add_argument('--log_format', default="%(asctime)s %(levelname)-7s %(message)s", 200 | type=str, dest='log_format', help='The format of log messages.') 201 | parser.add_argument('--rewrite', default=False, type=bool, 202 | dest='rewrite', help='Clear the log files existed.') 203 | 204 | args = parser.parse_args() 205 | Logger.init(logfile_level=args.logfile_level, stdout_level=args.stdout_level, 206 | log_file=args.log_file, log_format=args.log_format, rewrite=args.rewrite) 207 | 208 | Logger.info("info test.") 209 | Logger.debug("debug test.") 210 | Logger.warn("warn test.") 211 | Logger.error("error test.") 212 | Logger.debug("debug test.") -------------------------------------------------------------------------------- /utils/lr_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | 4 | 5 | class LRFcts: 6 | def __init__(self, config: dict, lr_restart_steps: list, lr_total_steps: int): 7 | self.base_lr = config['learning_rate'] 8 | self.lr_total_steps = lr_total_steps 9 | self.lr_fct = config['lr_fct'] 10 | self.batchwise = config['lr_batchwise'] 11 | 12 | self.uses_restarts = True 13 | if len(lr_restart_steps)== 0: 14 | self.uses_restarts = False 15 | # Restart epochs, and base values 16 | self.lr_restarts = lr_restart_steps 17 | restart_vals = config['lr_restart_vals'] 18 | if 0 not in self.lr_restarts: 19 | self.lr_restarts.insert(0, 0) 20 | self.lr_restart_vals = [1] 21 | if isinstance(restart_vals, float) or isinstance(restart_vals, int): 22 | # Base LR value reduced to fraction every restart, end set to 0 23 | for i in range(1, len(self.lr_restarts)): 24 | self.lr_restart_vals.append(self.lr_restart_vals[i - 1] * restart_vals) 25 | elif isinstance(restart_vals, list): 26 | assert len(restart_vals) == len(config['lr_restarts']) - 1, \ 27 | "Value Error: lr_restart_vals is list, but not the same length as lr_restarts" 28 | self.lr_restart_vals.extend(restart_vals) 29 | if lr_total_steps not in self.lr_restarts: 30 | self.lr_restarts.append(lr_total_steps) 31 | self.lr_restart_vals.append(0) 32 | self.lr_restarts = np.array(self.lr_restarts) 33 | self.lr_restart_vals = np.array(self.lr_restart_vals) 34 | 35 | # Length of each restart 36 | self.restart_lengths = np.ones_like(self.lr_restarts) 37 | self.restart_lengths[:-1] = self.lr_restarts[1:] - self.lr_restarts[:-1] 38 | 39 | # Current restart position 40 | self.curr_restart = len(self.lr_restarts) - np.argmax((np.arange(lr_total_steps + 1)[:, np.newaxis] >= self.lr_restarts)[:, ::-1], axis=1) - 1 41 | self.lr_params = dict() 42 | if config['lr_params'] is not None: 43 | self.lr_params = config['lr_params'] 44 | 45 | self.epochs_ulab = config['ulab_epochs'] if 'ulab_epochs' in config else None 46 | self.epochs_lab = config['lab_epochs'] if 'lab_epochs' in config else None 47 | 48 | if self.lr_fct == 'piecewise_static': 49 | # example entry in config['train']["piecewise_static_schedule"]: [[40,1],[50,0.1]] 50 | # if s<=40 ==> lr = learning_rate * 1 elif s<=50 ==> lr = learning_rate * 0.1 51 | assert(len(self.lr_restarts) == 2), 'with piecewise_static lr schedule lr_restarts must be empty list' \ 52 | ' instead got {}'.format(self.lr_restarts) 53 | assert 'piecewise_static_schedule' in self.lr_params 54 | assert isinstance(self.lr_params['piecewise_static_schedule'], list) 55 | assert self.lr_params['piecewise_static_schedule'][-1][0] == config['epochs'], \ 56 | "piecewise_static_schedule's last phase must have first element equal to number of epochs " \ 57 | "instead got: {} and {} respectively".format(config['piecewise_static_schedule'][-1][0], config['epochs']) 58 | 59 | piecewise_static_schedule = self.lr_params['piecewise_static_schedule'] 60 | self.piecewise_static_schedule = OrderedDict() # this is essential, it has to be an ordered dict 61 | phase_prev = 0 62 | for phase in piecewise_static_schedule: # get ordered dict from list 63 | assert phase_prev < phase[0], ' piecewise_static_schedule must have increasing first elements per phase' \ 64 | ' instead got phase_prev {} and phase {}'.format(phase_prev, phase[0]) 65 | self.piecewise_static_schedule[phase[0]] = phase[1] 66 | 67 | def __call__(self, step: int): 68 | if self.uses_restarts: 69 | steps_since_restart = step - self.lr_restarts[self.curr_restart[step]] 70 | base_val = self.lr_restart_vals[self.curr_restart[step]] 71 | if self.lr_fct == 'static': 72 | return base_val 73 | elif self.lr_fct == 'piecewise_static': 74 | return self.piecewise_static(step) 75 | elif self.lr_fct == 'exponential': 76 | return self.lr_exponential(base_val, steps_since_restart) 77 | elif self.lr_fct == 'polynomial': 78 | steps_in_restart = self.restart_lengths[self.curr_restart[step]] 79 | return self.lr_polynomial(base_val, steps_since_restart, steps_in_restart) 80 | elif self.lr_fct == 'cosine': 81 | steps_in_restart = self.restart_lengths[self.curr_restart[step]] 82 | return self.lr_cosine(base_val, steps_since_restart, steps_in_restart) 83 | else: 84 | ValueError("Learning rate schedule '{}' not recognised.".format(self.lr_fct)) 85 | else: 86 | # todo hacky for now, remove the lr_restarts code to be used only if lr_restarts are used 87 | base_val = 1.0 88 | if (step>self.lr_total_steps): 89 | print(f'warning learning rate scheduler at step {step} exceeds expected lr_total_steps {self.lr_total_steps}') 90 | if self.lr_fct == 'exponential': 91 | return self.lr_exponential(base_val, step) 92 | elif self.lr_fct == 'polynomial': 93 | return self.lr_polynomial(base_val, step, self.lr_total_steps) 94 | elif self.lr_fct == 'linear-warmup-polynomial': 95 | assert 'warmup_iters' in self.lr_params \ 96 | and 'warmup_rate' in self.lr_params, f'lr_params must be passed via config as dict with keys ' \ 97 | f'warmup_iters and warmup_rate instead got {self.lr_params}' 98 | if step <= self.lr_params['warmup_iters']-1: 99 | return self.linear_warmup(step) 100 | else: 101 | return self.lr_polynomial(base_val, step, self.lr_total_steps) 102 | else: 103 | ValueError("Learning rate schedule without restarts'{}' not recognised.".format(self.lr_fct)) 104 | 105 | def piecewise_static(self, step): 106 | # important this only works if self.piecewise_static_schedule is an ordered dict! 107 | for phase_end in self.piecewise_static_schedule.keys(): 108 | lr = self.piecewise_static_schedule[phase_end] 109 | if step <= phase_end: 110 | return lr 111 | 112 | def linear_warmup(self, step: int): 113 | # step + 1 to account for step = 0 ... warmup_iters -1 114 | 115 | lr = 1 - (1 - (step+1) / self.lr_params['warmup_iters']) * (1 - self.lr_params['warmup_rate']) 116 | # warmup_lr = [_lr * (1 - k) for _lr in regular_lr] 117 | return lr 118 | 119 | def lr_exponential(self, base_val: float, steps_current: int): 120 | gamma = .98 if self.lr_params is None else self.lr_params 121 | lr = base_val * gamma ** steps_current 122 | return lr 123 | 124 | def lr_polynomial(self, base_val: float, steps_current: int, max_steps: int): 125 | # max_steps - 1 to account for step = 0 ... max_steps -1 126 | # power = .9 if 'power' in self.lr_params else self.lr_params['power'] 127 | power = self.lr_params.get('power', .9) 128 | # min_lr = self.lr_params['min_lr'] if 'min_lr' in self.lr_params else 0.0 129 | min_lr = self.lr_params.get('min_lr', 0.0) 130 | coeff = (1 - steps_current / (max_steps-1)) ** power 131 | lr = (base_val- min_lr) * coeff + min_lr 132 | return lr 133 | 134 | def lr_cosine(self, base_val, steps_current, max_steps): 135 | lr = base_val * 0.5 * (1. + np.cos(np.pi * steps_current / max_steps)) 136 | return lr 137 | 138 | 139 | if __name__ == '__main__': 140 | def lr_exponential(base_val: float, steps_since_restart: int, steps_in_restart=None, gamma: int = .98): 141 | lr = base_val * gamma ** steps_since_restart 142 | return lr 143 | 144 | def lr_cosine(base_val, steps_since_restart, steps_in_restart): 145 | lr = base_val * 0.5 * (1. + np.cos(np.pi * steps_since_restart / steps_in_restart)) 146 | return lr 147 | 148 | 149 | def linear_warmup(step: int): 150 | base_lr = 0.0001 151 | rate = 1e-6 152 | # step + 1 to account for step = 0 ... warmup_iters -1 153 | lr = 1 - (1 - (step+1) / 1500) * (1 - rate) 154 | # warmup_lr = [_lr * (1 - k) for _lr in regular_lr] 155 | return lr * base_lr 156 | 157 | def lr_polynomial( base_val: float, steps_current: int, max_steps: int): 158 | # max_steps - 1 to account for step = 0 ... max_steps -1 159 | power = 1.0 160 | min_lr = 0.0 161 | coeff = (1 - steps_current / (max_steps-1)) ** power 162 | lr = (base_val- min_lr) * coeff + min_lr 163 | return lr 164 | 165 | 166 | def linear_warmup_then_poly(step:int, total_steps): 167 | if step <= 1500 - 1: 168 | return linear_warmup(step) 169 | else: 170 | return lr_polynomial(0.0001, step, total_steps) 171 | 172 | 173 | 174 | 175 | # lr_start = 0.0001 176 | # T = 100 177 | # lrs = [lr_cosine(lr_start, step, T) for step in range(T)] 178 | # lrs_exp = [lr_exponential(lr_start, step % (T//4), T//4) for step in range(T)] 179 | # 180 | # 181 | # 182 | import matplotlib.pyplot as plt 183 | # plt.plot(lrs) 184 | # plt.plot(lrs_exp) 185 | T = 160401 186 | lrs_exp = [linear_warmup_then_poly(step, T) for step in range(T)] 187 | plt.plot(lrs_exp) 188 | plt.show() 189 | a = 1 -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import to_numpy, CLASS_INFO 3 | 4 | 5 | def get_confusion_matrix(prediction, target, existing_matrix=None): 6 | """Expects prediction logits (as output by network), and target as classes in single channel (as from data)""" 7 | prediction, target = to_numpy(prediction), to_numpy(target) 8 | num_classes = prediction.shape[1] # prediction is shape NCHW, we want C (one-hot length of all classes) 9 | one_hots = np.eye(num_classes) 10 | prediction = np.moveaxis(prediction, 1, 0) # Prediction is NCHW -> move C to the front to make it CNHW 11 | prediction = np.reshape(prediction, (num_classes, -1)) # Prediction is [C, N*H*W] 12 | prediction = np.argmax(prediction, 0) # Prediction is now [N*H*W] 13 | one_hot_preds = one_hots[prediction] # Prediction is now [N*H*W, C] 14 | one_hot_preds = np.moveaxis(one_hot_preds, 1, 0) # Prediction is now [C, N*H*W] 15 | one_hot_targets = one_hots[target.reshape(-1)] # Target is now [N*H*W, C] 16 | confusion_matrix = np.matmul(one_hot_preds, one_hot_targets).astype('i') # [C, N*H*W] x [N*H*W, C] = [C, C] 17 | # Consistency check: 18 | assert(np.sum(confusion_matrix) == target.size) # All elements summed equals all pixels in original target 19 | for i in range(num_classes): 20 | assert(np.sum(confusion_matrix[i]) == np.sum(prediction == i)) # Row of matrix equals class incidence in pred 21 | assert(np.sum(confusion_matrix[:, i]) == np.sum(target == i)) # Col of matrix equals class incidence in target 22 | if existing_matrix is not None: 23 | assert(existing_matrix.shape == confusion_matrix.shape) 24 | confusion_matrix += existing_matrix 25 | return confusion_matrix 26 | 27 | 28 | def normalise_confusion_matrix(matrix, mode): 29 | if mode == 'row': 30 | row_sums = matrix.sum(axis=1) 31 | row_sums[row_sums == 0] = 1 # to avoid division by 0. Safe, because if sum is 0, all elements have to be 0 too 32 | matrix = matrix / row_sums[:, np.newaxis] 33 | elif mode == 'col': 34 | col_sums = matrix.sum(axis=0) 35 | col_sums[col_sums == 0] = 1 # to avoid division by 0. Safe, because if sum is 0, all elements have to be 0 too 36 | matrix = matrix / col_sums[np.newaxis, :] 37 | else: 38 | raise ValueError("Normalise confusion matrix: mode needs to be either 'row' or 'col'.") 39 | return matrix 40 | 41 | 42 | def get_pixel_accuracy(confusion_matrix): 43 | """Pixel accuracies, adapted from https://github.com/CSAILVision/semantic-segmentation-pytorch 44 | 45 | :param confusion_matrix: Confusion matrix with absolute values. Rows are predicted classes, columns ground truths 46 | :return: Overall pixel accuracy, pixel accuracy per class (PA / PAC in CaDISv2 paper) 47 | """ 48 | pred_class_correct = np.diag(confusion_matrix) 49 | acc = np.sum(pred_class_correct) / np.sum(confusion_matrix) 50 | pred_class_sums = np.sum(confusion_matrix, axis=1) 51 | pred_class_sums[pred_class_sums == 0] = 1 # To avoid division by 0 problems. Safe because all elem = 0 when sum = 0 52 | acc_per_class = np.mean(pred_class_correct / pred_class_sums) 53 | return acc, acc_per_class 54 | 55 | 56 | def get_mean_iou(confusion_matrix, experiment, categories=False, single_class=None): 57 | """Uses confusion matrix to compute mean iou. Confusion matrix computed by get_confusion_matrix: row indexes 58 | prediction class, column indexes ground truth class. Based on: 59 | github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalPixelLevelSemanticLabeling.py 60 | """ 61 | assert experiment in [1, 2, 3], 'experiment must be in [1,2,3] instead got [{}]'.format(experiment) 62 | if single_class is not None: 63 | # compute miou for a single_class 64 | assert(not categories),\ 65 | 'when single_class is not None, category must be False instead got [{}]'.format(categories) 66 | assert(single_class in CLASS_INFO[experiment]),\ 67 | 'single_class must be {} instead got [{}]'.format(CLASS_INFO[experiment][1].keys(), single_class) 68 | return get_single_class_iou(confusion_matrix, experiment, single_class) 69 | elif categories: 70 | # compute miou for the classes of instruments and for the classes of anatomies 71 | # compute miou for all classes 72 | assert (single_class is None),\ 73 | 'when category is not None, single class must be None instead got [{}]'.format(single_class) 74 | miou_instruments = np.mean([get_single_class_iou(confusion_matrix, experiment, c) 75 | for c in CLASS_INFO[experiment][2]['instruments']]) 76 | miou_anatomies = np.mean([get_single_class_iou(confusion_matrix, experiment, c) 77 | for c in CLASS_INFO[experiment][2]['anatomies']]) 78 | miou = np.mean([get_single_class_iou(confusion_matrix, experiment, c) 79 | for c in CLASS_INFO[experiment][1].keys()]) 80 | return miou, miou_instruments, miou_anatomies 81 | else: 82 | # compute miou for all classes 83 | miou = np.mean([get_single_class_iou(confusion_matrix, experiment, c) 84 | for c in CLASS_INFO[experiment][1].keys()]) 85 | return miou 86 | 87 | 88 | def get_single_class_iou(confusion_matrix, experiment, single_class): 89 | if single_class == 255: # This is the 'ignore' class helpfully introduced in exp 2 and 3. Needs to NOT be 255 here 90 | single_class = confusion_matrix.shape[0] - 1 91 | # iou = tp/(tp + fp + fn) 92 | # the number of true positive pixels for this class 93 | # the entry on the diagonal of the confusion matrix 94 | tp = confusion_matrix[single_class, single_class] 95 | 96 | # the number of false negative pixels for this class 97 | # the column sum of the matching row in the confusion matrix 98 | # minus the diagonal entry 99 | fn = confusion_matrix[:, single_class].sum() - tp 100 | 101 | # the number of false positive pixels for this class 102 | # Only pixels that are not on a pixel with ground truth class that is ignored 103 | # The row sum of the corresponding row in the confusion matrix 104 | # without the ignored rows and without the actual label of interest 105 | not_ignored = [c for c in CLASS_INFO[experiment][1].keys() if not (c == 255 or c == single_class)] 106 | fp = confusion_matrix[single_class, not_ignored].sum() 107 | 108 | # the denominator of the IOU score 109 | denom = (tp + fp + fn) 110 | if denom == 0: 111 | # return float('nan') 112 | return 0 # Otherwise the mean always returns NaN which is technically correct but not so helpful 113 | # return IOU 114 | return float(tp) / denom 115 | -------------------------------------------------------------------------------- /utils/optimizer_utils.py: -------------------------------------------------------------------------------- 1 | from .logger import printlog 2 | 3 | def get_num_layer_stage_wise(var_name, num_max_layer): 4 | """Get the layer id to set the different learning rates in ``stage_wise`` 5 | decay_type. only for convnext series 6 | Args: 7 | var_name (str): The key of the model. 8 | num_max_layer (int): Maximum number of backbone layers. 9 | Returns: 10 | int: The id number corresponding to different learning rate in 11 | ``LearningRateDecayOptimizerConstructor``. 12 | """ 13 | 14 | if var_name in ('backbone.cls_token', 'backbone.mask_token', 'backbone.pos_embed'): 15 | return 0 16 | elif var_name.startswith('backbone.downsample_layers'): 17 | return 0 18 | elif var_name.startswith('backbone.stages'): 19 | stage_id = int(var_name.split('.')[2]) 20 | return stage_id + 1 21 | else: 22 | return num_max_layer - 1 # this essentially means all layers beyond layers get lr = base_lr 23 | 24 | 25 | def is_in(param_group, param_group_list): 26 | # assert is_list_of(param_group_list, dict) 27 | param = set(param_group['params']) 28 | param_set = set() 29 | for group in param_group_list: 30 | param_set.update(set(group['params'])) 31 | return not param.isdisjoint(param_set) 32 | 33 | 34 | def get_param_groups_using_keys(model, config): 35 | # this mode specifies keys (strings) that if present in var's name place the var in a parameter group 36 | # each such key generates a new param_group where all variables share wd_mult, lr_mult 37 | base_lr = config['train']['learning_rate'] 38 | base_wd = config['train']['weight_decay'] 39 | params = [] 40 | parameter_groups = {} 41 | params_dict = dict(model.named_parameters()) 42 | for name, param in params_dict.items(): 43 | param_group = {'params': [param], "group_name": name, 'param_names': []} 44 | is_custom = (False, None) 45 | is_first_in_group = False 46 | if not param.requires_grad: 47 | params.append(param_group) 48 | continue 49 | if is_in(param_group, params): 50 | a = 1 51 | group_name = 'base_lr_wd' 52 | for custom_key in config['train']['opt_keys']: 53 | if custom_key in name: 54 | is_custom = (True, custom_key) 55 | lr_mult = config['train']['opt_keys'][custom_key].get('lr_mult', 1.0) 56 | wd_mult = config['train']['opt_keys'][custom_key].get('wd_mult', 1.0) 57 | param_group['lr'] = lr_mult * base_lr 58 | param_group['weight_decay'] = wd_mult * base_wd 59 | group_name = f'{custom_key}_lrm{lr_mult}_wdm{wd_mult}' 60 | break 61 | if not is_custom[0]: 62 | param_group['lr'] = base_lr 63 | param_group['weight_decay'] = base_wd 64 | 65 | if group_name not in parameter_groups: 66 | param_group['group_name'] = group_name 67 | parameter_groups[group_name] = param_group 68 | is_first_in_group = True 69 | # parameter_groups[group_name]['param_names'] = [name] 70 | if not is_first_in_group: 71 | parameter_groups[group_name]['params'].append(param) 72 | parameter_groups[group_name]['param_names'].append(name) 73 | 74 | params.extend(parameter_groups.values()) 75 | # printlog(f'optimizer param groups : \n {params}') 76 | params_cnt = 0 77 | for g in params: 78 | params_cnt += len(g['param_names']) 79 | assert (len(params_dict) == params_cnt), f'mismatch between params in parameter groups {params_cnt}' \ 80 | f' and model.named_parameters {len(params_dict)}' 81 | return params 82 | 83 | def get_param_groups_with_stage_wise_lr_decay(model, config): 84 | # adapted from convnext repo 85 | # scales the learning rate of deeper layers by decay_rate ** (num_layers - layer_id - 1) 86 | # tl,dr --> latest layers have gradually higher lr 87 | assert 'ConvNext' in config['graph']['backbone'], f"stage_wise_lr currently only supported for " \ 88 | f"ConvNext backbones instead got {config['graph']['backbone']}" 89 | decay_rate = config['train']['stage_wise_lr']['decay_rate'] 90 | num_layers = config['train']['stage_wise_lr']['num_layers'] + 2 # todo +2 is still a mystery (?) 91 | base_lr = config['train']['learning_rate'] 92 | base_wd = config['train']['weight_decay'] 93 | params = [] 94 | parameter_groups = {} 95 | params_dict = dict(model.named_parameters()) 96 | 97 | for name, param in params_dict.items(): 98 | if len(param.shape) == 1 or name.endswith('.bias') or name in ('pos_embed', 'cls_token'): 99 | # param.shape == 1 is here to ensure some layer-norm modules have 0 weight decay 100 | # despite not containing the word "norm" in their names 101 | # for convnext these are for e.x 'backbone.downsample_layers.0.1.weight' 102 | # or 'backbone.downsample_layers.0.1.bias' 103 | group_name = 'no_decay' 104 | this_weight_decay = 0.0 105 | # printlog(name, this_weight_decay) 106 | else: 107 | group_name = 'decay' 108 | this_weight_decay = base_wd 109 | layer_id = get_num_layer_stage_wise(name, num_layers) 110 | # logger.info(f'set param {name} as id {layer_id}') 111 | group_name = f'layer_{layer_id}_{group_name}' 112 | if group_name not in parameter_groups: 113 | scale = decay_rate ** (num_layers - layer_id - 1) # scale * base_lr is the learning rate for this group 114 | # printlog(group_name, scale) 115 | parameter_groups[group_name] = { 116 | 'weight_decay': this_weight_decay, 117 | 'params': [], 118 | 'param_names': [], 119 | 'lr_scale': scale, 120 | 'group_name': group_name, 121 | 'lr': scale * base_lr, 122 | } 123 | parameter_groups[group_name]['params'].append(param) 124 | parameter_groups[group_name]['param_names'].append(name) 125 | params.extend(parameter_groups.values()) 126 | # printlog(f'optimizer param groups : \n {params}') 127 | params_cnt = 0 128 | for g in params: 129 | params_cnt += len(g['param_names']) 130 | assert (len(params_dict) == params_cnt), f'mismatch between params in parameter groups {params_cnt}' \ 131 | f' and model.named_parameters {len(params_dict)}' 132 | return params 133 | -------------------------------------------------------------------------------- /utils/repeat_factor_sampling.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import pandas as pd 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Sampler 6 | from utils import DATASETS_INFO, get_class_info, reverse_one_to_many_mapping 7 | from itertools import islice 8 | from torch.utils.data.distributed import DistributedSampler 9 | from .distributed import is_distributed, get_rank, get_world_size 10 | import math 11 | 12 | 13 | def get_class_repeat_factors_for_experiment(lbl_df: pd.DataFrame, 14 | repeat_thresh: float, 15 | exp: int, 16 | return_frequencies=False, 17 | dataset: str = 'CADIS'): 18 | 19 | experiment_cls = DATASETS_INFO[dataset].CLASS_INFO[exp][1] 20 | exp_mapping = DATASETS_INFO[dataset].CLASS_INFO[exp][0] 21 | rev_mapping = reverse_one_to_many_mapping(exp_mapping) 22 | canonical_cls = DATASETS_INFO[dataset].CLASS_NAMES[0] 23 | canonical_num_to_name = reverse_one_to_many_mapping(DATASETS_INFO[dataset].CLASS_INFO[0][1]) 24 | num_frames = lbl_df.shape[0] 25 | 26 | cls_freqs = dict() 27 | cls_rfs = dict() 28 | 29 | for c in canonical_cls: 30 | c_exp = rev_mapping[canonical_num_to_name[c]] # from canonical cls name to experiment num 31 | if c_exp not in cls_freqs.keys(): 32 | cls_freqs[c_exp] = 0 33 | s = lbl_df.loc[lbl_df[c] > 0].shape[0] 34 | cls_freqs[c_exp] += s / num_frames 35 | 36 | for c_exp in experiment_cls: 37 | if cls_freqs[c_exp] == 0: 38 | cls_freqs[c_exp] = repeat_thresh 39 | cls_rfs[c_exp] = np.maximum(1, np.sqrt(repeat_thresh / cls_freqs[c_exp])) 40 | cls_freqs = {k: v for k, v in sorted(cls_freqs.items(), reverse=True, key=lambda item: item[1])} 41 | cls_rfs = {k: v for k, v in sorted(cls_rfs.items(), reverse=True, key=lambda item: item[1])} 42 | if return_frequencies: 43 | return cls_freqs, cls_rfs 44 | else: 45 | return cls_rfs 46 | 47 | 48 | def get_image_repeat_factors_for_experiment(lbl_df: pd.DataFrame, cls_rfs: dict, exp: int, dataset: str): 49 | exp_mapping = DATASETS_INFO[dataset].CLASS_INFO[exp][0] 50 | rev_mapping = reverse_one_to_many_mapping(exp_mapping) # from canonical to experiment classes 51 | canonical_cls = DATASETS_INFO[dataset].CLASS_NAMES[0] 52 | canonical_num_to_name = reverse_one_to_many_mapping(DATASETS_INFO[dataset].CLASS_INFO[0][1]) # canonical class to num 53 | img_rfs = [] 54 | inds = [] 55 | for idx, row in lbl_df.iterrows(): # for each frame 56 | class_repeat_factors_in_frame = [] 57 | for c in canonical_cls: 58 | if row[c] > 0: 59 | class_repeat_factors_in_frame.append(cls_rfs[rev_mapping[canonical_num_to_name[c]]]) 60 | img_rfs.append(np.max(class_repeat_factors_in_frame)) 61 | inds.append(idx) 62 | return inds, img_rfs 63 | 64 | 65 | class RepeatFactorSampler(Sampler): 66 | def __init__(self, data_source: torch.utils.data.Dataset, dataframe: pd.DataFrame, 67 | repeat_thresh: float, experiment: int, split: int, blacklist=True, seed=None, dataset='CADIS'): 68 | """ Computes repeat factors and returns repeat factor sampler 69 | Note: this sampler always uses shuffling 70 | :param data_source: a torch dataset object 71 | :param dataframe: a dataframe with class occurences as columns 72 | :param repeat_thresh: repeat factor threshold (intuitively: frequency below which rf kicks in) 73 | :param experiment: experiment id 74 | :param split: dataset split being used to determine repeat factors for each image in it. 75 | :param blacklist: whether blackslisting is to be applied 76 | :param seed: seeding for torch randomization 77 | :param dataset : todo does not support CTS currently 78 | :return RepeatFactorSampler object 79 | """ 80 | super().__init__(data_source=data_source) 81 | assert(0 <= repeat_thresh < 1 and split in [0, 1, 2]) 82 | seed = 1 if seed is None else seed 83 | self.seed = int(seed) 84 | self.shuffle = True # shuffling is always used with this sampler 85 | self.split = split 86 | self.repeat_thresh = repeat_thresh 87 | df = get_class_info(dataframe, 0, with_name=True) 88 | if blacklist: # drop blacklisted 89 | df = df.drop(df[df['blacklisted'] == 1].index) 90 | df.reset_index() 91 | self.class_repeat_factors, self.repeat_factors = \ 92 | self.repeat_factors_class_and_image_level(df, experiment, repeat_thresh, split, dataset) 93 | self._int_part = torch.trunc(self.repeat_factors) 94 | self._frac_part = self.repeat_factors - self._int_part 95 | self.g = torch.Generator() 96 | self.g.manual_seed(self.seed) 97 | self.epoch = 0 98 | self.indices = None 99 | self.distributed = is_distributed() # todo this should be removed in the future once local has ddp package 100 | 101 | self.num_replicas = get_world_size() 102 | self.rank = get_rank() 103 | print(f'RF sampler -- world_size: {self.num_replicas} rank : {self.rank}') 104 | self.dataset = data_source 105 | # if len(self.dataset) % self.num_replicas ==0: # type: ignore 106 | # self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore 107 | # else: 108 | # self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas) 109 | # self.total_size = self.num_samples * self.num_replicas 110 | 111 | @staticmethod 112 | def repeat_factors_class_and_image_level(df: pd.DataFrame, experiment: int, repeat_thresh: float, 113 | split: int, dataset: str): 114 | train_videos = DATASETS_INFO[dataset].DATA_SPLITS[split][0] 115 | train_df = df.loc[df['vid_num'].isin(train_videos)] 116 | train_df = train_df.reset_index() 117 | # For each class compute the class-level repeat factor: r(c) = max(1, sqrt(t/f(c)) where f(c) is class freq 118 | class_rfs = get_class_repeat_factors_for_experiment(train_df, repeat_thresh, experiment, dataset=dataset) 119 | # For each image I, compute the image-level repeat factor: r(I) = max_{c in I} r(c) 120 | inds, rfs = get_image_repeat_factors_for_experiment(train_df, class_rfs, experiment, dataset) 121 | return class_rfs, torch.tensor(rfs, dtype=torch.float32) 122 | 123 | def __iter__(self): 124 | if self.distributed: # todo this should be removed in the future 125 | start = get_rank() 126 | step = get_world_size() # 1 if not ddp 127 | # to debug 128 | # print(f'rank {get_rank()} -slicing start {start} step {step} ') 129 | print(f'rank {get_rank()} indices : {len([i for i in islice(self._yield_indices(), start, None, step)])}') 130 | yield from islice(self._yield_indices(), start, None, step) 131 | else: 132 | 133 | yield from islice(self._yield_indices(), 0, None, 1) 134 | 135 | def _yield_indices(self): 136 | if self.indices is not None: 137 | indices = torch.tensor(self.indices, dtype=torch.int64) 138 | else: 139 | indices = self._get_epoch_indices(self.g) 140 | ind_left = self.__len__() 141 | print(f'Indices generated {ind_left}, rank : {get_rank()}') 142 | self.g.manual_seed(self.seed + self.epoch) 143 | while ind_left > 0: 144 | # each epoch may have a slightly different size due to the stochastic rounding. 145 | randperm = torch.randperm(len(indices), generator=self.g) # shuffling 146 | for item in indices[randperm]: 147 | # print(f'yielding : {item} rank : {get_rank()}') 148 | yield int(item) 149 | ind_left -= 1 150 | self.indices = None 151 | 152 | def __len__(self): 153 | if self.indices is not None: 154 | return len(self.indices) 155 | else: 156 | return len(self._get_epoch_indices(self.g)) 157 | 158 | def set_epoch(self, epoch): 159 | self.epoch = epoch 160 | 161 | def _get_epoch_indices(self, generator): 162 | # stochastic rounding so that the target repeat factor 163 | # is achieved in expectation over the course of training 164 | rands = torch.rand(len(self._frac_part), generator=generator) 165 | rounded_rep_factors = self._int_part + (rands < self._frac_part).float() 166 | indices = [] 167 | # replicate each image's index by its rounded repeat factor 168 | for img_index, rep_factor in enumerate(rounded_rep_factors): 169 | indices.extend([img_index] * int(rep_factor.item())) 170 | self.indices = indices 171 | if self.num_replicas>1: # self.distributed and 172 | # ensures each process has access to equal number of indices from the dataset 173 | self.num_indices = len(self.indices) 174 | if self.num_indices % self.num_replicas ==0: 175 | self.indices_per_processs = math.ceil(self.num_indices / self.num_replicas) 176 | else: 177 | self.indices_per_processs = math.ceil((self.num_indices - self.num_replicas) / self.num_replicas) 178 | 179 | self.num_indices_to_keep = self.indices_per_processs * self.num_replicas 180 | self.indices_to_keep = torch.randint(low=0, high=self.num_indices_to_keep-1, 181 | size=[self.num_indices_to_keep], 182 | generator=generator) 183 | 184 | # print(f'num_indices = {self.num_indices} - num_indices_to_keep = {self.self.num_indices_to_keep} - rank : {get_rank()}' ) 185 | return torch.tensor(indices, dtype=torch.int64)[self.indices_to_keep] 186 | 187 | return torch.tensor(indices, dtype=torch.int64) 188 | 189 | 190 | 191 | if __name__ == '__main__': 192 | inds = np.arange(1000).tolist() 193 | def dummy(start): 194 | yield from islice(inds, start, None, 4) 195 | a = [[i for i in dummy(start)] for start in range(4)] 196 | -------------------------------------------------------------------------------- /utils/semi_utis.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from collections import OrderedDict 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class BalancedConcatDataset(Dataset): 7 | def __init__(self, *datasets): 8 | self.datasets = datasets 9 | dataset_lengths = [len(d) for d in self.datasets] 10 | self.max_len = max(dataset_lengths) 11 | self.min_len = min(dataset_lengths) 12 | 13 | def __getitem__(self, i): 14 | # each item is a a tuple of 1 unlabelled sample and 1 labelled sample 15 | v = [d[i % len(d)] for d in self.datasets] 16 | b = tuple(v) 17 | # b : b[0] = list containing dataset_1 img of shape (C,H,W), mask of shape (H,W), pseudo info of shape (,) 18 | # b : b[1] = list containing dataset_2 img of shape (C,H,W), mask of shape (H,W), pseudo info of shape (,) 19 | return b 20 | 21 | def __len__(self): 22 | # stop when the longest dataset runs out of samples 23 | return self.max_len 24 | 25 | 26 | def get_video_files_from_split(ids, debug=False): 27 | """ gets list of video ids (i.e a split's train videos) and returns a list 28 | the names of the corresponding mp4 files""" 29 | dicts = dict() 30 | dicts['train_1'] = [1, 2, 3, 4, 5, 6, 7, 8] if not debug else [1, 3, 6] 31 | dicts['train_2'] = [9, 10, 11, 12, 13, 14, 15, 16] 32 | dicts['train_3'] = [17, 18, 19, 20, 21, 22, 23, 24] 33 | dicts['train_4'] = [25] 34 | files = [] 35 | for i in ids: 36 | # s = "{0:0=1d}".format(i) 37 | s = "%02d" % i 38 | if i in dicts['train_1']: 39 | files.append(pathlib.Path('train_1') / pathlib.Path('train{}.mp4'.format(s))) 40 | elif i in dicts['train_2'] and not debug: 41 | files.append(pathlib.Path('train_2') / pathlib.Path('train{}.mp4'.format(s))) 42 | elif i in dicts['train_3'] and not debug: 43 | files.append(pathlib.Path('train_3') / pathlib.Path('train{}.mp4'.format(s))) 44 | elif i in dicts['train_4'] and not debug: 45 | files.append(pathlib.Path('train_4') / pathlib.Path('train{}.mp4'.format(s))) 46 | return files 47 | 48 | 49 | def get_excluded_frames_from_df(df, train_videos): 50 | train = df.loc[df['vid_num'].isin(train_videos)] 51 | train.reset_index() 52 | train = train.reset_index() 53 | train = train.drop(train[train['blacklisted'] == 1].index) 54 | train = train.reset_index() 55 | img_vid_frames = train['img_path'] 56 | img_vid_frames = img_vid_frames.tolist() 57 | video_to_excluded_frames_dict = OrderedDict() 58 | for f in img_vid_frames: 59 | frame_id = int(f.split('.')[-2][-6:]) 60 | video_id = f.split('Video')[-1][0:2] if '_' not in f.split('Video')[-1][0:2] else f.split('Video')[-1][0] 61 | video_id = int(video_id) 62 | if video_id in video_to_excluded_frames_dict: 63 | video_to_excluded_frames_dict[video_id].append(frame_id) 64 | else: 65 | video_to_excluded_frames_dict[video_id] = [] 66 | video_to_excluded_frames_dict[video_id].append(frame_id) 67 | # sanity check 68 | assert(list(video_to_excluded_frames_dict.keys()) == train_videos) 69 | return video_to_excluded_frames_dict 70 | 71 | -------------------------------------------------------------------------------- /utils/torch_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms.functional as F 3 | from PIL import Image 4 | import random 5 | import math 6 | from torchvision.transforms import ToPILImage 7 | 8 | class ExtRandomScale(object): 9 | def __init__(self, scale_range, interpolation=Image.BILINEAR): 10 | self.scale_range = scale_range 11 | self.interpolation = interpolation 12 | 13 | def __call__(self, img, lbl): 14 | """ 15 | Args: 16 | img (PIL Image): Image to be scaled. 17 | lbl (PIL Image): Label to be scaled. 18 | Returns: 19 | PIL Image: Rescaled image. 20 | PIL Image: Rescaled label. 21 | """ 22 | # assert img.size == lbl.size 23 | # scale = random.uniform(self.scale_range[0], self.scale_range[1]) 24 | w, h = img.size 25 | rand_log_scale = math.log(self.scale_range[0], 2) + random.random() * (math.log(self.scale_range[1], 2) - math.log(self.scale_range[0], 2)) 26 | random_scale = math.pow(2, rand_log_scale) 27 | new_size = (int(round(w * random_scale)), int(round(h * random_scale))) 28 | image = img.resize(new_size, Image.ANTIALIAS) 29 | mask = lbl.resize(new_size, Image.NEAREST) 30 | return image, mask 31 | 32 | if __name__ == '__main__': 33 | h = 8*8 34 | w = 8*8 35 | B = 2 36 | I_ = 2*torch.eye(h, w).rot90() 37 | lbl = torch.ones(size=(h, w)) - torch.eye(h, w) + I_ 38 | x = torch.rand(size=(h, w, 3)).float() 39 | scaler = ExtRandomScale([0.5,2]) 40 | for i in range(10): 41 | x_s, y_s = scaler(ToPILImage()(x), ToPILImage()(lbl)) 42 | print(x_s.size, y_s.size) 43 | 44 | -------------------------------------------------------------------------------- /utils/tsne_visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | from tsne_torch import TorchTSNE as TSNE 4 | from .utils import to_numpy 5 | import pathlib 6 | 7 | def test(): 8 | f = torch.rand(size=(100, 128)) 9 | f_ = TSNE(n_components=2, perplexity=30, n_iter=1000, verbose=True).fit_transform(f) 10 | l = f_.tolist() 11 | x,y = zip(*l) 12 | plt.scatter(x,y) 13 | 14 | 15 | class TsneMAnager(): 16 | def __init__(self, dataset, n_classes, feat_dim, run_id=None, scale=4): 17 | self.dataset = dataset 18 | self.n_classes = n_classes 19 | self.feats_per_class = 1000 20 | self.feat_dim = feat_dim 21 | self.feats = []#torch.zeros(size=(self.n_classes, self.feats_per_class, self.feat_dim)) 22 | self.labels = [] # class id per element of self.feats 23 | self.counts = [0] * self.n_classes 24 | self.scale = scale 25 | self.run_id = run_id if run_id is not None else 'tsne' 26 | 27 | # tsne settings 28 | self.perplexity = 30 29 | self.iters = 2000 30 | 31 | def accumulate(self, feats, labels): 32 | n, h, w = labels.shape 33 | assert n == 1 34 | lbl_down = torch.nn.functional.interpolate(labels.unsqueeze(1).float(), (h // self.scale, w // self.scale), 35 | mode='nearest').long() 36 | _, _, h, w = lbl_down.shape 37 | lbl_down = lbl_down.view(-1) 38 | # feats1 = feats.view(self.feat_dim, h*w) 39 | feats = feats.squeeze().view(self.feat_dim, -1) # self.feat_dim, h*w 40 | if self.dataset == 'CITYSCAPES': 41 | for cl in range(self.n_classes): 42 | views_per_class = self.feats_per_class // 500 if cl < 15 else 10 43 | if self.counts[cl] < self.feats_per_class: 44 | indices_from_cl = (lbl_down == cl).nonzero().squeeze() 45 | if len(indices_from_cl.shape) > 0: 46 | random_permutation = torch.randperm(indices_from_cl.shape[0]) 47 | this_views_per_class = min(views_per_class, indices_from_cl.shape[0]) 48 | if this_views_per_class > 0: 49 | sampled_indices_from_cl = indices_from_cl[random_permutation[:this_views_per_class]] 50 | self.feats.append(feats[:, sampled_indices_from_cl].T) 51 | self.labels += [cl] * this_views_per_class # class id per element of self.feats 52 | self.counts[cl] += this_views_per_class 53 | # print(f'class {cl} added {this_views_per_class} {len(indices_from_cl)} feats resulting in counts {self.counts[cl]}') 54 | else: 55 | print(f'class {cl} with counts {self.counts[cl]} is done') 56 | else: 57 | raise NotImplementedError() 58 | 59 | def compute(self, log_dir): 60 | f = torch.cat(self.feats) 61 | f_tsne = TSNE(n_components=2, perplexity=self.perplexity, n_iter=self.iters, verbose=True).fit_transform(f) 62 | l = f_tsne.tolist() 63 | x, y = zip(*l) 64 | # for colours look here https://matplotlib.org/3.5.0/gallery/color/named_colors.html 65 | cmap = {0: "red", 1: "green", 2: "blue", 3: "yellow", 4: "pink", 5: "black", 6: "orange", 7: "purple", 66 | 8: "beige", 9: "brown", 10: "gray", 11: "cyan", 12: "magenta", 13: "hotpink", 14: "darkviolet", 15: "mediumblue", 67 | 16: "lightsteelblue", 17: "gold", 18: "maroon"} 68 | colors = [cmap[l] for l in self.labels] 69 | fig = plt.scatter(x, y, c=colors, label=self.labels) 70 | plt.savefig(str(pathlib.Path(log_dir)/pathlib.Path( 71 | f'{self.run_id}_perp-{self.perplexity}_its-{self.iters}_feats-per-class-{self.feats_per_class}_scale{self.scale}.png'))) 72 | print(f'counts: {[(i, c) for i, c in enumerate(self.counts)]}') 73 | return f_tsne 74 | 75 | 76 | 77 | # def get_tsne_embedddings_ms(feats_ms, labels, scale, dataset): 78 | # n, h, w = labels.shape 79 | # assert n == 1 80 | # lbl_down = torch.nn.functional.interpolate(labels.unsqueeze(1).float(), (h//scale, w//scale), mode='nearest').long() 81 | # assert isinstance(feats_ms, list) 82 | # if isinstance(feats_ms, list) or isinstance(feats_ms, tuple): 83 | # for f in feats_ms: 84 | # get_tsne_embedddings(f, labels, scale, dataset) 85 | 86 | # 87 | # 88 | # def get_tsne_embedddings(feats, labels, scale, dataset): 89 | # print(feats.shape, labels.shape) 90 | # n, h, w = labels.shape 91 | # assert n == 1 92 | # lbl_down = torch.nn.functional.interpolate(labels.unsqueeze(1).float(), (h//scale, w//scale), mode='nearest').long() 93 | # c = feats.shape[1] # feature space dimensionality 94 | # feats = feats.view(h*w, c) 95 | # lbl_down = lbl_down.view(h*w) 96 | # if dataset == 'CITYSCAPES': 97 | # print('computing tsne for CITYSCAPES') 98 | # 99 | # return 0 100 | 101 | --------------------------------------------------------------------------------