├── Dockerfile
├── LICENSE
├── README.md
├── config.py
├── datasets
├── __init__.py
├── cityscapes.py
├── cityscapes_labels.py
└── edge_utils.py
├── docs
├── index.html
└── resources
│ ├── GSCNN.mp4
│ ├── architecture.jpg
│ ├── bibtex.txt
│ ├── crop.jpg
│ ├── edges.jpg
│ ├── gscnn.gif
│ ├── intro.jpg
│ ├── seg.jpg
│ ├── semboundary.jpg
│ ├── table.png
│ ├── test.jpg
│ └── top.jpg
├── loss.py
├── my_functionals
├── DualTaskLoss.py
├── GatedSpatialConv.py
├── __init__.py
└── custom_functional.py
├── network
├── Resnet.py
├── SEresnext.py
├── __init__.py
├── gscnn.py
├── mynn.py
└── wider_resnet.py
├── optimizer.py
├── train.py
├── transforms
├── joint_transforms.py
└── transforms.py
└── utils
├── AttrDict.py
├── f_boundary.py
├── image_page.py
└── misc.py
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.0-cuda10.0-cudnn7-devel
2 |
3 | RUN apt-get -y update
4 | RUN apt-get -y upgrade
5 |
6 | RUN apt-get update \
7 | && apt-get install -y software-properties-common wget \
8 | && add-apt-repository -y ppa:ubuntu-toolchain-r/test \
9 | && apt-get update \
10 | && apt-get install -y make git curl vim vim-gnome
11 |
12 | # Install apt-get
13 | RUN apt-get install -y python3-pip python3-dev vim htop python3-tk pkg-config
14 |
15 | RUN pip3 install --upgrade pip==9.0.1
16 |
17 | # Install from pip
18 | RUN pip3 install pyyaml \
19 | scipy==1.1.0 \
20 | numpy \
21 | tensorflow \
22 | scikit-learn \
23 | scikit-image \
24 | matplotlib \
25 | opencv-python \
26 | torch==1.0.0 \
27 | torchvision==0.2.0 \
28 | torch-encoding==1.0.1 \
29 | tensorboardX \
30 | tqdm
31 |
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2019 NVIDIA Corporation. Towaki Takikawa, David Acuna, Varun Jampani, Sanja Fidler
2 | All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | Permission to use, copy, modify, and distribute this software and its documentation
6 | for any non-commercial purpose is hereby granted without fee, provided that the above
7 | copyright notice appear in all copies and that both that copyright notice and this
8 | permission notice appear in supporting documentation, and that the name of the author
9 | not be used in advertising or publicity pertaining to distribution of the software
10 | without specific, written prior permission.
11 |
12 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
13 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
14 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
15 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
16 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
17 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
18 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GSCNN
2 | This is the official code for:
3 |
4 | #### Gated-SCNN: Gated Shape CNNs for Semantic Segmentation
5 |
6 | [Towaki Takikawa](https://tovacinni.github.io), [David Acuna](http://www.cs.toronto.edu/~davidj/), [Varun Jampani](https://varunjampani.github.io), [Sanja Fidler](http://www.cs.toronto.edu/~fidler/)
7 |
8 | ICCV 2019
9 | **[[Paper](https://arxiv.org/abs/1907.05740)] [[Project Page](https://nv-tlabs.github.io/GSCNN/)]**
10 |
11 | 
12 |
13 | Based on based on https://github.com/NVIDIA/semantic-segmentation.
14 |
15 | ## License
16 | ```
17 | Copyright (C) 2019 NVIDIA Corporation. Towaki Takikawa, David Acuna, Varun Jampani, Sanja Fidler
18 | All rights reserved.
19 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
20 |
21 | Permission to use, copy, modify, and distribute this software and its documentation
22 | for any non-commercial purpose is hereby granted without fee, provided that the above
23 | copyright notice appear in all copies and that both that copyright notice and this
24 | permission notice appear in supporting documentation, and that the name of the author
25 | not be used in advertising or publicity pertaining to distribution of the software
26 | without specific, written prior permission.
27 |
28 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
29 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
30 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
31 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
32 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
33 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
34 | ~
35 | ```
36 |
37 | ## Usage
38 |
39 | ##### Clone this repo
40 | ```bash
41 | git clone https://github.com/nv-tlabs/GSCNN
42 | cd GSCNN
43 | ```
44 |
45 | #### Python requirements
46 |
47 | Currently, the code supports Python 3
48 | * numpy
49 | * PyTorch (>=1.1.0)
50 | * torchvision
51 | * scipy
52 | * scikit-image
53 | * tensorboardX
54 | * tqdm
55 | * torch-encoding
56 | * opencv
57 | * PyYAML
58 |
59 | #### Download pretrained models
60 |
61 | Download the pretrained model from the [Google Drive Folder](https://drive.google.com/file/d/1wlhAXg-PfoUM-rFy2cksk43Ng3PpsK2c/view), and save it in 'checkpoints/'
62 |
63 | #### Download inferred images
64 |
65 | Download (if needed) the inferred images from the [Google Drive Folder](https://drive.google.com/file/d/105WYnpSagdlf5-ZlSKWkRVeq-MyKLYOV/view)
66 |
67 | #### Evaluation (Cityscapes)
68 | ```bash
69 | python train.py --evaluate --snapshot checkpoints/best_cityscapes_checkpoint.pth
70 | ```
71 |
72 | #### Training
73 |
74 | A note on training- we train on 8 NVIDIA GPUs, and as such, training will be an issue with WiderResNet38 if you try to train on a single GPU.
75 |
76 | If you use this code, please cite:
77 |
78 | ```
79 | @article{takikawa2019gated,
80 | title={Gated-SCNN: Gated Shape CNNs for Semantic Segmentation},
81 | author={Takikawa, Towaki and Acuna, David and Jampani, Varun and Fidler, Sanja},
82 | journal={ICCV},
83 | year={2019}
84 | }
85 | ```
86 |
87 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code adapted from:
6 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py
7 |
8 | Source License
9 | # Copyright (c) 2017-present, Facebook, Inc.
10 | #
11 | # Licensed under the Apache License, Version 2.0 (the "License");
12 | # you may not use this file except in compliance with the License.
13 | # You may obtain a copy of the License at
14 | #
15 | # http://www.apache.org/licenses/LICENSE-2.0
16 | #
17 | # Unless required by applicable law or agreed to in writing, software
18 | # distributed under the License is distributed on an "AS IS" BASIS,
19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20 | # See the License for the specific language governing permissions and
21 | # limitations under the License.
22 | ##############################################################################
23 | #
24 | # Based on:
25 | # --------------------------------------------------------
26 | # Fast R-CNN
27 | # Copyright (c) 2015 Microsoft
28 | # Licensed under The MIT License [see LICENSE for details]
29 | # Written by Ross Girshick
30 | # --------------------------------------------------------
31 | """
32 |
33 | from __future__ import absolute_import
34 | from __future__ import division
35 | from __future__ import print_function
36 | from __future__ import unicode_literals
37 |
38 | import copy
39 | import six
40 | import os.path as osp
41 |
42 | from ast import literal_eval
43 | import numpy as np
44 | import yaml
45 | import torch
46 | import torch.nn as nn
47 | from torch.nn import init
48 |
49 |
50 | from utils.AttrDict import AttrDict
51 |
52 |
53 | __C = AttrDict()
54 | # Consumers can get config by:
55 | # from fast_rcnn_config import cfg
56 | cfg = __C
57 | __C.EPOCH = 0
58 | __C.CLASS_UNIFORM_PCT=0.0
59 | __C.BATCH_WEIGHTING=False
60 | __C.BORDER_WINDOW=1
61 | __C.REDUCE_BORDER_EPOCH= -1
62 | __C.STRICTBORDERCLASS= None
63 |
64 | __C.DATASET =AttrDict()
65 | __C.DATASET.CITYSCAPES_DIR='/home/username/data/cityscapes'
66 | __C.DATASET.CV_SPLITS=3
67 |
68 | __C.MODEL = AttrDict()
69 | __C.MODEL.BN = 'regularnorm'
70 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d
71 | __C.MODEL.BIGMEMORY = False
72 |
73 | def assert_and_infer_cfg(args, make_immutable=True):
74 | """Call this function in your script after you have finished setting all cfg
75 | values that are necessary (e.g., merging a config from a file, merging
76 | command line config options, etc.). By default, this function will also
77 | mark the global cfg as immutable to prevent changing the global cfg settings
78 | during script execution (which can lead to hard to debug errors or code
79 | that's harder to understand than is necessary).
80 | """
81 |
82 | if args.batch_weighting:
83 | __C.BATCH_WEIGHTING=True
84 |
85 | if args.syncbn:
86 | import encoding
87 | __C.MODEL.BN = 'syncnorm'
88 | __C.MODEL.BNFUNC = encoding.nn.BatchNorm2d
89 | else:
90 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d
91 | print('Using regular batch norm')
92 |
93 | if make_immutable:
94 | cfg.immutable(True)
95 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from datasets import cityscapes
7 | import torchvision.transforms as standard_transforms
8 | import torchvision.utils as vutils
9 | import transforms.joint_transforms as joint_transforms
10 | import transforms.transforms as extended_transforms
11 | from torch.utils.data import DataLoader
12 |
13 | def setup_loaders(args):
14 | '''
15 | input: argument passed by the user
16 | return: training data loader, validation data loader loader, train_set
17 | '''
18 |
19 | if args.dataset == 'cityscapes':
20 | args.dataset_cls = cityscapes
21 | args.train_batch_size = args.bs_mult * args.ngpu
22 | if args.bs_mult_val > 0:
23 | args.val_batch_size = args.bs_mult_val * args.ngpu
24 | else:
25 | args.val_batch_size = args.bs_mult * args.ngpu
26 | else:
27 | raise
28 |
29 | args.num_workers = 4 * args.ngpu
30 | if args.test_mode:
31 | args.num_workers = 0 #1
32 |
33 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
34 |
35 | # Geometric image transformations
36 | train_joint_transform_list = [
37 | joint_transforms.RandomSizeAndCrop(args.crop_size,
38 | False,
39 | pre_size=args.pre_size,
40 | scale_min=args.scale_min,
41 | scale_max=args.scale_max,
42 | ignore_index=args.dataset_cls.ignore_label),
43 | joint_transforms.Resize(args.crop_size),
44 | joint_transforms.RandomHorizontallyFlip()]
45 |
46 | #if args.rotate:
47 | # train_joint_transform_list += [joint_transforms.RandomRotate(args.rotate)]
48 |
49 | train_joint_transform = joint_transforms.Compose(train_joint_transform_list)
50 |
51 | # Image appearance transformations
52 | train_input_transform = []
53 | if args.color_aug:
54 | train_input_transform += [extended_transforms.ColorJitter(
55 | brightness=args.color_aug,
56 | contrast=args.color_aug,
57 | saturation=args.color_aug,
58 | hue=args.color_aug)]
59 |
60 | if args.bblur:
61 | train_input_transform += [extended_transforms.RandomBilateralBlur()]
62 | elif args.gblur:
63 | train_input_transform += [extended_transforms.RandomGaussianBlur()]
64 | else:
65 | pass
66 |
67 | train_input_transform += [standard_transforms.ToTensor(),
68 | standard_transforms.Normalize(*mean_std)]
69 | train_input_transform = standard_transforms.Compose(train_input_transform)
70 |
71 | val_input_transform = standard_transforms.Compose([
72 | standard_transforms.ToTensor(),
73 | standard_transforms.Normalize(*mean_std)
74 | ])
75 |
76 | target_transform = extended_transforms.MaskToTensor()
77 |
78 | target_train_transform = extended_transforms.MaskToTensor()
79 |
80 | if args.dataset == 'cityscapes':
81 | city_mode = 'train' ## Can be trainval
82 | city_quality = 'fine'
83 | train_set = args.dataset_cls.CityScapes(
84 | city_quality, city_mode, 0,
85 | joint_transform=train_joint_transform,
86 | transform=train_input_transform,
87 | target_transform=target_train_transform,
88 | dump_images=args.dump_augmentation_images,
89 | cv_split=args.cv)
90 | val_set = args.dataset_cls.CityScapes('fine', 'val', 0,
91 | transform=val_input_transform,
92 | target_transform=target_transform,
93 | cv_split=args.cv)
94 | else:
95 | raise
96 |
97 | train_sampler = None
98 | val_sampler = None
99 |
100 | train_loader = DataLoader(train_set, batch_size=args.train_batch_size,
101 | num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)
102 | val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
103 | num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler)
104 |
105 | return train_loader, val_loader, train_set
106 |
107 |
--------------------------------------------------------------------------------
/datasets/cityscapes.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import os
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | from torch.utils import data
11 | from collections import defaultdict
12 | import math
13 | import logging
14 | import datasets.cityscapes_labels as cityscapes_labels
15 | import json
16 | from config import cfg
17 | import torchvision.transforms as transforms
18 | import datasets.edge_utils as edge_utils
19 |
20 | trainid_to_name = cityscapes_labels.trainId2name
21 | id_to_trainid = cityscapes_labels.label2trainid
22 | num_classes = 19
23 | ignore_label = 255
24 | root = cfg.DATASET.CITYSCAPES_DIR
25 |
26 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153,
27 | 153, 153, 153, 250, 170, 30,
28 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60,
29 | 255, 0, 0, 0, 0, 142, 0, 0, 70,
30 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
31 | zero_pad = 256 * 3 - len(palette)
32 | for i in range(zero_pad):
33 | palette.append(0)
34 |
35 |
36 | def colorize_mask(mask):
37 | # mask: numpy array of the mask
38 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
39 | new_mask.putpalette(palette)
40 | return new_mask
41 |
42 |
43 | def add_items(items, aug_items, cities, img_path, mask_path, mask_postfix, mode, maxSkip):
44 |
45 | for c in cities:
46 | c_items = [name.split('_leftImg8bit.png')[0] for name in
47 | os.listdir(os.path.join(img_path, c))]
48 | for it in c_items:
49 | item = (os.path.join(img_path, c, it + '_leftImg8bit.png'),
50 | os.path.join(mask_path, c, it + mask_postfix))
51 | items.append(item)
52 |
53 | def make_cv_splits(img_dir_name):
54 | '''
55 | Create splits of train/val data.
56 | A split is a lists of cities.
57 | split0 is aligned with the default Cityscapes train/val.
58 | '''
59 | trn_path = os.path.join(root, img_dir_name, 'leftImg8bit', 'train')
60 | val_path = os.path.join(root, img_dir_name, 'leftImg8bit', 'val')
61 |
62 | trn_cities = ['train/' + c for c in os.listdir(trn_path)]
63 | val_cities = ['val/' + c for c in os.listdir(val_path)]
64 |
65 | # want reproducible randomly shuffled
66 | trn_cities = sorted(trn_cities)
67 |
68 | all_cities = val_cities + trn_cities
69 | num_val_cities = len(val_cities)
70 | num_cities = len(all_cities)
71 |
72 | cv_splits = []
73 | for split_idx in range(cfg.DATASET.CV_SPLITS):
74 | split = {}
75 | split['train'] = []
76 | split['val'] = []
77 | offset = split_idx * num_cities // cfg.DATASET.CV_SPLITS
78 | for j in range(num_cities):
79 | if j >= offset and j < (offset + num_val_cities):
80 | split['val'].append(all_cities[j])
81 | else:
82 | split['train'].append(all_cities[j])
83 | cv_splits.append(split)
84 |
85 | return cv_splits
86 |
87 |
88 | def make_split_coarse(img_path):
89 | '''
90 | Create a train/val split for coarse
91 | return: city split in train
92 | '''
93 | all_cities = os.listdir(img_path)
94 | all_cities = sorted(all_cities) # needs to always be the same
95 | val_cities = [] # Can manually set cities to not be included into train split
96 |
97 | split = {}
98 | split['val'] = val_cities
99 | split['train'] = [c for c in all_cities if c not in val_cities]
100 | return split
101 |
102 | def make_test_split(img_dir_name):
103 | test_path = os.path.join(root, img_dir_name, 'leftImg8bit', 'test')
104 | test_cities = ['test/' + c for c in os.listdir(test_path)]
105 |
106 | return test_cities
107 |
108 |
109 | def make_dataset(quality, mode, maxSkip=0, fine_coarse_mult=6, cv_split=0):
110 | '''
111 | Assemble list of images + mask files
112 |
113 | fine - modes: train/val/test/trainval cv:0,1,2
114 | coarse - modes: train/val cv:na
115 |
116 | path examples:
117 | leftImg8bit_trainextra/leftImg8bit/train_extra/augsburg
118 | gtCoarse/gtCoarse/train_extra/augsburg
119 | '''
120 | items = []
121 | aug_items = []
122 |
123 | if quality == 'fine':
124 | assert mode in ['train', 'val', 'test', 'trainval']
125 | img_dir_name = 'leftImg8bit_trainvaltest'
126 | img_path = os.path.join(root, img_dir_name, 'leftImg8bit')
127 | mask_path = os.path.join(root, 'gtFine_trainvaltest', 'gtFine')
128 | mask_postfix = '_gtFine_labelIds.png'
129 | cv_splits = make_cv_splits(img_dir_name)
130 | if mode == 'trainval':
131 | modes = ['train', 'val']
132 | else:
133 | modes = [mode]
134 | for mode in modes:
135 | if mode == 'test':
136 | cv_splits = make_test_split(img_dir_name)
137 | add_items(items, cv_splits, img_path, mask_path,
138 | mask_postfix)
139 | else:
140 | logging.info('{} fine cities: '.format(mode) + str(cv_splits[cv_split][mode]))
141 |
142 | add_items(items, aug_items, cv_splits[cv_split][mode], img_path, mask_path,
143 | mask_postfix, mode, maxSkip)
144 | else:
145 | raise 'unknown cityscapes quality {}'.format(quality)
146 | logging.info('Cityscapes-{}: {} images'.format(mode, len(items)+len(aug_items)))
147 | return items, aug_items
148 |
149 |
150 | class CityScapes(data.Dataset):
151 |
152 | def __init__(self, quality, mode, maxSkip=0, joint_transform=None, sliding_crop=None,
153 | transform=None, target_transform=None, dump_images=False,
154 | cv_split=None, eval_mode=False,
155 | eval_scales=None, eval_flip=False):
156 | self.quality = quality
157 | self.mode = mode
158 | self.maxSkip = maxSkip
159 | self.joint_transform = joint_transform
160 | self.sliding_crop = sliding_crop
161 | self.transform = transform
162 | self.target_transform = target_transform
163 | self.dump_images = dump_images
164 | self.eval_mode = eval_mode
165 | self.eval_flip = eval_flip
166 | self.eval_scales = None
167 | if eval_scales != None:
168 | self.eval_scales = [float(scale) for scale in eval_scales.split(",")]
169 |
170 | if cv_split:
171 | self.cv_split = cv_split
172 | assert cv_split < cfg.DATASET.CV_SPLITS, \
173 | 'expected cv_split {} to be < CV_SPLITS {}'.format(
174 | cv_split, cfg.DATASET.CV_SPLITS)
175 | else:
176 | self.cv_split = 0
177 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split)
178 | if len(self.imgs) == 0:
179 | raise RuntimeError('Found 0 images, please check the data set')
180 |
181 | self.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
182 |
183 | def _eval_get_item(self, img, mask, scales, flip_bool):
184 | return_imgs = []
185 | for flip in range(int(flip_bool)+1):
186 | imgs = []
187 | if flip :
188 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
189 | for scale in scales:
190 | w,h = img.size
191 | target_w, target_h = int(w * scale), int(h * scale)
192 | resize_img =img.resize((target_w, target_h))
193 | tensor_img = transforms.ToTensor()(resize_img)
194 | final_tensor = transforms.Normalize(*self.mean_std)(tensor_img)
195 | imgs.append(tensor_img)
196 | return_imgs.append(imgs)
197 | return return_imgs, mask
198 |
199 |
200 |
201 | def __getitem__(self, index):
202 |
203 | img_path, mask_path = self.imgs[index]
204 |
205 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
206 | img_name = os.path.splitext(os.path.basename(img_path))[0]
207 |
208 | mask = np.array(mask)
209 | mask_copy = mask.copy()
210 | for k, v in id_to_trainid.items():
211 | mask_copy[mask == k] = v
212 |
213 | if self.eval_mode:
214 | return self._eval_get_item(img, mask_copy, self.eval_scales, self.eval_flip), img_name
215 |
216 | mask = Image.fromarray(mask_copy.astype(np.uint8))
217 |
218 | # Image Transformations
219 | if self.joint_transform is not None:
220 | img, mask = self.joint_transform(img, mask)
221 | if self.transform is not None:
222 | img = self.transform(img)
223 | if self.target_transform is not None:
224 | mask = self.target_transform(mask)
225 |
226 | _edgemap = mask.numpy()
227 | _edgemap = edge_utils.mask_to_onehot(_edgemap, num_classes)
228 |
229 | _edgemap = edge_utils.onehot_to_binary_edges(_edgemap, 2, num_classes)
230 |
231 | edgemap = torch.from_numpy(_edgemap).float()
232 |
233 | # Debug
234 | if self.dump_images:
235 | outdir = '../../dump_imgs_{}'.format(self.mode)
236 | os.makedirs(outdir, exist_ok=True)
237 | out_img_fn = os.path.join(outdir, img_name + '.png')
238 | out_msk_fn = os.path.join(outdir, img_name + '_mask.png')
239 | mask_img = colorize_mask(np.array(mask))
240 | img.save(out_img_fn)
241 | mask_img.save(out_msk_fn)
242 |
243 | return img, mask, edgemap, img_name
244 |
245 | def __len__(self):
246 | return len(self.imgs)
247 |
248 |
249 | def make_dataset_video():
250 | img_dir_name = 'leftImg8bit_demoVideo'
251 | img_path = os.path.join(root, img_dir_name, 'leftImg8bit/demoVideo')
252 | items = []
253 | categories = os.listdir(img_path)
254 | for c in categories[1:]:
255 | c_items = [name.split('_leftImg8bit.png')[0] for name in
256 | os.listdir(os.path.join(img_path, c))]
257 | for it in c_items:
258 | item = os.path.join(img_path, c, it + '_leftImg8bit.png')
259 | items.append(item)
260 | return items
261 |
262 |
263 | class CityScapesVideo(data.Dataset):
264 |
265 | def __init__(self, transform=None):
266 | self.imgs = make_dataset_video()
267 | if len(self.imgs) == 0:
268 | raise RuntimeError('Found 0 images, please check the data set')
269 | self.transform = transform
270 |
271 | def __getitem__(self, index):
272 | img_path = self.imgs[index]
273 | img = Image.open(img_path).convert('RGB')
274 | img_name = os.path.splitext(os.path.basename(img_path))[0]
275 |
276 | if self.transform is not None:
277 | img = self.transform(img)
278 | return img, img_name
279 |
280 | def __len__(self):
281 | return len(self.imgs)
282 |
283 |
--------------------------------------------------------------------------------
/datasets/cityscapes_labels.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # File taken from https://github.com/mcordts/cityscapesScripts/
6 | # License File Available at:
7 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt
8 |
9 | # ----------------------
10 | # The Cityscapes Dataset
11 | # ----------------------
12 | #
13 | #
14 | # License agreement
15 | # -----------------
16 | #
17 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree:
18 | #
19 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions.
20 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website.
21 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character.
22 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain.
23 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt).
24 | #
25 | #
26 | # Contact
27 | # -------
28 | #
29 | # Marius Cordts, Mohamed Omran
30 | # www.cityscapes-dataset.net
31 |
32 | """
33 |
34 | from collections import namedtuple
35 |
36 |
37 | #--------------------------------------------------------------------------------
38 | # Definitions
39 | #--------------------------------------------------------------------------------
40 |
41 | # a label and all meta information
42 | Label = namedtuple( 'Label' , [
43 |
44 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
45 | # We use them to uniquely name a class
46 |
47 | 'id' , # An integer ID that is associated with this label.
48 | # The IDs are used to represent the label in ground truth images
49 | # An ID of -1 means that this label does not have an ID and thus
50 | # is ignored when creating ground truth images (e.g. license plate).
51 | # Do not modify these IDs, since exactly these IDs are expected by the
52 | # evaluation server.
53 |
54 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
55 | # ground truth images with train IDs, using the tools provided in the
56 | # 'preparation' folder. However, make sure to validate or submit results
57 | # to our evaluation server using the regular IDs above!
58 | # For trainIds, multiple labels might have the same ID. Then, these labels
59 | # are mapped to the same class in the ground truth images. For the inverse
60 | # mapping, we use the label that is defined first in the list below.
61 | # For example, mapping all void-type classes to the same ID in training,
62 | # might make sense for some approaches.
63 | # Max value is 255!
64 |
65 | 'category' , # The name of the category that this label belongs to
66 |
67 | 'categoryId' , # The ID of this category. Used to create ground truth images
68 | # on category level.
69 |
70 | 'hasInstances', # Whether this label distinguishes between single instances or not
71 |
72 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
73 | # during evaluations or not
74 |
75 | 'color' , # The color of this label
76 | ] )
77 |
78 |
79 | #--------------------------------------------------------------------------------
80 | # A list of all labels
81 | #--------------------------------------------------------------------------------
82 |
83 | # Please adapt the train IDs as appropriate for you approach.
84 | # Note that you might want to ignore labels with ID 255 during training.
85 | # Further note that the current train IDs are only a suggestion. You can use whatever you like.
86 | # Make sure to provide your results using the original IDs and not the training IDs.
87 | # Note that many IDs are ignored in evaluation and thus you never need to predict these!
88 |
89 | labels = [
90 | # name id trainId category catId hasInstances ignoreInEval color
91 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
92 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
93 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
94 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
95 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
96 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
97 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
98 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
99 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
100 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
101 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
102 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
103 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
104 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
105 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
106 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
107 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
108 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
109 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
110 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
111 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
112 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
113 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
114 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
115 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
116 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
117 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
118 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
119 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
120 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
121 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
122 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
123 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
124 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
125 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
126 | Label( 'license plate' , 34 , 255 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
127 | ]
128 |
129 |
130 | #--------------------------------------------------------------------------------
131 | # Create dictionaries for a fast lookup
132 | #--------------------------------------------------------------------------------
133 |
134 | # Please refer to the main method below for example usages!
135 |
136 | # name to label object
137 | name2label = { label.name : label for label in labels }
138 | # id to label object
139 | id2label = { label.id : label for label in labels }
140 | # trainId to label object
141 | trainId2label = { label.trainId : label for label in reversed(labels) }
142 | # label2trainid
143 | label2trainid = { label.id : label.trainId for label in labels }
144 | # trainId to label object
145 | trainId2name = { label.trainId : label.name for label in labels }
146 | trainId2color = { label.trainId : label.color for label in labels }
147 | # category to list of label objects
148 | category2labels = {}
149 | for label in labels:
150 | category = label.category
151 | if category in category2labels:
152 | category2labels[category].append(label)
153 | else:
154 | category2labels[category] = [label]
155 |
156 | #--------------------------------------------------------------------------------
157 | # Assure single instance name
158 | #--------------------------------------------------------------------------------
159 |
160 | # returns the label name that describes a single instance (if possible)
161 | # e.g. input | output
162 | # ----------------------
163 | # car | car
164 | # cargroup | car
165 | # foo | None
166 | # foogroup | None
167 | # skygroup | None
168 | def assureSingleInstanceName( name ):
169 | # if the name is known, it is not a group
170 | if name in name2label:
171 | return name
172 | # test if the name actually denotes a group
173 | if not name.endswith("group"):
174 | return None
175 | # remove group
176 | name = name[:-len("group")]
177 | # test if the new name exists
178 | if not name in name2label:
179 | return None
180 | # test if the new name denotes a label that actually has instances
181 | if not name2label[name].hasInstances:
182 | return None
183 | # all good then
184 | return name
185 |
186 | #--------------------------------------------------------------------------------
187 | # Main for testing
188 | #--------------------------------------------------------------------------------
189 |
190 | # just a dummy main
191 | if __name__ == "__main__":
192 | # Print all the labels
193 | print("List of cityscapes labels:")
194 | print("")
195 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )))
196 | print((" " + ('-' * 98)))
197 | for label in labels:
198 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )))
199 | print("")
200 |
201 | print("Example usages:")
202 |
203 | # Map from name to label
204 | name = 'car'
205 | id = name2label[name].id
206 | print(("ID of label '{name}': {id}".format( name=name, id=id )))
207 |
208 | # Map from ID to label
209 | category = id2label[id].category
210 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category )))
211 |
212 | # Map from trainID to label
213 | trainId = 0
214 | name = trainId2label[trainId].name
215 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )))
216 |
--------------------------------------------------------------------------------
/datasets/edge_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import os
7 | import numpy as np
8 | from PIL import Image
9 | from scipy.ndimage.morphology import distance_transform_edt
10 |
11 | def mask_to_onehot(mask, num_classes):
12 | """
13 | Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one
14 | hot encoding vector
15 |
16 | """
17 | _mask = [mask == (i + 1) for i in range(num_classes)]
18 | return np.array(_mask).astype(np.uint8)
19 |
20 | def onehot_to_mask(mask):
21 | """
22 | Converts a mask (K,H,W) to (H,W)
23 | """
24 | _mask = np.argmax(mask, axis=0)
25 | _mask[_mask != 0] += 1
26 | return _mask
27 |
28 | def onehot_to_multiclass_edges(mask, radius, num_classes):
29 | """
30 | Converts a segmentation mask (K,H,W) to an edgemap (K,H,W)
31 |
32 | """
33 | if radius < 0:
34 | return mask
35 |
36 | # We need to pad the borders for boundary conditions
37 | mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0)
38 |
39 | channels = []
40 | for i in range(num_classes):
41 | dist = distance_transform_edt(mask_pad[i, :])+distance_transform_edt(1.0-mask_pad[i, :])
42 | dist = dist[1:-1, 1:-1]
43 | dist[dist > radius] = 0
44 | dist = (dist > 0).astype(np.uint8)
45 | channels.append(dist)
46 |
47 | return np.array(channels)
48 |
49 | def onehot_to_binary_edges(mask, radius, num_classes):
50 | """
51 | Converts a segmentation mask (K,H,W) to a binary edgemap (H,W)
52 |
53 | """
54 |
55 | if radius < 0:
56 | return mask
57 |
58 | # We need to pad the borders for boundary conditions
59 | mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0)
60 |
61 | edgemap = np.zeros(mask.shape[1:])
62 |
63 | for i in range(num_classes):
64 | dist = distance_transform_edt(mask_pad[i, :])+distance_transform_edt(1.0-mask_pad[i, :])
65 | dist = dist[1:-1, 1:-1]
66 | dist[dist > radius] = 0
67 | edgemap += dist
68 | edgemap = np.expand_dims(edgemap, axis=0)
69 | edgemap = (edgemap > 0).astype(np.uint8)
70 | return edgemap
71 |
72 |
--------------------------------------------------------------------------------
/docs/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
Redirecting to https://research.nvidia.com/labs/toronto-ai/GSCNN/
4 |
5 |
6 |
--------------------------------------------------------------------------------
/docs/resources/GSCNN.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/GSCNN.mp4
--------------------------------------------------------------------------------
/docs/resources/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/architecture.jpg
--------------------------------------------------------------------------------
/docs/resources/bibtex.txt:
--------------------------------------------------------------------------------
1 | @inproceedings{Takikawa2019GatedSCNNGS,
2 | title={Gated-SCNN: Gated Shape CNNs for Semantic Segmentation},
3 | author={Towaki Takikawa and David Acuna and Varun Jampani and Sanja Fidler},
4 | year={2019}
5 | }
6 |
7 |
--------------------------------------------------------------------------------
/docs/resources/crop.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/crop.jpg
--------------------------------------------------------------------------------
/docs/resources/edges.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/edges.jpg
--------------------------------------------------------------------------------
/docs/resources/gscnn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/gscnn.gif
--------------------------------------------------------------------------------
/docs/resources/intro.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/intro.jpg
--------------------------------------------------------------------------------
/docs/resources/seg.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/seg.jpg
--------------------------------------------------------------------------------
/docs/resources/semboundary.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/semboundary.jpg
--------------------------------------------------------------------------------
/docs/resources/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/table.png
--------------------------------------------------------------------------------
/docs/resources/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/test.jpg
--------------------------------------------------------------------------------
/docs/resources/top.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nv-tlabs/GSCNN/edd16c5d3f346a0ec009cec539fbcaf6d79b5111/docs/resources/top.jpg
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import logging
10 | import numpy as np
11 | from config import cfg
12 | from my_functionals.DualTaskLoss import DualTaskLoss
13 |
14 | def get_loss(args):
15 | '''
16 | Get the criterion based on the loss function
17 | args:
18 | return: criterion
19 | '''
20 |
21 | if args.img_wt_loss:
22 | criterion = ImageBasedCrossEntropyLoss2d(
23 | classes=args.dataset_cls.num_classes, size_average=True,
24 | ignore_index=args.dataset_cls.ignore_label,
25 | upper_bound=args.wt_bound).cuda()
26 | elif args.joint_edgeseg_loss:
27 | criterion = JointEdgeSegLoss(classes=args.dataset_cls.num_classes,
28 | ignore_index=args.dataset_cls.ignore_label, upper_bound=args.wt_bound,
29 | edge_weight=args.edge_weight, seg_weight=args.seg_weight, att_weight=args.att_weight, dual_weight=args.dual_weight).cuda()
30 |
31 | else:
32 | criterion = CrossEntropyLoss2d(size_average=True,
33 | ignore_index=args.dataset_cls.ignore_label).cuda()
34 |
35 | criterion_val = JointEdgeSegLoss(classes=args.dataset_cls.num_classes, mode='val',
36 | ignore_index=args.dataset_cls.ignore_label, upper_bound=args.wt_bound,
37 | edge_weight=args.edge_weight, seg_weight=args.seg_weight).cuda()
38 |
39 | return criterion, criterion_val
40 |
41 | class JointEdgeSegLoss(nn.Module):
42 | def __init__(self, classes, weight=None, reduction='mean', ignore_index=255,
43 | norm=False, upper_bound=1.0, mode='train',
44 | edge_weight=1, seg_weight=1, att_weight=1, dual_weight=1, edge='none'):
45 | super(JointEdgeSegLoss, self).__init__()
46 | self.num_classes = classes
47 | if mode == 'train':
48 | self.seg_loss = ImageBasedCrossEntropyLoss2d(
49 | classes=classes, ignore_index=ignore_index, upper_bound=upper_bound).cuda()
50 | elif mode == 'val':
51 | self.seg_loss = CrossEntropyLoss2d(size_average=True,
52 | ignore_index=ignore_index).cuda()
53 |
54 | self.edge_weight = edge_weight
55 | self.seg_weight = seg_weight
56 | self.att_weight = att_weight
57 | self.dual_weight = dual_weight
58 |
59 | self.dual_task = DualTaskLoss()
60 |
61 | def bce2d(self, input, target):
62 | n, c, h, w = input.size()
63 |
64 | log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1)
65 | target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1)
66 | target_trans = target_t.clone()
67 |
68 | pos_index = (target_t ==1)
69 | neg_index = (target_t ==0)
70 | ignore_index=(target_t >1)
71 |
72 | target_trans[pos_index] = 1
73 | target_trans[neg_index] = 0
74 |
75 | pos_index = pos_index.data.cpu().numpy().astype(bool)
76 | neg_index = neg_index.data.cpu().numpy().astype(bool)
77 | ignore_index=ignore_index.data.cpu().numpy().astype(bool)
78 |
79 | weight = torch.Tensor(log_p.size()).fill_(0)
80 | weight = weight.numpy()
81 | pos_num = pos_index.sum()
82 | neg_num = neg_index.sum()
83 | sum_num = pos_num + neg_num
84 | weight[pos_index] = neg_num*1.0 / sum_num
85 | weight[neg_index] = pos_num*1.0 / sum_num
86 |
87 | weight[ignore_index] = 0
88 |
89 | weight = torch.from_numpy(weight)
90 | weight = weight.cuda()
91 | loss = F.binary_cross_entropy_with_logits(log_p, target_t, weight, size_average=True)
92 | return loss
93 |
94 | def edge_attention(self, input, target, edge):
95 | n, c, h, w = input.size()
96 | filler = torch.ones_like(target) * 255
97 | return self.seg_loss(input,
98 | torch.where(edge.max(1)[0] > 0.8, target, filler))
99 |
100 | def forward(self, inputs, targets):
101 | segin, edgein = inputs
102 | segmask, edgemask = targets
103 |
104 | losses = {}
105 |
106 | losses['seg_loss'] = self.seg_weight * self.seg_loss(segin, segmask)
107 | losses['edge_loss'] = self.edge_weight * 20 * self.bce2d(edgein, edgemask)
108 | losses['att_loss'] = self.att_weight * self.edge_attention(segin, segmask, edgein)
109 | losses['dual_loss'] = self.dual_weight * self.dual_task(segin, segmask)
110 |
111 | return losses
112 |
113 | #Img Weighted Loss
114 | class ImageBasedCrossEntropyLoss2d(nn.Module):
115 |
116 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255,
117 | norm=False, upper_bound=1.0):
118 | super(ImageBasedCrossEntropyLoss2d, self).__init__()
119 | logging.info("Using Per Image based weighted loss")
120 | self.num_classes = classes
121 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index)
122 | self.norm = norm
123 | self.upper_bound = upper_bound
124 | self.batch_weights = cfg.BATCH_WEIGHTING
125 |
126 | def calculateWeights(self, target):
127 | hist = np.histogram(target.flatten(), range(
128 | self.num_classes + 1), normed=True)[0]
129 | if self.norm:
130 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
131 | else:
132 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
133 | return hist
134 |
135 | def forward(self, inputs, targets):
136 | target_cpu = targets.data.cpu().numpy()
137 | if self.batch_weights:
138 | weights = self.calculateWeights(target_cpu)
139 | self.nll_loss.weight = torch.Tensor(weights).cuda()
140 |
141 | loss = 0.0
142 | for i in range(0, inputs.shape[0]):
143 | if not self.batch_weights:
144 | weights = self.calculateWeights(target_cpu[i])
145 | self.nll_loss.weight = torch.Tensor(weights).cuda()
146 |
147 | loss += self.nll_loss(F.log_softmax(inputs[i].unsqueeze(0)),
148 | targets[i].unsqueeze(0))
149 | return loss
150 |
151 |
152 | #Cross Entroply NLL Loss
153 | class CrossEntropyLoss2d(nn.Module):
154 | def __init__(self, weight=None, size_average=True, ignore_index=255):
155 | super(CrossEntropyLoss2d, self).__init__()
156 | logging.info("Using Cross Entropy Loss")
157 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index)
158 |
159 | def forward(self, inputs, targets):
160 | return self.nll_loss(F.log_softmax(inputs), targets)
161 |
162 |
--------------------------------------------------------------------------------
/my_functionals/DualTaskLoss.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code adapted from:
6 | # https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
7 | #
8 | # MIT License
9 | #
10 | # Copyright (c) 2016 Eric Jang
11 | #
12 | # Permission is hereby granted, free of charge, to any person obtaining a copy
13 | # of this software and associated documentation files (the "Software"), to deal
14 | # in the Software without restriction, including without limitation the rights
15 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 | # copies of the Software, and to permit persons to whom the Software is
17 | # furnished to do so, subject to the following conditions:
18 | #
19 | # The above copyright notice and this permission notice shall be included in all
20 | # copies or substantial portions of the Software.
21 | #
22 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 | # SOFTWARE.
29 | """
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 |
35 | import numpy as np
36 | from my_functionals.custom_functional import compute_grad_mag
37 |
38 | def perturbate_input_(input, n_elements=200):
39 | N, C, H, W = input.shape
40 | assert N == 1
41 | c_ = np.random.random_integers(0, C - 1, n_elements)
42 | h_ = np.random.random_integers(0, H - 1, n_elements)
43 | w_ = np.random.random_integers(0, W - 1, n_elements)
44 | for c_idx in c_:
45 | for h_idx in h_:
46 | for w_idx in w_:
47 | input[0, c_idx, h_idx, w_idx] = 1
48 | return input
49 |
50 | def _sample_gumbel(shape, eps=1e-10):
51 | """
52 | Sample from Gumbel(0, 1)
53 |
54 | based on
55 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb ,
56 | (MIT license)
57 | """
58 | U = torch.rand(shape).cuda()
59 | return - torch.log(eps - torch.log(U + eps))
60 |
61 |
62 | def _gumbel_softmax_sample(logits, tau=1, eps=1e-10):
63 | """
64 | Draw a sample from the Gumbel-Softmax distribution
65 |
66 | based on
67 | https://github.com/ericjang/gumbel-softmax/blob/3c8584924603869e90ca74ac20a6a03d99a91ef9/Categorical%20VAE.ipynb
68 | (MIT license)
69 | """
70 | assert logits.dim() == 3
71 | gumbel_noise = _sample_gumbel(logits.size(), eps=eps)
72 | y = logits + gumbel_noise
73 | return F.softmax(y / tau, 1)
74 |
75 |
76 | def _one_hot_embedding(labels, num_classes):
77 | """Embedding labels to one-hot form.
78 |
79 | Args:
80 | labels: (LongTensor) class labels, sized [N,].
81 | num_classes: (int) number of classes.
82 |
83 | Returns:
84 | (tensor) encoded labels, sized [N, #classes].
85 | """
86 |
87 | y = torch.eye(num_classes).cuda()
88 | return y[labels].permute(0,3,1,2)
89 |
90 | class DualTaskLoss(nn.Module):
91 | def __init__(self, cuda=False):
92 | super(DualTaskLoss, self).__init__()
93 | self._cuda = cuda
94 | return
95 |
96 | def forward(self, input_logits, gts, ignore_pixel=255):
97 | """
98 | :param input_logits: NxCxHxW
99 | :param gt_semantic_masks: NxCxHxW
100 | :return: final loss
101 | """
102 | N, C, H, W = input_logits.shape
103 | th = 1e-8 # 1e-10
104 | eps = 1e-10
105 | ignore_mask = (gts == ignore_pixel).detach()
106 | input_logits = torch.where(ignore_mask.view(N, 1, H, W).expand(N, 19, H, W),
107 | torch.zeros(N,C,H,W).cuda(),
108 | input_logits)
109 | gt_semantic_masks = gts.detach()
110 | gt_semantic_masks = torch.where(ignore_mask, torch.zeros(N,H,W).long().cuda(), gt_semantic_masks)
111 | gt_semantic_masks = _one_hot_embedding(gt_semantic_masks, 19).detach()
112 |
113 | g = _gumbel_softmax_sample(input_logits.view(N, C, -1), tau=0.5)
114 | g = g.reshape((N, C, H, W))
115 | g = compute_grad_mag(g, cuda=self._cuda)
116 |
117 | g_hat = compute_grad_mag(gt_semantic_masks, cuda=self._cuda)
118 |
119 | g = g.view(N, -1)
120 | g_hat = g_hat.view(N, -1)
121 | loss_ewise = F.l1_loss(g, g_hat, reduction='none', reduce=False)
122 |
123 | p_plus_g_mask = (g >= th).detach().float()
124 | loss_p_plus_g = torch.sum(loss_ewise * p_plus_g_mask) / (torch.sum(p_plus_g_mask) + eps)
125 |
126 | p_plus_g_hat_mask = (g_hat >= th).detach().float()
127 | loss_p_plus_g_hat = torch.sum(loss_ewise * p_plus_g_hat_mask) / (torch.sum(p_plus_g_hat_mask) + eps)
128 |
129 | total_loss = 0.5 * loss_p_plus_g + 0.5 * loss_p_plus_g_hat
130 |
131 | return total_loss
132 |
133 |
--------------------------------------------------------------------------------
/my_functionals/GatedSpatialConv.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch.nn as nn
7 | import torch
8 | import torch.nn.functional as F
9 | from torch.nn.modules.conv import _ConvNd
10 | from torch.nn.modules.utils import _pair
11 | import numpy as np
12 | import math
13 | import network.mynn as mynn
14 | import my_functionals.custom_functional as myF
15 | class GatedSpatialConv2d(_ConvNd):
16 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
17 | padding=0, dilation=1, groups=1, bias=False):
18 | """
19 |
20 | :param in_channels:
21 | :param out_channels:
22 | :param kernel_size:
23 | :param stride:
24 | :param padding:
25 | :param dilation:
26 | :param groups:
27 | :param bias:
28 | """
29 |
30 | kernel_size = _pair(kernel_size)
31 | stride = _pair(stride)
32 | padding = _pair(padding)
33 | dilation = _pair(dilation)
34 | super(GatedSpatialConv2d, self).__init__(
35 | in_channels, out_channels, kernel_size, stride, padding, dilation,
36 | False, _pair(0), groups, bias, 'zeros')
37 |
38 | self._gate_conv = nn.Sequential(
39 | mynn.Norm2d(in_channels+1),
40 | nn.Conv2d(in_channels+1, in_channels+1, 1),
41 | nn.ReLU(),
42 | nn.Conv2d(in_channels+1, 1, 1),
43 | mynn.Norm2d(1),
44 | nn.Sigmoid()
45 | )
46 |
47 | def forward(self, input_features, gating_features):
48 | """
49 |
50 | :param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch).
51 | :param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map.
52 | :return:
53 | """
54 | alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1))
55 |
56 | input_features = (input_features * (alphas + 1))
57 | return F.conv2d(input_features, self.weight, self.bias, self.stride,
58 | self.padding, self.dilation, self.groups)
59 |
60 | def reset_parameters(self):
61 | nn.init.xavier_normal_(self.weight)
62 | if self.bias is not None:
63 | nn.init.zeros_(self.bias)
64 |
65 |
66 | class Conv2dPad(nn.Conv2d):
67 | def forward(self, input):
68 | return myF.conv2d_same(input,self.weight,self.groups)
69 |
70 | class HighFrequencyGatedSpatialConv2d(_ConvNd):
71 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
72 | padding=0, dilation=1, groups=1, bias=False):
73 | """
74 |
75 | :param in_channels:
76 | :param out_channels:
77 | :param kernel_size:
78 | :param stride:
79 | :param padding:
80 | :param dilation:
81 | :param groups:
82 | :param bias:
83 | """
84 |
85 | kernel_size = _pair(kernel_size)
86 | stride = _pair(stride)
87 | padding = _pair(padding)
88 | dilation = _pair(dilation)
89 | super(HighFrequencyGatedSpatialConv2d, self).__init__(
90 | in_channels, out_channels, kernel_size, stride, padding, dilation,
91 | False, _pair(0), groups, bias)
92 |
93 | self._gate_conv = nn.Sequential(
94 | mynn.Norm2d(in_channels+1),
95 | nn.Conv2d(in_channels+1, in_channels+1, 1),
96 | nn.ReLU(),
97 | nn.Conv2d(in_channels+1, 1, 1),
98 | mynn.Norm2d(1),
99 | nn.Sigmoid()
100 | )
101 |
102 | kernel_size = 7
103 | sigma = 3
104 |
105 | x_cord = torch.arange(kernel_size).float()
106 | x_grid = x_cord.repeat(kernel_size).view(kernel_size, kernel_size).float()
107 | y_grid = x_grid.t().float()
108 | xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
109 |
110 | mean = (kernel_size - 1)/2.
111 | variance = sigma**2.
112 | gaussian_kernel = (1./(2.*math.pi*variance)) *\
113 | torch.exp(
114 | -torch.sum((xy_grid - mean)**2., dim=-1) /\
115 | (2*variance)
116 | )
117 |
118 | gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
119 |
120 | gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
121 | gaussian_kernel = gaussian_kernel.repeat(in_channels, 1, 1, 1)
122 |
123 | self.gaussian_filter = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, padding=3,
124 | kernel_size=kernel_size, groups=in_channels, bias=False)
125 |
126 | self.gaussian_filter.weight.data = gaussian_kernel
127 | self.gaussian_filter.weight.requires_grad = False
128 |
129 | self.cw = nn.Conv2d(in_channels * 2, in_channels, 1)
130 |
131 | self.procdog = nn.Sequential(
132 | nn.Conv2d(in_channels, in_channels, 1),
133 | mynn.Norm2d(in_channels),
134 | nn.Sigmoid()
135 | )
136 |
137 | def forward(self, input_features, gating_features):
138 | """
139 |
140 | :param input_features: [NxCxHxW] featuers comming from the shape branch (canny branch).
141 | :param gating_features: [Nx1xHxW] features comming from the texture branch (resnet). Only one channel feature map.
142 | :return:
143 | """
144 | n, c, h, w = input_features.size()
145 | smooth_features = self.gaussian_filter(input_features)
146 | dog_features = input_features - smooth_features
147 | dog_features = self.cw(torch.cat((dog_features, input_features), dim=1))
148 |
149 | alphas = self._gate_conv(torch.cat([input_features, gating_features], dim=1))
150 |
151 | dog_features = dog_features * (alphas + 1)
152 |
153 | return F.conv2d(dog_features, self.weight, self.bias, self.stride,
154 | self.padding, self.dilation, self.groups)
155 |
156 | def reset_parameters(self):
157 | nn.init.xavier_normal_(self.weight)
158 | if self.bias is not None:
159 | nn.init.zeros_(self.bias)
160 |
161 | def t():
162 | import matplotlib.pyplot as plt
163 |
164 | canny_map_filters_in = 8
165 | canny_map = np.random.normal(size=(1, canny_map_filters_in, 10, 10)) # NxCxHxW
166 | resnet_map = np.random.normal(size=(1, 1, 10, 10)) # NxCxHxW
167 | plt.imshow(canny_map[0, 0])
168 | plt.show()
169 |
170 | canny_map = torch.from_numpy(canny_map).float()
171 | resnet_map = torch.from_numpy(resnet_map).float()
172 |
173 | gconv = GatedSpatialConv2d(canny_map_filters_in, canny_map_filters_in,
174 | kernel_size=3, stride=1, padding=1)
175 | output_map = gconv(canny_map, resnet_map)
176 | print('done')
177 |
178 |
179 | if __name__ == "__main__":
180 | t()
181 |
182 |
--------------------------------------------------------------------------------
/my_functionals/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
--------------------------------------------------------------------------------
/my_functionals/custom_functional.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torchvision.transforms.functional import pad
9 | import numpy as np
10 |
11 |
12 | def calc_pad_same(in_siz, out_siz, stride, ksize):
13 | """Calculate same padding width.
14 | Args:
15 | ksize: kernel size [I, J].
16 | Returns:
17 | pad_: Actual padding width.
18 | """
19 | return (out_siz - 1) * stride + ksize - in_siz
20 |
21 |
22 | def conv2d_same(input, kernel, groups,bias=None,stride=1,padding=0,dilation=1):
23 | n, c, h, w = input.shape
24 | kout, ki_c_g, kh, kw = kernel.shape
25 | pw = calc_pad_same(w, w, 1, kw)
26 | ph = calc_pad_same(h, h, 1, kh)
27 | pw_l = pw // 2
28 | pw_r = pw - pw_l
29 | ph_t = ph // 2
30 | ph_b = ph - ph_t
31 |
32 | input_ = F.pad(input, (pw_l, pw_r, ph_t, ph_b))
33 | result = F.conv2d(input_, kernel, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
34 | assert result.shape == input.shape
35 | return result
36 |
37 |
38 | def gradient_central_diff(input, cuda):
39 | return input, input
40 | kernel = [[1, 0, -1]]
41 | kernel_t = 0.5 * torch.Tensor(kernel) * -1. # pytorch implements correlation instead of conv
42 | if type(cuda) is int:
43 | if cuda != -1:
44 | kernel_t = kernel_t.cuda(device=cuda)
45 | else:
46 | if cuda is True:
47 | kernel_t = kernel_t.cuda()
48 | n, c, h, w = input.shape
49 |
50 | x = conv2d_same(input, kernel_t.unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), c)
51 | y = conv2d_same(input, kernel_t.t().unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]), c)
52 | return x, y
53 |
54 |
55 | def compute_single_sided_diferences(o_x, o_y, input):
56 | # n,c,h,w
57 | #input = input.clone()
58 | o_y[:, :, 0, :] = input[:, :, 1, :].clone() - input[:, :, 0, :].clone()
59 | o_x[:, :, :, 0] = input[:, :, :, 1].clone() - input[:, :, :, 0].clone()
60 | # --
61 | o_y[:, :, -1, :] = input[:, :, -1, :].clone() - input[:, :, -2, :].clone()
62 | o_x[:, :, :, -1] = input[:, :, :, -1].clone() - input[:, :, :, -2].clone()
63 | return o_x, o_y
64 |
65 |
66 | def numerical_gradients_2d(input, cuda=False):
67 | """
68 | numerical gradients implementation over batches using torch group conv operator.
69 | the single sided differences are re-computed later.
70 | it matches np.gradient(image) with the difference than here output=x,y for an image while there output=y,x
71 | :param input: N,C,H,W
72 | :param cuda: whether or not use cuda
73 | :return: X,Y
74 | """
75 | n, c, h, w = input.shape
76 | assert h > 1 and w > 1
77 | x, y = gradient_central_diff(input, cuda)
78 | return x, y
79 |
80 |
81 | def convTri(input, r, cuda=False):
82 | """
83 | Convolves an image by a 2D triangle filter (the 1D triangle filter f is
84 | [1:r r+1 r:-1:1]/(r+1)^2, the 2D version is simply conv2(f,f'))
85 | :param input:
86 | :param r: integer filter radius
87 | :param cuda: move the kernel to gpu
88 | :return:
89 | """
90 | if (r <= 1):
91 | raise ValueError()
92 | n, c, h, w = input.shape
93 | return input
94 | f = list(range(1, r + 1)) + [r + 1] + list(reversed(range(1, r + 1)))
95 | kernel = torch.Tensor([f]) / (r + 1) ** 2
96 | if type(cuda) is int:
97 | if cuda != -1:
98 | kernel = kernel.cuda(device=cuda)
99 | else:
100 | if cuda is True:
101 | kernel = kernel.cuda()
102 |
103 | # padding w
104 | input_ = F.pad(input, (1, 1, 0, 0), mode='replicate')
105 | input_ = F.pad(input_, (r, r, 0, 0), mode='reflect')
106 | input_ = [input_[:, :, :, :r], input, input_[:, :, :, -r:]]
107 | input_ = torch.cat(input_, 3)
108 | t = input_
109 |
110 | # padding h
111 | input_ = F.pad(input_, (0, 0, 1, 1), mode='replicate')
112 | input_ = F.pad(input_, (0, 0, r, r), mode='reflect')
113 | input_ = [input_[:, :, :r, :], t, input_[:, :, -r:, :]]
114 | input_ = torch.cat(input_, 2)
115 |
116 | output = F.conv2d(input_,
117 | kernel.unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]),
118 | padding=0, groups=c)
119 | output = F.conv2d(output,
120 | kernel.t().unsqueeze(0).unsqueeze(0).repeat([c, 1, 1, 1]),
121 | padding=0, groups=c)
122 | return output
123 |
124 |
125 | def compute_normal(E, cuda=False):
126 | if torch.sum(torch.isnan(E)) != 0:
127 | print('nans found here')
128 | import ipdb;
129 | ipdb.set_trace()
130 | E_ = convTri(E, 4, cuda)
131 | Ox, Oy = numerical_gradients_2d(E_, cuda)
132 | Oxx, _ = numerical_gradients_2d(Ox, cuda)
133 | Oxy, Oyy = numerical_gradients_2d(Oy, cuda)
134 |
135 | aa = Oyy * torch.sign(-(Oxy + 1e-5)) / (Oxx + 1e-5)
136 | t = torch.atan(aa)
137 | O = torch.remainder(t, np.pi)
138 |
139 | if torch.sum(torch.isnan(O)) != 0:
140 | print('nans found here')
141 | import ipdb;
142 | ipdb.set_trace()
143 |
144 | return O
145 |
146 |
147 | def compute_normal_2(E, cuda=False):
148 | if torch.sum(torch.isnan(E)) != 0:
149 | print('nans found here')
150 | import ipdb;
151 | ipdb.set_trace()
152 | E_ = convTri(E, 4, cuda)
153 | Ox, Oy = numerical_gradients_2d(E_, cuda)
154 | Oxx, _ = numerical_gradients_2d(Ox, cuda)
155 | Oxy, Oyy = numerical_gradients_2d(Oy, cuda)
156 |
157 | aa = Oyy * torch.sign(-(Oxy + 1e-5)) / (Oxx + 1e-5)
158 | t = torch.atan(aa)
159 | O = torch.remainder(t, np.pi)
160 |
161 | if torch.sum(torch.isnan(O)) != 0:
162 | print('nans found here')
163 | import ipdb;
164 | ipdb.set_trace()
165 |
166 | return O, (Oyy, Oxx)
167 |
168 |
169 | def compute_grad_mag(E, cuda=False):
170 | E_ = convTri(E, 4, cuda)
171 | Ox, Oy = numerical_gradients_2d(E_, cuda)
172 | mag = torch.sqrt(torch.mul(Ox,Ox) + torch.mul(Oy,Oy) + 1e-6)
173 | mag = mag / mag.max();
174 |
175 | return mag
176 |
--------------------------------------------------------------------------------
/network/Resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code Adapted from:
6 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
7 | #
8 | # BSD 3-Clause License
9 | #
10 | # Copyright (c) 2017,
11 | # All rights reserved.
12 | #
13 | # Redistribution and use in source and binary forms, with or without
14 | # modification, are permitted provided that the following conditions are met:
15 | #
16 | # * Redistributions of source code must retain the above copyright notice, this
17 | # list of conditions and the following disclaimer.
18 | #
19 | # * Redistributions in binary form must reproduce the above copyright notice,
20 | # this list of conditions and the following disclaimer in the documentation
21 | # and/or other materials provided with the distribution.
22 | #
23 | # * Neither the name of the copyright holder nor the names of its
24 | # contributors may be used to endorse or promote products derived from
25 | # this software without specific prior written permission.
26 | #
27 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
31 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
32 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
33 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
34 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
35 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 | """
38 |
39 | import torch.nn as nn
40 | import math
41 | import torch.utils.model_zoo as model_zoo
42 | import network.mynn as mynn
43 |
44 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
45 | 'resnet152']
46 |
47 |
48 | model_urls = {
49 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
50 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
51 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
52 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
53 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
54 | }
55 |
56 |
57 | def conv3x3(in_planes, out_planes, stride=1):
58 | """3x3 convolution with padding"""
59 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
60 | padding=1, bias=False)
61 |
62 |
63 | class BasicBlock(nn.Module):
64 | expansion = 1
65 |
66 | def __init__(self, inplanes, planes, stride=1, downsample=None):
67 | super(BasicBlock, self).__init__()
68 | self.conv1 = conv3x3(inplanes, planes, stride)
69 | self.bn1 = mynn.Norm2d(planes)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.conv2 = conv3x3(planes, planes)
72 | self.bn2 = mynn.Norm2d(planes)
73 | self.downsample = downsample
74 | self.stride = stride
75 | for m in self.modules():
76 | if isinstance(m, nn.Conv2d):
77 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
78 | elif isinstance(m, nn.BatchNorm2d):
79 | nn.init.constant_(m.weight, 1)
80 | nn.init.constant_(m.bias, 0)
81 |
82 | def forward(self, x):
83 | residual = x
84 |
85 | out = self.conv1(x)
86 | out = self.bn1(out)
87 | out = self.relu(out)
88 |
89 | out = self.conv2(out)
90 | out = self.bn2(out)
91 |
92 | if self.downsample is not None:
93 | residual = self.downsample(x)
94 |
95 | out += residual
96 | out = self.relu(out)
97 |
98 | return out
99 |
100 |
101 | class Bottleneck(nn.Module):
102 | expansion = 4
103 |
104 | def __init__(self, inplanes, planes, stride=1, downsample=None):
105 | super(Bottleneck, self).__init__()
106 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
107 | self.bn1 = mynn.Norm2d(planes)
108 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
109 | padding=1, bias=False)
110 | self.bn2 = mynn.Norm2d(planes)
111 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
112 | self.bn3 = mynn.Norm2d(planes * self.expansion)
113 | self.relu = nn.ReLU(inplace=True)
114 | self.downsample = downsample
115 | self.stride = stride
116 |
117 | def forward(self, x):
118 | residual = x
119 |
120 | out = self.conv1(x)
121 | out = self.bn1(out)
122 | out = self.relu(out)
123 |
124 | out = self.conv2(out)
125 | out = self.bn2(out)
126 | out = self.relu(out)
127 |
128 | out = self.conv3(out)
129 | out = self.bn3(out)
130 |
131 | if self.downsample is not None:
132 | residual = self.downsample(x)
133 |
134 | out += residual
135 | out = self.relu(out)
136 |
137 | return out
138 |
139 |
140 | class ResNet(nn.Module):
141 |
142 | def __init__(self, block, layers, num_classes=1000):
143 | self.inplanes = 64
144 | super(ResNet, self).__init__()
145 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
146 | bias=False)
147 | self.bn1 = mynn.Norm2d(64)
148 | self.relu = nn.ReLU(inplace=True)
149 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
150 | self.layer1 = self._make_layer(block, 64, layers[0])
151 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
154 | self.avgpool = nn.AvgPool2d(7, stride=1)
155 | self.fc = nn.Linear(512 * block.expansion, num_classes)
156 |
157 | for m in self.modules():
158 | if isinstance(m, nn.Conv2d):
159 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
160 | elif isinstance(m, nn.BatchNorm2d):
161 | nn.init.constant_(m.weight, 1)
162 | nn.init.constant_(m.bias, 0)
163 |
164 | def _make_layer(self, block, planes, blocks, stride=1):
165 | downsample = None
166 | if stride != 1 or self.inplanes != planes * block.expansion:
167 | downsample = nn.Sequential(
168 | nn.Conv2d(self.inplanes, planes * block.expansion,
169 | kernel_size=1, stride=stride, bias=False),
170 | mynn.Norm2d(planes * block.expansion),
171 | )
172 |
173 | layers = []
174 | layers.append(block(self.inplanes, planes, stride, downsample))
175 | self.inplanes = planes * block.expansion
176 | for i in range(1, blocks):
177 | layers.append(block(self.inplanes, planes))
178 |
179 | return nn.Sequential(*layers)
180 |
181 | def forward(self, x):
182 | x = self.conv1(x)
183 | x = self.bn1(x)
184 | x = self.relu(x)
185 | x = self.maxpool(x)
186 |
187 | x = self.layer1(x)
188 | x = self.layer2(x)
189 | x = self.layer3(x)
190 | x = self.layer4(x)
191 |
192 | x = self.avgpool(x)
193 | x = x.view(x.size(0), -1)
194 | x = self.fc(x)
195 |
196 | return x
197 |
198 |
199 | def resnet18(pretrained=True, **kwargs):
200 | """Constructs a ResNet-18 model.
201 |
202 | Args:
203 | pretrained (bool): If True, returns a model pre-trained on ImageNet
204 | """
205 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
206 | if pretrained:
207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
208 | return model
209 |
210 |
211 | def resnet34(pretrained=True, **kwargs):
212 | """Constructs a ResNet-34 model.
213 |
214 | Args:
215 | pretrained (bool): If True, returns a model pre-trained on ImageNet
216 | """
217 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
218 | if pretrained:
219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
220 | return model
221 |
222 |
223 | def resnet50(pretrained=True, **kwargs):
224 | """Constructs a ResNet-50 model.
225 |
226 | Args:
227 | pretrained (bool): If True, returns a model pre-trained on ImageNet
228 | """
229 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
230 | if pretrained:
231 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
232 | return model
233 |
234 |
235 | def resnet101(pretrained=True, **kwargs):
236 | """Constructs a ResNet-101 model.
237 |
238 | Args:
239 | pretrained (bool): If True, returns a model pre-trained on ImageNet
240 | """
241 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
242 | if pretrained:
243 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
244 | return model
245 |
246 |
247 | def resnet152(pretrained=True, **kwargs):
248 | """Constructs a ResNet-152 model.
249 |
250 | Args:
251 | pretrained (bool): If True, returns a model pre-trained on ImageNet
252 | """
253 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
254 | if pretrained:
255 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
256 | return model
257 |
--------------------------------------------------------------------------------
/network/SEresnext.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code adapted from:
6 | # https://github.com/Cadene/pretrained-models.pytorch
7 | #
8 | # BSD 3-Clause License
9 | #
10 | # Copyright (c) 2017, Remi Cadene
11 | # All rights reserved.
12 | #
13 | # Redistribution and use in source and binary forms, with or without
14 | # modification, are permitted provided that the following conditions are met:
15 | #
16 | # * Redistributions of source code must retain the above copyright notice, this
17 | # list of conditions and the following disclaimer.
18 | #
19 | # * Redistributions in binary form must reproduce the above copyright notice,
20 | # this list of conditions and the following disclaimer in the documentation
21 | # and/or other materials provided with the distribution.
22 | #
23 | # * Neither the name of the copyright holder nor the names of its
24 | # contributors may be used to endorse or promote products derived from
25 | # this software without specific prior written permission.
26 | #
27 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
31 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
32 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
33 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
34 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
35 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 | """
38 |
39 |
40 | from collections import OrderedDict
41 | import math
42 | import network.mynn as mynn
43 | import torch.nn as nn
44 | from torch.utils import model_zoo
45 |
46 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
47 | 'se_resnext50_32x4d', 'se_resnext101_32x4d']
48 |
49 | pretrained_settings = {
50 | 'se_resnext50_32x4d': {
51 | 'imagenet': {
52 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
53 | 'input_space': 'RGB',
54 | 'input_size': [3, 224, 224],
55 | 'input_range': [0, 1],
56 | 'mean': [0.485, 0.456, 0.406],
57 | 'std': [0.229, 0.224, 0.225],
58 | 'num_classes': 1000
59 | }
60 | },
61 | 'se_resnext101_32x4d': {
62 | 'imagenet': {
63 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
64 | 'input_space': 'RGB',
65 | 'input_size': [3, 224, 224],
66 | 'input_range': [0, 1],
67 | 'mean': [0.485, 0.456, 0.406],
68 | 'std': [0.229, 0.224, 0.225],
69 | 'num_classes': 1000
70 | }
71 | },
72 | }
73 |
74 |
75 | class SEModule(nn.Module):
76 |
77 | def __init__(self, channels, reduction):
78 | super(SEModule, self).__init__()
79 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
80 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
81 | padding=0)
82 | self.relu = nn.ReLU(inplace=True)
83 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
84 | padding=0)
85 | self.sigmoid = nn.Sigmoid()
86 |
87 | def forward(self, x):
88 | module_input = x
89 | x = self.avg_pool(x)
90 | x = self.fc1(x)
91 | x = self.relu(x)
92 | x = self.fc2(x)
93 | x = self.sigmoid(x)
94 | return module_input * x
95 |
96 |
97 | class Bottleneck(nn.Module):
98 | """
99 | Base class for bottlenecks that implements `forward()` method.
100 | """
101 | def forward(self, x):
102 | residual = x
103 |
104 | out = self.conv1(x)
105 | out = self.bn1(out)
106 | out = self.relu(out)
107 |
108 | out = self.conv2(out)
109 | out = self.bn2(out)
110 | out = self.relu(out)
111 |
112 | out = self.conv3(out)
113 | out = self.bn3(out)
114 |
115 | if self.downsample is not None:
116 | residual = self.downsample(x)
117 |
118 | out = self.se_module(out) + residual
119 | out = self.relu(out)
120 |
121 | return out
122 |
123 |
124 | class SEBottleneck(Bottleneck):
125 | """
126 | Bottleneck for SENet154.
127 | """
128 | expansion = 4
129 |
130 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
131 | downsample=None):
132 | super(SEBottleneck, self).__init__()
133 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
134 | self.bn1 = mynn.Norm2d(planes * 2)
135 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
136 | stride=stride, padding=1, groups=groups,
137 | bias=False)
138 | self.bn2 = mynn.Norm2d(planes * 4)
139 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
140 | bias=False)
141 | self.bn3 = mynn.Norm2d(planes * 4)
142 | self.relu = nn.ReLU(inplace=True)
143 | self.se_module = SEModule(planes * 4, reduction=reduction)
144 | self.downsample = downsample
145 | self.stride = stride
146 |
147 |
148 | class SEResNetBottleneck(Bottleneck):
149 | """
150 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
151 | implementation and uses `stride=stride` in `conv1` and not in `conv2`
152 | (the latter is used in the torchvision implementation of ResNet).
153 | """
154 | expansion = 4
155 |
156 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
157 | downsample=None):
158 | super(SEResNetBottleneck, self).__init__()
159 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
160 | stride=stride)
161 | self.bn1 = mynn.Norm2d(planes)
162 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
163 | groups=groups, bias=False)
164 | self.bn2 = mynn.Norm2d(planes)
165 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
166 | self.bn3 = mynn.Norm2d(planes * 4)
167 | self.relu = nn.ReLU(inplace=True)
168 | self.se_module = SEModule(planes * 4, reduction=reduction)
169 | self.downsample = downsample
170 | self.stride = stride
171 |
172 |
173 | class SEResNeXtBottleneck(Bottleneck):
174 | """
175 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
176 | """
177 | expansion = 4
178 |
179 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
180 | downsample=None, base_width=4):
181 | super(SEResNeXtBottleneck, self).__init__()
182 | width = math.floor(planes * (base_width / 64)) * groups
183 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
184 | stride=1)
185 | self.bn1 = mynn.Norm2d(width)
186 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
187 | padding=1, groups=groups, bias=False)
188 | self.bn2 = mynn.Norm2d(width)
189 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
190 | self.bn3 = mynn.Norm2d(planes * 4)
191 | self.relu = nn.ReLU(inplace=True)
192 | self.se_module = SEModule(planes * 4, reduction=reduction)
193 | self.downsample = downsample
194 | self.stride = stride
195 |
196 |
197 | class SENet(nn.Module):
198 |
199 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
200 | inplanes=128, input_3x3=True, downsample_kernel_size=3,
201 | downsample_padding=1, num_classes=1000):
202 | """
203 | Parameters
204 | ----------
205 | block (nn.Module): Bottleneck class.
206 | - For SENet154: SEBottleneck
207 | - For SE-ResNet models: SEResNetBottleneck
208 | - For SE-ResNeXt models: SEResNeXtBottleneck
209 | layers (list of ints): Number of residual blocks for 4 layers of the
210 | network (layer1...layer4).
211 | groups (int): Number of groups for the 3x3 convolution in each
212 | bottleneck block.
213 | - For SENet154: 64
214 | - For SE-ResNet models: 1
215 | - For SE-ResNeXt models: 32
216 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
217 | - For all models: 16
218 | dropout_p (float or None): Drop probability for the Dropout layer.
219 | If `None` the Dropout layer is not used.
220 | - For SENet154: 0.2
221 | - For SE-ResNet models: None
222 | - For SE-ResNeXt models: None
223 | inplanes (int): Number of input channels for layer1.
224 | - For SENet154: 128
225 | - For SE-ResNet models: 64
226 | - For SE-ResNeXt models: 64
227 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
228 | a single 7x7 convolution in layer0.
229 | - For SENet154: True
230 | - For SE-ResNet models: False
231 | - For SE-ResNeXt models: False
232 | downsample_kernel_size (int): Kernel size for downsampling convolutions
233 | in layer2, layer3 and layer4.
234 | - For SENet154: 3
235 | - For SE-ResNet models: 1
236 | - For SE-ResNeXt models: 1
237 | downsample_padding (int): Padding for downsampling convolutions in
238 | layer2, layer3 and layer4.
239 | - For SENet154: 1
240 | - For SE-ResNet models: 0
241 | - For SE-ResNeXt models: 0
242 | num_classes (int): Number of outputs in `last_linear` layer.
243 | - For all models: 1000
244 | """
245 | super(SENet, self).__init__()
246 | self.inplanes = inplanes
247 |
248 | if input_3x3:
249 | layer0_modules = [
250 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
251 | bias=False)),
252 | ('bn1', mynn.Norm2d(64)),
253 | ('relu1', nn.ReLU(inplace=True)),
254 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
255 | bias=False)),
256 | ('bn2', mynn.Norm2d(64)),
257 | ('relu2', nn.ReLU(inplace=True)),
258 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
259 | bias=False)),
260 | ('bn3', mynn.Norm2d(inplanes)),
261 | ('relu3', nn.ReLU(inplace=True)),
262 | ]
263 | else:
264 | layer0_modules = [
265 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
266 | padding=3, bias=False)),
267 | ('bn1', mynn.Norm2d(inplanes)),
268 | ('relu1', nn.ReLU(inplace=True)),
269 | ]
270 | # To preserve compatibility with Caffe weights `ceil_mode=True`
271 | # is used instead of `padding=1`.
272 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
273 | ceil_mode=True)))
274 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
275 | self.layer1 = self._make_layer(
276 | block,
277 | planes=64,
278 | blocks=layers[0],
279 | groups=groups,
280 | reduction=reduction,
281 | downsample_kernel_size=1,
282 | downsample_padding=0
283 | )
284 | self.layer2 = self._make_layer(
285 | block,
286 | planes=128,
287 | blocks=layers[1],
288 | stride=2,
289 | groups=groups,
290 | reduction=reduction,
291 | downsample_kernel_size=downsample_kernel_size,
292 | downsample_padding=downsample_padding
293 | )
294 | self.layer3 = self._make_layer(
295 | block,
296 | planes=256,
297 | blocks=layers[2],
298 | stride=1,
299 | groups=groups,
300 | reduction=reduction,
301 | downsample_kernel_size=downsample_kernel_size,
302 | downsample_padding=downsample_padding
303 | )
304 | self.layer4 = self._make_layer(
305 | block,
306 | planes=512,
307 | blocks=layers[3],
308 | stride=1,
309 | groups=groups,
310 | reduction=reduction,
311 | downsample_kernel_size=downsample_kernel_size,
312 | downsample_padding=downsample_padding
313 | )
314 | self.avg_pool = nn.AvgPool2d(7, stride=1)
315 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
316 | self.last_linear = nn.Linear(512 * block.expansion, num_classes)
317 |
318 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
319 | downsample_kernel_size=1, downsample_padding=0):
320 | downsample = None
321 | if stride != 1 or self.inplanes != planes * block.expansion:
322 | downsample = nn.Sequential(
323 | nn.Conv2d(self.inplanes, planes * block.expansion,
324 | kernel_size=downsample_kernel_size, stride=stride,
325 | padding=downsample_padding, bias=False),
326 | mynn.Norm2d(planes * block.expansion),
327 | )
328 |
329 | layers = []
330 | layers.append(block(self.inplanes, planes, groups, reduction, stride,
331 | downsample))
332 | self.inplanes = planes * block.expansion
333 | for i in range(1, blocks):
334 | layers.append(block(self.inplanes, planes, groups, reduction))
335 |
336 | return nn.Sequential(*layers)
337 |
338 | def features(self, x):
339 | x = self.layer0(x)
340 | x = self.layer1(x)
341 | x = self.layer2(x)
342 | x = self.layer3(x)
343 | x = self.layer4(x)
344 | return x
345 |
346 | def logits(self, x):
347 | x = self.avg_pool(x)
348 | if self.dropout is not None:
349 | x = self.dropout(x)
350 | x = x.view(x.size(0), -1)
351 | x = self.last_linear(x)
352 | return x
353 |
354 | def forward(self, x):
355 | x = self.features(x)
356 | x = self.logits(x)
357 | return x
358 |
359 |
360 | def initialize_pretrained_model(model, num_classes, settings):
361 | assert num_classes == settings['num_classes'], \
362 | 'num_classes should be {}, but is {}'.format(
363 | settings['num_classes'], num_classes)
364 | weights = model_zoo.load_url(settings['url'])
365 | model.load_state_dict(weights)
366 | model.input_space = settings['input_space']
367 | model.input_size = settings['input_size']
368 | model.input_range = settings['input_range']
369 | model.mean = settings['mean']
370 | model.std = settings['std']
371 |
372 |
373 |
374 | def se_resnext50_32x4d(num_classes=1000):
375 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
376 | dropout_p=None, inplanes=64, input_3x3=False,
377 | downsample_kernel_size=1, downsample_padding=0,
378 | num_classes=num_classes)
379 | settings = pretrained_settings['se_resnext50_32x4d']['imagenet']
380 | initialize_pretrained_model(model, num_classes, settings)
381 | return model
382 |
383 |
384 | def se_resnext101_32x4d(num_classes=1000):
385 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
386 | dropout_p=None, inplanes=64, input_3x3=False,
387 | downsample_kernel_size=1, downsample_padding=0,
388 | num_classes=num_classes)
389 | settings = pretrained_settings['se_resnext101_32x4d']['imagenet']
390 | initialize_pretrained_model(model, num_classes, settings)
391 | return model
392 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import importlib
7 | import torch
8 | import logging
9 |
10 | def get_net(args, criterion):
11 | net = get_model(network=args.arch, num_classes=args.dataset_cls.num_classes,
12 | criterion=criterion, trunk=args.trunk)
13 | num_params = sum([param.nelement() for param in net.parameters()])
14 | logging.info('Model params = {:2.1f}M'.format(num_params / 1000000))
15 |
16 | net = net.cuda()
17 | net = torch.nn.DataParallel(net)
18 | return net
19 |
20 |
21 | def get_model(network, num_classes, criterion, trunk):
22 |
23 | module = network[:network.rfind('.')]
24 | model = network[network.rfind('.')+1:]
25 | mod = importlib.import_module(module)
26 | net_func = getattr(mod, model)
27 | net = net_func(num_classes=num_classes, trunk=trunk, criterion=criterion)
28 | return net
29 |
30 |
31 |
--------------------------------------------------------------------------------
/network/gscnn.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code Adapted from:
6 | # https://github.com/sthalles/deeplab_v3
7 | #
8 | # MIT License
9 | #
10 | # Copyright (c) 2018 Thalles Santos Silva
11 | #
12 | # Permission is hereby granted, free of charge, to any person obtaining a copy
13 | # of this software and associated documentation files (the "Software"), to deal
14 | # in the Software without restriction, including without limitation the rights
15 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 | # copies of the Software, and to permit persons to whom the Software is
17 | # furnished to do so, subject to the following conditions:
18 | #
19 | # The above copyright notice and this permission notice shall be included in all
20 | # copies or substantial portions of the Software.
21 | #
22 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 | """
29 |
30 | import torch
31 | import torch.nn.functional as F
32 | from torch import nn
33 | from network import SEresnext
34 | from network import Resnet
35 | from network.wider_resnet import wider_resnet38_a2
36 | from config import cfg
37 | from network.mynn import initialize_weights, Norm2d
38 | from torch.autograd import Variable
39 |
40 | from my_functionals import GatedSpatialConv as gsc
41 |
42 | import cv2
43 | import numpy as np
44 |
45 | class Crop(nn.Module):
46 | def __init__(self, axis, offset):
47 | super(Crop, self).__init__()
48 | self.axis = axis
49 | self.offset = offset
50 |
51 | def forward(self, x, ref):
52 | """
53 |
54 | :param x: input layer
55 | :param ref: reference usually data in
56 | :return:
57 | """
58 | for axis in range(self.axis, x.dim()):
59 | ref_size = ref.size(axis)
60 | indices = torch.arange(self.offset, self.offset + ref_size).long()
61 | indices = x.data.new().resize_(indices.size()).copy_(indices).long()
62 | x = x.index_select(axis, Variable(indices))
63 | return x
64 |
65 |
66 | class MyIdentity(nn.Module):
67 | def __init__(self, axis, offset):
68 | super(MyIdentity, self).__init__()
69 | self.axis = axis
70 | self.offset = offset
71 |
72 | def forward(self, x, ref):
73 | """
74 |
75 | :param x: input layer
76 | :param ref: reference usually data in
77 | :return:
78 | """
79 | return x
80 |
81 | class SideOutputCrop(nn.Module):
82 | """
83 | This is the original implementation ConvTranspose2d (fixed) and crops
84 | """
85 |
86 | def __init__(self, num_output, kernel_sz=None, stride=None, upconv_pad=0, do_crops=True):
87 | super(SideOutputCrop, self).__init__()
88 | self._do_crops = do_crops
89 | self.conv = nn.Conv2d(num_output, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
90 |
91 | if kernel_sz is not None:
92 | self.upsample = True
93 | self.upsampled = nn.ConvTranspose2d(1, out_channels=1, kernel_size=kernel_sz, stride=stride,
94 | padding=upconv_pad,
95 | bias=False)
96 | ##doing crops
97 | if self._do_crops:
98 | self.crops = Crop(2, offset=kernel_sz // 4)
99 | else:
100 | self.crops = MyIdentity(None, None)
101 | else:
102 | self.upsample = False
103 |
104 | def forward(self, res, reference=None):
105 | side_output = self.conv(res)
106 | if self.upsample:
107 | side_output = self.upsampled(side_output)
108 | side_output = self.crops(side_output, reference)
109 |
110 | return side_output
111 |
112 |
113 | class _AtrousSpatialPyramidPoolingModule(nn.Module):
114 | '''
115 | operations performed:
116 | 1x1 x depth
117 | 3x3 x depth dilation 6
118 | 3x3 x depth dilation 12
119 | 3x3 x depth dilation 18
120 | image pooling
121 | concatenate all together
122 | Final 1x1 conv
123 | '''
124 |
125 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=[6, 12, 18]):
126 | super(_AtrousSpatialPyramidPoolingModule, self).__init__()
127 |
128 | # Check if we are using distributed BN and use the nn from encoding.nn
129 | # library rather than using standard pytorch.nn
130 |
131 | if output_stride == 8:
132 | rates = [2 * r for r in rates]
133 | elif output_stride == 16:
134 | pass
135 | else:
136 | raise 'output stride of {} not supported'.format(output_stride)
137 |
138 | self.features = []
139 | # 1x1
140 | self.features.append(
141 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
142 | Norm2d(reduction_dim), nn.ReLU(inplace=True)))
143 | # other rates
144 | for r in rates:
145 | self.features.append(nn.Sequential(
146 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3,
147 | dilation=r, padding=r, bias=False),
148 | Norm2d(reduction_dim),
149 | nn.ReLU(inplace=True)
150 | ))
151 | self.features = torch.nn.ModuleList(self.features)
152 |
153 | # img level features
154 | self.img_pooling = nn.AdaptiveAvgPool2d(1)
155 | self.img_conv = nn.Sequential(
156 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
157 | Norm2d(reduction_dim), nn.ReLU(inplace=True))
158 | self.edge_conv = nn.Sequential(
159 | nn.Conv2d(1, reduction_dim, kernel_size=1, bias=False),
160 | Norm2d(reduction_dim), nn.ReLU(inplace=True))
161 |
162 |
163 | def forward(self, x, edge):
164 | x_size = x.size()
165 |
166 | img_features = self.img_pooling(x)
167 | img_features = self.img_conv(img_features)
168 | img_features = F.interpolate(img_features, x_size[2:],
169 | mode='bilinear',align_corners=True)
170 | out = img_features
171 |
172 | edge_features = F.interpolate(edge, x_size[2:],
173 | mode='bilinear',align_corners=True)
174 | edge_features = self.edge_conv(edge_features)
175 | out = torch.cat((out, edge_features), 1)
176 |
177 | for f in self.features:
178 | y = f(x)
179 | out = torch.cat((out, y), 1)
180 | return out
181 |
182 | class GSCNN(nn.Module):
183 | '''
184 | Wide_resnet version of DeepLabV3
185 | mod1
186 | pool2
187 | mod2 str2
188 | pool3
189 | mod3-7
190 |
191 | structure: [3, 3, 6, 3, 1, 1]
192 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048),
193 | (1024, 2048, 4096)]
194 | '''
195 |
196 | def __init__(self, num_classes, trunk=None, criterion=None):
197 |
198 | super(GSCNN, self).__init__()
199 | self.criterion = criterion
200 | self.num_classes = num_classes
201 |
202 | wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
203 | wide_resnet = torch.nn.DataParallel(wide_resnet)
204 |
205 | wide_resnet = wide_resnet.module
206 | self.mod1 = wide_resnet.mod1
207 | self.mod2 = wide_resnet.mod2
208 | self.mod3 = wide_resnet.mod3
209 | self.mod4 = wide_resnet.mod4
210 | self.mod5 = wide_resnet.mod5
211 | self.mod6 = wide_resnet.mod6
212 | self.mod7 = wide_resnet.mod7
213 | self.pool2 = wide_resnet.pool2
214 | self.pool3 = wide_resnet.pool3
215 | self.interpolate = F.interpolate
216 | del wide_resnet
217 |
218 | self.dsn1 = nn.Conv2d(64, 1, 1)
219 | self.dsn3 = nn.Conv2d(256, 1, 1)
220 | self.dsn4 = nn.Conv2d(512, 1, 1)
221 | self.dsn7 = nn.Conv2d(4096, 1, 1)
222 |
223 | self.res1 = Resnet.BasicBlock(64, 64, stride=1, downsample=None)
224 | self.d1 = nn.Conv2d(64, 32, 1)
225 | self.res2 = Resnet.BasicBlock(32, 32, stride=1, downsample=None)
226 | self.d2 = nn.Conv2d(32, 16, 1)
227 | self.res3 = Resnet.BasicBlock(16, 16, stride=1, downsample=None)
228 | self.d3 = nn.Conv2d(16, 8, 1)
229 | self.fuse = nn.Conv2d(8, 1, kernel_size=1, padding=0, bias=False)
230 |
231 | self.cw = nn.Conv2d(2, 1, kernel_size=1, padding=0, bias=False)
232 |
233 | self.gate1 = gsc.GatedSpatialConv2d(32, 32)
234 | self.gate2 = gsc.GatedSpatialConv2d(16, 16)
235 | self.gate3 = gsc.GatedSpatialConv2d(8, 8)
236 |
237 | self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256,
238 | output_stride=8)
239 |
240 | self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
241 | self.bot_aspp = nn.Conv2d(1280 + 256, 256, kernel_size=1, bias=False)
242 |
243 | self.final_seg = nn.Sequential(
244 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
245 | Norm2d(256),
246 | nn.ReLU(inplace=True),
247 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
248 | Norm2d(256),
249 | nn.ReLU(inplace=True),
250 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False))
251 |
252 | self.sigmoid = nn.Sigmoid()
253 | initialize_weights(self.final_seg)
254 |
255 | def forward(self, inp, gts=None):
256 |
257 | x_size = inp.size()
258 |
259 | # res 1
260 | m1 = self.mod1(inp)
261 |
262 | # res 2
263 | m2 = self.mod2(self.pool2(m1))
264 |
265 | # res 3
266 | m3 = self.mod3(self.pool3(m2))
267 |
268 | # res 4-7
269 | m4 = self.mod4(m3)
270 | m5 = self.mod5(m4)
271 | m6 = self.mod6(m5)
272 | m7 = self.mod7(m6)
273 |
274 | s3 = F.interpolate(self.dsn3(m3), x_size[2:],
275 | mode='bilinear', align_corners=True)
276 | s4 = F.interpolate(self.dsn4(m4), x_size[2:],
277 | mode='bilinear', align_corners=True)
278 | s7 = F.interpolate(self.dsn7(m7), x_size[2:],
279 | mode='bilinear', align_corners=True)
280 |
281 | m1f = F.interpolate(m1, x_size[2:], mode='bilinear', align_corners=True)
282 |
283 | im_arr = inp.cpu().numpy().transpose((0,2,3,1)).astype(np.uint8)
284 | canny = np.zeros((x_size[0], 1, x_size[2], x_size[3]))
285 | for i in range(x_size[0]):
286 | canny[i] = cv2.Canny(im_arr[i],10,100)
287 | canny = torch.from_numpy(canny).cuda().float()
288 |
289 | cs = self.res1(m1f)
290 | cs = F.interpolate(cs, x_size[2:],
291 | mode='bilinear', align_corners=True)
292 | cs = self.d1(cs)
293 | cs = self.gate1(cs, s3)
294 | cs = self.res2(cs)
295 | cs = F.interpolate(cs, x_size[2:],
296 | mode='bilinear', align_corners=True)
297 | cs = self.d2(cs)
298 | cs = self.gate2(cs, s4)
299 | cs = self.res3(cs)
300 | cs = F.interpolate(cs, x_size[2:],
301 | mode='bilinear', align_corners=True)
302 | cs = self.d3(cs)
303 | cs = self.gate3(cs, s7)
304 | cs = self.fuse(cs)
305 | cs = F.interpolate(cs, x_size[2:],
306 | mode='bilinear', align_corners=True)
307 | edge_out = self.sigmoid(cs)
308 | cat = torch.cat((edge_out, canny), dim=1)
309 | acts = self.cw(cat)
310 | acts = self.sigmoid(acts)
311 |
312 | # aspp
313 | x = self.aspp(m7, acts)
314 | dec0_up = self.bot_aspp(x)
315 |
316 | dec0_fine = self.bot_fine(m2)
317 | dec0_up = self.interpolate(dec0_up, m2.size()[2:], mode='bilinear',align_corners=True)
318 | dec0 = [dec0_fine, dec0_up]
319 | dec0 = torch.cat(dec0, 1)
320 |
321 | dec1 = self.final_seg(dec0)
322 | seg_out = self.interpolate(dec1, x_size[2:], mode='bilinear')
323 |
324 | if self.training:
325 | return self.criterion((seg_out, edge_out), gts)
326 | else:
327 | return seg_out, edge_out
328 |
329 |
--------------------------------------------------------------------------------
/network/mynn.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from config import cfg
7 | import torch.nn as nn
8 | from math import sqrt
9 | import torch
10 | from torch.autograd.function import InplaceFunction
11 | from itertools import repeat
12 | from torch.nn.modules import Module
13 | from torch.utils.checkpoint import checkpoint
14 |
15 |
16 | def Norm2d(in_channels):
17 | """
18 | Custom Norm Function to allow flexible switching
19 | """
20 | layer = getattr(cfg.MODEL,'BNFUNC')
21 | normalizationLayer = layer(in_channels)
22 | return normalizationLayer
23 |
24 |
25 | def initialize_weights(*models):
26 | for model in models:
27 | for module in model.modules():
28 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
29 | nn.init.kaiming_normal(module.weight)
30 | if module.bias is not None:
31 | module.bias.data.zero_()
32 | elif isinstance(module, nn.BatchNorm2d):
33 | module.weight.data.fill_(1)
34 | module.bias.data.zero_()
35 |
--------------------------------------------------------------------------------
/network/wider_resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code adapted from:
6 | # https://github.com/mapillary/inplace_abn/
7 | #
8 | # BSD 3-Clause License
9 | #
10 | # Copyright (c) 2017, mapillary
11 | # All rights reserved.
12 | #
13 | # Redistribution and use in source and binary forms, with or without
14 | # modification, are permitted provided that the following conditions are met:
15 | #
16 | # * Redistributions of source code must retain the above copyright notice, this
17 | # list of conditions and the following disclaimer.
18 | #
19 | # * Redistributions in binary form must reproduce the above copyright notice,
20 | # this list of conditions and the following disclaimer in the documentation
21 | # and/or other materials provided with the distribution.
22 | #
23 | # * Neither the name of the copyright holder nor the names of its
24 | # contributors may be used to endorse or promote products derived from
25 | # this software without specific prior written permission.
26 | #
27 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
28 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
29 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
30 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
31 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
32 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
33 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
34 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
35 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
36 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
37 | """
38 |
39 | import sys
40 | from collections import OrderedDict
41 | from functools import partial
42 | import torch.nn as nn
43 | import torch
44 | import network.mynn as mynn
45 |
46 | def bnrelu(channels):
47 | return nn.Sequential(mynn.Norm2d(channels),
48 | nn.ReLU(inplace=True))
49 |
50 | class GlobalAvgPool2d(nn.Module):
51 |
52 | def __init__(self):
53 | """Global average pooling over the input's spatial dimensions"""
54 | super(GlobalAvgPool2d, self).__init__()
55 |
56 | def forward(self, inputs):
57 | in_size = inputs.size()
58 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
59 |
60 |
61 | class IdentityResidualBlock(nn.Module):
62 |
63 | def __init__(self,
64 | in_channels,
65 | channels,
66 | stride=1,
67 | dilation=1,
68 | groups=1,
69 | norm_act=bnrelu,
70 | dropout=None,
71 | dist_bn=False
72 | ):
73 | """Configurable identity-mapping residual block
74 |
75 | Parameters
76 | ----------
77 | in_channels : int
78 | Number of input channels.
79 | channels : list of int
80 | Number of channels in the internal feature maps.
81 | Can either have two or three elements: if three construct
82 | a residual block with two `3 x 3` convolutions,
83 | otherwise construct a bottleneck block with `1 x 1`, then
84 | `3 x 3` then `1 x 1` convolutions.
85 | stride : int
86 | Stride of the first `3 x 3` convolution
87 | dilation : int
88 | Dilation to apply to the `3 x 3` convolutions.
89 | groups : int
90 | Number of convolution groups.
91 | This is used to create ResNeXt-style blocks and is only compatible with
92 | bottleneck blocks.
93 | norm_act : callable
94 | Function to create normalization / activation Module.
95 | dropout: callable
96 | Function to create Dropout Module.
97 | dist_bn: Boolean
98 | A variable to enable or disable use of distributed BN
99 | """
100 | super(IdentityResidualBlock, self).__init__()
101 | self.dist_bn = dist_bn
102 |
103 | # Check if we are using distributed BN and use the nn from encoding.nn
104 | # library rather than using standard pytorch.nn
105 |
106 |
107 | # Check parameters for inconsistencies
108 | if len(channels) != 2 and len(channels) != 3:
109 | raise ValueError("channels must contain either two or three values")
110 | if len(channels) == 2 and groups != 1:
111 | raise ValueError("groups > 1 are only valid if len(channels) == 3")
112 |
113 | is_bottleneck = len(channels) == 3
114 | need_proj_conv = stride != 1 or in_channels != channels[-1]
115 |
116 | self.bn1 = norm_act(in_channels)
117 | if not is_bottleneck:
118 | layers = [
119 | ("conv1", nn.Conv2d(in_channels,
120 | channels[0],
121 | 3,
122 | stride=stride,
123 | padding=dilation,
124 | bias=False,
125 | dilation=dilation)),
126 | ("bn2", norm_act(channels[0])),
127 | ("conv2", nn.Conv2d(channels[0], channels[1],
128 | 3,
129 | stride=1,
130 | padding=dilation,
131 | bias=False,
132 | dilation=dilation))
133 | ]
134 | if dropout is not None:
135 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
136 | else:
137 | layers = [
138 | ("conv1",
139 | nn.Conv2d(in_channels,
140 | channels[0],
141 | 1,
142 | stride=stride,
143 | padding=0,
144 | bias=False)),
145 | ("bn2", norm_act(channels[0])),
146 | ("conv2", nn.Conv2d(channels[0],
147 | channels[1],
148 | 3, stride=1,
149 | padding=dilation, bias=False,
150 | groups=groups,
151 | dilation=dilation)),
152 | ("bn3", norm_act(channels[1])),
153 | ("conv3", nn.Conv2d(channels[1], channels[2],
154 | 1, stride=1, padding=0, bias=False))
155 | ]
156 | if dropout is not None:
157 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
158 | self.convs = nn.Sequential(OrderedDict(layers))
159 |
160 | if need_proj_conv:
161 | self.proj_conv = nn.Conv2d(
162 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
163 |
164 | def forward(self, x):
165 | """
166 | This is the standard forward function for non-distributed batch norm
167 | """
168 | if hasattr(self, "proj_conv"):
169 | bn1 = self.bn1(x)
170 | shortcut = self.proj_conv(bn1)
171 | else:
172 | shortcut = x.clone()
173 | bn1 = self.bn1(x)
174 |
175 | out = self.convs(bn1)
176 | out.add_(shortcut)
177 | return out
178 |
179 |
180 |
181 |
182 | class WiderResNet(nn.Module):
183 |
184 | def __init__(self,
185 | structure,
186 | norm_act=bnrelu,
187 | classes=0
188 | ):
189 | """Wider ResNet with pre-activation (identity mapping) blocks
190 |
191 | Parameters
192 | ----------
193 | structure : list of int
194 | Number of residual blocks in each of the six modules of the network.
195 | norm_act : callable
196 | Function to create normalization / activation Module.
197 | classes : int
198 | If not `0` also include global average pooling and \
199 | a fully-connected layer with `classes` outputs at the end
200 | of the network.
201 | """
202 | super(WiderResNet, self).__init__()
203 | self.structure = structure
204 |
205 | if len(structure) != 6:
206 | raise ValueError("Expected a structure with six values")
207 |
208 | # Initial layers
209 | self.mod1 = nn.Sequential(OrderedDict([
210 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))
211 | ]))
212 |
213 | # Groups of residual blocks
214 | in_channels = 64
215 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024),
216 | (512, 1024, 2048), (1024, 2048, 4096)]
217 | for mod_id, num in enumerate(structure):
218 | # Create blocks for module
219 | blocks = []
220 | for block_id in range(num):
221 | blocks.append((
222 | "block%d" % (block_id + 1),
223 | IdentityResidualBlock(in_channels, channels[mod_id],
224 | norm_act=norm_act)
225 | ))
226 |
227 | # Update channels and p_keep
228 | in_channels = channels[mod_id][-1]
229 |
230 | # Create module
231 | if mod_id <= 4:
232 | self.add_module("pool%d" %
233 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1))
234 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
235 |
236 | # Pooling and predictor
237 | self.bn_out = norm_act(in_channels)
238 | if classes != 0:
239 | self.classifier = nn.Sequential(OrderedDict([
240 | ("avg_pool", GlobalAvgPool2d()),
241 | ("fc", nn.Linear(in_channels, classes))
242 | ]))
243 |
244 | def forward(self, img):
245 | out = self.mod1(img)
246 | out = self.mod2(self.pool2(out))
247 | out = self.mod3(self.pool3(out))
248 | out = self.mod4(self.pool4(out))
249 | out = self.mod5(self.pool5(out))
250 | out = self.mod6(self.pool6(out))
251 | out = self.mod7(out)
252 | out = self.bn_out(out)
253 |
254 | if hasattr(self, "classifier"):
255 | out = self.classifier(out)
256 |
257 | return out
258 |
259 |
260 | class WiderResNetA2(nn.Module):
261 |
262 | def __init__(self,
263 | structure,
264 | norm_act=bnrelu,
265 | classes=0,
266 | dilation=False,
267 | dist_bn=False
268 | ):
269 | """Wider ResNet with pre-activation (identity mapping) blocks
270 |
271 | This variant uses down-sampling by max-pooling in the first two blocks and \
272 | by strided convolution in the others.
273 |
274 | Parameters
275 | ----------
276 | structure : list of int
277 | Number of residual blocks in each of the six modules of the network.
278 | norm_act : callable
279 | Function to create normalization / activation Module.
280 | classes : int
281 | If not `0` also include global average pooling and a fully-connected layer
282 | \with `classes` outputs at the end
283 | of the network.
284 | dilation : bool
285 | If `True` apply dilation to the last three modules and change the
286 | \down-sampling factor from 32 to 8.
287 | """
288 | super(WiderResNetA2, self).__init__()
289 | self.dist_bn = dist_bn
290 |
291 | # If using distributed batch norm, use the encoding.nn as oppose to torch.nn
292 |
293 |
294 | nn.Dropout = nn.Dropout2d
295 | norm_act = bnrelu
296 | self.structure = structure
297 | self.dilation = dilation
298 |
299 | if len(structure) != 6:
300 | raise ValueError("Expected a structure with six values")
301 |
302 | # Initial layers
303 | self.mod1 = torch.nn.Sequential(OrderedDict([
304 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))
305 | ]))
306 |
307 | # Groups of residual blocks
308 | in_channels = 64
309 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048),
310 | (1024, 2048, 4096)]
311 | for mod_id, num in enumerate(structure):
312 | # Create blocks for module
313 | blocks = []
314 | for block_id in range(num):
315 | if not dilation:
316 | dil = 1
317 | stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1
318 | else:
319 | if mod_id == 3:
320 | dil = 2
321 | elif mod_id > 3:
322 | dil = 4
323 | else:
324 | dil = 1
325 | stride = 2 if block_id == 0 and mod_id == 2 else 1
326 |
327 | if mod_id == 4:
328 | drop = partial(nn.Dropout, p=0.3)
329 | elif mod_id == 5:
330 | drop = partial(nn.Dropout, p=0.5)
331 | else:
332 | drop = None
333 |
334 | blocks.append((
335 | "block%d" % (block_id + 1),
336 | IdentityResidualBlock(in_channels,
337 | channels[mod_id], norm_act=norm_act,
338 | stride=stride, dilation=dil,
339 | dropout=drop, dist_bn=self.dist_bn)
340 | ))
341 |
342 | # Update channels and p_keep
343 | in_channels = channels[mod_id][-1]
344 |
345 | # Create module
346 | if mod_id < 2:
347 | self.add_module("pool%d" %
348 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1))
349 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
350 |
351 | # Pooling and predictor
352 | self.bn_out = norm_act(in_channels)
353 | if classes != 0:
354 | self.classifier = nn.Sequential(OrderedDict([
355 | ("avg_pool", GlobalAvgPool2d()),
356 | ("fc", nn.Linear(in_channels, classes))
357 | ]))
358 |
359 | def forward(self, img):
360 | out = self.mod1(img)
361 | out = self.mod2(self.pool2(out))
362 | out = self.mod3(self.pool3(out))
363 | out = self.mod4(out)
364 | out = self.mod5(out)
365 | out = self.mod6(out)
366 | out = self.mod7(out)
367 | out = self.bn_out(out)
368 |
369 | if hasattr(self, "classifier"):
370 | return self.classifier(out)
371 | else:
372 | return out
373 |
374 |
375 | _NETS = {
376 | "16": {"structure": [1, 1, 1, 1, 1, 1]},
377 | "20": {"structure": [1, 1, 1, 3, 1, 1]},
378 | "38": {"structure": [3, 3, 6, 3, 1, 1]},
379 | }
380 |
381 | __all__ = []
382 | for name, params in _NETS.items():
383 | net_name = "wider_resnet" + name
384 | setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params))
385 | __all__.append(net_name)
386 | for name, params in _NETS.items():
387 | net_name = "wider_resnet" + name + "_a2"
388 | setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params))
389 | __all__.append(net_name)
390 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | from torch import optim
8 | import math
9 | import logging
10 | from config import cfg
11 |
12 | def get_optimizer(args, net):
13 |
14 | param_groups = net.parameters()
15 |
16 | if args.sgd:
17 | optimizer = optim.SGD(param_groups,
18 | lr=args.lr,
19 | weight_decay=args.weight_decay,
20 | momentum=args.momentum,
21 | nesterov=False)
22 | elif args.adam:
23 | amsgrad=False
24 | if args.amsgrad:
25 | amsgrad=True
26 | optimizer = optim.Adam(param_groups,
27 | lr=args.lr,
28 | weight_decay=args.weight_decay,
29 | amsgrad=amsgrad
30 | )
31 | else:
32 | raise ('Not a valid optimizer')
33 |
34 | if args.lr_schedule == 'poly':
35 | lambda1 = lambda epoch: math.pow(1 - epoch / args.max_epoch, args.poly_exp)
36 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
37 | else:
38 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule))
39 |
40 | if args.snapshot:
41 | logging.info('Loading weights from model {}'.format(args.snapshot))
42 | net, optimizer = restore_snapshot(args, net, optimizer, args.snapshot)
43 | else:
44 | logging.info('Loaded weights from IMGNET classifier')
45 |
46 | return optimizer, scheduler
47 |
48 | def restore_snapshot(args, net, optimizer, snapshot):
49 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu'))
50 | logging.info("Load Compelete")
51 | if args.sgd_finetuned:
52 | print('skipping load optimizer')
53 | else:
54 | if 'optimizer' in checkpoint and args.restore_optimizer:
55 | optimizer.load_state_dict(checkpoint['optimizer'])
56 |
57 | if 'state_dict' in checkpoint:
58 | net = forgiving_state_restore(net, checkpoint['state_dict'])
59 | else:
60 | net = forgiving_state_restore(net, checkpoint)
61 |
62 | return net, optimizer
63 |
64 | def forgiving_state_restore(net, loaded_dict):
65 | # Handle partial loading when some tensors don't match up in size.
66 | # Because we want to use models that were trained off a different
67 | # number of classes.
68 | net_state_dict = net.state_dict()
69 | new_loaded_dict = {}
70 | for k in net_state_dict:
71 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
72 | new_loaded_dict[k] = loaded_dict[k]
73 | else:
74 | logging.info('Skipped loading parameter {}'.format(k))
75 | net_state_dict.update(new_loaded_dict)
76 | net.load_state_dict(net_state_dict)
77 | return net
78 |
79 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | import argparse
9 | from functools import partial
10 | from config import cfg, assert_and_infer_cfg
11 | import logging
12 | import math
13 | import os
14 | import sys
15 |
16 | import torch
17 | import numpy as np
18 |
19 | from utils.misc import AverageMeter, prep_experiment, evaluate_eval, fast_hist
20 | from utils.f_boundary import eval_mask_boundary
21 | import datasets
22 | import loss
23 | import network
24 | import optimizer
25 |
26 | # Argument Parser
27 | parser = argparse.ArgumentParser(description='GSCNN')
28 | parser.add_argument('--lr', type=float, default=0.01)
29 | parser.add_argument('--arch', type=str, default='network.gscnn.GSCNN')
30 | parser.add_argument('--dataset', type=str, default='cityscapes')
31 | parser.add_argument('--cv', type=int, default=0,
32 | help='cross validation split')
33 | parser.add_argument('--joint_edgeseg_loss', action='store_true', default=True,
34 | help='joint loss')
35 | parser.add_argument('--img_wt_loss', action='store_true', default=False,
36 | help='per-image class-weighted loss')
37 | parser.add_argument('--batch_weighting', action='store_true', default=False,
38 | help='Batch weighting for class')
39 | parser.add_argument('--eval_thresholds', type=str, default='0.0005,0.001875,0.00375,0.005',
40 | help='Thresholds for boundary evaluation')
41 | parser.add_argument('--rescale', type=float, default=1.0,
42 | help='Rescaled LR Rate')
43 | parser.add_argument('--repoly', type=float, default=1.5,
44 | help='Rescaled Poly')
45 |
46 | parser.add_argument('--edge_weight', type=float, default=1.0,
47 | help='Edge loss weight for joint loss')
48 | parser.add_argument('--seg_weight', type=float, default=1.0,
49 | help='Segmentation loss weight for joint loss')
50 | parser.add_argument('--att_weight', type=float, default=1.0,
51 | help='Attention loss weight for joint loss')
52 | parser.add_argument('--dual_weight', type=float, default=1.0,
53 | help='Dual loss weight for joint loss')
54 |
55 | parser.add_argument('--evaluate', action='store_true', default=False)
56 |
57 | parser.add_argument("--local_rank", default=0, type=int)
58 |
59 | parser.add_argument('--sgd', action='store_true', default=True)
60 | parser.add_argument('--sgd_finetuned',action='store_true',default=False)
61 | parser.add_argument('--adam', action='store_true', default=False)
62 | parser.add_argument('--amsgrad', action='store_true', default=False)
63 |
64 | parser.add_argument('--trunk', type=str, default='resnet101',
65 | help='trunk model, can be: resnet101 (default), resnet50')
66 | parser.add_argument('--max_epoch', type=int, default=175)
67 | parser.add_argument('--start_epoch', type=int, default=0)
68 | parser.add_argument('--color_aug', type=float,
69 | default=0.25, help='level of color augmentation')
70 | parser.add_argument('--rotate', type=float,
71 | default=0, help='rotation')
72 | parser.add_argument('--gblur', action='store_true', default=True)
73 | parser.add_argument('--bblur', action='store_true', default=False)
74 | parser.add_argument('--lr_schedule', type=str, default='poly',
75 | help='name of lr schedule: poly')
76 | parser.add_argument('--poly_exp', type=float, default=1.0,
77 | help='polynomial LR exponent')
78 | parser.add_argument('--bs_mult', type=int, default=1)
79 | parser.add_argument('--bs_mult_val', type=int, default=2)
80 | parser.add_argument('--crop_size', type=int, default=720,
81 | help='training crop size')
82 | parser.add_argument('--pre_size', type=int, default=None,
83 | help='resize image shorter edge to this before augmentation')
84 | parser.add_argument('--scale_min', type=float, default=0.5,
85 | help='dynamically scale training images down to this size')
86 | parser.add_argument('--scale_max', type=float, default=2.0,
87 | help='dynamically scale training images up to this size')
88 | parser.add_argument('--weight_decay', type=float, default=1e-4)
89 | parser.add_argument('--momentum', type=float, default=0.9)
90 | parser.add_argument('--snapshot', type=str, default=None)
91 | parser.add_argument('--restore_optimizer', action='store_true', default=False)
92 | parser.add_argument('--exp', type=str, default='default',
93 | help='experiment directory name')
94 | parser.add_argument('--tb_tag', type=str, default='',
95 | help='add tag to tb dir')
96 | parser.add_argument('--ckpt', type=str, default='logs/ckpt')
97 | parser.add_argument('--tb_path', type=str, default='logs/tb')
98 | parser.add_argument('--syncbn', action='store_true', default=True,
99 | help='Synchronized BN')
100 | parser.add_argument('--dump_augmentation_images', action='store_true', default=False,
101 | help='Synchronized BN')
102 | parser.add_argument('--test_mode', action='store_true', default=False,
103 | help='minimum testing (1 epoch run ) to verify nothing failed')
104 | parser.add_argument('-wb', '--wt_bound', type=float, default=1.0)
105 | parser.add_argument('--maxSkip', type=int, default=0)
106 | args = parser.parse_args()
107 | args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
108 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
109 |
110 | #Enable CUDNN Benchmarking optimization
111 | torch.backends.cudnn.benchmark = True
112 | args.world_size = 1
113 | #Test Mode run two epochs with a few iterations of training and val
114 | if args.test_mode:
115 | args.max_epoch = 2
116 |
117 | if 'WORLD_SIZE' in os.environ:
118 | args.world_size = int(os.environ['WORLD_SIZE'])
119 | print("Total world size: ", int(os.environ['WORLD_SIZE']))
120 |
121 | def main():
122 | '''
123 | Main Function
124 |
125 | '''
126 |
127 | #Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
128 | assert_and_infer_cfg(args)
129 | writer = prep_experiment(args,parser)
130 | train_loader, val_loader, train_obj = datasets.setup_loaders(args)
131 | criterion, criterion_val = loss.get_loss(args)
132 | net = network.get_net(args, criterion)
133 | optim, scheduler = optimizer.get_optimizer(args, net)
134 |
135 | torch.cuda.empty_cache()
136 |
137 | if args.evaluate:
138 | # Early evaluation for benchmarking
139 | default_eval_epoch = 1
140 | validate(val_loader, net, criterion_val,
141 | optim, default_eval_epoch, writer)
142 | evaluate(val_loader, net)
143 | return
144 |
145 | #Main Loop
146 | for epoch in range(args.start_epoch, args.max_epoch):
147 | # Update EPOCH CTR
148 | cfg.immutable(False)
149 | cfg.EPOCH = epoch
150 | cfg.immutable(True)
151 |
152 | scheduler.step()
153 |
154 | train(train_loader, net, criterion, optim, epoch, writer)
155 | validate(val_loader, net, criterion_val,
156 | optim, epoch, writer)
157 |
158 |
159 | def train(train_loader, net, criterion, optimizer, curr_epoch, writer):
160 | '''
161 | Runs the training loop per epoch
162 | train_loader: Data loader for train
163 | net: thet network
164 | criterion: loss fn
165 | optimizer: optimizer
166 | curr_epoch: current epoch
167 | writer: tensorboard writer
168 | return: val_avg for step function if required
169 | '''
170 | net.train()
171 |
172 | train_main_loss = AverageMeter()
173 | train_edge_loss = AverageMeter()
174 | train_seg_loss = AverageMeter()
175 | train_att_loss = AverageMeter()
176 | train_dual_loss = AverageMeter()
177 | curr_iter = curr_epoch * len(train_loader)
178 |
179 | for i, data in enumerate(train_loader):
180 | if i==0:
181 | print('running....')
182 |
183 | inputs, mask, edge, _img_name = data
184 |
185 | if torch.sum(torch.isnan(inputs)) > 0:
186 | import pdb; pdb.set_trace()
187 |
188 | batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3)
189 |
190 | inputs, mask, edge = inputs.cuda(), mask.cuda(), edge.cuda()
191 |
192 | if i==0:
193 | print('forward done')
194 |
195 | optimizer.zero_grad()
196 |
197 | main_loss = None
198 | loss_dict = None
199 |
200 | if args.joint_edgeseg_loss:
201 | loss_dict = net(inputs, gts=(mask, edge))
202 |
203 | if args.seg_weight > 0:
204 | log_seg_loss = loss_dict['seg_loss'].mean().clone().detach_()
205 | train_seg_loss.update(log_seg_loss.item(), batch_pixel_size)
206 | main_loss = loss_dict['seg_loss']
207 |
208 | if args.edge_weight > 0:
209 | log_edge_loss = loss_dict['edge_loss'].mean().clone().detach_()
210 | train_edge_loss.update(log_edge_loss.item(), batch_pixel_size)
211 | if main_loss is not None:
212 | main_loss += loss_dict['edge_loss']
213 | else:
214 | main_loss = loss_dict['edge_loss']
215 |
216 | if args.att_weight > 0:
217 | log_att_loss = loss_dict['att_loss'].mean().clone().detach_()
218 | train_att_loss.update(log_att_loss.item(), batch_pixel_size)
219 | if main_loss is not None:
220 | main_loss += loss_dict['att_loss']
221 | else:
222 | main_loss = loss_dict['att_loss']
223 |
224 | if args.dual_weight > 0:
225 | log_dual_loss = loss_dict['dual_loss'].mean().clone().detach_()
226 | train_dual_loss.update(log_dual_loss.item(), batch_pixel_size)
227 | if main_loss is not None:
228 | main_loss += loss_dict['dual_loss']
229 | else:
230 | main_loss = loss_dict['dual_loss']
231 |
232 | else:
233 | main_loss = net(inputs, gts=mask)
234 |
235 | main_loss = main_loss.mean()
236 | log_main_loss = main_loss.clone().detach_()
237 |
238 | train_main_loss.update(log_main_loss.item(), batch_pixel_size)
239 |
240 | main_loss.backward()
241 |
242 | optimizer.step()
243 |
244 | if i==0:
245 | print('step 1 done')
246 |
247 | curr_iter += 1
248 |
249 | if args.local_rank == 0:
250 | msg = '[epoch {}], [iter {} / {}], [train main loss {:0.6f}], [seg loss {:0.6f}], [edge loss {:0.6f}], [lr {:0.6f}]'.format(
251 | curr_epoch, i + 1, len(train_loader), train_main_loss.avg, train_seg_loss.avg, train_edge_loss.avg, optimizer.param_groups[-1]['lr'] )
252 |
253 | logging.info(msg)
254 |
255 | # Log tensorboard metrics for each iteration of the training phase
256 | writer.add_scalar('training/loss', (train_main_loss.val),
257 | curr_iter)
258 | writer.add_scalar('training/lr', optimizer.param_groups[-1]['lr'],
259 | curr_iter)
260 | if args.joint_edgeseg_loss:
261 |
262 | writer.add_scalar('training/seg_loss', (train_seg_loss.val),
263 | curr_iter)
264 | writer.add_scalar('training/edge_loss', (train_edge_loss.val),
265 | curr_iter)
266 | writer.add_scalar('training/att_loss', (train_att_loss.val),
267 | curr_iter)
268 | writer.add_scalar('training/dual_loss', (train_dual_loss.val),
269 | curr_iter)
270 | if i > 5 and args.test_mode:
271 | return
272 |
273 | def validate(val_loader, net, criterion, optimizer, curr_epoch, writer):
274 | '''
275 | Runs the validation loop after each training epoch
276 | val_loader: Data loader for validation
277 | net: thet network
278 | criterion: loss fn
279 | optimizer: optimizer
280 | curr_epoch: current epoch
281 | writer: tensorboard writer
282 | return:
283 | '''
284 | net.eval()
285 | val_loss = AverageMeter()
286 | mf_score = AverageMeter()
287 | IOU_acc = 0
288 | dump_images = []
289 | heatmap_images = []
290 | for vi, data in enumerate(val_loader):
291 | input, mask, edge, img_names = data
292 | assert len(input.size()) == 4 and len(mask.size()) == 3
293 | assert input.size()[2:] == mask.size()[1:]
294 | h, w = mask.size()[1:]
295 |
296 | batch_pixel_size = input.size(0) * input.size(2) * input.size(3)
297 | input, mask_cuda, edge_cuda = input.cuda(), mask.cuda(), edge.cuda()
298 |
299 | with torch.no_grad():
300 | seg_out, edge_out = net(input) # output = (1, 19, 713, 713)
301 |
302 | if args.joint_edgeseg_loss:
303 | loss_dict = criterion((seg_out, edge_out), (mask_cuda, edge_cuda))
304 | val_loss.update(sum(loss_dict.values()).item(), batch_pixel_size)
305 | else:
306 | val_loss.update(criterion(seg_out, mask_cuda).item(), batch_pixel_size)
307 |
308 | # Collect data from different GPU to a single GPU since
309 | # encoding.parallel.criterionparallel function calculates distributed loss
310 | # functions
311 |
312 | seg_predictions = seg_out.data.max(1)[1].cpu()
313 | edge_predictions = edge_out.max(1)[0].cpu()
314 |
315 | #Logging
316 | if vi % 20 == 0:
317 | if args.local_rank == 0:
318 | logging.info('validating: %d / %d' % (vi + 1, len(val_loader)))
319 | if vi > 10 and args.test_mode:
320 | break
321 | _edge = edge.max(1)[0]
322 |
323 | #Image Dumps
324 | if vi < 10:
325 | dump_images.append([mask, seg_predictions, img_names])
326 | heatmap_images.append([_edge, edge_predictions, img_names])
327 |
328 | IOU_acc += fast_hist(seg_predictions.numpy().flatten(), mask.numpy().flatten(),
329 | args.dataset_cls.num_classes)
330 |
331 | del seg_out, edge_out, vi, data
332 |
333 | if args.local_rank == 0:
334 | evaluate_eval(args, net, optimizer, val_loss, mf_score, IOU_acc, dump_images, heatmap_images,
335 | writer, curr_epoch, args.dataset_cls)
336 |
337 | return val_loss.avg
338 |
339 | def evaluate(val_loader, net):
340 | '''
341 | Runs the evaluation loop and prints F score
342 | val_loader: Data loader for validation
343 | net: thet network
344 | return:
345 | '''
346 | net.eval()
347 | for thresh in args.eval_thresholds.split(','):
348 | mf_score1 = AverageMeter()
349 | mf_pc_score1 = AverageMeter()
350 | ap_score1 = AverageMeter()
351 | ap_pc_score1 = AverageMeter()
352 | Fpc = np.zeros((args.dataset_cls.num_classes))
353 | Fc = np.zeros((args.dataset_cls.num_classes))
354 | for vi, data in enumerate(val_loader):
355 | input, mask, edge, img_names = data
356 | assert len(input.size()) == 4 and len(mask.size()) == 3
357 | assert input.size()[2:] == mask.size()[1:]
358 | h, w = mask.size()[1:]
359 |
360 | batch_pixel_size = input.size(0) * input.size(2) * input.size(3)
361 | input, mask_cuda, edge_cuda = input.cuda(), mask.cuda(), edge.cuda()
362 |
363 | with torch.no_grad():
364 | seg_out, edge_out = net(input)
365 |
366 | seg_predictions = seg_out.data.max(1)[1].cpu()
367 | edge_predictions = edge_out.max(1)[0].cpu()
368 |
369 | logging.info('evaluating: %d / %d' % (vi + 1, len(val_loader)))
370 | _Fpc, _Fc = eval_mask_boundary(seg_predictions.numpy(), mask.numpy(), args.dataset_cls.num_classes, bound_th=float(thresh))
371 | Fc += _Fc
372 | Fpc += _Fpc
373 |
374 | del seg_out, edge_out, vi, data
375 |
376 | logging.info('Threshold: ' + thresh)
377 | logging.info('F_Score: ' + str(np.sum(Fpc/Fc)/args.dataset_cls.num_classes))
378 | logging.info('F_Score (Classwise): ' + str(Fpc/Fc))
379 |
380 | if __name__ == '__main__':
381 | main()
382 |
383 |
384 |
385 |
386 |
--------------------------------------------------------------------------------
/transforms/joint_transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code borrowded from:
6 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/joint_transforms.py
7 | #
8 | #
9 | # MIT License
10 | #
11 | # Copyright (c) 2017 ZijunDeng
12 | #
13 | # Permission is hereby granted, free of charge, to any person obtaining a copy
14 | # of this software and associated documentation files (the "Software"), to deal
15 | # in the Software without restriction, including without limitation the rights
16 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17 | # copies of the Software, and to permit persons to whom the Software is
18 | # furnished to do so, subject to the following conditions:
19 | #
20 | # The above copyright notice and this permission notice shall be included in all
21 | # copies or substantial portions of the Software.
22 | #
23 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29 | # SOFTWARE.
30 |
31 | """
32 |
33 | import math
34 | import numbers
35 | import random
36 | from PIL import Image, ImageOps
37 | import numpy as np
38 | import random
39 |
40 | class Compose(object):
41 | def __init__(self, transforms):
42 | self.transforms = transforms
43 |
44 | def __call__(self, img, mask):
45 | assert img.size == mask.size
46 | for t in self.transforms:
47 | img, mask = t(img, mask)
48 | return img, mask
49 |
50 |
51 | class RandomCrop(object):
52 | '''
53 | Take a random crop from the image.
54 |
55 | First the image or crop size may need to be adjusted if the incoming image
56 | is too small...
57 |
58 | If the image is smaller than the crop, then:
59 | the image is padded up to the size of the crop
60 | unless 'nopad', in which case the crop size is shrunk to fit the image
61 |
62 | A random crop is taken such that the crop fits within the image.
63 | If a centroid is passed in, the crop must intersect the centroid.
64 | '''
65 | def __init__(self, size, ignore_index=0, nopad=True):
66 | if isinstance(size, numbers.Number):
67 | self.size = (int(size), int(size))
68 | else:
69 | self.size = size
70 | self.ignore_index = ignore_index
71 | self.nopad = nopad
72 | self.pad_color = (0, 0, 0)
73 |
74 | def __call__(self, img, mask, centroid=None):
75 | assert img.size == mask.size
76 | w, h = img.size
77 | # ASSUME H, W
78 | th, tw = self.size
79 | if w == tw and h == th:
80 | return img, mask
81 |
82 | if self.nopad:
83 | if th > h or tw > w:
84 | # Instead of padding, adjust crop size to the shorter edge of image.
85 | shorter_side = min(w, h)
86 | th, tw = shorter_side, shorter_side
87 | else:
88 | # Check if we need to pad img to fit for crop_size.
89 | if th > h:
90 | pad_h = (th - h) // 2 + 1
91 | else:
92 | pad_h = 0
93 | if tw > w:
94 | pad_w = (tw - w) // 2 + 1
95 | else:
96 | pad_w = 0
97 | border = (pad_w, pad_h, pad_w, pad_h)
98 | if pad_h or pad_w:
99 | img = ImageOps.expand(img, border=border, fill=self.pad_color)
100 | mask = ImageOps.expand(mask, border=border, fill=self.ignore_index)
101 | w, h = img.size
102 |
103 | if centroid is not None:
104 | # Need to insure that centroid is covered by crop and that crop
105 | # sits fully within the image
106 | c_x, c_y = centroid
107 | max_x = w - tw
108 | max_y = h - th
109 | x1 = random.randint(c_x - tw, c_x)
110 | x1 = min(max_x, max(0, x1))
111 | y1 = random.randint(c_y - th, c_y)
112 | y1 = min(max_y, max(0, y1))
113 | else:
114 | if w == tw:
115 | x1 = 0
116 | else:
117 | x1 = random.randint(0, w - tw)
118 | if h == th:
119 | y1 = 0
120 | else:
121 | y1 = random.randint(0, h - th)
122 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))
123 |
124 |
125 | class ResizeHeight(object):
126 | def __init__(self, size, interpolation=Image.BICUBIC):
127 | self.target_h = size
128 | self.interpolation = interpolation
129 |
130 | def __call__(self, img, mask):
131 | w, h = img.size
132 | target_w = int(w / h * self.target_h)
133 | return (img.resize((target_w, self.target_h), self.interpolation),
134 | mask.resize((target_w, self.target_h), Image.NEAREST))
135 |
136 |
137 | class CenterCrop(object):
138 | def __init__(self, size):
139 | if isinstance(size, numbers.Number):
140 | self.size = (int(size), int(size))
141 | else:
142 | self.size = size
143 |
144 | def __call__(self, img, mask):
145 | assert img.size == mask.size
146 | w, h = img.size
147 | th, tw = self.size
148 | x1 = int(round((w - tw) / 2.))
149 | y1 = int(round((h - th) / 2.))
150 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))
151 |
152 |
153 | class CenterCropPad(object):
154 | def __init__(self, size, ignore_index=0):
155 | if isinstance(size, numbers.Number):
156 | self.size = (int(size), int(size))
157 | else:
158 | self.size = size
159 | self.ignore_index = ignore_index
160 |
161 | def __call__(self, img, mask):
162 |
163 | assert img.size == mask.size
164 | w, h = img.size
165 | if isinstance(self.size, tuple):
166 | tw, th = self.size[0], self.size[1]
167 | else:
168 | th, tw = self.size, self.size
169 |
170 |
171 | if w < tw:
172 | pad_x = tw - w
173 | else:
174 | pad_x = 0
175 | if h < th:
176 | pad_y = th - h
177 | else:
178 | pad_y = 0
179 |
180 | if pad_x or pad_y:
181 | # left, top, right, bottom
182 | img = ImageOps.expand(img, border=(pad_x, pad_y, pad_x, pad_y), fill=0)
183 | mask = ImageOps.expand(mask, border=(pad_x, pad_y, pad_x, pad_y),
184 | fill=self.ignore_index)
185 |
186 | x1 = int(round((w - tw) / 2.))
187 | y1 = int(round((h - th) / 2.))
188 | return img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))
189 |
190 |
191 |
192 | class PadImage(object):
193 | def __init__(self, size, ignore_index):
194 | self.size = size
195 | self.ignore_index = ignore_index
196 |
197 |
198 | def __call__(self, img, mask):
199 | assert img.size == mask.size
200 | th, tw = self.size, self.size
201 |
202 |
203 | w, h = img.size
204 |
205 | if w > tw or h > th :
206 | wpercent = (tw/float(w))
207 | target_h = int((float(img.size[1])*float(wpercent)))
208 | img, mask = img.resize((tw, target_h), Image.BICUBIC), mask.resize((tw, target_h), Image.NEAREST)
209 |
210 | w, h = img.size
211 | ##Pad
212 | img = ImageOps.expand(img, border=(0,0,tw-w, th-h), fill=0)
213 | mask = ImageOps.expand(mask, border=(0,0,tw-w, th-h), fill=self.ignore_index)
214 |
215 | return img, mask
216 |
217 | class RandomHorizontallyFlip(object):
218 | def __call__(self, img, mask):
219 | if random.random() < 0.5:
220 | return img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(
221 | Image.FLIP_LEFT_RIGHT)
222 | return img, mask
223 |
224 |
225 | class FreeScale(object):
226 | def __init__(self, size):
227 | self.size = tuple(reversed(size)) # size: (h, w)
228 |
229 | def __call__(self, img, mask):
230 | assert img.size == mask.size
231 | return img.resize(self.size, Image.BICUBIC), mask.resize(self.size, Image.NEAREST)
232 |
233 |
234 | class Scale(object):
235 | '''
236 | Scale image such that longer side is == size
237 | '''
238 |
239 | def __init__(self, size):
240 | self.size = size
241 |
242 | def __call__(self, img, mask):
243 | assert img.size == mask.size
244 | w, h = img.size
245 | if (w >= h and w == self.size) or (h >= w and h == self.size):
246 | return img, mask
247 | if w > h:
248 | ow = self.size
249 | oh = int(self.size * h / w)
250 | return img.resize((ow, oh), Image.BICUBIC), mask.resize(
251 | (ow, oh), Image.NEAREST)
252 | else:
253 | oh = self.size
254 | ow = int(self.size * w / h)
255 | return img.resize((ow, oh), Image.BICUBIC), mask.resize(
256 | (ow, oh), Image.NEAREST)
257 |
258 |
259 | class ScaleMin(object):
260 | '''
261 | Scale image such that shorter side is == size
262 | '''
263 |
264 | def __init__(self, size):
265 | self.size = size
266 |
267 | def __call__(self, img, mask):
268 | assert img.size == mask.size
269 | w, h = img.size
270 | if (w <= h and w == self.size) or (h <= w and h == self.size):
271 | return img, mask
272 | if w < h:
273 | ow = self.size
274 | oh = int(self.size * h / w)
275 | return img.resize((ow, oh), Image.BICUBIC), mask.resize(
276 | (ow, oh), Image.NEAREST)
277 | else:
278 | oh = self.size
279 | ow = int(self.size * w / h)
280 | return img.resize((ow, oh), Image.BICUBIC), mask.resize(
281 | (ow, oh), Image.NEAREST)
282 |
283 |
284 | class Resize(object):
285 | '''
286 | Resize image to exact size of crop
287 | '''
288 |
289 | def __init__(self, size):
290 | self.size = (size, size)
291 |
292 | def __call__(self, img, mask):
293 | assert img.size == mask.size
294 | w, h = img.size
295 | if (w == h and w == self.size):
296 | return img, mask
297 | return (img.resize(self.size, Image.BICUBIC),
298 | mask.resize(self.size, Image.NEAREST))
299 |
300 |
301 | class RandomSizedCrop(object):
302 | def __init__(self, size):
303 | self.size = size
304 |
305 | def __call__(self, img, mask):
306 | assert img.size == mask.size
307 | for attempt in range(10):
308 | area = img.size[0] * img.size[1]
309 | target_area = random.uniform(0.45, 1.0) * area
310 | aspect_ratio = random.uniform(0.5, 2)
311 |
312 | w = int(round(math.sqrt(target_area * aspect_ratio)))
313 | h = int(round(math.sqrt(target_area / aspect_ratio)))
314 |
315 | if random.random() < 0.5:
316 | w, h = h, w
317 |
318 | if w <= img.size[0] and h <= img.size[1]:
319 | x1 = random.randint(0, img.size[0] - w)
320 | y1 = random.randint(0, img.size[1] - h)
321 |
322 | img = img.crop((x1, y1, x1 + w, y1 + h))
323 | mask = mask.crop((x1, y1, x1 + w, y1 + h))
324 | assert (img.size == (w, h))
325 |
326 | return img.resize((self.size, self.size), Image.BICUBIC),\
327 | mask.resize((self.size, self.size), Image.NEAREST)
328 |
329 | # Fallback
330 | scale = Scale(self.size)
331 | crop = CenterCrop(self.size)
332 | return crop(*scale(img, mask))
333 |
334 |
335 | class RandomRotate(object):
336 | def __init__(self, degree):
337 | self.degree = degree
338 |
339 | def __call__(self, img, mask):
340 | rotate_degree = random.random() * 2 * self.degree - self.degree
341 | return img.rotate(rotate_degree, Image.BICUBIC), mask.rotate(
342 | rotate_degree, Image.NEAREST)
343 |
344 |
345 | class RandomSizeAndCrop(object):
346 | def __init__(self, size, crop_nopad,
347 | scale_min=0.5, scale_max=2.0, ignore_index=0, pre_size=None):
348 | self.size = size
349 | self.crop = RandomCrop(self.size, ignore_index=ignore_index, nopad=crop_nopad)
350 | self.scale_min = scale_min
351 | self.scale_max = scale_max
352 | self.pre_size = pre_size
353 |
354 | def __call__(self, img, mask, centroid=None):
355 | assert img.size == mask.size
356 |
357 | # first, resize such that shorter edge is pre_size
358 | if self.pre_size is None:
359 | scale_amt = 1.
360 | elif img.size[1] < img.size[0]:
361 | scale_amt = self.pre_size / img.size[1]
362 | else:
363 | scale_amt = self.pre_size / img.size[0]
364 | scale_amt *= random.uniform(self.scale_min, self.scale_max)
365 | w, h = [int(i * scale_amt) for i in img.size]
366 |
367 | if centroid is not None:
368 | centroid = [int(c * scale_amt) for c in centroid]
369 |
370 | img, mask = img.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST)
371 |
372 | return self.crop(img, mask, centroid)
373 |
374 |
375 | class SlidingCropOld(object):
376 | def __init__(self, crop_size, stride_rate, ignore_label):
377 | self.crop_size = crop_size
378 | self.stride_rate = stride_rate
379 | self.ignore_label = ignore_label
380 |
381 | def _pad(self, img, mask):
382 | h, w = img.shape[: 2]
383 | pad_h = max(self.crop_size - h, 0)
384 | pad_w = max(self.crop_size - w, 0)
385 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant')
386 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant',
387 | constant_values=self.ignore_label)
388 | return img, mask
389 |
390 | def __call__(self, img, mask):
391 | assert img.size == mask.size
392 |
393 | w, h = img.size
394 | long_size = max(h, w)
395 |
396 | img = np.array(img)
397 | mask = np.array(mask)
398 |
399 | if long_size > self.crop_size:
400 | stride = int(math.ceil(self.crop_size * self.stride_rate))
401 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1
402 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1
403 | img_sublist, mask_sublist = [], []
404 | for yy in range(h_step_num):
405 | for xx in range(w_step_num):
406 | sy, sx = yy * stride, xx * stride
407 | ey, ex = sy + self.crop_size, sx + self.crop_size
408 | img_sub = img[sy: ey, sx: ex, :]
409 | mask_sub = mask[sy: ey, sx: ex]
410 | img_sub, mask_sub = self._pad(img_sub, mask_sub)
411 | img_sublist.append(
412 | Image.fromarray(
413 | img_sub.astype(
414 | np.uint8)).convert('RGB'))
415 | mask_sublist.append(
416 | Image.fromarray(
417 | mask_sub.astype(
418 | np.uint8)).convert('P'))
419 | return img_sublist, mask_sublist
420 | else:
421 | img, mask = self._pad(img, mask)
422 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB')
423 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
424 | return img, mask
425 |
426 |
427 | class SlidingCrop(object):
428 | def __init__(self, crop_size, stride_rate, ignore_label):
429 | self.crop_size = crop_size
430 | self.stride_rate = stride_rate
431 | self.ignore_label = ignore_label
432 |
433 | def _pad(self, img, mask):
434 | h, w = img.shape[: 2]
435 | pad_h = max(self.crop_size - h, 0)
436 | pad_w = max(self.crop_size - w, 0)
437 | img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), 'constant')
438 | mask = np.pad(mask, ((0, pad_h), (0, pad_w)), 'constant',
439 | constant_values=self.ignore_label)
440 | return img, mask, h, w
441 |
442 | def __call__(self, img, mask):
443 | assert img.size == mask.size
444 |
445 | w, h = img.size
446 | long_size = max(h, w)
447 |
448 | img = np.array(img)
449 | mask = np.array(mask)
450 |
451 | if long_size > self.crop_size:
452 | stride = int(math.ceil(self.crop_size * self.stride_rate))
453 | h_step_num = int(math.ceil((h - self.crop_size) / float(stride))) + 1
454 | w_step_num = int(math.ceil((w - self.crop_size) / float(stride))) + 1
455 | img_slices, mask_slices, slices_info = [], [], []
456 | for yy in range(h_step_num):
457 | for xx in range(w_step_num):
458 | sy, sx = yy * stride, xx * stride
459 | ey, ex = sy + self.crop_size, sx + self.crop_size
460 | img_sub = img[sy: ey, sx: ex, :]
461 | mask_sub = mask[sy: ey, sx: ex]
462 | img_sub, mask_sub, sub_h, sub_w = self._pad(img_sub, mask_sub)
463 | img_slices.append(
464 | Image.fromarray(
465 | img_sub.astype(
466 | np.uint8)).convert('RGB'))
467 | mask_slices.append(
468 | Image.fromarray(
469 | mask_sub.astype(
470 | np.uint8)).convert('P'))
471 | slices_info.append([sy, ey, sx, ex, sub_h, sub_w])
472 | return img_slices, mask_slices, slices_info
473 | else:
474 | img, mask, sub_h, sub_w = self._pad(img, mask)
475 | img = Image.fromarray(img.astype(np.uint8)).convert('RGB')
476 | mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
477 | return [img], [mask], [[0, sub_h, 0, sub_w, sub_h, sub_w]]
478 |
479 |
480 | class ClassUniform(object):
481 | def __init__(self, size, crop_nopad, scale_min=0.5, scale_max=2.0, ignore_index=0,
482 | class_list=[16, 15, 14]):
483 | """
484 | This is the initialization for class uniform sampling
485 | :param size: crop size (int)
486 | :param crop_nopad: Padding or no padding (bool)
487 | :param scale_min: Minimum Scale (float)
488 | :param scale_max: Maximum Scale (float)
489 | :param ignore_index: The index value to ignore in the GT images (unsigned int)
490 | :param class_list: A list of class to sample around, by default Truck, train, bus
491 | """
492 | self.size = size
493 | self.crop = RandomCrop(self.size, ignore_index=ignore_index, nopad=crop_nopad)
494 |
495 | self.class_list = class_list.replace(" ", "").split(",")
496 |
497 | self.scale_min = scale_min
498 | self.scale_max = scale_max
499 |
500 | def detect_peaks(self, image):
501 | """
502 | Takes an image and detect the peaks usingthe local maximum filter.
503 | Returns a boolean mask of the peaks (i.e. 1 when
504 | the pixel's value is the neighborhood maximum, 0 otherwise)
505 |
506 | :param image: An 2d input images
507 | :return: Binary output images of the same size as input with pixel value equal
508 | to 1 indicating that there is peak at that point
509 | """
510 |
511 | # define an 8-connected neighborhood
512 | neighborhood = generate_binary_structure(2, 2)
513 |
514 | # apply the local maximum filter; all pixel of maximal value
515 | # in their neighborhood are set to 1
516 | local_max = maximum_filter(image, footprint=neighborhood) == image
517 | # local_max is a mask that contains the peaks we are
518 | # looking for, but also the background.
519 | # In order to isolate the peaks we must remove the background from the mask.
520 |
521 | # we create the mask of the background
522 | background = (image == 0)
523 |
524 | # a little technicality: we must erode the background in order to
525 | # successfully subtract it form local_max, otherwise a line will
526 | # appear along the background border (artifact of the local maximum filter)
527 | eroded_background = binary_erosion(background, structure=neighborhood,
528 | border_value=1)
529 |
530 | # we obtain the final mask, containing only peaks,
531 | # by removing the background from the local_max mask (xor operation)
532 | detected_peaks = local_max ^ eroded_background
533 |
534 | return detected_peaks
535 |
536 | def __call__(self, img, mask):
537 | """
538 | :param img: PIL Input Image
539 | :param mask: PIL Input Mask
540 | :return: PIL output PIL (mask, crop) of self.crop_size
541 | """
542 | assert img.size == mask.size
543 |
544 | scale_amt = random.uniform(self.scale_min, self.scale_max)
545 | w = int(scale_amt * img.size[0])
546 | h = int(scale_amt * img.size[1])
547 |
548 | if scale_amt < 1.0:
549 | img, mask = img.resize((w, h), Image.BICUBIC), mask.resize((w, h),
550 | Image.NEAREST)
551 | return self.crop(img, mask)
552 | else:
553 | # Smart Crop ( Class Uniform's ABN)
554 | origw, origh = mask.size
555 | img_new, mask_new = \
556 | img.resize((w, h), Image.BICUBIC), mask.resize((w, h), Image.NEAREST)
557 | interested_class = self.class_list # [16, 15, 14] # Train, Truck, Bus
558 | data = np.array(mask)
559 | arr = np.zeros((1024, 2048))
560 | for class_of_interest in interested_class:
561 | # hist = np.histogram(data==class_of_interest)
562 | map = np.where(data == class_of_interest, data, 0)
563 | map = map.astype('float64') / map.sum() / class_of_interest
564 | map[np.isnan(map)] = 0
565 | arr = arr + map
566 |
567 | origarr = arr
568 | window_size = 250
569 |
570 | # Given a list of classes of interest find the points on the image that are
571 | # of interest to crop from
572 | sum_arr = np.zeros((1024, 2048)).astype('float32')
573 | tmp = np.zeros((1024, 2048)).astype('float32')
574 | for x in range(0, arr.shape[0] - window_size, window_size):
575 | for y in range(0, arr.shape[1] - window_size, window_size):
576 | sum_arr[int(x + window_size / 2), int(y + window_size / 2)] = origarr[
577 | x:x + window_size,
578 | y:y + window_size].sum()
579 | tmp[x:x + window_size, y:y + window_size] = \
580 | origarr[x:x + window_size, y:y + window_size].sum()
581 |
582 | # Scaling Ratios in X and Y for non-uniform images
583 | ratio = (float(origw) / w, float(origh) / h)
584 | output = self.detect_peaks(sum_arr)
585 | coord = (np.column_stack(np.where(output))).tolist()
586 |
587 | # Check if there are any peaks in the images to crop from if not do standard
588 | # cropping behaviour
589 | if len(coord) == 0:
590 | return self.crop(img_new, mask_new)
591 | else:
592 | # If peaks are detected, random peak selection followed by peak
593 | # coordinate scaling to new scaled image and then random
594 | # cropping around the peak point in the scaled image
595 | randompick = np.random.randint(len(coord))
596 | y, x = coord[randompick]
597 | y, x = int(y * ratio[0]), int(x * ratio[1])
598 | window_size = window_size * ratio[0]
599 | cropx = random.uniform(
600 | max(0, (x - window_size / 2) - (self.size - window_size)),
601 | max((x - window_size / 2), (x - window_size / 2) - (
602 | (w - window_size) - x + window_size / 2)))
603 |
604 | cropy = random.uniform(
605 | max(0, (y - window_size / 2) - (self.size - window_size)),
606 | max((y - window_size / 2), (y - window_size / 2) - (
607 | (h - window_size) - y + window_size / 2)))
608 |
609 | return_img = img_new.crop(
610 | (cropx, cropy, cropx + self.size, cropy + self.size))
611 | return_mask = mask_new.crop(
612 | (cropx, cropy, cropx + self.size, cropy + self.size))
613 | return (return_img, return_mask)
614 |
--------------------------------------------------------------------------------
/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code borrowded from:
6 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py
7 | #
8 | #
9 | # MIT License
10 | #
11 | # Copyright (c) 2017 ZijunDeng
12 | #
13 | # Permission is hereby granted, free of charge, to any person obtaining a copy
14 | # of this software and associated documentation files (the "Software"), to deal
15 | # in the Software without restriction, including without limitation the rights
16 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17 | # copies of the Software, and to permit persons to whom the Software is
18 | # furnished to do so, subject to the following conditions:
19 | #
20 | # The above copyright notice and this permission notice shall be included in all
21 | # copies or substantial portions of the Software.
22 | #
23 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29 | # SOFTWARE.
30 |
31 | """
32 |
33 | import random
34 | import numpy as np
35 | from skimage.filters import gaussian
36 | from skimage.restoration import denoise_bilateral
37 | import torch
38 | from PIL import Image, ImageFilter, ImageEnhance
39 | import torchvision.transforms as torch_tr
40 | from scipy import ndimage
41 | from config import cfg
42 | from scipy.ndimage.interpolation import shift
43 | from scipy.misc import imsave
44 | from skimage.segmentation import find_boundaries
45 | class RandomVerticalFlip(object):
46 | def __call__(self, img):
47 | if random.random() < 0.5:
48 | return img.transpose(Image.FLIP_TOP_BOTTOM)
49 | return img
50 |
51 |
52 | class DeNormalize(object):
53 | def __init__(self, mean, std):
54 | self.mean = mean
55 | self.std = std
56 |
57 | def __call__(self, tensor):
58 | for t, m, s in zip(tensor, self.mean, self.std):
59 | t.mul_(s).add_(m)
60 | return tensor
61 |
62 |
63 | class MaskToTensor(object):
64 | def __call__(self, img):
65 | return torch.from_numpy(np.array(img, dtype=np.int32)).long()
66 |
67 | class RelaxedBoundaryLossToTensor(object):
68 | def __init__(self,ignore_id, num_classes):
69 | self.ignore_id=ignore_id
70 | self.num_classes= num_classes
71 |
72 |
73 | def new_one_hot_converter(self,a):
74 | ncols = self.num_classes+1
75 | out = np.zeros( (a.size,ncols), dtype=np.uint8)
76 | out[np.arange(a.size),a.ravel()] = 1
77 | out.shape = a.shape + (ncols,)
78 | return out
79 |
80 | def __call__(self,img):
81 |
82 |
83 |
84 | img_arr = np.array(img)
85 | import scipy;scipy.misc.imsave('orig.png',img_arr)
86 |
87 | img_arr[img_arr==self.ignore_id]=self.num_classes
88 |
89 | if cfg.STRICTBORDERCLASS != None:
90 | one_hot_orig = self.new_one_hot_converter(img_arr)
91 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1]))
92 | for cls in cfg.STRICTBORDERCLASS:
93 | mask = np.logical_or(mask,(img_arr == cls))
94 | one_hot = 0
95 |
96 | #print(cfg.EPOCH, "Non Reduced", cfg.TRAIN.REDUCE_RELAXEDITERATIONCOUNT)
97 | border = cfg.BORDER_WINDOW
98 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH):
99 | border = border // 2
100 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8)
101 | print(cfg.EPOCH, "Reduced")
102 |
103 | for i in range(-border,border+1):
104 | for j in range(-border, border+1):
105 | shifted= shift(img_arr,(i,j), cval=self.num_classes)
106 | one_hot += self.new_one_hot_converter(shifted)
107 |
108 | one_hot[one_hot>1] = 1
109 |
110 | if cfg.STRICTBORDERCLASS != None:
111 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot)
112 |
113 | one_hot = np.moveaxis(one_hot,-1,0)
114 |
115 |
116 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH):
117 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot)
118 | print(one_hot.shape)
119 | return torch.from_numpy(one_hot).byte()
120 | #return torch.from_numpy(one_hot).float()
121 | exit(0)
122 |
123 | class ResizeHeight(object):
124 | def __init__(self, size, interpolation=Image.BILINEAR):
125 | self.target_h = size
126 | self.interpolation = interpolation
127 |
128 | def __call__(self, img):
129 | w, h = img.size
130 | target_w = int(w / h * self.target_h)
131 | return img.resize((target_w, self.target_h), self.interpolation)
132 |
133 |
134 | class FreeScale(object):
135 | def __init__(self, size, interpolation=Image.BILINEAR):
136 | self.size = tuple(reversed(size)) # size: (h, w)
137 | self.interpolation = interpolation
138 |
139 | def __call__(self, img):
140 | return img.resize(self.size, self.interpolation)
141 |
142 |
143 | class FlipChannels(object):
144 | def __call__(self, img):
145 | img = np.array(img)[:, :, ::-1]
146 | return Image.fromarray(img.astype(np.uint8))
147 |
148 | class RandomGaussianBlur(object):
149 | def __call__(self, img):
150 | sigma = 0.15 + random.random() * 1.15
151 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
152 | blurred_img *= 255
153 | return Image.fromarray(blurred_img.astype(np.uint8))
154 |
155 |
156 | class RandomBilateralBlur(object):
157 | def __call__(self, img):
158 | sigma = random.uniform(0.05,0.75)
159 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True)
160 | blurred_img *= 255
161 | return Image.fromarray(blurred_img.astype(np.uint8))
162 |
163 | try:
164 | import accimage
165 | except ImportError:
166 | accimage = None
167 |
168 |
169 | def _is_pil_image(img):
170 | if accimage is not None:
171 | return isinstance(img, (Image.Image, accimage.Image))
172 | else:
173 | return isinstance(img, Image.Image)
174 |
175 |
176 | def adjust_brightness(img, brightness_factor):
177 | """Adjust brightness of an Image.
178 |
179 | Args:
180 | img (PIL Image): PIL Image to be adjusted.
181 | brightness_factor (float): How much to adjust the brightness. Can be
182 | any non negative number. 0 gives a black image, 1 gives the
183 | original image while 2 increases the brightness by a factor of 2.
184 |
185 | Returns:
186 | PIL Image: Brightness adjusted image.
187 | """
188 | if not _is_pil_image(img):
189 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
190 |
191 | enhancer = ImageEnhance.Brightness(img)
192 | img = enhancer.enhance(brightness_factor)
193 | return img
194 |
195 |
196 | def adjust_contrast(img, contrast_factor):
197 | """Adjust contrast of an Image.
198 |
199 | Args:
200 | img (PIL Image): PIL Image to be adjusted.
201 | contrast_factor (float): How much to adjust the contrast. Can be any
202 | non negative number. 0 gives a solid gray image, 1 gives the
203 | original image while 2 increases the contrast by a factor of 2.
204 |
205 | Returns:
206 | PIL Image: Contrast adjusted image.
207 | """
208 | if not _is_pil_image(img):
209 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
210 |
211 | enhancer = ImageEnhance.Contrast(img)
212 | img = enhancer.enhance(contrast_factor)
213 | return img
214 |
215 |
216 | def adjust_saturation(img, saturation_factor):
217 | """Adjust color saturation of an image.
218 |
219 | Args:
220 | img (PIL Image): PIL Image to be adjusted.
221 | saturation_factor (float): How much to adjust the saturation. 0 will
222 | give a black and white image, 1 will give the original image while
223 | 2 will enhance the saturation by a factor of 2.
224 |
225 | Returns:
226 | PIL Image: Saturation adjusted image.
227 | """
228 | if not _is_pil_image(img):
229 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
230 |
231 | enhancer = ImageEnhance.Color(img)
232 | img = enhancer.enhance(saturation_factor)
233 | return img
234 |
235 |
236 | def adjust_hue(img, hue_factor):
237 | """Adjust hue of an image.
238 |
239 | The image hue is adjusted by converting the image to HSV and
240 | cyclically shifting the intensities in the hue channel (H).
241 | The image is then converted back to original image mode.
242 |
243 | `hue_factor` is the amount of shift in H channel and must be in the
244 | interval `[-0.5, 0.5]`.
245 |
246 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
247 |
248 | Args:
249 | img (PIL Image): PIL Image to be adjusted.
250 | hue_factor (float): How much to shift the hue channel. Should be in
251 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
252 | HSV space in positive and negative direction respectively.
253 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
254 | with complementary colors while 0 gives the original image.
255 |
256 | Returns:
257 | PIL Image: Hue adjusted image.
258 | """
259 | if not(-0.5 <= hue_factor <= 0.5):
260 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
261 |
262 | if not _is_pil_image(img):
263 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
264 |
265 | input_mode = img.mode
266 | if input_mode in {'L', '1', 'I', 'F'}:
267 | return img
268 |
269 | h, s, v = img.convert('HSV').split()
270 |
271 | np_h = np.array(h, dtype=np.uint8)
272 | # uint8 addition take cares of rotation across boundaries
273 | with np.errstate(over='ignore'):
274 | np_h += np.uint8(hue_factor * 255)
275 | h = Image.fromarray(np_h, 'L')
276 |
277 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
278 | return img
279 |
280 |
281 | class ColorJitter(object):
282 | """Randomly change the brightness, contrast and saturation of an image.
283 |
284 | Args:
285 | brightness (float): How much to jitter brightness. brightness_factor
286 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
287 | contrast (float): How much to jitter contrast. contrast_factor
288 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
289 | saturation (float): How much to jitter saturation. saturation_factor
290 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
291 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
292 | [-hue, hue]. Should be >=0 and <= 0.5.
293 | """
294 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
295 | self.brightness = brightness
296 | self.contrast = contrast
297 | self.saturation = saturation
298 | self.hue = hue
299 |
300 | @staticmethod
301 | def get_params(brightness, contrast, saturation, hue):
302 | """Get a randomized transform to be applied on image.
303 |
304 | Arguments are same as that of __init__.
305 |
306 | Returns:
307 | Transform which randomly adjusts brightness, contrast and
308 | saturation in a random order.
309 | """
310 | transforms = []
311 | if brightness > 0:
312 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
313 | transforms.append(
314 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
315 |
316 | if contrast > 0:
317 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
318 | transforms.append(
319 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
320 |
321 | if saturation > 0:
322 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
323 | transforms.append(
324 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
325 |
326 | if hue > 0:
327 | hue_factor = np.random.uniform(-hue, hue)
328 | transforms.append(
329 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
330 |
331 | np.random.shuffle(transforms)
332 | transform = torch_tr.Compose(transforms)
333 |
334 | return transform
335 |
336 | def __call__(self, img):
337 | """
338 | Args:
339 | img (PIL Image): Input image.
340 |
341 | Returns:
342 | PIL Image: Color jittered image.
343 | """
344 | transform = self.get_params(self.brightness, self.contrast,
345 | self.saturation, self.hue)
346 | return transform(img)
347 |
--------------------------------------------------------------------------------
/utils/AttrDict.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code adapted from:
6 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py
7 |
8 | Source License
9 | # Copyright (c) 2017-present, Facebook, Inc.
10 | #
11 | # Licensed under the Apache License, Version 2.0 (the "License");
12 | # you may not use this file except in compliance with the License.
13 | # You may obtain a copy of the License at
14 | #
15 | # http://www.apache.org/licenses/LICENSE-2.0
16 | #
17 | # Unless required by applicable law or agreed to in writing, software
18 | # distributed under the License is distributed on an "AS IS" BASIS,
19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20 | # See the License for the specific language governing permissions and
21 | # limitations under the License.
22 | ##############################################################################
23 | #
24 | # Based on:
25 | # --------------------------------------------------------
26 | # Fast R-CNN
27 | # Copyright (c) 2015 Microsoft
28 | # Licensed under The MIT License [see LICENSE for details]
29 | # Written by Ross Girshick
30 | # --------------------------------------------------------
31 | """
32 |
33 |
34 | class AttrDict(dict):
35 |
36 | IMMUTABLE = '__immutable__'
37 |
38 | def __init__(self, *args, **kwargs):
39 | super(AttrDict, self).__init__(*args, **kwargs)
40 | self.__dict__[AttrDict.IMMUTABLE] = False
41 |
42 | def __getattr__(self, name):
43 | if name in self.__dict__:
44 | return self.__dict__[name]
45 | elif name in self:
46 | return self[name]
47 | else:
48 | raise AttributeError(name)
49 |
50 | def __setattr__(self, name, value):
51 | if not self.__dict__[AttrDict.IMMUTABLE]:
52 | if name in self.__dict__:
53 | self.__dict__[name] = value
54 | else:
55 | self[name] = value
56 | else:
57 | raise AttributeError(
58 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'.
59 | format(name, value)
60 | )
61 |
62 | def immutable(self, is_immutable):
63 | """Set immutability to is_immutable and recursively apply the setting
64 | to all nested AttrDicts.
65 | """
66 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable
67 | # Recursively set immutable state
68 | for v in self.__dict__.values():
69 | if isinstance(v, AttrDict):
70 | v.immutable(is_immutable)
71 | for v in self.values():
72 | if isinstance(v, AttrDict):
73 | v.immutable(is_immutable)
74 |
75 | def is_immutable(self):
76 | return self.__dict__[AttrDict.IMMUTABLE]
77 |
--------------------------------------------------------------------------------
/utils/f_boundary.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | # Code adapted from:
6 | # https://github.com/fperazzi/davis/blob/master/python/lib/davis/measures/f_boundary.py
7 | #
8 | # Source License
9 | #
10 | # BSD 3-Clause License
11 | #
12 | # Copyright (c) 2017,
13 | # All rights reserved.
14 | #
15 | # Redistribution and use in source and binary forms, with or without
16 | # modification, are permitted provided that the following conditions are met:
17 | #
18 | # * Redistributions of source code must retain the above copyright notice, this
19 | # list of conditions and the following disclaimer.
20 | #
21 | # * Redistributions in binary form must reproduce the above copyright notice,
22 | # this list of conditions and the following disclaimer in the documentation
23 | # and/or other materials provided with the distribution.
24 | #
25 | # * Neither the name of the copyright holder nor the names of its
26 | # contributors may be used to endorse or promote products derived from
27 | # this software without specific prior written permission.
28 | #
29 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
30 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
31 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
32 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
33 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
34 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
35 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
36 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
37 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
38 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s
39 | ##############################################################################
40 | #
41 | # Based on:
42 | # ----------------------------------------------------------------------------
43 | # A Benchmark Dataset and Evaluation Methodology for Video Object Segmentation
44 | # Copyright (c) 2016 Federico Perazzi
45 | # Licensed under the BSD License [see LICENSE for details]
46 | # Written by Federico Perazzi
47 | # ----------------------------------------------------------------------------
48 | """
49 |
50 |
51 |
52 |
53 | import numpy as np
54 | from multiprocessing import Pool
55 | from tqdm import tqdm
56 |
57 | """ Utilities for computing, reading and saving benchmark evaluation."""
58 |
59 | def eval_mask_boundary(seg_mask,gt_mask,num_classes,num_proc=10,bound_th=0.008):
60 | """
61 | Compute F score for a segmentation mask
62 |
63 | Arguments:
64 | seg_mask (ndarray): segmentation mask prediction
65 | gt_mask (ndarray): segmentation mask ground truth
66 | num_classes (int): number of classes
67 |
68 | Returns:
69 | F (float): mean F score across all classes
70 | Fpc (listof float): F score per class
71 | """
72 | p = Pool(processes=num_proc)
73 | batch_size = seg_mask.shape[0]
74 |
75 | Fpc = np.zeros(num_classes)
76 | Fc = np.zeros(num_classes)
77 | for class_id in tqdm(range(num_classes)):
78 | args = [((seg_mask[i] == class_id).astype(np.uint8),
79 | (gt_mask[i] == class_id).astype(np.uint8),
80 | gt_mask[i] == 255,
81 | bound_th)
82 | for i in range(batch_size)]
83 | temp = p.map(db_eval_boundary_wrapper, args)
84 | temp = np.array(temp)
85 | Fs = temp[:,0]
86 | _valid = ~np.isnan(Fs)
87 | Fc[class_id] = np.sum(_valid)
88 | Fs[np.isnan(Fs)] = 0
89 | Fpc[class_id] = sum(Fs)
90 | return Fpc, Fc
91 |
92 |
93 | #def db_eval_boundary_wrapper_wrapper(args):
94 | # seg_mask, gt_mask, class_id, batch_size, Fpc = args
95 | # print("class_id:" + str(class_id))
96 | # p = Pool(processes=10)
97 | # args = [((seg_mask[i] == class_id).astype(np.uint8),
98 | # (gt_mask[i] == class_id).astype(np.uint8))
99 | # for i in range(batch_size)]
100 | # Fs = p.map(db_eval_boundary_wrapper, args)
101 | # Fpc[class_id] = sum(Fs)
102 | # return
103 |
104 | def db_eval_boundary_wrapper(args):
105 | foreground_mask, gt_mask, ignore, bound_th = args
106 | return db_eval_boundary(foreground_mask, gt_mask,ignore, bound_th)
107 |
108 | def db_eval_boundary(foreground_mask,gt_mask, ignore_mask,bound_th=0.008):
109 | """
110 | Compute mean,recall and decay from per-frame evaluation.
111 | Calculates precision/recall for boundaries between foreground_mask and
112 | gt_mask using morphological operators to speed it up.
113 |
114 | Arguments:
115 | foreground_mask (ndarray): binary segmentation image.
116 | gt_mask (ndarray): binary annotated image.
117 |
118 | Returns:
119 | F (float): boundaries F-measure
120 | P (float): boundaries precision
121 | R (float): boundaries recall
122 | """
123 | assert np.atleast_3d(foreground_mask).shape[2] == 1
124 |
125 | bound_pix = bound_th if bound_th >= 1 else \
126 | np.ceil(bound_th*np.linalg.norm(foreground_mask.shape))
127 |
128 | #print(bound_pix)
129 | #print(gt.shape)
130 | #print(np.unique(gt))
131 | foreground_mask[ignore_mask] = 0
132 | gt_mask[ignore_mask] = 0
133 |
134 | # Get the pixel boundaries of both masks
135 | fg_boundary = seg2bmap(foreground_mask);
136 | gt_boundary = seg2bmap(gt_mask);
137 |
138 | from skimage.morphology import binary_dilation,disk
139 |
140 | fg_dil = binary_dilation(fg_boundary,disk(bound_pix))
141 | gt_dil = binary_dilation(gt_boundary,disk(bound_pix))
142 |
143 | # Get the intersection
144 | gt_match = gt_boundary * fg_dil
145 | fg_match = fg_boundary * gt_dil
146 |
147 | # Area of the intersection
148 | n_fg = np.sum(fg_boundary)
149 | n_gt = np.sum(gt_boundary)
150 |
151 | #% Compute precision and recall
152 | if n_fg == 0 and n_gt > 0:
153 | precision = 1
154 | recall = 0
155 | elif n_fg > 0 and n_gt == 0:
156 | precision = 0
157 | recall = 1
158 | elif n_fg == 0 and n_gt == 0:
159 | precision = 1
160 | recall = 1
161 | else:
162 | precision = np.sum(fg_match)/float(n_fg)
163 | recall = np.sum(gt_match)/float(n_gt)
164 |
165 | # Compute F measure
166 | if precision + recall == 0:
167 | F = 0
168 | else:
169 | F = 2*precision*recall/(precision+recall);
170 |
171 | return F, precision
172 |
173 | def seg2bmap(seg,width=None,height=None):
174 | """
175 | From a segmentation, compute a binary boundary map with 1 pixel wide
176 | boundaries. The boundary pixels are offset by 1/2 pixel towards the
177 | origin from the actual segment boundary.
178 |
179 | Arguments:
180 | seg : Segments labeled from 1..k.
181 | width : Width of desired bmap <= seg.shape[1]
182 | height : Height of desired bmap <= seg.shape[0]
183 |
184 | Returns:
185 | bmap (ndarray): Binary boundary map.
186 |
187 | David Martin
188 | January 2003
189 | """
190 |
191 | seg = seg.astype(np.bool)
192 | seg[seg>0] = 1
193 |
194 | assert np.atleast_3d(seg).shape[2] == 1
195 |
196 | width = seg.shape[1] if width is None else width
197 | height = seg.shape[0] if height is None else height
198 |
199 | h,w = seg.shape[:2]
200 |
201 | ar1 = float(width) / float(height)
202 | ar2 = float(w) / float(h)
203 |
204 | assert not (width>w | height>h | abs(ar1-ar2)>0.01),\
205 | 'Can''t convert %dx%d seg to %dx%d bmap.'%(w,h,width,height)
206 |
207 | e = np.zeros_like(seg)
208 | s = np.zeros_like(seg)
209 | se = np.zeros_like(seg)
210 |
211 | e[:,:-1] = seg[:,1:]
212 | s[:-1,:] = seg[1:,:]
213 | se[:-1,:-1] = seg[1:,1:]
214 |
215 | b = seg^e | seg^s | seg^se
216 | b[-1,:] = seg[-1,:]^e[-1,:]
217 | b[:,-1] = seg[:,-1]^s[:,-1]
218 | b[-1,-1] = 0
219 |
220 | if w == width and h == height:
221 | bmap = b
222 | else:
223 | bmap = np.zeros((height,width))
224 | for x in range(w):
225 | for y in range(h):
226 | if b[y,x]:
227 | j = 1+floor((y-1)+height / h)
228 | i = 1+floor((x-1)+width / h)
229 | bmap[j,i] = 1;
230 |
231 | return bmap
232 |
--------------------------------------------------------------------------------
/utils/image_page.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import glob
7 | import os
8 |
9 | class ImagePage(object):
10 | '''
11 | This creates an HTML page of embedded images, useful for showing evaluation results.
12 |
13 | Usage:
14 | ip = ImagePage(html_fn)
15 |
16 | # Add a table with N images ...
17 | ip.add_table((img, descr), (img, descr), ...)
18 |
19 | # Generate html page
20 | ip.write_page()
21 | '''
22 | def __init__(self, experiment_name, html_filename):
23 | self.experiment_name = experiment_name
24 | self.html_filename = html_filename
25 | self.outfile = open(self.html_filename, 'w')
26 | self.items = []
27 |
28 | def _print_header(self):
29 | header = '''
30 |
31 |
32 | Experiment = {}
33 |
34 | '''.format(self.experiment_name)
35 | self.outfile.write(header)
36 |
37 | def _print_footer(self):
38 | self.outfile.write('''
39 | ''')
40 |
41 | def _print_table_header(self, table_name):
42 | table_hdr = ''' {}
43 |
44 | '''.format(table_name)
45 | self.outfile.write(table_hdr)
46 |
47 | def _print_table_footer(self):
48 | table_ftr = '''
49 |
'''
50 | self.outfile.write(table_ftr)
51 |
52 | def _print_table_guts(self, img_fn, descr):
53 | table = '''
54 |
55 |
56 |
57 |
58 | {descr}
59 |
60 | | '''.format(img_fn=img_fn, descr=descr)
61 | self.outfile.write(table)
62 |
63 | def add_table(self, img_label_pairs):
64 | self.items.append(img_label_pairs)
65 |
66 | def _write_table(self, table):
67 | img, _descr = table[0]
68 | self._print_table_header(os.path.basename(img))
69 | for img, descr in table:
70 | self._print_table_guts(img, descr)
71 | self._print_table_footer()
72 |
73 | def write_page(self):
74 | self._print_header()
75 |
76 | for table in self.items:
77 | self._write_table(table)
78 |
79 | self._print_footer()
80 |
81 |
82 | def main():
83 | images = glob.glob('dump_imgs_train/*.png')
84 | images = [i for i in images if 'mask' not in i]
85 |
86 | ip = ImagePage('test page', 'dd.html')
87 | for img in images:
88 | basename = os.path.splitext(img)[0]
89 | mask_img = basename + '_mask.png'
90 | ip.add_table(((img, 'image'), (mask_img, 'mask')))
91 | ip.write_page()
92 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import sys
7 | import re
8 | import os
9 | import shutil
10 | import torch
11 | from datetime import datetime
12 | import logging
13 | from subprocess import call
14 | import shlex
15 | from tensorboardX import SummaryWriter
16 | import numpy as np
17 | from utils.image_page import ImagePage
18 | import torchvision.transforms as standard_transforms
19 | import torchvision.utils as vutils
20 | from PIL import Image
21 |
22 | # Create unique output dir name based on non-default command line args
23 | def make_exp_name(args, parser):
24 | exp_name = '{}-{}'.format(args.dataset[:3], args.arch[:])
25 | dict_args = vars(args)
26 |
27 | # sort so that we get a consistent directory name
28 | argnames = sorted(dict_args)
29 |
30 | # build experiment name with non-default args
31 | for argname in argnames:
32 | if dict_args[argname] != parser.get_default(argname):
33 | if argname == 'exp' or argname == 'arch' or argname == 'prev_best_filepath':
34 | continue
35 | if argname == 'snapshot':
36 | arg_str = '-PT'
37 | elif argname == 'nosave':
38 | arg_str = ''
39 | argname=''
40 | elif argname == 'freeze_trunk':
41 | argname = ''
42 | arg_str = '-fr'
43 | elif argname == 'syncbn':
44 | argname = ''
45 | arg_str = '-sbn'
46 | elif argname == 'relaxedloss':
47 | argname = ''
48 | arg_str = 're-loss'
49 | elif isinstance(dict_args[argname], bool):
50 | arg_str = 'T' if dict_args[argname] else 'F'
51 | else:
52 | arg_str = str(dict_args[argname])[:6]
53 | exp_name += '-{}_{}'.format(str(argname), arg_str)
54 | # clean special chars out
55 | exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name)
56 | exp_name = 'testing'
57 | return exp_name
58 |
59 |
60 | def save_log(prefix, output_dir, date_str):
61 | fmt = '%(asctime)s.%(msecs)03d %(message)s'
62 | date_fmt = '%m-%d %H:%M:%S'
63 | filename = os.path.join(output_dir, prefix + '_' + date_str + '.log')
64 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt,
65 | filename=filename, filemode='w')
66 | console = logging.StreamHandler()
67 | console.setLevel(logging.INFO)
68 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt)
69 | console.setFormatter(formatter)
70 | logging.getLogger('').addHandler(console)
71 | #logging.basicConfig(level=logging.INFO, format='%(asctime)s.%(msecs)03d %(levelname)s:\t%(message)s', datefmt='%Y-%m-%d %H:%M:%S')
72 |
73 |
74 | def save_code(exp_path, date_str):
75 | code_root = '.' # FIXME!
76 | zip_outfile = os.path.join(exp_path, 'code_{}.tgz'.format(date_str))
77 | print('Saving code to {}'.format(zip_outfile))
78 | cmd = 'tar -czvf {zip_outfile} --exclude=\'*.pyc\' --exclude=\'*.png\' ' +\
79 | '--exclude=\'*tfevents*\' {root}/train.py ' + \
80 | ' {root}/utils {root}/datasets {root}/models'
81 | cmd = cmd.format(zip_outfile=zip_outfile, root=code_root)
82 | call(shlex.split(cmd), stdout=open(os.devnull, 'wb'))
83 |
84 |
85 | def prep_experiment(args, parser):
86 | '''
87 | Make output directories, setup logging, Tensorboard, snapshot code.
88 | '''
89 | ckpt_path = args.ckpt
90 | tb_path = args.tb_path
91 | exp_name = make_exp_name(args, parser)
92 | args.exp_path = os.path.join(ckpt_path, args.exp, exp_name)
93 | args.tb_exp_path = os.path.join(tb_path, args.exp, exp_name)
94 | args.ngpu = torch.cuda.device_count()
95 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
96 | args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
97 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
98 | args.last_record = {}
99 | os.makedirs(args.exp_path, exist_ok=True)
100 | os.makedirs(args.tb_exp_path, exist_ok=True)
101 | save_log('log', args.exp_path, args.date_str)
102 | #save_code(args.exp_path, args.date_str)
103 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write(
104 | str(args) + '\n\n')
105 |
106 | writer = SummaryWriter(logdir=args.tb_exp_path, comment=args.tb_tag)
107 | return writer
108 |
109 | class AverageMeter(object):
110 |
111 | def __init__(self):
112 | self.reset()
113 |
114 | def reset(self):
115 | self.val = 0
116 | self.avg = 0
117 | self.sum = 0
118 | self.count = 0
119 |
120 | def update(self, val, n=1):
121 | self.val = val
122 | self.sum += val * n
123 | self.count += n
124 | self.avg = self.sum / self.count
125 |
126 |
127 |
128 | def evaluate_eval(args, net, optimizer, val_loss, mf_score, hist, dump_images, heatmap_images, writer, epoch=0, dataset=None, ):
129 | '''
130 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for
131 | large dataset) Only applies to eval/eval.py
132 | '''
133 | # axis 0: gt, axis 1: prediction
134 | acc = np.diag(hist).sum() / hist.sum()
135 | acc_cls = np.diag(hist) / hist.sum(axis=1)
136 | acc_cls = np.nanmean(acc_cls)
137 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
138 |
139 | print_evaluate_results(hist, iu, writer, epoch, dataset)
140 | freq = hist.sum(axis=1) / hist.sum()
141 | mean_iu = np.nanmean(iu)
142 | logging.info('mean {}'.format(mean_iu))
143 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
144 | #return acc, acc_cls, mean_iu, fwavacc
145 |
146 | # update latest snapshot
147 | if 'mean_iu' in args.last_record:
148 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format(
149 | args.last_record['epoch'], args.last_record['mean_iu'])
150 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
151 | try:
152 | os.remove(last_snapshot)
153 | except OSError:
154 | pass
155 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format(epoch, mean_iu)
156 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
157 | args.last_record['mean_iu'] = mean_iu
158 | args.last_record['epoch'] = epoch
159 |
160 | torch.cuda.synchronize()
161 |
162 | torch.save({
163 | 'state_dict': net.state_dict(),
164 | 'optimizer': optimizer.state_dict(),
165 | 'epoch': epoch,
166 | 'mean_iu': mean_iu,
167 | 'command': ' '.join(sys.argv[1:])
168 | }, last_snapshot)
169 |
170 | # update best snapshot
171 | if mean_iu > args.best_record['mean_iu'] :
172 | # remove old best snapshot
173 | if args.best_record['epoch'] != -1:
174 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format(
175 | args.best_record['epoch'], args.best_record['mean_iu'])
176 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
177 | assert os.path.exists(best_snapshot), \
178 | 'cant find old snapshot {}'.format(best_snapshot)
179 | os.remove(best_snapshot)
180 |
181 |
182 | # save new best
183 | args.best_record['val_loss'] = val_loss.avg
184 | args.best_record['mask_f1_score'] = mf_score.avg
185 | args.best_record['epoch'] = epoch
186 | args.best_record['acc'] = acc
187 | args.best_record['acc_cls'] = acc_cls
188 | args.best_record['mean_iu'] = mean_iu
189 | args.best_record['fwavacc'] = fwavacc
190 |
191 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format(
192 | args.best_record['epoch'], args.best_record['mean_iu'])
193 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
194 | shutil.copyfile(last_snapshot, best_snapshot)
195 |
196 |
197 | to_save_dir = os.path.join(args.exp_path, 'best_images')
198 | os.makedirs(to_save_dir, exist_ok=True)
199 | ip = ImagePage(epoch, '{}/index.html'.format(to_save_dir))
200 |
201 | val_visual = []
202 |
203 | idx = 0
204 |
205 | visualize = standard_transforms.Compose([
206 | standard_transforms.Scale(384),
207 | standard_transforms.ToTensor()
208 | ])
209 | for bs_idx, bs_data in enumerate(dump_images):
210 | for local_idx, data in enumerate(zip(bs_data[0], bs_data[1],bs_data[2])):
211 | gt_pil = args.dataset_cls.colorize_mask(data[0].cpu().numpy())
212 | pred = data[1].cpu().numpy()
213 | predictions_pil = args.dataset_cls.colorize_mask(pred)
214 | img_name = data[2]
215 |
216 | prediction_fn = '{}_prediction.png'.format(img_name)
217 | predictions_pil.save(os.path.join(to_save_dir, prediction_fn))
218 | gt_fn = '{}_gt.png'.format(img_name)
219 | gt_pil.save(os.path.join(to_save_dir, gt_fn))
220 | ip.add_table([(gt_fn, 'gt'), (prediction_fn, 'prediction')])
221 | val_visual.extend([visualize(gt_pil.convert('RGB')),
222 | visualize(predictions_pil.convert('RGB'))])
223 | idx = idx+1
224 | if idx >= 9:
225 | ip.write_page()
226 | break
227 | for bs_idx, bs_data in enumerate(heatmap_images):
228 | for local_idx, data in enumerate(zip(bs_data[0], bs_data[1],bs_data[2])):
229 |
230 | gt_pil = args.dataset_cls.colorize_mask(data[0].cpu().numpy())
231 |
232 | predictions_pil = data[1].cpu().numpy()
233 | predictions_pil = (predictions_pil / predictions_pil.max()) * 255
234 | predictions_pil = Image.fromarray(predictions_pil.astype(np.uint8))
235 | img_name = data[2]
236 |
237 | prediction_fn = '{}_prediction.png'.format(img_name)
238 | predictions_pil.save(os.path.join(to_save_dir, prediction_fn))
239 | gt_fn = '{}_gt.png'.format(img_name)
240 | gt_pil.save(os.path.join(to_save_dir, gt_fn))
241 | ip.add_table([(gt_fn, 'gt'), (prediction_fn, 'prediction')])
242 | val_visual.extend([visualize(gt_pil.convert('RGB')),
243 | visualize(predictions_pil.convert('RGB'))])
244 | idx = idx+1
245 | if idx >= 9:
246 | ip.write_page()
247 | break
248 |
249 | val_visual = torch.stack(val_visual, 0)
250 | val_visual = vutils.make_grid(val_visual, nrow=10, padding=5)
251 | writer.add_image(last_snapshot, val_visual)
252 |
253 | logging.info('-' * 107)
254 | fmt_str = '[epoch %d], [val loss %.5f], [mask f1 %.5f], [acc %.5f], [acc_cls %.5f], ' +\
255 | '[mean_iu %.5f], [fwavacc %.5f]'
256 | logging.info(fmt_str % (epoch, val_loss.avg, mf_score.avg, acc, acc_cls, mean_iu, fwavacc))
257 | fmt_str = 'best record: [val loss %.5f], [mask f1 %.5f], [acc %.5f], [acc_cls %.5f], ' +\
258 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], '
259 | logging.info(fmt_str % (args.best_record['val_loss'], args.best_record['mask_f1_score'],
260 | args.best_record['acc'],
261 | args.best_record['acc_cls'], args.best_record['mean_iu'],
262 | args.best_record['fwavacc'], args.best_record['epoch']))
263 | logging.info('-' * 107)
264 |
265 | # tensorboard logging of validation phase metrics
266 |
267 | writer.add_scalar('training/acc', acc, epoch)
268 | writer.add_scalar('training/acc_cls', acc_cls, epoch)
269 | writer.add_scalar('training/mean_iu', mean_iu, epoch)
270 | writer.add_scalar('training/val_loss', val_loss.avg, epoch)
271 | writer.add_scalar('training/mask_f1_score', mf_score.avg, epoch)
272 |
273 |
274 | def fast_hist(label_pred, label_true, num_classes):
275 | mask = (label_true >= 0) & (label_true < num_classes)
276 | hist = np.bincount(
277 | num_classes * label_true[mask].astype(int) +
278 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
279 | return hist
280 |
281 |
282 |
283 | def print_evaluate_results(hist, iu, writer=None, epoch=0, dataset=None):
284 | try:
285 | id2cat = dataset.id2cat
286 | except:
287 | id2cat = {i: i for i in range(dataset.num_classes)}
288 | iu_false_positive = hist.sum(axis=1) - np.diag(hist)
289 | iu_false_negative = hist.sum(axis=0) - np.diag(hist)
290 | iu_true_positive = np.diag(hist)
291 |
292 | logging.info('IoU:')
293 | logging.info('label_id label iU Precision Recall TP FP FN')
294 | for idx, i in enumerate(iu):
295 | idx_string = "{:2d}".format(idx)
296 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else ''
297 | iu_string = '{:5.2f}'.format(i * 100)
298 | total_pixels = hist.sum()
299 | tp = '{:5.2f}'.format(100 * iu_true_positive[idx] / total_pixels)
300 | fp = '{:5.2f}'.format(
301 | iu_false_positive[idx] / iu_true_positive[idx])
302 | fn = '{:5.2f}'.format(iu_false_negative[idx] / iu_true_positive[idx])
303 | precision = '{:5.2f}'.format(
304 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx]))
305 | recall = '{:5.2f}'.format(
306 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx]))
307 | logging.info('{} {} {} {} {} {} {} {}'.format(
308 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn))
309 |
310 | writer.add_scalar('val_class_iu/{}'.format(id2cat[idx]), i * 100, epoch)
311 | writer.add_scalar('val_class_precision/{}'.format(id2cat[idx]),
312 | iu_true_positive[idx] / (iu_true_positive[idx] +
313 | iu_false_positive[idx]), epoch)
314 | writer.add_scalar('val_class_recall/{}'.format(id2cat[idx]),
315 | iu_true_positive[idx] / (iu_true_positive[idx] +
316 | iu_false_negative[idx]), epoch)
317 |
--------------------------------------------------------------------------------