├── README.md
├── configs
├── ADE20K
│ ├── hrnetocr_contrastive_ADE20K.json
│ └── upnswin_contrastive_ADE20K.json
├── CITYSCAPES
│ └── hrnet_contrastive_CTS.json
├── __init__.py
└── path_info.json
├── data
├── data.csv
└── data.pkl
├── datasets
├── ADE20K.py
├── CaDIS.py
├── Cityscapes.py
├── Dataset_from_df.py
├── PascalC.py
└── __init__.py
├── env_dgx.yml
├── losses
├── DenseContrastiveLossV2.py
├── DenseContrastiveLossV2_ms.py
├── LossWrapper.py
├── LovaszSoftmax.py
├── TwoScaleLoss.py
└── __init__.py
├── main.py
├── managers
├── BaseManager.py
├── DeepLabv3_Manager.py
├── HRNet_Manager.py
├── LoggingManager.py
├── OCRNet_Manager.py
└── __init__.py
├── misc
└── figs
│ └── fig1-01-01.png
├── models
├── DeepLabv3.py
├── HRNet.py
├── OCR.py
├── Projector.py
├── Swin.py
├── TTAWrapperSlide.py
├── TTA_wrapper.py
├── TTA_wrapper_CTS.py
├── TTA_wrapper_PC.py
├── Transformers.py
├── UPerNet.py
├── __init__.py
└── hrnet_config.py
└── utils
├── __init__.py
├── checkpoint_utils.py
├── config_parsers.py
├── datasets_info
├── ADE20K.py
├── CADIS.py
├── CITYSCAPES.py
├── PASCALC.py
└── __init__.py
├── defaults.py
├── df_from_data.py
├── distributed.py
├── logger.py
├── lr_functions.py
├── metrics.py
├── np_transforms.py
├── optimizer_utils.py
├── pointrend_utils.py
├── repeat_factor_sampling.py
├── semi_utis.py
├── torch_transforms.py
├── torch_utils.py
├── transforms.py
├── tsne_visualization.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # ECCV2022 Multi-scale and Cross-scale Contrastive Learning for Semantic Segmentation
2 | Implementation of "Multi-scale and Cross-scale Contrastive Learning for Semantic Segmentation", to appear at ECCV 2022
3 |
4 | arxiv link : https://arxiv.org/abs/2203.13409
5 |
6 |
7 | 
8 |
9 | > [**Multi-scale and Cross-scale Contrastive Learning for Semantic Segmentation**](https://arxiv.org/abs/2203.13409),
10 | > [Theodoros Pissas](https://rvim.online/author/theodoros-pissas/), [Claudio S. Ravasio](https://rvim.online/author/claudio-ravasio/), [Lyndon Da Cruz](), [Christos Bergeles](https://rvim.online/author/christos-bergeles/)
11 | >
12 | > *arXiv technical report ([arXiv 2203.13409](https://arxiv.org/abs/2203.13409))*
13 | >
14 | > *ECCV 2022 ([proceedings]())*
15 |
16 | ## Log
17 | - 20/07 loss code public
18 | - Coming soon: pretrained checkpoints and more configs for more models
19 |
20 |
21 | ## Data and requirements
22 | 1) Download datasets
23 | 2) Modify paths as per your setup in configs/paths_info.json to add path to a folder and a log_path and a data_path (see example in paths_info.json)
24 |
a) data_path should be the root directory of the datasets
25 |
b) log_path is where you want each run to generate a directory containing logs/checkpoints to be stored
26 | 3) Create conda environment with pytorch 1.7 and CUDA 10.0
27 | ```bash
28 | conda env create -f env_dgx.yml
29 | conda activate semseg
30 | ```
31 |
32 | ## Train
33 | To train a model we specify most settings using json configuration files, found in ```configs```.
34 | For each model on each dataset uses its own config. We also specify a few settings from commandline (see main.py)
35 | and also can override config settings from the commandline (see main.py)
36 | Here we show commands to start training on 4 GPUs and with the settings used in the paper.
37 |
38 | Training with ResNet or HRNet backbones requires imagenet initialization which is handled by torchvision or downloaded from a url respectively.
39 | To train with Swin backbones we use the provided imagenet checkpoints from their official implementation https://github.com/microsoft/Swin-Transformer/.
40 | These must be downloaded in a directory called pytorch_checkpoints structured as follows:
41 |
42 | ```
43 | pytorch_checkpoints/swin_imagenet/swin_tiny_patch4_window7_224.pth
44 | /swin_small_patch4_window7_224.pth
45 | /swin_base_patch4_window7_224.pth
46 | /swin_large_patch4_window7_224_22k.pth
47 | ```
48 | Example commands to start training (d = cuda device ids, p = multigpu training bs = batch size, w = workers per gpu ):
49 | - For HRNet on Cityscapes:
50 | ```bash
51 | python main.py -d 0,1,2,3 -p -u theo -c configs/CITYSCAPES/hrnet_contrastive_CTS.json -bs 12 -w 3
52 | ```
53 | - For UPerNet SwinT on ADE20K:
54 | ```bash
55 | python main.py -d 0,1,2,3 -p -u theo -c configs/upnswin_contrastive_ADE20K.json -bs 16 -w 4
56 | ```
57 |
58 | [//]: # (## Run a pretrained model)
59 |
60 | [//]: # (- Example of how to run inference with pretrained model:)
61 |
62 | [//]: # ( ```bash)
63 |
64 | [//]: # ( python main.py -d 0 -u theo -c configs/ADE20K/upnswin_contrastive_ADE20K.json -bs 1 -w 4 -m inference -cpt 20220303_230257_e1__upn_alignFalse_projFpn_swinT_sbn_DCms_cs_epochs127_bs16 -so)
65 |
66 | [//]: # ( ```)
67 |
68 | ## Licensing and copyright
69 |
70 | Please see the LICENSE file for details.
71 |
72 | ## Acknowledgements
73 |
74 | This project utilizes [timm] and the official implementation of [swin] Transformer.
75 | We thank the authors of those projects for open-sourcing their code and model weights.
76 |
77 | [timm]: https://github.com/rwightman/pytorch-image-models
78 |
79 | [swin]: https://github.com/microsoft/Swin-Transformer/
80 |
81 | ## Citation
82 | If you found the paper or code useful please cite the following:
83 |
84 | ```
85 | @misc{https://doi.org/10.48550/arxiv.2203.13409,
86 | doi = {10.48550/ARXIV.2203.13409},
87 | url = {https://arxiv.org/abs/2203.13409},
88 | author = {Pissas, Theodoros and Ravasio, Claudio S. and Da Cruz, Lyndon and Bergeles, Christos},
89 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
90 | title = {Multi-scale and Cross-scale Contrastive Learning for Semantic Segmentation},
91 | publisher = {arXiv},
92 | year = {2022},
93 | copyright = {Creative Commons Attribution Non Commercial Share Alike 4.0 International}
94 | }
95 | ```
96 |
--------------------------------------------------------------------------------
/configs/ADE20K/hrnetocr_contrastive_ADE20K.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "hrnOCR",
3 | "mode": "training",
4 | "manager": "OCRNet",
5 | "graph": {
6 | "model": "OCRNet",
7 | "backbone": "hrnet48",
8 | "sync_bn":true,
9 | "out_stride": 4,
10 | "pretrained": true,
11 | "align_corners": true,
12 | "ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": true, "before_context": true}
13 | },
14 |
15 | "load_last": true,
16 | "tta":true,
17 | "tta_scales": [0.5, 0.75, 1.25, 1.5, 1.75],
18 | "load_checkpoint_": "no",
19 | "run_final_val": false,
20 |
21 | "data": {
22 | "num_workers":8,
23 | "dataset": "ADE20K",
24 | "use_relabeled": false,
25 | "blacklist": false,
26 | "experiment": 1,
27 | "split": ["train", "val"],
28 | "transforms": ["flip", "random_scale", "RandomCropImgLbl", "colorjitter", "torchvision_normalise"],
29 | "transform_values": {"crop_shape": [512, 512], "crop_class_max_ratio": 0.75,
30 | "scale_range": [0.5, 2]},
31 | "transforms_val": ["torchvision_normalise"],
32 | "transform_values_val": {"min_side_length": 512,
33 | "crop_class_max_ratio": 0.75,
34 | "fit_stride_val": 32},
35 | "batch_size": 16
36 | },
37 |
38 | "loss": {
39 | "name": "LossWrapper",
40 | "label_scaling_mode": "nn",
41 | "dominant_mode": "all",
42 | "temperature": 0.1,
43 | "cross_scale_contrast": true,
44 | "weights": [1, 0.7, 0.4, 0.1],
45 | "scales": 4,
46 | "interm": {"name": "CrossEntropyLoss", "args": [], "weight": 0.4},
47 | "final": {"name": "CrossEntropyLoss", "args": [], "weight": 1.0},
48 | "losses": {"TwoScaleLoss": 1, "DenseContrastiveLossV2_ms": 0.1},
49 |
50 | "losses___": {"TwoScaleLoss": 1},
51 |
52 | "min_views_per_class": 5,
53 | "max_views_per_class": 2500,
54 | "max_features_total": 10000
55 | },
56 |
57 | "train": {
58 | "learning_rate": 0.02,
59 | "lr_fct": "polynomial",
60 | "optim": "SGD",
61 | "lr_batchwise": true,
62 | "epochs": 120,
63 | "momentum": 0.9,
64 | "weight_decay": 0.0001
65 | },
66 | "max_valid_imgs": 2,
67 | "valid_freq": 10,
68 | "log_every_n_epochs": 25,
69 | "cuda": true,
70 | "gpu_device": 0,
71 | "parallel": false,
72 | "seed": 100
73 | }
--------------------------------------------------------------------------------
/configs/ADE20K/upnswin_contrastive_ADE20K.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "upn",
3 | "mode": "training",
4 | "manager": "OCRNet",
5 | "graph": {
6 | "model": "UPerNet",
7 | "backbone": "swinT",
8 | "sync_bn":true,
9 | "out_stride": 32,
10 | "pretrained": false,
11 | "align_corners": false,
12 | "aux_head":{"in_index": 3, "dropout_rate": 0.1},
13 | "dropout_rate" : 0.1,
14 | "ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": true, "position":"fpn"}
15 | },
16 |
17 | "load_last": false,
18 | "tta":false,
19 | "tta_scales": [0.5, 0.75, 1.25, 1.5, 1.75],
20 |
21 | "data": {
22 | "num_workers":0,
23 | "dataset": "ADE20K",
24 | "use_relabeled": false,
25 | "blacklist": false,
26 | "experiment": 1,
27 | "split": ["train", "val"],
28 | "transforms": ["flip", "random_scale", "RandomCropImgLbl", "colorjitter", "torchvision_normalise"],
29 | "transform_values": {"crop_shape": [512, 512], "crop_class_max_ratio": 0.75,
30 | "scale_range": [0.5, 2]},
31 | "transforms_val": ["resize_val", "torchvision_normalise"],
32 | "transform_values_val": {"min_side_length": 512,
33 | "crop_class_max_ratio": 0.75,
34 | "fit_stride_val": 32},
35 | "batch_size": 16
36 | },
37 |
38 | "loss": {
39 | "name": "LossWrapper",
40 | "label_scaling_mode": "nn",
41 | "dominant_mode": "all",
42 | "temperature": 0.1,
43 | "cross_scale_contrast": true,
44 | "weights": [1.0, 0.7, 0.4, 0.1],
45 | "scales": 4,
46 | "interm": {"name": "CrossEntropyLoss", "args": [], "weight": 0.4},
47 | "final": {"name": "CrossEntropyLoss", "args": [], "weight": 1.0},
48 | "losses": {"TwoScaleLoss": 1.0, "DenseContrastiveLossV2_ms": 0.1},
49 | "losses__": {"TwoScaleLoss": 1.0},
50 | "min_views_per_class": 5,
51 | "max_views_per_class": 2500,
52 | "max_features_total": 10000
53 | },
54 |
55 | "train": {
56 | "lr_batchwise": true,
57 | "learning_rate": 0.00006,
58 | "lr_fct": "linear-warmup-polynomial",
59 | "lr_params": {"power": 1.0,
60 | "warmup_iters": 1500,
61 | "warmup_rate": 1e-6 ,
62 | "min_lr": 0.0},
63 | "optim": "AdamW",
64 | "epochs": 127,
65 | "epochs_bs12": 648,
66 | "momentum": 0.9,
67 | "betas": [0.9, 0.999],
68 | "weight_decay": 0.01,
69 | "opt_keys":{"absolute_pos_embed": {"wd_mult":0.0},
70 | "norm": {"wd_mult":0.0},
71 | "relative_position_bias_table":{"wd_mult": 0.0}}
72 | },
73 | "max_valid_imgs": 2,
74 | "valid_freq": 10,
75 | "log_every_n_epochs": 10,
76 | "cuda": true,
77 | "gpu_device": 0,
78 | "parallel": false,
79 | "seed": 100
80 | }
--------------------------------------------------------------------------------
/configs/CITYSCAPES/hrnet_contrastive_CTS.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "hrn",
3 | "mode": "training",
4 | "manager": "HRNet",
5 | "graph": {
6 | "model": "HRNet",
7 | "backbone": "hrnet48",
8 | "sync_bn":true,
9 | "out_stride": 4,
10 | "pretrained": false,
11 | "align_corners": true,
12 | "ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": true, "before_context": true}
13 | },
14 |
15 | "load_last": true,
16 | "tta":true,
17 | "tta_scales": [0.75, 1.25, 1.5, 1.75, 2],
18 | "run_final_val": false,
19 |
20 | "data": {
21 | "num_workers":8,
22 | "dataset": "CITYSCAPES",
23 | "use_relabeled": false,
24 | "blacklist": false,
25 | "experiment": 1,
26 | "split": ["train", "val"],
27 | "transforms": ["flip", "random_scale", "RandomCropImgLbl", "colorjitter", "torchvision_normalise"],
28 | "transform_values": {"crop_shape": [512, 1024], "crop_class_max_ratio": 0.75,
29 | "scale_range": [0.5, 2]},
30 | "transforms_val": ["torchvision_normalise"],
31 | "transform_values_val": {},
32 | "batch_size": 12
33 | },
34 |
35 | "loss": {
36 | "name": "LossWrapper",
37 | "label_scaling_mode": "nn",
38 | "dominant_mode": "all",
39 | "temperature": 0.1,
40 | "cross_scale_contrast": true,
41 | "weights": [1, 0.7, 0.4, 0.1],
42 | "scales": 4,
43 | "losses": {"CrossEntropyLoss": 1,"DenseContrastiveLossV2_ms": 0.1},
44 | "losses___": {"CrossEntropyLoss": 1},
45 | "min_views_per_class": 5,
46 | "max_views_per_class": 2500,
47 | "max_features_total": 10000
48 | },
49 | "train": {
50 | "learning_rate": 0.01,
51 | "lr_fct": "polynomial",
52 | "optim": "SGD",
53 | "lr_batchwise": true,
54 | "epochs": 484,
55 | "momentum": 0.9,
56 | "wd": 0.0005
57 | },
58 | "valid_batch_size": 1,
59 | "max_valid_imgs":2,
60 | "valid_freq": 100,
61 | "log_every_n_epochs": 100,
62 | "cuda": true,
63 | "gpu_device": 0,
64 | "parallel": false,
65 | "seed": 0
66 | }
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RViMLab/ECCV2022-multi-scale-and-cross-scale-contrastive-segmentation/c511fbcde6ac53b72d663225bdf6dded022ca1ce/configs/__init__.py
--------------------------------------------------------------------------------
/configs/path_info.json:
--------------------------------------------------------------------------------
1 |
2 | {
3 | "theo_CTS": [
4 | "D:\\datasets\\CITYSCAPES",
5 | "D:\\datasets\\CITYSCAPES\\logs"
6 | ],
7 |
8 | "theo_ADE20K": [
9 | "D:\\datasets\\ADEChallengeData2016",
10 | "D:\\datasets\\ADEChallengeData2016\\logs"
11 | ]
12 | }
13 |
--------------------------------------------------------------------------------
/data/data.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RViMLab/ECCV2022-multi-scale-and-cross-scale-contrastive-segmentation/c511fbcde6ac53b72d663225bdf6dded022ca1ce/data/data.pkl
--------------------------------------------------------------------------------
/datasets/ADE20K.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import namedtuple
4 | from typing import Union
5 | from torch.utils.data.dataset import Dataset
6 | from torchvision.transforms import Compose, ToPILImage
7 | from PIL import Image, ImageFile
8 | from utils import DATASETS_INFO, remap_mask, printlog
9 | import numpy as np
10 | ImageFile.LOAD_TRUNCATED_IMAGES = True
11 | import pathlib
12 |
13 | from utils import DATASETS_INFO, remap_mask, printlog, mask_to_colormap, get_remapped_colormap
14 |
15 |
16 |
17 | class ADE20K(Dataset):
18 |
19 | def __init__(self, root, transforms_dict, split:Union[str,list]='train', debug=False):
20 | """
21 |
22 | :param root: path to cityscapes dir (i.e where directories "leftImg8bit" and "gtFine" are located)
23 | :param transforms_dict: see dataset_from_df.py
24 | :param split: any of "train", "test", "val"
25 | :param mode: if "fine" then loads finely annotated images else Coarsely uses coarsely annotated
26 | :param target_type: currently only expects the default: 'semantic' (todo: test other target_types if needed)
27 | """
28 |
29 |
30 | super(ADE20K, self).__init__()
31 | self.root = root
32 | self.common_transforms = Compose(transforms_dict['common'])
33 | self.img_transforms = Compose(transforms_dict['img'])
34 | self.lbl_transforms = Compose(transforms_dict['lbl'])
35 | # assert(mode in ("fine", "coarse"))
36 | valid_splits = ["train", "test", "val", ['train', 'val']]
37 | assert (split in valid_splits), f'split {split} is not in valid_modes {valid_splits}'
38 | # self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
39 | self.split = split # "train", "test", "val"
40 | self.debug = debug
41 | # self.target_type = target_type
42 | self.images = []
43 | self.targets = []
44 | # this can only take the following values so hardcoded
45 | self.dataset = 'ADE20K'
46 | self.experiment = 1
47 | self.img_suffix = '.jpg'
48 | self.target_suffix = '.png'
49 |
50 |
51 | if self.split == ['train', 'val']:
52 | # for training on train + val
53 | printlog('train set is train+val splits')
54 | for i, s in enumerate(self.split):
55 | self.images_dir = os.path.join(self.root, 'ADEChallengeData2016', 'images', s)
56 | self.targets_dir = os.path.join(self.root, 'ADEChallengeData2016', 'annotations', s)
57 | for image_filename in os.listdir(self.images_dir):
58 | img_path = os.path.join(self.images_dir, image_filename)
59 | target_path = os.path.join(self.targets_dir, image_filename.split(self.img_suffix)[-2] + self.target_suffix)
60 | self.images.append(img_path)
61 | self.targets.append(target_path)
62 | assert (pathlib.Path(self.images[-1]).exists() and pathlib.Path(self.targets[-1]).exists())
63 | assert(pathlib.Path(self.images[-1]).stem == pathlib.Path(self.targets[-1]).stem)
64 |
65 | elif self.split == 'test':
66 | self.images_dir = os.path.join(self.root, 'ADEChallengeData2016', 'images', split)
67 | self.targets_dir = os.path.join(self.root,'ADEChallengeData2016', 'annotations', 'train') # dummy
68 | targets_dummy = os.listdir(self.targets_dir)
69 | for n, image_filename in enumerate(os.listdir(self.images_dir)):
70 | img_path = os.path.join(self.images_dir, image_filename)
71 | target_path = os.path.join(self.targets_dir, targets_dummy[n])
72 | self.images.append(img_path)
73 | self.targets.append(target_path)
74 | assert (pathlib.Path(self.images[-1]).exists()) # and pathlib.Path(self.targets[-1]).exists())
75 | # assert(pathlib.Path(self.images[-1]).stem == pathlib.Path(self.targets[-1]).stem)
76 | else:
77 | self.images_dir = os.path.join(self.root, 'ADEChallengeData2016', 'images', split)
78 | self.targets_dir = os.path.join(self.root, 'ADEChallengeData2016', 'annotations', split)
79 | for image_filename in os.listdir(self.images_dir):
80 | img_path = os.path.join(self.images_dir, image_filename)
81 | target_path = os.path.join(self.targets_dir, image_filename.split(self.img_suffix)[-2] + self.target_suffix)
82 | self.images.append(img_path)
83 | self.targets.append(target_path)
84 | assert (pathlib.Path(self.images[-1]).exists() and pathlib.Path(self.targets[-1]).exists())
85 | assert(pathlib.Path(self.images[-1]).stem == pathlib.Path(self.targets[-1]).stem)
86 | printlog(f'ade20k data all found split = {self.split}, images {len(self.images)}, targets {len(self.targets)}')
87 |
88 | self.return_filename = False
89 |
90 | def __getitem__(self, index, ):
91 | """
92 | Args:
93 | index (int): Index
94 | Returns:
95 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
96 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
97 | """
98 |
99 | image = Image.open(self.images[index]).convert('RGB')
100 | metadata = {'index': index}
101 |
102 | if self.split == 'test':
103 | target = remap_mask(np.ones(shape=np.array(image).shape[0:2], dtype=np.int32),
104 | DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], to_network=True).astype('int32')
105 |
106 | else:
107 | target = Image.open(self.targets[index])
108 | target = remap_mask(np.array(target),
109 | DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], to_network=True).astype('int32')
110 | # print(index, ': ', np.unique(target))
111 | target = Image.fromarray(target)
112 | # if 14 in np.unique(target).tolist(): # tracl
113 | # target.show()
114 | # target.close()
115 | # return 0 , 0 , 0
116 |
117 | image, target, metadata = self.common_transforms((image, target, metadata))
118 | img_tensor = self.img_transforms(image)
119 | lbl_tensor = self.lbl_transforms(target).squeeze()
120 |
121 | if self.return_filename:
122 | metadata.update({'img_filename': self.images[index],
123 | 'target_filename': self.targets[index]})
124 |
125 | if self.debug:
126 | # ToPILImage()(img_tensor).show()
127 | # ToPILImage()(lbl_tensor).show()
128 | # debug_lbl = mask_to_colormap(to_numpy(lbl_tensor),
129 | # get_remapped_colormap(
130 | # DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0],
131 | # self.dataset),
132 | # from_network=True, experiment=self.experiment,
133 | # dataset=self.dataset)[..., ::-1]
134 | #
135 | #
136 | #
137 | # fn = metadata['target_filename'].split('\\')[-1]
138 | # p = pathlib.Path(r'C:\Users\Theodoros Pissas\Documents\tresorit\ADEChallengeData2016\ADEChallengeData2016\visuals\val/')
139 | # p1 = pathlib.Path(f'{fn}')
140 | # # ToPILImage()(lbl_tensor).save(f"{str(p/p1)}")
141 | #
142 | # cv2.imwrite(f"{str(p/p1)}", debug_lbl)
143 | print(f'\nafter aug index : {np.unique(lbl_tensor)} lbl {lbl_tensor.shape} image {img_tensor.shape} fname:{self.images[index]}')
144 | return img_tensor, lbl_tensor, metadata
145 |
146 | def __len__(self):
147 | return len(self.images)
148 |
149 | if __name__ == '__main__':
150 | import pathlib
151 | import torch
152 | from utils import parse_transform_lists
153 | import json
154 | import cv2
155 | from torch.nn import functional as F
156 | from utils import Pad, RandomResize, RandomCropImgLbl, Resize, FlipNP, to_numpy, pil_plot_tensor, to_comb_image
157 | from torchvision.transforms import ToTensor
158 | import PIL.Image as Image
159 |
160 | data_path = 'C:\\Users\\Theodoros Pissas\\Documents\\tresorit\\ADEChallengeData2016\\'
161 | d = {"dataset":'ADE20K', "experiment":1}
162 | path_to_config = '../configs/ADE20K/upnswin_contrastive_ADE20K.json'
163 | with open(path_to_config, 'r') as f:
164 | config = json.load(f)
165 |
166 | transforms_list = config['data']['transforms']
167 | transforms_values = config['data']['transform_values']
168 | if 'torchvision_normalise' in transforms_list:
169 | del transforms_list[-1]
170 |
171 | transforms_dict = parse_transform_lists(transforms_list, transforms_values, **d)
172 | transforms_list_val = config['data']['transforms_val']
173 | transforms_values_val = config['data']['transform_values_val']
174 |
175 | if 'torchvision_normalise' in transforms_list_val:
176 | del transforms_list_val[-1]
177 | transforms_dict_val = parse_transform_lists(transforms_list_val, transforms_values_val, **d)
178 | del transforms_list_val[0]
179 | train_set = ADE20K(root=data_path,
180 | debug=True,
181 | split=['train', 'val'],
182 | transforms_dict=transforms_dict)
183 | valid_set = ADE20K(root=data_path,
184 | debug=True,
185 | split='test',
186 | transforms_dict=transforms_dict_val)
187 |
188 | issues = []
189 | valid_set.return_filename = True
190 | train_set.return_filename = True
191 | hs=[]
192 | ws = []
193 | for ret in valid_set:
194 | hs.append(ret[0].shape[1])
195 | ws.append(ret[0].shape[2])
196 | present_classes = torch.unique(ret[1])
197 | print(ret[-1])
198 | # elif 15 in present_classes:
199 | # issues.append([ret[-1], present_classes])
200 | # print('bus found !!!! ')
201 | # print(present_classes)
202 | # pil_plot_tensor(ret[0], is_rgb=True)
203 | # pil_plot_tensor(ret[1], is_rgb=False)
204 |
205 | # a = 1
206 | # print(max(hs))
207 | # print(max(ws))
--------------------------------------------------------------------------------
/datasets/CaDIS.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from utils import DATASETS_INFO, printlog
3 | import pathlib
4 |
5 |
6 | def get_cadis_dataframes(config: dict):
7 | # Make dataframes for the training and the validation set
8 | assert 'data' in config
9 | dataset = config['data']['dataset']
10 | assert dataset == 'CADIS', f'dataset must be CADIS instead got {dataset}'
11 | df = pd.read_csv('data/data.csv') # todo this should be moved to data dir
12 |
13 | if 'random_split' in config['data']:
14 | print("***Legacy mode: random split of all data used, instead of split of videos!***")
15 | train = df.sample(frac=config['data']['random_split'][0]).copy()
16 | valid = df.drop(train.index).copy()
17 | split_of_rest = config['data']['random_split'][1] / (1 - config['data']['random_split'][0])
18 | valid = valid.sample(frac=split_of_rest)
19 | else:
20 | splits = DATASETS_INFO[dataset].DATA_SPLITS[int(config['data']['split'])]
21 | if len(splits) == 3:
22 | printlog("using train-val-test split")
23 | train_videos, valid_videos, test_videos = splits
24 | if config['mode'] == 'infer':
25 | printlog(f"CADIS with mode {config['mode']}")
26 | printlog(f"going to use test_videos as vadilation set")
27 | valid_videos = test_videos
28 | elif len(splits) == 2:
29 | printlog("using train-merged[valtest] split")
30 | train_videos, valid_videos = splits
31 | else:
32 | raise ValueError('splits must be a list of length 2 or 3')
33 | train = df.loc[df['vid_num'].isin(train_videos)].copy()
34 | valid = df.loc[(df['vid_num'].isin(valid_videos)) & (df['propagated'] == 0)].copy() # No prop lbl in valid
35 | info_string = "Dataframes created. Number of records training / validation: {:06d} / {:06d}\n" \
36 | " Actual data split training / validation: {:.3f} / {:.3f}" \
37 | .format(len(train.index), len(valid.index), len(train.index) / len(df), len(valid.index) / len(df))
38 |
39 | # Replace incorrectly annotated frames if flag set
40 | if config['data']['use_relabeled']:
41 | train_idx_list = train[train['relabeled'] == 1].index
42 | for idx in train_idx_list:
43 | train.loc[idx, 'blacklisted'] = 0 # So the frames don't get removed after
44 | lbl_path = pathlib.Path(train.loc[idx, 'lbl_path']).name
45 | train.loc[idx, 'lbl_path'] = 'relabeled/' + str(lbl_path)
46 | valid_idx_list = valid[valid['relabeled'] == 1].index
47 | for idx in valid_idx_list:
48 | valid.loc[idx, 'blacklisted'] = 0 # So the frames don't get removed after
49 | lbl_path = pathlib.Path(valid.loc[idx, 'lbl_path']).name
50 | valid.loc[idx, 'lbl_path'] = 'relabeled/' + str(lbl_path)
51 | info_string += "\n Relabeled train recs: {}\n" \
52 | " Relabeled valid recs: {}" \
53 | .format(len(train_idx_list), len(valid_idx_list))
54 |
55 | # Remove incorrectly annotated frames if flag set
56 | if config['data']['blacklist']:
57 | train = train.drop(train[train['blacklisted'] == 1].index)
58 | valid = valid.drop(valid[valid['blacklisted'] == 1].index)
59 | t_len, v_len = len(train.index), len(valid.index)
60 | info_string += "\n After blacklisting: Number of records train / valid: {:06d} / {:06d}\n" \
61 | " Relative data split train / valid: {:.3f} / {:.3f}" \
62 | .format(t_len, v_len, t_len / (t_len + v_len), v_len / (t_len + v_len))
63 | train = train.reset_index()
64 | valid = valid.reset_index()
65 |
66 | printlog(f" dataset {dataset}")
67 | printlog(info_string)
68 | return train, valid
69 |
--------------------------------------------------------------------------------
/datasets/Dataset_from_df.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import cv2
4 | from torch.utils.data.dataset import Dataset
5 | from torchvision.transforms import Compose, ToPILImage
6 | import torch
7 | from utils import DATASETS_INFO, remap_mask
8 | import numpy as np
9 |
10 |
11 | class DatasetFromDF(Dataset):
12 | def __init__(self, dataframe, experiment, transforms_dict, data_path=None, labels_remaped=False,
13 | return_pseudo_property=False, dataset='CADIS', debug=False):
14 | self.df = dataframe
15 | self.experiment = experiment
16 | self.dataset = dataset
17 | self.common_transforms = Compose(transforms_dict['common'])
18 | self.img_transforms = Compose(transforms_dict['img'])
19 | self.lbl_transforms = Compose(transforms_dict['lbl'])
20 | self.labels_are_remapped = labels_remaped # used when reading pseudo labeled data
21 | self.return_pseudo_property = return_pseudo_property # used to return whether the datapoint is pseudo labelled
22 | self.preloaded = False if data_path is not None else True
23 | if self.preloaded: # Data preloaded, need to assert that 'image' and 'label' exist in the dataframe
24 | assert 'image' in self.df and 'label' in self.df, "For preloaded data, the dataframe passed to the " \
25 | "PyTorch dataset needs to contain the columns 'image' " \
26 | "and 'label'"
27 | else: # Standard case: data not preloaded, needs base path to get images / labels from
28 | assert 'img_path' in self.df and 'lbl_path' in self.df, "The dataframe passed to the PyTorch dataset needs"\
29 | " to contain the columns 'img_path' and 'lbl_path'"
30 | self.data_path = data_path
31 | self.debug = debug
32 |
33 | def __getitem__(self, item):
34 | if self.preloaded:
35 | img = self.df.iloc[item].loc['image']
36 | lbl = self.df.iloc[item].loc['label']
37 | else:
38 | # img = cv2.imread(str(pathlib.Path(self.data_path) / self.df.iloc[item].loc['img_path']))[..., ::-1]
39 | img = cv2.imread(
40 | os.path.join(
41 | self.data_path,
42 | os.path.join(*self.df.iloc[item].loc['img_path'].split('\\'))))[..., ::-1]
43 | img = img - np.zeros_like(img) # deals with negative stride error
44 | # lbl = cv2.imread(str(pathlib.Path(self.data_path) / self.df.iloc[item].loc['lbl_path']), 0)
45 | lbl = cv2.imread(
46 | os.path.join(
47 | self.data_path,
48 | os.path.join(*self.df.iloc[item].loc['lbl_path'].split('\\'))), 0)
49 | lbl = lbl - np.zeros_like(lbl)
50 |
51 | if self.labels_are_remapped:
52 | # if labels are pseudo they are already remapped to experiment label set
53 | lbl = lbl.astype('int32')
54 | else:
55 | lbl = remap_mask(lbl, DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], to_network=True).astype('int32')
56 |
57 | # Note: .astype('i') is VERY important. If left in uint8, ToTensor() will normalise the segmentation classes!
58 |
59 | # Here (and before Compose(lbl_transforms) we'd need to set the random seed and pray, following this idea:
60 | # https://github.com/pytorch/vision/issues/9#issuecomment-304224800
61 | # Big yikes. Big potential problem source, see here: https://github.com/pytorch/pytorch/issues/7068
62 | # If that doesn't work, the whole transforms structure needs to be changed into all-custom functions that will
63 | # transform both img and lbl at the same time, with one random shift / flip / whatever being applied to both
64 | metadata = {'index': item, 'filename': self.df.iloc[item].loc['img_path'],
65 | 'target_filename': str(pathlib.Path(self.df.iloc[item].loc['img_path']).stem)}
66 |
67 | if self.dataset == 'RETOUCH':
68 | subject_id = pathlib.Path(metadata['filename']).parent.stem
69 | slice_id = pathlib.Path(self.df.iloc[item].loc['lbl_path']).stem
70 | metadata['subject_id'] = subject_id
71 | metadata['target_filename'] = f"{subject_id}_{slice_id}"
72 |
73 | img, lbl, metadata = self.common_transforms((img, lbl, metadata))
74 | img_tensor = self.img_transforms(img)
75 | lbl_tensor = self.lbl_transforms(lbl).squeeze()
76 | if self.return_pseudo_property:
77 | # pseudo_tensor = torch.from_numpy(np.asarray(self.df.iloc[item].loc['pseudo']))
78 | metadata.update({'pseudo': self.df.iloc[item].loc['pseudo']})
79 |
80 | if self.debug:
81 | ToPILImage()(img_tensor).show()
82 | ToPILImage()(lbl_tensor).show()
83 | print(f'\nafter aug index : {np.unique(lbl_tensor)} lbl {lbl_tensor.shape} image {img_tensor.shape}')
84 |
85 | return img_tensor, lbl_tensor, metadata
86 |
87 | def __len__(self):
88 | return len(self.df)
89 |
--------------------------------------------------------------------------------
/datasets/PascalC.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import torch
4 | from collections import namedtuple
5 | from torch.utils.data.dataset import Dataset
6 | from torchvision.transforms import Compose, ToPILImage
7 | from PIL import Image, ImageFile
8 | from utils import DATASETS_INFO, remap_mask, printlog, mask_to_colormap, get_remapped_colormap
9 | import numpy as np
10 | # import cv2
11 | ImageFile.LOAD_TRUNCATED_IMAGES = True
12 | import pathlib
13 |
14 |
15 | class PascalC(Dataset):
16 | def __init__(self, root, transforms_dict, split='train', mode='fine', target_type='semantic', debug=False):
17 | """
18 | :param root: path to pascal dir (i.e where directories "leftImg8bit" and "gtFine" are located)
19 | :param transforms_dict: see dataset_from_df.py
20 | :param split: "train" or "val"
21 | :param mode: if "fine" then loads finely annotated images else Coarsely uses coarsely annotated
22 | :param target_type: currently only expects the default: 'semantic' (todo: test other target_types if needed)
23 | """
24 | self.debug = debug
25 | super(PascalC, self).__init__()
26 | self.root = root
27 | self.common_transforms = Compose(transforms_dict['common'])
28 | self.img_transforms = Compose(transforms_dict['img'])
29 | self.lbl_transforms = Compose(transforms_dict['lbl'])
30 | valid_modes = ["train", "val"]
31 | assert (split in valid_modes), f'split {split} is not in valid_modes {valid_modes}'
32 | # self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
33 | self.split = split # "train", "test", "val"
34 |
35 | # self.target_type = target_type
36 | self.images = []
37 | self.targets = []
38 | # this can only take the following values so hardcoded
39 | self.dataset = 'PASCALC'
40 | self.experiment = 1
41 |
42 | # for training on train + val
43 | self.images_dir = []
44 | self.targets_dir = []
45 | self.images_dir = pathlib.Path(os.path.join(self.root, self.split, 'image'))
46 | self.targets_dir = pathlib.Path(os.path.join(self.root, self.split, 'label'))
47 |
48 | for img_path, target_path in zip(sorted(self.images_dir.glob('*.jpg')), sorted(self.targets_dir.glob('*.png'))):
49 | self.images.append(img_path)
50 | self.targets.append(target_path)
51 | assert(pathlib.Path(self.images[-1]).exists() and pathlib.Path(self.targets[-1]).exists())
52 | assert(pathlib.Path(self.images[-1]).stem == pathlib.Path(self.targets[-1]).stem)
53 | printlog(f'{self.dataset} data all found split is [ {self.split} ]')
54 |
55 | self.return_filename = False
56 |
57 | def __getitem__(self, index):
58 | """
59 | Args:
60 | index (int): Index
61 | Returns:
62 | tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
63 | than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
64 | """
65 |
66 | image = Image.open(self.images[index]).convert('RGB')
67 | target = Image.open(self.targets[index])
68 |
69 | # if self.debug:
70 | # image.show()
71 | # target.show()
72 |
73 |
74 | target = remap_mask(np.array(target),
75 | DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0], to_network=True).astype('int32')
76 |
77 | target = Image.fromarray(target)
78 |
79 | # print(index, ': ', np.unique(target), ' ', [class_int_to_name[c] for c in np.unique(target) if not (c==59) ])
80 | # class_int_to_name = DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]
81 |
82 | metadata = {'index': index}
83 | image, target, metadata = self.common_transforms((image, target, metadata))
84 | img_tensor = self.img_transforms(image)
85 | lbl_tensor = self.lbl_transforms(target).squeeze()
86 |
87 | if self.return_filename:
88 | metadata.update({'img_filename': str(self.images[index]),
89 | 'target_filename': str(self.targets[index])})
90 | if self.debug:
91 | ToPILImage()(img_tensor).show()
92 | ToPILImage()(lbl_tensor).show()
93 | #
94 | # debug_lbl = mask_to_colormap(to_numpy(lbl_tensor),
95 | # get_remapped_colormap(
96 | # DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][0],
97 | # self.dataset),
98 | # from_network=True, experiment=self.experiment,
99 | # dataset=self.dataset)[..., ::-1]
100 | #
101 | #
102 | #
103 | # fn = metadata['target_filename'].split('\\')[-1]
104 | # p = pathlib.Path(r'C:\Users\Theodoros Pissas\Documents\tresorit\PASCALC\visuals\val/')
105 | # p1 = pathlib.Path(f'{fn}')
106 | # # ToPILImage()(lbl_tensor).save(f"{str(p/p1)}")
107 | #
108 | # cv2.imwrite(f"{str(p/p1)}", debug_lbl)
109 | #
110 | print(f'\nafter aug index, : {np.unique(lbl_tensor)} lbl {lbl_tensor.shape} image {img_tensor.shape} fname:{self.images[index]}')
111 |
112 | return img_tensor, lbl_tensor, metadata
113 |
114 | def __len__(self):
115 | return len(self.images)
116 |
117 | def extra_repr(self):
118 | lines = ["Split: {split}"]
119 | return '\n'.join(lines).format(**self.__dict__)
120 |
121 |
122 | if __name__ == '__main__':
123 | import pathlib
124 | from torch.nn import functional as F
125 | from utils import Pad, RandomResize, RandomCropImgLbl, Resize, FlipNP, to_numpy, pil_plot_tensor
126 | data_path = r'C:\Users\Theodoros Pissas\Documents\tresorit\PASCALC/'
127 | from torchvision.transforms import ToTensor
128 | import PIL.Image as Image
129 | d = {"dataset":'PASCALC', "experiment":1}
130 |
131 | # augs= [
132 | # FlipNP(probability=(0, 1.0)),
133 | # RandomResize(**d,
134 | # scale_range=[0.5, 2],
135 | # aspect_range=[0.9, 1.1],
136 | # target_size=[520, 520],
137 | # probability=1.0),
138 | # RandomCropImgLbl(**d,
139 | # shape=[512, 512],
140 | # crop_class_max_ratio=0.75),
141 | # ]
142 | #
143 | # augs_val= [Resize(**d, min_side_length=512, fit_stride=32, return_original_labels=True)]
144 | #
145 | #
146 | # train_set = PascalC(root=data_path, debug=True,
147 | # split='train',
148 | # transforms_dict={'common': augs_val,
149 | # 'img': [(ToTensor())],
150 | # 'lbl': [(ToTensor())]})
151 |
152 | from utils import parse_transform_lists
153 | import json
154 | path_to_config = '../configs/PASCALC/hrnet_contrastive_PC.json'
155 | with open(path_to_config, 'r') as f:
156 | config = json.load(f)
157 |
158 | transforms_list = config['data']['transforms']
159 | if 'torchvision_normalise' in transforms_list:
160 | del transforms_list[-1]
161 | transforms_values = config['data']['transform_values']
162 | transforms_dict = parse_transform_lists(transforms_list, transforms_values, dataset='PASCALC', experiment=1)
163 |
164 | transforms_list_val = config['data']['transforms_val']
165 | if 'torchvision_normalise' in transforms_list:
166 | del transforms_list[-1]
167 |
168 | transforms_values_val = config['data']['transform_values_val']
169 | transforms_dict_val = parse_transform_lists(transforms_list_val, transforms_values_val, dataset='PASCALC', experiment=1)
170 |
171 | # transforms_values_val = {}
172 | # transforms_dict_val = parse_transform_lists({}, transforms_values_val, dataset='PASCALC', experiment=1)
173 |
174 |
175 | train_set = PascalC(root=data_path,
176 | debug=True,
177 | split='train',
178 | transforms_dict=transforms_dict)
179 |
180 | valid_set = PascalC(root=data_path,
181 | debug=True,
182 | split='val',
183 | transforms_dict=transforms_dict_val)
184 | valid_set.return_filename = True
185 |
186 | issues = []
187 | train_set.return_filename = True
188 | hs=[]
189 | ws = []
190 | for ret in valid_set:
191 | # print(ret[0].shape)
192 | # img = ToPILImage()(ret[0]).show()
193 | # lbl = ToPILImage()(ret[1]).show()
194 |
195 | hs.append(ret[0].shape[1])
196 | ws.append(ret[0].shape[2])
197 | print(ret[-1])
198 | print('*'*10)
199 | # meta = ret[-1]
200 | # lbl = meta['original_labels'].unsqueeze(0)
201 | # resized = ret[1].unsqueeze(0).unsqueeze(0).long()
202 | # pad_w, pad_h, stride = meta["pw_ph_stride"]
203 | # if pad_h > 0 or pad_w > 0:
204 | #
205 | # un_padded = resized[:, :, 0:resized.size(2) - pad_h, 0:resized.size(3) - pad_w]
206 | # pil_plot_tensor(un_padded)
207 | # un_resized = F.interpolate(un_padded.float(), size=lbl.size()[-2:], mode='nearest')
208 | # print(torch.sum(un_resized- lbl))
209 | present_classes = torch.unique(ret[1])
210 | if len(present_classes) == 1 and 59 in present_classes:
211 | issues.append([ret[-1], present_classes])
212 | print('issue found !!!! ')
213 | print(present_classes, ret[-1])
214 | print('issue found !!!! ')
215 | a = 1
216 | print(max(hs))
217 | print(max(ws))
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .Dataset_from_df import DatasetFromDF
2 | from .Cityscapes import Cityscapes
3 | from .PascalC import PascalC
4 | from .ADE20K import ADE20K
5 | from .CaDIS import get_cadis_dataframes
6 |
--------------------------------------------------------------------------------
/env_dgx.yml:
--------------------------------------------------------------------------------
1 | name: semseg
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8
7 | - pytorch
8 | - cudatoolkit=10.1
9 | - tensorboard
10 | - torchvision=0.9.1
11 | - h5py
12 | - matplotlib
13 | - numpy
14 | - scipy
15 | - pandas
16 | - pillow
17 | - pip
18 | - future
19 | - pip:
20 | - opencv-python
21 | - tqdm
22 | - timm
23 | - tsne-torch
24 |
--------------------------------------------------------------------------------
/losses/DenseContrastiveLossV2_ms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from utils import DATASETS_INFO, is_distributed, concat_all_gather, get_rank, to_numpy, printlog
4 | from torch.nn.functional import one_hot
5 | import torch.distributed
6 | from losses.DenseContrastiveLossV2 import DenseContrastiveLossV2 as DCV2
7 |
8 | def has_inf_or_nan(x):
9 | return torch.isinf(x).max().item(), torch.isnan(x).max().item()
10 |
11 |
12 | class DenseContrastiveLossV2_ms(nn.Module):
13 | def __init__(self, config):
14 | super().__init__()
15 | self.parallel = is_distributed()
16 | self.experiment = config['experiment']
17 | self.dataset = config['dataset']
18 | self.num_all_classes = len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1])
19 | self.num_real_classes = self.num_all_classes - 1 if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] else self.num_all_classes
20 | self.ignore_class = (len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) - 1) if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] else -1
21 | self.scales = config['scales'] if 'scales' in config else 2
22 | self.weights = config['weights'] if 'weights' in config else [1.0] * self.scales
23 | assert(self.scales == len(self.weights)), f'given dc loss number of scales [{self.scales}] not equal len of weights {self.weights}'
24 | self.losses = []
25 | self.eps = torch.tensor(1e-10)
26 | self.meta = {}
27 | self.cross_scale_contrast = config['cross_scale_contrast'] if 'cross_scale_contrast' in config else False
28 | self.cross_scale_temperature = config['temperature'] if 'cross_scale_temperature' not in config else 0.1
29 | self.detach_cs_deepest = config['detach_deepest'] if 'detach_deepest' in config else False
30 | self.w_high_low = config['w_high_low'] if 'w_high_low' in config else 1.0
31 | self.w_high_mid = config['w_high_mid'] if 'w_high_mid' in config else 1.0
32 | self.ms_losses = []
33 | self.cs_losses = []
34 | printlog(f'defining dcv2 ms loss with number of scales {self.scales} and weights {self.weights}')
35 | printlog(f'using cross scale contrast {self.cross_scale_contrast}')
36 | for class_name in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]:
37 | self.meta[class_name] = (0.0, 0.0) # pos-neg per class
38 | for s in range(self.scales):
39 | printlog(f'defining dcv2 loss at scale {s}')
40 | setattr(self, f'DCV2_scale{s}', DCV2(config))
41 | if self.cross_scale_contrast:
42 | printlog(f'using cross-scale contrast with detach_cs_deepest set to {self.detach_cs_deepest}, w_high_low: {self.w_high_low}, w_high_mid: {self.w_high_mid}')
43 |
44 | def forward(self, label: torch.Tensor, features: list, **kwargs):
45 | self.cs_losses = []
46 | self.ms_losses = []
47 | flag_error = False
48 | loss = torch.tensor(0.0, dtype=torch.float, device=features[0].device)
49 | feats_ms = []
50 | labels_ms = []
51 | for s in range(self.scales):
52 | if self.cross_scale_contrast:
53 | loss_s, feats_s, labels_s, flag_error = getattr(self, f'DCV2_scale{s}')(label, features[s])
54 | loss+= self.weights[s] * loss_s
55 | feats_ms.append(feats_s)
56 | labels_ms.append(labels_s)
57 | else:
58 | loss_s=getattr(self, f'DCV2_scale{s}')(label, features[s])
59 | loss += self.weights[s] * loss_s
60 | self.ms_losses.append(loss_s.detach())
61 |
62 | if self.cross_scale_contrast and not flag_error:
63 | assert len(feats_ms) > 1
64 | assert len(labels_ms) > 1
65 | # highest res to lowest res contrast
66 | if self.detach_cs_deepest:
67 | loss_cross_scale = self.contrastive_loss(feats_ms[0], labels_ms[0], feats_ms[-1].detach(), labels_ms[-1])
68 | else:
69 | loss_cross_scale = self.contrastive_loss(feats_ms[0], labels_ms[0], feats_ms[-1], labels_ms[-1])
70 | self.cs_losses.append(loss_cross_scale.detach())
71 |
72 | loss += self.w_high_low * loss_cross_scale
73 |
74 | if len(feats_ms)>2: # hrnet : 4 , s4-s16 , s4-s32 dlv3 : 3 layer1(s4)-layer4(s8), layer1(s4)-layer3(s8)
75 | if self.detach_cs_deepest:
76 | loss_cross_scale2 = self.contrastive_loss(feats_ms[0], labels_ms[0], feats_ms[-2].detach(), labels_ms[-2])
77 | else:
78 | loss_cross_scale2 = self.contrastive_loss(feats_ms[0], labels_ms[0], feats_ms[-2], labels_ms[-2])
79 | loss += self.w_high_mid * loss_cross_scale2
80 | self.cs_losses.append(loss_cross_scale2.detach())
81 |
82 | return loss
83 |
84 | def contrastive_loss(self, feats1, labels1, feats2, labels2):
85 | """
86 | :param feats: T-C-V
87 | T: classes in batch (with repetition), which can be thought of as the number of anchors
88 | C: feature space dimensionality
89 | V: views per class (i.e samples from each class),
90 | which can be thought of as the number of views per anchor
91 | :param labels: T
92 | :return: loss
93 | """
94 | # prepare feats
95 | feats1 = torch.nn.functional.normalize(feats1, p=2, dim=1) # L2 normalization
96 | feats1 = feats1.transpose(dim0=1, dim1=2) # feats are T-V-C
97 | num_anchors, views_per_anchor, c = feats1.shape # get T, V, C
98 | feats_flat1 = feats1.contiguous().view(-1, c) # feats_flat is T*V-C
99 |
100 | labels1 = labels1.contiguous().view(-1, 1) # labels are T-1
101 | labels1 = labels1.repeat(1, views_per_anchor) # labels are T-V
102 | labels1 = labels1.view(-1, 1) # labels are T*V-1
103 |
104 | feats2 = torch.nn.functional.normalize(feats2, p=2, dim=1) # L2 normalization
105 | feats2 = feats2.transpose(dim0=1, dim1=2) # feats are T-V-C
106 | num_anchors, views_per_anchor, c = feats2.shape # get T, V, C
107 | feats_flat2 = feats2.contiguous().view(-1, c) # feats_flat is T*V-C
108 |
109 | labels2 = labels2.contiguous().view(-1, 1) # labels are T-1
110 | labels2 = labels2.repeat(1, views_per_anchor) # labels are T-V
111 | labels2 = labels2.view(-1, 1) # labels are T*V-1
112 |
113 | pos_mask, neg_mask = self.get_masks2(labels1, labels2)
114 | dot_product = torch.div(torch.matmul(feats_flat1, torch.transpose(feats_flat2, 0, 1)), self.cross_scale_temperature)
115 | loss2 = self.InfoNce_loss(pos_mask, neg_mask, dot_product)
116 | return loss2
117 |
118 | @staticmethod
119 | def get_masks2(lbl1, lbl2):
120 | """
121 | takes flattened labels and identifies pos/neg of each anchor
122 | :param labels: T*V-1
123 | :param num_anchors: T
124 | :param views_per_anchor: V
125 | :return: mask, pos_maks,
126 | """
127 | # extract mask indicating same class samples
128 | pos_mask = torch.eq(lbl1, torch.transpose(lbl2, 0, 1)).float() # mask T-T # indicator of positives
129 | neg_mask = (1 - pos_mask) # indicator of negatives
130 | return pos_mask, neg_mask
131 |
132 | def InfoNce_loss(self, pos, neg, dot):
133 | """
134 | :param pos: V*T-V*T
135 | :param neg: V*T-V*T
136 | :param dot: V*T-V*T
137 | :return:
138 | """
139 | # logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
140 | logits = dot # - logits_max.detach()
141 |
142 | neg_logits = torch.exp(logits) * neg
143 | neg_logits = neg_logits.sum(1, keepdim=True)
144 |
145 | exp_logits = torch.exp(logits)
146 | log_prob = logits - torch.log(exp_logits + neg_logits)
147 |
148 | pos_sums = pos.sum(1)
149 | ones = torch.ones(size=pos_sums.size())
150 | norm = torch.where(pos_sums > 0, pos_sums, ones.to(pos.device))
151 |
152 | mean_log_prob_pos = (pos * log_prob).sum(1) / norm # normalize by positives
153 |
154 | loss = - mean_log_prob_pos
155 |
156 | loss = loss.mean()
157 | # print('loss.mean() ', has_inf_or_nan(loss))
158 | # print('loss {}'.format(loss))
159 | if has_inf_or_nan(loss)[0] or has_inf_or_nan(loss)[1]:
160 | print('\n inf found in loss with positives {} and Negatives {}'.format(pos.sum(1), neg.sum(1)))
161 | return loss
162 |
163 | def get_meta(self):
164 | meta = {}
165 | meta['queue_fillings']= to_numpy(self.queue_fillings)
166 | meta['scales']= int(self.scales)
167 | return meta
168 |
--------------------------------------------------------------------------------
/losses/LossWrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | # noinspection PyUnresolvedReferences
4 | from losses import *
5 | from utils import DATASETS_INFO
6 | from typing import Union
7 |
8 |
9 | class LossWrapper(nn.Module):
10 | def __init__(self, config: dict):
11 | super().__init__()
12 | self.config = config
13 | self.loss_weightings = config['losses']
14 | self.device = config['device']
15 | self.dataset = config['dataset']
16 | self.experiment = config['experiment']
17 | self.total_loss = None
18 | self.loss_classes, self.loss_vals = {}, {}
19 | self.info_string = ''
20 | self.ignore_class = (len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) - 1) \
21 | if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] else -1
22 | for loss_class in self.loss_weightings:
23 | if loss_class == 'CrossEntropyLoss':
24 | class_weights = None
25 | if self.dataset == 'CITYSCAPES':
26 | class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,
27 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
28 | 1.0865, 1.1529, 1.0507]).cuda()
29 | print(f'using class_weights {class_weights}')
30 | loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_class, weight=class_weights)
31 | # loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_class)
32 | else:
33 | loss_fct = globals()[loss_class](config)
34 | self.loss_classes.update({loss_class: loss_fct})
35 | self.loss_vals.update({loss_class: 0})
36 | self.info_string += loss_class + ', '
37 | self.info_string = self.info_string[:-2]
38 | self.dc_off = True if 'dc_off_at_epoch' in self.config else False
39 |
40 | def forward(self,
41 | prediction: torch.Tensor,
42 | labels: torch.Tensor,
43 | loss_list: list = None,
44 | deep_features: Union[torch.Tensor,list] = None,
45 | interm_prediction: torch.Tensor = None,
46 | epoch: int = None,
47 | skip_mem_update: bool =False) -> torch.Tensor:
48 | self.total_loss = torch.tensor(0.0, dtype=torch.float, device=self.device)
49 | # Compile list of losses to be evaluated. If no specific 'loss_list' is passed
50 | loss_list = list(self.loss_weightings.keys()) if loss_list is None else loss_list
51 | for loss_class in self.loss_weightings: # Go through all the losses
52 | if loss_class in loss_list: # Check if this loss should be calculated
53 | if 'DenseContrastive' in loss_class:
54 | assert deep_features is not None, f'for loss_class {loss_class}, deep_features must be tensor (B,H,W,C) ' \
55 | f'instead got {deep_features}'
56 | if loss_class == 'LovaszSoftmax':
57 | if self.dc_off and epoch is not None and epoch < self.config['dc_off_at_epoch']:
58 | loss = torch.tensor(0.0, dtype=torch.float, device=self.device)
59 | else:
60 | loss = self.loss_classes[loss_class](prediction, labels)
61 | elif loss_class == 'DenseContrastiveLoss':
62 | if self.dc_off and epoch is not None and epoch >= self.config['dc_off_at_epoch']:
63 | loss = torch.tensor(0.0, dtype=torch.float, device=self.device)
64 | else:
65 | loss = self.loss_classes[loss_class](labels, deep_features)
66 | elif loss_class == 'TwoScaleLoss':
67 | loss = self.loss_classes[loss_class](interm_prediction, prediction, labels.long())
68 | elif loss_class == 'DenseContrastiveLossV2':
69 | loss = self.loss_classes[loss_class](labels, deep_features)
70 | elif loss_class == 'DenseContrastiveLossV2_ms':
71 | loss = self.loss_classes[loss_class](labels, deep_features)
72 | elif loss_class == 'DenseContrastiveLossV3':
73 | loss = self.loss_classes[loss_class](labels, deep_features, epoch, skip_mem_update)
74 | elif loss_class == 'DenseContrastiveLossV3_ms':
75 | loss = self.loss_classes[loss_class](labels, deep_features, epoch, skip_mem_update)
76 | # self.meta['queue'] = self.loss_classes[loss_class].queue_ptr.clone().numpy()
77 | elif loss_class == 'DenseContrastiveCenters':
78 | loss = self.loss_classes[loss_class](labels, deep_features, epoch, skip_mem_update)
79 | elif loss_class == 'OhemCrossEntropy':
80 | loss = self.loss_classes[loss_class](prediction, labels)
81 | elif loss_class == 'CrossEntropyLoss':
82 | loss = self.loss_classes[loss_class](prediction, labels)
83 | else:
84 | print("Error: Loss class '{}' not recognised!".format(loss_class))
85 | loss = torch.tensor(0.0, dtype=torch.float, device=self.device)
86 | else:
87 | loss = torch.tensor(0.0, dtype=torch.float, device=self.device)
88 |
89 | # Calculate weighted loss
90 | loss *= self.loss_weightings[loss_class]
91 | self.loss_vals[loss_class] = loss.detach()
92 |
93 | # logging each scale seperately if ms/cs loss
94 | if loss_class == 'DenseContrastiveLossV2_ms':
95 |
96 | if hasattr(self.loss_classes[loss_class], 'ms_losses'):
97 | for scale, loss_val_ms in enumerate(self.loss_classes[loss_class].ms_losses):
98 | self.loss_vals.update({f'{loss_class}_ms{scale}':loss_val_ms})
99 | if self.loss_classes[loss_class].cross_scale_contrast and hasattr(self.loss_classes[loss_class], 'cs_losses'):
100 | for cscale, loss_val_cs in enumerate(self.loss_classes[loss_class].cs_losses):
101 | self.loss_vals.update({f'{loss_class}_cs{cscale}':loss_val_cs})
102 | self.total_loss += loss
103 | return self.total_loss
104 |
--------------------------------------------------------------------------------
/losses/LovaszSoftmax.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.functional import softmax
4 | from utils import DATASETS_INFO
5 | from itertools import filterfalse
6 |
7 |
8 | class LovaszSoftmax(nn.Module):
9 | def __init__(self, config):
10 | super().__init__()
11 | self.eps = torch.as_tensor(1e-10)
12 | self.experiment = config['experiment']
13 | self.dataset = config['dataset']
14 |
15 | ignore_index_in_loss = len(
16 | DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) - 1 \
17 | if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] else None
18 | # self.num_classes = len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1])
19 | self.per_image = False if 'per_image' not in config else config['per_image']
20 | self.classes_to_ignore = ignore_index_in_loss if 'classes_to_ignore' not in config else config['classes_to_ignore']
21 | self.classes_to_consider = 'present' if 'classes_to_consider' not in config else config['classes_to_consider']
22 | # classes_to_consider: 'all' for all, 'present' for classes present in labels, or a list of classes to average
23 |
24 | def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
25 | """Multi-class Lovasz-Softmax loss. Adapted from github.com/bermanmaxim/LovaszSoftmax
26 |
27 | :param prediction: NCHW tensor, raw logits from the network
28 | :param target: NHW tensor, ground truth labels
29 | :return: Lovász-Softmax loss
30 | """
31 | p = softmax(prediction, dim=1)
32 | if self.per_image:
33 | loss = mean(self.lovasz_softmax_flat(*self.flatten_probabilities(p.unsqueeze(0), t.unsqueeze(0)))
34 | for p, t in zip(p, target))
35 | else:
36 | loss = self.lovasz_softmax_flat(*self.flatten_probabilities(p, target))
37 | return loss
38 |
39 | def lovasz_softmax_flat(self, prob: torch.Tensor, lbl: torch.Tensor) -> torch.Tensor:
40 | """Multi-class Lovasz-Softmax loss. Adapted from github.com/bermanmaxim/LovaszSoftmax
41 |
42 | :param prob: class probabilities at each prediction (between 0 and 1)
43 | :param lbl: ground truth labels (between 0 and C - 1)
44 | :return: Lovász-Softmax loss
45 | """
46 | if prob.numel() == 0:
47 | # only void pixels, the gradients should be 0
48 | return prob * 0.
49 | c = prob.size(1)
50 | losses = []
51 | class_to_sum = list(range(c)) if self.classes_to_consider in ['all', 'present'] else self.classes_to_consider
52 |
53 | if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1] and c in class_to_sum:
54 | class_to_sum.remove(c) # remove ignore class which is denoted in lbl by values = c
55 |
56 | for c in class_to_sum:
57 | fg = (lbl == c).float() # foreground for class c
58 | if self.classes_to_consider is 'present' and fg.sum() == 0:
59 | continue
60 | class_pred = prob[:, c]
61 | errors = (fg - class_pred).abs()
62 | errors_sorted, perm = torch.sort(errors, 0, descending=True)
63 | perm = perm.detach()
64 | fg_sorted = fg[perm]
65 | losses.append(torch.dot(errors_sorted, lovasz_grad(fg_sorted)))
66 | return mean(losses)
67 |
68 | def flatten_probabilities(self, prob: torch.Tensor, lbl: torch.Tensor):
69 | """
70 | Flattens predictions in the batch
71 | """
72 | if prob.dim() == 3:
73 | # assumes output of a sigmoid layer
74 | n, h, w = prob.size()
75 | prob = prob.view(n, 1, h, w)
76 | _, c, _, _ = prob.size()
77 | prob = prob.permute(0, 2, 3, 1).contiguous().view(-1, c) # B * H * W, C = P, C
78 | lbl = lbl.view(-1)
79 | if self.classes_to_ignore is None:
80 | return prob, lbl
81 | else:
82 | valid = (lbl != self.classes_to_ignore)
83 | vprobas = prob[valid.nonzero().squeeze()]
84 | vlabels = lbl[valid]
85 | return vprobas, vlabels
86 |
87 |
88 | def lovasz_grad(gt_sorted):
89 | """
90 | Computes gradient of the Lovasz extension w.r.t sorted errors
91 | See Alg. 1 in paper
92 | """
93 | p = len(gt_sorted)
94 | gts = gt_sorted.sum()
95 | intersection = gts - gt_sorted.float().cumsum(0)
96 | union = gts + (1 - gt_sorted).float().cumsum(0)
97 | jaccard = 1. - intersection / union
98 | if p > 1: # cover 1-pixel case
99 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
100 | return jaccard
101 |
102 |
103 | def isnan(x):
104 | return x != x
105 |
106 |
107 | def mean(ip: torch.Tensor, ignore_nan: bool = False, empty=0):
108 | """
109 | nanmean compatible with generators.
110 | """
111 | ip = iter(ip)
112 | if ignore_nan:
113 | ip = filterfalse(isnan, ip)
114 | try:
115 | n = 1
116 | acc = next(ip)
117 | except StopIteration:
118 | if empty == 'raise':
119 | raise ValueError('Empty mean')
120 | return empty
121 | for n, v in enumerate(ip, 2):
122 | acc += v
123 | if n == 1:
124 | return acc
125 | return acc / n
126 |
--------------------------------------------------------------------------------
/losses/TwoScaleLoss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 | from utils import DATASETS_INFO
4 | from losses import LovaszSoftmax
5 | from torch.nn import CrossEntropyLoss
6 | import torch
7 |
8 |
9 | class TwoScaleLoss(nn.Module):
10 | def __init__(self, config):
11 | """
12 | Loads two losses one from an intermediate output and one from the final output
13 | for now it assumes the two losses are the same CE-CE or Lovasz-Lovasz etc.
14 | the weights of the two losses may vary (by default 0.4 for interm and 1.0 final)
15 | :param config:
16 | """
17 | super(TwoScaleLoss, self).__init__()
18 | interm_loss_class = globals()[config['interm']['name']]
19 | final_loss_class = globals()[config['final']['name']]
20 | self.w_interm = config['interm']['weight'] if 'weight' in config['interm'] else 0.4
21 | self.w_final = config['final']['weight'] if 'weight' in config['final'] else 1.0
22 | self.ignore_label = -100 # if experiment is not given assume nothing is ignored
23 | self.dataset = config['dataset']
24 | self.experiment = config['experiment']
25 | if 'experiment' in config:
26 | self.ignore_label = len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1]) - 1 \
27 | if 255 in DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1].keys() \
28 | else len(DATASETS_INFO[self.dataset].CLASS_INFO[self.experiment][1])
29 |
30 | # pass experiment id to constructors of the two losses
31 | config['interm'].update({"experiment": config['experiment'], "dataset": self.dataset})
32 | config['final'].update({"experiment": config['experiment'], "dataset": self.dataset})
33 |
34 | if config['interm']['name'] == 'CrossEntropyLoss' and config['final']['name'] == 'CrossEntropyLoss':
35 | class_weights = None
36 | if self.dataset == 'CITYSCAPES':
37 | class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489,
38 | 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
39 | 1.0865, 1.1529, 1.0507]).cuda()
40 | print(f'using class weights {class_weights}')
41 | self.loss_interm = interm_loss_class(*config['interm']['args'],
42 | ignore_index=self.ignore_label, weight=class_weights)
43 | self.loss_final = final_loss_class(*config['final']['args'],
44 | ignore_index=self.ignore_label, weight=class_weights)
45 |
46 | # all other losses expect a config
47 | elif config['interm']['name'] == config['final']['name']:
48 | self.loss_interm = interm_loss_class(config['interm'])
49 | self.loss_final = final_loss_class(config['final'])
50 | else:
51 | raise NotImplementedError('different losses for interm {}'
52 | ' and final {}'.format(config['interm'], config['final']))
53 |
54 | print("intermediate loss {} with weight {}".format(interm_loss_class, self.w_interm))
55 | print("final loss {} with weight {}".format(final_loss_class, self.w_final))
56 |
57 | def forward(self, logits_interm, logits_final, target):
58 | # upsample intermediate if not already upsampled
59 | ph, pw = logits_interm.size(2), logits_interm.size(3)
60 | h, w = target.size(1), target.size(2)
61 | # todo add align_corners from outside --
62 | # this was ignored until now as upsampling was happening in model.forward()
63 | # if ph != h or pw != w:
64 | # logits_interm = F.upsample(input=logits_interm, size=(h, w), mode='bilinear')
65 | loss_final = self.loss_final(logits_final, target)
66 | loss_interm = self.loss_interm(logits_interm, target)
67 | loss = loss_final * self.w_final + loss_interm * self.w_interm
68 | return loss
69 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .LovaszSoftmax import LovaszSoftmax
2 | from .DenseContrastiveLossV2 import DenseContrastiveLossV2
3 | from .DenseContrastiveLossV2_ms import DenseContrastiveLossV2_ms
4 | from .TwoScaleLoss import TwoScaleLoss
5 | from .LossWrapper import LossWrapper
6 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | # noinspection PyUnresolvedReferences
3 | from managers import *
4 | from utils import parse_config
5 |
6 |
7 | def str2bool(s:str):
8 | assert type(s), f'input argument must be str instead {s}'
9 | if s in ['True', 'true']:
10 | return True
11 | elif s in ['False', 'false']:
12 | return False
13 | else:
14 | raise ValueError(f'string {s} ')
15 |
16 |
17 | if __name__ == '__main__':
18 | parser = argparse.ArgumentParser()
19 |
20 | parser.add_argument('-c', '--config', type=str, default='configs/FCN_train_config.json',
21 | help='Set path to configuration files, e.g. '
22 | 'python main.py --config configs/FCN_train_config.json.')
23 |
24 | parser.add_argument('-u', '--user', type=str, default='c',
25 | help='Select user to set correct data / logging paths for your system, e.g. '
26 | 'python main.py --user theo')
27 |
28 | parser.add_argument('-d', '--device', nargs="+", type=int, default=-1,
29 | help='Select GPU device to run the experiment one.g. --device 3')
30 |
31 | parser.add_argument('-s', '--dataset', type=str, default=-1, required=False,
32 | help='Select dataset to run the experiment one.g. --device 3')
33 |
34 | parser.add_argument('-p', '--parallel', action='store_true',
35 | help='whether to use distributed training')
36 |
37 | parser.add_argument('-debug', '--debugging', action='store_true',
38 | help='sets manager into debugging mode e.x --> cts is run with val/val split')
39 |
40 | parser.add_argument('-cdnb', '--cudnn_benchmark', type=str, default=None, required=False,
41 | help='if added in args then uses cudnn benchmark set to True '
42 | 'else uses config '
43 | 'else sets it to True by default')
44 |
45 | parser.add_argument('-cdne', '--cudnn_enabled', type=str, default=None, required=False,
46 | help='if added in args then uses cudnn enabled set to True '
47 | 'else uses config '
48 | 'else sets it to True by default')
49 |
50 | parser.add_argument('-vf', '--valid_freq', type=int, default=None, required=False,
51 | help='sets how often to run validation')
52 |
53 | parser.add_argument('-w', '--workers', type=int, default=None, required=False,
54 | help='workers for dataloader per gpu process')
55 |
56 | parser.add_argument('-ec', '--empty_cache', action='store_true',
57 | help='whether to empty cache (per gpu process) after each forward step to avoid OOM --'
58 | ' this is useful in DCV2_ms or DCV3/ms')
59 |
60 | parser.add_argument('-m', '--mode', type=str, default=None, required=False,
61 | help='mode setting e.x training, inference (see BaseManager for others)')
62 |
63 | parser.add_argument('-cpt', '--checkpoint', type=str, default=None, required=False,
64 | help='path to checkpoint folder')
65 |
66 | parser.add_argument('-bs', '--batch_size', type=int, default=None, required=False,
67 | help='batch size -- the number given is then divided by n_gpus if ddp')
68 |
69 | parser.add_argument('-ep', '--epochs', type=int, default=None, required=False,
70 | help='training epochs -- overrides config')
71 |
72 | parser.add_argument('-so', '--save_outputs', action='store_true',
73 | help='whether to save outputs for submission cts')
74 |
75 | parser.add_argument('-rfv', '--run_final_val', action='store_true',
76 | help='whether to run validation with special settings'
77 | ' at the end of training (ex using tta or sliding window inference)')
78 |
79 | parser.add_argument('-tta', '--tta', action='store_true',
80 | help='whether to tta_val at the end of training')
81 |
82 | parser.add_argument('-tsnes', '--tsne_scale', type=int, default=None, required=False,
83 | help=' stride of feats on which to apply tsne must be [4,8,16,32]')
84 |
85 | # loss args for convenience
86 | parser.add_argument('--loss', '-l', choices=[None,'ce', 'ms', 'ms_cs'], default=None, required=False,
87 | help=f'choose loss overriding config (refer to config for other options except {"[ce, ms, ms_cs]"}')
88 |
89 | args = parser.parse_args()
90 | config = parse_config(args.config, args.user, args.device, args.dataset, args.parallel)
91 | manager_class = globals()[config['manager'] + 'Manager']
92 | print(f'requested device ids: {config["gpu_device"]}')
93 | print('parsing cmdline args')
94 | # override config
95 | config['parallel'] = args.parallel
96 | config['tsne_scale'] = args.tsne_scale
97 | if args.loss:
98 | print(f'overriding loss type in config requested [{args.loss}]')
99 | if 'ms' in args.loss:
100 | config['loss'].update({"losses": {"CrossEntropyLoss": 1, "DenseContrastiveLossV2_ms": 0.1}})
101 | config['loss'].update({"cross_scale_contrast": False})
102 | if config['graph']['model'] == 'UPerNet':
103 | config['graph'].update({"ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": True, "position":"backbone"}})
104 | else:
105 | config['graph'].update({"ms_projector": {"mlp": [[1, -1, 1]], "scales":4, "d": 256, "use_bn": True}})
106 |
107 | if 'cs' in args.loss:
108 | config['loss'].update({"cross_scale_contrast": True})
109 |
110 | if args.loss == 'ce':
111 | config['loss'].update({"losses": {"CrossEntropyLoss": 1}})
112 | if 'ms_projector' in config['graph']:
113 | del config['graph']['ms_projector']
114 |
115 | if args.save_outputs:
116 | config['save_outputs'] = True
117 | if args.run_final_val:
118 | config['run_final_val'] = True
119 | print('going to run tta val at the end of training')
120 | if args.empty_cache:
121 | config['empty_cache'] = True
122 | print('emptying cache')
123 | if args.batch_size is not None:
124 | config['data']['batch_size'] = args.batch_size
125 | print(f'bsize {args.batch_size}')
126 | if args.epochs is not None:
127 | config['train']['epochs'] = args.epochs
128 | print(f'epochs : {args.epochs}')
129 | if args.tta:
130 | config['tta'] = True
131 | print(f'tta set to {config["tta"]}')
132 | if args.debugging:
133 | config['debugging'] = True
134 | if args.valid_freq is not None:
135 | config['valid_freq'] = args.valid_freq
136 | if args.workers is not None:
137 | config['data']['num_workers'] = args.workers
138 | print(f'workers {args.workers}')
139 | if args.mode is not None:
140 | config['mode'] = args.mode
141 | print(f'mode {args.mode}')
142 | if args.checkpoint is not None:
143 | config['load_checkpoint'] = args.checkpoint
144 | print(f'load_checkpoint set to {args.mode}')
145 |
146 | if args.cudnn_benchmark is not None:
147 | config['cudnn_benchmark'] = str2bool(args.cudnn_benchmark)
148 | if args.cudnn_enabled is not None:
149 | config['cudnn_enabled'] = str2bool(args.cudnn_enabled)
150 |
151 | manager = manager_class(config)
152 |
153 | if config['mode'] == 'training' and not manager.parallel:
154 | manager.train()
155 | elif config['mode'] == 'inference':
156 | manager.infer()
157 | elif config['mode'] == 'demo_tsne':
158 | manager.demo_tsne()
159 | elif config['mode'] == 'submission_inference':
160 | manager.submission_infer()
161 |
--------------------------------------------------------------------------------
/managers/HRNet_Manager.py:
--------------------------------------------------------------------------------
1 | from managers.BaseManager import BaseManager
2 | from utils import to_comb_image, t_get_confusion_matrix, t_normalise_confusion_matrix, t_get_pixel_accuracy, \
3 | get_matrix_fig, to_numpy, t_get_mean_iou, DATASETS_INFO, printlog
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 | import numpy as np
8 | import datetime
9 | from models import HRNet
10 | from losses import LossWrapper
11 | from torch.utils.tensorboard.writer import SummaryWriter
12 | from tqdm import tqdm
13 |
14 |
15 |
16 | class HRNetManager(BaseManager):
17 |
18 | def forward_step(self, img, lbl, **kwrargs):
19 | ret = dict()
20 |
21 | skip_mem_update = False
22 | if 'skip_mem_update' in kwrargs:
23 | skip_mem_update = kwrargs['skip_mem_update']
24 |
25 | if isinstance(self.loss, LossWrapper):
26 | if self.return_features:
27 | output, proj_features = self.model(img.float())
28 | loss = self.loss(output, lbl.long(), deep_features=proj_features, epoch=self.epoch, skip_mem_update=skip_mem_update)
29 | else:
30 | output = self.model(img.float())
31 | proj_features = None
32 | loss = self.loss(output, lbl.long(), epoch=self.epoch)
33 |
34 | # get individual loss terms values for logging
35 | if 'individual_losses' in kwrargs:
36 | individual_losses = kwrargs['individual_losses']
37 | for key in self.loss.loss_vals:
38 | individual_losses[key] += self.loss.loss_vals[key]
39 | ret['individual_losses'] = individual_losses
40 |
41 | else:
42 | # not using the LossWrapper module
43 | output = self.model(img.float())
44 | proj_features = None
45 | loss = self.loss(output, lbl.long())
46 |
47 | ret['output'] = output
48 | ret['interm_output'] = None
49 | ret['feats'] = proj_features
50 | ret['loss'] = loss
51 |
52 | if self.empty_cache:
53 | torch.cuda.empty_cache()
54 | return ret
55 |
56 | def post_process_output(self, img, output, lbl, metadata, skip_label=False):
57 | if metadata and self.dataset in ['PASCALC', 'ADE20K']:
58 | if "pw_ph_stride" in metadata:
59 | # undo padding due to fit_stride resizing
60 | pad_w, pad_h, stride = metadata["pw_ph_stride"]
61 | if pad_h > 0 or pad_w > 0:
62 | output = output[:, :, 0:output.size(2) - pad_h, 0:output.size(3) - pad_w]
63 | lbl = lbl[:, 0:output.size(2) - pad_h, 0:output.size(3) - pad_w]
64 | img = img[:, :, 0:output.size(2) - pad_h, 0:output.size(3) - pad_w]
65 |
66 | if "sh_sw_in_out" in metadata:
67 | if hasattr(self.model, 'module'):
68 | align_corners = self.model.module.align_corners
69 | else:
70 | align_corners = self.model.align_corners
71 | # undo resizing
72 | starting_size = metadata["sh_sw_in_out"][-2]
73 | # starting size is w,h due to fucking PIL
74 | output = F.interpolate(input=output, size=starting_size[-2:][::-1],
75 | mode='bilinear', align_corners=align_corners)
76 | img = F.interpolate(input=img, size=starting_size[-2:][::-1],
77 | mode='bilinear', align_corners=align_corners)
78 | lbl = metadata["original_labels"].squeeze(0).long().cuda()
79 |
80 | return img, output, lbl
81 |
82 | def train_one_epoch(self):
83 | """Train the model for one epoch"""
84 | if self.rank == 0 and self.epoch == 0 and self.parallel:
85 | printlog('worker rank {} : CREATING train_writer'.format(self.rank))
86 | self.train_writer = SummaryWriter(log_dir = self.log_dir / 'train')
87 |
88 | self.model.train()
89 | a = datetime.datetime.now()
90 | running_confusion_matrix = 0
91 | for batch_num, batch in enumerate(self.data_loaders[self.train_schedule[self.epoch]]):
92 | if len(batch) == 2:
93 | img, lbl = batch
94 | else:
95 | img, lbl, metadata = batch
96 | # if self.debugging:
97 | # continue
98 | b = (datetime.datetime.now() - a).total_seconds() * 1000
99 | a = datetime.datetime.now()
100 | img, lbl = img.to(self.device, non_blocking=True), lbl.to(self.device, non_blocking=True)
101 | self.optimiser.zero_grad()
102 | # forward
103 | ret = self.forward_step(img, lbl)
104 | loss = ret['loss']
105 | output = ret['output']
106 | # backward
107 | loss.backward()
108 | self.optimiser.step()
109 | # lr scheduler
110 | if self.scheduler is not None and self.config['train']['lr_batchwise']:
111 | self.scheduler.step()
112 |
113 | if batch_num == 2 and self.debugging:
114 | break
115 |
116 | # logging
117 | confusion_matrix = t_get_confusion_matrix(output, lbl, self.dataset)
118 | running_confusion_matrix += confusion_matrix
119 | pa, pac = t_get_pixel_accuracy(confusion_matrix)
120 | mious = t_get_mean_iou(confusion_matrix, self.config['data']['experiment'],
121 | self.dataset, categories=True, calculate_mean=False, rare=True)
122 | self.train_logging(batch_num, output, img, lbl, mious, loss, pa, pac, b)
123 |
124 | if 'DenseContrastiveLoss' in self.loss.loss_classes:
125 | col_confusion_matrix = t_normalise_confusion_matrix(running_confusion_matrix, mode='col')
126 | self.train_writer.add_figure('train_confusion_matrix/col_normalised',
127 | get_matrix_fig(to_numpy(col_confusion_matrix),
128 | self.config['data']['experiment'],
129 | self.dataset), self.global_step - 1)
130 | self.loss.loss_classes['DenseContrastiveLoss'].update_confusion_matrix(col_confusion_matrix)
131 |
132 | meta = {}
133 | if 'DenseContrastiveLossV3' in self.loss.loss_classes:
134 | meta = self.loss.loss_classes['DenseContrastiveLossV3'].get_meta()
135 | elif 'DenseContrastiveCenters' in self.loss.loss_classes:
136 | meta = self.loss.loss_classes['DenseContrastiveCenters'].get_meta()
137 |
138 | if 'queue_fillings' in meta:
139 | # self.num_real_classes, dtype=torch.long
140 | self.config['queue_fillings'] = meta['queue_fillings']
141 | self.write_info_json()
142 |
143 | if self.scheduler is not None and not self.config['train']['lr_batchwise']:
144 | self.scheduler.step()
145 | self.train_writer.add_scalar('parameters/learning_rate', self.scheduler.get_lr()[0], self.global_step) \
146 | if self.rank == 0 else None
147 |
148 | def validate(self):
149 | """Validate the model on the validation data"""
150 | if self.rank == 0:
151 | # only process with rank 0 runs validation step
152 | if self.epoch == 0 and self.parallel:
153 | printlog(f'\n creating valid_writer ... for process rank {self.rank}')
154 | self.valid_writer = SummaryWriter(log_dir= self.log_dir / 'valid')
155 | else:
156 | return
157 |
158 | if not self.parallel:
159 | torch.backends.cudnn.benchmark = False
160 |
161 | self.model.eval()
162 | valid_loss = 0
163 | confusion_matrix = None
164 | individual_losses = dict()
165 | if isinstance(self.loss, LossWrapper):
166 | for key in self.loss.loss_vals:
167 | individual_losses[key] = 0
168 | if 'DenseContrastiveLossV3' in self.loss.loss_classes: # make loss run the non ddp version for validation
169 | self.loss.loss_classes['DenseContrastiveLossV3'].parallel = False
170 |
171 | with torch.no_grad():
172 | for rec_num, batch in enumerate(tqdm(self.data_loaders['valid_loader'])):
173 | if len(batch) == 2:
174 | img, lbl = batch
175 | metadata = None
176 | else:
177 | img, lbl, metadata = batch
178 | img, lbl = img.to(self.device, non_blocking=True), lbl.to(self.device, non_blocking=True)
179 |
180 | # forward
181 | ret = self.forward_step(img, lbl, individual_losses=individual_losses, skip_mem_update=True)
182 | loss = ret['loss']
183 | output = ret['output']
184 | valid_loss += loss
185 | img, output, lbl = self.post_process_output(img, output, lbl, metadata)
186 |
187 | # logging
188 | confusion_matrix = t_get_confusion_matrix(output, lbl, self.dataset, confusion_matrix)
189 | if rec_num in np.round(np.linspace(0, len(self.data_loaders['valid_loader']) - 1, self.max_valid_imgs)):
190 | lbl_pred = torch.argmax(nn.Softmax2d()(output), dim=1)
191 | self.valid_writer.add_image(
192 | 'valid_images/record_{:02d}'.format(rec_num),
193 | to_comb_image(self.un_norm(img)[0], lbl[0], lbl_pred[0], self.config['data']['experiment'], self.dataset),
194 | self.global_step, dataformats='HWC')
195 | individual_losses= ret['individual_losses'] if 'individual_losses' in ret else individual_losses
196 | if self.debugging and rec_num == 2:
197 | break
198 | valid_loss /= len(self.data_loaders['valid_loader'])
199 | pa, pac = t_get_pixel_accuracy(confusion_matrix)
200 | mious = t_get_mean_iou(confusion_matrix, self.config['data']['experiment'], self.dataset, True, rare=True)
201 | # logging + checkpoint
202 | self.valid_logging(valid_loss, confusion_matrix, individual_losses, mious, pa, pac)
203 |
204 | if not self.parallel:
205 | torch.backends.cudnn.benchmark = True
206 |
207 | if isinstance(self.loss, LossWrapper):
208 | if 'DenseContrastiveLossV3' in self.loss.loss_classes: # reset
209 | self.loss.loss_classes['DenseContrastiveLossV3'].parallel = self.parallel
210 |
--------------------------------------------------------------------------------
/managers/__init__.py:
--------------------------------------------------------------------------------
1 | # from .colorization_manager import ColorizationManager
2 | from .DeepLabv3_Manager import DeepLabv3Manager
3 | from .OCRNet_Manager import OCRNetManager
4 | from .HRNet_Manager import HRNetManager
5 |
6 |
--------------------------------------------------------------------------------
/misc/figs/fig1-01-01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RViMLab/ECCV2022-multi-scale-and-cross-scale-contrastive-segmentation/c511fbcde6ac53b72d663225bdf6dded022ca1ce/misc/figs/fig1-01-01.png
--------------------------------------------------------------------------------
/models/Projector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from typing import Union
4 | from models.Transformers import SelfAttention
5 | from utils import printlog
6 |
7 | class Projector(nn.Module):
8 |
9 | def __init__(self, config):
10 | """ module that maps encoder features to a d-dimensional space
11 | if can be a single conv-linear (optionally) preceded by an fcn with conv-relu layers
12 | """
13 | super().__init__()
14 | self.d = config['d'] if 'd' in config else 128 # projection dim
15 | self.c_in = config['c_in'] # input features channels (usually == output channels of resnet backbone)
16 | assert isinstance(self.c_in, list) or isinstance(self.c_in, int)
17 | # config['mlp'] list of [k,c] for Conv-Relu layers, if empty only applies Conv(c_in, d, k=1)
18 | self.mlp = config['mlp'] if 'mlp' in config else []
19 | self.use_bn = config['use_bn'] if 'use_bn' in config else False
20 | self.transformer = config['trans'] if 'trans' in config else False
21 | self.heads = config['heads'] if 'heads' in config else 1
22 |
23 | if isinstance(self.c_in, list):
24 | self.is_ms = True
25 | self._create_ms_mlp()
26 | else:
27 | self.is_ms = False # whether the projector is multiscale
28 | self._create_mlp(self.c_in)
29 |
30 | def _create_ms_mlp(self):
31 | printlog('** creating ms projector **')
32 | for feat_id, c_in in enumerate(self.c_in):
33 | printlog(f'* scale {feat_id} feats: {c_in}')
34 | self._create_mlp(c_in, feat_id)
35 |
36 | def _create_mlp(self, c_in:int, feat_id:Union[list,int]=''):
37 | # sanity checks
38 | assert(isinstance(self.mlp, list)), 'config["mlp"] must be [[k_1, c_1, s_1], ., [k_n, c_n, s_n]] or [] ' \
39 | 'k_i is kernel (k_i x k_i) c_i is channels and s_i is stride'
40 | first_layer_has_cout_equal_to_cin = False
41 | if len(self.mlp):
42 | for layer in self.mlp:
43 | assert(isinstance(layer, list)), f'elements of layer definition list must be lists instead got {layer}'
44 | assert(len(layer) == 3 and layer[2] in [1, 2]), 'must provide list of lists of 3 elements each' \
45 | '[kernel, channels, stride] instead {}'.format(layer[2])
46 | if layer[1]>0:
47 | assert(layer[0] < layer[1]), 'kernel size is first element of list, got {} {}'.format(layer[0], layer[1])
48 |
49 | self.convs = []
50 | c_prev = c_in
51 | if len(self.mlp):
52 | for layer_id, (k, c_out, s) in enumerate(self.mlp):
53 | if layer_id == 0 and c_out==-1:
54 | c_out = c_prev
55 | printlog('Projector creating conv layer, k_{}/c_{}/s_{}'.format(k, c_out, s))
56 | # if use_bn --> do not use bias
57 | # p = (k + (k - 1) * (d - 1) - s + 1) // 2
58 | p = (k - s + 1) // 2
59 | self.convs.append(nn.Conv2d(c_prev, c_out, kernel_size=k, stride=s,
60 | padding=p, bias=not self.use_bn))
61 | self.convs.append(nn.ReLU(inplace=True))
62 | if self.use_bn:
63 | self.convs.append(nn.BatchNorm2d(c_out, momentum=0.0003))
64 | c_prev = c_out
65 | if self.transformer:
66 | sa = SelfAttention(dim = c_prev, heads=self.heads)
67 | printlog('Projector creating transformer layer, heads_{}/c_{}'.format(self.heads, c_prev))
68 | self.convs.append(sa)
69 |
70 | printlog('Projector creating linear layer, k_{}/c_{}/s_{}'.format(1, self.d, 1))
71 | self.convs.append(nn.Conv2d(c_prev, self.d, kernel_size=1, stride=1))
72 | setattr(self, f'project{feat_id}', nn.Sequential(*self.convs))
73 |
74 | def forward(self, x:Union[list, torch.tensor]):
75 | # # x are features of shape NCHW
76 | # x = x / torch.norm(x, dim=1, keepdim=True) # Normalise along C: features vectors now lie on unit hypersphere
77 | if self.is_ms:
78 | outs = []
79 | assert(isinstance(x, list) or isinstance(x, tuple)), f'if multiscale projector is used a list is expected as input instead got {type(x)}'
80 | for feat_id, x_i in enumerate(x):
81 | x_i = getattr(self, f'project{feat_id}')(x_i)
82 | outs.append(x_i)
83 | return outs
84 | else:
85 | if isinstance(x, list):
86 | if len(x)==1:
87 | x = x[0]
88 | else:
89 | raise ValueError(f'x is {type(x)}, of length {len(x)}')
90 | x = self.project(x)
91 | return x
92 |
93 |
94 |
95 |
96 | if __name__ == '__main__':
97 | # example
98 | feats1 = torch.rand(size=(2, 1024, 60, 120)).float()
99 | feats0 = torch.rand(size=(2, 2048, 60, 120)).float()
100 |
101 | # proj = Projector({'mlp': [[1,-1, 1], [1, 256, 1]], 'c_in': [512,512,1024,1024], 'd': 128, 'use_bn': True})
102 |
103 | proj = Projector({'mlp': [[1,-1, 1], [1, 256, 1]], 'c_in': 2048, 'd': 128, "trans": True, "heads":1, 'use_bn': True})
104 |
105 | # projected_feats = proj(([feats0]*2 )+([feats1]*2))
106 | p = proj(feats0)
107 | # p_sa = SelfAttention(dim=p.shape[1])(p)
108 | print(p.shape)
109 |
110 | # print(projected_feats.shape)
111 |
112 | # for v, par in proj.named_parameters():
113 | # if par.requires_grad:
114 | # print(v, par.data.shape, par.requires_grad)
115 | # d = proj.state_dict()
116 | # print(d)
117 |
--------------------------------------------------------------------------------
/models/TTAWrapperSlide.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from typing import Union
4 | import datetime
5 | import cv2
6 | import numpy as np
7 | from utils import printlog, to_numpy, to_comb_image, un_normalise
8 | from models import TTAWrapper
9 |
10 |
11 | class TTAWrapperSlide(TTAWrapper):
12 | def __init__(self,
13 | model,
14 | scale_list,
15 | flip=True,
16 | strides:Union[tuple, None]=None,
17 | crop_size:Union[tuple, None]=None,
18 | debug=False):
19 |
20 | super().__init__(model, scale_list, flip)
21 | # self.num_classes = 19
22 | self.num_classes = 150
23 | self.crop_size = crop_size if crop_size else [512,512]
24 | self.strides = strides if strides else self.crop_size # defaults to no-overlapping sliding window
25 | # self.base_size = 2048
26 | self.base_size = 512
27 | self.debug = debug
28 | img_scale = (2048, 512)
29 | self.image_flips = []
30 | self.image_scales = []
31 | if self.flip:
32 | flips = [True, False]
33 | else:
34 | flips = [False]
35 | for s in self.scales:
36 | for f in flips:
37 | self.image_scales.append(((int(img_scale[0]*s), int(img_scale[1]*s)),s))
38 | self.image_flips.append(f)
39 | printlog(f'Sliding window : strides : {self.strides} crop_size {self.crop_size} image_scales: {self.image_scales}')
40 |
41 | def inference(self, image, flip=False, scale=1.0, id_=1):
42 | # image BCHW
43 | assert image.device.type == 'cuda'
44 | size = image.size()
45 | pred = self.model(image)
46 | # done internally in model
47 | # pred = F.interpolate(
48 | # input=pred, size=size[-2:],
49 | # mode='bilinear', align_corners=self.model.align_corners
50 | # )
51 | if flip:
52 | flip_img = to_numpy(image)[:, :, :, ::-1]
53 | flip_output = self.model(torch.from_numpy(flip_img.copy()).cuda())
54 | # flip_output = F.interpolate(
55 | # input=flip_output, size=size[-2:],
56 | # mode='bilinear', align_corners=self.model.align_corners
57 | # )
58 |
59 | flip_pred = to_numpy(flip_output).copy()
60 | flip_pred = torch.from_numpy(flip_pred[:, :, :, ::-1].copy()).cuda()
61 | pred += flip_pred
62 | pred = pred * 0.5
63 | if self.debug:
64 | to_comb_image(un_normalise(image[0]), torch.argmax(pred[0], 0), None, 1, 'ADE20K', save=f'pred_scale_{scale}_{id_}.png')
65 | return pred.exp()
66 |
67 | def multi_scale_aug(self, image, new_shape):
68 | new_h, new_w = new_shape
69 | image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
70 | return image
71 |
72 | def slide_infer(self, new_img, preds, count, scale, flip):
73 | stride_h = int(self.strides[0])
74 | stride_w = int(self.strides[1])
75 | new_h, new_w = new_img.shape[:-1]
76 | rows = int(np.ceil((new_h - self.crop_size[0]) / stride_h)) + 1
77 | cols = int(np.ceil((new_w - self.crop_size[1]) / stride_w)) + 1
78 | # preds = torch.zeros([1, self.num_classes, new_h, new_w]).cuda()
79 | # count = torch.zeros([1, 1, new_h, new_w]).cuda()
80 | id_ = 1
81 | for r in range(rows):
82 | for c in range(cols):
83 | h0 = r * stride_h
84 | w0 = c * stride_w
85 | h1 = min(h0 + self.crop_size[0], new_h)
86 | w1 = min(w0 + self.crop_size[1], new_w) # x2
87 | h0 = max(int(h1 - self.crop_size[0]), 0) # y1
88 | w0 = max(int(w1 - self.crop_size[1]), 0) # x1
89 | crop_img = new_img[h0:h1, w0:w1, :]
90 | crop_img = crop_img.transpose((2, 0, 1))
91 | crop_img = np.expand_dims(crop_img, axis=0)
92 | crop_img = torch.from_numpy(crop_img)
93 | pred = self.inference(crop_img.cuda(), flip=flip, scale=scale, id_=id_)
94 | id_ += 1
95 | # print(w0, preds.shape[3] - w1, int(h0), preds.shape[2] - h1)
96 | preds[:, :, h0:h1, w0:w1] += pred[:, :, 0:h1 - h0, 0:w1 - w0]
97 | count[:, :, h0:h1, w0:w1] += 1
98 | # preds = preds / count
99 | # preds = preds[:, :, :height, :width]
100 |
101 | return preds, count
102 |
103 | def forward(self, x):
104 | a = datetime.datetime.now()
105 | if isinstance(x, tuple):
106 | x = x[0]
107 | assert isinstance(x, torch.tensor), f'x input must be a tensor instead got {type(x)}'
108 | batch, _, ori_height, ori_width = x.size()
109 | assert batch == 1, "only supporting batchsize 1."
110 | # x is BCHW
111 | image = to_numpy(x)[0].transpose((1, 2, 0)).copy()
112 | # x is HWC
113 | stride_h = int(self.strides[0] * 1.0)
114 | stride_w = int(self.strides[1] * 1.0)
115 |
116 | final_pred = torch.zeros([1, self.num_classes, ori_height, ori_width]).cuda()
117 | for flip, (shape, scale) in zip(self.image_flips, self.image_scales) :
118 | new_img = self.multi_scale_aug(image, new_shape=shape)
119 | height, width = new_img.shape[:-1]
120 | # if scale < 1.0:
121 | # new_img = new_img.transpose((2, 0, 1))
122 | # new_img = np.expand_dims(new_img, axis=0)
123 | # new_img = torch.from_numpy(new_img)
124 | # preds = self.inference(new_img.cuda(), flip=True, scale=scale, id_=1)
125 | # preds = preds[:, :, 0:height, 0:width]
126 | # else:
127 | new_h, new_w = new_img.shape[:-1]
128 | preds = torch.zeros([1, self.num_classes, new_h, new_w]).cuda()
129 | count = torch.zeros([1, 1, new_h, new_w]).cuda()
130 |
131 | preds, count = self.slide_infer(new_img, preds, count, scale, flip)
132 |
133 | preds = preds / count
134 | preds = preds[:, :, :height, :width]
135 |
136 | preds = F.interpolate(
137 | preds, (ori_height, ori_width),
138 | mode='bilinear', align_corners=self.model.align_corners
139 | )
140 |
141 | final_pred += preds
142 | if self.debug:
143 | to_comb_image(un_normalise(x[0]), torch.argmax(final_pred[0], 0), None, 1, 'ADE20K', save=f'final.png')
144 |
145 | b = (datetime.datetime.now() - a).total_seconds() * 1000
146 | print(f'\r time:{b}')
147 | return final_pred
148 |
149 |
150 | if __name__ == '__main__':
151 | import pickle
152 | from torchvision.transforms import Normalize, ToTensor, Compose, RandomCrop
153 | # from models.SegFormer import SegFormer
154 | from models.UPerNet import UPerNet
155 | import cv2
156 | from utils import to_numpy, to_comb_image, un_normalise, check_module_prefix
157 |
158 | file = open('..\\ade20k_img.pkl', 'rb')
159 | img = pickle.load(file)
160 | file.close()
161 |
162 | # path_to_chkpt = '..\\logging\\ADE20K\\20220326_185031_e1__upn_ConvNextT_sbn_DCms_cs_epochs127_bs16\\chkpts\\chkpt_epoch_126.pt'
163 | map_location = 'cuda:0'
164 | # checkpoint = torch.load(str(path_to_chkpt), map_location)
165 |
166 | config = dict()
167 | config.update({'backbone': 'ConvNextT', 'out_stride': 32, 'pretrained': False, 'dataset':'ADE20K',
168 | 'pretrained_res':224, 'pretrained_dataset':'22k' , 'align_corners':False})
169 | model = UPerNet(config, 1)
170 |
171 | # ret = model.load_state_dict(check_module_prefix(checkpoint['model_state_dict'], model), strict=False)
172 | # print(ret)
173 | T = Compose([
174 | ToTensor(),
175 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
176 | # RandomCrop(size=[512, 512])
177 | ])
178 |
179 | with torch.no_grad():
180 |
181 | tta_model = TTAWrapperSlide(model, scale_list=[0.5, 1], crop_size=(512, 512),
182 | strides=(326, 326), debug=True) # [0.75, 1.25, 1.5, 1.75, 2, 1.0]
183 | tta_model.cuda()
184 | tta_model.eval()
185 | x = T(img)
186 | # x = x.cuda().float()
187 | y = tta_model.forward(x.unsqueeze(0).float())
188 | c = 1
--------------------------------------------------------------------------------
/models/TTA_wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.parallel import DistributedDataParallel as ddp
3 | from torch import nn
4 | from torch.nn import functional as F
5 | import datetime
6 | import cv2
7 | from utils import printlog
8 |
9 |
10 | class TTAWrapper(nn.Module):
11 | """
12 | hard-coding common scaling and flipping protocol for simplicity
13 | """
14 | def __init__(self, model, scale_list=None, flip=True):
15 | super().__init__()
16 | self.scales = scale_list# 1.5, 1.75, 2] # 1.5, 1.75, 2]
17 | self.flip = flip
18 | if 1.0 not in self.scales:
19 | self.scales.append(1.0)
20 | if isinstance(model, ddp):
21 | self.model = ddp.module
22 | else:
23 | self.model = model
24 |
25 | self.align_corners = self.model.align_corners if hasattr(self.model, 'align_corners') else True
26 |
27 | printlog(f'*** TTA wrapper with flip : [{flip}] --- scales : {self.scales} -- align_corners:{self.align_corners}')
28 |
29 | def maybe_resize(self, x, scale, in_shape):
30 | """
31 |
32 | :param x: B,C,H,W
33 | :param scale: if s in R+ resizes the image to s*in_shape,
34 | if s=1 then return x,
35 | if s=-1 then resize image to in_shape
36 | :param in_shape:
37 | :return:
38 | """
39 | scaled_shape = [int(scale * in_shape[0]), int(scale * in_shape[1])]
40 | if scale != 1.0 and scale > 0:
41 | x = F.interpolate(x, size=scaled_shape, mode='bilinear', align_corners=self.align_corners)
42 | elif scale == -1:
43 | x = F.interpolate(x, size=in_shape, mode='bilinear', align_corners=self.align_corners)
44 | else:
45 | x = x.clone()
46 | return x
47 |
48 | def maybe_flip(self, x, f):
49 | if f == 0:
50 | x_f = torch.flip(x, dims=[3]) # clones
51 | else:
52 | x_f = x.clone()
53 | return x_f
54 |
55 | def forward(self, x, **kwargs):
56 | if isinstance(x, tuple):
57 | x = x[0]
58 | assert isinstance(x, torch.tensor), f'x input must be a tensor instead got {type(x)}'
59 |
60 | a = datetime.datetime.now()
61 | assert len(x.shape)==4, 'input must be B,C,H,W'
62 | flag_first = True # flag for the first iteration of the nested loop]
63 | in_shape=x.shape[2:4]
64 | out_shape = [1, self.model.num_classes] + list(in_shape)
65 | y_merged = torch.zeros(size=out_shape).cuda()
66 | for f in range(2):
67 | x_f = self.maybe_flip(x, f) # flip
68 | for s in self.scales:
69 | x_f_s = self.maybe_resize(x_f, s, in_shape) # resize
70 | y = self.model(x_f_s) # forward
71 | y = self.maybe_flip(y, f) # unflip
72 | y_merged += self.maybe_resize(y, -1, in_shape) # un-resize
73 |
74 | b = (datetime.datetime.now() - a).total_seconds() * 1000
75 | # print('time taken for tta {:.5f}'.format(b))
76 | y_merged = y_merged/(2*len(self.scales))
77 | # cv2.imshow('final', to_comb_image(un_normalise(x[0]), torch.argmax(y_merged[0], 0), None, 1, 'CITYSCAPES'))
78 | return y_merged
79 |
80 |
81 |
82 |
83 |
84 | if __name__ == '__main__':
85 | import pickle
86 | from torchvision.transforms import Normalize, ToTensor, Compose, RandomCrop
87 | from models.HRNet import HRNet
88 | from models.UPerNet import UPerNet
89 | import cv2
90 | from utils import to_numpy, to_comb_image, un_normalise, check_module_prefix
91 |
92 | file = open('..\\img_cts.pkl', 'rb')
93 | img = pickle.load(file)
94 | file.close()
95 |
96 | path_to_chkpt = '..\\logging\\ADE20K\\20220326_185031_e1__upn_ConvNextT_sbn_DCms_cs_epochs127_bs16\\chkpts\\chkpt_epoch_126.pt'
97 | map_location = 'cuda:0'
98 | checkpoint = torch.load(str(path_to_chkpt), map_location)
99 |
100 | config = dict()
101 | config.update({'backbone': 'ConvNextT', 'out_stride': 32, 'pretrained': True, 'dataset':'ADE20K'})
102 | model = UPerNet(config, 1)
103 |
104 | ret = model.load_state_dict(check_module_prefix(checkpoint['model_state_dict'], model), strict=False)
105 | print(ret)
106 | T = Compose([
107 | ToTensor(),
108 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
109 | RandomCrop(size=[512, 512])
110 | ])
111 |
112 | with torch.no_grad():
113 | tta_model = TTAWrapper(model, scale_list=[0.5, 1.5])
114 | tta_model.cuda()
115 | tta_model.eval()
116 | x = T(img)
117 | x = x.cuda().float()
118 | y = tta_model.forward(x.unsqueeze(0))
119 | c = 1
120 |
121 |
--------------------------------------------------------------------------------
/models/TTA_wrapper_CTS.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from typing import Union
4 | import datetime
5 | import cv2
6 | import numpy as np
7 | from utils import printlog, to_numpy, to_comb_image, un_normalise
8 | from models import TTAWrapper
9 |
10 |
11 | class TTAWrapperCTS(TTAWrapper):
12 | def __init__(self,
13 | model,
14 | scale_list,
15 | flip=True,
16 | strides:Union[tuple, None]=None,
17 | crop_size:Union[tuple, None]=None,
18 | debug=False):
19 |
20 | super().__init__(model, scale_list, flip)
21 | self.num_classes = 19
22 | self.crop_size = crop_size if crop_size else [512,1024]
23 | self.strides = strides if strides else self.crop_size # defaults to no-overlapping sliding window
24 | self.base_size = 2048
25 | self.debug = debug
26 |
27 | printlog(f'Sliding window : strides : {self.strides} crop_size {self.crop_size}')
28 |
29 | def inference(self, image, flip=False, scale=1.0, id_=1):
30 | # image BCHW
31 | assert image.device.type == 'cuda'
32 | size = image.size()
33 | pred = self.model(image)
34 | # done internally in model
35 | # pred = F.interpolate(
36 | # input=pred, size=size[-2:],
37 | # mode='bilinear', align_corners=self.model.align_corners
38 | # )
39 | if flip:
40 | flip_img = to_numpy(image)[:, :, :, ::-1]
41 | flip_output = self.model(torch.from_numpy(flip_img.copy()).cuda())
42 | # flip_output = F.interpolate(
43 | # input=flip_output, size=size[-2:],
44 | # mode='bilinear', align_corners=self.model.align_corners
45 | # )
46 |
47 | flip_pred = to_numpy(flip_output).copy()
48 | flip_pred = torch.from_numpy(flip_pred[:, :, :, ::-1].copy()).cuda()
49 | pred += flip_pred
50 | pred = pred * 0.5
51 | if self.debug:
52 | to_comb_image(un_normalise(image[0]), torch.argmax(pred[0], 0), None, 1, 'ADE20K', save=f'pred_scale_{scale}_{id_}.png')
53 | return pred.exp()
54 |
55 | def multi_scale_aug(self, image, label=None,
56 | rand_scale=1, rand_crop=True):
57 |
58 | long_size = int(self.base_size * rand_scale + 0.5)
59 | h, w = image.shape[:2]
60 | if h > w:
61 | new_h = long_size
62 | new_w = int(w * long_size / h + 0.5)
63 | else:
64 | new_w = long_size
65 | new_h = int(h * long_size / w + 0.5)
66 |
67 | image = cv2.resize(image, (new_w, new_h),
68 | interpolation=cv2.INTER_LINEAR)
69 | if label is not None:
70 | label = cv2.resize(label, (new_w, new_h),
71 | interpolation=cv2.INTER_NEAREST)
72 | else:
73 | return image
74 |
75 | if rand_crop:
76 | image, label = self.rand_crop(image, label)
77 |
78 | return image, label
79 |
80 | def forward(self, x):
81 | a = datetime.datetime.now()
82 | if isinstance(x, tuple):
83 | x = x[0]
84 | assert isinstance(x, torch.tensor), f'x input must be a tensor instead got {type(x)}'
85 | batch, _, ori_height, ori_width = x.size()
86 | assert batch == 1, "only supporting batchsize 1."
87 | # x is BCHW
88 | image = to_numpy(x)[0].transpose((1, 2, 0)).copy()
89 | # x is HWC
90 | stride_h = int(self.strides[0] * 1.0)
91 | stride_w = int(self.strides[1] * 1.0)
92 |
93 | final_pred = torch.zeros([1, self.num_classes,
94 | ori_height, ori_width]).cuda()
95 |
96 | for scale in self.scales:
97 | new_img = self.multi_scale_aug(image=image,
98 | rand_scale=scale,
99 | rand_crop=False)
100 | # cv2.imshow(f'scale {scale}', new_img)
101 | height, width = new_img.shape[:-1]
102 |
103 | if scale < 1.0:
104 | new_img = new_img.transpose((2, 0, 1))
105 | new_img = np.expand_dims(new_img, axis=0)
106 | new_img = torch.from_numpy(new_img)
107 | preds = self.inference(new_img.cuda(), flip=True, scale=scale, id_=1)
108 | preds = preds[:, :, 0:height, 0:width]
109 | else:
110 | new_h, new_w = new_img.shape[:-1]
111 | rows = int(np.ceil(1.0 * (new_h - self.crop_size[0]) / stride_h)) + 1
112 | cols = int(np.ceil(1.0 * (new_w - self.crop_size[1]) / stride_w)) + 1
113 | preds = torch.zeros([1, self.num_classes, new_h, new_w]).cuda()
114 | count = torch.zeros([1, 1, new_h, new_w]).cuda()
115 | id_ = 1
116 | for r in range(rows):
117 | for c in range(cols):
118 | h0 = r * stride_h
119 | w0 = c * stride_w
120 | h1 = min(h0 + self.crop_size[0], new_h)
121 | w1 = min(w0 + self.crop_size[1], new_w)
122 | h0 = max(int(h1 - self.crop_size[0]), 0)
123 | w0 = max(int(w1 - self.crop_size[1]), 0)
124 | crop_img = new_img[h0:h1, w0:w1, :]
125 | crop_img = crop_img.transpose((2, 0, 1))
126 | crop_img = np.expand_dims(crop_img, axis=0)
127 | crop_img = torch.from_numpy(crop_img)
128 | pred = self.inference(crop_img.cuda(), flip=self.flip, scale=scale, id_= id_)
129 | id_ += 1
130 |
131 | preds[:, :, h0:h1, w0:w1] += pred[:, :, 0:h1 - h0, 0:w1 - w0]
132 | count[:, :, h0:h1, w0:w1] += 1
133 | preds = preds / count
134 | preds = preds[:, :, :height, :width]
135 |
136 | preds = F.interpolate(
137 | preds, (ori_height, ori_width),
138 | mode='bilinear', align_corners=self.model.align_corners
139 | )
140 |
141 | final_pred += preds
142 | if self.debug:
143 | to_comb_image(un_normalise(x[0]), torch.argmax(final_pred[0], 0), None, 1, 'CITYSCAPES', save=f'final.png')
144 |
145 | b = (datetime.datetime.now() - a).total_seconds() * 1000
146 | print(f'\r time:{b}')
147 | return final_pred
148 |
149 |
150 | if __name__ == '__main__':
151 | import pickle
152 | from torchvision.transforms import Normalize, ToTensor, Compose, RandomCrop
153 | # from models.SegFormer import SegFormer
154 | from models.UPerNet import UPerNet
155 | import cv2
156 | from utils import to_numpy, to_comb_image, un_normalise, check_module_prefix
157 |
158 | file = open('..\\ade20k_img.pkl', 'rb')
159 | img = pickle.load(file)
160 | file.close()
161 |
162 | path_to_chkpt = '..\\logging\\ADE20K\\20220326_185031_e1__upn_ConvNextT_sbn_DCms_cs_epochs127_bs16\\chkpts\\chkpt_epoch_126.pt'
163 | map_location = 'cuda:0'
164 | checkpoint = torch.load(str(path_to_chkpt), map_location)
165 |
166 | config = dict()
167 |
168 | config.update({'backbone': 'ConvNextT', 'out_stride': 32, 'pretrained': False, 'dataset':'ADE20K',
169 | 'pretrained_res':224, 'pretrained_dataset':'22k' , 'align_corners':False})
170 |
171 | model = UPerNet(config, 1)
172 |
173 | ret = model.load_state_dict(check_module_prefix(checkpoint['model_state_dict'], model), strict=False)
174 | print(ret)
175 | T = Compose([
176 | ToTensor(),
177 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
178 | # RandomCrop(size=[512, 512])
179 | ])
180 |
181 | with torch.no_grad():
182 |
183 | tta_model = TTAWrapperCTS(model, scale_list=[0.5], crop_size=(512, 512), strides=(341, 341), debug=True) # [0.75, 1.25, 1.5, 1.75, 2, 1.0]
184 | tta_model.cuda()
185 | tta_model.eval()
186 | x = T(img)
187 | # x = x.cuda().float()
188 | y = tta_model.forward(x.unsqueeze(0).float())
189 | c = 1
190 |
--------------------------------------------------------------------------------
/models/TTA_wrapper_PC.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.parallel import DistributedDataParallel as ddp
3 | from torch import nn
4 | from torch.nn import functional as F
5 | import datetime
6 | import cv2
7 | import numpy as np
8 | from utils import printlog, to_numpy, to_comb_image
9 | from models import TTAWrapper
10 |
11 |
12 | class TTAWrapperPC(TTAWrapper):
13 | def __init__(self, model, scale_list):
14 | super().__init__(model, scale_list)
15 | self.num_classes = 59
16 | self.crop_size = [512,512]
17 | self.base_size = 520
18 |
19 | def inference(self, image, flip=False, scale=1.0, id_=1):
20 | # image BCHW
21 | assert image.device.type == 'cuda'
22 | size = image.size()
23 | pred = self.model(image)
24 | # done internally in model
25 | # pred = F.interpolate(
26 | # input=pred, size=size[-2:],
27 | # mode='bilinear', align_corners=self.model.align_corners
28 | # )
29 | if flip:
30 | flip_img = to_numpy(image)[:, :, :, ::-1]
31 | flip_output = self.model(torch.from_numpy(flip_img.copy()).cuda())
32 | # flip_output = F.interpolate(
33 | # input=flip_output, size=size[-2:],
34 | # mode='bilinear', align_corners=self.model.align_corners
35 | # )
36 |
37 | flip_pred = to_numpy(flip_output).copy()
38 | flip_pred = torch.from_numpy(flip_pred[:, :, :, ::-1].copy()).cuda()
39 | pred += flip_pred
40 | pred = pred * 0.5
41 |
42 | # to_comb_image(un_normalise(image[0]), torch.argmax(pred[0], 0), None, 1, 'PASCALC', save=f'pred_{self.ind}_scale_{scale}_{id_}.png')
43 | return pred.exp()
44 |
45 | def multi_scale_aug(self, image, label=None,
46 | rand_scale=1, rand_crop=True):
47 |
48 | long_size = int(self.base_size * rand_scale + 0.5)
49 | h, w = image.shape[:2]
50 | if h > w:
51 | new_h = long_size
52 | new_w = int(w * long_size / h + 0.5)
53 | else:
54 | new_w = long_size
55 | new_h = int(h * long_size / w + 0.5)
56 |
57 | image = cv2.resize(image, (new_w, new_h),
58 | interpolation=cv2.INTER_LINEAR)
59 | if label is not None:
60 | label = cv2.resize(label, (new_w, new_h),
61 | interpolation=cv2.INTER_NEAREST)
62 | else:
63 | return image
64 |
65 | if rand_crop:
66 | image, label = self.rand_crop(image, label)
67 |
68 | return image, label
69 |
70 | def pad_image(self, image, h, w, size, padvalue):
71 | pad_image = image.copy()
72 | pad_h = max(size[0] - h, 0)
73 | pad_w = max(size[1] - w, 0)
74 | if pad_h > 0 or pad_w > 0:
75 | pad_image = cv2.copyMakeBorder(image, 0, pad_h, 0,
76 | pad_w, cv2.BORDER_CONSTANT,
77 | value=padvalue)
78 |
79 | return pad_image
80 |
81 | def forward(self, x, ind=0):
82 | self.ind = ind
83 | a = datetime.datetime.now()
84 | if isinstance(x, tuple):
85 | x = x[0]
86 | assert isinstance(x, torch.tensor), f'x input must be a tensor instead got {type(x)}'
87 | batch, _, ori_height, ori_width = x.size()
88 | assert batch == 1, "only supporting batchsize 1."
89 | # x is BCHW
90 | image = to_numpy(x)[0].transpose((1, 2, 0)).copy( )
91 | # x is HWC
92 | stride_h = int(self.crop_size[0] * 2.0/3.0)
93 | stride_w = int(self.crop_size[1] * 2.0/3.0)
94 | final_pred = torch.zeros([1, self.num_classes,
95 | ori_height, ori_width]).cuda()
96 |
97 | mean = [0.485, 0.456, 0.406]
98 | std = [0.229, 0.224, 0.225]
99 | padvalue = -1.0 * np.array(mean)/np.array(std)
100 |
101 | for scale in self.scales:
102 | new_img = self.multi_scale_aug(image=image,
103 | rand_scale=scale,
104 | rand_crop=False)
105 | # cv2.imshow(f'scale {scale}', new_img)
106 | height, width = new_img.shape[:-1]
107 |
108 | if max(height, width) <= np.min(self.crop_size):
109 | new_img = self.pad_image(new_img, height, width,
110 | self.crop_size, padvalue)
111 | new_img = new_img.transpose((2, 0, 1))
112 | new_img = np.expand_dims(new_img, axis=0)
113 | new_img = torch.from_numpy(new_img)
114 | preds = self.inference(new_img.cuda(), flip=True, scale=scale, id_=1)
115 | preds = preds[:, :, 0:height, 0:width]
116 | else:
117 |
118 | if height < self.crop_size[0] or width < self.crop_size[1]:
119 | new_img = self.pad_image(new_img, height, width,
120 | self.crop_size, padvalue)
121 |
122 | new_h, new_w = new_img.shape[:-1]
123 | rows = int(np.ceil(1.0 * (new_h - self.crop_size[0]) / stride_h)) + 1
124 | cols = int(np.ceil(1.0 * (new_w - self.crop_size[1]) / stride_w)) + 1
125 | preds = torch.zeros([1, self.num_classes, new_h, new_w]).cuda()
126 | count = torch.zeros([1, 1, new_h, new_w]).cuda()
127 | id_ = 1
128 | for r in range(rows):
129 | for c in range(cols):
130 | h0 = r * stride_h
131 | w0 = c * stride_w
132 | h1 = min(h0 + self.crop_size[0], new_h)
133 | w1 = min(w0 + self.crop_size[1], new_w)
134 | # h0 = max(int(h1 - self.crop_size[0]), 0)
135 | # w0 = max(int(w1 - self.crop_size[1]), 0)
136 | crop_img = new_img[h0:h1, w0:w1, :]
137 |
138 | if h1 == new_h or w1 == new_w:
139 | crop_img = self.pad_image(crop_img,
140 | h1-h0,
141 | w1-w0,
142 | self.crop_size,
143 | padvalue)
144 |
145 | crop_img = crop_img.transpose((2, 0, 1))
146 | crop_img = np.expand_dims(crop_img, axis=0)
147 | crop_img = torch.from_numpy(crop_img)
148 | pred = self.inference(crop_img.cuda(), flip=True, scale=scale, id_= id_)
149 | id_ += 1
150 |
151 | preds[:, :, h0:h1, w0:w1] += pred[:, :, 0:h1 - h0, 0:w1 - w0]
152 | count[:, :, h0:h1, w0:w1] += 1
153 | preds = preds / count
154 | preds = preds[:, :, :height, :width]
155 |
156 | preds = F.interpolate(
157 | preds, (ori_height, ori_width),
158 | mode='bilinear', align_corners=self.model.align_corners
159 | )
160 |
161 | final_pred += preds
162 |
163 | # final_pred = F.interpolate(
164 | # final_pred, ori_shape,
165 | # mode='bilinear', align_corners=self.model.align_corners
166 | # )
167 |
168 | # to_comb_image(un_normalise(x[0]), torch.argmax(final_pred[0], 0), None, 1, 'PASCALC', save=f'final_{self.ind}.png')
169 | b = (datetime.datetime.now() - a).total_seconds() * 1000
170 | print(f'\r time:{b}')
171 | return final_pred
172 |
173 |
174 | if __name__ == '__main__':
175 | import pickle
176 | from torchvision.transforms import Normalize, ToTensor, Compose, RandomCrop
177 | from models.HRNet import HRNet
178 | import cv2
179 | from utils import to_numpy, to_comb_image, un_normalise, check_module_prefix
180 | from datasets import PascalC
181 | data_path = r'C:\Users\Theodoros Pissas\Documents\tresorit\PASCALC/'
182 | from torchvision.transforms import ToTensor
183 | import PIL.Image as Image
184 | d = {"dataset":'PASCALC', "experiment":1}
185 |
186 | # file = open('..\\img_cts.pkl', 'rb')
187 | # img = pickle.load(file)
188 | # file.close()
189 |
190 | from utils import parse_transform_lists
191 | import json
192 | path_to_config = '../configs/dlv3_contrastive_PC.json'
193 | with open(path_to_config, 'r') as f:
194 | config = json.load(f)
195 |
196 | transforms_list_val = config['data']['transforms_val']
197 | # if 'torchvision_normalise' in transforms_list_val:
198 | # del transforms_list_val[-1]
199 | transforms_values_val = config['data']['transform_values_val']
200 | transforms_dict_val = parse_transform_lists(transforms_list_val, transforms_values_val, dataset='PASCALC', experiment=1)
201 | valid_set = PascalC(root=data_path,
202 | debug=False,
203 | split='val',
204 | transforms_dict=transforms_dict_val)
205 |
206 | issues = []
207 | valid_set.return_filename = True
208 |
209 | # if i ==5:
210 | # break
211 | path_to_chkpt = '..\\logging/PASCALC/20211216_072315_e1__hrn_200epochs_hr48_sbn_DCms_cs/chkpts/chkpt_epoch_199.pt'
212 | # path_to_chkpt = '..\\logging/PASCALC/20211215_213857_e1__hrn_200epochs_hr48_sbn_CE/chkpts/chkpt_epoch_199.pt'
213 | map_location = 'cuda:0'
214 | checkpoint = torch.load(str(path_to_chkpt), map_location)
215 | torch.manual_seed(0)
216 | config = dict()
217 | config.update({'backbone': 'hrnet48', 'out_stride': 4, 'pretrained': True, 'dataset':'PASCALC'})
218 | model = HRNet(config, 1)
219 | msg = model.load_state_dict(check_module_prefix(checkpoint['model_state_dict'], model), strict=False)
220 | print(msg)
221 |
222 | for i, ret in enumerate(valid_set):
223 | img = ret[0]
224 | ori_shape = ret[2]['original_labels'].shape[-2:]
225 | with torch.no_grad():
226 | tta_model = TTAWrapperPC(model, scale_list=[0.75, 0.5, 1.5])
227 | tta_model.cuda()
228 | tta_model.eval()
229 | x = img
230 | y = tta_model.forward(x.unsqueeze(0).float(), ori_shape, i)
231 | c = 1
--------------------------------------------------------------------------------
/models/Transformers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | # from lib.models.backbones.vit.helper import IntermediateSequential
3 | import torch
4 |
5 | class SelfAttention(nn.Module):
6 | def __init__(
7 | self, dim, heads=1, qkv_bias=False, qk_scale=None, dropout_rate=0.0
8 | ):
9 | super().__init__()
10 | self.num_heads = heads
11 | head_dim = dim // heads
12 | self.scale = qk_scale or head_dim ** -0.5
13 |
14 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
15 | self.dropout_rate = 0.0
16 | self.attn_drop = nn.Dropout(dropout_rate)
17 | self.proj = nn.Linear(dim, dim)
18 | self.proj_drop = nn.Dropout(dropout_rate)
19 |
20 | def forward(self, x, unflatten_output=True):
21 | H,W,was_flattened = -1,-1, False
22 | if len(x.shape)==4:
23 | was_flattened=unflatten_output
24 | B,C,H,W = x.shape
25 | x = x.permute(0,1,2,3).view(B,-1,C) # B,C,H,W --> B,H,W,C --> B,HW,C
26 |
27 | B, N, C = x.shape
28 | qkv = (
29 | self.qkv(x)
30 | .reshape(B, N, 3, self.num_heads, C // self.num_heads)
31 | .permute(2, 0, 3, 1, 4)
32 | )
33 | q, k, v = (
34 | qkv[0],
35 | qkv[1],
36 | qkv[2],
37 | ) # make torchscript happy (cannot use tensor as tuple)
38 |
39 | attn = (q @ k.transpose(-2, -1)) * self.scale
40 | attn = attn.softmax(dim=-1)
41 | if self.dropout_rate > 0.0:
42 | attn = self.attn_drop(attn)
43 |
44 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
45 | x = self.proj(x)
46 | if self.dropout_rate > 0.0:
47 | x = self.proj_drop(x)
48 | if was_flattened:
49 | return x.view(B,H,W,C).permute(0,3,1,2) # B,HW,C --> B,H,W,C --> B,C,H,W
50 | return x
51 |
52 |
53 | class Residual(nn.Module):
54 | def __init__(self, fn):
55 | super().__init__()
56 | self.fn = fn
57 |
58 | def forward(self, x):
59 | return self.fn(x) + x
60 |
61 |
62 | class PreNorm(nn.Module):
63 | def __init__(self, dim, fn):
64 | super().__init__()
65 | self.norm = nn.LayerNorm(dim)
66 | self.fn = fn
67 |
68 | def forward(self, x):
69 | return self.fn(self.norm(x))
70 |
71 |
72 | class PreNormDrop(nn.Module):
73 | def __init__(self, dim, dropout_rate, fn):
74 | super().__init__()
75 | self.norm = nn.LayerNorm(dim)
76 | self.dropout = nn.Dropout(p=dropout_rate)
77 | self.fn = fn
78 |
79 | def forward(self, x):
80 | return self.dropout(self.fn(self.norm(x)))
81 |
82 |
83 | class FeedForward(nn.Module):
84 | def __init__(self, dim, hidden_dim, dropout_rate):
85 | super().__init__()
86 | self.net = nn.Sequential(
87 | nn.Linear(dim, hidden_dim),
88 | nn.GELU(),
89 | nn.Dropout(p=dropout_rate),
90 | nn.Linear(hidden_dim, dim),
91 | nn.Dropout(p=dropout_rate),
92 | )
93 |
94 | def forward(self, x):
95 | return self.net(x)
96 |
97 |
98 | class TransformerModel(nn.Module):
99 | def __init__(
100 | self,
101 | dim,
102 | depth,
103 | heads,
104 | mlp_dim,
105 | dropout_rate=0.1,
106 | attn_dropout_rate=0.1,
107 | ):
108 | super().__init__()
109 | layers = []
110 | for _ in range(depth):
111 | layers.extend(
112 | [
113 | Residual(
114 | PreNormDrop(
115 | dim,
116 | dropout_rate,
117 | SelfAttention(
118 | dim, heads=heads, dropout_rate=attn_dropout_rate
119 | ),
120 | )
121 | ),
122 | Residual(
123 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
124 | ),
125 | ]
126 | )
127 | self.net = torch.nn.Sequential(*layers)
128 |
129 | def forward(self, x):
130 | return self.net(x)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .UPerNet import UPerNet
2 | from .DeepLabv3 import DeepLabv3
3 | from .OCR import OCRNet
4 | from .Projector import Projector
5 | from .HRNet import hrnet48, hrnet32, hrnet18, HRNet
6 | from .TTA_wrapper import TTAWrapper
7 | from .TTA_wrapper_CTS import TTAWrapperCTS
8 | from .TTA_wrapper_PC import TTAWrapperPC
9 | from .TTAWrapperSlide import TTAWrapperSlide
10 | from .Transformers import SelfAttention
11 | from .Swin import SwinTransformer
12 |
13 |
--------------------------------------------------------------------------------
/models/hrnet_config.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Create by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn), Rainbowsecret (yuyua@microsoft.com)
6 | # ------------------------------------------------------------------------------
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | from utils import EasyDict as CN
13 |
14 | # configs for HRNet48
15 | HRNET_48 = CN()
16 | HRNET_48.FINAL_CONV_KERNEL = 1
17 |
18 | HRNET_48.STAGE1 = CN()
19 | HRNET_48.STAGE1.NUM_MODULES = 1
20 | HRNET_48.STAGE1.NUM_BRANCHES = 1
21 | HRNET_48.STAGE1.NUM_BLOCKS = [4]
22 | HRNET_48.STAGE1.NUM_CHANNELS = [64]
23 | HRNET_48.STAGE1.BLOCK = 'BOTTLENECK'
24 | HRNET_48.STAGE1.FUSE_METHOD = 'SUM'
25 |
26 | HRNET_48.STAGE2 = CN()
27 | HRNET_48.STAGE2.NUM_MODULES = 1
28 | HRNET_48.STAGE2.NUM_BRANCHES = 2
29 | HRNET_48.STAGE2.NUM_BLOCKS = [4, 4]
30 | HRNET_48.STAGE2.NUM_CHANNELS = [48, 96]
31 | HRNET_48.STAGE2.BLOCK = 'BASIC'
32 | HRNET_48.STAGE2.FUSE_METHOD = 'SUM'
33 |
34 | HRNET_48.STAGE3 = CN()
35 | HRNET_48.STAGE3.NUM_MODULES = 4
36 | HRNET_48.STAGE3.NUM_BRANCHES = 3
37 | HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4]
38 | HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192]
39 | HRNET_48.STAGE3.BLOCK = 'BASIC'
40 | HRNET_48.STAGE3.FUSE_METHOD = 'SUM'
41 |
42 | HRNET_48.STAGE4 = CN()
43 | HRNET_48.STAGE4.NUM_MODULES = 3
44 | HRNET_48.STAGE4.NUM_BRANCHES = 4
45 | HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
46 | HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384]
47 | HRNET_48.STAGE4.BLOCK = 'BASIC'
48 | HRNET_48.STAGE4.FUSE_METHOD = 'SUM'
49 |
50 |
51 | # configs for HRNet32
52 | HRNET_32 = CN()
53 | HRNET_32.FINAL_CONV_KERNEL = 1
54 |
55 | HRNET_32.STAGE1 = CN()
56 | HRNET_32.STAGE1.NUM_MODULES = 1
57 | HRNET_32.STAGE1.NUM_BRANCHES = 1
58 | HRNET_32.STAGE1.NUM_BLOCKS = [4]
59 | HRNET_32.STAGE1.NUM_CHANNELS = [64]
60 | HRNET_32.STAGE1.BLOCK = 'BOTTLENECK'
61 | HRNET_32.STAGE1.FUSE_METHOD = 'SUM'
62 |
63 | HRNET_32.STAGE2 = CN()
64 | HRNET_32.STAGE2.NUM_MODULES = 1
65 | HRNET_32.STAGE2.NUM_BRANCHES = 2
66 | HRNET_32.STAGE2.NUM_BLOCKS = [4, 4]
67 | HRNET_32.STAGE2.NUM_CHANNELS = [32, 64]
68 | HRNET_32.STAGE2.BLOCK = 'BASIC'
69 | HRNET_32.STAGE2.FUSE_METHOD = 'SUM'
70 |
71 | HRNET_32.STAGE3 = CN()
72 | HRNET_32.STAGE3.NUM_MODULES = 4
73 | HRNET_32.STAGE3.NUM_BRANCHES = 3
74 | HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4]
75 | HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128]
76 | HRNET_32.STAGE3.BLOCK = 'BASIC'
77 | HRNET_32.STAGE3.FUSE_METHOD = 'SUM'
78 |
79 | HRNET_32.STAGE4 = CN()
80 | HRNET_32.STAGE4.NUM_MODULES = 3
81 | HRNET_32.STAGE4.NUM_BRANCHES = 4
82 | HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
83 | HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
84 | HRNET_32.STAGE4.BLOCK = 'BASIC'
85 | HRNET_32.STAGE4.FUSE_METHOD = 'SUM'
86 |
87 |
88 | # configs for HRNet18
89 | HRNET_18 = CN()
90 | HRNET_18.FINAL_CONV_KERNEL = 1
91 |
92 | HRNET_18.STAGE1 = CN()
93 | HRNET_18.STAGE1.NUM_MODULES = 1
94 | HRNET_18.STAGE1.NUM_BRANCHES = 1
95 | HRNET_18.STAGE1.NUM_BLOCKS = [4]
96 | HRNET_18.STAGE1.NUM_CHANNELS = [64]
97 | HRNET_18.STAGE1.BLOCK = 'BOTTLENECK'
98 | HRNET_18.STAGE1.FUSE_METHOD = 'SUM'
99 |
100 | HRNET_18.STAGE2 = CN()
101 | HRNET_18.STAGE2.NUM_MODULES = 1
102 | HRNET_18.STAGE2.NUM_BRANCHES = 2
103 | HRNET_18.STAGE2.NUM_BLOCKS = [4, 4]
104 | HRNET_18.STAGE2.NUM_CHANNELS = [18, 36]
105 | HRNET_18.STAGE2.BLOCK = 'BASIC'
106 | HRNET_18.STAGE2.FUSE_METHOD = 'SUM'
107 |
108 | HRNET_18.STAGE3 = CN()
109 | HRNET_18.STAGE3.NUM_MODULES = 4
110 | HRNET_18.STAGE3.NUM_BRANCHES = 3
111 | HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4]
112 | HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72]
113 | HRNET_18.STAGE3.BLOCK = 'BASIC'
114 | HRNET_18.STAGE3.FUSE_METHOD = 'SUM'
115 |
116 | HRNET_18.STAGE4 = CN()
117 | HRNET_18.STAGE4.NUM_MODULES = 3
118 | HRNET_18.STAGE4.NUM_BRANCHES = 4
119 | HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
120 | HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144]
121 | HRNET_18.STAGE4.BLOCK = 'BASIC'
122 | HRNET_18.STAGE4.FUSE_METHOD = 'SUM'
123 |
124 |
125 | MODEL_CONFIGS = {
126 | 'hrnet18': HRNET_18,
127 | 'hrnet32': HRNET_32,
128 | 'hrnet48': HRNET_48,
129 | }
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .defaults import *
2 | from .transforms import *
3 | from .np_transforms import *
4 | from .utils import *
5 | from .torch_utils import *
6 | from .repeat_factor_sampling import RepeatFactorSampler
7 | from .lr_functions import *
8 | from .distributed import *
9 | from .checkpoint_utils import *
10 | from .logger import *
11 | from .config_parsers import *
12 | from .optimizer_utils import *
13 | from .tsne_visualization import *
14 |
--------------------------------------------------------------------------------
/utils/checkpoint_utils.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | def _check_model_param_prefix(state_dict, prefix:str):
5 | # check if parameters of model state dict contain a give prefix
6 | found_prefix_model = False
7 | for param_name in state_dict:
8 | if not param_name.startswith(prefix):
9 | found_prefix_model = False
10 | if found_prefix_model:
11 | raise Warning('module prefix found in some of the model params but not others '
12 | '-- this will cause bugs!! -- check before proceeding')
13 | break
14 | else:
15 | found_prefix_model = True
16 | return found_prefix_model
17 |
18 |
19 | def check_module_prefix(chkpt_state_dict, model:nn.Module):
20 | found_prefix_model = _check_model_param_prefix(model.state_dict(), prefix='module.')
21 | found_prefix_chkpt = _check_model_param_prefix(chkpt_state_dict, prefix='module.')
22 |
23 | # remove prefix from chkpt_state_dict keys
24 | # if that prefix is not found in model variable names
25 | if ~found_prefix_model and found_prefix_chkpt:
26 | for k in list(chkpt_state_dict.keys()):
27 | # retain only encoder_q up to before the embedding layer
28 | if k.startswith('module.'):
29 | # remove prefix
30 | chkpt_state_dict[k[len("module."):]] = chkpt_state_dict[k]
31 | # delete renamed or unused k
32 | del chkpt_state_dict[k]
33 |
34 | return chkpt_state_dict
35 |
--------------------------------------------------------------------------------
/utils/datasets_info/CADIS.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 |
4 | class EasyDict(dict):
5 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
6 |
7 | def __getattr__(self, name: str) -> Any:
8 | try:
9 | return self[name]
10 | except KeyError:
11 | raise AttributeError(name)
12 |
13 | def __setattr__(self, name: str, value: Any) -> None:
14 | self[name] = value
15 |
16 | def __delattr__(self, name: str) -> None:
17 | del self[name]
18 |
19 |
20 | DATA_SPLITS = [ # Pre-defined splits of the videos, to be used generally
21 | [[1], [5]], # Split 0: debugging
22 | [[1, 3, 4, 6, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 21, 23, 24, 25], [5, 7, 16, 2, 12, 22]], # Split 1
23 | [list(range(1, 26)), [5, 7, 16, 2, 12, 22]], # Split 2 (all data)
24 | [[1, 8, 9, 10, 14, 15, 21, 23, 24], [5, 7, 16, 2, 12, 22]], # Split 3: "50% of data" (1729 frames, 49.3%)
25 | [[10, 14, 21, 24], [5, 7, 16, 2, 12, 22]], # Split 4: "25% of data" (834 frames, 23.8%)
26 | [[1, 3, 4, 6, 8, 9, 10, 11, 13, 14, 15, 17, 18, 19, 20, 21, 23, 24, 25], [5, 7, 16], [2, 12, 22]] # train-val-test
27 | ]
28 |
29 | categories_exp0 = {
30 | 'anatomies': [],
31 | 'instruments': [],
32 | 'others': []
33 | }
34 | categories_exp1 = {
35 | 'anatomies': [0, 4, 5, 6],
36 | 'instruments': [7],
37 | 'others': [1, 2, 3],
38 | 'rare': [2]
39 | }
40 | categories_exp2 = {
41 | 'anatomies': [0, 4, 5, 6],
42 | 'instruments': [7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
43 | 'others': [1, 2, 3],
44 | 'rare': [16, 10, 9, 12, 14] # picked with freq_thresh 0.2 and s.t rf > 1.5
45 | }
46 | categories_exp3 = {
47 | 'anatomies': [0, 4, 5, 6],
48 | 'instruments': [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
49 | 'others': [1, 2, 3],
50 | 'rare': [24, 20, 21, 22, 18, 23, 19, 16, 12, 11, 14] # picked with freq_thresh 0.2 and s.t rf > 1.5
51 | }
52 |
53 | class_remapping_exp0 = {
54 | 0: [0],
55 | 1: [1],
56 | 2: [2],
57 | 3: [3],
58 | 4: [4],
59 | 5: [5],
60 | 6: [6],
61 | 7: [7],
62 | 8: [8],
63 | 9: [9],
64 | 10: [10],
65 | 11: [11],
66 | 12: [12],
67 | 13: [13],
68 | 14: [14],
69 | 15: [15],
70 | 16: [16],
71 | 17: [17],
72 | 18: [18],
73 | 19: [19],
74 | 20: [20],
75 | 21: [21],
76 | 22: [22],
77 | 23: [23],
78 | 24: [24],
79 | 25: [25],
80 | 26: [26],
81 | 27: [27],
82 | 28: [28],
83 | 29: [29],
84 | 30: [30],
85 | 31: [31],
86 | 32: [32],
87 | 33: [33],
88 | 34: [34],
89 | 35: [35]
90 | }
91 | classes_exp0 = {
92 | 0: 'Pupil',
93 | 1: 'Surgical Tape',
94 | 2: 'Hand',
95 | 3: 'Eye Retractors',
96 | 4: 'Iris',
97 | 5: 'Skin',
98 | 6: 'Cornea',
99 | 7: 'Hydrodissection Cannula',
100 | 8: 'Viscoelastic Cannula',
101 | 9: 'Capsulorhexis Cystotome',
102 | 10: 'Rycroft Cannula',
103 | 11: 'Bonn Forceps',
104 | 12: 'Primary Knife',
105 | 13: 'Phacoemulsifier Handpiece',
106 | 14: 'Lens Injector',
107 | 15: 'I/A Handpiece',
108 | 16: 'Secondary Knife',
109 | 17: 'Micromanipulator',
110 | 18: 'I/A Handpiece Handle',
111 | 19: 'Capsulorhexis Forceps',
112 | 20: 'Rycroft Cannula Handle',
113 | 21: 'Phacoemulsifier Handpiece Handle',
114 | 22: 'Capsulorhexis Cystotome Handle',
115 | 23: 'Secondary Knife Handle',
116 | 24: 'Lens Injector Handle',
117 | 25: 'Suture Needle',
118 | 26: 'Needle Holder',
119 | 27: 'Charleux Cannula',
120 | 28: 'Primary Knife Handle',
121 | 29: 'Vitrectomy Handpiece',
122 | 30: 'Mendez Ring',
123 | 31: 'Marker',
124 | 32: 'Hydrodissection Cannula Handle',
125 | 33: 'Troutman Forceps',
126 | 34: 'Cotton',
127 | 35: 'Iris Hooks'
128 | }
129 |
130 | class_remapping_exp1 = {
131 | 0: [0],
132 | 1: [1],
133 | 2: [2],
134 | 3: [3],
135 | 4: [4],
136 | 5: [5],
137 | 6: [6],
138 | 7: [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22,
139 | 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
140 | }
141 | classes_exp1 = {
142 | 0: "Pupil",
143 | 1: "Surgical Tape",
144 | 2: "Hand",
145 | 3: "Eye Retractors",
146 | 4: "Iris",
147 | 5: "Skin",
148 | 6: "Cornea",
149 | 7: "Instrument",
150 | }
151 |
152 | class_remapping_exp2 = {
153 | 0: [0],
154 | 1: [1],
155 | 2: [2],
156 | 3: [3],
157 | 4: [4],
158 | 5: [5],
159 | 6: [6],
160 | 7: [7, 8, 10, 27, 20, 32],
161 | 8: [9, 22],
162 | 9: [11, 33],
163 | 10: [12, 28],
164 | 11: [13, 21],
165 | 12: [14, 24],
166 | 13: [15, 18],
167 | 14: [16, 23],
168 | 15: [17],
169 | 16: [19],
170 | 255: [25, 26, 29, 30, 31, 34, 35],
171 | }
172 | classes_exp2 = {
173 | 0: "Pupil",
174 | 1: "Surgical Tape",
175 | 2: "Hand",
176 | 3: "Eye Retractors",
177 | 4: "Iris",
178 | 5: "Skin",
179 | 6: "Cornea",
180 | 7: "Cannula",
181 | 8: "Cap. Cystotome",
182 | 9: "Tissue Forceps",
183 | 10: "Primary Knife",
184 | 11: "Ph. Handpiece",
185 | 12: "Lens Injector",
186 | 13: "I/A Handpiece",
187 | 14: "Secondary Knife",
188 | 15: "Micromanipulator",
189 | 16: "Cap. Forceps",
190 | 255: "Ignore",
191 | }
192 |
193 | class_remapping_exp3 = {
194 | 0: [0],
195 | 1: [1],
196 | 2: [2],
197 | 3: [3],
198 | 4: [4],
199 | 5: [5],
200 | 6: [6],
201 | 7: [7],
202 | 8: [8],
203 | 9: [9],
204 | 10: [10],
205 | 11: [11],
206 | 12: [12],
207 | 13: [13],
208 | 14: [14],
209 | 15: [15],
210 | 16: [16],
211 | 17: [17],
212 | 18: [18],
213 | 19: [19],
214 | 20: [20],
215 | 21: [21],
216 | 22: [22],
217 | 23: [23],
218 | 24: [24],
219 | 255: [25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
220 | }
221 | classes_exp3 = {
222 | 0: "Pupil",
223 | 1: "Surgical Tape",
224 | 2: "Hand",
225 | 3: "Eye Retractors",
226 | 4: "Iris",
227 | 5: "Skin",
228 | 6: "Cornea",
229 | 7: "Hydro. Cannula",
230 | 8: "Visc. Cannula",
231 | 9: "Cap. Cystotome",
232 | 10: "Rycroft Cannula",
233 | 11: "Bonn Forceps",
234 | 12: "Primary Knife",
235 | 13: "Ph. Handpiece",
236 | 14: "Lens Injector",
237 | 15: "I/A Handpiece",
238 | 16: "Secondary Knife",
239 | 17: "Micromanipulator",
240 | 18: "I/A Handpiece Handle",
241 | 19: "Cap. Forceps",
242 | 20: "R. Cannula Handle",
243 | 21: "Ph. Handpiece Handle",
244 | 22: "Cap. Cystotome Handle",
245 | 23: "Sec. Knife Handle",
246 | 24: "Lens Injector Handle",
247 | 255: "Ignore",
248 | }
249 |
250 | CLASS_INFO = [
251 | [class_remapping_exp0, classes_exp0, categories_exp0], # Original classes
252 | [class_remapping_exp1, classes_exp1, categories_exp1],
253 | [class_remapping_exp2, classes_exp2, categories_exp2],
254 | [class_remapping_exp3, classes_exp3, categories_exp3]
255 | ]
256 |
257 | CLASS_NAMES = [[CLASS_INFO[0][1][key] for key in sorted(CLASS_INFO[0][1].keys())],
258 | [CLASS_INFO[1][1][key] for key in sorted(CLASS_INFO[1][1].keys())],
259 | [CLASS_INFO[2][1][key] for key in sorted(CLASS_INFO[2][1].keys())],
260 | [CLASS_INFO[3][1][key] for key in sorted(CLASS_INFO[3][1].keys())]]
261 |
262 | OVERSAMPLING_PRESETS = {
263 | 'default': [
264 | [3, 5, 7], # Experiment 1
265 | [7, 8, 15, 16], # Experiment 2
266 | [19, 20, 22, 24] # Experiment 3
267 | ],
268 | 'rare': [ # Same classes as 'rare' category for mIoU metric
269 | [2], # Experiment 1
270 | [16, 10, 9, 12, 14], # Experiment 2
271 | [24, 20, 21, 22, 18, 23, 19, 16, 12, 11, 14] # Experiment 3
272 | ]
273 | }
274 |
275 | CLASS_FREQUENCIES = [
276 | 1.68024535e-01,
277 | 5.93061223e-02,
278 | 7.38987570e-03,
279 | 5.72173439e-03,
280 | 1.12288211e-01,
281 | 1.33608027e-01,
282 | 4.89257831e-01,
283 | 1.26300163e-03,
284 | 8.96526043e-04,
285 | 9.28408858e-04,
286 | 6.47719387e-04,
287 | 2.61340734e-03,
288 | 1.40455685e-03,
289 | 1.84766048e-03,
290 | 3.25327478e-03,
291 | 3.60986861e-03,
292 | 1.06050077e-03,
293 | 1.97264561e-03,
294 | 5.32642854e-04,
295 | 7.07037962e-04,
296 | 3.66272768e-04,
297 | 4.75095501e-04,
298 | 1.73250919e-04,
299 | 5.49602466e-04,
300 | 2.91966965e-04,
301 | 1.06066764e-05,
302 | 1.54437472e-04,
303 | 4.16546878e-05,
304 | 2.96828324e-06,
305 | 1.02785378e-04,
306 | 4.38665256e-04,
307 | 4.91079867e-04,
308 | 1.13576281e-05,
309 | 1.83788200e-04,
310 | 1.37330396e-04,
311 | 2.35550169e-04
312 | ]
313 | CLASS_SUMS = [
314 | 406775301,
315 | 143575852,
316 | 17890357,
317 | 13851907,
318 | 271841675,
319 | 323455413,
320 | 1184457982,
321 | 3057636,
322 | 2170425,
323 | 2247611,
324 | 1568082,
325 | 6326871,
326 | 3400331,
327 | 4473053,
328 | 7875944,
329 | 8739232,
330 | 2567396,
331 | 4775633,
332 | 1289490,
333 | 1711688,
334 | 886720,
335 | 1150172,
336 | 419428,
337 | 1330548,
338 | 706831,
339 | 25678,
340 | 373882,
341 | 100843,
342 | 7186,
343 | 248836,
344 | 1061977,
345 | 1188869,
346 | 27496,
347 | 444938,
348 | 332467,
349 | 570250
350 | ]
351 |
352 | CADIS_INFO = EasyDict(CLASS_INFO=CLASS_INFO,
353 | CLASS_NAMES=CLASS_NAMES,
354 | DATA_SPLITS=DATA_SPLITS,
355 | OVERSAMPLING_PRESETS=OVERSAMPLING_PRESETS,
356 | CLASS_FREQUENCIES=CLASS_FREQUENCIES,
357 | CLASS_SUMS=CLASS_SUMS)
358 |
--------------------------------------------------------------------------------
/utils/datasets_info/CITYSCAPES.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 |
4 | class EasyDict(dict):
5 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
6 |
7 | def __getattr__(self, name: str) -> Any:
8 | try:
9 | return self[name]
10 | except KeyError:
11 | raise AttributeError(name)
12 |
13 | def __setattr__(self, name: str, value: Any) -> None:
14 | self[name] = value
15 |
16 | def __delattr__(self, name: str) -> None:
17 | del self[name]
18 |
19 |
20 | categories_exp0 = {
21 | 'void': [0, 1, 2, 3, 4, 5, 6],
22 | 'flat': [7, 8, 9, 10],
23 | 'construction': [11, 12, 13, 14, 15, 16],
24 | 'object': [17, 18, 19, 20],
25 | 'nature': [21, 22],
26 | 'sky': [23],
27 | 'human': [24, 25],
28 | 'vehicle': [26, 27, 28, 29, 30, 31, 32, 33]
29 | }
30 |
31 | categories_exp1 = {
32 | 'flat': [0, 1],
33 | 'construction': [2, 3, 4],
34 | 'object': [5, 6, 7],
35 | 'nature': [8, 9],
36 | 'sky': [10],
37 | 'human': [11, 12],
38 | 'vehicle': [13, 14, 15, 16, 17, 18]
39 | }
40 |
41 | class_remapping_exp0 = {
42 | 0: [0],
43 | 1: [1],
44 | 2: [2],
45 | 3: [3],
46 | 4: [4],
47 | 5: [5],
48 | 6: [6],
49 | 7: [7],
50 | 8: [8],
51 | 9: [9],
52 | 10: [10],
53 | 11: [11],
54 | 12: [12],
55 | 13: [13],
56 | 14: [14],
57 | 15: [15],
58 | 16: [16],
59 | 17: [17],
60 | 18: [18],
61 | 19: [19],
62 | 20: [20],
63 | 21: [21],
64 | 22: [22],
65 | 23: [23],
66 | 24: [24],
67 | 25: [25],
68 | 26: [26],
69 | 27: [27],
70 | 28: [28],
71 | 29: [29],
72 | 30: [30],
73 | 31: [31],
74 | 32: [32],
75 | 33: [33],
76 | -1: [-1]
77 | }
78 | classes_exp0 = {
79 | 0: 'unlabeled',
80 | 1: 'ego vehicle',
81 | 2: 'rectification border',
82 | 3: 'out of roi',
83 | 4: 'static',
84 | 5: 'dynamic',
85 | 6: 'ground',
86 | 7: 'road',
87 | 8: 'sidewalk',
88 | 9: 'parking',
89 | 10: 'rail track',
90 | 11: 'building',
91 | 12: 'wall',
92 | 13: 'fence',
93 | 14: 'guard rail',
94 | 15: 'bridge',
95 | 16: 'tunnel',
96 | 17: 'pole',
97 | 18: 'polegroup',
98 | 19: 'traffic light',
99 | 20: 'traffic sign',
100 | 21: 'vegetation',
101 | 22: 'terrain',
102 | 23: 'sky',
103 | 24: 'person',
104 | 25: 'rider',
105 | 26: 'car',
106 | 27: 'truck',
107 | 28: 'bus',
108 | 29: 'caravan',
109 | 30: 'trailer',
110 | 31: 'train',
111 | 32: 'motorcycle',
112 | 33: 'bicycle',
113 | 34: 'Cotton',
114 | 35: 'Iris Hooks',
115 | -1: 'license plate'
116 | }
117 |
118 | class_remapping_exp1 = {
119 | 0: [7],
120 | 1: [8],
121 | 2: [11],
122 | 3: [12],
123 | 4: [13],
124 | 5: [17],
125 | 6: [19],
126 | 7: [20],
127 | 8: [21],
128 | 9: [22],
129 | 10: [23],
130 | 11: [24],
131 | 12: [25],
132 | 13: [26],
133 | 14: [27],
134 | 15: [28],
135 | 16: [31],
136 | 17: [32],
137 | 18: [33],
138 | 255: [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
139 | }
140 |
141 |
142 | classes_exp1 = {
143 | 0: 'road',
144 | 1: 'sidewalk',
145 | 2: 'building',
146 | 3: 'wall',
147 | 4: 'fence',
148 | 5: 'pole',
149 | 6: 'traffic light',
150 | 7: 'traffic sign',
151 | 8: 'vegetation',
152 | 9: 'terrain',
153 | 10: 'sky',
154 | 11: 'person',
155 | 12: 'rider',
156 | 13: 'car',
157 | 14: 'truck',
158 | 15: 'bus',
159 | 16: 'train',
160 | 17: 'motorcycle',
161 | 18: 'bicycle',
162 | 255: 'Ignore'
163 | }
164 |
165 |
166 | CLASS_INFO = [
167 | [class_remapping_exp0, classes_exp0, categories_exp0], # Original classes
168 | [class_remapping_exp1, classes_exp1, categories_exp1]
169 | ]
170 |
171 | CLASS_NAMES = [[CLASS_INFO[0][1][key] for key in sorted(CLASS_INFO[0][1].keys())],
172 | [CLASS_INFO[1][1][key] for key in sorted(CLASS_INFO[1][1].keys())]]
173 |
174 | CITYSCAPES_INFO = EasyDict(CLASS_INFO=CLASS_INFO, CLASS_NAMES=CLASS_NAMES)
175 |
176 | if __name__ == '__main__':
177 | # all info is in a class attribute of the Cityscapes class
178 | from torchvision.datasets.cityscapes import Cityscapes
179 | CTS_info = Cityscapes.classes
180 | categories_exp1 = {}
181 | colormap = {}
182 | ingored_colormap = {}
183 | class_remap_exp1 = {}
184 | categ_exp0 = {}
185 | categ_exp1 = {}
186 | for cl in CTS_info:
187 | ############################################
188 | classes_exp0[cl.id] = cl.name
189 | colormap[cl.id] = cl.color
190 | # ingored_colormap[cl.train_id] = cl.color
191 | ############################################
192 | if cl.train_id in class_remap_exp1:
193 | class_remap_exp1[cl.train_id] += [cl.id]
194 | else:
195 | # -1 mapped to 255 which is used as the ignored class
196 | class_remap_exp1[cl.train_id] = [cl.id]
197 | classes_exp1[cl.train_id] = cl.name
198 |
199 | if cl.category not in categ_exp0:
200 | categ_exp0[cl.category] = [cl.id]
201 | else:
202 | categ_exp0[cl.category] += [cl.id]
203 |
204 | if cl.category not in categ_exp1:
205 | categ_exp1[cl.category] = [cl.train_id]
206 | else:
207 | categ_exp1[cl.category] += [cl.train_id]
208 |
209 | class_remap_exp1.pop(-1) # remove -1 from dictionary
210 | class_remap_exp1[255] += [-1] # and place it in the ignore class
211 |
212 | classes_exp1[255] = 'Ignore'
213 | classes_exp1.pop(-1) # remove -1 from dictionary
214 |
215 | a = 1
216 |
--------------------------------------------------------------------------------
/utils/datasets_info/PASCALC.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 |
4 | class EasyDict(dict):
5 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
6 |
7 | def __getattr__(self, name: str) -> Any:
8 | try:
9 | return self[name]
10 | except KeyError:
11 | raise AttributeError(name)
12 |
13 | def __setattr__(self, name: str, value: Any) -> None:
14 | self[name] = value
15 |
16 | def __delattr__(self, name: str) -> None:
17 | del self[name]
18 |
19 |
20 | categories_exp0 = {
21 | 'flat': [1, 2],
22 | }
23 |
24 | categories_exp1 = {
25 | 'flat': [1, 2],
26 | }
27 |
28 | class_remapping_exp0 = {
29 | 0:[255],
30 | 1: [1],
31 | 2: [2],
32 | 3: [3],
33 | 4: [4],
34 | 5: [5],
35 | 6: [6],
36 | 7: [7],
37 | 8: [8],
38 | 9: [9],
39 | 10: [10],
40 | 11: [11],
41 | 12: [12],
42 | 13: [13],
43 | 14: [14],
44 | 15: [15],
45 | 16: [16],
46 | 17: [17],
47 | 18: [18],
48 | 19: [19],
49 | 20: [20],
50 | 21: [21],
51 | 22: [22],
52 | 23: [23],
53 | 24: [24],
54 | 25: [25],
55 | 26: [26],
56 | 27: [27],
57 | 28: [28],
58 | 29: [29],
59 | 30: [30],
60 | 31: [31],
61 | 32: [32],
62 | 33: [33],
63 | 34: [34],
64 | 35: [35],
65 | 36: [36],
66 | 37: [37],
67 | 38: [38],
68 | 39: [39],
69 | 40: [40],
70 | 41: [41],
71 | 42: [42],
72 | 43: [43],
73 | 44: [44],
74 | 45: [45],
75 | 46: [46],
76 | 47: [47],
77 | 48: [48],
78 | 49: [49],
79 | 50: [50],
80 | 51: [51],
81 | 52: [52],
82 | 53: [53],
83 | 54: [54],
84 | 55: [55],
85 | 56: [56],
86 | 57: [57],
87 | 58: [58],
88 | 59: [59]
89 | }
90 | #
91 | # CLASSES = ('background', 'aeroplane', 'bag', 'bed', 'bedclothes', 'bench',
92 | # 'bicycle', 'bird', 'boat', 'book', 'bottle', 'building', 'bus',
93 | # 'cabinet', 'car', 'cat', 'ceiling', 'chair', 'cloth',
94 | # 'computer', 'cow', 'cup', 'curtain', 'dog', 'door', 'fence',
95 | # 'floor', 'flower', 'food', 'grass', 'ground', 'horse',
96 | # 'keyboard', 'light', 'motorbike', 'mountain', 'mouse', 'person',
97 | # 'plate', 'platform', 'pottedplant', 'road', 'rock', 'sheep',
98 | # 'shelves', 'sidewalk', 'sign', 'sky', 'snow', 'sofa', 'table',
99 | # 'track', 'train', 'tree', 'truck', 'tvmonitor', 'wall', 'water',
100 | # 'window', 'wood')
101 | #
102 | # PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
103 | # [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
104 | # [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
105 | # [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
106 | # [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
107 | # [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
108 | # [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
109 | # [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
110 | # [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
111 | # [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
112 | # [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
113 | # [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
114 | # [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
115 | # [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
116 | # [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
117 |
118 | classes_exp0 = {
119 | 0: "background",
120 | 1: "aeroplane",
121 | 2: "bag",
122 | 3: "bed",
123 | 4: "bedclothes",
124 | 5: "bench",
125 | 6: "bicycle",
126 | 7: "bird",
127 | 8: "boat",
128 | 9: "book",
129 | 10: "bottle",
130 | 11: "building",
131 | 12: "bus",
132 | 13: "cabinet",
133 | 14: "car",
134 | 15: "cat",
135 | 16: "ceiling",
136 | 17: "chair",
137 | 18: "cloth",
138 | 19: "computer",
139 | 20: "cow",
140 | 21: "cup",
141 | 22: "curtain",
142 | 23: "dog",
143 | 24: "door",
144 | 25: "fence",
145 | 26: "floor",
146 | 27: "flower",
147 | 28: "food",
148 | 29: "grass",
149 | 30: "ground",
150 | 31: "horse",
151 | 32: "keyboard",
152 | 33: "light",
153 | 34: "motorbike",
154 | 35: "mountain",
155 | 36: "mouse",
156 | 37: "person",
157 | 38: "plate",
158 | 39: "platform",
159 | 40: "pottedplant",
160 | 41: "road",
161 | 42: "rock",
162 | 43: "sheep",
163 | 44: "shelves",
164 | 45: "sidewalk",
165 | 46: "sign",
166 | 47: "sky",
167 | 48: "snow",
168 | 49: "sofa",
169 | 50: "table",
170 | 51: "track",
171 | 52: "train",
172 | 53: "tree",
173 | 54: "truck",
174 | 55: "tvmonitor",
175 | 56: "wall",
176 | 57: "water",
177 | 58: "window",
178 | 59: "wood"
179 | }
180 |
181 |
182 | class_remapping_exp1 = {
183 | 255: [0],
184 | 0: [1],
185 | 1: [2],
186 | 2: [3],
187 | 3: [4],
188 | 4: [5],
189 | 5: [6],
190 | 6: [7],
191 | 7: [8],
192 | 8: [9],
193 | 9: [10],
194 | 10: [11],
195 | 11: [12],
196 | 12: [13],
197 | 13: [14],
198 | 14: [15],
199 | 15: [16],
200 | 16: [17],
201 | 17: [18],
202 | 18: [19],
203 | 19: [20],
204 | 20: [21],
205 | 21: [22],
206 | 22: [23],
207 | 23: [24],
208 | 24: [25],
209 | 25: [26],
210 | 26: [27],
211 | 27: [28],
212 | 28: [29],
213 | 29: [30],
214 | 30: [31],
215 | 31: [32],
216 | 32: [33],
217 | 33: [34],
218 | 34: [35],
219 | 35: [36],
220 | 36: [37],
221 | 37: [38],
222 | 38: [39],
223 | 39: [40],
224 | 40: [41],
225 | 41: [42],
226 | 42: [43],
227 | 43: [44],
228 | 44: [45],
229 | 45: [46],
230 | 46: [47],
231 | 47: [48],
232 | 48: [49],
233 | 49: [50],
234 | 50: [51],
235 | 51: [52],
236 | 52: [53],
237 | 53: [54],
238 | 54: [55],
239 | 55: [56],
240 | 56: [57],
241 | 57: [58],
242 | 58: [59]
243 | }
244 |
245 |
246 | classes_exp1 = {
247 | 255: "background",
248 | 0: "aeroplane",
249 | 1: "bag",
250 | 2: "bed",
251 | 3: "bedclothes",
252 | 4: "bench",
253 | 5: "bicycle",
254 | 6: "bird",
255 | 7: "boat",
256 | 8: "book",
257 | 9: "bottle",
258 | 10: "building",
259 | 11: "bus",
260 | 12: "cabinet",
261 | 13: "car",
262 | 14: "cat",
263 | 15: "ceiling",
264 | 16: "chair",
265 | 17: "cloth",
266 | 18: "computer",
267 | 19: "cow",
268 | 20: "cup",
269 | 21: "curtain",
270 | 22: "dog",
271 | 23: "door",
272 | 24: "fence",
273 | 25: "floor",
274 | 26: "flower",
275 | 27: "food",
276 | 28: "grass",
277 | 29: "ground",
278 | 30: "horse",
279 | 31: "keyboard",
280 | 32: "light",
281 | 33: "motorbike",
282 | 34: "mountain",
283 | 35: "mouse",
284 | 36: "person",
285 | 37: "plate",
286 | 38: "platform",
287 | 39: "pottedplant",
288 | 40: "road",
289 | 41: "rock",
290 | 42: "sheep",
291 | 43: "shelves",
292 | 44: "sidewalk",
293 | 45: "sign",
294 | 46: "sky",
295 | 47: "snow",
296 | 48: "sofa",
297 | 49: "table",
298 | 50: "track",
299 | 51: "train",
300 | 52: "tree",
301 | 53: "truck",
302 | 54: "tvmonitor",
303 | 55: "wall",
304 | 56: "water",
305 | 57: "window",
306 | 58: "wood"
307 | }
308 |
309 |
310 | CLASS_INFO = [
311 | [class_remapping_exp0, classes_exp0, categories_exp0], # Original classes
312 | [class_remapping_exp1, classes_exp1, categories_exp1]
313 | ]
314 |
315 | CLASS_NAMES = [[CLASS_INFO[0][1][key] for key in sorted(CLASS_INFO[0][1].keys())],
316 | [CLASS_INFO[1][1][key] for key in sorted(CLASS_INFO[1][1].keys())]]
317 |
318 | PASCALC_INFO = EasyDict(CLASS_INFO=CLASS_INFO, CLASS_NAMES=CLASS_NAMES)
319 |
320 |
321 | def label_sanity_check(root=None):
322 | import cv2
323 | import warnings
324 | import pathlib
325 | import numpy as np
326 | warning = 0
327 | warning_msg = []
328 | if root == None:
329 | root = pathlib.Path(r"C:\Users\Theodoros Pissas\Documents\tresorit\PASCALC\val\label/")
330 | for path_to_label in root.glob('**/*.PNG'):
331 | i = cv2.imread(str(path_to_label))
332 | labels_present = np.unique(i)
333 | print(f'{path_to_label.stem} : {labels_present}')
334 | if max(labels_present) > 59:
335 | warnings.warn(f'invalid label found {labels_present}')
336 | warning += 1
337 | warning_msg.append(f'invalid label found {labels_present}')
338 | return warning_msg, warning
339 |
340 | def class_dict_from_txt():
341 | d = dict()
342 | content = open('pascal.txt').read()
343 | print('{')
344 | for i in content.split('\n'):
345 | key = i.split(':')[0]
346 | val = i.split(':')[-1]
347 | # print(key, val)
348 | d[int(key)] = val
349 | val = val.replace(" ", "")
350 | print(f'{key}:"{val}",')
351 | print('}')
352 | if __name__ == '__main__':
353 | # label_sanity_check()
354 | # class_dict_from_txt()
355 |
356 | # for i in classes_exp0:
357 | # # for remapping
358 | # # print(f'{i-1}:{[i]},')
359 | from utils import get_pascalc_colormap
360 | # A = PALETTE
361 | # print(f'{i - 1}:"{classes_exp0[i]}",')
362 | for i, c in enumerate(classes_exp0):
363 | print(f'{i-1}:"{c}",')
--------------------------------------------------------------------------------
/utils/datasets_info/__init__.py:
--------------------------------------------------------------------------------
1 | from .CITYSCAPES import *
2 | from .CADIS import *
3 | from .PASCALC import *
4 | from .ADE20K import *
5 |
--------------------------------------------------------------------------------
/utils/defaults.py:
--------------------------------------------------------------------------------
1 | from .datasets_info import CITYSCAPES_INFO, CADIS_INFO, PASCALC_INFO, ADE20K_INFO
2 | import numpy as np
3 | from typing import Any
4 |
5 |
6 | class EasyDict(dict):
7 | """Convenience class that behaves like a dict but allows access with the attribute syntax."""
8 |
9 | def __getattr__(self, name: str) -> Any:
10 | try:
11 | return self[name]
12 | except KeyError:
13 | raise AttributeError(name)
14 |
15 | def __setattr__(self, name: str, value: Any) -> None:
16 | self[name] = value
17 |
18 | def __delattr__(self, name: str) -> None:
19 | del self[name]
20 |
21 |
22 | DATASETS_INFO = EasyDict(CADIS=CADIS_INFO, CITYSCAPES=CITYSCAPES_INFO, PASCALC=PASCALC_INFO, ADE20K=ADE20K_INFO)
23 |
24 |
25 | def get_cityscapes_colormap():
26 | """
27 | Returns cityscapes colormap as in paper
28 | :return: ndarray of rgb colors
29 | """
30 | return np.asarray(
31 | [(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81), (128, 64, 128),
32 | (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153),
33 | (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30),
34 | (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142),
35 | (0, 0, 70), (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)]
36 | )
37 |
38 |
39 | def get_cadis_colormap():
40 | """
41 | Returns cadis colormap as in paper
42 | :return: ndarray of rgb colors
43 | """
44 | return np.asarray(
45 | [
46 | [0, 137, 255],
47 | [255, 165, 0],
48 | [255, 156, 201],
49 | [99, 0, 255],
50 | [255, 0, 0],
51 | [255, 0, 165],
52 | [255, 255, 255],
53 | [141, 141, 141],
54 | [255, 218, 0],
55 | [173, 156, 255],
56 | [73, 73, 73],
57 | [250, 213, 255],
58 | [255, 156, 156],
59 | [99, 255, 0],
60 | [157, 225, 255],
61 | [255, 89, 124],
62 | [173, 255, 156],
63 | [255, 60, 0],
64 | [40, 0, 255],
65 | [170, 124, 0],
66 | [188, 255, 0],
67 | [0, 207, 255],
68 | [0, 255, 207],
69 | [188, 0, 255],
70 | [243, 0, 255],
71 | [0, 203, 108],
72 | [252, 255, 0],
73 | [93, 182, 177],
74 | [0, 81, 203],
75 | [211, 183, 120],
76 | [231, 203, 0],
77 | [0, 124, 255],
78 | [10, 91, 44],
79 | [2, 0, 60],
80 | [0, 144, 2],
81 | [133, 59, 59],
82 | ]
83 | )
84 |
85 |
86 | def get_pascalc_colormap():
87 | cmap = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
88 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
89 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
90 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
91 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
92 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
93 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
94 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
95 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
96 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
97 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
98 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
99 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
100 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
101 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]]
102 | return cmap
103 |
104 |
105 | def get_ade20k_colormap():
106 | # 151 VALUES , CMAP[0] IS IGNORED
107 | cmap = [[0,0,0], [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50],
108 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255],
109 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7],
110 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82],
111 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3],
112 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255],
113 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220],
114 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224],
115 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255],
116 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7],
117 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153],
118 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255],
119 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0],
120 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255],
121 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255],
122 | [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255],
123 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0],
124 | [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0],
125 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255],
126 | [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255],
127 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20],
128 | [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255],
129 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255],
130 | [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255],
131 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0],
132 | [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0],
133 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255],
134 | [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112],
135 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160],
136 | [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163],
137 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0],
138 | [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0],
139 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255],
140 | [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204],
141 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255],
142 | [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255],
143 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194],
144 | [102, 255, 0], [92, 0, 255]]
145 | return cmap
146 |
147 |
148 | def get_iacl_colormap():
149 | cmap = [[0, 0, 127],
150 | [0, 0, 254],
151 | [0, 96, 256],
152 | [0, 212, 255],
153 | [76, 255, 170],
154 | [170, 255, 76],
155 | [255, 229, 0],
156 | [255, 122, 0],
157 | [254, 18, 0]]
158 | return cmap
159 |
160 | def get_retouch_colormap():
161 | cmap = [[0, 0, 0],
162 | [0, 0, 254],
163 | [0, 96, 256],
164 | [0, 212, 255],
165 | [76, 255, 170],
166 | [170, 255, 76],
167 | [255, 229, 0],
168 | [255, 122, 0],
169 | [254, 18, 0]]
170 | return cmap
171 |
172 |
173 |
174 | DEFAULT_VALUES = {
175 | 'sliding_miou_kernel': 7, # Make sure this is odd!
176 | 'sliding_miou_stride': 4,
177 | }
178 |
179 | DEFAULT_CONFIG_DICT = {
180 | 'mode': 'training',
181 | 'debugging': False,
182 | 'log_every_n_epochs': 100,
183 | 'max_valid_imgs': 10,
184 | 'cuda': True,
185 | 'gpu_device': 0,
186 | 'parallel': False,
187 | 'parallel_gpu_devices': [],
188 | 'seed': 0,
189 | 'tta': False
190 | }
191 |
192 | DEFAULT_CONFIG_NESTED_DICT = {
193 | 'data': {
194 | 'transforms': ['pad'],
195 | 'transform_values': {
196 | 'crop_size': 0.5,
197 | 'crop_mode': 'random',
198 | 'crop_shape': [512, 1024]
199 | },
200 | 'split': 1,
201 | 'batch_size': 10,
202 | 'num_workers': 0,
203 | 'preload': False,
204 | 'blacklist': True,
205 | 'use_propagated': False,
206 | 'propagated_video_blacklist': False,
207 | 'propagated_quart_blacklist': False,
208 | 'use_relabeled': False,
209 | 'weighted_random': [0, 0],
210 | 'weighted_random_mode': 'v1',
211 | 'oversampling': [0, 0],
212 | 'oversampling_frac': 0.2,
213 | 'oversampling_preset': 'default',
214 | 'adaptive_batching': [0, 0],
215 | 'adaptive_sel_size': 10,
216 | 'adaptive_iou_update': 1,
217 | "repeat_factor": [0, 0],
218 | "repeat_factor_freq_thresh": 0.15,
219 | # loaders for two-step pseudo training
220 | # only loads labelled data with RF
221 | "lab_repeat_factor": [0, 0],
222 | # only loads unlabelled data
223 | "ulab_default": [0, 0],
224 | # loads lab and ulab mixed -- default choice for pseudo training
225 | "mixed_default": [0, 0],
226 | # loads lab with RF and ulab mixed
227 | "mixed_repeat_factor": [0, 0]
228 | },
229 | 'train': {
230 | 'epochs': 50,
231 | 'lr_fct': 'exponential',
232 | 'lr_batchwise': False,
233 | 'lr_restarts': [],
234 | 'lr_restart_vals': 1,
235 | 'lr_params': None,
236 | },
237 | 'loss': {
238 | 'temperature': 0.1,
239 | 'dominant_mode': 'all',
240 | 'label_scaling_mode': 'avg_pool',
241 | 'dc_weightings': {
242 | 'outer_freq': False,
243 | 'outer_entropy': False,
244 | 'outer_confusionmatrix': False,
245 | 'inner_crossentropy': False,
246 | 'inner_idealcrossentropy': False,
247 | 'neg_confusionmatrix': False,
248 | 'neg_negativity': False
249 | },
250 | }
251 | }
252 |
--------------------------------------------------------------------------------
/utils/df_from_data.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import pandas as pd
3 |
4 | import argparse
5 |
6 | # Set path to data, e.g. python df_from_data.py --path
7 | parser = argparse.ArgumentParser()
8 | path = "C:\\Users\\Theodoros Pissas\\Documents\\tresorit\\CaDIS\\segmentation"
9 | parser.add_argument('-p', '--path', type=str, default=path,
10 | help='Set path to data, e.g. python df_from_data.py --path ')
11 | args = parser.parse_args()
12 |
13 | record_list = []
14 | data_path = pathlib.Path(args.path)
15 | subfolders = [[f, f.name] for f in data_path.iterdir() if f.is_dir()]
16 | for folder_path, folder_name in subfolders:
17 | for image in (folder_path / 'Images').iterdir():
18 | record_list.append([
19 | int(folder_name[-2:]), # Video number: 'Video01' --> 1
20 | str(pathlib.PurePosixPath(pathlib.Path(folder_name) / 'Images' / image.name)), # Relative path to the image
21 | str(pathlib.PurePosixPath(pathlib.Path(folder_name) / 'Labels' / image.name)), # Relative path ot the label
22 | ])
23 | df = pd.DataFrame(data=record_list, columns=['vid_num', 'img_path', 'lbl_path'])
24 | df = df.sort_values(by=['vid_num', 'img_path']).reset_index(drop=True)
25 | df.to_pickle('../data/data.pkl')
26 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 |
4 | """ utils from openseg codebase """
5 | def is_distributed():
6 | return dist.is_initialized()
7 |
8 |
9 | def get_world_size():
10 | if not dist.is_initialized():
11 | return 1
12 | return dist.get_world_size()
13 |
14 |
15 | def get_rank():
16 | if not dist.is_initialized():
17 | return 0
18 | return dist.get_rank()
19 |
20 |
21 | def all_reduce_numpy(array):
22 | tensor = torch.from_numpy(array).cuda()
23 | dist.all_reduce(tensor)
24 | return tensor.cpu().numpy()
25 |
26 | def reduce_tensor(inp):
27 | """
28 | Reduce the loss from all processes so that
29 | process with rank 0 has the averaged results.
30 | """
31 | world_size = dist.get_world_size()
32 | if world_size < 2:
33 | return inp
34 | with torch.no_grad():
35 | reduced_inp = inp
36 | torch.distributed.reduce(reduced_inp, dst=0)
37 | return reduced_inp / world_size
38 |
39 |
40 | def barrier():
41 | """Synchronizes all processes.
42 |
43 | This collective blocks processes until the whole group enters this
44 | function.
45 | """
46 | if dist.is_initialized():
47 | dist.barrier() # processes in global group wait here until all processes reach this point
48 | return
49 |
50 | @torch.no_grad()
51 | def concat_all_gather(tensor, concat_dim=0):
52 | """ from moco
53 | Performs all_gather operation on the provided tensors.
54 | *** Warning ***: torch.distributed.all_gather has no gradient.
55 | """
56 | tensors_gather = [torch.ones_like(tensor)
57 | for _ in range(torch.distributed.get_world_size())]
58 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
59 | output = torch.cat(tensors_gather, dim=concat_dim)
60 | return output
61 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding:utf-8 -*-
3 | # Author: Donny You(youansheng@gmail.com)
4 | # Logging tool implemented with the python Package logging.
5 |
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import argparse
12 | import logging
13 | import os
14 | import sys
15 |
16 |
17 | DEFAULT_LOGFILE_LEVEL = 'debug'
18 | DEFAULT_STDOUT_LEVEL = 'info'
19 | DEFAULT_LOG_FILE = './default.log'
20 | DEFAULT_LOG_FORMAT = '%(asctime)s %(levelname)-7s %(message)s'
21 |
22 | LOG_LEVEL_DICT = {
23 | 'debug': logging.DEBUG,
24 | 'info': logging.INFO,
25 | 'warning': logging.WARNING,
26 | 'error': logging.ERROR,
27 | 'critical': logging.CRITICAL
28 | }
29 |
30 |
31 | class Logger(object):
32 | """
33 | Args:
34 | Log level: CRITICAL>ERROR>WARNING>INFO>DEBUG.
35 | Log file: The file that stores the logging info.
36 | rewrite: Clear the log file.
37 | log format: The format of log messages.
38 | stdout level: The log level to print on the screen.
39 | """
40 | logfile_level = None
41 | log_file = None
42 | log_format = None
43 | rewrite = None
44 | stdout_level = None
45 | logger = None
46 |
47 | _caches = {}
48 |
49 | @staticmethod
50 | def init(logfile_level=DEFAULT_LOGFILE_LEVEL,
51 | log_file=DEFAULT_LOG_FILE,
52 | log_format=DEFAULT_LOG_FORMAT,
53 | rewrite=False,
54 | stdout_level=None):
55 | Logger.logfile_level = logfile_level
56 | Logger.log_file = log_file
57 | Logger.log_format = log_format
58 | Logger.rewrite = rewrite
59 | Logger.stdout_level = stdout_level
60 |
61 | Logger.logger = logging.getLogger()
62 | Logger.logger.handlers = []
63 | fmt = logging.Formatter(Logger.log_format)
64 |
65 | if Logger.logfile_level is not None:
66 | filemode = 'w'
67 | if not Logger.rewrite:
68 | filemode = 'a'
69 |
70 | dir_name = os.path.dirname(os.path.abspath(Logger.log_file))
71 | if not os.path.exists(dir_name):
72 | os.makedirs(dir_name)
73 |
74 | if Logger.logfile_level not in LOG_LEVEL_DICT:
75 | print('Invalid logging level: {}'.format(Logger.logfile_level))
76 | Logger.logfile_level = DEFAULT_LOGFILE_LEVEL
77 |
78 | Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.logfile_level])
79 |
80 | fh = logging.FileHandler(Logger.log_file, mode=filemode)
81 | fh.setFormatter(fmt)
82 | fh.setLevel(LOG_LEVEL_DICT[Logger.logfile_level])
83 |
84 | Logger.logger.addHandler(fh)
85 |
86 | if stdout_level is not None:
87 | if Logger.logfile_level is None:
88 | Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.stdout_level])
89 |
90 | console = logging.StreamHandler()
91 | if Logger.stdout_level not in LOG_LEVEL_DICT:
92 | print('Invalid logging level: {}'.format(Logger.stdout_level))
93 | return
94 |
95 | console.setLevel(LOG_LEVEL_DICT[Logger.stdout_level])
96 | console.setFormatter(fmt)
97 | Logger.logger.addHandler(console)
98 |
99 | @staticmethod
100 | def set_log_file(file_path):
101 | Logger.log_file = file_path
102 | Logger.init(log_file=file_path)
103 |
104 | @staticmethod
105 | def set_logfile_level(log_level):
106 | if log_level not in LOG_LEVEL_DICT:
107 | print('Invalid logging level: {}'.format(log_level))
108 | return
109 |
110 | Logger.init(logfile_level=log_level)
111 |
112 | @staticmethod
113 | def clear_log_file():
114 | Logger.rewrite = True
115 | Logger.init(rewrite=True)
116 |
117 | @staticmethod
118 | def check_logger():
119 | if Logger.logger is None:
120 | Logger.init(logfile_level=None, stdout_level=DEFAULT_STDOUT_LEVEL)
121 |
122 | @staticmethod
123 | def set_stdout_level(log_level):
124 | if log_level not in LOG_LEVEL_DICT:
125 | print('Invalid logging level: {}'.format(log_level))
126 | return
127 |
128 | Logger.init(stdout_level=log_level)
129 |
130 | @staticmethod
131 | def debug(message):
132 | Logger.check_logger()
133 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
134 | lineno = sys._getframe().f_back.f_lineno
135 | prefix = '[{}, {}]'.format(filename,lineno)
136 | Logger.logger.debug('{} {}'.format(prefix, message))
137 |
138 | @staticmethod
139 | def info(message):
140 | Logger.check_logger()
141 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
142 | lineno = sys._getframe().f_back.f_lineno
143 | prefix = '[{}, {}]'.format(filename,lineno)
144 | Logger.logger.info('{} {}'.format(prefix, message))
145 |
146 | @staticmethod
147 | def info_once(message):
148 | Logger.check_logger()
149 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
150 | lineno = sys._getframe().f_back.f_lineno
151 | prefix = '[{}, {}]'.format(filename, lineno)
152 |
153 | if Logger._caches.get((prefix, message)) is not None:
154 | return
155 |
156 | Logger.logger.info('{} {}'.format(prefix, message))
157 | Logger._caches[(prefix, message)] = True
158 |
159 | @staticmethod
160 | def warn(message):
161 | Logger.check_logger()
162 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
163 | lineno = sys._getframe().f_back.f_lineno
164 | prefix = '[{}, {}]'.format(filename,lineno)
165 | Logger.logger.warn('{} {}'.format(prefix, message))
166 |
167 | @staticmethod
168 | def error(message):
169 | Logger.check_logger()
170 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
171 | lineno = sys._getframe().f_back.f_lineno
172 | prefix = '[{}, {}]'.format(filename,lineno)
173 | Logger.logger.error('{} {}'.format(prefix, message))
174 |
175 | @staticmethod
176 | def critical(message):
177 | Logger.check_logger()
178 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
179 | lineno = sys._getframe().f_back.f_lineno
180 | prefix = '[{}, {}]'.format(filename,lineno)
181 | Logger.logger.critical('{} {}'.format(prefix, message))
182 |
183 |
184 | def printlog(message:str, save_to_log=True, **kwargs):
185 | """prints a message in console and logs without printing in log file"""
186 | print(message, **kwargs)
187 | if save_to_log:
188 | Logger.info(message)
189 |
190 |
191 | if __name__ == "__main__":
192 | parser = argparse.ArgumentParser()
193 | parser.add_argument('--logfile_level', default="debug", type=str,
194 | dest='logfile_level', help='To set the log level to files.')
195 | parser.add_argument('--stdout_level', default=None, type=str,
196 | dest='stdout_level', help='To set the level to print to screen.')
197 | parser.add_argument('--log_file', default="./default.log", type=str,
198 | dest='log_file', help='The path of log files.')
199 | parser.add_argument('--log_format', default="%(asctime)s %(levelname)-7s %(message)s",
200 | type=str, dest='log_format', help='The format of log messages.')
201 | parser.add_argument('--rewrite', default=False, type=bool,
202 | dest='rewrite', help='Clear the log files existed.')
203 |
204 | args = parser.parse_args()
205 | Logger.init(logfile_level=args.logfile_level, stdout_level=args.stdout_level,
206 | log_file=args.log_file, log_format=args.log_format, rewrite=args.rewrite)
207 |
208 | Logger.info("info test.")
209 | Logger.debug("debug test.")
210 | Logger.warn("warn test.")
211 | Logger.error("error test.")
212 | Logger.debug("debug test.")
--------------------------------------------------------------------------------
/utils/lr_functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import OrderedDict
3 |
4 |
5 | class LRFcts:
6 | def __init__(self, config: dict, lr_restart_steps: list, lr_total_steps: int):
7 | self.base_lr = config['learning_rate']
8 | self.lr_total_steps = lr_total_steps
9 | self.lr_fct = config['lr_fct']
10 | self.batchwise = config['lr_batchwise']
11 |
12 | self.uses_restarts = True
13 | if len(lr_restart_steps)== 0:
14 | self.uses_restarts = False
15 | # Restart epochs, and base values
16 | self.lr_restarts = lr_restart_steps
17 | restart_vals = config['lr_restart_vals']
18 | if 0 not in self.lr_restarts:
19 | self.lr_restarts.insert(0, 0)
20 | self.lr_restart_vals = [1]
21 | if isinstance(restart_vals, float) or isinstance(restart_vals, int):
22 | # Base LR value reduced to fraction every restart, end set to 0
23 | for i in range(1, len(self.lr_restarts)):
24 | self.lr_restart_vals.append(self.lr_restart_vals[i - 1] * restart_vals)
25 | elif isinstance(restart_vals, list):
26 | assert len(restart_vals) == len(config['lr_restarts']) - 1, \
27 | "Value Error: lr_restart_vals is list, but not the same length as lr_restarts"
28 | self.lr_restart_vals.extend(restart_vals)
29 | if lr_total_steps not in self.lr_restarts:
30 | self.lr_restarts.append(lr_total_steps)
31 | self.lr_restart_vals.append(0)
32 | self.lr_restarts = np.array(self.lr_restarts)
33 | self.lr_restart_vals = np.array(self.lr_restart_vals)
34 |
35 | # Length of each restart
36 | self.restart_lengths = np.ones_like(self.lr_restarts)
37 | self.restart_lengths[:-1] = self.lr_restarts[1:] - self.lr_restarts[:-1]
38 |
39 | # Current restart position
40 | self.curr_restart = len(self.lr_restarts) - np.argmax((np.arange(lr_total_steps + 1)[:, np.newaxis] >= self.lr_restarts)[:, ::-1], axis=1) - 1
41 | self.lr_params = dict()
42 | if config['lr_params'] is not None:
43 | self.lr_params = config['lr_params']
44 |
45 | self.epochs_ulab = config['ulab_epochs'] if 'ulab_epochs' in config else None
46 | self.epochs_lab = config['lab_epochs'] if 'lab_epochs' in config else None
47 |
48 | if self.lr_fct == 'piecewise_static':
49 | # example entry in config['train']["piecewise_static_schedule"]: [[40,1],[50,0.1]]
50 | # if s<=40 ==> lr = learning_rate * 1 elif s<=50 ==> lr = learning_rate * 0.1
51 | assert(len(self.lr_restarts) == 2), 'with piecewise_static lr schedule lr_restarts must be empty list' \
52 | ' instead got {}'.format(self.lr_restarts)
53 | assert 'piecewise_static_schedule' in self.lr_params
54 | assert isinstance(self.lr_params['piecewise_static_schedule'], list)
55 | assert self.lr_params['piecewise_static_schedule'][-1][0] == config['epochs'], \
56 | "piecewise_static_schedule's last phase must have first element equal to number of epochs " \
57 | "instead got: {} and {} respectively".format(config['piecewise_static_schedule'][-1][0], config['epochs'])
58 |
59 | piecewise_static_schedule = self.lr_params['piecewise_static_schedule']
60 | self.piecewise_static_schedule = OrderedDict() # this is essential, it has to be an ordered dict
61 | phase_prev = 0
62 | for phase in piecewise_static_schedule: # get ordered dict from list
63 | assert phase_prev < phase[0], ' piecewise_static_schedule must have increasing first elements per phase' \
64 | ' instead got phase_prev {} and phase {}'.format(phase_prev, phase[0])
65 | self.piecewise_static_schedule[phase[0]] = phase[1]
66 |
67 | def __call__(self, step: int):
68 | if self.uses_restarts:
69 | steps_since_restart = step - self.lr_restarts[self.curr_restart[step]]
70 | base_val = self.lr_restart_vals[self.curr_restart[step]]
71 | if self.lr_fct == 'static':
72 | return base_val
73 | elif self.lr_fct == 'piecewise_static':
74 | return self.piecewise_static(step)
75 | elif self.lr_fct == 'exponential':
76 | return self.lr_exponential(base_val, steps_since_restart)
77 | elif self.lr_fct == 'polynomial':
78 | steps_in_restart = self.restart_lengths[self.curr_restart[step]]
79 | return self.lr_polynomial(base_val, steps_since_restart, steps_in_restart)
80 | elif self.lr_fct == 'cosine':
81 | steps_in_restart = self.restart_lengths[self.curr_restart[step]]
82 | return self.lr_cosine(base_val, steps_since_restart, steps_in_restart)
83 | else:
84 | ValueError("Learning rate schedule '{}' not recognised.".format(self.lr_fct))
85 | else:
86 | # todo hacky for now, remove the lr_restarts code to be used only if lr_restarts are used
87 | base_val = 1.0
88 | if (step>self.lr_total_steps):
89 | print(f'warning learning rate scheduler at step {step} exceeds expected lr_total_steps {self.lr_total_steps}')
90 | if self.lr_fct == 'exponential':
91 | return self.lr_exponential(base_val, step)
92 | elif self.lr_fct == 'polynomial':
93 | return self.lr_polynomial(base_val, step, self.lr_total_steps)
94 | elif self.lr_fct == 'linear-warmup-polynomial':
95 | assert 'warmup_iters' in self.lr_params \
96 | and 'warmup_rate' in self.lr_params, f'lr_params must be passed via config as dict with keys ' \
97 | f'warmup_iters and warmup_rate instead got {self.lr_params}'
98 | if step <= self.lr_params['warmup_iters']-1:
99 | return self.linear_warmup(step)
100 | else:
101 | return self.lr_polynomial(base_val, step, self.lr_total_steps)
102 | else:
103 | ValueError("Learning rate schedule without restarts'{}' not recognised.".format(self.lr_fct))
104 |
105 | def piecewise_static(self, step):
106 | # important this only works if self.piecewise_static_schedule is an ordered dict!
107 | for phase_end in self.piecewise_static_schedule.keys():
108 | lr = self.piecewise_static_schedule[phase_end]
109 | if step <= phase_end:
110 | return lr
111 |
112 | def linear_warmup(self, step: int):
113 | # step + 1 to account for step = 0 ... warmup_iters -1
114 |
115 | lr = 1 - (1 - (step+1) / self.lr_params['warmup_iters']) * (1 - self.lr_params['warmup_rate'])
116 | # warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
117 | return lr
118 |
119 | def lr_exponential(self, base_val: float, steps_current: int):
120 | gamma = .98 if self.lr_params is None else self.lr_params
121 | lr = base_val * gamma ** steps_current
122 | return lr
123 |
124 | def lr_polynomial(self, base_val: float, steps_current: int, max_steps: int):
125 | # max_steps - 1 to account for step = 0 ... max_steps -1
126 | # power = .9 if 'power' in self.lr_params else self.lr_params['power']
127 | power = self.lr_params.get('power', .9)
128 | # min_lr = self.lr_params['min_lr'] if 'min_lr' in self.lr_params else 0.0
129 | min_lr = self.lr_params.get('min_lr', 0.0)
130 | coeff = (1 - steps_current / (max_steps-1)) ** power
131 | lr = (base_val- min_lr) * coeff + min_lr
132 | return lr
133 |
134 | def lr_cosine(self, base_val, steps_current, max_steps):
135 | lr = base_val * 0.5 * (1. + np.cos(np.pi * steps_current / max_steps))
136 | return lr
137 |
138 |
139 | if __name__ == '__main__':
140 | def lr_exponential(base_val: float, steps_since_restart: int, steps_in_restart=None, gamma: int = .98):
141 | lr = base_val * gamma ** steps_since_restart
142 | return lr
143 |
144 | def lr_cosine(base_val, steps_since_restart, steps_in_restart):
145 | lr = base_val * 0.5 * (1. + np.cos(np.pi * steps_since_restart / steps_in_restart))
146 | return lr
147 |
148 |
149 | def linear_warmup(step: int):
150 | base_lr = 0.0001
151 | rate = 1e-6
152 | # step + 1 to account for step = 0 ... warmup_iters -1
153 | lr = 1 - (1 - (step+1) / 1500) * (1 - rate)
154 | # warmup_lr = [_lr * (1 - k) for _lr in regular_lr]
155 | return lr * base_lr
156 |
157 | def lr_polynomial( base_val: float, steps_current: int, max_steps: int):
158 | # max_steps - 1 to account for step = 0 ... max_steps -1
159 | power = 1.0
160 | min_lr = 0.0
161 | coeff = (1 - steps_current / (max_steps-1)) ** power
162 | lr = (base_val- min_lr) * coeff + min_lr
163 | return lr
164 |
165 |
166 | def linear_warmup_then_poly(step:int, total_steps):
167 | if step <= 1500 - 1:
168 | return linear_warmup(step)
169 | else:
170 | return lr_polynomial(0.0001, step, total_steps)
171 |
172 |
173 |
174 |
175 | # lr_start = 0.0001
176 | # T = 100
177 | # lrs = [lr_cosine(lr_start, step, T) for step in range(T)]
178 | # lrs_exp = [lr_exponential(lr_start, step % (T//4), T//4) for step in range(T)]
179 | #
180 | #
181 | #
182 | import matplotlib.pyplot as plt
183 | # plt.plot(lrs)
184 | # plt.plot(lrs_exp)
185 | T = 160401
186 | lrs_exp = [linear_warmup_then_poly(step, T) for step in range(T)]
187 | plt.plot(lrs_exp)
188 | plt.show()
189 | a = 1
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from utils import to_numpy, CLASS_INFO
3 |
4 |
5 | def get_confusion_matrix(prediction, target, existing_matrix=None):
6 | """Expects prediction logits (as output by network), and target as classes in single channel (as from data)"""
7 | prediction, target = to_numpy(prediction), to_numpy(target)
8 | num_classes = prediction.shape[1] # prediction is shape NCHW, we want C (one-hot length of all classes)
9 | one_hots = np.eye(num_classes)
10 | prediction = np.moveaxis(prediction, 1, 0) # Prediction is NCHW -> move C to the front to make it CNHW
11 | prediction = np.reshape(prediction, (num_classes, -1)) # Prediction is [C, N*H*W]
12 | prediction = np.argmax(prediction, 0) # Prediction is now [N*H*W]
13 | one_hot_preds = one_hots[prediction] # Prediction is now [N*H*W, C]
14 | one_hot_preds = np.moveaxis(one_hot_preds, 1, 0) # Prediction is now [C, N*H*W]
15 | one_hot_targets = one_hots[target.reshape(-1)] # Target is now [N*H*W, C]
16 | confusion_matrix = np.matmul(one_hot_preds, one_hot_targets).astype('i') # [C, N*H*W] x [N*H*W, C] = [C, C]
17 | # Consistency check:
18 | assert(np.sum(confusion_matrix) == target.size) # All elements summed equals all pixels in original target
19 | for i in range(num_classes):
20 | assert(np.sum(confusion_matrix[i]) == np.sum(prediction == i)) # Row of matrix equals class incidence in pred
21 | assert(np.sum(confusion_matrix[:, i]) == np.sum(target == i)) # Col of matrix equals class incidence in target
22 | if existing_matrix is not None:
23 | assert(existing_matrix.shape == confusion_matrix.shape)
24 | confusion_matrix += existing_matrix
25 | return confusion_matrix
26 |
27 |
28 | def normalise_confusion_matrix(matrix, mode):
29 | if mode == 'row':
30 | row_sums = matrix.sum(axis=1)
31 | row_sums[row_sums == 0] = 1 # to avoid division by 0. Safe, because if sum is 0, all elements have to be 0 too
32 | matrix = matrix / row_sums[:, np.newaxis]
33 | elif mode == 'col':
34 | col_sums = matrix.sum(axis=0)
35 | col_sums[col_sums == 0] = 1 # to avoid division by 0. Safe, because if sum is 0, all elements have to be 0 too
36 | matrix = matrix / col_sums[np.newaxis, :]
37 | else:
38 | raise ValueError("Normalise confusion matrix: mode needs to be either 'row' or 'col'.")
39 | return matrix
40 |
41 |
42 | def get_pixel_accuracy(confusion_matrix):
43 | """Pixel accuracies, adapted from https://github.com/CSAILVision/semantic-segmentation-pytorch
44 |
45 | :param confusion_matrix: Confusion matrix with absolute values. Rows are predicted classes, columns ground truths
46 | :return: Overall pixel accuracy, pixel accuracy per class (PA / PAC in CaDISv2 paper)
47 | """
48 | pred_class_correct = np.diag(confusion_matrix)
49 | acc = np.sum(pred_class_correct) / np.sum(confusion_matrix)
50 | pred_class_sums = np.sum(confusion_matrix, axis=1)
51 | pred_class_sums[pred_class_sums == 0] = 1 # To avoid division by 0 problems. Safe because all elem = 0 when sum = 0
52 | acc_per_class = np.mean(pred_class_correct / pred_class_sums)
53 | return acc, acc_per_class
54 |
55 |
56 | def get_mean_iou(confusion_matrix, experiment, categories=False, single_class=None):
57 | """Uses confusion matrix to compute mean iou. Confusion matrix computed by get_confusion_matrix: row indexes
58 | prediction class, column indexes ground truth class. Based on:
59 | github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalPixelLevelSemanticLabeling.py
60 | """
61 | assert experiment in [1, 2, 3], 'experiment must be in [1,2,3] instead got [{}]'.format(experiment)
62 | if single_class is not None:
63 | # compute miou for a single_class
64 | assert(not categories),\
65 | 'when single_class is not None, category must be False instead got [{}]'.format(categories)
66 | assert(single_class in CLASS_INFO[experiment]),\
67 | 'single_class must be {} instead got [{}]'.format(CLASS_INFO[experiment][1].keys(), single_class)
68 | return get_single_class_iou(confusion_matrix, experiment, single_class)
69 | elif categories:
70 | # compute miou for the classes of instruments and for the classes of anatomies
71 | # compute miou for all classes
72 | assert (single_class is None),\
73 | 'when category is not None, single class must be None instead got [{}]'.format(single_class)
74 | miou_instruments = np.mean([get_single_class_iou(confusion_matrix, experiment, c)
75 | for c in CLASS_INFO[experiment][2]['instruments']])
76 | miou_anatomies = np.mean([get_single_class_iou(confusion_matrix, experiment, c)
77 | for c in CLASS_INFO[experiment][2]['anatomies']])
78 | miou = np.mean([get_single_class_iou(confusion_matrix, experiment, c)
79 | for c in CLASS_INFO[experiment][1].keys()])
80 | return miou, miou_instruments, miou_anatomies
81 | else:
82 | # compute miou for all classes
83 | miou = np.mean([get_single_class_iou(confusion_matrix, experiment, c)
84 | for c in CLASS_INFO[experiment][1].keys()])
85 | return miou
86 |
87 |
88 | def get_single_class_iou(confusion_matrix, experiment, single_class):
89 | if single_class == 255: # This is the 'ignore' class helpfully introduced in exp 2 and 3. Needs to NOT be 255 here
90 | single_class = confusion_matrix.shape[0] - 1
91 | # iou = tp/(tp + fp + fn)
92 | # the number of true positive pixels for this class
93 | # the entry on the diagonal of the confusion matrix
94 | tp = confusion_matrix[single_class, single_class]
95 |
96 | # the number of false negative pixels for this class
97 | # the column sum of the matching row in the confusion matrix
98 | # minus the diagonal entry
99 | fn = confusion_matrix[:, single_class].sum() - tp
100 |
101 | # the number of false positive pixels for this class
102 | # Only pixels that are not on a pixel with ground truth class that is ignored
103 | # The row sum of the corresponding row in the confusion matrix
104 | # without the ignored rows and without the actual label of interest
105 | not_ignored = [c for c in CLASS_INFO[experiment][1].keys() if not (c == 255 or c == single_class)]
106 | fp = confusion_matrix[single_class, not_ignored].sum()
107 |
108 | # the denominator of the IOU score
109 | denom = (tp + fp + fn)
110 | if denom == 0:
111 | # return float('nan')
112 | return 0 # Otherwise the mean always returns NaN which is technically correct but not so helpful
113 | # return IOU
114 | return float(tp) / denom
115 |
--------------------------------------------------------------------------------
/utils/optimizer_utils.py:
--------------------------------------------------------------------------------
1 | from .logger import printlog
2 |
3 | def get_num_layer_stage_wise(var_name, num_max_layer):
4 | """Get the layer id to set the different learning rates in ``stage_wise``
5 | decay_type. only for convnext series
6 | Args:
7 | var_name (str): The key of the model.
8 | num_max_layer (int): Maximum number of backbone layers.
9 | Returns:
10 | int: The id number corresponding to different learning rate in
11 | ``LearningRateDecayOptimizerConstructor``.
12 | """
13 |
14 | if var_name in ('backbone.cls_token', 'backbone.mask_token', 'backbone.pos_embed'):
15 | return 0
16 | elif var_name.startswith('backbone.downsample_layers'):
17 | return 0
18 | elif var_name.startswith('backbone.stages'):
19 | stage_id = int(var_name.split('.')[2])
20 | return stage_id + 1
21 | else:
22 | return num_max_layer - 1 # this essentially means all layers beyond layers get lr = base_lr
23 |
24 |
25 | def is_in(param_group, param_group_list):
26 | # assert is_list_of(param_group_list, dict)
27 | param = set(param_group['params'])
28 | param_set = set()
29 | for group in param_group_list:
30 | param_set.update(set(group['params']))
31 | return not param.isdisjoint(param_set)
32 |
33 |
34 | def get_param_groups_using_keys(model, config):
35 | # this mode specifies keys (strings) that if present in var's name place the var in a parameter group
36 | # each such key generates a new param_group where all variables share wd_mult, lr_mult
37 | base_lr = config['train']['learning_rate']
38 | base_wd = config['train']['weight_decay']
39 | params = []
40 | parameter_groups = {}
41 | params_dict = dict(model.named_parameters())
42 | for name, param in params_dict.items():
43 | param_group = {'params': [param], "group_name": name, 'param_names': []}
44 | is_custom = (False, None)
45 | is_first_in_group = False
46 | if not param.requires_grad:
47 | params.append(param_group)
48 | continue
49 | if is_in(param_group, params):
50 | a = 1
51 | group_name = 'base_lr_wd'
52 | for custom_key in config['train']['opt_keys']:
53 | if custom_key in name:
54 | is_custom = (True, custom_key)
55 | lr_mult = config['train']['opt_keys'][custom_key].get('lr_mult', 1.0)
56 | wd_mult = config['train']['opt_keys'][custom_key].get('wd_mult', 1.0)
57 | param_group['lr'] = lr_mult * base_lr
58 | param_group['weight_decay'] = wd_mult * base_wd
59 | group_name = f'{custom_key}_lrm{lr_mult}_wdm{wd_mult}'
60 | break
61 | if not is_custom[0]:
62 | param_group['lr'] = base_lr
63 | param_group['weight_decay'] = base_wd
64 |
65 | if group_name not in parameter_groups:
66 | param_group['group_name'] = group_name
67 | parameter_groups[group_name] = param_group
68 | is_first_in_group = True
69 | # parameter_groups[group_name]['param_names'] = [name]
70 | if not is_first_in_group:
71 | parameter_groups[group_name]['params'].append(param)
72 | parameter_groups[group_name]['param_names'].append(name)
73 |
74 | params.extend(parameter_groups.values())
75 | # printlog(f'optimizer param groups : \n {params}')
76 | params_cnt = 0
77 | for g in params:
78 | params_cnt += len(g['param_names'])
79 | assert (len(params_dict) == params_cnt), f'mismatch between params in parameter groups {params_cnt}' \
80 | f' and model.named_parameters {len(params_dict)}'
81 | return params
82 |
83 | def get_param_groups_with_stage_wise_lr_decay(model, config):
84 | # adapted from convnext repo
85 | # scales the learning rate of deeper layers by decay_rate ** (num_layers - layer_id - 1)
86 | # tl,dr --> latest layers have gradually higher lr
87 | assert 'ConvNext' in config['graph']['backbone'], f"stage_wise_lr currently only supported for " \
88 | f"ConvNext backbones instead got {config['graph']['backbone']}"
89 | decay_rate = config['train']['stage_wise_lr']['decay_rate']
90 | num_layers = config['train']['stage_wise_lr']['num_layers'] + 2 # todo +2 is still a mystery (?)
91 | base_lr = config['train']['learning_rate']
92 | base_wd = config['train']['weight_decay']
93 | params = []
94 | parameter_groups = {}
95 | params_dict = dict(model.named_parameters())
96 |
97 | for name, param in params_dict.items():
98 | if len(param.shape) == 1 or name.endswith('.bias') or name in ('pos_embed', 'cls_token'):
99 | # param.shape == 1 is here to ensure some layer-norm modules have 0 weight decay
100 | # despite not containing the word "norm" in their names
101 | # for convnext these are for e.x 'backbone.downsample_layers.0.1.weight'
102 | # or 'backbone.downsample_layers.0.1.bias'
103 | group_name = 'no_decay'
104 | this_weight_decay = 0.0
105 | # printlog(name, this_weight_decay)
106 | else:
107 | group_name = 'decay'
108 | this_weight_decay = base_wd
109 | layer_id = get_num_layer_stage_wise(name, num_layers)
110 | # logger.info(f'set param {name} as id {layer_id}')
111 | group_name = f'layer_{layer_id}_{group_name}'
112 | if group_name not in parameter_groups:
113 | scale = decay_rate ** (num_layers - layer_id - 1) # scale * base_lr is the learning rate for this group
114 | # printlog(group_name, scale)
115 | parameter_groups[group_name] = {
116 | 'weight_decay': this_weight_decay,
117 | 'params': [],
118 | 'param_names': [],
119 | 'lr_scale': scale,
120 | 'group_name': group_name,
121 | 'lr': scale * base_lr,
122 | }
123 | parameter_groups[group_name]['params'].append(param)
124 | parameter_groups[group_name]['param_names'].append(name)
125 | params.extend(parameter_groups.values())
126 | # printlog(f'optimizer param groups : \n {params}')
127 | params_cnt = 0
128 | for g in params:
129 | params_cnt += len(g['param_names'])
130 | assert (len(params_dict) == params_cnt), f'mismatch between params in parameter groups {params_cnt}' \
131 | f' and model.named_parameters {len(params_dict)}'
132 | return params
133 |
--------------------------------------------------------------------------------
/utils/repeat_factor_sampling.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import pandas as pd
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import Sampler
6 | from utils import DATASETS_INFO, get_class_info, reverse_one_to_many_mapping
7 | from itertools import islice
8 | from torch.utils.data.distributed import DistributedSampler
9 | from .distributed import is_distributed, get_rank, get_world_size
10 | import math
11 |
12 |
13 | def get_class_repeat_factors_for_experiment(lbl_df: pd.DataFrame,
14 | repeat_thresh: float,
15 | exp: int,
16 | return_frequencies=False,
17 | dataset: str = 'CADIS'):
18 |
19 | experiment_cls = DATASETS_INFO[dataset].CLASS_INFO[exp][1]
20 | exp_mapping = DATASETS_INFO[dataset].CLASS_INFO[exp][0]
21 | rev_mapping = reverse_one_to_many_mapping(exp_mapping)
22 | canonical_cls = DATASETS_INFO[dataset].CLASS_NAMES[0]
23 | canonical_num_to_name = reverse_one_to_many_mapping(DATASETS_INFO[dataset].CLASS_INFO[0][1])
24 | num_frames = lbl_df.shape[0]
25 |
26 | cls_freqs = dict()
27 | cls_rfs = dict()
28 |
29 | for c in canonical_cls:
30 | c_exp = rev_mapping[canonical_num_to_name[c]] # from canonical cls name to experiment num
31 | if c_exp not in cls_freqs.keys():
32 | cls_freqs[c_exp] = 0
33 | s = lbl_df.loc[lbl_df[c] > 0].shape[0]
34 | cls_freqs[c_exp] += s / num_frames
35 |
36 | for c_exp in experiment_cls:
37 | if cls_freqs[c_exp] == 0:
38 | cls_freqs[c_exp] = repeat_thresh
39 | cls_rfs[c_exp] = np.maximum(1, np.sqrt(repeat_thresh / cls_freqs[c_exp]))
40 | cls_freqs = {k: v for k, v in sorted(cls_freqs.items(), reverse=True, key=lambda item: item[1])}
41 | cls_rfs = {k: v for k, v in sorted(cls_rfs.items(), reverse=True, key=lambda item: item[1])}
42 | if return_frequencies:
43 | return cls_freqs, cls_rfs
44 | else:
45 | return cls_rfs
46 |
47 |
48 | def get_image_repeat_factors_for_experiment(lbl_df: pd.DataFrame, cls_rfs: dict, exp: int, dataset: str):
49 | exp_mapping = DATASETS_INFO[dataset].CLASS_INFO[exp][0]
50 | rev_mapping = reverse_one_to_many_mapping(exp_mapping) # from canonical to experiment classes
51 | canonical_cls = DATASETS_INFO[dataset].CLASS_NAMES[0]
52 | canonical_num_to_name = reverse_one_to_many_mapping(DATASETS_INFO[dataset].CLASS_INFO[0][1]) # canonical class to num
53 | img_rfs = []
54 | inds = []
55 | for idx, row in lbl_df.iterrows(): # for each frame
56 | class_repeat_factors_in_frame = []
57 | for c in canonical_cls:
58 | if row[c] > 0:
59 | class_repeat_factors_in_frame.append(cls_rfs[rev_mapping[canonical_num_to_name[c]]])
60 | img_rfs.append(np.max(class_repeat_factors_in_frame))
61 | inds.append(idx)
62 | return inds, img_rfs
63 |
64 |
65 | class RepeatFactorSampler(Sampler):
66 | def __init__(self, data_source: torch.utils.data.Dataset, dataframe: pd.DataFrame,
67 | repeat_thresh: float, experiment: int, split: int, blacklist=True, seed=None, dataset='CADIS'):
68 | """ Computes repeat factors and returns repeat factor sampler
69 | Note: this sampler always uses shuffling
70 | :param data_source: a torch dataset object
71 | :param dataframe: a dataframe with class occurences as columns
72 | :param repeat_thresh: repeat factor threshold (intuitively: frequency below which rf kicks in)
73 | :param experiment: experiment id
74 | :param split: dataset split being used to determine repeat factors for each image in it.
75 | :param blacklist: whether blackslisting is to be applied
76 | :param seed: seeding for torch randomization
77 | :param dataset : todo does not support CTS currently
78 | :return RepeatFactorSampler object
79 | """
80 | super().__init__(data_source=data_source)
81 | assert(0 <= repeat_thresh < 1 and split in [0, 1, 2])
82 | seed = 1 if seed is None else seed
83 | self.seed = int(seed)
84 | self.shuffle = True # shuffling is always used with this sampler
85 | self.split = split
86 | self.repeat_thresh = repeat_thresh
87 | df = get_class_info(dataframe, 0, with_name=True)
88 | if blacklist: # drop blacklisted
89 | df = df.drop(df[df['blacklisted'] == 1].index)
90 | df.reset_index()
91 | self.class_repeat_factors, self.repeat_factors = \
92 | self.repeat_factors_class_and_image_level(df, experiment, repeat_thresh, split, dataset)
93 | self._int_part = torch.trunc(self.repeat_factors)
94 | self._frac_part = self.repeat_factors - self._int_part
95 | self.g = torch.Generator()
96 | self.g.manual_seed(self.seed)
97 | self.epoch = 0
98 | self.indices = None
99 | self.distributed = is_distributed() # todo this should be removed in the future once local has ddp package
100 |
101 | self.num_replicas = get_world_size()
102 | self.rank = get_rank()
103 | print(f'RF sampler -- world_size: {self.num_replicas} rank : {self.rank}')
104 | self.dataset = data_source
105 | # if len(self.dataset) % self.num_replicas ==0: # type: ignore
106 | # self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore
107 | # else:
108 | # self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
109 | # self.total_size = self.num_samples * self.num_replicas
110 |
111 | @staticmethod
112 | def repeat_factors_class_and_image_level(df: pd.DataFrame, experiment: int, repeat_thresh: float,
113 | split: int, dataset: str):
114 | train_videos = DATASETS_INFO[dataset].DATA_SPLITS[split][0]
115 | train_df = df.loc[df['vid_num'].isin(train_videos)]
116 | train_df = train_df.reset_index()
117 | # For each class compute the class-level repeat factor: r(c) = max(1, sqrt(t/f(c)) where f(c) is class freq
118 | class_rfs = get_class_repeat_factors_for_experiment(train_df, repeat_thresh, experiment, dataset=dataset)
119 | # For each image I, compute the image-level repeat factor: r(I) = max_{c in I} r(c)
120 | inds, rfs = get_image_repeat_factors_for_experiment(train_df, class_rfs, experiment, dataset)
121 | return class_rfs, torch.tensor(rfs, dtype=torch.float32)
122 |
123 | def __iter__(self):
124 | if self.distributed: # todo this should be removed in the future
125 | start = get_rank()
126 | step = get_world_size() # 1 if not ddp
127 | # to debug
128 | # print(f'rank {get_rank()} -slicing start {start} step {step} ')
129 | print(f'rank {get_rank()} indices : {len([i for i in islice(self._yield_indices(), start, None, step)])}')
130 | yield from islice(self._yield_indices(), start, None, step)
131 | else:
132 |
133 | yield from islice(self._yield_indices(), 0, None, 1)
134 |
135 | def _yield_indices(self):
136 | if self.indices is not None:
137 | indices = torch.tensor(self.indices, dtype=torch.int64)
138 | else:
139 | indices = self._get_epoch_indices(self.g)
140 | ind_left = self.__len__()
141 | print(f'Indices generated {ind_left}, rank : {get_rank()}')
142 | self.g.manual_seed(self.seed + self.epoch)
143 | while ind_left > 0:
144 | # each epoch may have a slightly different size due to the stochastic rounding.
145 | randperm = torch.randperm(len(indices), generator=self.g) # shuffling
146 | for item in indices[randperm]:
147 | # print(f'yielding : {item} rank : {get_rank()}')
148 | yield int(item)
149 | ind_left -= 1
150 | self.indices = None
151 |
152 | def __len__(self):
153 | if self.indices is not None:
154 | return len(self.indices)
155 | else:
156 | return len(self._get_epoch_indices(self.g))
157 |
158 | def set_epoch(self, epoch):
159 | self.epoch = epoch
160 |
161 | def _get_epoch_indices(self, generator):
162 | # stochastic rounding so that the target repeat factor
163 | # is achieved in expectation over the course of training
164 | rands = torch.rand(len(self._frac_part), generator=generator)
165 | rounded_rep_factors = self._int_part + (rands < self._frac_part).float()
166 | indices = []
167 | # replicate each image's index by its rounded repeat factor
168 | for img_index, rep_factor in enumerate(rounded_rep_factors):
169 | indices.extend([img_index] * int(rep_factor.item()))
170 | self.indices = indices
171 | if self.num_replicas>1: # self.distributed and
172 | # ensures each process has access to equal number of indices from the dataset
173 | self.num_indices = len(self.indices)
174 | if self.num_indices % self.num_replicas ==0:
175 | self.indices_per_processs = math.ceil(self.num_indices / self.num_replicas)
176 | else:
177 | self.indices_per_processs = math.ceil((self.num_indices - self.num_replicas) / self.num_replicas)
178 |
179 | self.num_indices_to_keep = self.indices_per_processs * self.num_replicas
180 | self.indices_to_keep = torch.randint(low=0, high=self.num_indices_to_keep-1,
181 | size=[self.num_indices_to_keep],
182 | generator=generator)
183 |
184 | # print(f'num_indices = {self.num_indices} - num_indices_to_keep = {self.self.num_indices_to_keep} - rank : {get_rank()}' )
185 | return torch.tensor(indices, dtype=torch.int64)[self.indices_to_keep]
186 |
187 | return torch.tensor(indices, dtype=torch.int64)
188 |
189 |
190 |
191 | if __name__ == '__main__':
192 | inds = np.arange(1000).tolist()
193 | def dummy(start):
194 | yield from islice(inds, start, None, 4)
195 | a = [[i for i in dummy(start)] for start in range(4)]
196 |
--------------------------------------------------------------------------------
/utils/semi_utis.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from collections import OrderedDict
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class BalancedConcatDataset(Dataset):
7 | def __init__(self, *datasets):
8 | self.datasets = datasets
9 | dataset_lengths = [len(d) for d in self.datasets]
10 | self.max_len = max(dataset_lengths)
11 | self.min_len = min(dataset_lengths)
12 |
13 | def __getitem__(self, i):
14 | # each item is a a tuple of 1 unlabelled sample and 1 labelled sample
15 | v = [d[i % len(d)] for d in self.datasets]
16 | b = tuple(v)
17 | # b : b[0] = list containing dataset_1 img of shape (C,H,W), mask of shape (H,W), pseudo info of shape (,)
18 | # b : b[1] = list containing dataset_2 img of shape (C,H,W), mask of shape (H,W), pseudo info of shape (,)
19 | return b
20 |
21 | def __len__(self):
22 | # stop when the longest dataset runs out of samples
23 | return self.max_len
24 |
25 |
26 | def get_video_files_from_split(ids, debug=False):
27 | """ gets list of video ids (i.e a split's train videos) and returns a list
28 | the names of the corresponding mp4 files"""
29 | dicts = dict()
30 | dicts['train_1'] = [1, 2, 3, 4, 5, 6, 7, 8] if not debug else [1, 3, 6]
31 | dicts['train_2'] = [9, 10, 11, 12, 13, 14, 15, 16]
32 | dicts['train_3'] = [17, 18, 19, 20, 21, 22, 23, 24]
33 | dicts['train_4'] = [25]
34 | files = []
35 | for i in ids:
36 | # s = "{0:0=1d}".format(i)
37 | s = "%02d" % i
38 | if i in dicts['train_1']:
39 | files.append(pathlib.Path('train_1') / pathlib.Path('train{}.mp4'.format(s)))
40 | elif i in dicts['train_2'] and not debug:
41 | files.append(pathlib.Path('train_2') / pathlib.Path('train{}.mp4'.format(s)))
42 | elif i in dicts['train_3'] and not debug:
43 | files.append(pathlib.Path('train_3') / pathlib.Path('train{}.mp4'.format(s)))
44 | elif i in dicts['train_4'] and not debug:
45 | files.append(pathlib.Path('train_4') / pathlib.Path('train{}.mp4'.format(s)))
46 | return files
47 |
48 |
49 | def get_excluded_frames_from_df(df, train_videos):
50 | train = df.loc[df['vid_num'].isin(train_videos)]
51 | train.reset_index()
52 | train = train.reset_index()
53 | train = train.drop(train[train['blacklisted'] == 1].index)
54 | train = train.reset_index()
55 | img_vid_frames = train['img_path']
56 | img_vid_frames = img_vid_frames.tolist()
57 | video_to_excluded_frames_dict = OrderedDict()
58 | for f in img_vid_frames:
59 | frame_id = int(f.split('.')[-2][-6:])
60 | video_id = f.split('Video')[-1][0:2] if '_' not in f.split('Video')[-1][0:2] else f.split('Video')[-1][0]
61 | video_id = int(video_id)
62 | if video_id in video_to_excluded_frames_dict:
63 | video_to_excluded_frames_dict[video_id].append(frame_id)
64 | else:
65 | video_to_excluded_frames_dict[video_id] = []
66 | video_to_excluded_frames_dict[video_id].append(frame_id)
67 | # sanity check
68 | assert(list(video_to_excluded_frames_dict.keys()) == train_videos)
69 | return video_to_excluded_frames_dict
70 |
71 |
--------------------------------------------------------------------------------
/utils/torch_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms.functional as F
3 | from PIL import Image
4 | import random
5 | import math
6 | from torchvision.transforms import ToPILImage
7 |
8 | class ExtRandomScale(object):
9 | def __init__(self, scale_range, interpolation=Image.BILINEAR):
10 | self.scale_range = scale_range
11 | self.interpolation = interpolation
12 |
13 | def __call__(self, img, lbl):
14 | """
15 | Args:
16 | img (PIL Image): Image to be scaled.
17 | lbl (PIL Image): Label to be scaled.
18 | Returns:
19 | PIL Image: Rescaled image.
20 | PIL Image: Rescaled label.
21 | """
22 | # assert img.size == lbl.size
23 | # scale = random.uniform(self.scale_range[0], self.scale_range[1])
24 | w, h = img.size
25 | rand_log_scale = math.log(self.scale_range[0], 2) + random.random() * (math.log(self.scale_range[1], 2) - math.log(self.scale_range[0], 2))
26 | random_scale = math.pow(2, rand_log_scale)
27 | new_size = (int(round(w * random_scale)), int(round(h * random_scale)))
28 | image = img.resize(new_size, Image.ANTIALIAS)
29 | mask = lbl.resize(new_size, Image.NEAREST)
30 | return image, mask
31 |
32 | if __name__ == '__main__':
33 | h = 8*8
34 | w = 8*8
35 | B = 2
36 | I_ = 2*torch.eye(h, w).rot90()
37 | lbl = torch.ones(size=(h, w)) - torch.eye(h, w) + I_
38 | x = torch.rand(size=(h, w, 3)).float()
39 | scaler = ExtRandomScale([0.5,2])
40 | for i in range(10):
41 | x_s, y_s = scaler(ToPILImage()(x), ToPILImage()(lbl))
42 | print(x_s.size, y_s.size)
43 |
44 |
--------------------------------------------------------------------------------
/utils/tsne_visualization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | from tsne_torch import TorchTSNE as TSNE
4 | from .utils import to_numpy
5 | import pathlib
6 |
7 | def test():
8 | f = torch.rand(size=(100, 128))
9 | f_ = TSNE(n_components=2, perplexity=30, n_iter=1000, verbose=True).fit_transform(f)
10 | l = f_.tolist()
11 | x,y = zip(*l)
12 | plt.scatter(x,y)
13 |
14 |
15 | class TsneMAnager():
16 | def __init__(self, dataset, n_classes, feat_dim, run_id=None, scale=4):
17 | self.dataset = dataset
18 | self.n_classes = n_classes
19 | self.feats_per_class = 1000
20 | self.feat_dim = feat_dim
21 | self.feats = []#torch.zeros(size=(self.n_classes, self.feats_per_class, self.feat_dim))
22 | self.labels = [] # class id per element of self.feats
23 | self.counts = [0] * self.n_classes
24 | self.scale = scale
25 | self.run_id = run_id if run_id is not None else 'tsne'
26 |
27 | # tsne settings
28 | self.perplexity = 30
29 | self.iters = 2000
30 |
31 | def accumulate(self, feats, labels):
32 | n, h, w = labels.shape
33 | assert n == 1
34 | lbl_down = torch.nn.functional.interpolate(labels.unsqueeze(1).float(), (h // self.scale, w // self.scale),
35 | mode='nearest').long()
36 | _, _, h, w = lbl_down.shape
37 | lbl_down = lbl_down.view(-1)
38 | # feats1 = feats.view(self.feat_dim, h*w)
39 | feats = feats.squeeze().view(self.feat_dim, -1) # self.feat_dim, h*w
40 | if self.dataset == 'CITYSCAPES':
41 | for cl in range(self.n_classes):
42 | views_per_class = self.feats_per_class // 500 if cl < 15 else 10
43 | if self.counts[cl] < self.feats_per_class:
44 | indices_from_cl = (lbl_down == cl).nonzero().squeeze()
45 | if len(indices_from_cl.shape) > 0:
46 | random_permutation = torch.randperm(indices_from_cl.shape[0])
47 | this_views_per_class = min(views_per_class, indices_from_cl.shape[0])
48 | if this_views_per_class > 0:
49 | sampled_indices_from_cl = indices_from_cl[random_permutation[:this_views_per_class]]
50 | self.feats.append(feats[:, sampled_indices_from_cl].T)
51 | self.labels += [cl] * this_views_per_class # class id per element of self.feats
52 | self.counts[cl] += this_views_per_class
53 | # print(f'class {cl} added {this_views_per_class} {len(indices_from_cl)} feats resulting in counts {self.counts[cl]}')
54 | else:
55 | print(f'class {cl} with counts {self.counts[cl]} is done')
56 | else:
57 | raise NotImplementedError()
58 |
59 | def compute(self, log_dir):
60 | f = torch.cat(self.feats)
61 | f_tsne = TSNE(n_components=2, perplexity=self.perplexity, n_iter=self.iters, verbose=True).fit_transform(f)
62 | l = f_tsne.tolist()
63 | x, y = zip(*l)
64 | # for colours look here https://matplotlib.org/3.5.0/gallery/color/named_colors.html
65 | cmap = {0: "red", 1: "green", 2: "blue", 3: "yellow", 4: "pink", 5: "black", 6: "orange", 7: "purple",
66 | 8: "beige", 9: "brown", 10: "gray", 11: "cyan", 12: "magenta", 13: "hotpink", 14: "darkviolet", 15: "mediumblue",
67 | 16: "lightsteelblue", 17: "gold", 18: "maroon"}
68 | colors = [cmap[l] for l in self.labels]
69 | fig = plt.scatter(x, y, c=colors, label=self.labels)
70 | plt.savefig(str(pathlib.Path(log_dir)/pathlib.Path(
71 | f'{self.run_id}_perp-{self.perplexity}_its-{self.iters}_feats-per-class-{self.feats_per_class}_scale{self.scale}.png')))
72 | print(f'counts: {[(i, c) for i, c in enumerate(self.counts)]}')
73 | return f_tsne
74 |
75 |
76 |
77 | # def get_tsne_embedddings_ms(feats_ms, labels, scale, dataset):
78 | # n, h, w = labels.shape
79 | # assert n == 1
80 | # lbl_down = torch.nn.functional.interpolate(labels.unsqueeze(1).float(), (h//scale, w//scale), mode='nearest').long()
81 | # assert isinstance(feats_ms, list)
82 | # if isinstance(feats_ms, list) or isinstance(feats_ms, tuple):
83 | # for f in feats_ms:
84 | # get_tsne_embedddings(f, labels, scale, dataset)
85 |
86 | #
87 | #
88 | # def get_tsne_embedddings(feats, labels, scale, dataset):
89 | # print(feats.shape, labels.shape)
90 | # n, h, w = labels.shape
91 | # assert n == 1
92 | # lbl_down = torch.nn.functional.interpolate(labels.unsqueeze(1).float(), (h//scale, w//scale), mode='nearest').long()
93 | # c = feats.shape[1] # feature space dimensionality
94 | # feats = feats.view(h*w, c)
95 | # lbl_down = lbl_down.view(h*w)
96 | # if dataset == 'CITYSCAPES':
97 | # print('computing tsne for CITYSCAPES')
98 | #
99 | # return 0
100 |
101 |
--------------------------------------------------------------------------------