├── src ├── __init__.py ├── dataset │ ├── __init__.py │ ├── cityscapes_coco_mixed.py │ ├── coco.py │ ├── fishyscapes.py │ ├── lost_and_found.py │ └── cityscapes.py ├── metaseg │ ├── __init__.py │ ├── metrics_setup.py │ └── metrics.pyx ├── model │ ├── __init__.py │ ├── mynn.py │ ├── Resnet.py │ ├── DualGCNNet.py │ ├── GALDNet.py │ ├── deepv3.py │ ├── wider_resnet.py │ └── SEresnext.py ├── calc.py ├── helper.py ├── model_utils.py └── imageaugmentations.py ├── preparation ├── __init__.py └── prepare_coco_segmentation.py ├── requirements.txt ├── README.md ├── config.py ├── ood_training.py ├── evaluation.py └── meta_classification.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /preparation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/metaseg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cached-property==1.5.2 2 | Cython==0.29.21 3 | dataclasses==0.8 4 | future==0.18.2 5 | h5py==3.1.0 6 | joblib==0.17.0 7 | numpy==1.19.4 8 | Pillow==8.0.1 9 | scikit-learn==0.23.2 10 | scipy==1.5.4 11 | threadpoolctl==2.1.0 12 | torch==1.7.0 13 | torchvision==0.8.1 14 | typing-extensions==3.7.4.3 15 | pycocotools==2.0.2 -------------------------------------------------------------------------------- /src/metaseg/metrics_setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compile metrics.pyx via: 3 | python3 metrics_setup.py build_ext --inplace 4 | """ 5 | 6 | from distutils.core import setup 7 | from Cython.Build import cythonize 8 | from distutils.extension import Extension 9 | import numpy as np 10 | 11 | ext_core = Extension( 12 | "metrics", 13 | sources=["metrics.pyx"], 14 | include_dirs=[np.get_include()], 15 | extra_compile_args=["-O3"]) 16 | 17 | setup(ext_modules=cythonize([ext_core])) 18 | -------------------------------------------------------------------------------- /src/model/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization 3 | """ 4 | import torch.nn as nn 5 | 6 | def Norm2d(in_channels): 7 | """ 8 | Custom Norm Function to allow flexible switching 9 | """ 10 | return nn.BatchNorm2d(in_channels) 11 | 12 | def initialize_weights(*models): 13 | """ 14 | Initialize Model Weights 15 | """ 16 | for model in models: 17 | for module in model.modules(): 18 | if isinstance(module, (nn.Conv2d, nn.Linear)): 19 | nn.init.kaiming_normal_(module.weight) 20 | if module.bias is not None: 21 | module.bias.data.zero_() 22 | elif isinstance(module, nn.BatchNorm2d): 23 | module.weight.data.fill_(1) 24 | module.bias.data.zero_() 25 | 26 | def Upsample(x, size): 27 | """ 28 | Wrapper Around the Upsample Call 29 | """ 30 | return nn.functional.interpolate(x, size=size, mode='bilinear', 31 | align_corners=True) 32 | -------------------------------------------------------------------------------- /src/calc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from src.helper import counts_array_to_data_list 4 | from sklearn.metrics import roc_curve, precision_recall_curve, average_precision_score, auc 5 | 6 | 7 | def calc_precision_recall(data, balance=False): 8 | if balance: 9 | x1 = counts_array_to_data_list(np.array(data["in"]), 1e+5) 10 | x2 = counts_array_to_data_list(np.array(data["out"]), 1e+5) 11 | else: 12 | ratio_in = np.sum(data["in"]) / (np.sum(data["in"]) + np.sum(data["out"])) 13 | ratio_out = 1 - ratio_in 14 | x1 = counts_array_to_data_list(np.array(data["in"]), 1e+7 * ratio_in) 15 | x2 = counts_array_to_data_list(np.array(data["out"]), 1e+7 * ratio_out) 16 | probas_pred1 = np.array(x1) / 100 17 | probas_pred2 = np.array(x2) / 100 18 | y_true = np.concatenate((np.zeros(len(probas_pred1)), np.ones(len(probas_pred2)))) 19 | y_scores = np.concatenate((probas_pred1, probas_pred2)) 20 | return precision_recall_curve(y_true, y_scores) + (average_precision_score(y_true, y_scores), ) 21 | 22 | 23 | def calc_sensitivity_specificity(data, balance=False): 24 | if balance: 25 | x1 = counts_array_to_data_list(np.array(data["in"]), max_size=1e+5) 26 | x2 = counts_array_to_data_list(np.array(data["out"]), max_size=1e+5) 27 | else: 28 | x1 = counts_array_to_data_list(np.array(data["in"])) 29 | x2 = counts_array_to_data_list(np.array(data["out"])) 30 | probas_pred1 = np.array(x1) / 100 31 | probas_pred2 = np.array(x2) / 100 32 | y_true = np.concatenate((np.zeros(len(probas_pred1)), np.ones(len(probas_pred2)))).astype("uint8") 33 | y_scores = np.concatenate((probas_pred1, probas_pred2)) 34 | fpr, tpr, thresholds = roc_curve(y_true, y_scores) 35 | return fpr, tpr, thresholds, auc(fpr, tpr) 36 | -------------------------------------------------------------------------------- /src/dataset/cityscapes_coco_mixed.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from src.dataset.coco import COCO 5 | from src.dataset.cityscapes import Cityscapes 6 | 7 | 8 | class CityscapesCocoMix(Dataset): 9 | 10 | def __init__(self, split='train', transform=None, 11 | cs_root="/home/datasets/cityscapes", 12 | coco_root="/home/datasets/COCO/2017", 13 | subsampling_factor=0.1, cs_split=None, coco_split=None,): 14 | 15 | self.transform = transform 16 | if cs_split is None or coco_split is None: 17 | self.cs_split = split 18 | self.coco_split = split 19 | else: 20 | self.cs_split = cs_split 21 | self.coco_split = coco_split 22 | 23 | self.cs = Cityscapes(root=cs_root, split=self.cs_split) 24 | self.coco = COCO(root=coco_root, split=self.coco_split, proxy_size=int(subsampling_factor*len(self.cs))) 25 | self.images = self.cs.images + self.coco.images 26 | self.targets = self.cs.targets + self.coco.targets 27 | self.train_id_out = self.coco.train_id_out 28 | self.num_classes = self.cs.num_train_ids 29 | self.mean = self.cs.mean 30 | self.std = self.cs.std 31 | self.void_ind = self.cs.ignore_in_eval_ids 32 | 33 | def __getitem__(self, i): 34 | """Return raw image, ground truth in PIL format and absolute path of raw image as string""" 35 | image = Image.open(self.images[i]).convert('RGB') 36 | target = Image.open(self.targets[i]).convert('L') 37 | if self.transform is not None: 38 | image, target = self.transform(image, target) 39 | return image, target 40 | 41 | def __len__(self): 42 | """Return total number of images in the whole dataset.""" 43 | return len(self.images) 44 | 45 | def __repr__(self): 46 | """Return number of images in each dataset.""" 47 | fmt_str = 'Cityscapes Split: %s\n' % self.cs_split 48 | fmt_str += '----Number of images: %d\n' % len(self.cs) 49 | fmt_str += 'COCO Split: %s\n' % self.coco_split 50 | fmt_str += '----Number of images: %d\n' % len(self.coco) 51 | return fmt_str.strip() 52 | 53 | -------------------------------------------------------------------------------- /src/dataset/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from PIL import Image 5 | from typing import Callable, Optional 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class COCO(Dataset): 10 | 11 | train_id_in = 0 12 | train_id_out = 254 13 | min_image_size = 480 14 | 15 | def __init__(self, root: str, split: str = "train", transform: Optional[Callable] = None, shuffle=True, 16 | proxy_size: Optional[int] = None) -> None: 17 | """ 18 | COCO dataset loader 19 | """ 20 | self.root = root 21 | self.coco_year = list(filter(None, self.root.split("/")))[-1] 22 | self.split = split + self.coco_year 23 | self.images = [] 24 | self.targets = [] 25 | self.transform = transform 26 | 27 | for root, _, filenames in os.walk(os.path.join(self.root, "annotations", "ood_seg_" + self.split)): 28 | assert self.split in ['train' + self.coco_year, 'val' + self.coco_year] 29 | for filename in filenames: 30 | if os.path.splitext(filename)[-1] == '.png': 31 | self.targets.append(os.path.join(root, filename)) 32 | self.images.append(os.path.join(self.root, self.split, filename.split(".")[0] + ".jpg")) 33 | 34 | """ 35 | shuffle data and subsample 36 | """ 37 | if shuffle: 38 | zipped = list(zip(self.images, self.targets)) 39 | random.shuffle(zipped) 40 | self.images, self.targets = zip(*zipped) 41 | if proxy_size is not None: 42 | self.images = list(self.images[:int(proxy_size)]) 43 | self.targets = list(self.targets[:int(proxy_size)]) 44 | 45 | def __len__(self): 46 | """Return total number of images in the whole dataset.""" 47 | return len(self.images) 48 | 49 | def __getitem__(self, i): 50 | """Return raw image and ground truth in PIL format or as torch tensor""" 51 | image = Image.open(self.images[i]).convert('RGB') 52 | target = Image.open(self.targets[i]).convert('L') 53 | if self.transform is not None: 54 | image, target = self.transform(image, target) 55 | return image, target 56 | 57 | def __repr__(self): 58 | """Return number of images in each dataset.""" 59 | fmt_str = 'Number of COCO Images: %d\n' % len(self.images) 60 | return fmt_str.strip() 61 | -------------------------------------------------------------------------------- /src/dataset/fishyscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from collections import namedtuple 5 | from src.dataset.cityscapes import Cityscapes 6 | 7 | 8 | class Fishyscapes(Dataset): 9 | 10 | FishyscapesClass = namedtuple('FishyscapesClass', ['name', 'id', 'train_id', 'hasinstances', 11 | 'ignoreineval', 'color']) 12 | # -------------------------------------------------------------------------------- 13 | # A list of all Lost & Found labels 14 | # -------------------------------------------------------------------------------- 15 | labels = [ 16 | FishyscapesClass('in-distribution', 0, 0, False, False, (144, 238, 144)), 17 | FishyscapesClass('out-distribution', 1, 1, False, False, (255, 102, 102)), 18 | FishyscapesClass('unlabeled', 2, 255, False, True, (0, 0, 0)), 19 | ] 20 | 21 | train_id_in = 0 22 | train_id_out = 1 23 | cs = Cityscapes() 24 | mean = cs.mean 25 | std = cs.std 26 | num_eval_classes = cs.num_train_ids 27 | label_id_to_name = {label.id: label.name for label in labels} 28 | train_id_to_name = {label.train_id: label.name for label in labels} 29 | trainid_to_color = {label.train_id: label.color for label in labels} 30 | label_name_to_id = {label.name: label.id for label in labels} 31 | 32 | def __init__(self, split='Static', root="/home/datasets/fishyscapes/", transform=None): 33 | """Load all filenames.""" 34 | self.transform = transform 35 | self.root = root 36 | self.split = split # ['Static', 'LostAndFound'] 37 | self.images = [] # list of all raw input images 38 | self.targets = [] # list of all ground truth TrainIds images 39 | for root, _, filenames in os.walk(os.path.join(root, self.split)): 40 | for filename in filenames: 41 | if os.path.splitext(filename)[1] == '.png': 42 | filename_base = os.path.splitext(filename)[0] 43 | self.images.append(os.path.join(root, filename_base + '.jpg')) 44 | self.targets.append(os.path.join(root, filename_base + '.png')) 45 | self.images = sorted(self.images) 46 | self.targets = sorted(self.targets) 47 | 48 | def __len__(self): 49 | """Return number of images in the dataset split.""" 50 | return len(self.images) 51 | 52 | def __getitem__(self, i): 53 | """Return raw image, trainIds as torch.Tensor or PIL Image""" 54 | image = Image.open(self.images[i]).convert('RGB') 55 | target = Image.open(self.targets[i]).convert('L') 56 | if self.transform is not None: 57 | image, target = self.transform(image, target) 58 | return image, target 59 | 60 | def __repr__(self): 61 | """Print some information about dataset.""" 62 | fmt_str = 'LostAndFound Split: %s\n' % self.split 63 | fmt_str += '----Number of images: %d\n' % len(self.images) 64 | return fmt_str.strip() 65 | -------------------------------------------------------------------------------- /preparation/prepare_coco_segmentation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | sys.path.append(os.path.dirname(sys.path[0])) 6 | 7 | from PIL import Image 8 | from config import cs_coco_roots 9 | from src.dataset.coco import COCO 10 | from pycocotools.coco import COCO as coco_tools 11 | 12 | 13 | def main(): 14 | start = time.time() 15 | root = cs_coco_roots.coco_root 16 | split = "train" 17 | year = 2017 18 | id_in = COCO.train_id_in 19 | id_out = COCO.train_id_out 20 | min_size = COCO.min_image_size 21 | annotation_file = '{}/annotations/instances_{}.json'.format(root, split+str(year)) 22 | images_dir = '{}/{}'.format(root, split+str(year)) 23 | tools = coco_tools(annotation_file) 24 | save_dir = '{}/annotations/ood_seg_{}'.format(root, split+str(year)) 25 | print("\nPrepare COCO{} {} split for OoD training".format(str(year), split)) 26 | 27 | # Names of classes that are excluded - these are Cityscapes classes also available in COCO 28 | exclude_classes = ['person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', 'traffic light', 'stop sign'] 29 | 30 | # Fetch all image ids that does not include instance from classes defined in "exclude_classes" 31 | exclude_cat_Ids = tools.getCatIds(catNms=exclude_classes) 32 | exclude_img_Ids = [] 33 | for cat_Id in exclude_cat_Ids: 34 | exclude_img_Ids += tools.getImgIds(catIds=cat_Id) 35 | exclude_img_Ids = set(exclude_img_Ids) 36 | img_Ids = [int(image[:-4]) for image in os.listdir(images_dir) if int(image[:-4]) not in exclude_img_Ids] 37 | 38 | num_masks = 0 39 | # Process each image 40 | print("Ground truth segmentation mask will be saved in:", save_dir) 41 | if not os.path.exists(save_dir): 42 | os.makedirs(save_dir) 43 | print("Created save directory:", save_dir) 44 | for i, img_Id in enumerate(img_Ids): 45 | img = tools.loadImgs(img_Id)[0] 46 | h, w = img['height'], img['width'] 47 | 48 | # Select only images with height and width of at least min_size 49 | if h >= min_size and w >= min_size: 50 | ann_Ids = tools.getAnnIds(imgIds=img['id'], iscrowd=None) 51 | annotations = tools.loadAnns(ann_Ids) 52 | 53 | # Generate binary segmentation mask 54 | mask = np.ones((h, w), dtype="uint8") * id_in 55 | for j in range(len(annotations)): 56 | mask = np.maximum(tools.annToMask(annotations[j])*id_out, mask) 57 | 58 | # Save segmentation mask 59 | Image.fromarray(mask).save(os.path.join(save_dir, "{:012d}.png".format(img_Id))) 60 | num_masks += 1 61 | print("\rImages Processed: {}/{}".format(i + 1, len(img_Ids)), end=' ') 62 | sys.stdout.flush() 63 | 64 | # Print summary 65 | print("\nNumber of created segmentation masks with height and width of at least %d pixels:" % min_size, num_masks) 66 | end = time.time() 67 | hours, rem = divmod(end - start, 3600) 68 | minutes, seconds = divmod(rem, 60) 69 | print("FINISHED {:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds)) 70 | 71 | if __name__ == '__main__': 72 | main() 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Entropy Maximization and Meta Classification for Out-of-Distribution Detection in Semantic Segmentation 3 | 4 | **Abstract** Deep neural networks (DNNs) for the semantic segmentation of images are usually trained to operate on a predefined closed set of object classes. This is in contrast to the "open world" setting where DNNs are envisioned to be deployed to. From a functional safety point of view, the ability to detect so-called "out-of-distribution" (OoD) samples, i.e., objects outside of a DNN's semantic space, is crucial for many applications such as automated driving. 5 | We present a two-step procedure for OoD detection. Firstly, we utilize samples from the COCO dataset as OoD proxy and introduce a second training objective to maximize the softmax entropy on these samples. Starting from pretrained semantic segmentation networks we re-train a number of DNNs on different in-distribution datasets and evaluate on completely disjoint OoD datasets. Secondly, we perform a transparent post-processing step to discard false positive OoD samples by so-called "meta classification". To this end, we apply linear models to a set of hand-crafted metrics derived from the DNN's softmax probabilities. 6 | Our method contributes to safer DNNs with more reliable overall system performance. 7 | 8 | * More details can be found in the preprint https://arxiv.org/abs/2012.06575 9 | * Training with [Cityscapes](https://www.cityscapes-dataset.com/) and [COCO](https://cocodataset.org), evaluation with [LostAndFound](http://www.6d-vision.com/lostandfounddataset) and [Fishyscapes](https://fishyscapes.com/) 10 | 11 | ## Requirements 12 | 13 | This code was tested with **Python 3.6.10** and **CUDA 10.2**. The following Python packages were installed via **pip 20.2.4**, see also ```requirements.txt```: 14 | ``` 15 | Cython == 0.29.21 16 | h5py == 3.1.0 17 | scikit-learn == 0.23.2 18 | scipy == 1.5.4 19 | torch == 1.7.0 20 | torchvision == 0.8.1 21 | pycocotools == 2.0.2 22 | ``` 23 | **Dataset preparation**: In ```preparation/prepare_coco_segmentation.py``` a preprocessing script can be found in order prepare the COCO images serving as OoD proxy for OoD training. This script basically generates binary segmentation masks for COCO images not containing any instances that could also be assigned to one of the Cityscapes (train-)classes. Execute via: 24 | ``` 25 | python preparation/prepare_coco_segmentation.py 26 | ``` 27 | Regarding the Cityscapes dataset, the dataloader used in this repo assumes that the *labelTrainId* images are already generated according to the [official Cityscapes script](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createTrainIdLabelImgs.py). 28 | 29 | **Cython preparation**: Make sure that the Cython script ```src/metaseg/metrics.pyx``` (on the machine where the script is deployed to) is compiled. If it has not been compiled yet: 30 | ``` 31 | cd src/metaseg/ 32 | python metrics_setup.py build_ext --inplace 33 | cd ../../ 34 | ``` 35 | For pretrained weights, see [https://github.com/NVIDIA/semantic-segmentation/tree/sdcnet](https://github.com/NVIDIA/semantic-segmentation/tree/sdcnet) (for DeepLabv3+) and [https://github.com/lxtGH/GALD-DGCNet](https://github.com/lxtGH/GALD-DGCNet) (for DualGCNNet). 36 | The weights after OoD training can be downloaded [here for DeepLabv3+](https://uni-wuppertal.sciebo.de/s/kCgnr0LQuTbrArA/download) and [here for DualGCNNet](https://uni-wuppertal.sciebo.de/s/VAXiKxZ21eAF68q/download). 37 | 38 | ## Quick start 39 | 40 | Modify settings in ```config.py```. All files will be saved in the directory defined via ```io_root``` (Different roots for each datasets that is used). Then run: 41 | ``` 42 | python ood_training.py 43 | python meta_classification.py 44 | python evaluation.py 45 | ``` 46 | 47 | ## More options 48 | 49 | For better automation of experiments, **command-line options** for ```ood_training.py```, ```meta_classification.py``` and ```evaluation.py``` are available. 50 | 51 | Use the ```[-h]``` argument for details about which parameters in ```config.py``` can be modified via command line. Example: 52 | ``` 53 | python ood_training.py -h 54 | ``` 55 | 56 | If no command-line options are provided, the settings in ```config.py``` are applied. 57 | 58 | -------------------------------------------------------------------------------- /src/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import pickle 4 | 5 | import numpy as np 6 | from collections import Counter 7 | 8 | 9 | class colors: 10 | """Class for colors""" 11 | RED = '\033[31;1m' 12 | GREEN = '\033[32;1m' 13 | YELLOW = '\033[33;1m' 14 | BLUE = '\033[34;1m' 15 | MAGENTA = '\033[35;1m' 16 | CYAN = '\033[36;1m' 17 | BOLD = '\033[1m' 18 | UNDERLINE = '\033[4m' 19 | ENDC = '\033[0m' 20 | 21 | 22 | def getColorEntry(val): 23 | """Colored value output if colorized flag is activated.""" 24 | if not isinstance(val, float) or math.isnan(val): 25 | return colors.ENDC 26 | if val < .20: 27 | return colors.RED 28 | elif val < .40: 29 | return colors.YELLOW 30 | elif val < .60: 31 | return colors.BLUE 32 | elif val < .80: 33 | return colors.CYAN 34 | else: 35 | return colors.GREEN 36 | 37 | 38 | def counts_array_to_data_list(counts_array, max_size=None): 39 | if max_size is None: 40 | max_size = np.sum(counts_array) # max of counted array entry 41 | counts_array = (counts_array / np.sum(counts_array) * max_size).astype("uint32") 42 | counts_dict = {} 43 | for i in range(1, len(counts_array) + 1): 44 | counts_dict[i] = counts_array[i - 1] 45 | return list(Counter(counts_dict).elements()) 46 | 47 | 48 | def get_save_path_metrics_i(i, metaseg_root, subdir): 49 | return os.path.join(metaseg_root, "metrics", subdir, "%04d.p" % i) 50 | 51 | 52 | def get_save_path_components_i(i, metaseg_root, subdir): 53 | return os.path.join(metaseg_root, "components", subdir, "%04d.p" % i) 54 | 55 | 56 | def metrics_dump(metrics, i, metaseg_root, subdir): 57 | dump_path = get_save_path_metrics_i(i, metaseg_root, subdir) 58 | dump_dir = os.path.dirname(dump_path) 59 | if not os.path.exists(dump_dir): 60 | os.makedirs(dump_dir) 61 | pickle.dump(metrics, open(dump_path, "wb")) 62 | 63 | 64 | def components_dump(components, i, metaseg_root, subdir): 65 | dump_path = get_save_path_components_i(i, metaseg_root, subdir) 66 | dump_dir = os.path.dirname(dump_path) 67 | if not os.path.exists(dump_dir): 68 | os.makedirs(dump_dir) 69 | pickle.dump(components.astype('int16'), open(dump_path, "wb")) 70 | 71 | 72 | def metrics_load(i, metaseg_root, subdir): 73 | return pickle.load(open(get_save_path_metrics_i(i, metaseg_root, subdir), "rb")) 74 | 75 | 76 | def components_load(i, metaseg_root, subdir): 77 | return pickle.load(open(get_save_path_components_i(i, metaseg_root, subdir), "rb")) 78 | 79 | 80 | def concatenate_metrics(metaseg_root, subdir, num_imgs): 81 | metrics = metrics_load(0, metaseg_root, subdir) 82 | start = list([0, len(metrics["S"])]) 83 | for i in range(1, num_imgs): 84 | m = metrics_load(i, metaseg_root, subdir) 85 | start += [start[-1] + len(m["S"])] 86 | for j in metrics: 87 | metrics[j] += m[j] 88 | return metrics, start 89 | 90 | 91 | def metrics_to_nparray(metrics, names, normalize=False, non_empty=False, all_metrics=None): 92 | if all_metrics is None: 93 | all_metrics = [] 94 | I = range(len(metrics['S_in'])) 95 | if non_empty: 96 | I = np.asarray(metrics['S_in']) > 0 97 | M = np.asarray([np.asarray(metrics[m])[I] for m in names]) 98 | if all_metrics == []: 99 | MM = M.copy() 100 | else: 101 | MM = np.asarray([np.asarray(all_metrics[m])[I] for m in names]) 102 | if normalize: 103 | for i in range(M.shape[0]): 104 | if names[i] != "class": 105 | M[i] = (np.asarray(M[i]) - np.mean(MM[i], axis=-1)) / (np.std(MM[i], axis=-1) + 1e-10) 106 | M = np.squeeze(M.T) 107 | return M 108 | 109 | 110 | def metrics_to_dataset(metrics, nclasses, non_empty=False, all_metrics=None): 111 | if all_metrics is None: 112 | all_metrics = [] 113 | exclude = ["iou", "iou0", "prc"] 114 | X_names = sorted([m for m in metrics if m not in exclude]) 115 | class_names = ["cprob" + str(i) for i in range(nclasses) if "cprob" + str(i) in metrics] 116 | Xa = metrics_to_nparray(metrics, X_names, normalize=True, non_empty=non_empty, all_metrics=all_metrics) 117 | classes = metrics_to_nparray(metrics, class_names, normalize=True, non_empty=non_empty, all_metrics=all_metrics) 118 | ya = metrics_to_nparray(metrics, ["iou"], non_empty=non_empty) 119 | y0a = metrics_to_nparray(metrics, ["iou0"], non_empty=non_empty) ### 1, if iou=0 120 | return Xa, classes, ya, y0a, X_names, class_names 121 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from src.dataset.cityscapes_coco_mixed import CityscapesCocoMix 3 | from src.dataset.lost_and_found import LostAndFound 4 | from src.dataset.fishyscapes import Fishyscapes 5 | 6 | TRAINSETS = ["Cityscapes+COCO"] 7 | VALSETS = ["LostAndFound", "Fishyscapes"] 8 | MODELS = ["DeepLabV3+_WideResNet38", "DualGCNNet_res50"] 9 | 10 | TRAINSET = TRAINSETS[0] 11 | VALSET = VALSETS[0] 12 | MODEL = MODELS[0] 13 | IO = "/home/chan/io/ood_detection/" 14 | 15 | class cs_coco_roots: 16 | """ 17 | OoD training roots for Cityscapes + COCO mix 18 | """ 19 | model_name = MODEL 20 | init_ckpt = os.path.join("/home/chan/io/cityscapes/weights/", model_name + ".pth") 21 | cs_root = "/home/datasets/cityscapes/" 22 | coco_root = "/home/datasets/COCO/2017" 23 | io_root = IO + "meta_ood_" + model_name 24 | weights_dir = os.path.join(io_root, "weights/") 25 | 26 | 27 | class laf_roots: 28 | """ 29 | LostAndFound config class 30 | """ 31 | model_name = MODEL 32 | init_ckpt = os.path.join("/home/chan/io/cityscapes/weights/", model_name + ".pth") 33 | eval_dataset_root = "/home/datasets/lost_and_found/" 34 | eval_sub_dir = "laf_eval" 35 | io_root = os.path.join(IO + "meta_ood_" + model_name, eval_sub_dir) 36 | weights_dir = os.path.join(io_root, "..", "weights/") 37 | 38 | 39 | class fs_roots: 40 | """ 41 | Fishyscapes config class 42 | """ 43 | model_name = MODEL 44 | init_ckpt = os.path.join("/home/chan/io/cityscapes/weights/", model_name + ".pth") 45 | eval_dataset_root = "/home/datasets/fishyscapes/" 46 | eval_sub_dir = "fs_eval" 47 | io_root = os.path.join(IO + "meta_ood_" + model_name, eval_sub_dir) 48 | weights_dir = os.path.join(io_root, "..", "weights/") 49 | 50 | 51 | class params: 52 | """ 53 | Set pipeline parameters 54 | """ 55 | training_starting_epoch = 0 56 | num_training_epochs = 1 57 | pareto_alpha = 0.9 58 | ood_subsampling_factor = 0.1 59 | learning_rate = 1e-5 60 | crop_size = 480 61 | val_epoch = num_training_epochs 62 | batch_size = 8 63 | entropy_threshold = 0.7 64 | 65 | 66 | ######################################################################### 67 | 68 | class config_training_setup(object): 69 | """ 70 | Setup config class for training 71 | If 'None' arguments are passed, the settings from above are applied 72 | """ 73 | def __init__(self, args): 74 | if args["TRAINSET"] is not None: 75 | self.TRAINSET = args["TRAINSET"] 76 | else: 77 | self.TRAINSET = TRAINSET 78 | if self.TRAINSET == "Cityscapes+COCO": 79 | self.roots = cs_coco_roots 80 | self.dataset = CityscapesCocoMix 81 | else: 82 | print("TRAINSET not correctly specified... bye...") 83 | exit() 84 | if args["MODEL"] is not None: 85 | tmp = getattr(self.roots, "model_name") 86 | roots_attr = [attr for attr in dir(self.roots) if not attr.startswith('__')] 87 | for attr in roots_attr: 88 | if tmp in getattr(self.roots, attr): 89 | rep = getattr(self.roots, attr).replace(tmp, args["MODEL"]) 90 | setattr(self.roots, attr, rep) 91 | self.params = params 92 | params_attr = [attr for attr in dir(self.params) if not attr.startswith('__')] 93 | for attr in params_attr: 94 | if attr in args: 95 | if args[attr] is not None: 96 | setattr(self.params, attr, args[attr]) 97 | roots_attr = [self.roots.weights_dir] 98 | for attr in roots_attr: 99 | if not os.path.exists(attr): 100 | print("Create directory:", attr) 101 | os.makedirs(attr) 102 | 103 | 104 | class config_evaluation_setup(object): 105 | """ 106 | Setup config class for evaluation 107 | If 'None' arguments are passed, the settings from above are applied 108 | """ 109 | def __init__(self, args): 110 | if args["VALSET"] is not None: 111 | self.VALSET = args["VALSET"] 112 | else: 113 | self.VALSET = VALSET 114 | if self.VALSET == "LostAndFound": 115 | self.roots = laf_roots 116 | self.dataset = LostAndFound 117 | if self.VALSET == "Fishyscapes": 118 | self.roots = fs_roots 119 | self.dataset = Fishyscapes 120 | self.params = params 121 | params_attr = [attr for attr in dir(self.params) if not attr.startswith('__')] 122 | for attr in params_attr: 123 | if attr in args: 124 | if args[attr] is not None: 125 | setattr(self.params, attr, args[attr]) 126 | roots_attr = [self.roots.io_root] 127 | for attr in roots_attr: 128 | if not os.path.exists(attr): 129 | print("Create directory:", attr) 130 | os.makedirs(attr) 131 | -------------------------------------------------------------------------------- /ood_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | 10 | from config import config_training_setup 11 | from src.imageaugmentations import Compose, Normalize, ToTensor, RandomCrop, RandomHorizontalFlip 12 | from src.model_utils import load_network 13 | from torch.utils.data import DataLoader 14 | 15 | 16 | def cross_entropy(logits, targets): 17 | """ 18 | cross entropy loss with one/all hot encoded targets -> logits.size()=targets.size() 19 | :param logits: torch tensor with logits obtained from network forward pass 20 | :param targets: torch tensor one/all hot encoded 21 | :return: computed loss 22 | """ 23 | neg_log_like = - 1.0 * F.log_softmax(logits, 1) 24 | L = torch.mul(targets.float(), neg_log_like) 25 | L = L.mean() 26 | return L 27 | 28 | 29 | def encode_target(target, pareto_alpha, num_classes, ignore_train_ind, ood_ind=254): 30 | """ 31 | encode target tensor with all hot encoding for OoD samples 32 | :param target: torch tensor 33 | :param pareto_alpha: OoD loss weight 34 | :param num_classes: number of classes in original task 35 | :param ignore_train_ind: void class in original task 36 | :param ood_ind: class label corresponding to OoD class 37 | :return: one/all hot encoded torch tensor 38 | """ 39 | npy = target.numpy() 40 | npz = npy.copy() 41 | npy[np.isin(npy, ood_ind)] = num_classes 42 | npy[np.isin(npy, ignore_train_ind)] = num_classes + 1 43 | enc = np.eye(num_classes + 2)[npy][..., :-2] # one hot encoding with last 2 axis cutoff 44 | enc[(npy == num_classes)] = np.full(num_classes, pareto_alpha / num_classes) # set all hot encoded vector 45 | enc[(enc == 1)] = 1 - pareto_alpha # convex combination between in and out distribution samples 46 | enc[np.isin(npz, ignore_train_ind)] = np.zeros(num_classes) 47 | enc = torch.from_numpy(enc) 48 | enc = enc.permute(0, 3, 1, 2).contiguous() 49 | return enc 50 | 51 | 52 | def training_routine(config): 53 | """Start OoD Training""" 54 | print("START OOD TRAINING") 55 | params = config.params 56 | roots = config.roots 57 | dataset = config.dataset() 58 | print("Pareto alpha:", params.pareto_alpha) 59 | start_epoch = params.training_starting_epoch 60 | epochs = params.num_training_epochs 61 | start = time.time() 62 | 63 | """Initialize model""" 64 | if start_epoch == 0: 65 | network = load_network(model_name=roots.model_name, num_classes=dataset.num_classes, 66 | ckpt_path=roots.init_ckpt, train=True) 67 | else: 68 | basename = roots.model_name + "_epoch_" + str(start_epoch) \ 69 | + "_alpha_" + str(params.pareto_alpha) + ".pth" 70 | network = load_network(model_name=roots.model_name, num_classes=dataset.num_classes, 71 | ckpt_path=os.path.join(roots.weights_dir, basename), train=True) 72 | 73 | transform = Compose([RandomHorizontalFlip(), RandomCrop(params.crop_size), ToTensor(), 74 | Normalize(dataset.mean, dataset.std)]) 75 | 76 | for epoch in range(start_epoch, start_epoch + epochs): 77 | """Perform one epoch of training""" 78 | print('\nEpoch {}/{}'.format(epoch + 1, start_epoch + epochs)) 79 | optimizer = optim.Adam(network.parameters(), lr=params.learning_rate) 80 | trainloader = config.dataset('train', transform, roots.cs_root, roots.coco_root, params.ood_subsampling_factor) 81 | dataloader = DataLoader(trainloader, batch_size=params.batch_size, shuffle=True) 82 | i = 0 83 | loss = None 84 | for x, target in dataloader: 85 | optimizer.zero_grad() 86 | logits = network(x.cuda()) 87 | y = encode_target(target=target, pareto_alpha=params.pareto_alpha, num_classes=dataset.num_classes, 88 | ignore_train_ind=dataset.void_ind, ood_ind=dataset.train_id_out).cuda() 89 | loss = cross_entropy(logits, y) 90 | loss.backward() 91 | optimizer.step() 92 | print('{} Loss: {}'.format(i, loss.item())) 93 | i += 1 94 | 95 | """Save model state""" 96 | save_basename = roots.model_name + "_epoch_" + str(epoch + 1) + "_alpha_" + str(params.pareto_alpha) + ".pth" 97 | print('Saving checkpoint', os.path.join(roots.weights_dir, save_basename)) 98 | torch.save({ 99 | 'epoch': epoch + 1, 100 | 'state_dict': network.state_dict(), 101 | 'optimizer_state_dict': optimizer.state_dict(), 102 | 'loss': loss, 103 | }, os.path.join(roots.weights_dir, save_basename)) 104 | torch.cuda.empty_cache() 105 | 106 | end = time.time() 107 | hours, rem = divmod(end - start, 3600) 108 | minutes, seconds = divmod(rem, 60) 109 | print("FINISHED {:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds)) 110 | 111 | 112 | def main(args): 113 | """Perform training""" 114 | config = config_training_setup(args) 115 | training_routine(config) 116 | 117 | 118 | if __name__ == '__main__': 119 | """Get Arguments and setup config class""" 120 | parser = argparse.ArgumentParser(description='OPTIONAL argument setting, see also config.py') 121 | parser.add_argument("-train", "--TRAINSET", nargs="?", type=str) 122 | parser.add_argument("-model", "--MODEL", nargs="?", type=str) 123 | parser.add_argument("-epoch", "--training_starting_epoch", nargs="?", type=int) 124 | parser.add_argument("-nepochs", "--num_training_epochs", nargs="?", type=int) 125 | parser.add_argument("-alpha", "--pareto_alpha", nargs="?", type=float) 126 | parser.add_argument("-lr", "--learning_rate", nargs="?", type=float) 127 | parser.add_argument("-crop", "--crop_size", nargs="?", type=int) 128 | main(vars(parser.parse_args())) 129 | -------------------------------------------------------------------------------- /src/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import h5py 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from PIL import Image 11 | 12 | from src.model.deepv3 import DeepWV3Plus 13 | from src.model.DualGCNNet import DualSeg_res50 14 | 15 | 16 | def load_network(model_name, num_classes, ckpt_path=None, train=False): 17 | network = None 18 | print("Checkpoint file:", ckpt_path) 19 | print("Load model:", model_name, end="", flush=True) 20 | if model_name == "DeepLabV3+_WideResNet38": 21 | network = nn.DataParallel(DeepWV3Plus(num_classes)) 22 | elif model_name == "DualGCNNet_res50": 23 | network = DualSeg_res50(num_classes) 24 | else: 25 | print("\nModel is not known") 26 | exit() 27 | 28 | if ckpt_path is not None: 29 | network.load_state_dict(torch.load(ckpt_path)['state_dict'], strict=False) 30 | network = network.cuda() 31 | if train: 32 | print("... ok") 33 | return network.train() 34 | else: 35 | print("... ok") 36 | return network.eval() 37 | 38 | 39 | def prediction(net, image): 40 | image = image.cuda() 41 | with torch.no_grad(): 42 | out = net(image) 43 | if isinstance(out, tuple): 44 | out = out[0] 45 | out = out.data.cpu() 46 | out = F.softmax(out, 1) 47 | return out.numpy() 48 | 49 | 50 | class inference(object): 51 | 52 | def __init__(self, params, roots, loader, num_classes=None, init_net=True): 53 | self.epoch = params.val_epoch 54 | self.alpha = params.pareto_alpha 55 | self.batch_size = params.batch_size 56 | self.model_name = roots.model_name 57 | self.batch = 0 58 | self.batch_max = int(len(loader) / self.batch_size) + (len(loader) % self.batch_size > 0) 59 | self.loader = loader 60 | self.batchloader = iter(DataLoader(loader, batch_size=self.batch_size, shuffle=False)) 61 | self.probs_root = os.path.join(roots.io_root, "probs") 62 | 63 | if self.epoch == 0: 64 | pattern = "baseline" 65 | ckpt_path = roots.init_ckpt 66 | self.probs_load_dir = os.path.join(self.probs_root, pattern) 67 | else: 68 | pattern = "epoch_" + str(self.epoch) + "_alpha_" + str(self.alpha) 69 | basename = self.model_name + "_" + pattern + ".pth" 70 | self.probs_load_dir = os.path.join(self.probs_root, pattern) 71 | ckpt_path = os.path.join(roots.weights_dir, basename) 72 | if init_net and num_classes is not None: 73 | self.net = load_network(self.model_name, num_classes, ckpt_path) 74 | 75 | def probs_gt_load(self, i, load_dir=None): 76 | if load_dir is None: 77 | load_dir = self.probs_load_dir 78 | try: 79 | filename = os.path.join(load_dir, "probs" + str(i) + ".hdf5") 80 | f_probs = h5py.File(filename, "r") 81 | probs = np.asarray(f_probs['probabilities']) 82 | gt_train = np.asarray(f_probs['gt_train_ids']) 83 | gt_label = np.asarray(f_probs['gt_label_ids']) 84 | probs = np.squeeze(probs) 85 | gt_train = np.squeeze(gt_train) 86 | gt_label = np.squeeze(gt_label) 87 | im_path = f_probs['image_path'][0].decode("utf8") 88 | except OSError: 89 | print("No probs file for image %d, therefore run inference..." % i) 90 | probs, gt_train, gt_label, im_path = self.prob_gt_calc(i) 91 | return probs, gt_train, gt_label, im_path 92 | 93 | def probs_gt_save(self, i, save_dir=None): 94 | if save_dir is None: 95 | save_dir = self.probs_load_dir 96 | if not os.path.exists(save_dir): 97 | print("Create directory:", save_dir) 98 | os.makedirs(save_dir) 99 | probs, gt_train, gt_label, im_path = self.prob_gt_calc(i) 100 | file_name = os.path.join(save_dir, "probs" + str(i) + ".hdf5") 101 | f = h5py.File(file_name, "w") 102 | f.create_dataset("probabilities", data=probs) 103 | f.create_dataset("gt_train_ids", data=gt_train) 104 | f.create_dataset("gt_label_ids", data=gt_label) 105 | f.create_dataset("image_path", data=[im_path.encode('utf8')]) 106 | print("file stored:", file_name) 107 | f.close() 108 | 109 | def probs_gt_load_batch(self): 110 | assert self.batch_size > 1, "Please use batch size > 1 or use function 'probs_gt_load()' instead, bye bye..." 111 | x, y, z, im_paths = next(self.batchloader) 112 | probs = prediction(self.net, x) 113 | gt_train = y.numpy() 114 | gt_label = z.numpy() 115 | self.batch += 1 116 | print("\rBatch %d/%d processed" % (self.batch, self.batch_max)) 117 | sys.stdout.flush() 118 | return probs, gt_train, gt_label, im_paths 119 | 120 | def prob_gt_calc(self, i): 121 | x, y = self.loader[i] 122 | probs = np.squeeze(prediction(self.net, x.unsqueeze_(0))) 123 | gt_train = y.numpy() 124 | try: 125 | gt_label = np.array(Image.open(self.loader.annotations[i]).convert('L')) 126 | except AttributeError: 127 | gt_label = np.zeros(gt_train.shape) 128 | im_path = self.loader.images[i] 129 | return probs, gt_train, gt_label, im_path 130 | 131 | 132 | def probs_gt_load(i, load_dir): 133 | try: 134 | filepath = os.path.join(load_dir, "probs" + str(i) + ".hdf5") 135 | f_probs = h5py.File(filepath, "r") 136 | probs = np.asarray(f_probs['probabilities']) 137 | gt_train = np.asarray(f_probs['gt_train_ids']) 138 | gt_label = np.asarray(f_probs['gt_label_ids']) 139 | probs = np.squeeze(probs) 140 | gt_train = np.squeeze(gt_train) 141 | gt_label = np.squeeze(gt_label) 142 | im_path = f_probs['image_path'][0].decode("utf8") 143 | except OSError: 144 | probs, gt_train, gt_label, im_path = None, None, None, None 145 | print("No probs file, see src.model_utils") 146 | exit() 147 | return probs, gt_train, gt_label, im_path 148 | -------------------------------------------------------------------------------- /src/dataset/lost_and_found.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from torch.utils.data import Dataset 4 | from collections import namedtuple 5 | from src.dataset.cityscapes import Cityscapes 6 | 7 | 8 | class LostAndFound(Dataset): 9 | 10 | LostAndFoundClass = namedtuple('LostAndFoundClass', ['name', 'id', 'train_id', 'category_name', 11 | 'category_id', 'color']) 12 | 13 | labels = [ 14 | LostAndFoundClass('unlabeled', 0, 255, 'Miscellaneous', 0, (0, 0, 0)), 15 | LostAndFoundClass('ego vehicle', 0, 255, 'Miscellaneous', 0, (0, 0, 0)), 16 | LostAndFoundClass('rectification border', 0, 255, 'Miscellaneous', 0, (0, 0, 0)), 17 | LostAndFoundClass('out of roi', 0, 255, 'Miscellaneous', 0, (0, 0, 0)), 18 | LostAndFoundClass('background', 0, 255, 'Counter hypotheses', 1, (0, 0, 0)), 19 | LostAndFoundClass('free', 1, 1, 'Counter hypotheses', 1, (128, 64, 128)), 20 | LostAndFoundClass('Crate (black)', 2, 2, 'Standard objects', 2, (0, 0, 142)), 21 | LostAndFoundClass('Crate (black - stacked)', 3, 2, 'Standard objects', 2, (0, 0, 142)), 22 | LostAndFoundClass('Crate (black - upright)', 4, 2, 'Standard objects', 2, (0, 0, 142)), 23 | LostAndFoundClass('Crate (gray)', 5, 2, 'Standard objects', 2, (0, 0, 142)), 24 | LostAndFoundClass('Crate (gray - stacked) ', 6, 2, 'Standard objects', 2, (0, 0, 142)), 25 | LostAndFoundClass('Crate (gray - upright)', 7, 2, 'Standard objects', 2, (0, 0, 142)), 26 | LostAndFoundClass('Bumper', 8, 2, 'Random hazards', 3, (0, 0, 142)), 27 | LostAndFoundClass('Cardboard box 1', 9, 2, 'Random hazards', 3, (0, 0, 142)), 28 | LostAndFoundClass('Crate (blue)', 10, 2, 'Random hazards', 3, (0, 0, 142)), 29 | LostAndFoundClass('Crate (blue - small)', 11, 2, 'Random hazards', 3, (0, 0, 142)), 30 | LostAndFoundClass('Crate (green)', 12, 2, 'Random hazards', 3, (0, 0, 142)), 31 | LostAndFoundClass('Crate (green - small)', 13, 2, 'Random hazards', 3, (0, 0, 142)), 32 | LostAndFoundClass('Exhaust Pipe', 14, 2, 'Random hazards', 3, (0, 0, 142)), 33 | LostAndFoundClass('Headlight', 15, 2, 'Random hazards', 3, (0, 0, 142)), 34 | LostAndFoundClass('Euro Pallet', 16, 2, 'Random hazards', 3, (0, 0, 142)), 35 | LostAndFoundClass('Pylon', 17, 2, 'Random hazards', 3, (0, 0, 142)), 36 | LostAndFoundClass('Pylon (large)', 18, 2, 'Random hazards', 3, (0, 0, 142)), 37 | LostAndFoundClass('Pylon (white)', 19, 2, 'Random hazards', 3, (0, 0, 142)), 38 | LostAndFoundClass('Rearview mirror', 20, 2, 'Random hazards', 3, (0, 0, 142)), 39 | LostAndFoundClass('Tire', 21, 2, 'Random hazards', 3, (0, 0, 142)), 40 | LostAndFoundClass('Ball', 22, 2, 'Emotional hazards', 4, (0, 0, 142)), 41 | LostAndFoundClass('Bicycle', 23, 2, 'Emotional hazards', 4, (0, 0, 142)), 42 | LostAndFoundClass('Dog (black)', 24, 2, 'Emotional hazards', 4, (0, 0, 142)), 43 | LostAndFoundClass('Dog (white)', 25, 2, 'Emotional hazards', 4, (0, 0, 142)), 44 | LostAndFoundClass('Kid dummy', 26, 2, 'Emotional hazards', 4, (0, 0, 142)), 45 | LostAndFoundClass('Bobby car (gray)', 27, 2, 'Emotional hazards', 4, (0, 0, 142)), 46 | LostAndFoundClass('Bobby Car (red)', 28, 2, 'Emotional hazards', 4, (0, 0, 142)), 47 | LostAndFoundClass('Bobby Car (yellow)', 29, 2, 'Emotional hazards', 4, (0, 0, 142)), 48 | LostAndFoundClass('Cardboard box 2', 30, 2, 'Random hazards', 3, (0, 0, 142)), 49 | LostAndFoundClass('Marker Pole (lying)', 31, 0, 'Random non-hazards', 5, (0, 0, 0)), 50 | LostAndFoundClass('Plastic bag (bloated)', 32, 2, 'Random hazards', 3, (0, 0, 142)), 51 | LostAndFoundClass('Post (red - lying)', 33, 0, 'Random non-hazards', 5, (0, 0, 0)), 52 | LostAndFoundClass('Post Stand', 34, 0, 'Random non-hazards', 5, (0, 0, 0)), 53 | LostAndFoundClass('Styrofoam', 35, 2, 'Random hazards', 3, (0, 0, 142)), 54 | LostAndFoundClass('Timber (small)', 36, 0, 'Random non-hazards', 5, (0, 0, 0)), 55 | LostAndFoundClass('Timber (squared)', 37, 0, 'Random non-hazards', 5, (0, 0, 0)), 56 | LostAndFoundClass('Wheel Cap', 38, 0, 'Random non-hazards', 5, (0, 0, 0)), 57 | LostAndFoundClass('Wood (thin)', 39, 0, 'Random non-hazards', 5, (0, 0, 0)), 58 | LostAndFoundClass('Kid (walking)', 40, 2, 'Humans', 6, (0, 0, 142)), 59 | LostAndFoundClass('Kid (on a bobby car)', 41, 2, 'Humans', 6, (0, 0, 142)), 60 | LostAndFoundClass('Kid (small bobby)', 42, 2, 'Humans', 6, (0, 0, 142)), 61 | LostAndFoundClass('Kid (crawling)', 43, 2, 'Humans', 6, (0, 0, 142)), 62 | ] 63 | 64 | train_id_in = 1 65 | train_id_out = 2 66 | cs = Cityscapes() 67 | mean = cs.mean 68 | std = cs.std 69 | num_eval_classes = cs.num_train_ids 70 | 71 | def __init__(self, split='test', root="/home/datasets/lost_and_found/", transform=None): 72 | """Load all filenames.""" 73 | self.transform = transform 74 | self.root = root 75 | self.split = split # ['test', 'train'] 76 | self.images = [] # list of all raw input images 77 | self.targets = [] # list of all ground truth TrainIds images 78 | self.annotations = [] # list of all ground truth LabelIds images 79 | 80 | for root, _, filenames in os.walk(os.path.join(root, 'leftImg8bit', self.split)): 81 | for filename in filenames: 82 | if os.path.splitext(filename)[1] == '.png': 83 | filename_base = '_'.join(filename.split('_')[:-1]) 84 | city = '_'.join(filename.split('_')[:-3]) 85 | self.images.append(os.path.join(root, filename_base + '_leftImg8bit.png')) 86 | target_root = os.path.join(self.root, 'gtCoarse', self.split) 87 | self.targets.append(os.path.join(target_root, city, filename_base + '_gtCoarse_labelTrainIds.png')) 88 | self.annotations.append(os.path.join(target_root, city, filename_base + '_gtCoarse_labelIds.png')) 89 | 90 | def __len__(self): 91 | """Return number of images in the dataset split.""" 92 | return len(self.images) 93 | 94 | def __getitem__(self, i): 95 | """Return raw image and trainIds as PIL image or torch.Tensor""" 96 | image = Image.open(self.images[i]).convert('RGB') 97 | target = Image.open(self.targets[i]).convert('L') 98 | if self.transform is not None: 99 | image, target = self.transform(image, target) 100 | return image, target 101 | 102 | def __repr__(self): 103 | """Return number of images in each dataset.""" 104 | fmt_str = 'LostAndFound Split: %s\n' % self.split 105 | fmt_str += '----Number of images: %d\n' % len(self.images) 106 | return fmt_str.strip() 107 | -------------------------------------------------------------------------------- /src/dataset/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | from typing import Any, Callable, Optional, Tuple 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | 7 | 8 | class Cityscapes(Dataset): 9 | """` 10 | Cityscapes Dataset http://www.cityscapes-dataset.com/ 11 | Labels based on https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 12 | """ 13 | CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id', 14 | 'has_instances', 'ignore_in_eval', 'color']) 15 | 16 | labels = [ 17 | CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)), 18 | CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)), 19 | CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)), 20 | CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)), 21 | CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)), 22 | CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)), 23 | CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)), 24 | CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)), 25 | CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)), 26 | CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)), 27 | CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)), 28 | CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)), 29 | CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)), 30 | CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)), 31 | CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)), 32 | CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)), 33 | CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)), 34 | CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)), 35 | CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)), 36 | CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)), 37 | CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)), 38 | CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)), 39 | CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)), 40 | CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)), 41 | CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)), 42 | CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)), 43 | CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)), 44 | CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)), 45 | CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)), 46 | CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)), 47 | CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)), 48 | CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)), 49 | CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)), 50 | CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)), 51 | CityscapesClass('license plate', -1, -1, 'vehicle', 7, False, True, (0, 0, 142)), 52 | ] 53 | 54 | """Normalization parameters""" 55 | mean = (0.485, 0.456, 0.406) 56 | std = (0.229, 0.224, 0.225) 57 | 58 | """Useful information from labels""" 59 | ignore_in_eval_ids, label_ids, train_ids, train_id2id = [], [], [], [] # empty lists for storing ids 60 | color_palette_train_ids = [(0, 0, 0) for i in range(256)] 61 | for i in range(len(labels)): 62 | if labels[i].ignore_in_eval and labels[i].train_id not in ignore_in_eval_ids: 63 | ignore_in_eval_ids.append(labels[i].train_id) 64 | for i in range(len(labels)): 65 | label_ids.append(labels[i].id) 66 | if labels[i].train_id not in ignore_in_eval_ids: 67 | train_ids.append(labels[i].train_id) 68 | color_palette_train_ids[labels[i].train_id] = labels[i].color 69 | train_id2id.append(labels[i].id) 70 | num_label_ids = len(set(label_ids)) # Number of ids 71 | num_train_ids = len(set(train_ids)) # Number of trainIds 72 | id2label = {label.id: label for label in labels} 73 | train_id2label = {label.train_id: label for label in labels} 74 | 75 | def __init__(self, root: str = "/home/datasets/cityscapes/", split: str = "val", mode: str = "gtFine", 76 | target_type: str = "semantic_train_id", transform: Optional[Callable] = None, 77 | predictions_root: Optional[str] = None) -> None: 78 | """ 79 | Cityscapes dataset loader 80 | """ 81 | self.root = root 82 | self.split = split 83 | self.mode = 'gtFine' if "fine" in mode.lower() else 'gtCoarse' 84 | self.transform = transform 85 | self.images_dir = os.path.join(self.root, 'leftImg8bit', self.split) 86 | self.targets_dir = os.path.join(self.root, self.mode, self.split) 87 | self.predictions_dir = os.path.join(predictions_root, self.split) if predictions_root is not None else "" 88 | self.images = [] 89 | self.targets = [] 90 | self.predictions = [] 91 | 92 | for city in os.listdir(self.images_dir): 93 | img_dir = os.path.join(self.images_dir, city) 94 | target_dir = os.path.join(self.targets_dir, city) 95 | pred_dir = os.path.join(self.predictions_dir, city) 96 | for file_name in os.listdir(img_dir): 97 | target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0], 98 | self._get_target_suffix(self.mode, target_type)) 99 | self.images.append(os.path.join(img_dir, file_name)) 100 | self.targets.append(os.path.join(target_dir, target_name)) 101 | self.predictions.append(os.path.join(pred_dir, file_name.replace("_leftImg8bit", ""))) 102 | 103 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 104 | image = Image.open(self.images[index]).convert('RGB') 105 | if self.split in ['train', 'val']: 106 | target = Image.open(self.targets[index]) 107 | else: 108 | target = None 109 | if self.transform is not None: 110 | image, target = self.transform(image, target) 111 | return image, target 112 | 113 | def __len__(self) -> int: 114 | return len(self.images) 115 | 116 | @staticmethod 117 | def _get_target_suffix(mode: str, target_type: str) -> str: 118 | if target_type == 'instance': 119 | return '{}_instanceIds.png'.format(mode) 120 | elif target_type == 'semantic_id': 121 | return '{}_labelIds.png'.format(mode) 122 | elif target_type == 'semantic_train_id': 123 | return '{}_labelTrainIds.png'.format(mode) 124 | elif target_type == 'color': 125 | return '{}_color.png'.format(mode) 126 | else: 127 | print("'%s' is not a valid target type, choose from:\n" % target_type + 128 | "['instance', 'semantic_id', 'semantic_train_id', 'color']") 129 | exit() 130 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | import pickle 9 | 10 | from config import config_evaluation_setup 11 | from src.imageaugmentations import Compose, Normalize, ToTensor 12 | from src.model_utils import inference 13 | from scipy.stats import entropy 14 | from src.calc import calc_precision_recall, calc_sensitivity_specificity 15 | from src.helper import concatenate_metrics 16 | from meta_classification import meta_classification 17 | 18 | 19 | class eval_pixels(object): 20 | """ 21 | Evaluate in vs. out separability on pixel-level 22 | """ 23 | 24 | def __init__(self, params, roots, dataset): 25 | self.params = params 26 | self.epoch = params.val_epoch 27 | self.alpha = params.pareto_alpha 28 | self.batch_size = params.batch_size 29 | self.roots = roots 30 | self.dataset = dataset 31 | self.save_dir_data = os.path.join(self.roots.io_root, "results/entropy_counts_per_pixel") 32 | self.save_dir_plot = os.path.join(self.roots.io_root, "plots") 33 | if self.epoch == 0: 34 | self.pattern = "baseline" 35 | self.save_path_data = os.path.join(self.save_dir_data, "baseline.p") 36 | else: 37 | self.pattern = "epoch_" + str(self.epoch) + "_alpha_" + str(self.alpha) 38 | self.save_path_data = os.path.join(self.save_dir_data, self.pattern + ".p") 39 | 40 | def counts(self, loader, num_bins=100, save_path=None, rewrite=False): 41 | """ 42 | Count the number in-distribution and out-distribution pixels 43 | and get the networks corresponding confidence scores 44 | :param loader: dataset loader for evaluation data 45 | :param num_bins: (int) number of bins for histogram construction 46 | :param save_path: (str) path where to save the counts data 47 | :param rewrite: (bool) whether to rewrite the data file if already exists 48 | """ 49 | print("\nCounting in-distribution and out-distribution pixels") 50 | if save_path is None: 51 | save_path = self.save_path_data 52 | if not os.path.exists(save_path) or rewrite: 53 | save_dir = os.path.dirname(save_path) 54 | if not os.path.exists(save_dir): 55 | print("Create directory", save_dir) 56 | os.makedirs(save_dir) 57 | bins = np.linspace(start=0, stop=1, num=num_bins + 1) 58 | counts = {"in": np.zeros(num_bins, dtype="int64"), "out": np.zeros(num_bins, dtype="int64")} 59 | inf = inference(self.params, self.roots, loader, self.dataset.num_eval_classes) 60 | for i in range(len(loader)): 61 | probs, gt_train, _, _ = inf.probs_gt_load(i) 62 | ent = entropy(probs, axis=0) / np.log(self.dataset.num_eval_classes) 63 | counts["in"] += np.histogram(ent[gt_train == self.dataset.train_id_in], bins=bins, density=False)[0] 64 | counts["out"] += np.histogram(ent[gt_train == self.dataset.train_id_out], bins=bins, density=False)[0] 65 | print("\rImages Processed: {}/{}".format(i + 1, len(loader)), end=' ') 66 | sys.stdout.flush() 67 | torch.cuda.empty_cache() 68 | pickle.dump(counts, open(save_path, "wb")) 69 | print("Counts data saved:", save_path) 70 | 71 | def oodd_metrics_pixel(self, datloader=None, load_path=None): 72 | """ 73 | Calculate 3 OoD detection metrics, namely AUROC, FPR95, AUPRC 74 | :param datloader: dataset loader 75 | :param load_path: (str) path to counts data (run 'counts' first) 76 | :return: OoD detection metrics 77 | """ 78 | if load_path is None: 79 | load_path = self.save_path_data 80 | if not os.path.exists(load_path): 81 | if datloader is None: 82 | print("Please, specify dataset loader") 83 | exit() 84 | self.counts(loader=datloader, save_path=load_path) 85 | data = pickle.load(open(load_path, "rb")) 86 | fpr, tpr, _, auroc = calc_sensitivity_specificity(data, balance=True) 87 | fpr95 = fpr[(np.abs(tpr - 0.95)).argmin()] 88 | _, _, _, auprc = calc_precision_recall(data) 89 | if self.epoch == 0: 90 | print("\nOoDD Metrics - Epoch %d - Baseline" % self.epoch) 91 | else: 92 | print("\nOoDD Metrics - Epoch %d - Lambda %.2f" % (self.epoch, self.alpha)) 93 | print("AUROC:", auroc) 94 | print("FPR95:", fpr95) 95 | print("AUPRC:", auprc) 96 | return auroc, fpr95, auprc 97 | 98 | 99 | def oodd_metrics_segment(params, roots, dataset, metaseg_dir=None): 100 | """ 101 | Compute number of errors before / after meta classification and compare to baseline 102 | """ 103 | epoch = params.val_epoch 104 | alpha = params.pareto_alpha 105 | thresh = params.entropy_threshold 106 | num_imgs = len(dataset) 107 | if epoch == 0: 108 | load_subdir = "baseline" + "_t" + str(thresh) 109 | else: 110 | load_subdir = "epoch_" + str(epoch) + "_alpha_" + str(alpha) + "_t" + str(thresh) 111 | if metaseg_dir is None: 112 | metaseg_dir = os.path.join(roots.io_root, "metaseg_io") 113 | try: 114 | m, _ = concatenate_metrics(metaseg_root=metaseg_dir, num_imgs=num_imgs, 115 | subdir="baseline" + "_t" + str(thresh)) 116 | fp_baseline = len([i for i in range(len(m["iou0"])) if m["iou0"][i] == 1]) 117 | m, _ = concatenate_metrics(metaseg_root=metaseg_dir, num_imgs=num_imgs, 118 | subdir="baseline" + "_t" + str(thresh) + "_gt") 119 | fn_baseline = len([i for i in range(len(m["iou"])) if m["iou0"][i] == 1]) 120 | except FileNotFoundError: 121 | fp_baseline, fn_baseline = None, None 122 | m, _ = concatenate_metrics(metaseg_root=metaseg_dir, num_imgs=num_imgs, 123 | subdir=load_subdir + "_gt") 124 | fn_training = len([i for i in range(len(m["iou"])) if m["iou0"][i] == 1]) 125 | fn_meta, fp_training, fp_meta = meta_classification(params=params, roots=roots, dataset=dataset).remove() 126 | 127 | if epoch == 0: 128 | print("\nOoDD Metrics - Epoch %d - Baseline - Entropy Threshold %.2f" % (epoch, thresh)) 129 | else: 130 | print("\nOoDD Metrics - Epoch %d - Lambda %.2f - Entropy Threshold %.2f" % (epoch, alpha, thresh)) 131 | if fp_baseline is not None and fn_baseline is not None: 132 | print("Num FPs baseline :", fp_baseline) 133 | print("Num FNs baseline :", fn_baseline) 134 | if epoch > 0: 135 | print("Num FPs OoD training :", fp_training) 136 | print("Num FNs OoD training :", fn_training) 137 | print("Num FPs OoD training + meta classifier :", fp_meta) 138 | print("Num FNs OoD training + meta classifier :", fn_meta) 139 | return fp_baseline, fn_baseline, fp_training, fn_training, fp_meta, fn_meta 140 | 141 | 142 | def main(args): 143 | config = config_evaluation_setup(args) 144 | if not args["pixel_eval"] and not args["segment_eval"]: 145 | args["pixel_eval"] = args["segment_eval"] = True 146 | 147 | transform = Compose([ToTensor(), Normalize(config.dataset.mean, config.dataset.std)]) 148 | datloader = config.dataset(root=config.roots.eval_dataset_root, transform=transform) 149 | start = time.time() 150 | 151 | """Perform evaluation""" 152 | print("\nEVALUATE MODEL: ", config.roots.model_name) 153 | if args["pixel_eval"]: 154 | print("\nPIXEL-LEVEL EVALUATION") 155 | eval_pixels(config.params, config.roots, config.dataset).oodd_metrics_pixel(datloader=datloader) 156 | 157 | if args["segment_eval"]: 158 | print("\nSEGMENT-LEVEL EVALUATION") 159 | oodd_metrics_segment(config.params, config.roots, datloader) 160 | 161 | end = time.time() 162 | hours, rem = divmod(end - start, 3600) 163 | minutes, seconds = divmod(rem, 60) 164 | print("\nFINISHED {:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds)) 165 | 166 | 167 | if __name__ == '__main__': 168 | """Get Arguments and setup config class""" 169 | parser = argparse.ArgumentParser(description='OPTIONAL argument setting, see also config.py') 170 | parser.add_argument("-train", "--TRAINSET", nargs="?", type=str) 171 | parser.add_argument("-val", "--VALSET", nargs="?", type=str) 172 | parser.add_argument("-model", "--MODEL", nargs="?", type=str) 173 | parser.add_argument("-epoch", "--val_epoch", nargs="?", type=int) 174 | parser.add_argument("-alpha", "--pareto_alpha", nargs="?", type=float) 175 | parser.add_argument("-pixel", "--pixel_eval", action='store_true') 176 | parser.add_argument("-segment", "--segment_eval", action='store_true') 177 | main(vars(parser.parse_args())) 178 | -------------------------------------------------------------------------------- /src/model/Resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code Adapted from: 3 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | import torch.nn as nn 37 | import torch.utils.model_zoo as model_zoo 38 | from . import mynn 39 | 40 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 41 | 'resnet152'] 42 | 43 | 44 | model_urls = { 45 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 46 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 47 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 48 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 49 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 50 | } 51 | 52 | 53 | def conv3x3(in_planes, out_planes, stride=1): 54 | """3x3 convolution with padding""" 55 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | 58 | 59 | class BasicBlock(nn.Module): 60 | """ 61 | Basic Block for Resnet 62 | """ 63 | expansion = 1 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(BasicBlock, self).__init__() 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = mynn.Norm2d(planes) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.conv2 = conv3x3(planes, planes) 71 | self.bn2 = mynn.Norm2d(planes) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Bottleneck(nn.Module): 95 | """ 96 | Bottleneck Layer for Resnet 97 | """ 98 | expansion = 4 99 | 100 | def __init__(self, inplanes, planes, stride=1, downsample=None): 101 | super(Bottleneck, self).__init__() 102 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 103 | self.bn1 = mynn.Norm2d(planes) 104 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 105 | padding=1, bias=False) 106 | self.bn2 = mynn.Norm2d(planes) 107 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 108 | self.bn3 = mynn.Norm2d(planes * self.expansion) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.downsample = downsample 111 | self.stride = stride 112 | 113 | def forward(self, x): 114 | residual = x 115 | 116 | out = self.conv1(x) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.downsample is not None: 128 | residual = self.downsample(x) 129 | 130 | out += residual 131 | out = self.relu(out) 132 | 133 | return out 134 | 135 | 136 | class ResNet(nn.Module): 137 | """ 138 | Resnet Global Module for Initialization 139 | """ 140 | def __init__(self, block, layers, num_classes=1000): 141 | self.inplanes = 64 142 | super(ResNet, self).__init__() 143 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = mynn.Norm2d(64) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 152 | self.avgpool = nn.AvgPool2d(7, stride=1) 153 | self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, nn.BatchNorm2d): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1): 163 | downsample = None 164 | if stride != 1 or self.inplanes != planes * block.expansion: 165 | downsample = nn.Sequential( 166 | nn.Conv2d(self.inplanes, planes * block.expansion, 167 | kernel_size=1, stride=stride, bias=False), 168 | mynn.Norm2d(planes * block.expansion), 169 | ) 170 | 171 | layers = [] 172 | layers.append(block(self.inplanes, planes, stride, downsample)) 173 | self.inplanes = planes * block.expansion 174 | for index in range(1, blocks): 175 | layers.append(block(self.inplanes, planes)) 176 | 177 | return nn.Sequential(*layers) 178 | 179 | def forward(self, x): 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | x = self.maxpool(x) 184 | 185 | x = self.layer1(x) 186 | x = self.layer2(x) 187 | x = self.layer3(x) 188 | x = self.layer4(x) 189 | 190 | x = self.avgpool(x) 191 | x = x.view(x.size(0), -1) 192 | x = self.fc(x) 193 | 194 | return x 195 | 196 | 197 | def resnet18(pretrained=True, **kwargs): 198 | """Constructs a ResNet-18 model. 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 204 | if pretrained: 205 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 206 | return model 207 | 208 | 209 | def resnet34(pretrained=True, **kwargs): 210 | """Constructs a ResNet-34 model. 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 216 | if pretrained: 217 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 218 | return model 219 | 220 | 221 | def resnet50(pretrained=True, **kwargs): 222 | """Constructs a ResNet-50 model. 223 | 224 | Args: 225 | pretrained (bool): If True, returns a model pre-trained on ImageNet 226 | """ 227 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 228 | if pretrained: 229 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 230 | return model 231 | 232 | 233 | def resnet101(pretrained=True, **kwargs): 234 | """Constructs a ResNet-101 model. 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | """ 239 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 240 | if pretrained: 241 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 242 | return model 243 | 244 | 245 | def resnet152(pretrained=True, **kwargs): 246 | """Constructs a ResNet-152 model. 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | """ 251 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 252 | if pretrained: 253 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 254 | return model 255 | -------------------------------------------------------------------------------- /src/model/DualGCNNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Xiangtai(lxt@pku.edu.cn) 4 | # Pytorch implementation of Dual-GCN net 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from src.model.GALDNet import Bottleneck, conv3x3 9 | 10 | BatchNorm2d = nn.BatchNorm2d 11 | BatchNorm1d = nn.BatchNorm1d 12 | 13 | 14 | class SpatialGCN(nn.Module): 15 | def __init__(self, plane): 16 | super(SpatialGCN, self).__init__() 17 | inter_plane = plane // 2 18 | self.node_k = nn.Conv2d(plane, inter_plane, kernel_size=1) 19 | self.node_v = nn.Conv2d(plane, inter_plane, kernel_size=1) 20 | self.node_q = nn.Conv2d(plane, inter_plane, kernel_size=1) 21 | 22 | self.conv_wg = nn.Conv1d(inter_plane, inter_plane, kernel_size=1, bias=False) 23 | self.bn_wg = BatchNorm1d(inter_plane) 24 | self.softmax = nn.Softmax(dim=2) 25 | 26 | self.out = nn.Sequential(nn.Conv2d(inter_plane, plane, kernel_size=1), 27 | BatchNorm2d(plane)) 28 | 29 | def forward(self, x): 30 | # b, c, h, w = x.size() 31 | node_k = self.node_k(x) 32 | node_v = self.node_v(x) 33 | node_q = self.node_q(x) 34 | b,c,h,w = node_k.size() 35 | node_k = node_k.view(b, c, -1).permute(0, 2, 1) 36 | node_q = node_q.view(b, c, -1) 37 | node_v = node_v.view(b, c, -1).permute(0, 2, 1) 38 | # A = k * q 39 | # AV = k * q * v 40 | # AVW = k *(q *v) * w 41 | AV = torch.bmm(node_q,node_v) 42 | AV = self.softmax(AV) 43 | AV = torch.bmm(node_k, AV) 44 | AV = AV.transpose(1, 2).contiguous() 45 | AVW = self.conv_wg(AV) 46 | AVW = self.bn_wg(AVW) 47 | AVW = AVW.view(b, c, h, -1) 48 | out = F.relu_(self.out(AVW) + x) 49 | return out 50 | 51 | 52 | class DualGCN(nn.Module): 53 | """ 54 | Feature GCN with coordinate GCN 55 | """ 56 | def __init__(self, planes, ratio=4): 57 | super(DualGCN, self).__init__() 58 | 59 | self.phi = nn.Conv2d(planes, planes // ratio * 2, kernel_size=1, bias=False) 60 | self.bn_phi = BatchNorm2d(planes // ratio * 2) 61 | self.theta = nn.Conv2d(planes, planes // ratio, kernel_size=1, bias=False) 62 | self.bn_theta = BatchNorm2d(planes // ratio) 63 | 64 | # Interaction Space 65 | # Adjacency Matrix: (-)A_g 66 | self.conv_adj = nn.Conv1d(planes // ratio, planes // ratio, kernel_size=1, bias=False) 67 | self.bn_adj = BatchNorm1d(planes // ratio) 68 | 69 | # State Update Function: W_g 70 | self.conv_wg = nn.Conv1d(planes // ratio * 2, planes // ratio * 2, kernel_size=1, bias=False) 71 | self.bn_wg = BatchNorm1d(planes // ratio * 2) 72 | 73 | # last fc 74 | self.conv3 = nn.Conv2d(planes // ratio * 2, planes, kernel_size=1, bias=False) 75 | self.bn3 = BatchNorm2d(planes) 76 | 77 | self.local = nn.Sequential( 78 | nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False), 79 | BatchNorm2d(planes), 80 | nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False), 81 | BatchNorm2d(planes), 82 | nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False), 83 | BatchNorm2d(planes)) 84 | self.gcn_local_attention = SpatialGCN(planes) 85 | 86 | self.final = nn.Sequential(nn.Conv2d(planes * 2, planes, kernel_size=1, bias=False), 87 | BatchNorm2d(planes)) 88 | 89 | def to_matrix(self, x): 90 | n, c, h, w = x.size() 91 | x = x.view(n, c, -1) 92 | return x 93 | 94 | def forward(self, feat): 95 | # # # # Local # # # # 96 | x = feat 97 | local = self.local(feat) 98 | local = self.gcn_local_attention(local) 99 | local = F.interpolate(local, size=x.size()[2:], mode='bilinear', align_corners=True) 100 | spatial_local_feat = x * local + x 101 | 102 | # # # # Projection Space # # # # 103 | x_sqz, b = x, x 104 | 105 | x_sqz = self.phi(x_sqz) 106 | x_sqz = self.bn_phi(x_sqz) 107 | x_sqz = self.to_matrix(x_sqz) 108 | 109 | b = self.theta(b) 110 | b = self.bn_theta(b) 111 | b = self.to_matrix(b) 112 | 113 | # Project 114 | z_idt = torch.matmul(x_sqz, b.transpose(1, 2)) 115 | 116 | # # # # Interaction Space # # # # 117 | z = z_idt.transpose(1, 2).contiguous() 118 | 119 | z = self.conv_adj(z) 120 | z = self.bn_adj(z) 121 | 122 | z = z.transpose(1, 2).contiguous() 123 | # Laplacian smoothing: (I - A_g)Z => Z - A_gZ 124 | z += z_idt 125 | 126 | z = self.conv_wg(z) 127 | z = self.bn_wg(z) 128 | 129 | # # # # Re-projection Space # # # # 130 | # Re-project 131 | y = torch.matmul(z, b) 132 | 133 | n, _, h, w = x.size() 134 | y = y.view(n, -1, h, w) 135 | 136 | y = self.conv3(y) 137 | y = self.bn3(y) 138 | 139 | g_out = F.relu_(x+y) 140 | 141 | # cat or sum, nearly the same results 142 | out = self.final(torch.cat((spatial_local_feat, g_out), 1)) 143 | 144 | return out 145 | 146 | 147 | class DualGCNHead(nn.Module): 148 | def __init__(self, inplanes, interplanes, num_classes): 149 | super(DualGCNHead, self).__init__() 150 | self.conva = nn.Sequential(nn.Conv2d(inplanes, interplanes, 3, padding=1, bias=False), 151 | BatchNorm2d(interplanes), 152 | nn.ReLU(interplanes)) 153 | self.dualgcn = DualGCN(interplanes) 154 | self.convb = nn.Sequential(nn.Conv2d(interplanes, interplanes, 3, padding=1, bias=False), 155 | BatchNorm2d(interplanes), 156 | nn.ReLU(interplanes)) 157 | 158 | self.bottleneck = nn.Sequential( 159 | nn.Conv2d(inplanes + interplanes, interplanes, kernel_size=3, padding=1, dilation=1, bias=False), 160 | BatchNorm2d(interplanes), 161 | nn.ReLU(interplanes), 162 | nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) 163 | ) 164 | 165 | def forward(self, x): 166 | output = self.conva(x) 167 | output = self.dualgcn(output) 168 | output = self.convb(output) 169 | output = self.bottleneck(torch.cat([x, output], 1)) 170 | return output 171 | 172 | 173 | class ResNet(nn.Module): 174 | def __init__(self, block, layers, num_classes): 175 | self.inplanes = 128 176 | super(ResNet, self).__init__() 177 | self.conv1 = nn.Sequential( 178 | conv3x3(3, 64, stride=2), 179 | BatchNorm2d(64), 180 | nn.ReLU(inplace=True), 181 | conv3x3(64, 64), 182 | BatchNorm2d(64), 183 | nn.ReLU(inplace=True), 184 | conv3x3(64, 128)) 185 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 186 | self.bn1 = BatchNorm2d(self.inplanes) 187 | self.relu = nn.ReLU(inplace=False) 188 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) 189 | self.layer1 = self._make_layer(block, 64, layers[0]) 190 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 191 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 192 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1, 2, 4)) 193 | 194 | # # # DualGCN 195 | self.head =DualGCNHead(2048, 512, num_classes) 196 | 197 | self.dsn = nn.Sequential( 198 | nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), 199 | BatchNorm2d(512), 200 | nn.Dropout2d(0.1), 201 | nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) 202 | ) 203 | 204 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1): 205 | downsample = None 206 | if stride != 1 or self.inplanes != planes * block.expansion: 207 | downsample = nn.Sequential( 208 | nn.Conv2d(self.inplanes, planes * block.expansion, 209 | kernel_size=1, stride=stride, bias=False), 210 | BatchNorm2d(planes * block.expansion)) 211 | 212 | layers = [] 213 | generate_multi_grid = lambda index, grids: grids[index % len(grids)] if isinstance(grids, tuple) else 1 214 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample, 215 | multi_grid=generate_multi_grid(0, multi_grid))) 216 | self.inplanes = planes * block.expansion 217 | for i in range(1, blocks): 218 | layers.append( 219 | block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid))) 220 | 221 | return nn.Sequential(*layers) 222 | 223 | def forward(self, x): 224 | h, w = x.size(2), x.size(3) 225 | x = self.conv1(x) 226 | x = self.bn1(x) 227 | x = self.relu(x) 228 | x = self.maxpool(x) 229 | x = self.layer1(x) 230 | x = self.layer2(x) 231 | x = self.layer3(x) 232 | x_dsn = self.dsn(x) 233 | x = self.layer4(x) 234 | x = self.head(x) 235 | x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=True) 236 | # return [x, x_dsn] 237 | return x 238 | 239 | 240 | def DualSeg_res101(num_classes=21): 241 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes) 242 | return model 243 | 244 | 245 | def DualSeg_res50(num_classes=21): 246 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes) 247 | return model 248 | -------------------------------------------------------------------------------- /meta_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import os 4 | import pickle 5 | import sys 6 | 7 | import numpy as np 8 | 9 | from config import config_evaluation_setup 10 | from src.imageaugmentations import Compose, Normalize, ToTensor 11 | from src.model_utils import probs_gt_load 12 | from src.helper import metrics_dump, components_dump, concatenate_metrics, metrics_to_dataset, components_load 13 | from src.model_utils import inference 14 | from multiprocessing import Pool, cpu_count 15 | from scipy.stats import entropy 16 | from sklearn.metrics import accuracy_score, roc_auc_score, average_precision_score 17 | from sklearn.model_selection import LeaveOneOut, KFold 18 | from sklearn.linear_model import LogisticRegression 19 | 20 | try: 21 | from src.metaseg.metrics import compute_metrics_components, compute_metrics_mask 22 | except ImportError: 23 | compute_metrics_components, compute_metrics_mask = None, None 24 | print("MetaSeg ImportError: Maybe need to compile (src/metaseg/)metrics.pyx ....") 25 | exit() 26 | 27 | 28 | def metaseg_prepare(params, roots, dataset): 29 | """Generate Metaseg input which are .hdf5 files""" 30 | inf = inference(params, roots, dataset, dataset.num_eval_classes) 31 | for i in range(len(dataset)): 32 | inf.probs_gt_save(i) 33 | 34 | 35 | def entropy_segments_mask(probs, t): 36 | """Generate OoD prediction mask from softmax probabilities""" 37 | ent = entropy(probs) / np.log(probs.shape[0]) 38 | ent[ent < t] = 0 39 | ent[ent >= t] = 1 40 | return ent.astype("uint8") 41 | 42 | 43 | class compute_metrics(object): 44 | """ 45 | Compute the hand-crafted segment-wise metrics serving as meta classification input 46 | """ 47 | def __init__(self, params, roots, dataset, num_imgs=None, metaseg_dir=None, num_cores=1): 48 | self.epoch = params.val_epoch 49 | self.alpha = params.pareto_alpha 50 | self.thresh = params.entropy_threshold 51 | self.dataset = dataset 52 | if self.epoch == 0: 53 | self.load_dir = os.path.join(roots.io_root, "probs/baseline") 54 | self.save_subdir = "baseline" + "_t" + str(self.thresh) 55 | else: 56 | self.load_dir = os.path.join(roots.io_root, "probs/epoch_" + str(self.epoch) + "_alpha_" + str(self.alpha)) 57 | self.save_subdir = "epoch_" + str(self.epoch) + "_alpha_" + str(self.alpha) + "_t" + str(self.thresh) 58 | self.num_imgs = len(dataset) if num_imgs is None else num_imgs 59 | if metaseg_dir is None: 60 | self.metaseg_dir = os.path.join(roots.io_root, "metaseg_io") 61 | else: 62 | self.metaseg_dir = metaseg_dir 63 | self.num_cores = num_cores 64 | 65 | def compute_metrics_per_image(self, num_cores=None): 66 | """Perform segment search and compute corresponding segment-wise metrics""" 67 | print("Calculating statistics for", self.save_subdir) 68 | if num_cores is None: 69 | num_cores = self.num_cores 70 | p_args = [(k,) for k in range(self.num_imgs)] 71 | Pool(num_cores).starmap(self.compute_metrics_pred_i, p_args) 72 | Pool(num_cores).starmap(self.compute_metrics_gt_i, p_args) 73 | 74 | def compute_metrics_pred_i(self, i): 75 | """Compute metrics for predicted segments in one image""" 76 | start_i = time.time() 77 | probs, gt_train, _, img_path = probs_gt_load(i, load_dir=self.load_dir) 78 | ent_mask = entropy_segments_mask(probs, self.thresh) 79 | metrics, components = compute_metrics_components(probs=probs, gt_train=gt_train, ood_mask=ent_mask, 80 | ood_index=self.dataset.train_id_out) 81 | metrics_dump(metrics, i, metaseg_root=self.metaseg_dir, subdir=self.save_subdir) 82 | components_dump(components, i, metaseg_root=self.metaseg_dir, subdir=self.save_subdir) 83 | print("image", i, "processed in {}s\r".format(round(time.time() - start_i))) 84 | 85 | def compute_metrics_gt_i(self, i): 86 | """Compute metrics for ground truth segments in one image""" 87 | start_i = time.time() 88 | probs, gt_train, gt_label, img_path = probs_gt_load(i, load_dir=self.load_dir) 89 | ent_mask = entropy_segments_mask(probs, self.thresh) 90 | metrics, components = compute_metrics_mask(probs=probs, mask=ent_mask, gt_train=gt_train, gt_label=gt_label, 91 | ood_index=self.dataset.train_id_out) 92 | metrics_dump(metrics, i, metaseg_root=self.metaseg_dir, subdir=self.save_subdir + "_gt") 93 | components_dump(components, i, metaseg_root=self.metaseg_dir, subdir=self.save_subdir + "_gt") 94 | print("image", i, "processed in {}s\r".format(round(time.time() - start_i))) 95 | 96 | 97 | class meta_classification(object): 98 | """ 99 | Perform meta classification with the aid of logistic regressions in order to remove false positive OoD predictions 100 | """ 101 | def __init__(self, params, roots, dataset=None, num_imgs=None, metaseg_dir=None): 102 | self.epoch = params.val_epoch 103 | self.alpha = params.pareto_alpha 104 | self.thresh = params.entropy_threshold 105 | self.dataset = dataset 106 | self.net = roots.model_name 107 | if self.epoch == 0: 108 | self.load_subdir = "baseline" + "_t" + str(self.thresh) 109 | else: 110 | self.load_subdir = "epoch_" + str(self.epoch) + "_alpha_" + str(self.alpha) + "_t" + str(self.thresh) 111 | if metaseg_dir is None: 112 | self.metaseg_dir = os.path.join(roots.io_root, "metaseg_io") 113 | else: 114 | self.metaseg_dir = metaseg_dir 115 | self.num_imgs = len(self.dataset) if num_imgs is None else num_imgs 116 | 117 | def classifier_fit_and_predict(self): 118 | """Fit a logistic regression and cross validate performance""" 119 | print("\nClassifier fit and predict") 120 | metrics, start = concatenate_metrics(metaseg_root=self.metaseg_dir, subdir=self.load_subdir, 121 | num_imgs=self.num_imgs) 122 | Xa, _, _, y0a, X_names, class_names = metrics_to_dataset(metrics, self.dataset.num_eval_classes) 123 | y_pred_proba = np.zeros((len(y0a), 2)) 124 | 125 | model = LogisticRegression(solver="liblinear") 126 | loo = LeaveOneOut() 127 | 128 | for train_index, test_index in loo.split(Xa): 129 | print("TRAIN:", train_index, "TEST:", test_index) 130 | X_train, X_test = Xa[train_index], Xa[test_index] 131 | y_train, y_test = y0a[train_index], y0a[test_index] 132 | model.fit(X_train, y_train) 133 | y_pred_proba[test_index] = model.predict_proba(X_test) 134 | 135 | auroc = roc_auc_score(y0a, y_pred_proba[:, 1]) 136 | auprc = average_precision_score(y0a, y_pred_proba[:, 1]) 137 | y_pred = np.argmax(y_pred_proba, axis=-1) 138 | acc = accuracy_score(y0a, y_pred) 139 | print("\nMeta classifier performance scores:") 140 | print("AUROC:", auroc) 141 | print("AUPRC:", auprc) 142 | print("Accuracy:", acc) 143 | 144 | metrics["kick"] = y_pred 145 | metrics["start"] = start 146 | metrics["auroc"] = auroc 147 | metrics["auprc"] = auprc 148 | metrics["acc"] = acc 149 | 150 | save_path = os.path.join(self.metaseg_dir, "metrics", self.load_subdir, "meta_classified.p") 151 | with open(save_path, 'wb') as f: 152 | pickle.dump(metrics, f, pickle.HIGHEST_PROTOCOL) 153 | print("Saved meta classified:", save_path) 154 | return metrics, start 155 | 156 | def remove(self): 157 | """Based on a meta classifier's decision, remove false positive predictions""" 158 | print("\nRemoving False positive OoD segment predictions") 159 | load_path = os.path.join(self.metaseg_dir, "metrics", self.load_subdir, "meta_classified.p") 160 | if os.path.isfile(load_path): 161 | with open(load_path, "rb") as metrics_file: 162 | metrics = pickle.load(metrics_file) 163 | K = metrics["kick"] 164 | start = metrics["start"] 165 | else: 166 | metrics, start = self.classifier_fit_and_predict() 167 | K = metrics["kick"] 168 | fn_after = 0 169 | for i in range(len(start) - 1): 170 | comp_pred = abs(components_load(i, self.metaseg_dir, self.load_subdir)).flatten() 171 | for l, k in enumerate(K[start[i]:start[i + 1]]): 172 | if k == 1: 173 | comp_pred[comp_pred == l + 1] = 0 174 | comp_pred[comp_pred > 0] = 1 175 | comp_gt = abs(components_load(i, self.metaseg_dir, self.load_subdir + "_gt")).flatten() 176 | for c in np.unique(comp_gt)[1:]: 177 | comp_c = np.squeeze([comp_gt == c]) 178 | if np.sum(comp_c[comp_pred > 0]) == 0: 179 | fn_after += 1 180 | print("\rImages Processed: %d, Num FNs: %d" % (i + 1, fn_after), end=' ') 181 | sys.stdout.flush() 182 | fp_before = len([i for i in range(len(metrics["iou"])) if metrics["iou"][i] == 0]) 183 | fp_after = np.sum([metrics["kick"] == 1]) - np.sum(np.array(metrics["iou0"])[metrics["kick"] == 1]) 184 | return fn_after, fp_before, fp_after 185 | 186 | 187 | def main(args): 188 | config = config_evaluation_setup(args) 189 | 190 | transform = Compose([ToTensor(), Normalize(config.dataset.mean, config.dataset.std)]) 191 | datloader = config.dataset(root=config.roots.eval_dataset_root, split="test", transform=transform) 192 | start = time.time() 193 | 194 | """Perform Meta Classification""" 195 | if not args["metaseg_prepare"] and not args["segment_search"] and not args["fp_removal"]: 196 | args["metaseg_prepare"] = args["segment_search"] = args["fp_removal"] = True 197 | if args["metaseg_prepare"]: 198 | print("PREPARE METASEG INPUT") 199 | metaseg_prepare(config.params, config.roots, datloader) 200 | if args["segment_search"]: 201 | print("SEGMENT SEARCH") 202 | compute_metrics(config.params, config.roots, datloader, num_cores=cpu_count() - 2).compute_metrics_per_image() 203 | if args["fp_removal"]: 204 | print("FALSE POSITIVE REMOVAL VIA META CLASSIFICATION") 205 | meta_classification(config.params, config.roots, datloader).classifier_fit_and_predict() 206 | 207 | end = time.time() 208 | hours, rem = divmod(end - start, 3600) 209 | minutes, seconds = divmod(rem, 60) 210 | print("\nFINISHED {:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds)) 211 | 212 | 213 | if __name__ == '__main__': 214 | """Get Arguments and setup config class""" 215 | parser = argparse.ArgumentParser(description='OPTIONAL argument setting, see also config.py') 216 | parser.add_argument("-val", "--VALSET", nargs="?", type=str) 217 | parser.add_argument("-model", "--MODEL", nargs="?", type=str) 218 | parser.add_argument("-epoch", "--val_epoch", nargs="?", type=int) 219 | parser.add_argument("-alpha", "--pareto_alpha", nargs="?", type=float) 220 | parser.add_argument("-threshold", "--entropy_threshold", nargs="?", type=float) 221 | parser.add_argument("-prepare", "--metaseg_prepare", action='store_true') 222 | parser.add_argument("-segment", "--segment_search", action='store_true') 223 | parser.add_argument("-removal", "--fp_removal", action='store_true') 224 | main(vars(parser.parse_args())) 225 | -------------------------------------------------------------------------------- /src/model/GALDNet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author: Xiangtai(lxt@pku.edu.cn) 4 | # GA module is borrowed from CGNL paper directly 5 | # Pytorch implementation of GALD-Net 6 | 7 | 8 | import torch.nn as nn 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.nn import BatchNorm2d 12 | 13 | 14 | class SpatialCGNL(nn.Module): 15 | """Spatial CGNL block with dot production kernel for image classfication. 16 | """ 17 | def __init__(self, inplanes, planes, use_scale=False, groups=8): 18 | self.use_scale = use_scale 19 | self.groups = groups 20 | 21 | super(SpatialCGNL, self).__init__() 22 | # conv theta 23 | self.t = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) 24 | # conv phi 25 | self.p = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) 26 | # conv g 27 | self.g = nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, bias=False) 28 | # conv z 29 | self.z = nn.Conv2d(planes, inplanes, kernel_size=1, stride=1, 30 | groups=self.groups, bias=False) 31 | self.gn = nn.GroupNorm(num_groups=self.groups, num_channels=inplanes) 32 | 33 | if self.use_scale: 34 | print("=> WARN: SpatialCGNL block uses 'SCALE'", \ 35 | 'yellow') 36 | if self.groups: 37 | print("=> WARN: SpatialCGNL block uses '{}' groups".format(self.groups), \ 38 | 'yellow') 39 | 40 | def kernel(self, t, p, g, b, c, h, w): 41 | """The linear kernel (dot production). 42 | Args: 43 | t: output of conv theata 44 | p: output of conv phi 45 | g: output of conv g 46 | b: batch size 47 | c: channels number 48 | h: height of featuremaps 49 | w: width of featuremaps 50 | """ 51 | t = t.view(b, 1, c * h * w) 52 | p = p.view(b, 1, c * h * w) 53 | g = g.view(b, c * h * w, 1) 54 | 55 | att = torch.bmm(p, g) 56 | 57 | if self.use_scale: 58 | att = att.div((c*h*w)**0.5) 59 | 60 | x = torch.bmm(att, t) 61 | x = x.view(b, c, h, w) 62 | 63 | return x 64 | 65 | def forward(self, x): 66 | residual = x 67 | 68 | t = self.t(x) 69 | p = self.p(x) 70 | g = self.g(x) 71 | b, c, h, w = t.size() 72 | 73 | if self.groups and self.groups > 1: 74 | _c = int(c / self.groups) 75 | 76 | ts = torch.split(t, split_size_or_sections=_c, dim=1) 77 | ps = torch.split(p, split_size_or_sections=_c, dim=1) 78 | gs = torch.split(g, split_size_or_sections=_c, dim=1) 79 | 80 | _t_sequences = [] 81 | 82 | for i in range(self.groups): 83 | _x = self.kernel(ts[i], ps[i], gs[i], 84 | b, _c, h, w) 85 | _t_sequences.append(_x) 86 | 87 | x = torch.cat(_t_sequences, dim=1) 88 | else: 89 | x = self.kernel(t, p, g, 90 | b, c, h, w) 91 | 92 | x = self.z(x) 93 | x = self.gn(x) + residual 94 | 95 | return x 96 | 97 | 98 | class GALDBlock(nn.Module): 99 | def __init__(self, inplane, plane): 100 | super(GALDBlock, self).__init__() 101 | """ 102 | Note down the spatial into 1/16 103 | """ 104 | self.down = nn.Sequential( 105 | nn.Conv2d(inplane, inplane,kernel_size=3, groups=inplane, stride=2), 106 | BatchNorm2d(inplane), 107 | nn.ReLU(inplace=False) 108 | ) 109 | self.long_relation = SpatialCGNL(inplane, plane) 110 | self.local_attention = LocalAttenModule(inplane) 111 | 112 | def forward(self, x): 113 | size = x.size()[2:] 114 | x = self.down(x) 115 | x = self.long_relation(x) 116 | # local attention 117 | x = F.upsample(x,size=size, mode="bilinear", align_corners=True) 118 | res = x 119 | x = self.local_attention(x) 120 | return x + res 121 | 122 | 123 | class LocalAttenModule(nn.Module): 124 | def __init__(self, inplane): 125 | super(LocalAttenModule, self).__init__() 126 | self.dconv1 = nn.Sequential( 127 | nn.Conv2d(inplane, inplane, kernel_size=3, groups=inplane, stride=2), 128 | BatchNorm2d(inplane), 129 | nn.ReLU(inplace=False) 130 | ) 131 | self.dconv2 = nn.Sequential( 132 | nn.Conv2d(inplane, inplane, kernel_size=3, groups=inplane, stride=2), 133 | BatchNorm2d(inplane), 134 | nn.ReLU(inplace=False) 135 | ) 136 | self.dconv3 = nn.Sequential( 137 | nn.Conv2d(inplane, inplane, kernel_size=3, groups=inplane, stride=2), 138 | BatchNorm2d(inplane), 139 | nn.ReLU(inplace=False) 140 | ) 141 | self.sigmoid_spatial = nn.Sigmoid() 142 | 143 | def forward(self, x): 144 | b, c, h, w = x.size() 145 | res1 = x 146 | res2 = x 147 | x = self.dconv1(x) 148 | x = self.dconv2(x) 149 | x = self.dconv3(x) 150 | x = F.upsample(x, size=(h, w), mode="bilinear", align_corners=True) 151 | x_mask = self.sigmoid_spatial(x) 152 | 153 | res1 = res1 * x_mask 154 | 155 | return res2 + res1 156 | 157 | 158 | def conv3x3(in_planes, out_planes, stride=1): 159 | "3x3 convolution with padding" 160 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 161 | padding=1, bias=False) 162 | 163 | 164 | class Bottleneck(nn.Module): 165 | expansion = 4 166 | 167 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, fist_dilation=1, multi_grid=1): 168 | super(Bottleneck, self).__init__() 169 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 170 | self.bn1 = BatchNorm2d(planes) 171 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 172 | padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False) 173 | self.bn2 = BatchNorm2d(planes) 174 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 175 | self.bn3 = BatchNorm2d(planes * 4) 176 | self.relu = nn.ReLU(inplace=False) 177 | self.relu_inplace = nn.ReLU(inplace=True) 178 | self.downsample = downsample 179 | self.dilation = dilation 180 | self.stride = stride 181 | 182 | def forward(self, x): 183 | residual = x 184 | 185 | out = self.conv1(x) 186 | out = self.bn1(out) 187 | out = self.relu(out) 188 | 189 | out = self.conv2(out) 190 | out = self.bn2(out) 191 | out = self.relu(out) 192 | 193 | out = self.conv3(out) 194 | out = self.bn3(out) 195 | 196 | if self.downsample is not None: 197 | residual = self.downsample(x) 198 | 199 | out = out + residual 200 | out = self.relu_inplace(out) 201 | 202 | return out 203 | 204 | 205 | class GALDHead(nn.Module): 206 | def __init__(self, inplanes, interplanes, num_classes): 207 | super(GALDHead, self).__init__() 208 | self.conva = nn.Sequential(nn.Conv2d(inplanes, interplanes, 3, padding=1, bias=False), 209 | BatchNorm2d(interplanes), 210 | nn.ReLU(interplanes)) 211 | self.a2block = GALDBlock(interplanes, interplanes//2) 212 | self.convb = nn.Sequential(nn.Conv2d(interplanes, interplanes, 3, padding=1, bias=False), 213 | BatchNorm2d(interplanes), 214 | nn.ReLU(interplanes)) 215 | 216 | self.bottleneck = nn.Sequential( 217 | nn.Conv2d(inplanes + interplanes, interplanes, kernel_size=3, padding=1, dilation=1, bias=False), 218 | BatchNorm2d(interplanes), 219 | nn.ReLU(interplanes), 220 | nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) 221 | ) 222 | 223 | def forward(self, x): 224 | output = self.conva(x) 225 | output = self.a2block(output) 226 | output = self.convb(output) 227 | output = self.bottleneck(torch.cat([x, output], 1)) 228 | return output 229 | 230 | 231 | class GALDNet(nn.Module): 232 | def __init__(self, block, layers, num_classes, avg=False): 233 | self.inplanes = 128 234 | super(GALDNet, self).__init__() 235 | self.conv1 = nn.Sequential( 236 | conv3x3(3, 64, stride=2), 237 | BatchNorm2d(64), 238 | nn.ReLU(inplace=True), 239 | conv3x3(64, 64), 240 | BatchNorm2d(64), 241 | nn.ReLU(inplace=True), 242 | conv3x3(64, 128)) 243 | self.bn1 = BatchNorm2d(self.inplanes) 244 | self.relu = nn.ReLU(inplace=False) 245 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) 246 | self.layer1 = self._make_layer(block, 64, layers[0]) 247 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 248 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 249 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4, multi_grid=(1, 2, 4)) 250 | 251 | self.head = GALDHead(2048, 512, num_classes=num_classes) 252 | self.dsn = nn.Sequential( 253 | nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1), 254 | BatchNorm2d(512), 255 | nn.ReLU(), 256 | nn.Dropout2d(0.1), 257 | nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) 258 | ) 259 | 260 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1): 261 | downsample = None 262 | if stride != 1 or self.inplanes != planes * block.expansion: 263 | downsample = nn.Sequential( 264 | nn.Conv2d(self.inplanes, planes * block.expansion, 265 | kernel_size=1, stride=stride, bias=False), 266 | BatchNorm2d(planes * block.expansion, affine=True)) 267 | 268 | layers = [] 269 | generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1 270 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid))) 271 | self.inplanes = planes * block.expansion 272 | for i in range(1, blocks): 273 | layers.append(block(self.inplanes, planes, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid))) 274 | 275 | return nn.Sequential(*layers) 276 | 277 | def forward(self, x): 278 | size = x.size()[2:] 279 | x = self.conv1(x) 280 | x = self.bn1(x) 281 | x = self.relu(x) 282 | x = self.maxpool(x) 283 | x = self.layer1(x) 284 | x = self.layer2(x) 285 | x = self.layer3(x) 286 | x_dsn = self.dsn(x) 287 | x = self.layer4(x) 288 | x = self.head(x) 289 | return [x, x_dsn] 290 | 291 | 292 | def GALD_res101(num_classes=21): 293 | model = GALDNet(Bottleneck, [3, 4, 23, 3], num_classes) 294 | return model 295 | 296 | 297 | def GALD_res50(num_classes=21): 298 | model = GALDNet(Bottleneck, [3,4,6,3], num_classes) 299 | return model 300 | 301 | 302 | if __name__ == '__main__': 303 | i = torch.Tensor(1, 3, 769, 769).cuda() 304 | m = GALD_res50(19).cuda() 305 | m.eval() 306 | o = m(i) 307 | print(o[0].size()) -------------------------------------------------------------------------------- /src/model/deepv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code Adapted from: 3 | # https://github.com/sthalles/deeplab_v3 4 | # 5 | # MIT License 6 | # 7 | # Copyright (c) 2018 Thalles Santos Silva 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | """ 26 | import logging 27 | import torch 28 | from torch import nn 29 | from . import SEresnext 30 | from . import Resnet 31 | from .wider_resnet import wider_resnet38_a2 32 | from .mynn import initialize_weights, Norm2d, Upsample 33 | 34 | 35 | class _AtrousSpatialPyramidPoolingModule(nn.Module): 36 | """ 37 | operations performed: 38 | 1x1 x depth 39 | 3x3 x depth dilation 6 40 | 3x3 x depth dilation 12 41 | 3x3 x depth dilation 18 42 | image pooling 43 | concatenate all together 44 | Final 1x1 conv 45 | """ 46 | 47 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)): 48 | super(_AtrousSpatialPyramidPoolingModule, self).__init__() 49 | 50 | # Check if we are using distributed BN and use the nn from encoding.nn 51 | # library rather than using standard pytorch.nn 52 | 53 | if output_stride == 8: 54 | rates = [2 * r for r in rates] 55 | elif output_stride == 16: 56 | pass 57 | else: 58 | raise 'output stride of {} not supported'.format(output_stride) 59 | 60 | self.features = [] 61 | # 1x1 62 | self.features.append( 63 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 64 | Norm2d(reduction_dim), nn.ReLU(inplace=True))) 65 | # other rates 66 | for r in rates: 67 | self.features.append(nn.Sequential( 68 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3, 69 | dilation=r, padding=r, bias=False), 70 | Norm2d(reduction_dim), 71 | nn.ReLU(inplace=True) 72 | )) 73 | self.features = torch.nn.ModuleList(self.features) 74 | 75 | # img level features 76 | self.img_pooling = nn.AdaptiveAvgPool2d(1) 77 | self.img_conv = nn.Sequential( 78 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 79 | Norm2d(reduction_dim), nn.ReLU(inplace=True)) 80 | 81 | def forward(self, x): 82 | x_size = x.size() 83 | 84 | img_features = self.img_pooling(x) 85 | img_features = self.img_conv(img_features) 86 | img_features = Upsample(img_features, x_size[2:]) 87 | out = img_features 88 | 89 | for f in self.features: 90 | y = f(x) 91 | out = torch.cat((out, y), 1) 92 | return out 93 | 94 | 95 | class DeepV3Plus(nn.Module): 96 | """ 97 | Implement DeepLab-V3 model 98 | A: stride8 99 | B: stride16 100 | with skip connections 101 | """ 102 | 103 | def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D', 104 | skip='m1', skip_num=48): 105 | super(DeepV3Plus, self).__init__() 106 | self.criterion = criterion 107 | self.variant = variant 108 | self.skip = skip 109 | self.skip_num = skip_num 110 | 111 | if trunk == 'seresnext-50': 112 | resnet = SEresnext.se_resnext50_32x4d() 113 | elif trunk == 'seresnext-101': 114 | resnet = SEresnext.se_resnext101_32x4d() 115 | elif trunk == 'resnet-50': 116 | resnet = Resnet.resnet50() 117 | resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 118 | elif trunk == 'resnet-101': 119 | resnet = Resnet.resnet101() 120 | resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 121 | else: 122 | raise ValueError("Not a valid network arch") 123 | 124 | self.layer0 = resnet.layer0 125 | self.layer1, self.layer2, self.layer3, self.layer4 = \ 126 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 127 | 128 | if self.variant == 'D': 129 | for n, m in self.layer3.named_modules(): 130 | if 'conv2' in n: 131 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 132 | elif 'downsample.0' in n: 133 | m.stride = (1, 1) 134 | for n, m in self.layer4.named_modules(): 135 | if 'conv2' in n: 136 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) 137 | elif 'downsample.0' in n: 138 | m.stride = (1, 1) 139 | elif self.variant == 'D16': 140 | for n, m in self.layer4.named_modules(): 141 | if 'conv2' in n: 142 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 143 | elif 'downsample.0' in n: 144 | m.stride = (1, 1) 145 | else: 146 | # raise 'unknown deepv3 variant: {}'.format(self.variant) 147 | print("Not using Dilation ") 148 | 149 | self.aspp = _AtrousSpatialPyramidPoolingModule(2048, 256, 150 | output_stride=8) 151 | 152 | if self.skip == 'm1': 153 | self.bot_fine = nn.Conv2d(256, self.skip_num, kernel_size=1, bias=False) 154 | elif self.skip == 'm2': 155 | self.bot_fine = nn.Conv2d(512, self.skip_num, kernel_size=1, bias=False) 156 | else: 157 | raise Exception('Not a valid skip') 158 | 159 | self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) 160 | 161 | self.final = nn.Sequential( 162 | nn.Conv2d(256 + self.skip_num, 256, kernel_size=3, padding=1, bias=False), 163 | Norm2d(256), 164 | nn.ReLU(inplace=True), 165 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 166 | Norm2d(256), 167 | nn.ReLU(inplace=True), 168 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 169 | 170 | initialize_weights(self.aspp) 171 | initialize_weights(self.bot_aspp) 172 | initialize_weights(self.bot_fine) 173 | initialize_weights(self.final) 174 | 175 | def forward(self, x, gts=None): 176 | 177 | x_size = x.size() # 800 178 | x0 = self.layer0(x) # 400 179 | x1 = self.layer1(x0) # 400 180 | x2 = self.layer2(x1) # 100 181 | x3 = self.layer3(x2) # 100 182 | x4 = self.layer4(x3) # 100 183 | xp = self.aspp(x4) 184 | 185 | dec0_up = self.bot_aspp(xp) 186 | if self.skip == 'm1': 187 | dec0_fine = self.bot_fine(x1) 188 | dec0_up = Upsample(dec0_up, x1.size()[2:]) 189 | else: 190 | dec0_fine = self.bot_fine(x2) 191 | dec0_up = Upsample(dec0_up, x2.size()[2:]) 192 | 193 | dec0 = [dec0_fine, dec0_up] 194 | dec0 = torch.cat(dec0, 1) 195 | dec1 = self.final(dec0) 196 | main_out = Upsample(dec1, x_size[2:]) 197 | 198 | if self.training: 199 | return self.criterion(main_out, gts) 200 | 201 | return main_out 202 | 203 | 204 | class DeepWV3Plus(nn.Module): 205 | """ 206 | Wide_resnet version of DeepLabV3 207 | mod1 208 | pool2 209 | mod2 str2 210 | pool3 211 | mod3-7 212 | 213 | structure: [3, 3, 6, 3, 1, 1] 214 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), 215 | (1024, 2048, 4096)] 216 | """ 217 | 218 | def __init__(self, num_classes, trunk='WideResnet38'): 219 | 220 | super(DeepWV3Plus, self).__init__() 221 | logging.debug("Trunk: %s", trunk) 222 | wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) 223 | wide_resnet = torch.nn.DataParallel(wide_resnet) 224 | # try: 225 | # checkpoint = torch.load('./pretrained_models/wider_resnet38.pth.tar', map_location='cpu') 226 | # wide_resnet.load_state_dict(checkpoint['state_dict']) 227 | # del checkpoint 228 | # except: 229 | # print("=====================Could not load ImageNet weights=======================") 230 | # print("Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models.") 231 | 232 | wide_resnet = wide_resnet.module 233 | 234 | self.mod1 = wide_resnet.mod1 235 | self.mod2 = wide_resnet.mod2 236 | self.mod3 = wide_resnet.mod3 237 | self.mod4 = wide_resnet.mod4 238 | self.mod5 = wide_resnet.mod5 239 | self.mod6 = wide_resnet.mod6 240 | self.mod7 = wide_resnet.mod7 241 | self.pool2 = wide_resnet.pool2 242 | self.pool3 = wide_resnet.pool3 243 | del wide_resnet 244 | 245 | self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, 246 | output_stride=8) 247 | 248 | self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) 249 | self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) 250 | 251 | self.final = nn.Sequential( 252 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), 253 | Norm2d(256), 254 | nn.ReLU(inplace=True), 255 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 256 | Norm2d(256), 257 | nn.ReLU(inplace=True), 258 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 259 | 260 | initialize_weights(self.final) 261 | 262 | def forward(self, inp): 263 | 264 | x_size = inp.size() 265 | x = self.mod1(inp) 266 | m2 = self.mod2(self.pool2(x)) 267 | x = self.mod3(self.pool3(m2)) 268 | x = self.mod4(x) 269 | x = self.mod5(x) 270 | x = self.mod6(x) 271 | x = self.mod7(x) 272 | x = self.aspp(x) 273 | dec0_up = self.bot_aspp(x) 274 | 275 | dec0_fine = self.bot_fine(m2) 276 | dec0_up = Upsample(dec0_up, m2.size()[2:]) 277 | dec0 = [dec0_fine, dec0_up] 278 | dec0 = torch.cat(dec0, 1) 279 | 280 | dec1 = self.final(dec0) 281 | out = Upsample(dec1, x_size[2:]) 282 | 283 | return out 284 | 285 | 286 | def DeepSRNX50V3PlusD_m1(num_classes, criterion): 287 | """ 288 | SEResnet 50 Based Network 289 | """ 290 | return DeepV3Plus(num_classes, trunk='seresnext-50', criterion=criterion, variant='D', 291 | skip='m1') 292 | 293 | def DeepR50V3PlusD_m1(num_classes, criterion): 294 | """ 295 | Resnet 50 Based Network 296 | """ 297 | return DeepV3Plus(num_classes, trunk='resnet-50', criterion=criterion, variant='D', skip='m1') 298 | 299 | 300 | def DeepSRNX101V3PlusD_m1(num_classes, criterion): 301 | """ 302 | SeResnext 101 Based Network 303 | """ 304 | return DeepV3Plus(num_classes, trunk='seresnext-101', criterion=criterion, variant='D', 305 | skip='m1') 306 | 307 | -------------------------------------------------------------------------------- /src/imageaugmentations.py: -------------------------------------------------------------------------------- 1 | # adjusted from 2 | # https://github.com/meetshah1995/pytorch-semseg/tree/master/ptsemseg 3 | 4 | # Adapted from 5 | # https://github.com/ZijunDeng/pytorch-semantic-segmentation/blob/master/utils/joint_transforms.py 6 | 7 | import math 8 | import numbers 9 | import logging 10 | import random 11 | import types 12 | import numpy as np 13 | import torch 14 | import torch.nn.functional as F 15 | import torchvision.transforms as trans 16 | 17 | from PIL import Image, ImageOps 18 | 19 | 20 | class Compose(object): 21 | """Wraps together multiple image augmentations. 22 | 23 | Should also be used with only one augmentation, as it ensures, that input 24 | images are of type 'PIL.Image' and handles the augmentation process. 25 | 26 | Args: 27 | augmentations: List of augmentations to be applied. 28 | """ 29 | 30 | def __init__(self, augmentations): 31 | """Initializes the composer with the given augmentations.""" 32 | self.augmentations = augmentations 33 | 34 | def __call__(self, img, mask, *inputs): 35 | """Returns images that are augmented with the given augmentations.""" 36 | # img, mask = Image.fromarray(img, mode='RGB'), Image.fromarray(mask, mode='L') 37 | assert img.size == mask.size 38 | for a in self.augmentations: 39 | img, mask, inputs = a(img, mask, *inputs) 40 | return (img, mask, *inputs) 41 | 42 | 43 | class RandomCrop(object): 44 | """Returns an image of size 'size' that is a random crop of the original. 45 | 46 | Args: 47 | size: Size of the croped image. 48 | padding: Number of pixels to be placed around the original image. 49 | """ 50 | 51 | def __init__(self, size, padding=0, *inputs, **kwargs): 52 | if isinstance(size, numbers.Number): 53 | self.size = (int(size), int(size)) 54 | else: 55 | self.size = size 56 | self.padding = padding 57 | 58 | def __call__(self, img, mask, *inputs, **kwargs): 59 | """Returns randomly cropped image.""" 60 | if self.padding > 0: 61 | img = ImageOps.expand(img, border=self.padding, fill=0) 62 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 63 | inputs = tuple(ImageOps.expand(i, border=self.padding, fill=0) for i in inputs) 64 | 65 | assert img.size == mask.size 66 | w, h = img.size 67 | th, tw = self.size 68 | if w == tw and h == th: 69 | return img, mask 70 | if w < tw or h < th: 71 | return (img.resize((tw, th), Image.BILINEAR), 72 | mask.resize((tw, th), Image.NEAREST), 73 | tuple(i.resize((tw, th), Image.NEAREST) for i in inputs)) 74 | 75 | x1 = random.randint(0, w - tw) 76 | y1 = random.randint(0, h - th) 77 | return (img.crop((x1, y1, x1 + tw, y1 + th)), 78 | mask.crop((x1, y1, x1 + tw, y1 + th)), 79 | tuple(i.crop((x1, y1, x1 + tw, y1 + th)) for i in inputs)) 80 | 81 | 82 | class CenterCrop(object): 83 | """Returns image of size 'size' that is center cropped. 84 | 85 | Crops an image of size 'size' from the center of an image. If the center 86 | index is not an integer, the value will be rounded. 87 | 88 | Args: 89 | size: The size of the output image. 90 | """ 91 | 92 | def __init__(self, size, *inputs, **kwargs): 93 | if isinstance(size, numbers.Number): 94 | self.size = (int(size), int(size)) 95 | else: 96 | self.size = size 97 | 98 | def __call__(self, img, mask, *inputs, **kwargs): 99 | assert img.size == mask.size 100 | w, h = img.size 101 | th, tw = self.size 102 | x1 = int(round((w - tw) / 2.)) 103 | y1 = int(round((h - th) / 2.)) 104 | return (img.crop((x1, y1, x1 + tw, y1 + th)), 105 | mask.crop((x1, y1, x1 + tw, y1 + th)), *inputs) 106 | 107 | 108 | class RandomHorizontalFlip(object): 109 | """Returns an image the got flipped with a probability of 'prob'. 110 | 111 | Args: 112 | prob: Probability with which the horizontal flip is applied. 113 | """ 114 | 115 | def __init__(self, prob=0.5, *inputs, **kwargs): 116 | if not isinstance(prob, numbers.Number): 117 | raise TypeError("'prob' needs to be a number.") 118 | self.prob = prob 119 | 120 | def __call__(self, img, mask, *inputs, **kwargs): 121 | if random.random() < self.prob: 122 | return (img.transpose(Image.FLIP_LEFT_RIGHT), 123 | mask.transpose(Image.FLIP_LEFT_RIGHT), 124 | tuple(i.transpose(Image.FLIP_LEFT_RIGHT) for i in inputs)) 125 | return img, mask, tuple(i for i in inputs) 126 | 127 | 128 | class FreeScale(object): 129 | def __init__(self, size, *inputs, **kwargs): 130 | self.size = tuple(reversed(size)) # size: (h, w) 131 | 132 | def __call__(self, img, mask, *inputs, **kwargs): 133 | assert img.size == mask.size 134 | return (img.resize(self.size, Image.BILINEAR), 135 | mask.resize(self.size, Image.NEAREST), *inputs) 136 | 137 | 138 | class Scale(object): 139 | def __init__(self, size, *inputs, **kwargs): 140 | self.size = size 141 | 142 | def __call__(self, img, mask, *inputs, **kwargs): 143 | assert img.size == mask.size 144 | w, h = img.size 145 | if (w >= h and w == self.size) or (h >= w and h == self.size): 146 | return (img, mask, *inputs) 147 | if w > h: 148 | ow = self.size 149 | oh = int(self.size * h / w) 150 | return (img.resize((ow, oh), Image.BILINEAR), 151 | mask.resize((ow, oh), Image.NEAREST), *inputs) 152 | else: 153 | oh = self.size 154 | ow = int(self.size * w / h) 155 | return (img.resize((ow, oh), Image.BILINEAR), 156 | mask.resize((ow, oh), Image.NEAREST), *inputs) 157 | 158 | 159 | class RandomSizedCrop(object): 160 | def __init__(self, size, *inputs, **kwargs): 161 | self.size = size 162 | 163 | def __call__(self, img, mask, *inputs, **kwargs): 164 | assert img.size == mask.size 165 | for attempt in range(10): 166 | area = img.size[0] * img.size[1] 167 | target_area = random.uniform(0.45, 1.0) * area 168 | aspect_ratio = random.uniform(0.5, 2) 169 | 170 | w = int(round(math.sqrt(target_area * aspect_ratio))) 171 | h = int(round(math.sqrt(target_area / aspect_ratio))) 172 | 173 | if random.random() < 0.5: 174 | w, h = h, w 175 | 176 | if w <= img.size[0] and h <= img.size[1]: 177 | x1 = random.randint(0, img.size[0] - w) 178 | y1 = random.randint(0, img.size[1] - h) 179 | 180 | img = img.crop((x1, y1, x1 + w, y1 + h)) 181 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 182 | assert (img.size == (w, h)) 183 | 184 | return (img.resize((self.size, self.size), Image.BILINEAR), 185 | mask.resize((self.size, self.size), Image.NEAREST), *inputs) 186 | 187 | # Fallback 188 | scale = Scale(self.size) 189 | crop = CenterCrop(self.size) 190 | return crop(*scale(img, mask, *inputs)) 191 | 192 | 193 | class RandomRotate(object): 194 | def __init__(self, degree, *inputs, **kwargs): 195 | if not isinstance(degree, numbers.Number): 196 | raise TypeError("'degree' needs to be a number.") 197 | self.degree = degree 198 | 199 | def __call__(self, img, mask, *inputs, **kwargs): 200 | rotate_degree = random.random() * 2 * self.degree - self.degree 201 | return (img.rotate(rotate_degree, Image.BILINEAR), 202 | mask.rotate(rotate_degree, Image.NEAREST), *inputs) 203 | 204 | 205 | class RandomSized(object): 206 | def __init__(self, size, min_scale=0.5, max_scale=2, *inputs, **kwargs): 207 | self.size = size 208 | self.min_scale = min_scale 209 | self.max_scale = max_scale 210 | self.scale = Scale(self.size) 211 | self.crop = RandomCrop(self.size) 212 | 213 | def __call__(self, img, mask, *inputs, **kwargs): 214 | assert img.size == mask.size 215 | 216 | w = int(random.uniform(self.min_scale, self.max_scale) * img.size[0]) 217 | h = int(random.uniform(self.min_scale, self.max_scale) * img.size[1]) 218 | 219 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 220 | 221 | return self.crop(*self.scale(img, mask, *inputs)) 222 | 223 | 224 | class RandomOcclusion(object): 225 | def __init__(self, build_prob=0.5, secondary_build_prob=0.99, occlusion_class=-1, start_points=5, min_size=100, 226 | *inputs, **kwargs): 227 | self.log = logging.getLogger(__name__) 228 | if build_prob > 1 or build_prob < 0: 229 | self.log.error('build_prob has to be between 0 and 1!') 230 | raise ValueError('build_prob has to be between 0 and 1!') 231 | if secondary_build_prob > 1 or secondary_build_prob < 0: 232 | self.log.error('secondary_build_prob has to be between 0 and 1!') 233 | raise ValueError('secondary_build_prob has to be between 0 and 1!') 234 | self.build_prob = build_prob 235 | self.secondary_build_prob = secondary_build_prob 236 | self.occlusion_class = occlusion_class 237 | self.start_points = start_points 238 | self.min_size = min_size 239 | 240 | def __call__(self, img, mask, *inputs, **kwargs): 241 | while (mask == self.occlusion_class).sum() < self.min_size: 242 | self.queue = [] 243 | self.flags = torch.full_like(mask, 0).byte() 244 | self.occlusion_map = torch.full_like(mask, 0).byte() 245 | self.img_height = img.shape[-2] 246 | self.img_width = img.shape[-1] 247 | 248 | # add first elements 249 | for _ in range(self.start_points): 250 | x = random.randint(0, self.img_height) 251 | y = random.randint(0, self.img_width) 252 | self.queue.append((x, y)) 253 | while len(self.queue) > 0: 254 | i, j = self.queue.pop(0) 255 | self._scan_neighborhood(i, j) 256 | 257 | if self.occlusion_map.sum().item() >= self.min_size: 258 | for c in range(img.shape[0]): 259 | img[c][self.occlusion_map] = 0 260 | mask[self.occlusion_map] = self.occlusion_class 261 | 262 | return (img, mask, *inputs) 263 | 264 | def _scan_neighborhood(self, i, j, *inputs, **kwargs): 265 | grid = [(i - 1, j - 1), 266 | (i - 1, j), 267 | (i - 1, j + 1), 268 | (i, j - 1), 269 | (i, j + 1), 270 | (i + 1, j - 1), 271 | (i + 1, j), 272 | (i + 1, j + 1)] 273 | if random.random() < self.build_prob: 274 | for ind in grid: 275 | if 0 <= ind[0] < self.img_height and 0 <= ind[1] < self.img_width: 276 | if self.flags[ind] == 0 and random.random() < self.secondary_build_prob: 277 | self.queue.append(ind) 278 | self.occlusion_map[ind] = 1 279 | self.flags[ind] = 1 280 | else: 281 | for ind in grid: 282 | if 0 <= ind[0] < self.img_height and 0 <= ind[1] < self.img_width: 283 | self.flags[ind] = 1 284 | 285 | 286 | class RandomNoise(object): 287 | def __init__(self, prob=0.5, ratio=0.1, *inputs, **kwargs): 288 | self.prob = prob 289 | self.ratio = ratio 290 | 291 | def __call__(self, image, mask, *inputs, **kwargs): 292 | if random.random() < self.prob: 293 | image = (1 - self.ratio) * image + self.ratio * torch.rand_like(image) 294 | return (image, mask, *inputs) 295 | 296 | 297 | class RandomNoiseImage(object): 298 | def __init__(self, prob=0.05, class_index=-1, *inputs, **kwargs): 299 | self.prob = prob 300 | self.class_index = class_index 301 | 302 | def __call__(self, image, mask, *inputs, **kwargs): 303 | if random.random() < self.prob: 304 | image = torch.rand_like(image) 305 | mask = torch.full_like(mask, self.class_index) 306 | return (image, mask, *inputs) 307 | 308 | 309 | class ToTensor(object): 310 | def __call__(self, image, mask, *inputs, **kwargs): 311 | t = trans.ToTensor() 312 | return (t(image), torch.tensor(np.array(mask, dtype=np.uint8), dtype=torch.long), 313 | tuple(torch.tensor(np.array(i, dtype=np.uint8), dtype=torch.long) for i in inputs)) 314 | 315 | def __repr__(self, *inputs, **kwargs): 316 | return self.__class__.__name__ + '()' 317 | 318 | 319 | class Normalize(object): 320 | def __init__(self, mean, std, *inputs, **kwargs): 321 | self.mean = mean 322 | self.std = std 323 | self.t = trans.Normalize(mean=self.mean, std=self.std) 324 | 325 | def __call__(self, tensor, mask, *inputs, **kwargs): 326 | return self.t(tensor), mask, tuple(i for i in inputs) 327 | 328 | 329 | class DeStandardize(object): 330 | def __init__(self, mean, std, *inputs, **kwargs): 331 | self.mean = mean 332 | self.std = std 333 | 334 | def __call__(self, tensor, mask, *inputs, **kwargs): 335 | for i in range(tensor.shape[0]): 336 | tensor[i] = tensor[i].mul(self.std[i]).add(self.mean[i]) 337 | return (tensor, mask, *inputs) 338 | 339 | 340 | class Lambda(object): 341 | def __init__(self, lambd, *inputs, **kwargs): 342 | assert isinstance(lambd, types.LambdaType) 343 | self.lambd = lambd 344 | 345 | def __call__(self, img, mask, *inputs, **kwargs): 346 | return self.lambd(img, mask, *inputs) 347 | 348 | def __repr__(self, *inputs, **kwargs): 349 | return self.__class__.__name__ + '()' 350 | -------------------------------------------------------------------------------- /src/model/wider_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/mapillary/inplace_abn/ 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, mapillary 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | import logging 36 | import sys 37 | from collections import OrderedDict 38 | from functools import partial 39 | import torch.nn as nn 40 | import torch 41 | from . import mynn 42 | 43 | def bnrelu(channels): 44 | """ 45 | Single Layer BN and Relui 46 | """ 47 | return nn.Sequential(mynn.Norm2d(channels), 48 | nn.ReLU(inplace=True)) 49 | 50 | class GlobalAvgPool2d(nn.Module): 51 | """ 52 | Global average pooling over the input's spatial dimensions 53 | """ 54 | 55 | def __init__(self): 56 | super(GlobalAvgPool2d, self).__init__() 57 | logging.debug("Global Average Pooling Initialized") 58 | 59 | def forward(self, inputs): 60 | in_size = inputs.size() 61 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 62 | 63 | 64 | class IdentityResidualBlock(nn.Module): 65 | """ 66 | Identity Residual Block for WideResnet 67 | """ 68 | def __init__(self, 69 | in_channels, 70 | channels, 71 | stride=1, 72 | dilation=1, 73 | groups=1, 74 | norm_act=bnrelu, 75 | dropout=None, 76 | dist_bn=False 77 | ): 78 | """Configurable identity-mapping residual block 79 | 80 | Parameters 81 | ---------- 82 | in_channels : int 83 | Number of input channels. 84 | channels : list of int 85 | Number of channels in the internal feature maps. 86 | Can either have two or three elements: if three construct 87 | a residual block with two `3 x 3` convolutions, 88 | otherwise construct a bottleneck block with `1 x 1`, then 89 | `3 x 3` then `1 x 1` convolutions. 90 | stride : int 91 | Stride of the first `3 x 3` convolution 92 | dilation : int 93 | Dilation to apply to the `3 x 3` convolutions. 94 | groups : int 95 | Number of convolution groups. 96 | This is used to create ResNeXt-style blocks and is only compatible with 97 | bottleneck blocks. 98 | norm_act : callable 99 | Function to create normalization / activation Module. 100 | dropout: callable 101 | Function to create Dropout Module. 102 | dist_bn: Boolean 103 | A variable to enable or disable use of distributed BN 104 | """ 105 | super(IdentityResidualBlock, self).__init__() 106 | self.dist_bn = dist_bn 107 | 108 | # Check if we are using distributed BN and use the nn from encoding.nn 109 | # library rather than using standard pytorch.nn 110 | 111 | 112 | # Check parameters for inconsistencies 113 | if len(channels) != 2 and len(channels) != 3: 114 | raise ValueError("channels must contain either two or three values") 115 | if len(channels) == 2 and groups != 1: 116 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 117 | 118 | is_bottleneck = len(channels) == 3 119 | need_proj_conv = stride != 1 or in_channels != channels[-1] 120 | 121 | self.bn1 = norm_act(in_channels) 122 | if not is_bottleneck: 123 | layers = [ 124 | ("conv1", nn.Conv2d(in_channels, 125 | channels[0], 126 | 3, 127 | stride=stride, 128 | padding=dilation, 129 | bias=False, 130 | dilation=dilation)), 131 | ("bn2", norm_act(channels[0])), 132 | ("conv2", nn.Conv2d(channels[0], channels[1], 133 | 3, 134 | stride=1, 135 | padding=dilation, 136 | bias=False, 137 | dilation=dilation)) 138 | ] 139 | if dropout is not None: 140 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 141 | else: 142 | layers = [ 143 | ("conv1", 144 | nn.Conv2d(in_channels, 145 | channels[0], 146 | 1, 147 | stride=stride, 148 | padding=0, 149 | bias=False)), 150 | ("bn2", norm_act(channels[0])), 151 | ("conv2", nn.Conv2d(channels[0], 152 | channels[1], 153 | 3, stride=1, 154 | padding=dilation, bias=False, 155 | groups=groups, 156 | dilation=dilation)), 157 | ("bn3", norm_act(channels[1])), 158 | ("conv3", nn.Conv2d(channels[1], channels[2], 159 | 1, stride=1, padding=0, bias=False)) 160 | ] 161 | if dropout is not None: 162 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 163 | self.convs = nn.Sequential(OrderedDict(layers)) 164 | 165 | if need_proj_conv: 166 | self.proj_conv = nn.Conv2d( 167 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 168 | 169 | def forward(self, x): 170 | """ 171 | This is the standard forward function for non-distributed batch norm 172 | """ 173 | if hasattr(self, "proj_conv"): 174 | bn1 = self.bn1(x) 175 | shortcut = self.proj_conv(bn1) 176 | else: 177 | shortcut = x.clone() 178 | bn1 = self.bn1(x) 179 | 180 | out = self.convs(bn1) 181 | out.add_(shortcut) 182 | return out 183 | 184 | 185 | 186 | 187 | class WiderResNet(nn.Module): 188 | """ 189 | WideResnet Global Module for Initialization 190 | """ 191 | def __init__(self, 192 | structure, 193 | norm_act=bnrelu, 194 | classes=0 195 | ): 196 | """Wider ResNet with pre-activation (identity mapping) blocks 197 | 198 | Parameters 199 | ---------- 200 | structure : list of int 201 | Number of residual blocks in each of the six modules of the network. 202 | norm_act : callable 203 | Function to create normalization / activation Module. 204 | classes : int 205 | If not `0` also include global average pooling and \ 206 | a fully-connected layer with `classes` outputs at the end 207 | of the network. 208 | """ 209 | super(WiderResNet, self).__init__() 210 | self.structure = structure 211 | 212 | if len(structure) != 6: 213 | raise ValueError("Expected a structure with six values") 214 | 215 | # Initial layers 216 | self.mod1 = nn.Sequential(OrderedDict([ 217 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 218 | ])) 219 | 220 | # Groups of residual blocks 221 | in_channels = 64 222 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), 223 | (512, 1024, 2048), (1024, 2048, 4096)] 224 | for mod_id, num in enumerate(structure): 225 | # Create blocks for module 226 | blocks = [] 227 | for block_id in range(num): 228 | blocks.append(( 229 | "block%d" % (block_id + 1), 230 | IdentityResidualBlock(in_channels, channels[mod_id], 231 | norm_act=norm_act) 232 | )) 233 | 234 | # Update channels and p_keep 235 | in_channels = channels[mod_id][-1] 236 | 237 | # Create module 238 | if mod_id <= 4: 239 | self.add_module("pool%d" % 240 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 241 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 242 | 243 | # Pooling and predictor 244 | self.bn_out = norm_act(in_channels) 245 | if classes != 0: 246 | self.classifier = nn.Sequential(OrderedDict([ 247 | ("avg_pool", GlobalAvgPool2d()), 248 | ("fc", nn.Linear(in_channels, classes)) 249 | ])) 250 | 251 | def forward(self, img): 252 | out = self.mod1(img) 253 | out = self.mod2(self.pool2(out)) 254 | out = self.mod3(self.pool3(out)) 255 | out = self.mod4(self.pool4(out)) 256 | out = self.mod5(self.pool5(out)) 257 | out = self.mod6(self.pool6(out)) 258 | out = self.mod7(out) 259 | out = self.bn_out(out) 260 | 261 | if hasattr(self, "classifier"): 262 | out = self.classifier(out) 263 | 264 | return out 265 | 266 | 267 | class WiderResNetA2(nn.Module): 268 | """ 269 | Wider ResNet with pre-activation (identity mapping) blocks 270 | 271 | This variant uses down-sampling by max-pooling in the first two blocks and 272 | by strided convolution in the others. 273 | 274 | Parameters 275 | ---------- 276 | structure : list of int 277 | Number of residual blocks in each of the six modules of the network. 278 | norm_act : callable 279 | Function to create normalization / activation Module. 280 | classes : int 281 | If not `0` also include global average pooling and a fully-connected layer 282 | with `classes` outputs at the end 283 | of the network. 284 | dilation : bool 285 | If `True` apply dilation to the last three modules and change the 286 | down-sampling factor from 32 to 8. 287 | """ 288 | def __init__(self, 289 | structure, 290 | norm_act=bnrelu, 291 | classes=0, 292 | dilation=False, 293 | dist_bn=False 294 | ): 295 | super(WiderResNetA2, self).__init__() 296 | self.dist_bn = dist_bn 297 | 298 | # If using distributed batch norm, use the encoding.nn as oppose to torch.nn 299 | 300 | 301 | nn.Dropout = nn.Dropout2d 302 | norm_act = bnrelu 303 | self.structure = structure 304 | self.dilation = dilation 305 | 306 | if len(structure) != 6: 307 | raise ValueError("Expected a structure with six values") 308 | 309 | # Initial layers 310 | self.mod1 = torch.nn.Sequential(OrderedDict([ 311 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 312 | ])) 313 | 314 | # Groups of residual blocks 315 | in_channels = 64 316 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), 317 | (1024, 2048, 4096)] 318 | for mod_id, num in enumerate(structure): 319 | # Create blocks for module 320 | blocks = [] 321 | for block_id in range(num): 322 | if not dilation: 323 | dil = 1 324 | stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1 325 | else: 326 | if mod_id == 3: 327 | dil = 2 328 | elif mod_id > 3: 329 | dil = 4 330 | else: 331 | dil = 1 332 | stride = 2 if block_id == 0 and mod_id == 2 else 1 333 | 334 | if mod_id == 4: 335 | drop = partial(nn.Dropout, p=0.3) 336 | elif mod_id == 5: 337 | drop = partial(nn.Dropout, p=0.5) 338 | else: 339 | drop = None 340 | 341 | blocks.append(( 342 | "block%d" % (block_id + 1), 343 | IdentityResidualBlock(in_channels, 344 | channels[mod_id], norm_act=norm_act, 345 | stride=stride, dilation=dil, 346 | dropout=drop, dist_bn=self.dist_bn) 347 | )) 348 | 349 | # Update channels and p_keep 350 | in_channels = channels[mod_id][-1] 351 | 352 | # Create module 353 | if mod_id < 2: 354 | self.add_module("pool%d" % 355 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 356 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 357 | 358 | # Pooling and predictor 359 | self.bn_out = norm_act(in_channels) 360 | if classes != 0: 361 | self.classifier = nn.Sequential(OrderedDict([ 362 | ("avg_pool", GlobalAvgPool2d()), 363 | ("fc", nn.Linear(in_channels, classes)) 364 | ])) 365 | 366 | def forward(self, img): 367 | out = self.mod1(img) 368 | out = self.mod2(self.pool2(out)) 369 | out = self.mod3(self.pool3(out)) 370 | out = self.mod4(out) 371 | out = self.mod5(out) 372 | out = self.mod6(out) 373 | out = self.mod7(out) 374 | out = self.bn_out(out) 375 | 376 | if hasattr(self, "classifier"): 377 | return self.classifier(out) 378 | return out 379 | 380 | 381 | _NETS = { 382 | "16": {"structure": [1, 1, 1, 1, 1, 1]}, 383 | "20": {"structure": [1, 1, 1, 3, 1, 1]}, 384 | "38": {"structure": [3, 3, 6, 3, 1, 1]}, 385 | } 386 | 387 | __all__ = [] 388 | for name, params in _NETS.items(): 389 | net_name = "wider_resnet" + name 390 | setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params)) 391 | __all__.append(net_name) 392 | for name, params in _NETS.items(): 393 | net_name = "wider_resnet" + name + "_a2" 394 | setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params)) 395 | __all__.append(net_name) 396 | -------------------------------------------------------------------------------- /src/model/SEresnext.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/Cadene/pretrained-models.pytorch 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, Remi Cadene 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | from collections import OrderedDict 36 | import math 37 | import torch.nn as nn 38 | from torch.utils import model_zoo 39 | from . import mynn 40 | 41 | __all__ = ['SENet', 'se_resnext50_32x4d', 'se_resnext101_32x4d'] 42 | 43 | pretrained_settings = { 44 | 'se_resnext50_32x4d': { 45 | 'imagenet': { 46 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 47 | 'input_space': 'RGB', 48 | 'input_size': [3, 224, 224], 49 | 'input_range': [0, 1], 50 | 'mean': [0.485, 0.456, 0.406], 51 | 'std': [0.229, 0.224, 0.225], 52 | 'num_classes': 1000 53 | } 54 | }, 55 | 'se_resnext101_32x4d': { 56 | 'imagenet': { 57 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 58 | 'input_space': 'RGB', 59 | 'input_size': [3, 224, 224], 60 | 'input_range': [0, 1], 61 | 'mean': [0.485, 0.456, 0.406], 62 | 'std': [0.229, 0.224, 0.225], 63 | 'num_classes': 1000 64 | } 65 | }, 66 | } 67 | 68 | 69 | class SEModule(nn.Module): 70 | """ 71 | Sequeeze Excitation Module 72 | """ 73 | def __init__(self, channels, reduction): 74 | super(SEModule, self).__init__() 75 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 76 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 77 | padding=0) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 80 | padding=0) 81 | self.sigmoid = nn.Sigmoid() 82 | 83 | def forward(self, x): 84 | module_input = x 85 | x = self.avg_pool(x) 86 | x = self.fc1(x) 87 | x = self.relu(x) 88 | x = self.fc2(x) 89 | x = self.sigmoid(x) 90 | return module_input * x 91 | 92 | 93 | class Bottleneck(nn.Module): 94 | """ 95 | Base class for bottlenecks that implements `forward()` method. 96 | """ 97 | def forward(self, x): 98 | residual = x 99 | 100 | out = self.conv1(x) 101 | out = self.bn1(out) 102 | out = self.relu(out) 103 | 104 | out = self.conv2(out) 105 | out = self.bn2(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv3(out) 109 | out = self.bn3(out) 110 | 111 | if self.downsample is not None: 112 | residual = self.downsample(x) 113 | 114 | out = self.se_module(out) + residual 115 | out = self.relu(out) 116 | 117 | return out 118 | 119 | 120 | class SEBottleneck(Bottleneck): 121 | """ 122 | Bottleneck for SENet154. 123 | """ 124 | expansion = 4 125 | 126 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 127 | downsample=None): 128 | super(SEBottleneck, self).__init__() 129 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 130 | self.bn1 = mynn.Norm2d(planes * 2) 131 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, 132 | stride=stride, padding=1, groups=groups, 133 | bias=False) 134 | self.bn2 = mynn.Norm2d(planes * 4) 135 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, 136 | bias=False) 137 | self.bn3 = mynn.Norm2d(planes * 4) 138 | self.relu = nn.ReLU(inplace=True) 139 | self.se_module = SEModule(planes * 4, reduction=reduction) 140 | self.downsample = downsample 141 | self.stride = stride 142 | 143 | 144 | class SEResNetBottleneck(Bottleneck): 145 | """ 146 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 147 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 148 | (the latter is used in the torchvision implementation of ResNet). 149 | """ 150 | expansion = 4 151 | 152 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 153 | downsample=None): 154 | super(SEResNetBottleneck, self).__init__() 155 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 156 | stride=stride) 157 | self.bn1 = mynn.Norm2d(planes) 158 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 159 | groups=groups, bias=False) 160 | self.bn2 = mynn.Norm2d(planes) 161 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 162 | self.bn3 = mynn.Norm2d(planes * 4) 163 | self.relu = nn.ReLU(inplace=True) 164 | self.se_module = SEModule(planes * 4, reduction=reduction) 165 | self.downsample = downsample 166 | self.stride = stride 167 | 168 | 169 | class SEResNeXtBottleneck(Bottleneck): 170 | """ 171 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 172 | """ 173 | expansion = 4 174 | 175 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 176 | downsample=None, base_width=4): 177 | super(SEResNeXtBottleneck, self).__init__() 178 | width = math.floor(planes * (base_width / 64)) * groups 179 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 180 | stride=1) 181 | self.bn1 = mynn.Norm2d(width) 182 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 183 | padding=1, groups=groups, bias=False) 184 | self.bn2 = mynn.Norm2d(width) 185 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 186 | self.bn3 = mynn.Norm2d(planes * 4) 187 | self.relu = nn.ReLU(inplace=True) 188 | self.se_module = SEModule(planes * 4, reduction=reduction) 189 | self.downsample = downsample 190 | self.stride = stride 191 | 192 | 193 | class SENet(nn.Module): 194 | """ 195 | Main Squeeze Excitation Network Module 196 | """ 197 | 198 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 199 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 200 | downsample_padding=1, num_classes=1000): 201 | """ 202 | Parameters 203 | ---------- 204 | block (nn.Module): Bottleneck class. 205 | - For SENet154: SEBottleneck 206 | - For SE-ResNet models: SEResNetBottleneck 207 | - For SE-ResNeXt models: SEResNeXtBottleneck 208 | layers (list of ints): Number of residual blocks for 4 layers of the 209 | network (layer1...layer4). 210 | groups (int): Number of groups for the 3x3 convolution in each 211 | bottleneck block. 212 | - For SENet154: 64 213 | - For SE-ResNet models: 1 214 | - For SE-ResNeXt models: 32 215 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 216 | - For all models: 16 217 | dropout_p (float or None): Drop probability for the Dropout layer. 218 | If `None` the Dropout layer is not used. 219 | - For SENet154: 0.2 220 | - For SE-ResNet models: None 221 | - For SE-ResNeXt models: None 222 | inplanes (int): Number of input channels for layer1. 223 | - For SENet154: 128 224 | - For SE-ResNet models: 64 225 | - For SE-ResNeXt models: 64 226 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 227 | a single 7x7 convolution in layer0. 228 | - For SENet154: True 229 | - For SE-ResNet models: False 230 | - For SE-ResNeXt models: False 231 | downsample_kernel_size (int): Kernel size for downsampling convolutions 232 | in layer2, layer3 and layer4. 233 | - For SENet154: 3 234 | - For SE-ResNet models: 1 235 | - For SE-ResNeXt models: 1 236 | downsample_padding (int): Padding for downsampling convolutions in 237 | layer2, layer3 and layer4. 238 | - For SENet154: 1 239 | - For SE-ResNet models: 0 240 | - For SE-ResNeXt models: 0 241 | num_classes (int): Number of outputs in `last_linear` layer. 242 | - For all models: 1000 243 | """ 244 | super(SENet, self).__init__() 245 | self.inplanes = inplanes 246 | if input_3x3: 247 | layer0_modules = [ 248 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 249 | bias=False)), 250 | ('bn1', mynn.Norm2d(64)), 251 | ('relu1', nn.ReLU(inplace=True)), 252 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 253 | bias=False)), 254 | ('bn2', mynn.Norm2d(64)), 255 | ('relu2', nn.ReLU(inplace=True)), 256 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 257 | bias=False)), 258 | ('bn3', mynn.Norm2d(inplanes)), 259 | ('relu3', nn.ReLU(inplace=True)), 260 | ] 261 | else: 262 | layer0_modules = [ 263 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 264 | padding=3, bias=False)), 265 | ('bn1', mynn.Norm2d(inplanes)), 266 | ('relu1', nn.ReLU(inplace=True)), 267 | ] 268 | # To preserve compatibility with Caffe weights `ceil_mode=True` 269 | # is used instead of `padding=1`. 270 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 271 | ceil_mode=True))) 272 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 273 | self.layer1 = self._make_layer( 274 | block, 275 | planes=64, 276 | blocks=layers[0], 277 | groups=groups, 278 | reduction=reduction, 279 | downsample_kernel_size=1, 280 | downsample_padding=0 281 | ) 282 | self.layer2 = self._make_layer( 283 | block, 284 | planes=128, 285 | blocks=layers[1], 286 | stride=2, 287 | groups=groups, 288 | reduction=reduction, 289 | downsample_kernel_size=downsample_kernel_size, 290 | downsample_padding=downsample_padding 291 | ) 292 | self.layer3 = self._make_layer( 293 | block, 294 | planes=256, 295 | blocks=layers[2], 296 | stride=1, 297 | groups=groups, 298 | reduction=reduction, 299 | downsample_kernel_size=downsample_kernel_size, 300 | downsample_padding=downsample_padding 301 | ) 302 | self.layer4 = self._make_layer( 303 | block, 304 | planes=512, 305 | blocks=layers[3], 306 | stride=1, 307 | groups=groups, 308 | reduction=reduction, 309 | downsample_kernel_size=downsample_kernel_size, 310 | downsample_padding=downsample_padding 311 | ) 312 | self.avg_pool = nn.AvgPool2d(7, stride=1) 313 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 314 | self.last_linear = nn.Linear(512 * block.expansion, num_classes) 315 | 316 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 317 | downsample_kernel_size=1, downsample_padding=0): 318 | downsample = None 319 | if stride != 1 or self.inplanes != planes * block.expansion: 320 | downsample = nn.Sequential( 321 | nn.Conv2d(self.inplanes, planes * block.expansion, 322 | kernel_size=downsample_kernel_size, stride=stride, 323 | padding=downsample_padding, bias=False), 324 | mynn.Norm2d(planes * block.expansion), 325 | ) 326 | 327 | layers = [] 328 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 329 | downsample)) 330 | self.inplanes = planes * block.expansion 331 | for index in range(1, blocks): 332 | layers.append(block(self.inplanes, planes, groups, reduction)) 333 | 334 | return nn.Sequential(*layers) 335 | 336 | def features(self, x): 337 | """ 338 | Forward Pass through the each layer of SE network 339 | """ 340 | x = self.layer0(x) 341 | x = self.layer1(x) 342 | x = self.layer2(x) 343 | x = self.layer3(x) 344 | x = self.layer4(x) 345 | return x 346 | 347 | def logits(self, x): 348 | """ 349 | AvgPool and Linear Layer 350 | """ 351 | x = self.avg_pool(x) 352 | if self.dropout is not None: 353 | x = self.dropout(x) 354 | x = x.view(x.size(0), -1) 355 | x = self.last_linear(x) 356 | return x 357 | 358 | def forward(self, x): 359 | x = self.features(x) 360 | x = self.logits(x) 361 | return x 362 | 363 | 364 | def initialize_pretrained_model(model, num_classes, settings): 365 | """ 366 | Initialize Pretrain Model Information, 367 | Dowload weights, load weights, set variables 368 | """ 369 | assert num_classes == settings['num_classes'], \ 370 | 'num_classes should be {}, but is {}'.format( 371 | settings['num_classes'], num_classes) 372 | weights = model_zoo.load_url(settings['url']) 373 | model.load_state_dict(weights) 374 | model.input_space = settings['input_space'] 375 | model.input_size = settings['input_size'] 376 | model.input_range = settings['input_range'] 377 | model.mean = settings['mean'] 378 | model.std = settings['std'] 379 | 380 | 381 | 382 | def se_resnext50_32x4d(num_classes=1000): 383 | """ 384 | Defination For SE Resnext50 385 | """ 386 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, 387 | dropout_p=None, inplanes=64, input_3x3=False, 388 | downsample_kernel_size=1, downsample_padding=0, 389 | num_classes=num_classes) 390 | settings = pretrained_settings['se_resnext50_32x4d']['imagenet'] 391 | initialize_pretrained_model(model, num_classes, settings) 392 | return model 393 | 394 | 395 | def se_resnext101_32x4d(num_classes=1000): 396 | """ 397 | Defination For SE Resnext101 398 | """ 399 | 400 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, 401 | dropout_p=None, inplanes=64, input_3x3=False, 402 | downsample_kernel_size=1, downsample_padding=0, 403 | num_classes=num_classes) 404 | settings = pretrained_settings['se_resnext101_32x4d']['imagenet'] 405 | initialize_pretrained_model(model, num_classes, settings) 406 | return model 407 | -------------------------------------------------------------------------------- /src/metaseg/metrics.pyx: -------------------------------------------------------------------------------- 1 | """ 2 | NOTE: 3 | Setup this file first via metrics_setup.py 4 | "python3 metrics_setup.py build_ext --inplace" 5 | """ 6 | 7 | 8 | import numpy as np 9 | cimport numpy as np 10 | 11 | 12 | 13 | 14 | def entropy( probs ): 15 | E = np.sum( np.multiply( probs, np.log(probs+np.finfo(np.float32).eps) ) , axis=-1) / np.log(1.0/probs.shape[-1]) 16 | return np.asarray( E, dtype="float32" ) 17 | 18 | 19 | 20 | def probdist( probs ): 21 | cdef int i, j 22 | arrayA = np.asarray(np.argsort(probs,axis=-1), dtype="uint8") 23 | arrayD = np.ones( probs.shape[:-1], dtype="float32" ) 24 | cdef float[:,:,:] P = probs 25 | cdef float[:,:] D = arrayD 26 | cdef char[:,:,:] A = arrayA 27 | for i in range( arrayD.shape[0] ): 28 | for j in range( arrayD.shape[1] ): 29 | D[i,j] = ( 1 - P[ i, j, A[i,j,-1] ] + P[ i, j, A[i,j,-2] ] ) 30 | return arrayD 31 | 32 | 33 | def varrat( probs ): 34 | cdef int i, j 35 | arrayA = np.asarray(np.argsort(probs,axis=-1), dtype="uint8") 36 | arrayV = np.ones( probs.shape[:-1], dtype="float32" ) 37 | cdef float[:,:,:] P = probs 38 | cdef float[:,:] V = arrayV 39 | cdef char[:,:,:] A = arrayA 40 | for i in range( arrayV.shape[0] ): 41 | for j in range( arrayV.shape[1] ): 42 | V[i,j] = ( 1 - P[ i, j, A[i,j,-1] ] ) 43 | return arrayV 44 | 45 | 46 | 47 | def prediction(probs, gt=None, ignore=False ): 48 | pred = np.asarray( np.argmax( probs, axis=-1 ), dtype="uint8" ) 49 | if ignore and gt is not None: 50 | pred[ gt==255 ] = 255 51 | return pred 52 | 53 | 54 | 55 | 56 | def segment_search( int i, int j, unsigned char[:,:] seg, short int seg_ind, unsigned char[:,:] pred, 57 | np.ndarray marked_array, np.ndarray flag_array, float[:,:,:] probs, unsigned char[:,:] gt, 58 | heatmaps, metrics, unsigned short int[:,:] members_k, 59 | unsigned short int[:,:] members_l, int nclasses, int x_max, int y_max ): 60 | 61 | cdef int k, l, ii, jj, x, y, n_in, n_bd, c, I, U, flag_max_x, flag_min_x, flag_max_y, flag_min_y, ic 62 | cdef unsigned char[:,:] flag 63 | cdef short int[:,:] marked 64 | 65 | if seg[i,j] != 0: 66 | n_in, n_bd = 0, 0 67 | c = seg[i,j] 68 | members_k[0,0], members_k[0,1] = i, j 69 | marked = marked_array 70 | 71 | flag_min_x = flag_max_x = i 72 | flag_min_y = flag_max_y = j 73 | 74 | flag = flag_array 75 | flag[i,j] = 1 76 | I, U = 0, 0 77 | marked[i,j] = seg_ind 78 | 79 | for m in metrics: 80 | metrics[m].append( 0 ) 81 | 82 | # go through union of current segment and corresponding ground truth 83 | # and identify all inner pixels, boundary pixels and 84 | # pixels where ground_truth and prediction match 85 | k = 1 86 | l = 0 87 | num_neighbors = 0 88 | while k > 0 or l > 0: 89 | 90 | flag_k = 0 91 | 92 | if k > 0: 93 | k -= 1 94 | x, y = members_k[k] 95 | flag_k = 1 96 | elif l > 0: 97 | l -= 1 98 | x, y = members_l[l] 99 | 100 | if flag_k: 101 | for ii in range(max(x-1,0),min(x+2,x_max)): 102 | for jj in range(max(y-1,0),min(y+2,y_max)): 103 | if seg[ii,jj] == c and marked[ii,jj] == 0: 104 | marked[ii,jj] = seg_ind 105 | flag[ii,jj] = 1 106 | if ii > flag_max_x: 107 | flag_max_x = ii 108 | elif ii < flag_min_x: 109 | flag_min_x = ii 110 | if jj > flag_max_y: 111 | flag_max_y = jj 112 | elif jj < flag_min_y: 113 | flag_min_y = jj 114 | members_k[k,0], members_k[k,1] = ii, jj 115 | k += 1 116 | elif seg[ii,jj] != c: 117 | # if seg[ii,jj] != 255: 118 | metrics["ndist"+str(pred[ii,jj])][-1] = metrics["ndist"+str(pred[ii,jj])][-1]+1 119 | num_neighbors += 1 120 | marked[x,y] = -seg_ind 121 | if gt != []: 122 | if gt[ii,jj] == c and flag[ii,jj]==0: 123 | flag[ii,jj] = 1 124 | if ii > flag_max_x: 125 | flag_max_x = ii 126 | elif ii < flag_min_x: 127 | flag_min_x = ii 128 | if jj > flag_max_y: 129 | flag_max_y = jj 130 | elif jj < flag_min_y: 131 | flag_min_y = jj 132 | members_l[l,0], members_l[l,1] = ii, jj 133 | l += 1 134 | 135 | if not flag_k and gt != []: 136 | if I == 0: 137 | break 138 | for ii in range(max(x-1,0),min(x+2,x_max)): 139 | for jj in range(max(y-1,0),min(y+2,y_max)): 140 | #if gt[ii,jj] == c and flag[ii,jj]==0: # ordinary IoU 141 | if gt[ii,jj] == c and flag[ii,jj]==0 and seg[ii,jj] != c: # IoU_adj 142 | flag[ii,jj] = 1 143 | if ii > flag_max_x: 144 | flag_max_x = ii 145 | elif ii < flag_min_x: 146 | flag_min_x = ii 147 | if jj > flag_max_y: 148 | flag_max_y = jj 149 | elif jj < flag_min_y: 150 | flag_min_y = jj 151 | members_l[l,0], members_l[l,1] = ii, jj 152 | l += 1 153 | 154 | if flag_k: 155 | if marked[x,y] in [seg_ind,-seg_ind]: 156 | # update heap maps 157 | if marked[x,y] == seg_ind: 158 | for h in heatmaps: 159 | metrics[h+"_in"][-1] += heatmaps[h][x,y] 160 | metrics[h+"_var_in"][-1] += heatmaps[h][x,y]**2 161 | n_in += 1 162 | elif marked[x,y] == -seg_ind: 163 | for h in heatmaps: 164 | metrics[h+"_bd"][-1] += heatmaps[h][x,y] 165 | metrics[h+"_var_bd"][-1] += heatmaps[h][x,y]**2 166 | n_bd += 1 167 | for ic in range(nclasses): 168 | metrics["cprob"+str(ic)][-1] += probs[x,y,ic] 169 | metrics["mean_x"][-1] += x 170 | metrics["mean_y"][-1] += y 171 | if gt != []: 172 | if gt[x,y] == c: 173 | I += 1 174 | 175 | U += 1 176 | 177 | for ii in range(flag_min_x,flag_max_x+1): 178 | for jj in range(flag_min_y,flag_max_y+1): 179 | flag[ii,jj] = 0 180 | 181 | # compute all metrics 182 | # metrics["class" ][-1] = c 183 | if gt != []: 184 | metrics["iou" ][-1] = float(I) / float(U) 185 | metrics["iou0" ][-1] = int(I == 0) 186 | else: 187 | metrics["iou" ][-1] = -1 188 | metrics["iou0" ][-1] = -1 189 | metrics["S" ][-1] = n_in + n_bd 190 | metrics["S_in" ][-1] = n_in 191 | metrics["S_bd" ][-1] = n_bd 192 | metrics["S_rel" ][-1] = float( n_in + n_bd ) / float(n_bd) 193 | metrics["S_rel_in"][-1] = float( n_in ) / float(n_bd) 194 | metrics["mean_x"][-1] /= ( n_in + n_bd ) 195 | metrics["mean_y"][-1] /= ( n_in + n_bd ) 196 | 197 | for nc in range(nclasses): 198 | metrics["cprob"+str(nc)][-1] /= ( n_in + n_bd ) 199 | 200 | for nc in range(nclasses): 201 | metrics["ndist"+str(nc)][-1] /= float(np.max((num_neighbors,1))) 202 | 203 | for h in heatmaps: 204 | metrics[h ][-1] = (metrics[h+"_in"][-1] + metrics[h+"_bd"][-1]) / float( n_in + n_bd ) 205 | if n_in > 0: 206 | metrics[ h+"_in"][-1] /= float(n_in) 207 | metrics[h+"_bd" ][-1] /= float(n_bd) 208 | metrics[h+"_var" ][-1] = (metrics[h+"_var_in"][-1] + metrics[h+"_var_bd"][-1]) / float( n_in + n_bd ) - (metrics[h][-1] **2 ) 209 | if n_in > 0: 210 | metrics[h+"_var_in"][-1] = metrics[h+"_var_in"][-1] / float(n_in) - metrics[h+"_in"][-1]**2 211 | metrics[h+"_var_bd"][-1] = metrics[h+"_var_bd"][-1] / float(n_bd) - metrics[h+"_bd"][-1]**2 212 | metrics[h+"_rel" ][-1] = metrics[h ][-1] * metrics["S_rel" ][-1] 213 | metrics[h+"_rel_in"][-1] = metrics[h+"_in"][-1] * metrics["S_rel_in"][-1] 214 | metrics[h+"_var_rel" ][-1] = metrics[h+"_var" ][-1] * metrics["S_rel" ][-1] 215 | metrics[h+"_var_rel_in"][-1] = metrics[h+"_var_in"][-1] * metrics["S_rel_in"][-1] 216 | 217 | seg_ind +=1 218 | 219 | return marked_array, metrics, seg_ind 220 | 221 | 222 | def compute_metrics_components( probs, gt_train, ood_mask=None, ood_index=None ): 223 | 224 | cdef int i, j 225 | cdef short int seg_ind 226 | cdef np.ndarray marked 227 | cdef np.ndarray members_k 228 | cdef np.ndarray members_l 229 | cdef short int[:,:] M 230 | 231 | gt_train = np.asarray( gt_train, dtype="uint8" ) 232 | probs = np.asarray( np.transpose(probs, (1, 2, 0)), dtype="float32" ) 233 | pred = np.asarray( prediction(probs, gt=gt_train, ignore=False), dtype="uint8" ) 234 | nclasses = probs.shape[-1] 235 | dims = np.asarray( probs.shape[:-1], dtype="uint16" ) 236 | if ood_mask is not None and ood_index is not None: 237 | seg = np.asarray( ood_mask * 255, dtype="uint8" ) 238 | seg[gt_train==255] = 0 239 | gt = np.asarray( np.isin(gt_train, ood_index) * 255, dtype="uint8" ) 240 | else: 241 | seg = pred 242 | gt = gt_train 243 | 244 | marked = np.zeros( dims, dtype="int16" ) 245 | members_k = np.zeros( (np.prod(dims), 2 ), dtype="uint16" ) 246 | members_l = np.zeros( (np.prod(dims), 2 ), dtype="uint16" ) 247 | flag = np.zeros( dims, dtype="uint8" ) 248 | M = marked 249 | 250 | metrics = {} 251 | keys = ["iou", "iou0", "prc", "mean_x", "mean_y"] 252 | for key in keys: 253 | metrics[key] = list([]) 254 | 255 | heatmaps = { "E": entropy( probs ), "D": probdist( probs ), "V": varrat( probs ) } 256 | for m in list(heatmaps)+["S"]: 257 | metrics[m ] = list([]) 258 | metrics[m+"_in" ] = list([]) 259 | metrics[m+"_bd" ] = list([]) 260 | metrics[m+"_rel" ] = list([]) 261 | metrics[m+"_rel_in"] = list([]) 262 | 263 | if m != "S": 264 | metrics[m+"_var" ] = list([]) 265 | metrics[m+"_var_in" ] = list([]) 266 | metrics[m+"_var_bd" ] = list([]) 267 | metrics[m+"_var_rel" ] = list([]) 268 | metrics[m+"_var_rel_in"] = list([]) 269 | 270 | 271 | for i in range(nclasses): 272 | metrics['cprob'+str(i)] = list([]) 273 | metrics['ndist'+str(i)] = list([]) 274 | 275 | seg_ind = 1 276 | 277 | for i in range(dims[0]): 278 | for j in range(dims[1]): 279 | if M[i,j] == 0: 280 | 281 | marked, metrics, seg_ind = segment_search( i, j, seg, seg_ind, pred, marked, flag, probs, gt, heatmaps, 282 | metrics, members_k, members_l, nclasses, dims[0], dims[1] ) 283 | 284 | 285 | return metrics, marked 286 | 287 | 288 | 289 | def segment_search_slim( int i, int j, unsigned char[:,:] seg1, unsigned char[:,:] label_mask, short int seg1_ind, 290 | np.ndarray marked_array, np.ndarray flag_array, unsigned char[:,:] seg2, metrics, 291 | unsigned short int[:,:] members_k, unsigned short int[:,:] members_l, 292 | int x_max, int y_max, int nclasses, float[:,:,:] probs, unsigned char[:,:] pred ): 293 | 294 | cdef int k, l, ii, jj, x, y, n_in, n_bd, c, I, U, flag_max_x, flag_min_x, flag_max_y, flag_min_y, ic 295 | cdef unsigned char[:,:] flag 296 | cdef short int[:,:] marked 297 | 298 | if seg1[i,j] != 0: 299 | 300 | n_in, n_bd = 0, 0 301 | c = seg1[i,j] 302 | c_label = label_mask[i,j] 303 | members_k[0,0], members_k[0,1] = i, j 304 | marked = marked_array 305 | 306 | flag_min_x = flag_max_x = i 307 | flag_min_y = flag_max_y = j 308 | 309 | flag = flag_array 310 | flag[i,j] = 1 311 | I, U = 0, 0 312 | marked[i,j] = seg1_ind 313 | 314 | for m in metrics: 315 | metrics[m].append( 0 ) 316 | 317 | k = 1 318 | l = 0 319 | while k > 0 or l > 0: 320 | 321 | flag_k = 0 322 | 323 | if k > 0: 324 | k -= 1 325 | x, y = members_k[k] 326 | flag_k = 1 327 | elif l > 0: 328 | l -= 1 329 | x, y = members_l[l] 330 | 331 | if flag_k: 332 | for ii in range(max(x-1,0),min(x+2,x_max)): 333 | for jj in range(max(y-1,0),min(y+2,y_max)): 334 | if seg1[ii,jj] == c and marked[ii,jj] == 0: 335 | marked[ii,jj] = seg1_ind 336 | flag[ii,jj] = 1 337 | if ii > flag_max_x: 338 | flag_max_x = ii 339 | elif ii < flag_min_x: 340 | flag_min_x = ii 341 | if jj > flag_max_y: 342 | flag_max_y = jj 343 | elif jj < flag_min_y: 344 | flag_min_y = jj 345 | members_k[k,0], members_k[k,1] = ii, jj 346 | k += 1 347 | elif seg1[ii,jj] != c: 348 | marked[x,y] = -seg1_ind 349 | if seg2 != []: 350 | if seg2[ii,jj] == c and flag[ii,jj]==0: 351 | flag[ii,jj] = 1 352 | if ii > flag_max_x: 353 | flag_max_x = ii 354 | elif ii < flag_min_x: 355 | flag_min_x = ii 356 | if jj > flag_max_y: 357 | flag_max_y = jj 358 | elif jj < flag_min_y: 359 | flag_min_y = jj 360 | members_l[l,0], members_l[l,1] = ii, jj 361 | l += 1 362 | 363 | if not flag_k and seg2 != []: 364 | if I == 0: 365 | break 366 | for ii in range(max(x-1,0),min(x+2,x_max)): 367 | for jj in range(max(y-1,0),min(y+2,y_max)): 368 | #if seg2[ii,jj] == c and flag[ii,jj]==0: # ordinary IoU 369 | if seg2[ii,jj] == c and flag[ii,jj]==0 and seg1[ii,jj] != c: # IoU_adj 370 | flag[ii,jj] = 1 371 | if ii > flag_max_x: 372 | flag_max_x = ii 373 | elif ii < flag_min_x: 374 | flag_min_x = ii 375 | if jj > flag_max_y: 376 | flag_max_y = jj 377 | elif jj < flag_min_y: 378 | flag_min_y = jj 379 | members_l[l,0], members_l[l,1] = ii, jj 380 | l += 1 381 | 382 | if flag_k: 383 | if marked[x,y] in [seg1_ind,-seg1_ind]: 384 | if marked[x,y] == seg1_ind: 385 | n_in += 1 386 | elif marked[x,y] == -seg1_ind: 387 | n_bd += 1 388 | for ic in range(nclasses): 389 | metrics["cprob"+str(ic)][-1] += probs[x,y,ic] 390 | metrics["S_"+str(pred[x,y])][-1] += 1 391 | metrics["mean_x"][-1] += x 392 | metrics["mean_y"][-1] += y 393 | if seg2 != []: 394 | if seg2[x,y] == c: 395 | I += 1 396 | 397 | U += 1 398 | 399 | for ii in range(flag_min_x,flag_max_x+1): 400 | for jj in range(flag_min_y,flag_max_y+1): 401 | flag[ii,jj] = 0 402 | 403 | # compute all metrics 404 | metrics["train_class"][-1] = c #### predicted class, train id 405 | metrics["label_class"][-1] = c_label #### label id 406 | if seg2 != []: 407 | metrics["iou"][-1] = float(I) / float(U) 408 | metrics["iou0"][-1] = int(I == 0) 409 | else: 410 | metrics["iou"][-1] = -1 411 | metrics["iou0"][-1] = -1 412 | metrics["rec" ][-1] = float(I) / float(n_in + n_bd) #### recall if seg1==gt 413 | metrics["S" ][-1] = n_in + n_bd 414 | metrics["S_in" ][-1] = n_in 415 | metrics["S_bd" ][-1] = n_bd 416 | # metrics["S_rel" ][-1] = float( n_in + n_bd ) / float(n_bd) 417 | # metrics["S_rel_in"][-1] = float( n_in ) / float(n_bd) 418 | metrics["mean_x"][-1] /= ( n_in + n_bd ) 419 | metrics["mean_y"][-1] /= ( n_in + n_bd ) 420 | 421 | for nc in range(nclasses): 422 | metrics["cprob"+str(nc)][-1] /= ( n_in + n_bd ) 423 | 424 | seg1_ind +=1 425 | 426 | return marked_array, metrics, seg1_ind 427 | 428 | 429 | 430 | def compute_metrics_mask( probs, mask, gt_train, gt_label, ood_index=None ): 431 | 432 | cdef int i, j 433 | cdef short int seg_ind 434 | cdef np.ndarray marked 435 | cdef np.ndarray members_k 436 | cdef np.ndarray members_l 437 | cdef short int[:,:] M 438 | 439 | dims = np.asarray( mask.shape, dtype="uint16" ) 440 | mask = np.asarray( mask, dtype="uint8" ) 441 | gt_train = np.asarray( gt_train, dtype="uint8" ) 442 | gt_label = np.asarray( gt_label, dtype="uint8" ) 443 | probs = np.asarray( np.transpose(probs, (1, 2, 0)), dtype="float32" ) 444 | pred = np.asarray( prediction(probs, gt=gt_train, ignore=False), dtype="uint8" ) 445 | nclasses = probs.shape[-1] 446 | marked = np.zeros( dims, dtype="int16" ) 447 | members_k = np.zeros( (np.prod(dims), 2 ), dtype="uint16" ) 448 | members_l = np.zeros( (np.prod(dims), 2 ), dtype="uint16" ) 449 | flag = np.zeros( dims, dtype="uint8" ) 450 | M = marked 451 | 452 | if ood_index is not None: 453 | seg = np.asarray( mask * 255, dtype="uint8" ) 454 | seg[gt_train==255] = 0 455 | gt = np.asarray( np.isin(gt_train, ood_index) * 255, dtype="uint8" ) 456 | else: 457 | seg = mask 458 | gt = gt_train 459 | 460 | metrics = {} 461 | keys = ["iou", "iou0", "rec", "train_class", "label_class", "mean_x", "mean_y", "S", "S_in", "S_bd"] 462 | for key in keys: 463 | metrics[key] = list([]) 464 | for i in range(nclasses): 465 | metrics['S_'+str(i)] = list([]) 466 | metrics['cprob'+str(i)] = list([]) 467 | 468 | 469 | seg_ind = 1 470 | for i in range(dims[0]): 471 | for j in range(dims[1]): 472 | if M[i,j] == 0: 473 | marked, metrics, seg_ind = segment_search_slim( i, j, gt, gt_label, seg_ind, marked, flag, seg, metrics, 474 | members_k, members_l, dims[0], dims[1], nclasses, probs, pred) 475 | 476 | return metrics, marked --------------------------------------------------------------------------------