├── Dockerfile ├── LICENSE ├── README.md ├── config.py ├── datasets ├── __init__.py ├── cityscapes.py ├── cityscapes_labels.py └── edge_utils.py ├── docs ├── index.html └── resources │ ├── GSCNN.mp4 │ ├── architecture.jpg │ ├── bibtex.txt │ ├── crop.jpg │ ├── edges.jpg │ ├── gscnn.gif │ ├── intro.jpg │ ├── seg.jpg │ ├── semboundary.jpg │ ├── table.png │ ├── test.jpg │ └── top.jpg ├── loss.py ├── my_functionals ├── DualTaskLoss.py ├── GatedSpatialConv.py ├── __init__.py └── custom_functional.py ├── network ├── Resnet.py ├── SEresnext.py ├── __init__.py ├── gscnn.py ├── mynn.py └── wider_resnet.py ├── optimizer.py ├── train.py ├── transforms ├── joint_transforms.py └── transforms.py └── utils ├── AttrDict.py ├── f_boundary.py ├── image_page.py └── misc.py /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.0-cuda10.0-cudnn7-devel 2 | 3 | RUN apt-get -y update 4 | RUN apt-get -y upgrade 5 | 6 | RUN apt-get update \ 7 | && apt-get install -y software-properties-common wget \ 8 | && add-apt-repository -y ppa:ubuntu-toolchain-r/test \ 9 | && apt-get update \ 10 | && apt-get install -y make git curl vim vim-gnome 11 | 12 | # Install apt-get 13 | RUN apt-get install -y python3-pip python3-dev vim htop python3-tk pkg-config 14 | 15 | RUN pip3 install --upgrade pip==9.0.1 16 | 17 | # Install from pip 18 | RUN pip3 install pyyaml \ 19 | scipy==1.1.0 \ 20 | numpy \ 21 | tensorflow \ 22 | scikit-learn \ 23 | scikit-image \ 24 | matplotlib \ 25 | opencv-python \ 26 | torch==1.0.0 \ 27 | torchvision==0.2.0 \ 28 | torch-encoding==1.0.1 \ 29 | tensorboardX \ 30 | tqdm 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 NVIDIA Corporation. Towaki Takikawa, David Acuna, Varun Jampani, Sanja Fidler 2 | All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | Permission to use, copy, modify, and distribute this software and its documentation 6 | for any non-commercial purpose is hereby granted without fee, provided that the above 7 | copyright notice appear in all copies and that both that copyright notice and this 8 | permission notice appear in supporting documentation, and that the name of the author 9 | not be used in advertising or publicity pertaining to distribution of the software 10 | without specific, written prior permission. 11 | 12 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL 13 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. 14 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL 15 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 16 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING 17 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GSCNN 2 | This is the official code for: 3 | 4 | #### Gated-SCNN: Gated Shape CNNs for Semantic Segmentation 5 | 6 | [Towaki Takikawa](https://tovacinni.github.io), [David Acuna](http://www.cs.toronto.edu/~davidj/), [Varun Jampani](https://varunjampani.github.io), [Sanja Fidler](http://www.cs.toronto.edu/~fidler/) 7 | 8 | ICCV 2019 9 | **[[Paper](https://arxiv.org/abs/1907.05740)] [[Project Page](https://nv-tlabs.github.io/GSCNN/)]** 10 | 11 | ![GSCNN DEMO](docs/resources/gscnn.gif) 12 | 13 | Based on based on https://github.com/NVIDIA/semantic-segmentation. 14 | 15 | ## License 16 | ``` 17 | Copyright (C) 2019 NVIDIA Corporation. Towaki Takikawa, David Acuna, Varun Jampani, Sanja Fidler 18 | All rights reserved. 19 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 20 | 21 | Permission to use, copy, modify, and distribute this software and its documentation 22 | for any non-commercial purpose is hereby granted without fee, provided that the above 23 | copyright notice appear in all copies and that both that copyright notice and this 24 | permission notice appear in supporting documentation, and that the name of the author 25 | not be used in advertising or publicity pertaining to distribution of the software 26 | without specific, written prior permission. 27 | 28 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL 29 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. 30 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL 31 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 32 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING 33 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 34 | ~ 35 | ``` 36 | 37 | ## Usage 38 | 39 | ##### Clone this repo 40 | ```bash 41 | git clone https://github.com/nv-tlabs/GSCNN 42 | cd GSCNN 43 | ``` 44 | 45 | #### Python requirements 46 | 47 | Currently, the code supports Python 3 48 | * numpy 49 | * PyTorch (>=1.1.0) 50 | * torchvision 51 | * scipy 52 | * scikit-image 53 | * tensorboardX 54 | * tqdm 55 | * torch-encoding 56 | * opencv 57 | * PyYAML 58 | 59 | #### Download pretrained models 60 | 61 | Download the pretrained model from the [Google Drive Folder](https://drive.google.com/file/d/1wlhAXg-PfoUM-rFy2cksk43Ng3PpsK2c/view), and save it in 'checkpoints/' 62 | 63 | #### Download inferred images 64 | 65 | Download (if needed) the inferred images from the [Google Drive Folder](https://drive.google.com/file/d/105WYnpSagdlf5-ZlSKWkRVeq-MyKLYOV/view) 66 | 67 | #### Evaluation (Cityscapes) 68 | ```bash 69 | python train.py --evaluate --snapshot checkpoints/best_cityscapes_checkpoint.pth 70 | ``` 71 | 72 | #### Training 73 | 74 | A note on training- we train on 8 NVIDIA GPUs, and as such, training will be an issue with WiderResNet38 if you try to train on a single GPU. 75 | 76 | If you use this code, please cite: 77 | 78 | ``` 79 | @article{takikawa2019gated, 80 | title={Gated-SCNN: Gated Shape CNNs for Semantic Segmentation}, 81 | author={Takikawa, Towaki and Acuna, David and Jampani, Varun and Fidler, Sanja}, 82 | journal={ICCV}, 83 | year={2019} 84 | } 85 | ``` 86 | 87 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code adapted from: 6 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py 7 | 8 | Source License 9 | # Copyright (c) 2017-present, Facebook, Inc. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | ############################################################################## 23 | # 24 | # Based on: 25 | # -------------------------------------------------------- 26 | # Fast R-CNN 27 | # Copyright (c) 2015 Microsoft 28 | # Licensed under The MIT License [see LICENSE for details] 29 | # Written by Ross Girshick 30 | # -------------------------------------------------------- 31 | """ 32 | 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | from __future__ import unicode_literals 37 | 38 | import copy 39 | import six 40 | import os.path as osp 41 | 42 | from ast import literal_eval 43 | import numpy as np 44 | import yaml 45 | import torch 46 | import torch.nn as nn 47 | from torch.nn import init 48 | 49 | 50 | from utils.AttrDict import AttrDict 51 | 52 | 53 | __C = AttrDict() 54 | # Consumers can get config by: 55 | # from fast_rcnn_config import cfg 56 | cfg = __C 57 | __C.EPOCH = 0 58 | __C.CLASS_UNIFORM_PCT=0.0 59 | __C.BATCH_WEIGHTING=False 60 | __C.BORDER_WINDOW=1 61 | __C.REDUCE_BORDER_EPOCH= -1 62 | __C.STRICTBORDERCLASS= None 63 | 64 | __C.DATASET =AttrDict() 65 | __C.DATASET.CITYSCAPES_DIR='/home/username/data/cityscapes' 66 | __C.DATASET.CV_SPLITS=3 67 | 68 | __C.MODEL = AttrDict() 69 | __C.MODEL.BN = 'regularnorm' 70 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d 71 | __C.MODEL.BIGMEMORY = False 72 | 73 | def assert_and_infer_cfg(args, make_immutable=True): 74 | """Call this function in your script after you have finished setting all cfg 75 | values that are necessary (e.g., merging a config from a file, merging 76 | command line config options, etc.). By default, this function will also 77 | mark the global cfg as immutable to prevent changing the global cfg settings 78 | during script execution (which can lead to hard to debug errors or code 79 | that's harder to understand than is necessary). 80 | """ 81 | 82 | if args.batch_weighting: 83 | __C.BATCH_WEIGHTING=True 84 | 85 | if args.syncbn: 86 | import encoding 87 | __C.MODEL.BN = 'syncnorm' 88 | __C.MODEL.BNFUNC = encoding.nn.BatchNorm2d 89 | else: 90 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d 91 | print('Using regular batch norm') 92 | 93 | if make_immutable: 94 | cfg.immutable(True) 95 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from datasets import cityscapes 7 | import torchvision.transforms as standard_transforms 8 | import torchvision.utils as vutils 9 | import transforms.joint_transforms as joint_transforms 10 | import transforms.transforms as extended_transforms 11 | from torch.utils.data import DataLoader 12 | 13 | def setup_loaders(args): 14 | ''' 15 | input: argument passed by the user 16 | return: training data loader, validation data loader loader, train_set 17 | ''' 18 | 19 | if args.dataset == 'cityscapes': 20 | args.dataset_cls = cityscapes 21 | args.train_batch_size = args.bs_mult * args.ngpu 22 | if args.bs_mult_val > 0: 23 | args.val_batch_size = args.bs_mult_val * args.ngpu 24 | else: 25 | args.val_batch_size = args.bs_mult * args.ngpu 26 | else: 27 | raise 28 | 29 | args.num_workers = 4 * args.ngpu 30 | if args.test_mode: 31 | args.num_workers = 0 #1 32 | 33 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 34 | 35 | # Geometric image transformations 36 | train_joint_transform_list = [ 37 | joint_transforms.RandomSizeAndCrop(args.crop_size, 38 | False, 39 | pre_size=args.pre_size, 40 | scale_min=args.scale_min, 41 | scale_max=args.scale_max, 42 | ignore_index=args.dataset_cls.ignore_label), 43 | joint_transforms.Resize(args.crop_size), 44 | joint_transforms.RandomHorizontallyFlip()] 45 | 46 | #if args.rotate: 47 | # train_joint_transform_list += [joint_transforms.RandomRotate(args.rotate)] 48 | 49 | train_joint_transform = joint_transforms.Compose(train_joint_transform_list) 50 | 51 | # Image appearance transformations 52 | train_input_transform = [] 53 | if args.color_aug: 54 | train_input_transform += [extended_transforms.ColorJitter( 55 | brightness=args.color_aug, 56 | contrast=args.color_aug, 57 | saturation=args.color_aug, 58 | hue=args.color_aug)] 59 | 60 | if args.bblur: 61 | train_input_transform += [extended_transforms.RandomBilateralBlur()] 62 | elif args.gblur: 63 | train_input_transform += [extended_transforms.RandomGaussianBlur()] 64 | else: 65 | pass 66 | 67 | train_input_transform += [standard_transforms.ToTensor(), 68 | standard_transforms.Normalize(*mean_std)] 69 | train_input_transform = standard_transforms.Compose(train_input_transform) 70 | 71 | val_input_transform = standard_transforms.Compose([ 72 | standard_transforms.ToTensor(), 73 | standard_transforms.Normalize(*mean_std) 74 | ]) 75 | 76 | target_transform = extended_transforms.MaskToTensor() 77 | 78 | target_train_transform = extended_transforms.MaskToTensor() 79 | 80 | if args.dataset == 'cityscapes': 81 | city_mode = 'train' ## Can be trainval 82 | city_quality = 'fine' 83 | train_set = args.dataset_cls.CityScapes( 84 | city_quality, city_mode, 0, 85 | joint_transform=train_joint_transform, 86 | transform=train_input_transform, 87 | target_transform=target_train_transform, 88 | dump_images=args.dump_augmentation_images, 89 | cv_split=args.cv) 90 | val_set = args.dataset_cls.CityScapes('fine', 'val', 0, 91 | transform=val_input_transform, 92 | target_transform=target_transform, 93 | cv_split=args.cv) 94 | else: 95 | raise 96 | 97 | train_sampler = None 98 | val_sampler = None 99 | 100 | train_loader = DataLoader(train_set, batch_size=args.train_batch_size, 101 | num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler) 102 | val_loader = DataLoader(val_set, batch_size=args.val_batch_size, 103 | num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler) 104 | 105 | return train_loader, val_loader, train_set 106 | 107 | -------------------------------------------------------------------------------- /datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | from torch.utils import data 11 | from collections import defaultdict 12 | import math 13 | import logging 14 | import datasets.cityscapes_labels as cityscapes_labels 15 | import json 16 | from config import cfg 17 | import torchvision.transforms as transforms 18 | import datasets.edge_utils as edge_utils 19 | 20 | trainid_to_name = cityscapes_labels.trainId2name 21 | id_to_trainid = cityscapes_labels.label2trainid 22 | num_classes = 19 23 | ignore_label = 255 24 | root = cfg.DATASET.CITYSCAPES_DIR 25 | 26 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 27 | 153, 153, 153, 250, 170, 30, 28 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 29 | 255, 0, 0, 0, 0, 142, 0, 0, 70, 30 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 31 | zero_pad = 256 * 3 - len(palette) 32 | for i in range(zero_pad): 33 | palette.append(0) 34 | 35 | 36 | def colorize_mask(mask): 37 | # mask: numpy array of the mask 38 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 39 | new_mask.putpalette(palette) 40 | return new_mask 41 | 42 | 43 | def add_items(items, aug_items, cities, img_path, mask_path, mask_postfix, mode, maxSkip): 44 | 45 | for c in cities: 46 | c_items = [name.split('_leftImg8bit.png')[0] for name in 47 | os.listdir(os.path.join(img_path, c))] 48 | for it in c_items: 49 | item = (os.path.join(img_path, c, it + '_leftImg8bit.png'), 50 | os.path.join(mask_path, c, it + mask_postfix)) 51 | items.append(item) 52 | 53 | def make_cv_splits(img_dir_name): 54 | ''' 55 | Create splits of train/val data. 56 | A split is a lists of cities. 57 | split0 is aligned with the default Cityscapes train/val. 58 | ''' 59 | trn_path = os.path.join(root, img_dir_name, 'leftImg8bit', 'train') 60 | val_path = os.path.join(root, img_dir_name, 'leftImg8bit', 'val') 61 | 62 | trn_cities = ['train/' + c for c in os.listdir(trn_path)] 63 | val_cities = ['val/' + c for c in os.listdir(val_path)] 64 | 65 | # want reproducible randomly shuffled 66 | trn_cities = sorted(trn_cities) 67 | 68 | all_cities = val_cities + trn_cities 69 | num_val_cities = len(val_cities) 70 | num_cities = len(all_cities) 71 | 72 | cv_splits = [] 73 | for split_idx in range(cfg.DATASET.CV_SPLITS): 74 | split = {} 75 | split['train'] = [] 76 | split['val'] = [] 77 | offset = split_idx * num_cities // cfg.DATASET.CV_SPLITS 78 | for j in range(num_cities): 79 | if j >= offset and j < (offset + num_val_cities): 80 | split['val'].append(all_cities[j]) 81 | else: 82 | split['train'].append(all_cities[j]) 83 | cv_splits.append(split) 84 | 85 | return cv_splits 86 | 87 | 88 | def make_split_coarse(img_path): 89 | ''' 90 | Create a train/val split for coarse 91 | return: city split in train 92 | ''' 93 | all_cities = os.listdir(img_path) 94 | all_cities = sorted(all_cities) # needs to always be the same 95 | val_cities = [] # Can manually set cities to not be included into train split 96 | 97 | split = {} 98 | split['val'] = val_cities 99 | split['train'] = [c for c in all_cities if c not in val_cities] 100 | return split 101 | 102 | def make_test_split(img_dir_name): 103 | test_path = os.path.join(root, img_dir_name, 'leftImg8bit', 'test') 104 | test_cities = ['test/' + c for c in os.listdir(test_path)] 105 | 106 | return test_cities 107 | 108 | 109 | def make_dataset(quality, mode, maxSkip=0, fine_coarse_mult=6, cv_split=0): 110 | ''' 111 | Assemble list of images + mask files 112 | 113 | fine - modes: train/val/test/trainval cv:0,1,2 114 | coarse - modes: train/val cv:na 115 | 116 | path examples: 117 | leftImg8bit_trainextra/leftImg8bit/train_extra/augsburg 118 | gtCoarse/gtCoarse/train_extra/augsburg 119 | ''' 120 | items = [] 121 | aug_items = [] 122 | 123 | if quality == 'fine': 124 | assert mode in ['train', 'val', 'test', 'trainval'] 125 | img_dir_name = 'leftImg8bit_trainvaltest' 126 | img_path = os.path.join(root, img_dir_name, 'leftImg8bit') 127 | mask_path = os.path.join(root, 'gtFine_trainvaltest', 'gtFine') 128 | mask_postfix = '_gtFine_labelIds.png' 129 | cv_splits = make_cv_splits(img_dir_name) 130 | if mode == 'trainval': 131 | modes = ['train', 'val'] 132 | else: 133 | modes = [mode] 134 | for mode in modes: 135 | if mode == 'test': 136 | cv_splits = make_test_split(img_dir_name) 137 | add_items(items, cv_splits, img_path, mask_path, 138 | mask_postfix) 139 | else: 140 | logging.info('{} fine cities: '.format(mode) + str(cv_splits[cv_split][mode])) 141 | 142 | add_items(items, aug_items, cv_splits[cv_split][mode], img_path, mask_path, 143 | mask_postfix, mode, maxSkip) 144 | else: 145 | raise 'unknown cityscapes quality {}'.format(quality) 146 | logging.info('Cityscapes-{}: {} images'.format(mode, len(items)+len(aug_items))) 147 | return items, aug_items 148 | 149 | 150 | class CityScapes(data.Dataset): 151 | 152 | def __init__(self, quality, mode, maxSkip=0, joint_transform=None, sliding_crop=None, 153 | transform=None, target_transform=None, dump_images=False, 154 | cv_split=None, eval_mode=False, 155 | eval_scales=None, eval_flip=False): 156 | self.quality = quality 157 | self.mode = mode 158 | self.maxSkip = maxSkip 159 | self.joint_transform = joint_transform 160 | self.sliding_crop = sliding_crop 161 | self.transform = transform 162 | self.target_transform = target_transform 163 | self.dump_images = dump_images 164 | self.eval_mode = eval_mode 165 | self.eval_flip = eval_flip 166 | self.eval_scales = None 167 | if eval_scales != None: 168 | self.eval_scales = [float(scale) for scale in eval_scales.split(",")] 169 | 170 | if cv_split: 171 | self.cv_split = cv_split 172 | assert cv_split < cfg.DATASET.CV_SPLITS, \ 173 | 'expected cv_split {} to be < CV_SPLITS {}'.format( 174 | cv_split, cfg.DATASET.CV_SPLITS) 175 | else: 176 | self.cv_split = 0 177 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split) 178 | if len(self.imgs) == 0: 179 | raise RuntimeError('Found 0 images, please check the data set') 180 | 181 | self.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 182 | 183 | def _eval_get_item(self, img, mask, scales, flip_bool): 184 | return_imgs = [] 185 | for flip in range(int(flip_bool)+1): 186 | imgs = [] 187 | if flip : 188 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 189 | for scale in scales: 190 | w,h = img.size 191 | target_w, target_h = int(w * scale), int(h * scale) 192 | resize_img =img.resize((target_w, target_h)) 193 | tensor_img = transforms.ToTensor()(resize_img) 194 | final_tensor = transforms.Normalize(*self.mean_std)(tensor_img) 195 | imgs.append(tensor_img) 196 | return_imgs.append(imgs) 197 | return return_imgs, mask 198 | 199 | 200 | 201 | def __getitem__(self, index): 202 | 203 | img_path, mask_path = self.imgs[index] 204 | 205 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 206 | img_name = os.path.splitext(os.path.basename(img_path))[0] 207 | 208 | mask = np.array(mask) 209 | mask_copy = mask.copy() 210 | for k, v in id_to_trainid.items(): 211 | mask_copy[mask == k] = v 212 | 213 | if self.eval_mode: 214 | return self._eval_get_item(img, mask_copy, self.eval_scales, self.eval_flip), img_name 215 | 216 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 217 | 218 | # Image Transformations 219 | if self.joint_transform is not None: 220 | img, mask = self.joint_transform(img, mask) 221 | if self.transform is not None: 222 | img = self.transform(img) 223 | if self.target_transform is not None: 224 | mask = self.target_transform(mask) 225 | 226 | _edgemap = mask.numpy() 227 | _edgemap = edge_utils.mask_to_onehot(_edgemap, num_classes) 228 | 229 | _edgemap = edge_utils.onehot_to_binary_edges(_edgemap, 2, num_classes) 230 | 231 | edgemap = torch.from_numpy(_edgemap).float() 232 | 233 | # Debug 234 | if self.dump_images: 235 | outdir = '../../dump_imgs_{}'.format(self.mode) 236 | os.makedirs(outdir, exist_ok=True) 237 | out_img_fn = os.path.join(outdir, img_name + '.png') 238 | out_msk_fn = os.path.join(outdir, img_name + '_mask.png') 239 | mask_img = colorize_mask(np.array(mask)) 240 | img.save(out_img_fn) 241 | mask_img.save(out_msk_fn) 242 | 243 | return img, mask, edgemap, img_name 244 | 245 | def __len__(self): 246 | return len(self.imgs) 247 | 248 | 249 | def make_dataset_video(): 250 | img_dir_name = 'leftImg8bit_demoVideo' 251 | img_path = os.path.join(root, img_dir_name, 'leftImg8bit/demoVideo') 252 | items = [] 253 | categories = os.listdir(img_path) 254 | for c in categories[1:]: 255 | c_items = [name.split('_leftImg8bit.png')[0] for name in 256 | os.listdir(os.path.join(img_path, c))] 257 | for it in c_items: 258 | item = os.path.join(img_path, c, it + '_leftImg8bit.png') 259 | items.append(item) 260 | return items 261 | 262 | 263 | class CityScapesVideo(data.Dataset): 264 | 265 | def __init__(self, transform=None): 266 | self.imgs = make_dataset_video() 267 | if len(self.imgs) == 0: 268 | raise RuntimeError('Found 0 images, please check the data set') 269 | self.transform = transform 270 | 271 | def __getitem__(self, index): 272 | img_path = self.imgs[index] 273 | img = Image.open(img_path).convert('RGB') 274 | img_name = os.path.splitext(os.path.basename(img_path))[0] 275 | 276 | if self.transform is not None: 277 | img = self.transform(img) 278 | return img, img_name 279 | 280 | def __len__(self): 281 | return len(self.imgs) 282 | 283 | -------------------------------------------------------------------------------- /datasets/cityscapes_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # File taken from https://github.com/mcordts/cityscapesScripts/ 6 | # License File Available at: 7 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt 8 | 9 | # ---------------------- 10 | # The Cityscapes Dataset 11 | # ---------------------- 12 | # 13 | # 14 | # License agreement 15 | # ----------------- 16 | # 17 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree: 18 | # 19 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions. 20 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website. 21 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character. 22 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain. 23 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt). 24 | # 25 | # 26 | # Contact 27 | # ------- 28 | # 29 | # Marius Cordts, Mohamed Omran 30 | # www.cityscapes-dataset.net 31 | 32 | """ 33 | 34 | from collections import namedtuple 35 | 36 | 37 | #-------------------------------------------------------------------------------- 38 | # Definitions 39 | #-------------------------------------------------------------------------------- 40 | 41 | # a label and all meta information 42 | Label = namedtuple( 'Label' , [ 43 | 44 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 45 | # We use them to uniquely name a class 46 | 47 | 'id' , # An integer ID that is associated with this label. 48 | # The IDs are used to represent the label in ground truth images 49 | # An ID of -1 means that this label does not have an ID and thus 50 | # is ignored when creating ground truth images (e.g. license plate). 51 | # Do not modify these IDs, since exactly these IDs are expected by the 52 | # evaluation server. 53 | 54 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 55 | # ground truth images with train IDs, using the tools provided in the 56 | # 'preparation' folder. However, make sure to validate or submit results 57 | # to our evaluation server using the regular IDs above! 58 | # For trainIds, multiple labels might have the same ID. Then, these labels 59 | # are mapped to the same class in the ground truth images. For the inverse 60 | # mapping, we use the label that is defined first in the list below. 61 | # For example, mapping all void-type classes to the same ID in training, 62 | # might make sense for some approaches. 63 | # Max value is 255! 64 | 65 | 'category' , # The name of the category that this label belongs to 66 | 67 | 'categoryId' , # The ID of this category. Used to create ground truth images 68 | # on category level. 69 | 70 | 'hasInstances', # Whether this label distinguishes between single instances or not 71 | 72 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 73 | # during evaluations or not 74 | 75 | 'color' , # The color of this label 76 | ] ) 77 | 78 | 79 | #-------------------------------------------------------------------------------- 80 | # A list of all labels 81 | #-------------------------------------------------------------------------------- 82 | 83 | # Please adapt the train IDs as appropriate for you approach. 84 | # Note that you might want to ignore labels with ID 255 during training. 85 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 86 | # Make sure to provide your results using the original IDs and not the training IDs. 87 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 88 | 89 | labels = [ 90 | # name id trainId category catId hasInstances ignoreInEval color 91 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 92 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 93 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 94 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 95 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 96 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 97 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 98 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 99 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 100 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 101 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 102 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 103 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 104 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 105 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 106 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 107 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 108 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 109 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 110 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 111 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 112 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 113 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 114 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 115 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 116 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 117 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 118 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 119 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 120 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 121 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 122 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 123 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 124 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 125 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 126 | Label( 'license plate' , 34 , 255 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 127 | ] 128 | 129 | 130 | #-------------------------------------------------------------------------------- 131 | # Create dictionaries for a fast lookup 132 | #-------------------------------------------------------------------------------- 133 | 134 | # Please refer to the main method below for example usages! 135 | 136 | # name to label object 137 | name2label = { label.name : label for label in labels } 138 | # id to label object 139 | id2label = { label.id : label for label in labels } 140 | # trainId to label object 141 | trainId2label = { label.trainId : label for label in reversed(labels) } 142 | # label2trainid 143 | label2trainid = { label.id : label.trainId for label in labels } 144 | # trainId to label object 145 | trainId2name = { label.trainId : label.name for label in labels } 146 | trainId2color = { label.trainId : label.color for label in labels } 147 | # category to list of label objects 148 | category2labels = {} 149 | for label in labels: 150 | category = label.category 151 | if category in category2labels: 152 | category2labels[category].append(label) 153 | else: 154 | category2labels[category] = [label] 155 | 156 | #-------------------------------------------------------------------------------- 157 | # Assure single instance name 158 | #-------------------------------------------------------------------------------- 159 | 160 | # returns the label name that describes a single instance (if possible) 161 | # e.g. input | output 162 | # ---------------------- 163 | # car | car 164 | # cargroup | car 165 | # foo | None 166 | # foogroup | None 167 | # skygroup | None 168 | def assureSingleInstanceName( name ): 169 | # if the name is known, it is not a group 170 | if name in name2label: 171 | return name 172 | # test if the name actually denotes a group 173 | if not name.endswith("group"): 174 | return None 175 | # remove group 176 | name = name[:-len("group")] 177 | # test if the new name exists 178 | if not name in name2label: 179 | return None 180 | # test if the new name denotes a label that actually has instances 181 | if not name2label[name].hasInstances: 182 | return None 183 | # all good then 184 | return name 185 | 186 | #-------------------------------------------------------------------------------- 187 | # Main for testing 188 | #-------------------------------------------------------------------------------- 189 | 190 | # just a dummy main 191 | if __name__ == "__main__": 192 | # Print all the labels 193 | print("List of cityscapes labels:") 194 | print("") 195 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))) 196 | print((" " + ('-' * 98))) 197 | for label in labels: 198 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))) 199 | print("") 200 | 201 | print("Example usages:") 202 | 203 | # Map from name to label 204 | name = 'car' 205 | id = name2label[name].id 206 | print(("ID of label '{name}': {id}".format( name=name, id=id ))) 207 | 208 | # Map from ID to label 209 | category = id2label[id].category 210 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category ))) 211 | 212 | # Map from trainID to label 213 | trainId = 0 214 | name = trainId2label[trainId].name 215 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))) 216 | -------------------------------------------------------------------------------- /datasets/edge_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | from scipy.ndimage.morphology import distance_transform_edt 10 | 11 | def mask_to_onehot(mask, num_classes): 12 | """ 13 | Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one 14 | hot encoding vector 15 | 16 | """ 17 | _mask = [mask == (i + 1) for i in range(num_classes)] 18 | return np.array(_mask).astype(np.uint8) 19 | 20 | def onehot_to_mask(mask): 21 | """ 22 | Converts a mask (K,H,W) to (H,W) 23 | """ 24 | _mask = np.argmax(mask, axis=0) 25 | _mask[_mask != 0] += 1 26 | return _mask 27 | 28 | def onehot_to_multiclass_edges(mask, radius, num_classes): 29 | """ 30 | Converts a segmentation mask (K,H,W) to an edgemap (K,H,W) 31 | 32 | """ 33 | if radius < 0: 34 | return mask 35 | 36 | # We need to pad the borders for boundary conditions 37 | mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0) 38 | 39 | channels = [] 40 | for i in range(num_classes): 41 | dist = distance_transform_edt(mask_pad[i, :])+distance_transform_edt(1.0-mask_pad[i, :]) 42 | dist = dist[1:-1, 1:-1] 43 | dist[dist > radius] = 0 44 | dist = (dist > 0).astype(np.uint8) 45 | channels.append(dist) 46 | 47 | return np.array(channels) 48 | 49 | def onehot_to_binary_edges(mask, radius, num_classes): 50 | """ 51 | Converts a segmentation mask (K,H,W) to a binary edgemap (H,W) 52 | 53 | """ 54 | 55 | if radius < 0: 56 | return mask 57 | 58 | # We need to pad the borders for boundary conditions 59 | mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0) 60 | 61 | edgemap = np.zeros(mask.shape[1:]) 62 | 63 | for i in range(num_classes): 64 | dist = distance_transform_edt(mask_pad[i, :])+distance_transform_edt(1.0-mask_pad[i, :]) 65 | dist = dist[1:-1, 1:-1] 66 | dist[dist > radius] = 0 67 | edgemap += dist 68 | edgemap = np.expand_dims(edgemap, axis=0) 69 | edgemap = (edgemap > 0).astype(np.uint8) 70 | return edgemap 71 | 72 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | Redirecting to https://research.nvidia.com/labs/toronto-ai/GSCNN/ 4 | 5 | 6 | -------------------------------------------------------------------------------- /docs/resources/GSCNN.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/GSCNN.mp4 -------------------------------------------------------------------------------- /docs/resources/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/architecture.jpg -------------------------------------------------------------------------------- /docs/resources/bibtex.txt: -------------------------------------------------------------------------------- 1 | @inproceedings{Takikawa2019GatedSCNNGS, 2 | title={Gated-SCNN: Gated Shape CNNs for Semantic Segmentation}, 3 | author={Towaki Takikawa and David Acuna and Varun Jampani and Sanja Fidler}, 4 | year={2019} 5 | } 6 | 7 | -------------------------------------------------------------------------------- /docs/resources/crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/crop.jpg -------------------------------------------------------------------------------- /docs/resources/edges.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/edges.jpg -------------------------------------------------------------------------------- /docs/resources/gscnn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/gscnn.gif -------------------------------------------------------------------------------- /docs/resources/intro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/intro.jpg -------------------------------------------------------------------------------- /docs/resources/seg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/seg.jpg -------------------------------------------------------------------------------- /docs/resources/semboundary.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/semboundary.jpg -------------------------------------------------------------------------------- /docs/resources/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/table.png -------------------------------------------------------------------------------- /docs/resources/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/test.jpg -------------------------------------------------------------------------------- /docs/resources/top.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/top.jpg -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import logging 10 | import numpy as np 11 | from config import cfg 12 | from my_functionals.DualTaskLoss import DualTaskLoss 13 | 14 | def get_loss(args): 15 | ''' 16 | Get the criterion based on the loss function 17 | args: 18 | return: criterion 19 | ''' 20 | 21 | if args.img_wt_loss: 22 | criterion = ImageBasedCrossEntropyLoss2d( 23 | classes=args.dataset_cls.num_classes, size_average=True, 24 | ignore_index=args.dataset_cls.ignore_label, 25 | upper_bound=args.wt_bound).cuda() 26 | elif args.joint_edgeseg_loss: 27 | criterion = JointEdgeSegLoss(classes=args.dataset_cls.num_classes, 28 | ignore_index=args.dataset_cls.ignore_label, upper_bound=args.wt_bound, 29 | edge_weight=args.edge_weight, seg_weight=args.seg_weight, att_weight=args.att_weight, dual_weight=args.dual_weight).cuda() 30 | 31 | else: 32 | criterion = CrossEntropyLoss2d(size_average=True, 33 | ignore_index=args.dataset_cls.ignore_label).cuda() 34 | 35 | criterion_val = JointEdgeSegLoss(classes=args.dataset_cls.num_classes, mode='val', 36 | ignore_index=args.dataset_cls.ignore_label, upper_bound=args.wt_bound, 37 | edge_weight=args.edge_weight, seg_weight=args.seg_weight).cuda() 38 | 39 | return criterion, criterion_val 40 | 41 | class JointEdgeSegLoss(nn.Module): 42 | def __init__(self, classes, weight=None, reduction='mean', ignore_index=255, 43 | norm=False, upper_bound=1.0, mode='train', 44 | edge_weight=1, seg_weight=1, att_weight=1, dual_weight=1, edge='none'): 45 | super(JointEdgeSegLoss, self).__init__() 46 | self.num_classes = classes 47 | if mode == 'train': 48 | self.seg_loss = ImageBasedCrossEntropyLoss2d( 49 | classes=classes, ignore_index=ignore_index, upper_bound=upper_bound).cuda() 50 | elif mode == 'val': 51 | self.seg_loss = CrossEntropyLoss2d(size_average=True, 52 | ignore_index=ignore_index).cuda() 53 | 54 | self.edge_weight = edge_weight 55 | self.seg_weight = seg_weight 56 | self.att_weight = att_weight 57 | self.dual_weight = dual_weight 58 | 59 | self.dual_task = DualTaskLoss() 60 | 61 | def bce2d(self, input, target): 62 | n, c, h, w = input.size() 63 | 64 | log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1) 65 | target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1) 66 | target_trans = target_t.clone() 67 | 68 | pos_index = (target_t ==1) 69 | neg_index = (target_t ==0) 70 | ignore_index=(target_t >1) 71 | 72 | target_trans[pos_index] = 1 73 | target_trans[neg_index] = 0 74 | 75 | pos_index = pos_index.data.cpu().numpy().astype(bool) 76 | neg_index = neg_index.data.cpu().numpy().astype(bool) 77 | ignore_index=ignore_index.data.cpu().numpy().astype(bool) 78 | 79 | weight = torch.Tensor(log_p.size()).fill_(0) 80 | weight = weight.numpy() 81 | pos_num = pos_index.sum() 82 | neg_num = neg_index.sum() 83 | sum_num = pos_num + neg_num 84 | weight[pos_index] = neg_num*1.0 / sum_num 85 | weight[neg_index] = pos_num*1.0 / sum_num 86 | 87 | weight[ignore_index] = 0 88 | 89 | weight = torch.from_numpy(weight) 90 | weight = weight.cuda() 91 | loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, size_average=True) 92 | return loss 93 | 94 | def edge_attention(self, input, target, edge): 95 | n, c, h, w = input.size() 96 | filler = torch.ones_like(target) * 255 97 | return self.seg_loss(input, 98 | torch.where(edge.max(1)[0] > 0.8, target, filler)) 99 | 100 | def forward(self, inputs, targets): 101 | segin, edgein = inputs 102 | segmask, edgemask = targets 103 | 104 | losses = {} 105 | 106 | losses['seg_loss'] = self.seg_weight * self.seg_loss(segin, segmask) 107 | losses['edge_loss'] = self.edge_weight * 20 * self.bce2d(edgein, edgemask) 108 | losses['att_loss'] = self.att_weight * self.edge_attention(segin, segmask, edgein) 109 | losses['dual_loss'] = self.dual_weight * self.dual_task(segin, segmask) 110 | 111 | return losses 112 | 113 | #Img Weighted Loss 114 | class ImageBasedCrossEntropyLoss2d(nn.Module): 115 | 116 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255, 117 | norm=False, upper_bound=1.0): 118 | super(ImageBasedCrossEntropyLoss2d, self).__init__() 119 | logging.info("Using Per Image based weighted loss") 120 | self.num_classes = classes 121 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index) 122 | self.norm = norm 123 | self.upper_bound = upper_bound 124 | self.batch_weights = cfg.BATCH_WEIGHTING 125 | 126 | def calculateWeights(self, target): 127 | hist = np.histogram(target.flatten(), range( 128 | self.num_classes + 1), normed=True)[0] 129 | if self.norm: 130 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 131 | else: 132 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 133 | return hist 134 | 135 | def forward(self, inputs, targets): 136 | target_cpu = targets.data.cpu().numpy() 137 | if self.batch_weights: 138 | weights = self.calculateWeights(target_cpu) 139 | self.nll_loss.weight = torch.Tensor(weights).cuda() 140 | 141 | loss = 0.0 142 | for i in range(0, inputs.shape[0]): 143 | if not self.batch_weights: 144 | weights = self.calculateWeights(target_cpu[i]) 145 | self.nll_loss.weight = torch.Tensor(weights).cuda() 146 | 147 | loss += self.nll_loss(F.log_softmax(inputs[i].unsqueeze(0)), 148 | targets[i].unsqueeze(0)) 149 | return loss 150 | 151 | 152 | #Cross Entroply NLL Loss 153 | class CrossEntropyLoss2d(nn.Module): 154 | def __init__(self, weight=None, size_average=True, ignore_index=255): 155 | super(CrossEntropyLoss2d, self).__init__() 156 | logging.info("Using Cross Entropy Loss") 157 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index) 158 | 159 | def forward(self, inputs, targets): 160 | return self.nll_loss(F.log_softmax(inputs), targets) 161 | 162 | -------------------------------------------------------------------------------- /my_functionals/DualTaskLoss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code adapted from: 6 | # https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb 7 | # 8 | # MIT License 9 | # 10 | # Copyright (c) 2016 Eric Jang 11 | # 12 | # Permission is hereby granted, free of charge, to any person obtaining a copy 13 | # of this software and associated documentation files (the "Software"), to deal 14 | # in the Software without restriction, including without limitation the rights 15 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | # copies of the Software, and to permit persons to whom the Software is 17 | # furnished to do so, subject to the following conditions: 18 | # 19 | # The above copyright notice and this permission notice shall be included in all 20 | # copies or substantial portions of the Software. 21 | # 22 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | # SOFTWARE. 29 | """ 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | import numpy as np 36 | from my_functionals.custom_functional import compute_grad_mag 37 | 38 | def perturbate_input_(input, n_elements=200): 39 | N, C, H, W = input.shape 40 | assert N == 1 41 | c_ = np.random.random_integers(0, C - 1, n_elements) 42 | h_ = np.random.random_integers(0, H - 1, n_elements) 43 | w_ = np.random.random_integers(0, W - 1, n_elements) 44 | for c_idx in c_: 45 | for h_idx in h_: 46 | for w_idx in w_: 47 | input[0, c_idx, h_idx, w_idx] = 1 48 | return input 49 | 50 | def _sample_gumbel(shape, eps=1e-10): 51 | """ 52 | Sample from Gumbel(0, 1) 53 | 54 | based on 55 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb , 56 | (MIT license) 57 | """ 58 | U = torch.rand(shape).cuda() 59 | return - torch.log(eps - torch.log(U + eps)) 60 | 61 | 62 | def _gumbel_softmax_sample(logits, tau=1, eps=1e-10): 63 | """ 64 | Draw a sample from the Gumbel-Softmax distribution 65 | 66 | based on 67 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb 68 | (MIT license) 69 | """ 70 | assert logits.dim() == 3 71 | gumbel_noise = _sample_gumbel(logits.size(), eps=eps) 72 | y = logits + gumbel_noise 73 | return F.softmax(y / tau, 1) 74 | 75 | 76 | def _one_hot_embedding(labels, num_classes): 77 | """Embedding labels to one-hot form. 78 | 79 | Args: 80 | labels: (LongTensor) class labels, sized [N,]. 81 | num_classes: (int) number of classes. 82 | 83 | Returns: 84 | (tensor) encoded labels, sized [N, #classes]. 85 | """ 86 | 87 | y = torch.eye(num_classes).cuda() 88 | return y[labels].permute(0,3,1,2) 89 | 90 | class DualTaskLoss(nn.Module): 91 | def __init__(self, cuda=False): 92 | super(DualTaskLoss, self).__init__() 93 | self._cuda = cuda 94 | return 95 | 96 | def forward(self, input_logits, gts, ignore_pixel=255): 97 | """ 98 | :param input_logits: NxCxHxW 99 | :param gt_semantic_masks: NxCxHxW 100 | :return: final loss 101 | """ 102 | N, C, H, W = input_logits.shape 103 | th = 1e-8 # 1e-10 104 | eps = 1e-10 105 | ignore_mask = (gts == ignore_pixel).detach() 106 | input_logits = torch.where(ignore_mask.view(N, 1, H, W).expand(N, 19, H, W), 107 | torch.zeros(N,C,H,W).cuda(), 108 | input_logits) 109 | gt_semantic_masks = gts.detach() 110 | gt_semantic_masks = torch.where(ignore_mask, torch.zeros(N,H,W).long().cuda(), gt_semantic_masks) 111 | gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 19).detach() 112 | 113 | g = _gumbel_softmax_sample(input_logits.view(N, C, -1), tau=0.5) 114 | g = g.reshape((N, C, H, W)) 115 | g = compute_grad_mag(g, cuda=self._cuda) 116 | 117 | g_hat = compute_grad_mag(gt_semantic_masks, cuda=self._cuda) 118 | 119 | g = g.view(N, -1) 120 | g_hat = g_hat.view(N, -1) 121 | loss_ewise = F.l1_loss(g, g_hat, reduction='none', reduce=False) 122 | 123 | p_plus_g_mask = (g >= th).detach().float() 124 | loss_p_plus_g = torch.sum(loss_ewise * p_plus_g_mask) / (torch.sum(p_plus_g_mask) + eps) 125 | 126 | p_plus_g_hat_mask = (g_hat >= th).detach().float() 127 | loss_p_plus_g_hat = torch.sum(loss_ewise * p_plus_g_hat_mask) / (torch.sum(p_plus_g_hat_mask) + eps) 128 | 129 | total_loss = 0.5 * loss_p_plus_g + 0.5 * loss_p_plus_g_hat 130 | 131 | return total_loss 132 | 133 | -------------------------------------------------------------------------------- /my_functionals/GatedSpatialConv.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.nn.modules.conv import _ConvNd 10 | from torch.nn.modules.utils import _pair 11 | import numpy as np 12 | import math 13 | import network.mynn as mynn 14 | import my_functionals.custom_functional as myF 15 | class GatedSpatialConv2d(_ConvNd): 16 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 17 | padding=0, dilation=1, groups=1, bias=False): 18 | """ 19 | 20 | :param in_channels: 21 | :param out_channels: 22 | :param kernel_size: 23 | :param stride: 24 | :param padding: 25 | :param dilation: 26 | :param groups: 27 | :param bias: 28 | """ 29 | 30 | kernel_size = _pair(kernel_size) 31 | stride = _pair(stride) 32 | padding = _pair(padding) 33 | dilation = _pair(dilation) 34 | super(GatedSpatialConv2d, self).__init__( 35 | in_channels, out_channels, kernel_size, stride, padding, dilation, 36 | False, _pair(0), groups, bias, 'zeros') 37 | 38 | self._gate_conv = nn.Sequential( 39 | mynn.Norm2d(in_channels+1), 40 | nn.Conv2d(in_channels+1, in_channels+1, 1), 41 | nn.ReLU(), 42 | nn.Conv2d(in_channels+1, 1, 1), 43 | mynn.Norm2d(1), 44 | nn.Sigmoid() 45 | ) 46 | 47 | def forward(self, input_features, gating_features): 48 | """ 49 | 50 | :param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch). 51 | :param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map. 52 | :return: 53 | """ 54 | alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1)) 55 | 56 | input_features = (input_features * (alphas + 1)) 57 | return F.conv2d(input_features, self.weight, self.bias, self.stride, 58 | self.padding, self.dilation, self.groups) 59 | 60 | def reset_parameters(self): 61 | nn.init.xavier_normal_(self.weight) 62 | if self.bias is not None: 63 | nn.init.zeros_(self.bias) 64 | 65 | 66 | class Conv2dPad(nn.Conv2d): 67 | def forward(self, input): 68 | return myF.conv2d_same(input,self.weight,self.groups) 69 | 70 | class HighFrequencyGatedSpatialConv2d(_ConvNd): 71 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 72 | padding=0, dilation=1, groups=1, bias=False): 73 | """ 74 | 75 | :param in_channels: 76 | :param out_channels: 77 | :param kernel_size: 78 | :param stride: 79 | :param padding: 80 | :param dilation: 81 | :param groups: 82 | :param bias: 83 | """ 84 | 85 | kernel_size = _pair(kernel_size) 86 | stride = _pair(stride) 87 | padding = _pair(padding) 88 | dilation = _pair(dilation) 89 | super(HighFrequencyGatedSpatialConv2d, self).__init__( 90 | in_channels, out_channels, kernel_size, stride, padding, dilation, 91 | False, _pair(0), groups, bias) 92 | 93 | self._gate_conv = nn.Sequential( 94 | mynn.Norm2d(in_channels+1), 95 | nn.Conv2d(in_channels+1, in_channels+1, 1), 96 | nn.ReLU(), 97 | nn.Conv2d(in_channels+1, 1, 1), 98 | mynn.Norm2d(1), 99 | nn.Sigmoid() 100 | ) 101 | 102 | kernel_size = 7 103 | sigma = 3 104 | 105 | x_cord = torch.arange(kernel_size).float() 106 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size).float() 107 | y_grid = x_grid.t().float() 108 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float() 109 | 110 | mean = (kernel_size - 1)/2. 111 | variance = sigma**2. 112 | gaussian_kernel = (1./(2.*math.pi*variance)) *\ 113 | torch.exp( 114 | -torch.sum((xy_grid - mean)**2., dim=-1) /\ 115 | (2*variance) 116 | ) 117 | 118 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel) 119 | 120 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size) 121 | gaussian_kernel = gaussian_kernel.repeat(in_channels, 1, 1, 1) 122 | 123 | self.gaussian_filter = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, padding=3, 124 | kernel_size=kernel_size, groups=in_channels, bias=False) 125 | 126 | self.gaussian_filter.weight.data = gaussian_kernel 127 | self.gaussian_filter.weight.requires_grad = False 128 | 129 | self.cw = nn.Conv2d(in_channels * 2, in_channels, 1) 130 | 131 | self.procdog = nn.Sequential( 132 | nn.Conv2d(in_channels, in_channels, 1), 133 | mynn.Norm2d(in_channels), 134 | nn.Sigmoid() 135 | ) 136 | 137 | def forward(self, input_features, gating_features): 138 | """ 139 | 140 | :param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch). 141 | :param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map. 142 | :return: 143 | """ 144 | n, c, h, w = input_features.size() 145 | smooth_features = self.gaussian_filter(input_features) 146 | dog_features = input_features - smooth_features 147 | dog_features = self.cw(torch.cat((dog_features, input_features), dim=1)) 148 | 149 | alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1)) 150 | 151 | dog_features = dog_features * (alphas + 1) 152 | 153 | return F.conv2d(dog_features, self.weight, self.bias, self.stride, 154 | self.padding, self.dilation, self.groups) 155 | 156 | def reset_parameters(self): 157 | nn.init.xavier_normal_(self.weight) 158 | if self.bias is not None: 159 | nn.init.zeros_(self.bias) 160 | 161 | def t(): 162 | import matplotlib.pyplot as plt 163 | 164 | canny_map_filters_in = 8 165 | canny_map = np.random.normal(size=(1, canny_map_filters_in, 10, 10)) # NxCxHxW 166 | resnet_map = np.random.normal(size=(1, 1, 10, 10)) # NxCxHxW 167 | plt.imshow(canny_map[0, 0]) 168 | plt.show() 169 | 170 | canny_map = torch.from_numpy(canny_map).float() 171 | resnet_map = torch.from_numpy(resnet_map).float() 172 | 173 | gconv = GatedSpatialConv2d(canny_map_filters_in, canny_map_filters_in, 174 | kernel_size=3, stride=1, padding=1) 175 | output_map = gconv(canny_map, resnet_map) 176 | print('done') 177 | 178 | 179 | if __name__ == "__main__": 180 | t() 181 | 182 | -------------------------------------------------------------------------------- /my_functionals/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | -------------------------------------------------------------------------------- /my_functionals/custom_functional.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torchvision.transforms.functional import pad 9 | import numpy as np 10 | 11 | 12 | def calc_pad_same(in_siz, out_siz, stride, ksize): 13 | """Calculate same padding width. 14 | Args: 15 | ksize: kernel size [I, J]. 16 | Returns: 17 | pad_: Actual padding width. 18 | """ 19 | return (out_siz - 1) * stride + ksize - in_siz 20 | 21 | 22 | def conv2d_same(input, kernel, groups,bias=None,stride=1,padding=0,dilation=1): 23 | n, c, h, w = input.shape 24 | kout, ki_c_g, kh, kw = kernel.shape 25 | pw = calc_pad_same(w, w, 1, kw) 26 | ph = calc_pad_same(h, h, 1, kh) 27 | pw_l = pw // 2 28 | pw_r = pw - pw_l 29 | ph_t = ph // 2 30 | ph_b = ph - ph_t 31 | 32 | input_ = F.pad(input, (pw_l, pw_r, ph_t, ph_b)) 33 | result = F.conv2d(input_, kernel, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 34 | assert result.shape == input.shape 35 | return result 36 | 37 | 38 | def gradient_central_diff(input, cuda): 39 | return input, input 40 | kernel = [[1, 0, -1]] 41 | kernel_t = 0.5 * torch.Tensor(kernel) * -1. # pytorch implements correlation instead of conv 42 | if type(cuda) is int: 43 | if cuda != -1: 44 | kernel_t = kernel_t.cuda(device=cuda) 45 | else: 46 | if cuda is True: 47 | kernel_t = kernel_t.cuda() 48 | n, c, h, w = input.shape 49 | 50 | x = conv2d_same(input, kernel_t.unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), c) 51 | y = conv2d_same(input, kernel_t.t().unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), c) 52 | return x, y 53 | 54 | 55 | def compute_single_sided_diferences(o_x, o_y, input): 56 | # n,c,h,w 57 | #input = input.clone() 58 | o_y[:, :, 0, :] = input[:, :, 1, :].clone() - input[:, :, 0, :].clone() 59 | o_x[:, :, :, 0] = input[:, :, :, 1].clone() - input[:, :, :, 0].clone() 60 | # -- 61 | o_y[:, :, -1, :] = input[:, :, -1, :].clone() - input[:, :, -2, :].clone() 62 | o_x[:, :, :, -1] = input[:, :, :, -1].clone() - input[:, :, :, -2].clone() 63 | return o_x, o_y 64 | 65 | 66 | def numerical_gradients_2d(input, cuda=False): 67 | """ 68 | numerical gradients implementation over batches using torch group conv operator. 69 | the single sided differences are re-computed later. 70 | it matches np.gradient(image) with the difference than here output=x,y for an image while there output=y,x 71 | :param input: N,C,H,W 72 | :param cuda: whether or not use cuda 73 | :return: X,Y 74 | """ 75 | n, c, h, w = input.shape 76 | assert h > 1 and w > 1 77 | x, y = gradient_central_diff(input, cuda) 78 | return x, y 79 | 80 | 81 | def convTri(input, r, cuda=False): 82 | """ 83 | Convolves an image by a 2D triangle filter (the 1D triangle filter f is 84 | [1:r r+1 r:-1:1]/(r+1)^2, the 2D version is simply conv2(f,f')) 85 | :param input: 86 | :param r: integer filter radius 87 | :param cuda: move the kernel to gpu 88 | :return: 89 | """ 90 | if (r <= 1): 91 | raise ValueError() 92 | n, c, h, w = input.shape 93 | return input 94 | f = list(range(1, r + 1)) + [r + 1] + list(reversed(range(1, r + 1))) 95 | kernel = torch.Tensor([f]) / (r + 1) ** 2 96 | if type(cuda) is int: 97 | if cuda != -1: 98 | kernel = kernel.cuda(device=cuda) 99 | else: 100 | if cuda is True: 101 | kernel = kernel.cuda() 102 | 103 | # padding w 104 | input_ = F.pad(input, (1, 1, 0, 0), mode='replicate') 105 | input_ = F.pad(input_, (r, r, 0, 0), mode='reflect') 106 | input_ = [input_[:, :, :, :r], input, input_[:, :, :, -r:]] 107 | input_ = torch.cat(input_, 3) 108 | t = input_ 109 | 110 | # padding h 111 | input_ = F.pad(input_, (0, 0, 1, 1), mode='replicate') 112 | input_ = F.pad(input_, (0, 0, r, r), mode='reflect') 113 | input_ = [input_[:, :, :r, :], t, input_[:, :, -r:, :]] 114 | input_ = torch.cat(input_, 2) 115 | 116 | output = F.conv2d(input_, 117 | kernel.unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), 118 | padding=0, groups=c) 119 | output = F.conv2d(output, 120 | kernel.t().unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), 121 | padding=0, groups=c) 122 | return output 123 | 124 | 125 | def compute_normal(E, cuda=False): 126 | if torch.sum(torch.isnan(E)) != 0: 127 | print('nans found here') 128 | import ipdb; 129 | ipdb.set_trace() 130 | E_ = convTri(E, 4, cuda) 131 | Ox, Oy = numerical_gradients_2d(E_, cuda) 132 | Oxx, _ = numerical_gradients_2d(Ox, cuda) 133 | Oxy, Oyy = numerical_gradients_2d(Oy, cuda) 134 | 135 | aa = Oyy * torch.sign(-(Oxy + 1e-5)) / (Oxx + 1e-5) 136 | t = torch.atan(aa) 137 | O = torch.remainder(t, np.pi) 138 | 139 | if torch.sum(torch.isnan(O)) != 0: 140 | print('nans found here') 141 | import ipdb; 142 | ipdb.set_trace() 143 | 144 | return O 145 | 146 | 147 | def compute_normal_2(E, cuda=False): 148 | if torch.sum(torch.isnan(E)) != 0: 149 | print('nans found here') 150 | import ipdb; 151 | ipdb.set_trace() 152 | E_ = convTri(E, 4, cuda) 153 | Ox, Oy = numerical_gradients_2d(E_, cuda) 154 | Oxx, _ = numerical_gradients_2d(Ox, cuda) 155 | Oxy, Oyy = numerical_gradients_2d(Oy, cuda) 156 | 157 | aa = Oyy * torch.sign(-(Oxy + 1e-5)) / (Oxx + 1e-5) 158 | t = torch.atan(aa) 159 | O = torch.remainder(t, np.pi) 160 | 161 | if torch.sum(torch.isnan(O)) != 0: 162 | print('nans found here') 163 | import ipdb; 164 | ipdb.set_trace() 165 | 166 | return O, (Oyy, Oxx) 167 | 168 | 169 | def compute_grad_mag(E, cuda=False): 170 | E_ = convTri(E, 4, cuda) 171 | Ox, Oy = numerical_gradients_2d(E_, cuda) 172 | mag = torch.sqrt(torch.mul(Ox,Ox) + torch.mul(Oy,Oy) + 1e-6) 173 | mag = mag / mag.max(); 174 | 175 | return mag 176 | -------------------------------------------------------------------------------- /network/Resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code Adapted from: 6 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 7 | # 8 | # BSD 3-Clause License 9 | # 10 | # Copyright (c) 2017, 11 | # All rights reserved. 12 | # 13 | # Redistribution and use in source and binary forms, with or without 14 | # modification, are permitted provided that the following conditions are met: 15 | # 16 | # * Redistributions of source code must retain the above copyright notice, this 17 | # list of conditions and the following disclaimer. 18 | # 19 | # * Redistributions in binary form must reproduce the above copyright notice, 20 | # this list of conditions and the following disclaimer in the documentation 21 | # and/or other materials provided with the distribution. 22 | # 23 | # * Neither the name of the copyright holder nor the names of its 24 | # contributors may be used to endorse or promote products derived from 25 | # this software without specific prior written permission. 26 | # 27 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 28 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 29 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 30 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 31 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 32 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 33 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 34 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 35 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 36 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 37 | """ 38 | 39 | import torch.nn as nn 40 | import math 41 | import torch.utils.model_zoo as model_zoo 42 | import network.mynn as mynn 43 | 44 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 45 | 'resnet152'] 46 | 47 | 48 | model_urls = { 49 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 50 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 51 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 52 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 53 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 54 | } 55 | 56 | 57 | def conv3x3(in_planes, out_planes, stride=1): 58 | """3x3 convolution with padding""" 59 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 60 | padding=1, bias=False) 61 | 62 | 63 | class BasicBlock(nn.Module): 64 | expansion = 1 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None): 67 | super(BasicBlock, self).__init__() 68 | self.conv1 = conv3x3(inplanes, planes, stride) 69 | self.bn1 = mynn.Norm2d(planes) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.conv2 = conv3x3(planes, planes) 72 | self.bn2 = mynn.Norm2d(planes) 73 | self.downsample = downsample 74 | self.stride = stride 75 | for m in self.modules(): 76 | if isinstance(m, nn.Conv2d): 77 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 78 | elif isinstance(m, nn.BatchNorm2d): 79 | nn.init.constant_(m.weight, 1) 80 | nn.init.constant_(m.bias, 0) 81 | 82 | def forward(self, x): 83 | residual = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class Bottleneck(nn.Module): 102 | expansion = 4 103 | 104 | def __init__(self, inplanes, planes, stride=1, downsample=None): 105 | super(Bottleneck, self).__init__() 106 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 107 | self.bn1 = mynn.Norm2d(planes) 108 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 109 | padding=1, bias=False) 110 | self.bn2 = mynn.Norm2d(planes) 111 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 112 | self.bn3 = mynn.Norm2d(planes * self.expansion) 113 | self.relu = nn.ReLU(inplace=True) 114 | self.downsample = downsample 115 | self.stride = stride 116 | 117 | def forward(self, x): 118 | residual = x 119 | 120 | out = self.conv1(x) 121 | out = self.bn1(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv2(out) 125 | out = self.bn2(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv3(out) 129 | out = self.bn3(out) 130 | 131 | if self.downsample is not None: 132 | residual = self.downsample(x) 133 | 134 | out += residual 135 | out = self.relu(out) 136 | 137 | return out 138 | 139 | 140 | class ResNet(nn.Module): 141 | 142 | def __init__(self, block, layers, num_classes=1000): 143 | self.inplanes = 64 144 | super(ResNet, self).__init__() 145 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 146 | bias=False) 147 | self.bn1 = mynn.Norm2d(64) 148 | self.relu = nn.ReLU(inplace=True) 149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 150 | self.layer1 = self._make_layer(block, 64, layers[0]) 151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 154 | self.avgpool = nn.AvgPool2d(7, stride=1) 155 | self.fc = nn.Linear(512 * block.expansion, num_classes) 156 | 157 | for m in self.modules(): 158 | if isinstance(m, nn.Conv2d): 159 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 160 | elif isinstance(m, nn.BatchNorm2d): 161 | nn.init.constant_(m.weight, 1) 162 | nn.init.constant_(m.bias, 0) 163 | 164 | def _make_layer(self, block, planes, blocks, stride=1): 165 | downsample = None 166 | if stride != 1 or self.inplanes != planes * block.expansion: 167 | downsample = nn.Sequential( 168 | nn.Conv2d(self.inplanes, planes * block.expansion, 169 | kernel_size=1, stride=stride, bias=False), 170 | mynn.Norm2d(planes * block.expansion), 171 | ) 172 | 173 | layers = [] 174 | layers.append(block(self.inplanes, planes, stride, downsample)) 175 | self.inplanes = planes * block.expansion 176 | for i in range(1, blocks): 177 | layers.append(block(self.inplanes, planes)) 178 | 179 | return nn.Sequential(*layers) 180 | 181 | def forward(self, x): 182 | x = self.conv1(x) 183 | x = self.bn1(x) 184 | x = self.relu(x) 185 | x = self.maxpool(x) 186 | 187 | x = self.layer1(x) 188 | x = self.layer2(x) 189 | x = self.layer3(x) 190 | x = self.layer4(x) 191 | 192 | x = self.avgpool(x) 193 | x = x.view(x.size(0), -1) 194 | x = self.fc(x) 195 | 196 | return x 197 | 198 | 199 | def resnet18(pretrained=True, **kwargs): 200 | """Constructs a ResNet-18 model. 201 | 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 208 | return model 209 | 210 | 211 | def resnet34(pretrained=True, **kwargs): 212 | """Constructs a ResNet-34 model. 213 | 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | """ 217 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 220 | return model 221 | 222 | 223 | def resnet50(pretrained=True, **kwargs): 224 | """Constructs a ResNet-50 model. 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | """ 229 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 230 | if pretrained: 231 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 232 | return model 233 | 234 | 235 | def resnet101(pretrained=True, **kwargs): 236 | """Constructs a ResNet-101 model. 237 | 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | """ 241 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 242 | if pretrained: 243 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 244 | return model 245 | 246 | 247 | def resnet152(pretrained=True, **kwargs): 248 | """Constructs a ResNet-152 model. 249 | 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | """ 253 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 254 | if pretrained: 255 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 256 | return model 257 | -------------------------------------------------------------------------------- /network/SEresnext.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code adapted from: 6 | # https://github.com/Cadene/pretrained-models.pytorch 7 | # 8 | # BSD 3-Clause License 9 | # 10 | # Copyright (c) 2017, Remi Cadene 11 | # All rights reserved. 12 | # 13 | # Redistribution and use in source and binary forms, with or without 14 | # modification, are permitted provided that the following conditions are met: 15 | # 16 | # * Redistributions of source code must retain the above copyright notice, this 17 | # list of conditions and the following disclaimer. 18 | # 19 | # * Redistributions in binary form must reproduce the above copyright notice, 20 | # this list of conditions and the following disclaimer in the documentation 21 | # and/or other materials provided with the distribution. 22 | # 23 | # * Neither the name of the copyright holder nor the names of its 24 | # contributors may be used to endorse or promote products derived from 25 | # this software without specific prior written permission. 26 | # 27 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 28 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 29 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 30 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 31 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 32 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 33 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 34 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 35 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 36 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 37 | """ 38 | 39 | 40 | from collections import OrderedDict 41 | import math 42 | import network.mynn as mynn 43 | import torch.nn as nn 44 | from torch.utils import model_zoo 45 | 46 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 47 | 'se_resnext50_32x4d', 'se_resnext101_32x4d'] 48 | 49 | pretrained_settings = { 50 | 'se_resnext50_32x4d': { 51 | 'imagenet': { 52 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 53 | 'input_space': 'RGB', 54 | 'input_size': [3, 224, 224], 55 | 'input_range': [0, 1], 56 | 'mean': [0.485, 0.456, 0.406], 57 | 'std': [0.229, 0.224, 0.225], 58 | 'num_classes': 1000 59 | } 60 | }, 61 | 'se_resnext101_32x4d': { 62 | 'imagenet': { 63 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 64 | 'input_space': 'RGB', 65 | 'input_size': [3, 224, 224], 66 | 'input_range': [0, 1], 67 | 'mean': [0.485, 0.456, 0.406], 68 | 'std': [0.229, 0.224, 0.225], 69 | 'num_classes': 1000 70 | } 71 | }, 72 | } 73 | 74 | 75 | class SEModule(nn.Module): 76 | 77 | def __init__(self, channels, reduction): 78 | super(SEModule, self).__init__() 79 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 80 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 81 | padding=0) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 84 | padding=0) 85 | self.sigmoid = nn.Sigmoid() 86 | 87 | def forward(self, x): 88 | module_input = x 89 | x = self.avg_pool(x) 90 | x = self.fc1(x) 91 | x = self.relu(x) 92 | x = self.fc2(x) 93 | x = self.sigmoid(x) 94 | return module_input * x 95 | 96 | 97 | class Bottleneck(nn.Module): 98 | """ 99 | Base class for bottlenecks that implements `forward()` method. 100 | """ 101 | def forward(self, x): 102 | residual = x 103 | 104 | out = self.conv1(x) 105 | out = self.bn1(out) 106 | out = self.relu(out) 107 | 108 | out = self.conv2(out) 109 | out = self.bn2(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv3(out) 113 | out = self.bn3(out) 114 | 115 | if self.downsample is not None: 116 | residual = self.downsample(x) 117 | 118 | out = self.se_module(out) + residual 119 | out = self.relu(out) 120 | 121 | return out 122 | 123 | 124 | class SEBottleneck(Bottleneck): 125 | """ 126 | Bottleneck for SENet154. 127 | """ 128 | expansion = 4 129 | 130 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 131 | downsample=None): 132 | super(SEBottleneck, self).__init__() 133 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 134 | self.bn1 = mynn.Norm2d(planes * 2) 135 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, 136 | stride=stride, padding=1, groups=groups, 137 | bias=False) 138 | self.bn2 = mynn.Norm2d(planes * 4) 139 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, 140 | bias=False) 141 | self.bn3 = mynn.Norm2d(planes * 4) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.se_module = SEModule(planes * 4, reduction=reduction) 144 | self.downsample = downsample 145 | self.stride = stride 146 | 147 | 148 | class SEResNetBottleneck(Bottleneck): 149 | """ 150 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 151 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 152 | (the latter is used in the torchvision implementation of ResNet). 153 | """ 154 | expansion = 4 155 | 156 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 157 | downsample=None): 158 | super(SEResNetBottleneck, self).__init__() 159 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 160 | stride=stride) 161 | self.bn1 = mynn.Norm2d(planes) 162 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 163 | groups=groups, bias=False) 164 | self.bn2 = mynn.Norm2d(planes) 165 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 166 | self.bn3 = mynn.Norm2d(planes * 4) 167 | self.relu = nn.ReLU(inplace=True) 168 | self.se_module = SEModule(planes * 4, reduction=reduction) 169 | self.downsample = downsample 170 | self.stride = stride 171 | 172 | 173 | class SEResNeXtBottleneck(Bottleneck): 174 | """ 175 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 176 | """ 177 | expansion = 4 178 | 179 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 180 | downsample=None, base_width=4): 181 | super(SEResNeXtBottleneck, self).__init__() 182 | width = math.floor(planes * (base_width / 64)) * groups 183 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 184 | stride=1) 185 | self.bn1 = mynn.Norm2d(width) 186 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 187 | padding=1, groups=groups, bias=False) 188 | self.bn2 = mynn.Norm2d(width) 189 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 190 | self.bn3 = mynn.Norm2d(planes * 4) 191 | self.relu = nn.ReLU(inplace=True) 192 | self.se_module = SEModule(planes * 4, reduction=reduction) 193 | self.downsample = downsample 194 | self.stride = stride 195 | 196 | 197 | class SENet(nn.Module): 198 | 199 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 200 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 201 | downsample_padding=1, num_classes=1000): 202 | """ 203 | Parameters 204 | ---------- 205 | block (nn.Module): Bottleneck class. 206 | - For SENet154: SEBottleneck 207 | - For SE-ResNet models: SEResNetBottleneck 208 | - For SE-ResNeXt models: SEResNeXtBottleneck 209 | layers (list of ints): Number of residual blocks for 4 layers of the 210 | network (layer1...layer4). 211 | groups (int): Number of groups for the 3x3 convolution in each 212 | bottleneck block. 213 | - For SENet154: 64 214 | - For SE-ResNet models: 1 215 | - For SE-ResNeXt models: 32 216 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 217 | - For all models: 16 218 | dropout_p (float or None): Drop probability for the Dropout layer. 219 | If `None` the Dropout layer is not used. 220 | - For SENet154: 0.2 221 | - For SE-ResNet models: None 222 | - For SE-ResNeXt models: None 223 | inplanes (int): Number of input channels for layer1. 224 | - For SENet154: 128 225 | - For SE-ResNet models: 64 226 | - For SE-ResNeXt models: 64 227 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 228 | a single 7x7 convolution in layer0. 229 | - For SENet154: True 230 | - For SE-ResNet models: False 231 | - For SE-ResNeXt models: False 232 | downsample_kernel_size (int): Kernel size for downsampling convolutions 233 | in layer2, layer3 and layer4. 234 | - For SENet154: 3 235 | - For SE-ResNet models: 1 236 | - For SE-ResNeXt models: 1 237 | downsample_padding (int): Padding for downsampling convolutions in 238 | layer2, layer3 and layer4. 239 | - For SENet154: 1 240 | - For SE-ResNet models: 0 241 | - For SE-ResNeXt models: 0 242 | num_classes (int): Number of outputs in `last_linear` layer. 243 | - For all models: 1000 244 | """ 245 | super(SENet, self).__init__() 246 | self.inplanes = inplanes 247 | 248 | if input_3x3: 249 | layer0_modules = [ 250 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 251 | bias=False)), 252 | ('bn1', mynn.Norm2d(64)), 253 | ('relu1', nn.ReLU(inplace=True)), 254 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 255 | bias=False)), 256 | ('bn2', mynn.Norm2d(64)), 257 | ('relu2', nn.ReLU(inplace=True)), 258 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 259 | bias=False)), 260 | ('bn3', mynn.Norm2d(inplanes)), 261 | ('relu3', nn.ReLU(inplace=True)), 262 | ] 263 | else: 264 | layer0_modules = [ 265 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 266 | padding=3, bias=False)), 267 | ('bn1', mynn.Norm2d(inplanes)), 268 | ('relu1', nn.ReLU(inplace=True)), 269 | ] 270 | # To preserve compatibility with Caffe weights `ceil_mode=True` 271 | # is used instead of `padding=1`. 272 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 273 | ceil_mode=True))) 274 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 275 | self.layer1 = self._make_layer( 276 | block, 277 | planes=64, 278 | blocks=layers[0], 279 | groups=groups, 280 | reduction=reduction, 281 | downsample_kernel_size=1, 282 | downsample_padding=0 283 | ) 284 | self.layer2 = self._make_layer( 285 | block, 286 | planes=128, 287 | blocks=layers[1], 288 | stride=2, 289 | groups=groups, 290 | reduction=reduction, 291 | downsample_kernel_size=downsample_kernel_size, 292 | downsample_padding=downsample_padding 293 | ) 294 | self.layer3 = self._make_layer( 295 | block, 296 | planes=256, 297 | blocks=layers[2], 298 | stride=1, 299 | groups=groups, 300 | reduction=reduction, 301 | downsample_kernel_size=downsample_kernel_size, 302 | downsample_padding=downsample_padding 303 | ) 304 | self.layer4 = self._make_layer( 305 | block, 306 | planes=512, 307 | blocks=layers[3], 308 | stride=1, 309 | groups=groups, 310 | reduction=reduction, 311 | downsample_kernel_size=downsample_kernel_size, 312 | downsample_padding=downsample_padding 313 | ) 314 | self.avg_pool = nn.AvgPool2d(7, stride=1) 315 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 316 | self.last_linear = nn.Linear(512 * block.expansion, num_classes) 317 | 318 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 319 | downsample_kernel_size=1, downsample_padding=0): 320 | downsample = None 321 | if stride != 1 or self.inplanes != planes * block.expansion: 322 | downsample = nn.Sequential( 323 | nn.Conv2d(self.inplanes, planes * block.expansion, 324 | kernel_size=downsample_kernel_size, stride=stride, 325 | padding=downsample_padding, bias=False), 326 | mynn.Norm2d(planes * block.expansion), 327 | ) 328 | 329 | layers = [] 330 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 331 | downsample)) 332 | self.inplanes = planes * block.expansion 333 | for i in range(1, blocks): 334 | layers.append(block(self.inplanes, planes, groups, reduction)) 335 | 336 | return nn.Sequential(*layers) 337 | 338 | def features(self, x): 339 | x = self.layer0(x) 340 | x = self.layer1(x) 341 | x = self.layer2(x) 342 | x = self.layer3(x) 343 | x = self.layer4(x) 344 | return x 345 | 346 | def logits(self, x): 347 | x = self.avg_pool(x) 348 | if self.dropout is not None: 349 | x = self.dropout(x) 350 | x = x.view(x.size(0), -1) 351 | x = self.last_linear(x) 352 | return x 353 | 354 | def forward(self, x): 355 | x = self.features(x) 356 | x = self.logits(x) 357 | return x 358 | 359 | 360 | def initialize_pretrained_model(model, num_classes, settings): 361 | assert num_classes == settings['num_classes'], \ 362 | 'num_classes should be {}, but is {}'.format( 363 | settings['num_classes'], num_classes) 364 | weights = model_zoo.load_url(settings['url']) 365 | model.load_state_dict(weights) 366 | model.input_space = settings['input_space'] 367 | model.input_size = settings['input_size'] 368 | model.input_range = settings['input_range'] 369 | model.mean = settings['mean'] 370 | model.std = settings['std'] 371 | 372 | 373 | 374 | def se_resnext50_32x4d(num_classes=1000): 375 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, 376 | dropout_p=None, inplanes=64, input_3x3=False, 377 | downsample_kernel_size=1, downsample_padding=0, 378 | num_classes=num_classes) 379 | settings = pretrained_settings['se_resnext50_32x4d']['imagenet'] 380 | initialize_pretrained_model(model, num_classes, settings) 381 | return model 382 | 383 | 384 | def se_resnext101_32x4d(num_classes=1000): 385 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, 386 | dropout_p=None, inplanes=64, input_3x3=False, 387 | downsample_kernel_size=1, downsample_padding=0, 388 | num_classes=num_classes) 389 | settings = pretrained_settings['se_resnext101_32x4d']['imagenet'] 390 | initialize_pretrained_model(model, num_classes, settings) 391 | return model 392 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import importlib 7 | import torch 8 | import logging 9 | 10 | def get_net(args, criterion): 11 | net = get_model(network=args.arch, num_classes=args.dataset_cls.num_classes, 12 | criterion=criterion, trunk=args.trunk) 13 | num_params = sum([param.nelement() for param in net.parameters()]) 14 | logging.info('Model params = {:2.1f}M'.format(num_params / 1000000)) 15 | 16 | net = net.cuda() 17 | net = torch.nn.DataParallel(net) 18 | return net 19 | 20 | 21 | def get_model(network, num_classes, criterion, trunk): 22 | 23 | module = network[:network.rfind('.')] 24 | model = network[network.rfind('.')+1:] 25 | mod = importlib.import_module(module) 26 | net_func = getattr(mod, model) 27 | net = net_func(num_classes=num_classes, trunk=trunk, criterion=criterion) 28 | return net 29 | 30 | 31 | -------------------------------------------------------------------------------- /network/gscnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code Adapted from: 6 | # https://github.com/sthalles/deeplab_v3 7 | # 8 | # MIT License 9 | # 10 | # Copyright (c) 2018 Thalles Santos Silva 11 | # 12 | # Permission is hereby granted, free of charge, to any person obtaining a copy 13 | # of this software and associated documentation files (the "Software"), to deal 14 | # in the Software without restriction, including without limitation the rights 15 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | # copies of the Software, and to permit persons to whom the Software is 17 | # furnished to do so, subject to the following conditions: 18 | # 19 | # The above copyright notice and this permission notice shall be included in all 20 | # copies or substantial portions of the Software. 21 | # 22 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | """ 29 | 30 | import torch 31 | import torch.nn.functional as F 32 | from torch import nn 33 | from network import SEresnext 34 | from network import Resnet 35 | from network.wider_resnet import wider_resnet38_a2 36 | from config import cfg 37 | from network.mynn import initialize_weights, Norm2d 38 | from torch.autograd import Variable 39 | 40 | from my_functionals import GatedSpatialConv as gsc 41 | 42 | import cv2 43 | import numpy as np 44 | 45 | class Crop(nn.Module): 46 | def __init__(self, axis, offset): 47 | super(Crop, self).__init__() 48 | self.axis = axis 49 | self.offset = offset 50 | 51 | def forward(self, x, ref): 52 | """ 53 | 54 | :param x: input layer 55 | :param ref: reference usually data in 56 | :return: 57 | """ 58 | for axis in range(self.axis, x.dim()): 59 | ref_size = ref.size(axis) 60 | indices = torch.arange(self.offset, self.offset + ref_size).long() 61 | indices = x.data.new().resize_(indices.size()).copy_(indices).long() 62 | x = x.index_select(axis, Variable(indices)) 63 | return x 64 | 65 | 66 | class MyIdentity(nn.Module): 67 | def __init__(self, axis, offset): 68 | super(MyIdentity, self).__init__() 69 | self.axis = axis 70 | self.offset = offset 71 | 72 | def forward(self, x, ref): 73 | """ 74 | 75 | :param x: input layer 76 | :param ref: reference usually data in 77 | :return: 78 | """ 79 | return x 80 | 81 | class SideOutputCrop(nn.Module): 82 | """ 83 | This is the original implementation ConvTranspose2d (fixed) and crops 84 | """ 85 | 86 | def __init__(self, num_output, kernel_sz=None, stride=None, upconv_pad=0, do_crops=True): 87 | super(SideOutputCrop, self).__init__() 88 | self._do_crops = do_crops 89 | self.conv = nn.Conv2d(num_output, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) 90 | 91 | if kernel_sz is not None: 92 | self.upsample = True 93 | self.upsampled = nn.ConvTranspose2d(1, out_channels=1, kernel_size=kernel_sz, stride=stride, 94 | padding=upconv_pad, 95 | bias=False) 96 | ##doing crops 97 | if self._do_crops: 98 | self.crops = Crop(2, offset=kernel_sz // 4) 99 | else: 100 | self.crops = MyIdentity(None, None) 101 | else: 102 | self.upsample = False 103 | 104 | def forward(self, res, reference=None): 105 | side_output = self.conv(res) 106 | if self.upsample: 107 | side_output = self.upsampled(side_output) 108 | side_output = self.crops(side_output, reference) 109 | 110 | return side_output 111 | 112 | 113 | class _AtrousSpatialPyramidPoolingModule(nn.Module): 114 | ''' 115 | operations performed: 116 | 1x1 x depth 117 | 3x3 x depth dilation 6 118 | 3x3 x depth dilation 12 119 | 3x3 x depth dilation 18 120 | image pooling 121 | concatenate all together 122 | Final 1x1 conv 123 | ''' 124 | 125 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=[6, 12, 18]): 126 | super(_AtrousSpatialPyramidPoolingModule, self).__init__() 127 | 128 | # Check if we are using distributed BN and use the nn from encoding.nn 129 | # library rather than using standard pytorch.nn 130 | 131 | if output_stride == 8: 132 | rates = [2 * r for r in rates] 133 | elif output_stride == 16: 134 | pass 135 | else: 136 | raise 'output stride of {} not supported'.format(output_stride) 137 | 138 | self.features = [] 139 | # 1x1 140 | self.features.append( 141 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 142 | Norm2d(reduction_dim), nn.ReLU(inplace=True))) 143 | # other rates 144 | for r in rates: 145 | self.features.append(nn.Sequential( 146 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3, 147 | dilation=r, padding=r, bias=False), 148 | Norm2d(reduction_dim), 149 | nn.ReLU(inplace=True) 150 | )) 151 | self.features = torch.nn.ModuleList(self.features) 152 | 153 | # img level features 154 | self.img_pooling = nn.AdaptiveAvgPool2d(1) 155 | self.img_conv = nn.Sequential( 156 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 157 | Norm2d(reduction_dim), nn.ReLU(inplace=True)) 158 | self.edge_conv = nn.Sequential( 159 | nn.Conv2d(1, reduction_dim, kernel_size=1, bias=False), 160 | Norm2d(reduction_dim), nn.ReLU(inplace=True)) 161 | 162 | 163 | def forward(self, x, edge): 164 | x_size = x.size() 165 | 166 | img_features = self.img_pooling(x) 167 | img_features = self.img_conv(img_features) 168 | img_features = F.interpolate(img_features, x_size[2:], 169 | mode='bilinear',align_corners=True) 170 | out = img_features 171 | 172 | edge_features = F.interpolate(edge, x_size[2:], 173 | mode='bilinear',align_corners=True) 174 | edge_features = self.edge_conv(edge_features) 175 | out = torch.cat((out, edge_features), 1) 176 | 177 | for f in self.features: 178 | y = f(x) 179 | out = torch.cat((out, y), 1) 180 | return out 181 | 182 | class GSCNN(nn.Module): 183 | ''' 184 | Wide_resnet version of DeepLabV3 185 | mod1 186 | pool2 187 | mod2 str2 188 | pool3 189 | mod3-7 190 | 191 | structure: [3, 3, 6, 3, 1, 1] 192 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), 193 | (1024, 2048, 4096)] 194 | ''' 195 | 196 | def __init__(self, num_classes, trunk=None, criterion=None): 197 | 198 | super(GSCNN, self).__init__() 199 | self.criterion = criterion 200 | self.num_classes = num_classes 201 | 202 | wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) 203 | wide_resnet = torch.nn.DataParallel(wide_resnet) 204 | 205 | wide_resnet = wide_resnet.module 206 | self.mod1 = wide_resnet.mod1 207 | self.mod2 = wide_resnet.mod2 208 | self.mod3 = wide_resnet.mod3 209 | self.mod4 = wide_resnet.mod4 210 | self.mod5 = wide_resnet.mod5 211 | self.mod6 = wide_resnet.mod6 212 | self.mod7 = wide_resnet.mod7 213 | self.pool2 = wide_resnet.pool2 214 | self.pool3 = wide_resnet.pool3 215 | self.interpolate = F.interpolate 216 | del wide_resnet 217 | 218 | self.dsn1 = nn.Conv2d(64, 1, 1) 219 | self.dsn3 = nn.Conv2d(256, 1, 1) 220 | self.dsn4 = nn.Conv2d(512, 1, 1) 221 | self.dsn7 = nn.Conv2d(4096, 1, 1) 222 | 223 | self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None) 224 | self.d1 = nn.Conv2d(64, 32, 1) 225 | self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None) 226 | self.d2 = nn.Conv2d(32, 16, 1) 227 | self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None) 228 | self.d3 = nn.Conv2d(16, 8, 1) 229 | self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False) 230 | 231 | self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False) 232 | 233 | self.gate1 = gsc.GatedSpatialConv2d(32, 32) 234 | self.gate2 = gsc.GatedSpatialConv2d(16, 16) 235 | self.gate3 = gsc.GatedSpatialConv2d(8, 8) 236 | 237 | self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, 238 | output_stride=8) 239 | 240 | self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) 241 | self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False) 242 | 243 | self.final_seg = nn.Sequential( 244 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), 245 | Norm2d(256), 246 | nn.ReLU(inplace=True), 247 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 248 | Norm2d(256), 249 | nn.ReLU(inplace=True), 250 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 251 | 252 | self.sigmoid = nn.Sigmoid() 253 | initialize_weights(self.final_seg) 254 | 255 | def forward(self, inp, gts=None): 256 | 257 | x_size = inp.size() 258 | 259 | # res 1 260 | m1 = self.mod1(inp) 261 | 262 | # res 2 263 | m2 = self.mod2(self.pool2(m1)) 264 | 265 | # res 3 266 | m3 = self.mod3(self.pool3(m2)) 267 | 268 | # res 4-7 269 | m4 = self.mod4(m3) 270 | m5 = self.mod5(m4) 271 | m6 = self.mod6(m5) 272 | m7 = self.mod7(m6) 273 | 274 | s3 = F.interpolate(self.dsn3(m3), x_size[2:], 275 | mode='bilinear', align_corners=True) 276 | s4 = F.interpolate(self.dsn4(m4), x_size[2:], 277 | mode='bilinear', align_corners=True) 278 | s7 = F.interpolate(self.dsn7(m7), x_size[2:], 279 | mode='bilinear', align_corners=True) 280 | 281 | m1f = F.interpolate(m1, x_size[2:], mode='bilinear', align_corners=True) 282 | 283 | im_arr = inp.cpu().numpy().transpose((0,2,3,1)).astype(np.uint8) 284 | canny = np.zeros((x_size[0], 1, x_size[2], x_size[3])) 285 | for i in range(x_size[0]): 286 | canny[i] = cv2.Canny(im_arr[i],10,100) 287 | canny = torch.from_numpy(canny).cuda().float() 288 | 289 | cs = self.res1(m1f) 290 | cs = F.interpolate(cs, x_size[2:], 291 | mode='bilinear', align_corners=True) 292 | cs = self.d1(cs) 293 | cs = self.gate1(cs, s3) 294 | cs = self.res2(cs) 295 | cs = F.interpolate(cs, x_size[2:], 296 | mode='bilinear', align_corners=True) 297 | cs = self.d2(cs) 298 | cs = self.gate2(cs, s4) 299 | cs = self.res3(cs) 300 | cs = F.interpolate(cs, x_size[2:], 301 | mode='bilinear', align_corners=True) 302 | cs = self.d3(cs) 303 | cs = self.gate3(cs, s7) 304 | cs = self.fuse(cs) 305 | cs = F.interpolate(cs, x_size[2:], 306 | mode='bilinear', align_corners=True) 307 | edge_out = self.sigmoid(cs) 308 | cat = torch.cat((edge_out, canny), dim=1) 309 | acts = self.cw(cat) 310 | acts = self.sigmoid(acts) 311 | 312 | # aspp 313 | x = self.aspp(m7, acts) 314 | dec0_up = self.bot_aspp(x) 315 | 316 | dec0_fine = self.bot_fine(m2) 317 | dec0_up = self.interpolate(dec0_up, m2.size()[2:], mode='bilinear',align_corners=True) 318 | dec0 = [dec0_fine, dec0_up] 319 | dec0 = torch.cat(dec0, 1) 320 | 321 | dec1 = self.final_seg(dec0) 322 | seg_out = self.interpolate(dec1, x_size[2:], mode='bilinear') 323 | 324 | if self.training: 325 | return self.criterion((seg_out, edge_out), gts) 326 | else: 327 | return seg_out, edge_out 328 | 329 | -------------------------------------------------------------------------------- /network/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from config import cfg 7 | import torch.nn as nn 8 | from math import sqrt 9 | import torch 10 | from torch.autograd.function import InplaceFunction 11 | from itertools import repeat 12 | from torch.nn.modules import Module 13 | from torch.utils.checkpoint import checkpoint 14 | 15 | 16 | def Norm2d(in_channels): 17 | """ 18 | Custom Norm Function to allow flexible switching 19 | """ 20 | layer = getattr(cfg.MODEL,'BNFUNC') 21 | normalizationLayer = layer(in_channels) 22 | return normalizationLayer 23 | 24 | 25 | def initialize_weights(*models): 26 | for model in models: 27 | for module in model.modules(): 28 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 29 | nn.init.kaiming_normal(module.weight) 30 | if module.bias is not None: 31 | module.bias.data.zero_() 32 | elif isinstance(module, nn.BatchNorm2d): 33 | module.weight.data.fill_(1) 34 | module.bias.data.zero_() 35 | -------------------------------------------------------------------------------- /network/wider_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code adapted from: 6 | # https://github.com/mapillary/inplace_abn/ 7 | # 8 | # BSD 3-Clause License 9 | # 10 | # Copyright (c) 2017, mapillary 11 | # All rights reserved. 12 | # 13 | # Redistribution and use in source and binary forms, with or without 14 | # modification, are permitted provided that the following conditions are met: 15 | # 16 | # * Redistributions of source code must retain the above copyright notice, this 17 | # list of conditions and the following disclaimer. 18 | # 19 | # * Redistributions in binary form must reproduce the above copyright notice, 20 | # this list of conditions and the following disclaimer in the documentation 21 | # and/or other materials provided with the distribution. 22 | # 23 | # * Neither the name of the copyright holder nor the names of its 24 | # contributors may be used to endorse or promote products derived from 25 | # this software without specific prior written permission. 26 | # 27 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 28 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 29 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 30 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 31 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 32 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 33 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 34 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 35 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 36 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 37 | """ 38 | 39 | import sys 40 | from collections import OrderedDict 41 | from functools import partial 42 | import torch.nn as nn 43 | import torch 44 | import network.mynn as mynn 45 | 46 | def bnrelu(channels): 47 | return nn.Sequential(mynn.Norm2d(channels), 48 | nn.ReLU(inplace=True)) 49 | 50 | class GlobalAvgPool2d(nn.Module): 51 | 52 | def __init__(self): 53 | """Global average pooling over the input's spatial dimensions""" 54 | super(GlobalAvgPool2d, self).__init__() 55 | 56 | def forward(self, inputs): 57 | in_size = inputs.size() 58 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 59 | 60 | 61 | class IdentityResidualBlock(nn.Module): 62 | 63 | def __init__(self, 64 | in_channels, 65 | channels, 66 | stride=1, 67 | dilation=1, 68 | groups=1, 69 | norm_act=bnrelu, 70 | dropout=None, 71 | dist_bn=False 72 | ): 73 | """Configurable identity-mapping residual block 74 | 75 | Parameters 76 | ---------- 77 | in_channels : int 78 | Number of input channels. 79 | channels : list of int 80 | Number of channels in the internal feature maps. 81 | Can either have two or three elements: if three construct 82 | a residual block with two `3 x 3` convolutions, 83 | otherwise construct a bottleneck block with `1 x 1`, then 84 | `3 x 3` then `1 x 1` convolutions. 85 | stride : int 86 | Stride of the first `3 x 3` convolution 87 | dilation : int 88 | Dilation to apply to the `3 x 3` convolutions. 89 | groups : int 90 | Number of convolution groups. 91 | This is used to create ResNeXt-style blocks and is only compatible with 92 | bottleneck blocks. 93 | norm_act : callable 94 | Function to create normalization / activation Module. 95 | dropout: callable 96 | Function to create Dropout Module. 97 | dist_bn: Boolean 98 | A variable to enable or disable use of distributed BN 99 | """ 100 | super(IdentityResidualBlock, self).__init__() 101 | self.dist_bn = dist_bn 102 | 103 | # Check if we are using distributed BN and use the nn from encoding.nn 104 | # library rather than using standard pytorch.nn 105 | 106 | 107 | # Check parameters for inconsistencies 108 | if len(channels) != 2 and len(channels) != 3: 109 | raise ValueError("channels must contain either two or three values") 110 | if len(channels) == 2 and groups != 1: 111 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 112 | 113 | is_bottleneck = len(channels) == 3 114 | need_proj_conv = stride != 1 or in_channels != channels[-1] 115 | 116 | self.bn1 = norm_act(in_channels) 117 | if not is_bottleneck: 118 | layers = [ 119 | ("conv1", nn.Conv2d(in_channels, 120 | channels[0], 121 | 3, 122 | stride=stride, 123 | padding=dilation, 124 | bias=False, 125 | dilation=dilation)), 126 | ("bn2", norm_act(channels[0])), 127 | ("conv2", nn.Conv2d(channels[0], channels[1], 128 | 3, 129 | stride=1, 130 | padding=dilation, 131 | bias=False, 132 | dilation=dilation)) 133 | ] 134 | if dropout is not None: 135 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 136 | else: 137 | layers = [ 138 | ("conv1", 139 | nn.Conv2d(in_channels, 140 | channels[0], 141 | 1, 142 | stride=stride, 143 | padding=0, 144 | bias=False)), 145 | ("bn2", norm_act(channels[0])), 146 | ("conv2", nn.Conv2d(channels[0], 147 | channels[1], 148 | 3, stride=1, 149 | padding=dilation, bias=False, 150 | groups=groups, 151 | dilation=dilation)), 152 | ("bn3", norm_act(channels[1])), 153 | ("conv3", nn.Conv2d(channels[1], channels[2], 154 | 1, stride=1, padding=0, bias=False)) 155 | ] 156 | if dropout is not None: 157 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 158 | self.convs = nn.Sequential(OrderedDict(layers)) 159 | 160 | if need_proj_conv: 161 | self.proj_conv = nn.Conv2d( 162 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 163 | 164 | def forward(self, x): 165 | """ 166 | This is the standard forward function for non-distributed batch norm 167 | """ 168 | if hasattr(self, "proj_conv"): 169 | bn1 = self.bn1(x) 170 | shortcut = self.proj_conv(bn1) 171 | else: 172 | shortcut = x.clone() 173 | bn1 = self.bn1(x) 174 | 175 | out = self.convs(bn1) 176 | out.add_(shortcut) 177 | return out 178 | 179 | 180 | 181 | 182 | class WiderResNet(nn.Module): 183 | 184 | def __init__(self, 185 | structure, 186 | norm_act=bnrelu, 187 | classes=0 188 | ): 189 | """Wider ResNet with pre-activation (identity mapping) blocks 190 | 191 | Parameters 192 | ---------- 193 | structure : list of int 194 | Number of residual blocks in each of the six modules of the network. 195 | norm_act : callable 196 | Function to create normalization / activation Module. 197 | classes : int 198 | If not `0` also include global average pooling and \ 199 | a fully-connected layer with `classes` outputs at the end 200 | of the network. 201 | """ 202 | super(WiderResNet, self).__init__() 203 | self.structure = structure 204 | 205 | if len(structure) != 6: 206 | raise ValueError("Expected a structure with six values") 207 | 208 | # Initial layers 209 | self.mod1 = nn.Sequential(OrderedDict([ 210 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 211 | ])) 212 | 213 | # Groups of residual blocks 214 | in_channels = 64 215 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), 216 | (512, 1024, 2048), (1024, 2048, 4096)] 217 | for mod_id, num in enumerate(structure): 218 | # Create blocks for module 219 | blocks = [] 220 | for block_id in range(num): 221 | blocks.append(( 222 | "block%d" % (block_id + 1), 223 | IdentityResidualBlock(in_channels, channels[mod_id], 224 | norm_act=norm_act) 225 | )) 226 | 227 | # Update channels and p_keep 228 | in_channels = channels[mod_id][-1] 229 | 230 | # Create module 231 | if mod_id <= 4: 232 | self.add_module("pool%d" % 233 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 234 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 235 | 236 | # Pooling and predictor 237 | self.bn_out = norm_act(in_channels) 238 | if classes != 0: 239 | self.classifier = nn.Sequential(OrderedDict([ 240 | ("avg_pool", GlobalAvgPool2d()), 241 | ("fc", nn.Linear(in_channels, classes)) 242 | ])) 243 | 244 | def forward(self, img): 245 | out = self.mod1(img) 246 | out = self.mod2(self.pool2(out)) 247 | out = self.mod3(self.pool3(out)) 248 | out = self.mod4(self.pool4(out)) 249 | out = self.mod5(self.pool5(out)) 250 | out = self.mod6(self.pool6(out)) 251 | out = self.mod7(out) 252 | out = self.bn_out(out) 253 | 254 | if hasattr(self, "classifier"): 255 | out = self.classifier(out) 256 | 257 | return out 258 | 259 | 260 | class WiderResNetA2(nn.Module): 261 | 262 | def __init__(self, 263 | structure, 264 | norm_act=bnrelu, 265 | classes=0, 266 | dilation=False, 267 | dist_bn=False 268 | ): 269 | """Wider ResNet with pre-activation (identity mapping) blocks 270 | 271 | This variant uses down-sampling by max-pooling in the first two blocks and \ 272 | by strided convolution in the others. 273 | 274 | Parameters 275 | ---------- 276 | structure : list of int 277 | Number of residual blocks in each of the six modules of the network. 278 | norm_act : callable 279 | Function to create normalization / activation Module. 280 | classes : int 281 | If not `0` also include global average pooling and a fully-connected layer 282 | \with `classes` outputs at the end 283 | of the network. 284 | dilation : bool 285 | If `True` apply dilation to the last three modules and change the 286 | \down-sampling factor from 32 to 8. 287 | """ 288 | super(WiderResNetA2, self).__init__() 289 | self.dist_bn = dist_bn 290 | 291 | # If using distributed batch norm, use the encoding.nn as oppose to torch.nn 292 | 293 | 294 | nn.Dropout = nn.Dropout2d 295 | norm_act = bnrelu 296 | self.structure = structure 297 | self.dilation = dilation 298 | 299 | if len(structure) != 6: 300 | raise ValueError("Expected a structure with six values") 301 | 302 | # Initial layers 303 | self.mod1 = torch.nn.Sequential(OrderedDict([ 304 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 305 | ])) 306 | 307 | # Groups of residual blocks 308 | in_channels = 64 309 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), 310 | (1024, 2048, 4096)] 311 | for mod_id, num in enumerate(structure): 312 | # Create blocks for module 313 | blocks = [] 314 | for block_id in range(num): 315 | if not dilation: 316 | dil = 1 317 | stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1 318 | else: 319 | if mod_id == 3: 320 | dil = 2 321 | elif mod_id > 3: 322 | dil = 4 323 | else: 324 | dil = 1 325 | stride = 2 if block_id == 0 and mod_id == 2 else 1 326 | 327 | if mod_id == 4: 328 | drop = partial(nn.Dropout, p=0.3) 329 | elif mod_id == 5: 330 | drop = partial(nn.Dropout, p=0.5) 331 | else: 332 | drop = None 333 | 334 | blocks.append(( 335 | "block%d" % (block_id + 1), 336 | IdentityResidualBlock(in_channels, 337 | channels[mod_id], norm_act=norm_act, 338 | stride=stride, dilation=dil, 339 | dropout=drop, dist_bn=self.dist_bn) 340 | )) 341 | 342 | # Update channels and p_keep 343 | in_channels = channels[mod_id][-1] 344 | 345 | # Create module 346 | if mod_id < 2: 347 | self.add_module("pool%d" % 348 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 349 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 350 | 351 | # Pooling and predictor 352 | self.bn_out = norm_act(in_channels) 353 | if classes != 0: 354 | self.classifier = nn.Sequential(OrderedDict([ 355 | ("avg_pool", GlobalAvgPool2d()), 356 | ("fc", nn.Linear(in_channels, classes)) 357 | ])) 358 | 359 | def forward(self, img): 360 | out = self.mod1(img) 361 | out = self.mod2(self.pool2(out)) 362 | out = self.mod3(self.pool3(out)) 363 | out = self.mod4(out) 364 | out = self.mod5(out) 365 | out = self.mod6(out) 366 | out = self.mod7(out) 367 | out = self.bn_out(out) 368 | 369 | if hasattr(self, "classifier"): 370 | return self.classifier(out) 371 | else: 372 | return out 373 | 374 | 375 | _NETS = { 376 | "16": {"structure": [1, 1, 1, 1, 1, 1]}, 377 | "20": {"structure": [1, 1, 1, 3, 1, 1]}, 378 | "38": {"structure": [3, 3, 6, 3, 1, 1]}, 379 | } 380 | 381 | __all__ = [] 382 | for name, params in _NETS.items(): 383 | net_name = "wider_resnet" + name 384 | setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params)) 385 | __all__.append(net_name) 386 | for name, params in _NETS.items(): 387 | net_name = "wider_resnet" + name + "_a2" 388 | setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params)) 389 | __all__.append(net_name) 390 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | from torch import optim 8 | import math 9 | import logging 10 | from config import cfg 11 | 12 | def get_optimizer(args, net): 13 | 14 | param_groups = net.parameters() 15 | 16 | if args.sgd: 17 | optimizer = optim.SGD(param_groups, 18 | lr=args.lr, 19 | weight_decay=args.weight_decay, 20 | momentum=args.momentum, 21 | nesterov=False) 22 | elif args.adam: 23 | amsgrad=False 24 | if args.amsgrad: 25 | amsgrad=True 26 | optimizer = optim.Adam(param_groups, 27 | lr=args.lr, 28 | weight_decay=args.weight_decay, 29 | amsgrad=amsgrad 30 | ) 31 | else: 32 | raise ('Not a valid optimizer') 33 | 34 | if args.lr_schedule == 'poly': 35 | lambda1 = lambda epoch: math.pow(1 - epoch / args.max_epoch, args.poly_exp) 36 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 37 | else: 38 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule)) 39 | 40 | if args.snapshot: 41 | logging.info('Loading weights from model {}'.format(args.snapshot)) 42 | net, optimizer = restore_snapshot(args, net, optimizer, args.snapshot) 43 | else: 44 | logging.info('Loaded weights from IMGNET classifier') 45 | 46 | return optimizer, scheduler 47 | 48 | def restore_snapshot(args, net, optimizer, snapshot): 49 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 50 | logging.info("Load Compelete") 51 | if args.sgd_finetuned: 52 | print('skipping load optimizer') 53 | else: 54 | if 'optimizer' in checkpoint and args.restore_optimizer: 55 | optimizer.load_state_dict(checkpoint['optimizer']) 56 | 57 | if 'state_dict' in checkpoint: 58 | net = forgiving_state_restore(net, checkpoint['state_dict']) 59 | else: 60 | net = forgiving_state_restore(net, checkpoint) 61 | 62 | return net, optimizer 63 | 64 | def forgiving_state_restore(net, loaded_dict): 65 | # Handle partial loading when some tensors don't match up in size. 66 | # Because we want to use models that were trained off a different 67 | # number of classes. 68 | net_state_dict = net.state_dict() 69 | new_loaded_dict = {} 70 | for k in net_state_dict: 71 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 72 | new_loaded_dict[k] = loaded_dict[k] 73 | else: 74 | logging.info('Skipped loading parameter {}'.format(k)) 75 | net_state_dict.update(new_loaded_dict) 76 | net.load_state_dict(net_state_dict) 77 | return net 78 | 79 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | import argparse 9 | from functools import partial 10 | from config import cfg, assert_and_infer_cfg 11 | import logging 12 | import math 13 | import os 14 | import sys 15 | 16 | import torch 17 | import numpy as np 18 | 19 | from utils.misc import AverageMeter, prep_experiment, evaluate_eval, fast_hist 20 | from utils.f_boundary import eval_mask_boundary 21 | import datasets 22 | import loss 23 | import network 24 | import optimizer 25 | 26 | # Argument Parser 27 | parser = argparse.ArgumentParser(description='GSCNN') 28 | parser.add_argument('--lr', type=float, default=0.01) 29 | parser.add_argument('--arch', type=str, default='network.gscnn.GSCNN') 30 | parser.add_argument('--dataset', type=str, default='cityscapes') 31 | parser.add_argument('--cv', type=int, default=0, 32 | help='cross validation split') 33 | parser.add_argument('--joint_edgeseg_loss', action='store_true', default=True, 34 | help='joint loss') 35 | parser.add_argument('--img_wt_loss', action='store_true', default=False, 36 | help='per-image class-weighted loss') 37 | parser.add_argument('--batch_weighting', action='store_true', default=False, 38 | help='Batch weighting for class') 39 | parser.add_argument('--eval_thresholds', type=str, default='0.0005,0.001875,0.00375,0.005', 40 | help='Thresholds for boundary evaluation') 41 | parser.add_argument('--rescale', type=float, default=1.0, 42 | help='Rescaled LR Rate') 43 | parser.add_argument('--repoly', type=float, default=1.5, 44 | help='Rescaled Poly') 45 | 46 | parser.add_argument('--edge_weight', type=float, default=1.0, 47 | help='Edge loss weight for joint loss') 48 | parser.add_argument('--seg_weight', type=float, default=1.0, 49 | help='Segmentation loss weight for joint loss') 50 | parser.add_argument('--att_weight', type=float, default=1.0, 51 | help='Attention loss weight for joint loss') 52 | parser.add_argument('--dual_weight', type=float, default=1.0, 53 | help='Dual loss weight for joint loss') 54 | 55 | parser.add_argument('--evaluate', action='store_true', default=False) 56 | 57 | parser.add_argument("--local_rank", default=0, type=int) 58 | 59 | parser.add_argument('--sgd', action='store_true', default=True) 60 | parser.add_argument('--sgd_finetuned',action='store_true',default=False) 61 | parser.add_argument('--adam', action='store_true', default=False) 62 | parser.add_argument('--amsgrad', action='store_true', default=False) 63 | 64 | parser.add_argument('--trunk', type=str, default='resnet101', 65 | help='trunk model, can be: resnet101 (default), resnet50') 66 | parser.add_argument('--max_epoch', type=int, default=175) 67 | parser.add_argument('--start_epoch', type=int, default=0) 68 | parser.add_argument('--color_aug', type=float, 69 | default=0.25, help='level of color augmentation') 70 | parser.add_argument('--rotate', type=float, 71 | default=0, help='rotation') 72 | parser.add_argument('--gblur', action='store_true', default=True) 73 | parser.add_argument('--bblur', action='store_true', default=False) 74 | parser.add_argument('--lr_schedule', type=str, default='poly', 75 | help='name of lr schedule: poly') 76 | parser.add_argument('--poly_exp', type=float, default=1.0, 77 | help='polynomial LR exponent') 78 | parser.add_argument('--bs_mult', type=int, default=1) 79 | parser.add_argument('--bs_mult_val', type=int, default=2) 80 | parser.add_argument('--crop_size', type=int, default=720, 81 | help='training crop size') 82 | parser.add_argument('--pre_size', type=int, default=None, 83 | help='resize image shorter edge to this before augmentation') 84 | parser.add_argument('--scale_min', type=float, default=0.5, 85 | help='dynamically scale training images down to this size') 86 | parser.add_argument('--scale_max', type=float, default=2.0, 87 | help='dynamically scale training images up to this size') 88 | parser.add_argument('--weight_decay', type=float, default=1e-4) 89 | parser.add_argument('--momentum', type=float, default=0.9) 90 | parser.add_argument('--snapshot', type=str, default=None) 91 | parser.add_argument('--restore_optimizer', action='store_true', default=False) 92 | parser.add_argument('--exp', type=str, default='default', 93 | help='experiment directory name') 94 | parser.add_argument('--tb_tag', type=str, default='', 95 | help='add tag to tb dir') 96 | parser.add_argument('--ckpt', type=str, default='logs/ckpt') 97 | parser.add_argument('--tb_path', type=str, default='logs/tb') 98 | parser.add_argument('--syncbn', action='store_true', default=True, 99 | help='Synchronized BN') 100 | parser.add_argument('--dump_augmentation_images', action='store_true', default=False, 101 | help='Synchronized BN') 102 | parser.add_argument('--test_mode', action='store_true', default=False, 103 | help='minimum testing (1 epoch run ) to verify nothing failed') 104 | parser.add_argument('-wb', '--wt_bound', type=float, default=1.0) 105 | parser.add_argument('--maxSkip', type=int, default=0) 106 | args = parser.parse_args() 107 | args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 108 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 109 | 110 | #Enable CUDNN Benchmarking optimization 111 | torch.backends.cudnn.benchmark = True 112 | args.world_size = 1 113 | #Test Mode run two epochs with a few iterations of training and val 114 | if args.test_mode: 115 | args.max_epoch = 2 116 | 117 | if 'WORLD_SIZE' in os.environ: 118 | args.world_size = int(os.environ['WORLD_SIZE']) 119 | print("Total world size: ", int(os.environ['WORLD_SIZE'])) 120 | 121 | def main(): 122 | ''' 123 | Main Function 124 | 125 | ''' 126 | 127 | #Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer 128 | assert_and_infer_cfg(args) 129 | writer = prep_experiment(args,parser) 130 | train_loader, val_loader, train_obj = datasets.setup_loaders(args) 131 | criterion, criterion_val = loss.get_loss(args) 132 | net = network.get_net(args, criterion) 133 | optim, scheduler = optimizer.get_optimizer(args, net) 134 | 135 | torch.cuda.empty_cache() 136 | 137 | if args.evaluate: 138 | # Early evaluation for benchmarking 139 | default_eval_epoch = 1 140 | validate(val_loader, net, criterion_val, 141 | optim, default_eval_epoch, writer) 142 | evaluate(val_loader, net) 143 | return 144 | 145 | #Main Loop 146 | for epoch in range(args.start_epoch, args.max_epoch): 147 | # Update EPOCH CTR 148 | cfg.immutable(False) 149 | cfg.EPOCH = epoch 150 | cfg.immutable(True) 151 | 152 | scheduler.step() 153 | 154 | train(train_loader, net, criterion, optim, epoch, writer) 155 | validate(val_loader, net, criterion_val, 156 | optim, epoch, writer) 157 | 158 | 159 | def train(train_loader, net, criterion, optimizer, curr_epoch, writer): 160 | ''' 161 | Runs the training loop per epoch 162 | train_loader: Data loader for train 163 | net: thet network 164 | criterion: loss fn 165 | optimizer: optimizer 166 | curr_epoch: current epoch 167 | writer: tensorboard writer 168 | return: val_avg for step function if required 169 | ''' 170 | net.train() 171 | 172 | train_main_loss = AverageMeter() 173 | train_edge_loss = AverageMeter() 174 | train_seg_loss = AverageMeter() 175 | train_att_loss = AverageMeter() 176 | train_dual_loss = AverageMeter() 177 | curr_iter = curr_epoch * len(train_loader) 178 | 179 | for i, data in enumerate(train_loader): 180 | if i==0: 181 | print('running....') 182 | 183 | inputs, mask, edge, _img_name = data 184 | 185 | if torch.sum(torch.isnan(inputs)) > 0: 186 | import pdb; pdb.set_trace() 187 | 188 | batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3) 189 | 190 | inputs, mask, edge = inputs.cuda(), mask.cuda(), edge.cuda() 191 | 192 | if i==0: 193 | print('forward done') 194 | 195 | optimizer.zero_grad() 196 | 197 | main_loss = None 198 | loss_dict = None 199 | 200 | if args.joint_edgeseg_loss: 201 | loss_dict = net(inputs, gts=(mask, edge)) 202 | 203 | if args.seg_weight > 0: 204 | log_seg_loss = loss_dict['seg_loss'].mean().clone().detach_() 205 | train_seg_loss.update(log_seg_loss.item(), batch_pixel_size) 206 | main_loss = loss_dict['seg_loss'] 207 | 208 | if args.edge_weight > 0: 209 | log_edge_loss = loss_dict['edge_loss'].mean().clone().detach_() 210 | train_edge_loss.update(log_edge_loss.item(), batch_pixel_size) 211 | if main_loss is not None: 212 | main_loss += loss_dict['edge_loss'] 213 | else: 214 | main_loss = loss_dict['edge_loss'] 215 | 216 | if args.att_weight > 0: 217 | log_att_loss = loss_dict['att_loss'].mean().clone().detach_() 218 | train_att_loss.update(log_att_loss.item(), batch_pixel_size) 219 | if main_loss is not None: 220 | main_loss += loss_dict['att_loss'] 221 | else: 222 | main_loss = loss_dict['att_loss'] 223 | 224 | if args.dual_weight > 0: 225 | log_dual_loss = loss_dict['dual_loss'].mean().clone().detach_() 226 | train_dual_loss.update(log_dual_loss.item(), batch_pixel_size) 227 | if main_loss is not None: 228 | main_loss += loss_dict['dual_loss'] 229 | else: 230 | main_loss = loss_dict['dual_loss'] 231 | 232 | else: 233 | main_loss = net(inputs, gts=mask) 234 | 235 | main_loss = main_loss.mean() 236 | log_main_loss = main_loss.clone().detach_() 237 | 238 | train_main_loss.update(log_main_loss.item(), batch_pixel_size) 239 | 240 | main_loss.backward() 241 | 242 | optimizer.step() 243 | 244 | if i==0: 245 | print('step 1 done') 246 | 247 | curr_iter += 1 248 | 249 | if args.local_rank == 0: 250 | msg = '[epoch {}], [iter {} / {}], [train main loss {:0.6f}], [seg loss {:0.6f}], [edge loss {:0.6f}], [lr {:0.6f}]'.format( 251 | curr_epoch, i + 1, len(train_loader), train_main_loss.avg, train_seg_loss.avg, train_edge_loss.avg, optimizer.param_groups[-1]['lr'] ) 252 | 253 | logging.info(msg) 254 | 255 | # Log tensorboard metrics for each iteration of the training phase 256 | writer.add_scalar('training/loss', (train_main_loss.val), 257 | curr_iter) 258 | writer.add_scalar('training/lr', optimizer.param_groups[-1]['lr'], 259 | curr_iter) 260 | if args.joint_edgeseg_loss: 261 | 262 | writer.add_scalar('training/seg_loss', (train_seg_loss.val), 263 | curr_iter) 264 | writer.add_scalar('training/edge_loss', (train_edge_loss.val), 265 | curr_iter) 266 | writer.add_scalar('training/att_loss', (train_att_loss.val), 267 | curr_iter) 268 | writer.add_scalar('training/dual_loss', (train_dual_loss.val), 269 | curr_iter) 270 | if i > 5 and args.test_mode: 271 | return 272 | 273 | def validate(val_loader, net, criterion, optimizer, curr_epoch, writer): 274 | ''' 275 | Runs the validation loop after each training epoch 276 | val_loader: Data loader for validation 277 | net: thet network 278 | criterion: loss fn 279 | optimizer: optimizer 280 | curr_epoch: current epoch 281 | writer: tensorboard writer 282 | return: 283 | ''' 284 | net.eval() 285 | val_loss = AverageMeter() 286 | mf_score = AverageMeter() 287 | IOU_acc = 0 288 | dump_images = [] 289 | heatmap_images = [] 290 | for vi, data in enumerate(val_loader): 291 | input, mask, edge, img_names = data 292 | assert len(input.size()) == 4 and len(mask.size()) == 3 293 | assert input.size()[2:] == mask.size()[1:] 294 | h, w = mask.size()[1:] 295 | 296 | batch_pixel_size = input.size(0) * input.size(2) * input.size(3) 297 | input, mask_cuda, edge_cuda = input.cuda(), mask.cuda(), edge.cuda() 298 | 299 | with torch.no_grad(): 300 | seg_out, edge_out = net(input) # output = (1, 19, 713, 713) 301 | 302 | if args.joint_edgeseg_loss: 303 | loss_dict = criterion((seg_out, edge_out), (mask_cuda, edge_cuda)) 304 | val_loss.update(sum(loss_dict.values()).item(), batch_pixel_size) 305 | else: 306 | val_loss.update(criterion(seg_out, mask_cuda).item(), batch_pixel_size) 307 | 308 | # Collect data from different GPU to a single GPU since 309 | # encoding.parallel.criterionparallel function calculates distributed loss 310 | # functions 311 | 312 | seg_predictions = seg_out.data.max(1)[1].cpu() 313 | edge_predictions = edge_out.max(1)[0].cpu() 314 | 315 | #Logging 316 | if vi % 20 == 0: 317 | if args.local_rank == 0: 318 | logging.info('validating: %d / %d' % (vi + 1, len(val_loader))) 319 | if vi > 10 and args.test_mode: 320 | break 321 | _edge = edge.max(1)[0] 322 | 323 | #Image Dumps 324 | if vi < 10: 325 | dump_images.append([mask, seg_predictions, img_names]) 326 | heatmap_images.append([_edge, edge_predictions, img_names]) 327 | 328 | IOU_acc += fast_hist(seg_predictions.numpy().flatten(), mask.numpy().flatten(), 329 | args.dataset_cls.num_classes) 330 | 331 | del seg_out, edge_out, vi, data 332 | 333 | if args.local_rank == 0: 334 | evaluate_eval(args, net, optimizer, val_loss, mf_score, IOU_acc, dump_images, heatmap_images, 335 | writer, curr_epoch, args.dataset_cls) 336 | 337 | return val_loss.avg 338 | 339 | def evaluate(val_loader, net): 340 | ''' 341 | Runs the evaluation loop and prints F score 342 | val_loader: Data loader for validation 343 | net: thet network 344 | return: 345 | ''' 346 | net.eval() 347 | for thresh in args.eval_thresholds.split(','): 348 | mf_score1 = AverageMeter() 349 | mf_pc_score1 = AverageMeter() 350 | ap_score1 = AverageMeter() 351 | ap_pc_score1 = AverageMeter() 352 | Fpc = np.zeros((args.dataset_cls.num_classes)) 353 | Fc = np.zeros((args.dataset_cls.num_classes)) 354 | for vi, data in enumerate(val_loader): 355 | input, mask, edge, img_names = data 356 | assert len(input.size()) == 4 and len(mask.size()) == 3 357 | assert input.size()[2:] == mask.size()[1:] 358 | h, w = mask.size()[1:] 359 | 360 | batch_pixel_size = input.size(0) * input.size(2) * input.size(3) 361 | input, mask_cuda, edge_cuda = input.cuda(), mask.cuda(), edge.cuda() 362 | 363 | with torch.no_grad(): 364 | seg_out, edge_out = net(input) 365 | 366 | seg_predictions = seg_out.data.max(1)[1].cpu() 367 | edge_predictions = edge_out.max(1)[0].cpu() 368 | 369 | logging.info('evaluating: %d / %d' % (vi + 1, len(val_loader))) 370 | _Fpc, _Fc = eval_mask_boundary(seg_predictions.numpy(), mask.numpy(), args.dataset_cls.num_classes, bound_th=float(thresh)) 371 | Fc += _Fc 372 | Fpc += _Fpc 373 | 374 | del seg_out, edge_out, vi, data 375 | 376 | logging.info('Threshold: ' + thresh) 377 | logging.info('F_Score: ' + str(np.sum(Fpc/Fc)/args.dataset_cls.num_classes)) 378 | logging.info('F_Score (Classwise): ' + str(Fpc/Fc)) 379 | 380 | if __name__ == '__main__': 381 | main() 382 | 383 | 384 | 385 | 386 | -------------------------------------------------------------------------------- /transforms/joint_transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code borrowded from: 6 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/joint_transforms.py 7 | # 8 | # 9 | # MIT License 10 | # 11 | # Copyright (c) 2017 ZijunDeng 12 | # 13 | # Permission is hereby granted, free of charge, to any person obtaining a copy 14 | # of this software and associated documentation files (the "Software"), to deal 15 | # in the Software without restriction, including without limitation the rights 16 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | # copies of the Software, and to permit persons to whom the Software is 18 | # furnished to do so, subject to the following conditions: 19 | # 20 | # The above copyright notice and this permission notice shall be included in all 21 | # copies or substantial portions of the Software. 22 | # 23 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 | # SOFTWARE. 30 | 31 | """ 32 | 33 | import math 34 | import numbers 35 | import random 36 | from PIL import Image, ImageOps 37 | import numpy as np 38 | import random 39 | 40 | class Compose(object): 41 | def __init__(self, transforms): 42 | self.transforms = transforms 43 | 44 | def __call__(self, img, mask): 45 | assert img.size == mask.size 46 | for t in self.transforms: 47 | img, mask = t(img, mask) 48 | return img, mask 49 | 50 | 51 | class RandomCrop(object): 52 | ''' 53 | Take a random crop from the image. 54 | 55 | First the image or crop size may need to be adjusted if the incoming image 56 | is too small... 57 | 58 | If the image is smaller than the crop, then: 59 | the image is padded up to the size of the crop 60 | unless 'nopad', in which case the crop size is shrunk to fit the image 61 | 62 | A random crop is taken such that the crop fits within the image. 63 | If a centroid is passed in, the crop must intersect the centroid. 64 | ''' 65 | def __init__(self, size, ignore_index=0, nopad=True): 66 | if isinstance(size, numbers.Number): 67 | self.size = (int(size), int(size)) 68 | else: 69 | self.size = size 70 | self.ignore_index = ignore_index 71 | self.nopad = nopad 72 | self.pad_color = (0, 0, 0) 73 | 74 | def __call__(self, img, mask, centroid=None): 75 | assert img.size == mask.size 76 | w, h = img.size 77 | # ASSUME H, W 78 | th, tw = self.size 79 | if w == tw and h == th: 80 | return img, mask 81 | 82 | if self.nopad: 83 | if th > h or tw > w: 84 | # Instead of padding, adjust crop size to the shorter edge of image. 85 | shorter_side = min(w, h) 86 | th, tw = shorter_side, shorter_side 87 | else: 88 | # Check if we need to pad img to fit for crop_size. 89 | if th > h: 90 | pad_h = (th - h) // 2 + 1 91 | else: 92 | pad_h = 0 93 | if tw > w: 94 | pad_w = (tw - w) // 2 + 1 95 | else: 96 | pad_w = 0 97 | border = (pad_w, pad_h, pad_w, pad_h) 98 | if pad_h or pad_w: 99 | img = ImageOps.expand(img, border=border, fill=self.pad_color) 100 | mask = ImageOps.expand(mask, border=border, fill=self.ignore_index) 101 | w, h = img.size 102 | 103 | if centroid is not None: 104 | # Need to insure that centroid is covered by crop and that crop 105 | # sits fully within the image 106 | c_x, c_y = centroid 107 | max_x = w - tw 108 | max_y = h - th 109 | x1 = random.randint(c_x - tw, c_x) 110 | x1 = min(max_x, max(0, x1)) 111 | y1 = random.randint(c_y - th, c_y) 112 | y1 = min(max_y, max(0, y1)) 113 | else: 114 | if w == tw: 115 | x1 = 0 116 | else: 117 | x1 = random.randint(0, w - tw) 118 | if h == th: 119 | y1 = 0 120 | else: 121 | y1 = random.randint(0, h - th) 122 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 123 | 124 | 125 | class ResizeHeight(object): 126 | def __init__(self, size, interpolation=Image.BICUBIC): 127 | self.target_h = size 128 | self.interpolation = interpolation 129 | 130 | def __call__(self, img, mask): 131 | w, h = img.size 132 | target_w = int(w / h * self.target_h) 133 | return (img.resize((target_w, self.target_h), self.interpolation), 134 | mask.resize((target_w, self.target_h), Image.NEAREST)) 135 | 136 | 137 | class CenterCrop(object): 138 | def __init__(self, size): 139 | if isinstance(size, numbers.Number): 140 | self.size = (int(size), int(size)) 141 | else: 142 | self.size = size 143 | 144 | def __call__(self, img, mask): 145 | assert img.size == mask.size 146 | w, h = img.size 147 | th, tw = self.size 148 | x1 = int(round((w - tw) / 2.)) 149 | y1 = int(round((h - th) / 2.)) 150 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 151 | 152 | 153 | class CenterCropPad(object): 154 | def __init__(self, size, ignore_index=0): 155 | if isinstance(size, numbers.Number): 156 | self.size = (int(size), int(size)) 157 | else: 158 | self.size = size 159 | self.ignore_index = ignore_index 160 | 161 | def __call__(self, img, mask): 162 | 163 | assert img.size == mask.size 164 | w, h = img.size 165 | if isinstance(self.size, tuple): 166 | tw, th = self.size[0], self.size[1] 167 | else: 168 | th, tw = self.size, self.size 169 | 170 | 171 | if w < tw: 172 | pad_x = tw - w 173 | else: 174 | pad_x = 0 175 | if h < th: 176 | pad_y = th - h 177 | else: 178 | pad_y = 0 179 | 180 | if pad_x or pad_y: 181 | # left, top, right, bottom 182 | img = ImageOps.expand(img, border=(pad_x, pad_y, pad_x, pad_y), fill=0) 183 | mask = ImageOps.expand(mask, border=(pad_x, pad_y, pad_x, pad_y), 184 | fill=self.ignore_index) 185 | 186 | x1 = int(round((w - tw) / 2.)) 187 | y1 = int(round((h - th) / 2.)) 188 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)) 189 | 190 | 191 | 192 | class PadImage(object): 193 | def __init__(self, size, ignore_index): 194 | self.size = size 195 | self.ignore_index = ignore_index 196 | 197 | 198 | def __call__(self, img, mask): 199 | assert img.size == mask.size 200 | th, tw = self.size, self.size 201 | 202 | 203 | w, h = img.size 204 | 205 | if w > tw or h > th : 206 | wpercent = (tw/float(w)) 207 | target_h = int((float(img.size[1])*float(wpercent))) 208 | img, mask = img.resize((tw, target_h), Image.BICUBIC), mask.resize((tw, target_h), Image.NEAREST) 209 | 210 | w, h = img.size 211 | ##Pad 212 | img = ImageOps.expand(img, border=(0,0,tw-w, th-h), fill=0) 213 | mask = ImageOps.expand(mask, border=(0,0,tw-w, th-h), fill=self.ignore_index) 214 | 215 | return img, mask 216 | 217 | class RandomHorizontallyFlip(object): 218 | def __call__(self, img, mask): 219 | if random.random() < 0.5: 220 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose( 221 | Image.FLIP_LEFT_RIGHT) 222 | return img, mask 223 | 224 | 225 | class FreeScale(object): 226 | def __init__(self, size): 227 | self.size = tuple(reversed(size)) # size: (h, w) 228 | 229 | def __call__(self, img, mask): 230 | assert img.size == mask.size 231 | return img.resize(self.size, Image.BICUBIC), mask.resize(self.size, Image.NEAREST) 232 | 233 | 234 | class Scale(object): 235 | ''' 236 | Scale image such that longer side is == size 237 | ''' 238 | 239 | def __init__(self, size): 240 | self.size = size 241 | 242 | def __call__(self, img, mask): 243 | assert img.size == mask.size 244 | w, h = img.size 245 | if (w >= h and w == self.size) or (h >= w and h == self.size): 246 | return img, mask 247 | if w > h: 248 | ow = self.size 249 | oh = int(self.size * h / w) 250 | return img.resize((ow, oh), Image.BICUBIC), mask.resize( 251 | (ow, oh), Image.NEAREST) 252 | else: 253 | oh = self.size 254 | ow = int(self.size * w / h) 255 | return img.resize((ow, oh), Image.BICUBIC), mask.resize( 256 | (ow, oh), Image.NEAREST) 257 | 258 | 259 | class ScaleMin(object): 260 | ''' 261 | Scale image such that shorter side is == size 262 | ''' 263 | 264 | def __init__(self, size): 265 | self.size = size 266 | 267 | def __call__(self, img, mask): 268 | assert img.size == mask.size 269 | w, h = img.size 270 | if (w <= h and w == self.size) or (h <= w and h == self.size): 271 | return img, mask 272 | if w < h: 273 | ow = self.size 274 | oh = int(self.size * h / w) 275 | return img.resize((ow, oh), Image.BICUBIC), mask.resize( 276 | (ow, oh), Image.NEAREST) 277 | else: 278 | oh = self.size 279 | ow = int(self.size * w / h) 280 | return img.resize((ow, oh), Image.BICUBIC), mask.resize( 281 | (ow, oh), Image.NEAREST) 282 | 283 | 284 | class Resize(object): 285 | ''' 286 | Resize image to exact size of crop 287 | ''' 288 | 289 | def __init__(self, size): 290 | self.size = (size, size) 291 | 292 | def __call__(self, img, mask): 293 | assert img.size == mask.size 294 | w, h = img.size 295 | if (w == h and w == self.size): 296 | return img, mask 297 | return (img.resize(self.size, Image.BICUBIC), 298 | mask.resize(self.size, Image.NEAREST)) 299 | 300 | 301 | class RandomSizedCrop(object): 302 | def __init__(self, size): 303 | self.size = size 304 | 305 | def __call__(self, img, mask): 306 | assert img.size == mask.size 307 | for attempt in range(10): 308 | area = img.size[0] * img.size[1] 309 | target_area = random.uniform(0.45, 1.0) * area 310 | aspect_ratio = random.uniform(0.5, 2) 311 | 312 | w = int(round(math.sqrt(target_area * aspect_ratio))) 313 | h = int(round(math.sqrt(target_area / aspect_ratio))) 314 | 315 | if random.random() < 0.5: 316 | w, h = h, w 317 | 318 | if w <= img.size[0] and h <= img.size[1]: 319 | x1 = random.randint(0, img.size[0] - w) 320 | y1 = random.randint(0, img.size[1] - h) 321 | 322 | img = img.crop((x1, y1, x1 + w, y1 + h)) 323 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 324 | assert (img.size == (w, h)) 325 | 326 | return img.resize((self.size, self.size), Image.BICUBIC),\ 327 | mask.resize((self.size, self.size), Image.NEAREST) 328 | 329 | # Fallback 330 | scale = Scale(self.size) 331 | crop = CenterCrop(self.size) 332 | return crop(*scale(img, mask)) 333 | 334 | 335 | class RandomRotate(object): 336 | def __init__(self, degree): 337 | self.degree = degree 338 | 339 | def __call__(self, img, mask): 340 | rotate_degree = random.random() * 2 * self.degree - self.degree 341 | return img.rotate(rotate_degree, Image.BICUBIC), mask.rotate( 342 | rotate_degree, Image.NEAREST) 343 | 344 | 345 | class RandomSizeAndCrop(object): 346 | def __init__(self, size, crop_nopad, 347 | scale_min=0.5, scale_max=2.0, ignore_index=0, pre_size=None): 348 | self.size = size 349 | self.crop = RandomCrop(self.size, ignore_index=ignore_index, nopad=crop_nopad) 350 | self.scale_min = scale_min 351 | self.scale_max = scale_max 352 | self.pre_size = pre_size 353 | 354 | def __call__(self, img, mask, centroid=None): 355 | assert img.size == mask.size 356 | 357 | # first, resize such that shorter edge is pre_size 358 | if self.pre_size is None: 359 | scale_amt = 1. 360 | elif img.size[1] < img.size[0]: 361 | scale_amt = self.pre_size / img.size[1] 362 | else: 363 | scale_amt = self.pre_size / img.size[0] 364 | scale_amt *= random.uniform(self.scale_min, self.scale_max) 365 | w, h = [int(i * scale_amt) for i in img.size] 366 | 367 | if centroid is not None: 368 | centroid = [int(c * scale_amt) for c in centroid] 369 | 370 | img, mask = img.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST) 371 | 372 | return self.crop(img, mask, centroid) 373 | 374 | 375 | class SlidingCropOld(object): 376 | def __init__(self, crop_size, stride_rate, ignore_label): 377 | self.crop_size = crop_size 378 | self.stride_rate = stride_rate 379 | self.ignore_label = ignore_label 380 | 381 | def _pad(self, img, mask): 382 | h, w = img.shape[: 2] 383 | pad_h = max(self.crop_size - h, 0) 384 | pad_w = max(self.crop_size - w, 0) 385 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 386 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', 387 | constant_values=self.ignore_label) 388 | return img, mask 389 | 390 | def __call__(self, img, mask): 391 | assert img.size == mask.size 392 | 393 | w, h = img.size 394 | long_size = max(h, w) 395 | 396 | img = np.array(img) 397 | mask = np.array(mask) 398 | 399 | if long_size > self.crop_size: 400 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 401 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 402 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 403 | img_sublist, mask_sublist = [], [] 404 | for yy in range(h_step_num): 405 | for xx in range(w_step_num): 406 | sy, sx = yy * stride, xx * stride 407 | ey, ex = sy + self.crop_size, sx + self.crop_size 408 | img_sub = img[sy: ey, sx: ex, :] 409 | mask_sub = mask[sy: ey, sx: ex] 410 | img_sub, mask_sub = self._pad(img_sub, mask_sub) 411 | img_sublist.append( 412 | Image.fromarray( 413 | img_sub.astype( 414 | np.uint8)).convert('RGB')) 415 | mask_sublist.append( 416 | Image.fromarray( 417 | mask_sub.astype( 418 | np.uint8)).convert('P')) 419 | return img_sublist, mask_sublist 420 | else: 421 | img, mask = self._pad(img, mask) 422 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 423 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 424 | return img, mask 425 | 426 | 427 | class SlidingCrop(object): 428 | def __init__(self, crop_size, stride_rate, ignore_label): 429 | self.crop_size = crop_size 430 | self.stride_rate = stride_rate 431 | self.ignore_label = ignore_label 432 | 433 | def _pad(self, img, mask): 434 | h, w = img.shape[: 2] 435 | pad_h = max(self.crop_size - h, 0) 436 | pad_w = max(self.crop_size - w, 0) 437 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant') 438 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant', 439 | constant_values=self.ignore_label) 440 | return img, mask, h, w 441 | 442 | def __call__(self, img, mask): 443 | assert img.size == mask.size 444 | 445 | w, h = img.size 446 | long_size = max(h, w) 447 | 448 | img = np.array(img) 449 | mask = np.array(mask) 450 | 451 | if long_size > self.crop_size: 452 | stride = int(math.ceil(self.crop_size * self.stride_rate)) 453 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1 454 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1 455 | img_slices, mask_slices, slices_info = [], [], [] 456 | for yy in range(h_step_num): 457 | for xx in range(w_step_num): 458 | sy, sx = yy * stride, xx * stride 459 | ey, ex = sy + self.crop_size, sx + self.crop_size 460 | img_sub = img[sy: ey, sx: ex, :] 461 | mask_sub = mask[sy: ey, sx: ex] 462 | img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub) 463 | img_slices.append( 464 | Image.fromarray( 465 | img_sub.astype( 466 | np.uint8)).convert('RGB')) 467 | mask_slices.append( 468 | Image.fromarray( 469 | mask_sub.astype( 470 | np.uint8)).convert('P')) 471 | slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) 472 | return img_slices, mask_slices, slices_info 473 | else: 474 | img, mask, sub_h, sub_w = self._pad(img, mask) 475 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB') 476 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 477 | return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]] 478 | 479 | 480 | class ClassUniform(object): 481 | def __init__(self, size, crop_nopad, scale_min=0.5, scale_max=2.0, ignore_index=0, 482 | class_list=[16, 15, 14]): 483 | """ 484 | This is the initialization for class uniform sampling 485 | :param size: crop size (int) 486 | :param crop_nopad: Padding or no padding (bool) 487 | :param scale_min: Minimum Scale (float) 488 | :param scale_max: Maximum Scale (float) 489 | :param ignore_index: The index value to ignore in the GT images (unsigned int) 490 | :param class_list: A list of class to sample around, by default Truck, train, bus 491 | """ 492 | self.size = size 493 | self.crop = RandomCrop(self.size, ignore_index=ignore_index, nopad=crop_nopad) 494 | 495 | self.class_list = class_list.replace(" ", "").split(",") 496 | 497 | self.scale_min = scale_min 498 | self.scale_max = scale_max 499 | 500 | def detect_peaks(self, image): 501 | """ 502 | Takes an image and detect the peaks usingthe local maximum filter. 503 | Returns a boolean mask of the peaks (i.e. 1 when 504 | the pixel's value is the neighborhood maximum, 0 otherwise) 505 | 506 | :param image: An 2d input images 507 | :return: Binary output images of the same size as input with pixel value equal 508 | to 1 indicating that there is peak at that point 509 | """ 510 | 511 | # define an 8-connected neighborhood 512 | neighborhood = generate_binary_structure(2, 2) 513 | 514 | # apply the local maximum filter; all pixel of maximal value 515 | # in their neighborhood are set to 1 516 | local_max = maximum_filter(image, footprint=neighborhood) == image 517 | # local_max is a mask that contains the peaks we are 518 | # looking for, but also the background. 519 | # In order to isolate the peaks we must remove the background from the mask. 520 | 521 | # we create the mask of the background 522 | background = (image == 0) 523 | 524 | # a little technicality: we must erode the background in order to 525 | # successfully subtract it form local_max, otherwise a line will 526 | # appear along the background border (artifact of the local maximum filter) 527 | eroded_background = binary_erosion(background, structure=neighborhood, 528 | border_value=1) 529 | 530 | # we obtain the final mask, containing only peaks, 531 | # by removing the background from the local_max mask (xor operation) 532 | detected_peaks = local_max ^ eroded_background 533 | 534 | return detected_peaks 535 | 536 | def __call__(self, img, mask): 537 | """ 538 | :param img: PIL Input Image 539 | :param mask: PIL Input Mask 540 | :return: PIL output PIL (mask, crop) of self.crop_size 541 | """ 542 | assert img.size == mask.size 543 | 544 | scale_amt = random.uniform(self.scale_min, self.scale_max) 545 | w = int(scale_amt * img.size[0]) 546 | h = int(scale_amt * img.size[1]) 547 | 548 | if scale_amt < 1.0: 549 | img, mask = img.resize((w, h), Image.BICUBIC), mask.resize((w, h), 550 | Image.NEAREST) 551 | return self.crop(img, mask) 552 | else: 553 | # Smart Crop ( Class Uniform's ABN) 554 | origw, origh = mask.size 555 | img_new, mask_new = \ 556 | img.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST) 557 | interested_class = self.class_list # [16, 15, 14] # Train, Truck, Bus 558 | data = np.array(mask) 559 | arr = np.zeros((1024, 2048)) 560 | for class_of_interest in interested_class: 561 | # hist = np.histogram(data==class_of_interest) 562 | map = np.where(data == class_of_interest, data, 0) 563 | map = map.astype('float64') / map.sum() / class_of_interest 564 | map[np.isnan(map)] = 0 565 | arr = arr + map 566 | 567 | origarr = arr 568 | window_size = 250 569 | 570 | # Given a list of classes of interest find the points on the image that are 571 | # of interest to crop from 572 | sum_arr = np.zeros((1024, 2048)).astype('float32') 573 | tmp = np.zeros((1024, 2048)).astype('float32') 574 | for x in range(0, arr.shape[0] - window_size, window_size): 575 | for y in range(0, arr.shape[1] - window_size, window_size): 576 | sum_arr[int(x + window_size / 2), int(y + window_size / 2)] = origarr[ 577 | x:x + window_size, 578 | y:y + window_size].sum() 579 | tmp[x:x + window_size, y:y + window_size] = \ 580 | origarr[x:x + window_size, y:y + window_size].sum() 581 | 582 | # Scaling Ratios in X and Y for non-uniform images 583 | ratio = (float(origw) / w, float(origh) / h) 584 | output = self.detect_peaks(sum_arr) 585 | coord = (np.column_stack(np.where(output))).tolist() 586 | 587 | # Check if there are any peaks in the images to crop from if not do standard 588 | # cropping behaviour 589 | if len(coord) == 0: 590 | return self.crop(img_new, mask_new) 591 | else: 592 | # If peaks are detected, random peak selection followed by peak 593 | # coordinate scaling to new scaled image and then random 594 | # cropping around the peak point in the scaled image 595 | randompick = np.random.randint(len(coord)) 596 | y, x = coord[randompick] 597 | y, x = int(y * ratio[0]), int(x * ratio[1]) 598 | window_size = window_size * ratio[0] 599 | cropx = random.uniform( 600 | max(0, (x - window_size / 2) - (self.size - window_size)), 601 | max((x - window_size / 2), (x - window_size / 2) - ( 602 | (w - window_size) - x + window_size / 2))) 603 | 604 | cropy = random.uniform( 605 | max(0, (y - window_size / 2) - (self.size - window_size)), 606 | max((y - window_size / 2), (y - window_size / 2) - ( 607 | (h - window_size) - y + window_size / 2))) 608 | 609 | return_img = img_new.crop( 610 | (cropx, cropy, cropx + self.size, cropy + self.size)) 611 | return_mask = mask_new.crop( 612 | (cropx, cropy, cropx + self.size, cropy + self.size)) 613 | return (return_img, return_mask) 614 | -------------------------------------------------------------------------------- /transforms/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code borrowded from: 6 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py 7 | # 8 | # 9 | # MIT License 10 | # 11 | # Copyright (c) 2017 ZijunDeng 12 | # 13 | # Permission is hereby granted, free of charge, to any person obtaining a copy 14 | # of this software and associated documentation files (the "Software"), to deal 15 | # in the Software without restriction, including without limitation the rights 16 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | # copies of the Software, and to permit persons to whom the Software is 18 | # furnished to do so, subject to the following conditions: 19 | # 20 | # The above copyright notice and this permission notice shall be included in all 21 | # copies or substantial portions of the Software. 22 | # 23 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 | # SOFTWARE. 30 | 31 | """ 32 | 33 | import random 34 | import numpy as np 35 | from skimage.filters import gaussian 36 | from skimage.restoration import denoise_bilateral 37 | import torch 38 | from PIL import Image, ImageFilter, ImageEnhance 39 | import torchvision.transforms as torch_tr 40 | from scipy import ndimage 41 | from config import cfg 42 | from scipy.ndimage.interpolation import shift 43 | from scipy.misc import imsave 44 | from skimage.segmentation import find_boundaries 45 | class RandomVerticalFlip(object): 46 | def __call__(self, img): 47 | if random.random() < 0.5: 48 | return img.transpose(Image.FLIP_TOP_BOTTOM) 49 | return img 50 | 51 | 52 | class DeNormalize(object): 53 | def __init__(self, mean, std): 54 | self.mean = mean 55 | self.std = std 56 | 57 | def __call__(self, tensor): 58 | for t, m, s in zip(tensor, self.mean, self.std): 59 | t.mul_(s).add_(m) 60 | return tensor 61 | 62 | 63 | class MaskToTensor(object): 64 | def __call__(self, img): 65 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 66 | 67 | class RelaxedBoundaryLossToTensor(object): 68 | def __init__(self,ignore_id, num_classes): 69 | self.ignore_id=ignore_id 70 | self.num_classes= num_classes 71 | 72 | 73 | def new_one_hot_converter(self,a): 74 | ncols = self.num_classes+1 75 | out = np.zeros( (a.size,ncols), dtype=np.uint8) 76 | out[np.arange(a.size),a.ravel()] = 1 77 | out.shape = a.shape + (ncols,) 78 | return out 79 | 80 | def __call__(self,img): 81 | 82 | 83 | 84 | img_arr = np.array(img) 85 | import scipy;scipy.misc.imsave('orig.png',img_arr) 86 | 87 | img_arr[img_arr==self.ignore_id]=self.num_classes 88 | 89 | if cfg.STRICTBORDERCLASS != None: 90 | one_hot_orig = self.new_one_hot_converter(img_arr) 91 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1])) 92 | for cls in cfg.STRICTBORDERCLASS: 93 | mask = np.logical_or(mask,(img_arr == cls)) 94 | one_hot = 0 95 | 96 | #print(cfg.EPOCH, "Non Reduced", cfg.TRAIN.REDUCE_RELAXEDITERATIONCOUNT) 97 | border = cfg.BORDER_WINDOW 98 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 99 | border = border // 2 100 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8) 101 | print(cfg.EPOCH, "Reduced") 102 | 103 | for i in range(-border,border+1): 104 | for j in range(-border, border+1): 105 | shifted= shift(img_arr,(i,j), cval=self.num_classes) 106 | one_hot += self.new_one_hot_converter(shifted) 107 | 108 | one_hot[one_hot>1] = 1 109 | 110 | if cfg.STRICTBORDERCLASS != None: 111 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot) 112 | 113 | one_hot = np.moveaxis(one_hot,-1,0) 114 | 115 | 116 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 117 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot) 118 | print(one_hot.shape) 119 | return torch.from_numpy(one_hot).byte() 120 | #return torch.from_numpy(one_hot).float() 121 | exit(0) 122 | 123 | class ResizeHeight(object): 124 | def __init__(self, size, interpolation=Image.BILINEAR): 125 | self.target_h = size 126 | self.interpolation = interpolation 127 | 128 | def __call__(self, img): 129 | w, h = img.size 130 | target_w = int(w / h * self.target_h) 131 | return img.resize((target_w, self.target_h), self.interpolation) 132 | 133 | 134 | class FreeScale(object): 135 | def __init__(self, size, interpolation=Image.BILINEAR): 136 | self.size = tuple(reversed(size)) # size: (h, w) 137 | self.interpolation = interpolation 138 | 139 | def __call__(self, img): 140 | return img.resize(self.size, self.interpolation) 141 | 142 | 143 | class FlipChannels(object): 144 | def __call__(self, img): 145 | img = np.array(img)[:, :, ::-1] 146 | return Image.fromarray(img.astype(np.uint8)) 147 | 148 | class RandomGaussianBlur(object): 149 | def __call__(self, img): 150 | sigma = 0.15 + random.random() * 1.15 151 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 152 | blurred_img *= 255 153 | return Image.fromarray(blurred_img.astype(np.uint8)) 154 | 155 | 156 | class RandomBilateralBlur(object): 157 | def __call__(self, img): 158 | sigma = random.uniform(0.05,0.75) 159 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True) 160 | blurred_img *= 255 161 | return Image.fromarray(blurred_img.astype(np.uint8)) 162 | 163 | try: 164 | import accimage 165 | except ImportError: 166 | accimage = None 167 | 168 | 169 | def _is_pil_image(img): 170 | if accimage is not None: 171 | return isinstance(img, (Image.Image, accimage.Image)) 172 | else: 173 | return isinstance(img, Image.Image) 174 | 175 | 176 | def adjust_brightness(img, brightness_factor): 177 | """Adjust brightness of an Image. 178 | 179 | Args: 180 | img (PIL Image): PIL Image to be adjusted. 181 | brightness_factor (float): How much to adjust the brightness. Can be 182 | any non negative number. 0 gives a black image, 1 gives the 183 | original image while 2 increases the brightness by a factor of 2. 184 | 185 | Returns: 186 | PIL Image: Brightness adjusted image. 187 | """ 188 | if not _is_pil_image(img): 189 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 190 | 191 | enhancer = ImageEnhance.Brightness(img) 192 | img = enhancer.enhance(brightness_factor) 193 | return img 194 | 195 | 196 | def adjust_contrast(img, contrast_factor): 197 | """Adjust contrast of an Image. 198 | 199 | Args: 200 | img (PIL Image): PIL Image to be adjusted. 201 | contrast_factor (float): How much to adjust the contrast. Can be any 202 | non negative number. 0 gives a solid gray image, 1 gives the 203 | original image while 2 increases the contrast by a factor of 2. 204 | 205 | Returns: 206 | PIL Image: Contrast adjusted image. 207 | """ 208 | if not _is_pil_image(img): 209 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 210 | 211 | enhancer = ImageEnhance.Contrast(img) 212 | img = enhancer.enhance(contrast_factor) 213 | return img 214 | 215 | 216 | def adjust_saturation(img, saturation_factor): 217 | """Adjust color saturation of an image. 218 | 219 | Args: 220 | img (PIL Image): PIL Image to be adjusted. 221 | saturation_factor (float): How much to adjust the saturation. 0 will 222 | give a black and white image, 1 will give the original image while 223 | 2 will enhance the saturation by a factor of 2. 224 | 225 | Returns: 226 | PIL Image: Saturation adjusted image. 227 | """ 228 | if not _is_pil_image(img): 229 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 230 | 231 | enhancer = ImageEnhance.Color(img) 232 | img = enhancer.enhance(saturation_factor) 233 | return img 234 | 235 | 236 | def adjust_hue(img, hue_factor): 237 | """Adjust hue of an image. 238 | 239 | The image hue is adjusted by converting the image to HSV and 240 | cyclically shifting the intensities in the hue channel (H). 241 | The image is then converted back to original image mode. 242 | 243 | `hue_factor` is the amount of shift in H channel and must be in the 244 | interval `[-0.5, 0.5]`. 245 | 246 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 247 | 248 | Args: 249 | img (PIL Image): PIL Image to be adjusted. 250 | hue_factor (float): How much to shift the hue channel. Should be in 251 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 252 | HSV space in positive and negative direction respectively. 253 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 254 | with complementary colors while 0 gives the original image. 255 | 256 | Returns: 257 | PIL Image: Hue adjusted image. 258 | """ 259 | if not(-0.5 <= hue_factor <= 0.5): 260 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 261 | 262 | if not _is_pil_image(img): 263 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 264 | 265 | input_mode = img.mode 266 | if input_mode in {'L', '1', 'I', 'F'}: 267 | return img 268 | 269 | h, s, v = img.convert('HSV').split() 270 | 271 | np_h = np.array(h, dtype=np.uint8) 272 | # uint8 addition take cares of rotation across boundaries 273 | with np.errstate(over='ignore'): 274 | np_h += np.uint8(hue_factor * 255) 275 | h = Image.fromarray(np_h, 'L') 276 | 277 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 278 | return img 279 | 280 | 281 | class ColorJitter(object): 282 | """Randomly change the brightness, contrast and saturation of an image. 283 | 284 | Args: 285 | brightness (float): How much to jitter brightness. brightness_factor 286 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 287 | contrast (float): How much to jitter contrast. contrast_factor 288 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 289 | saturation (float): How much to jitter saturation. saturation_factor 290 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 291 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 292 | [-hue, hue]. Should be >=0 and <= 0.5. 293 | """ 294 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 295 | self.brightness = brightness 296 | self.contrast = contrast 297 | self.saturation = saturation 298 | self.hue = hue 299 | 300 | @staticmethod 301 | def get_params(brightness, contrast, saturation, hue): 302 | """Get a randomized transform to be applied on image. 303 | 304 | Arguments are same as that of __init__. 305 | 306 | Returns: 307 | Transform which randomly adjusts brightness, contrast and 308 | saturation in a random order. 309 | """ 310 | transforms = [] 311 | if brightness > 0: 312 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 313 | transforms.append( 314 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) 315 | 316 | if contrast > 0: 317 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 318 | transforms.append( 319 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) 320 | 321 | if saturation > 0: 322 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 323 | transforms.append( 324 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) 325 | 326 | if hue > 0: 327 | hue_factor = np.random.uniform(-hue, hue) 328 | transforms.append( 329 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) 330 | 331 | np.random.shuffle(transforms) 332 | transform = torch_tr.Compose(transforms) 333 | 334 | return transform 335 | 336 | def __call__(self, img): 337 | """ 338 | Args: 339 | img (PIL Image): Input image. 340 | 341 | Returns: 342 | PIL Image: Color jittered image. 343 | """ 344 | transform = self.get_params(self.brightness, self.contrast, 345 | self.saturation, self.hue) 346 | return transform(img) 347 | -------------------------------------------------------------------------------- /utils/AttrDict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code adapted from: 6 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py 7 | 8 | Source License 9 | # Copyright (c) 2017-present, Facebook, Inc. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | ############################################################################## 23 | # 24 | # Based on: 25 | # -------------------------------------------------------- 26 | # Fast R-CNN 27 | # Copyright (c) 2015 Microsoft 28 | # Licensed under The MIT License [see LICENSE for details] 29 | # Written by Ross Girshick 30 | # -------------------------------------------------------- 31 | """ 32 | 33 | 34 | class AttrDict(dict): 35 | 36 | IMMUTABLE = '__immutable__' 37 | 38 | def __init__(self, *args, **kwargs): 39 | super(AttrDict, self).__init__(*args, **kwargs) 40 | self.__dict__[AttrDict.IMMUTABLE] = False 41 | 42 | def __getattr__(self, name): 43 | if name in self.__dict__: 44 | return self.__dict__[name] 45 | elif name in self: 46 | return self[name] 47 | else: 48 | raise AttributeError(name) 49 | 50 | def __setattr__(self, name, value): 51 | if not self.__dict__[AttrDict.IMMUTABLE]: 52 | if name in self.__dict__: 53 | self.__dict__[name] = value 54 | else: 55 | self[name] = value 56 | else: 57 | raise AttributeError( 58 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'. 59 | format(name, value) 60 | ) 61 | 62 | def immutable(self, is_immutable): 63 | """Set immutability to is_immutable and recursively apply the setting 64 | to all nested AttrDicts. 65 | """ 66 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable 67 | # Recursively set immutable state 68 | for v in self.__dict__.values(): 69 | if isinstance(v, AttrDict): 70 | v.immutable(is_immutable) 71 | for v in self.values(): 72 | if isinstance(v, AttrDict): 73 | v.immutable(is_immutable) 74 | 75 | def is_immutable(self): 76 | return self.__dict__[AttrDict.IMMUTABLE] 77 | -------------------------------------------------------------------------------- /utils/f_boundary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code adapted from: 6 | # https://github.com/fperazzi/davis/blob/master/python/lib/davis/measures/f_boundary.py 7 | # 8 | # Source License 9 | # 10 | # BSD 3-Clause License 11 | # 12 | # Copyright (c) 2017, 13 | # All rights reserved. 14 | # 15 | # Redistribution and use in source and binary forms, with or without 16 | # modification, are permitted provided that the following conditions are met: 17 | # 18 | # * Redistributions of source code must retain the above copyright notice, this 19 | # list of conditions and the following disclaimer. 20 | # 21 | # * Redistributions in binary form must reproduce the above copyright notice, 22 | # this list of conditions and the following disclaimer in the documentation 23 | # and/or other materials provided with the distribution. 24 | # 25 | # * Neither the name of the copyright holder nor the names of its 26 | # contributors may be used to endorse or promote products derived from 27 | # this software without specific prior written permission. 28 | # 29 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 30 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 31 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 32 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 33 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 34 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 35 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 36 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 37 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 38 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s 39 | ############################################################################## 40 | # 41 | # Based on: 42 | # ---------------------------------------------------------------------------- 43 | # A Benchmark Dataset and Evaluation Methodology for Video Object Segmentation 44 | # Copyright (c) 2016 Federico Perazzi 45 | # Licensed under the BSD License [see LICENSE for details] 46 | # Written by Federico Perazzi 47 | # ---------------------------------------------------------------------------- 48 | """ 49 | 50 | 51 | 52 | 53 | import numpy as np 54 | from multiprocessing import Pool 55 | from tqdm import tqdm 56 | 57 | """ Utilities for computing, reading and saving benchmark evaluation.""" 58 | 59 | def eval_mask_boundary(seg_mask,gt_mask,num_classes,num_proc=10,bound_th=0.008): 60 | """ 61 | Compute F score for a segmentation mask 62 | 63 | Arguments: 64 | seg_mask (ndarray): segmentation mask prediction 65 | gt_mask (ndarray): segmentation mask ground truth 66 | num_classes (int): number of classes 67 | 68 | Returns: 69 | F (float): mean F score across all classes 70 | Fpc (listof float): F score per class 71 | """ 72 | p = Pool(processes=num_proc) 73 | batch_size = seg_mask.shape[0] 74 | 75 | Fpc = np.zeros(num_classes) 76 | Fc = np.zeros(num_classes) 77 | for class_id in tqdm(range(num_classes)): 78 | args = [((seg_mask[i] == class_id).astype(np.uint8), 79 | (gt_mask[i] == class_id).astype(np.uint8), 80 | gt_mask[i] == 255, 81 | bound_th) 82 | for i in range(batch_size)] 83 | temp = p.map(db_eval_boundary_wrapper, args) 84 | temp = np.array(temp) 85 | Fs = temp[:,0] 86 | _valid = ~np.isnan(Fs) 87 | Fc[class_id] = np.sum(_valid) 88 | Fs[np.isnan(Fs)] = 0 89 | Fpc[class_id] = sum(Fs) 90 | return Fpc, Fc 91 | 92 | 93 | #def db_eval_boundary_wrapper_wrapper(args): 94 | # seg_mask, gt_mask, class_id, batch_size, Fpc = args 95 | # print("class_id:" + str(class_id)) 96 | # p = Pool(processes=10) 97 | # args = [((seg_mask[i] == class_id).astype(np.uint8), 98 | # (gt_mask[i] == class_id).astype(np.uint8)) 99 | # for i in range(batch_size)] 100 | # Fs = p.map(db_eval_boundary_wrapper, args) 101 | # Fpc[class_id] = sum(Fs) 102 | # return 103 | 104 | def db_eval_boundary_wrapper(args): 105 | foreground_mask, gt_mask, ignore, bound_th = args 106 | return db_eval_boundary(foreground_mask, gt_mask,ignore, bound_th) 107 | 108 | def db_eval_boundary(foreground_mask,gt_mask, ignore_mask,bound_th=0.008): 109 | """ 110 | Compute mean,recall and decay from per-frame evaluation. 111 | Calculates precision/recall for boundaries between foreground_mask and 112 | gt_mask using morphological operators to speed it up. 113 | 114 | Arguments: 115 | foreground_mask (ndarray): binary segmentation image. 116 | gt_mask (ndarray): binary annotated image. 117 | 118 | Returns: 119 | F (float): boundaries F-measure 120 | P (float): boundaries precision 121 | R (float): boundaries recall 122 | """ 123 | assert np.atleast_3d(foreground_mask).shape[2] == 1 124 | 125 | bound_pix = bound_th if bound_th >= 1 else \ 126 | np.ceil(bound_th*np.linalg.norm(foreground_mask.shape)) 127 | 128 | #print(bound_pix) 129 | #print(gt.shape) 130 | #print(np.unique(gt)) 131 | foreground_mask[ignore_mask] = 0 132 | gt_mask[ignore_mask] = 0 133 | 134 | # Get the pixel boundaries of both masks 135 | fg_boundary = seg2bmap(foreground_mask); 136 | gt_boundary = seg2bmap(gt_mask); 137 | 138 | from skimage.morphology import binary_dilation,disk 139 | 140 | fg_dil = binary_dilation(fg_boundary,disk(bound_pix)) 141 | gt_dil = binary_dilation(gt_boundary,disk(bound_pix)) 142 | 143 | # Get the intersection 144 | gt_match = gt_boundary * fg_dil 145 | fg_match = fg_boundary * gt_dil 146 | 147 | # Area of the intersection 148 | n_fg = np.sum(fg_boundary) 149 | n_gt = np.sum(gt_boundary) 150 | 151 | #% Compute precision and recall 152 | if n_fg == 0 and n_gt > 0: 153 | precision = 1 154 | recall = 0 155 | elif n_fg > 0 and n_gt == 0: 156 | precision = 0 157 | recall = 1 158 | elif n_fg == 0 and n_gt == 0: 159 | precision = 1 160 | recall = 1 161 | else: 162 | precision = np.sum(fg_match)/float(n_fg) 163 | recall = np.sum(gt_match)/float(n_gt) 164 | 165 | # Compute F measure 166 | if precision + recall == 0: 167 | F = 0 168 | else: 169 | F = 2*precision*recall/(precision+recall); 170 | 171 | return F, precision 172 | 173 | def seg2bmap(seg,width=None,height=None): 174 | """ 175 | From a segmentation, compute a binary boundary map with 1 pixel wide 176 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 177 | origin from the actual segment boundary. 178 | 179 | Arguments: 180 | seg : Segments labeled from 1..k. 181 | width : Width of desired bmap <= seg.shape[1] 182 | height : Height of desired bmap <= seg.shape[0] 183 | 184 | Returns: 185 | bmap (ndarray): Binary boundary map. 186 | 187 | David Martin 188 | January 2003 189 | """ 190 | 191 | seg = seg.astype(np.bool) 192 | seg[seg>0] = 1 193 | 194 | assert np.atleast_3d(seg).shape[2] == 1 195 | 196 | width = seg.shape[1] if width is None else width 197 | height = seg.shape[0] if height is None else height 198 | 199 | h,w = seg.shape[:2] 200 | 201 | ar1 = float(width) / float(height) 202 | ar2 = float(w) / float(h) 203 | 204 | assert not (width>w | height>h | abs(ar1-ar2)>0.01),\ 205 | 'Can''t convert %dx%d seg to %dx%d bmap.'%(w,h,width,height) 206 | 207 | e = np.zeros_like(seg) 208 | s = np.zeros_like(seg) 209 | se = np.zeros_like(seg) 210 | 211 | e[:,:-1] = seg[:,1:] 212 | s[:-1,:] = seg[1:,:] 213 | se[:-1,:-1] = seg[1:,1:] 214 | 215 | b = seg^e | seg^s | seg^se 216 | b[-1,:] = seg[-1,:]^e[-1,:] 217 | b[:,-1] = seg[:,-1]^s[:,-1] 218 | b[-1,-1] = 0 219 | 220 | if w == width and h == height: 221 | bmap = b 222 | else: 223 | bmap = np.zeros((height,width)) 224 | for x in range(w): 225 | for y in range(h): 226 | if b[y,x]: 227 | j = 1+floor((y-1)+height / h) 228 | i = 1+floor((x-1)+width / h) 229 | bmap[j,i] = 1; 230 | 231 | return bmap 232 | -------------------------------------------------------------------------------- /utils/image_page.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import glob 7 | import os 8 | 9 | class ImagePage(object): 10 | ''' 11 | This creates an HTML page of embedded images, useful for showing evaluation results. 12 | 13 | Usage: 14 | ip = ImagePage(html_fn) 15 | 16 | # Add a table with N images ... 17 | ip.add_table((img, descr), (img, descr), ...) 18 | 19 | # Generate html page 20 | ip.write_page() 21 | ''' 22 | def __init__(self, experiment_name, html_filename): 23 | self.experiment_name = experiment_name 24 | self.html_filename = html_filename 25 | self.outfile = open(self.html_filename, 'w') 26 | self.items = [] 27 | 28 | def _print_header(self): 29 | header = ''' 30 | 31 | 32 | Experiment = {} 33 | 34 | '''.format(self.experiment_name) 35 | self.outfile.write(header) 36 | 37 | def _print_footer(self): 38 | self.outfile.write(''' 39 | ''') 40 | 41 | def _print_table_header(self, table_name): 42 | table_hdr = '''

{}

43 | 44 | '''.format(table_name) 45 | self.outfile.write(table_hdr) 46 | 47 | def _print_table_footer(self): 48 | table_ftr = ''' 49 |
''' 50 | self.outfile.write(table_ftr) 51 | 52 | def _print_table_guts(self, img_fn, descr): 53 | table = ''' 54 |

55 | 56 | 57 |
58 |

{descr}

59 |

60 | '''.format(img_fn=img_fn, descr=descr) 61 | self.outfile.write(table) 62 | 63 | def add_table(self, img_label_pairs): 64 | self.items.append(img_label_pairs) 65 | 66 | def _write_table(self, table): 67 | img, _descr = table[0] 68 | self._print_table_header(os.path.basename(img)) 69 | for img, descr in table: 70 | self._print_table_guts(img, descr) 71 | self._print_table_footer() 72 | 73 | def write_page(self): 74 | self._print_header() 75 | 76 | for table in self.items: 77 | self._write_table(table) 78 | 79 | self._print_footer() 80 | 81 | 82 | def main(): 83 | images = glob.glob('dump_imgs_train/*.png') 84 | images = [i for i in images if 'mask' not in i] 85 | 86 | ip = ImagePage('test page', 'dd.html') 87 | for img in images: 88 | basename = os.path.splitext(img)[0] 89 | mask_img = basename + '_mask.png' 90 | ip.add_table(((img, 'image'), (mask_img, 'mask'))) 91 | ip.write_page() 92 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import sys 7 | import re 8 | import os 9 | import shutil 10 | import torch 11 | from datetime import datetime 12 | import logging 13 | from subprocess import call 14 | import shlex 15 | from tensorboardX import SummaryWriter 16 | import numpy as np 17 | from utils.image_page import ImagePage 18 | import torchvision.transforms as standard_transforms 19 | import torchvision.utils as vutils 20 | from PIL import Image 21 | 22 | # Create unique output dir name based on non-default command line args 23 | def make_exp_name(args, parser): 24 | exp_name = '{}-{}'.format(args.dataset[:3], args.arch[:]) 25 | dict_args = vars(args) 26 | 27 | # sort so that we get a consistent directory name 28 | argnames = sorted(dict_args) 29 | 30 | # build experiment name with non-default args 31 | for argname in argnames: 32 | if dict_args[argname] != parser.get_default(argname): 33 | if argname == 'exp' or argname == 'arch' or argname == 'prev_best_filepath': 34 | continue 35 | if argname == 'snapshot': 36 | arg_str = '-PT' 37 | elif argname == 'nosave': 38 | arg_str = '' 39 | argname='' 40 | elif argname == 'freeze_trunk': 41 | argname = '' 42 | arg_str = '-fr' 43 | elif argname == 'syncbn': 44 | argname = '' 45 | arg_str = '-sbn' 46 | elif argname == 'relaxedloss': 47 | argname = '' 48 | arg_str = 're-loss' 49 | elif isinstance(dict_args[argname], bool): 50 | arg_str = 'T' if dict_args[argname] else 'F' 51 | else: 52 | arg_str = str(dict_args[argname])[:6] 53 | exp_name += '-{}_{}'.format(str(argname), arg_str) 54 | # clean special chars out 55 | exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name) 56 | exp_name = 'testing' 57 | return exp_name 58 | 59 | 60 | def save_log(prefix, output_dir, date_str): 61 | fmt = '%(asctime)s.%(msecs)03d %(message)s' 62 | date_fmt = '%m-%d %H:%M:%S' 63 | filename = os.path.join(output_dir, prefix + '_' + date_str + '.log') 64 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt, 65 | filename=filename, filemode='w') 66 | console = logging.StreamHandler() 67 | console.setLevel(logging.INFO) 68 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt) 69 | console.setFormatter(formatter) 70 | logging.getLogger('').addHandler(console) 71 | #logging.basicConfig(level=logging.INFO, format='%(asctime)s.%(msecs)03d %(levelname)s:\t%(message)s', datefmt='%Y-%m-%d %H:%M:%S') 72 | 73 | 74 | def save_code(exp_path, date_str): 75 | code_root = '.' # FIXME! 76 | zip_outfile = os.path.join(exp_path, 'code_{}.tgz'.format(date_str)) 77 | print('Saving code to {}'.format(zip_outfile)) 78 | cmd = 'tar -czvf {zip_outfile} --exclude=\'*.pyc\' --exclude=\'*.png\' ' +\ 79 | '--exclude=\'*tfevents*\' {root}/train.py ' + \ 80 | ' {root}/utils {root}/datasets {root}/models' 81 | cmd = cmd.format(zip_outfile=zip_outfile, root=code_root) 82 | call(shlex.split(cmd), stdout=open(os.devnull, 'wb')) 83 | 84 | 85 | def prep_experiment(args, parser): 86 | ''' 87 | Make output directories, setup logging, Tensorboard, snapshot code. 88 | ''' 89 | ckpt_path = args.ckpt 90 | tb_path = args.tb_path 91 | exp_name = make_exp_name(args, parser) 92 | args.exp_path = os.path.join(ckpt_path, args.exp, exp_name) 93 | args.tb_exp_path = os.path.join(tb_path, args.exp, exp_name) 94 | args.ngpu = torch.cuda.device_count() 95 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S')) 96 | args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 97 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 98 | args.last_record = {} 99 | os.makedirs(args.exp_path, exist_ok=True) 100 | os.makedirs(args.tb_exp_path, exist_ok=True) 101 | save_log('log', args.exp_path, args.date_str) 102 | #save_code(args.exp_path, args.date_str) 103 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write( 104 | str(args) + '\n\n') 105 | 106 | writer = SummaryWriter(logdir=args.tb_exp_path, comment=args.tb_tag) 107 | return writer 108 | 109 | class AverageMeter(object): 110 | 111 | def __init__(self): 112 | self.reset() 113 | 114 | def reset(self): 115 | self.val = 0 116 | self.avg = 0 117 | self.sum = 0 118 | self.count = 0 119 | 120 | def update(self, val, n=1): 121 | self.val = val 122 | self.sum += val * n 123 | self.count += n 124 | self.avg = self.sum / self.count 125 | 126 | 127 | 128 | def evaluate_eval(args, net, optimizer, val_loss, mf_score, hist, dump_images, heatmap_images, writer, epoch=0, dataset=None, ): 129 | ''' 130 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for 131 | large dataset) Only applies to eval/eval.py 132 | ''' 133 | # axis 0: gt, axis 1: prediction 134 | acc = np.diag(hist).sum() / hist.sum() 135 | acc_cls = np.diag(hist) / hist.sum(axis=1) 136 | acc_cls = np.nanmean(acc_cls) 137 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 138 | 139 | print_evaluate_results(hist, iu, writer, epoch, dataset) 140 | freq = hist.sum(axis=1) / hist.sum() 141 | mean_iu = np.nanmean(iu) 142 | logging.info('mean {}'.format(mean_iu)) 143 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 144 | #return acc, acc_cls, mean_iu, fwavacc 145 | 146 | # update latest snapshot 147 | if 'mean_iu' in args.last_record: 148 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format( 149 | args.last_record['epoch'], args.last_record['mean_iu']) 150 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 151 | try: 152 | os.remove(last_snapshot) 153 | except OSError: 154 | pass 155 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format(epoch, mean_iu) 156 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 157 | args.last_record['mean_iu'] = mean_iu 158 | args.last_record['epoch'] = epoch 159 | 160 | torch.cuda.synchronize() 161 | 162 | torch.save({ 163 | 'state_dict': net.state_dict(), 164 | 'optimizer': optimizer.state_dict(), 165 | 'epoch': epoch, 166 | 'mean_iu': mean_iu, 167 | 'command': ' '.join(sys.argv[1:]) 168 | }, last_snapshot) 169 | 170 | # update best snapshot 171 | if mean_iu > args.best_record['mean_iu'] : 172 | # remove old best snapshot 173 | if args.best_record['epoch'] != -1: 174 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format( 175 | args.best_record['epoch'], args.best_record['mean_iu']) 176 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 177 | assert os.path.exists(best_snapshot), \ 178 | 'cant find old snapshot {}'.format(best_snapshot) 179 | os.remove(best_snapshot) 180 | 181 | 182 | # save new best 183 | args.best_record['val_loss'] = val_loss.avg 184 | args.best_record['mask_f1_score'] = mf_score.avg 185 | args.best_record['epoch'] = epoch 186 | args.best_record['acc'] = acc 187 | args.best_record['acc_cls'] = acc_cls 188 | args.best_record['mean_iu'] = mean_iu 189 | args.best_record['fwavacc'] = fwavacc 190 | 191 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format( 192 | args.best_record['epoch'], args.best_record['mean_iu']) 193 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 194 | shutil.copyfile(last_snapshot, best_snapshot) 195 | 196 | 197 | to_save_dir = os.path.join(args.exp_path, 'best_images') 198 | os.makedirs(to_save_dir, exist_ok=True) 199 | ip = ImagePage(epoch, '{}/index.html'.format(to_save_dir)) 200 | 201 | val_visual = [] 202 | 203 | idx = 0 204 | 205 | visualize = standard_transforms.Compose([ 206 | standard_transforms.Scale(384), 207 | standard_transforms.ToTensor() 208 | ]) 209 | for bs_idx, bs_data in enumerate(dump_images): 210 | for local_idx, data in enumerate(zip(bs_data[0], bs_data[1],bs_data[2])): 211 | gt_pil = args.dataset_cls.colorize_mask(data[0].cpu().numpy()) 212 | pred = data[1].cpu().numpy() 213 | predictions_pil = args.dataset_cls.colorize_mask(pred) 214 | img_name = data[2] 215 | 216 | prediction_fn = '{}_prediction.png'.format(img_name) 217 | predictions_pil.save(os.path.join(to_save_dir, prediction_fn)) 218 | gt_fn = '{}_gt.png'.format(img_name) 219 | gt_pil.save(os.path.join(to_save_dir, gt_fn)) 220 | ip.add_table([(gt_fn, 'gt'), (prediction_fn, 'prediction')]) 221 | val_visual.extend([visualize(gt_pil.convert('RGB')), 222 | visualize(predictions_pil.convert('RGB'))]) 223 | idx = idx+1 224 | if idx >= 9: 225 | ip.write_page() 226 | break 227 | for bs_idx, bs_data in enumerate(heatmap_images): 228 | for local_idx, data in enumerate(zip(bs_data[0], bs_data[1],bs_data[2])): 229 | 230 | gt_pil = args.dataset_cls.colorize_mask(data[0].cpu().numpy()) 231 | 232 | predictions_pil = data[1].cpu().numpy() 233 | predictions_pil = (predictions_pil / predictions_pil.max()) * 255 234 | predictions_pil = Image.fromarray(predictions_pil.astype(np.uint8)) 235 | img_name = data[2] 236 | 237 | prediction_fn = '{}_prediction.png'.format(img_name) 238 | predictions_pil.save(os.path.join(to_save_dir, prediction_fn)) 239 | gt_fn = '{}_gt.png'.format(img_name) 240 | gt_pil.save(os.path.join(to_save_dir, gt_fn)) 241 | ip.add_table([(gt_fn, 'gt'), (prediction_fn, 'prediction')]) 242 | val_visual.extend([visualize(gt_pil.convert('RGB')), 243 | visualize(predictions_pil.convert('RGB'))]) 244 | idx = idx+1 245 | if idx >= 9: 246 | ip.write_page() 247 | break 248 | 249 | val_visual = torch.stack(val_visual, 0) 250 | val_visual = vutils.make_grid(val_visual, nrow=10, padding=5) 251 | writer.add_image(last_snapshot, val_visual) 252 | 253 | logging.info('-' * 107) 254 | fmt_str = '[epoch %d], [val loss %.5f], [mask f1 %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 255 | '[mean_iu %.5f], [fwavacc %.5f]' 256 | logging.info(fmt_str % (epoch, val_loss.avg, mf_score.avg, acc, acc_cls, mean_iu, fwavacc)) 257 | fmt_str = 'best record: [val loss %.5f], [mask f1 %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 258 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], ' 259 | logging.info(fmt_str % (args.best_record['val_loss'], args.best_record['mask_f1_score'], 260 | args.best_record['acc'], 261 | args.best_record['acc_cls'], args.best_record['mean_iu'], 262 | args.best_record['fwavacc'], args.best_record['epoch'])) 263 | logging.info('-' * 107) 264 | 265 | # tensorboard logging of validation phase metrics 266 | 267 | writer.add_scalar('training/acc', acc, epoch) 268 | writer.add_scalar('training/acc_cls', acc_cls, epoch) 269 | writer.add_scalar('training/mean_iu', mean_iu, epoch) 270 | writer.add_scalar('training/val_loss', val_loss.avg, epoch) 271 | writer.add_scalar('training/mask_f1_score', mf_score.avg, epoch) 272 | 273 | 274 | def fast_hist(label_pred, label_true, num_classes): 275 | mask = (label_true >= 0) & (label_true < num_classes) 276 | hist = np.bincount( 277 | num_classes * label_true[mask].astype(int) + 278 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 279 | return hist 280 | 281 | 282 | 283 | def print_evaluate_results(hist, iu, writer=None, epoch=0, dataset=None): 284 | try: 285 | id2cat = dataset.id2cat 286 | except: 287 | id2cat = {i: i for i in range(dataset.num_classes)} 288 | iu_false_positive = hist.sum(axis=1) - np.diag(hist) 289 | iu_false_negative = hist.sum(axis=0) - np.diag(hist) 290 | iu_true_positive = np.diag(hist) 291 | 292 | logging.info('IoU:') 293 | logging.info('label_id label iU Precision Recall TP FP FN') 294 | for idx, i in enumerate(iu): 295 | idx_string = "{:2d}".format(idx) 296 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else '' 297 | iu_string = '{:5.2f}'.format(i * 100) 298 | total_pixels = hist.sum() 299 | tp = '{:5.2f}'.format(100 * iu_true_positive[idx] / total_pixels) 300 | fp = '{:5.2f}'.format( 301 | iu_false_positive[idx] / iu_true_positive[idx]) 302 | fn = '{:5.2f}'.format(iu_false_negative[idx] / iu_true_positive[idx]) 303 | precision = '{:5.2f}'.format( 304 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx])) 305 | recall = '{:5.2f}'.format( 306 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx])) 307 | logging.info('{} {} {} {} {} {} {} {}'.format( 308 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn)) 309 | 310 | writer.add_scalar('val_class_iu/{}'.format(id2cat[idx]), i * 100, epoch) 311 | writer.add_scalar('val_class_precision/{}'.format(id2cat[idx]), 312 | iu_true_positive[idx] / (iu_true_positive[idx] + 313 | iu_false_positive[idx]), epoch) 314 | writer.add_scalar('val_class_recall/{}'.format(id2cat[idx]), 315 | iu_true_positive[idx] / (iu_true_positive[idx] + 316 | iu_false_negative[idx]), epoch) 317 | --------------------------------------------------------------------------------