├── README.md
├── assets
└── comparison.png
├── config.py
├── datasets
├── __init__.py
├── bdd100k.py
├── cityscapes.py
├── cityscapes_labels.py
├── gtav.py
├── imagenet.py
├── mapillary.py
├── multi_loader.py
├── nullloader.py
├── sampler.py
├── synthia.py
└── uniform.py
├── loss.py
├── network
├── Resnet.py
├── __init__.py
├── adain.py
├── cel.py
├── deepv3.py
└── mynn.py
├── optimizer.py
├── scripts
├── train_wildnet_r50os16_gtav.sh
└── valid_wildnet_r50os16_gtav.sh
├── split_data
├── gtav_split_test.txt
├── gtav_split_train.txt
├── gtav_split_val.txt
├── synthia_split_train.txt
└── synthia_split_val.txt
├── train.py
├── transforms
├── __init__.py
├── joint_transforms.py
└── transforms.py
├── utils
├── __init__.py
├── attr_dict.py
├── misc.py
└── my_data_parallel.py
└── valid.py
/README.md:
--------------------------------------------------------------------------------
1 | ## WildNet (CVPR 2022): Official Project Webpage
2 | This repository provides the official PyTorch implementation of the following paper:
3 | > [**WildNet: Learning Domain Generalized Semantic Segmentation from the Wild**](https://arxiv.org/abs/2204.01446)
4 | > [Suhyeon Lee](https://suhyeonlee.github.io/), [Hongje Seong](https://hongje.github.io/), [Seongwon Lee](https://sungonce.github.io/), [Euntai Kim](https://cilab.yonsei.ac.kr/)
5 | > Yonsei University
6 |
7 | > **Abstract:**
8 | *We present a new domain generalized semantic segmentation network named WildNet, which learns domain-generalized features by leveraging a variety of contents and styles from the wild.
9 | In domain generalization, the low generalization ability for unseen target domains is clearly due to overfitting to the source domain.
10 | To address this problem, previous works have focused on generalizing the domain by removing or diversifying the styles of the source domain.
11 | These alleviated overfitting to the source-style but overlooked overfitting to the source-content.
12 | In this paper, we propose to diversify both the content and style of the source domain with the help of the wild.
13 | Our main idea is for networks to naturally learn domain-generalized semantic information from the wild.
14 | To this end, we diversify styles by augmenting source features to resemble wild styles and enable networks to adapt to a variety of styles.
15 | Furthermore, we encourage networks to learn class-discriminant features by providing semantic variations borrowed from the wild to source contents in the feature space.
16 | Finally, we regularize networks to capture consistent semantic information even when both the content and style of the source domain are extended to the wild.
17 | Extensive experiments on five different datasets validate the effectiveness of our WildNet, and we significantly outperform state-of-the-art methods.*
18 |
19 |
20 |
21 |
22 |
23 | ## Pytorch Implementation
24 | Our pytorch implementation is heavily derived from [RobustNet](https://github.com/shachoi/RobustNet) (CVPR 2021). If you use this code in your research, please also cite their work.
25 | [[link to license](https://github.com/shachoi/RobustNet/blob/main/LICENSE)]
26 |
27 | ### Installation
28 | Clone this repository.
29 | ```
30 | git clone https://github.com/suhyeonlee/WildNet.git
31 | cd WildNet
32 | ```
33 | Install following packages.
34 | ```
35 | conda create --name wildnet python=3.7
36 | conda activate wildnet
37 | conda install pytorch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1 cudatoolkit=11.1 -c pytorch
38 | conda install scipy==1.1.0
39 | conda install tqdm==4.46.0
40 | conda install scikit-image==0.16.2
41 | pip install tensorboardX
42 | pip install thop
43 | pip install kmeans1d
44 | imageio_download_bin freeimage
45 | ```
46 | ### How to Run WildNet
47 | We trained our model with the source domain ([GTAV](https://download.visinf.tu-darmstadt.de/data/from_games/) or [Cityscapes](https://www.cityscapes-dataset.com/)) and the wild domain ([ImageNet](https://www.image-net.org/)).
48 | Then we evaluated the model on [Cityscapes](https://www.cityscapes-dataset.com/), [BDD-100K](https://bair.berkeley.edu/blog/2018/05/30/bdd/), [Synthia](https://synthia-dataset.net/downloads/) ([SYNTHIA-RAND-CITYSCAPES](http://synthia-dataset.net/download/808/)), [GTAV](https://download.visinf.tu-darmstadt.de/data/from_games/) and [Mapillary Vistas](https://www.mapillary.com/dataset/vistas?pKey=2ix3yvnjy9fwqdzwum3t9g&lat=20&lng=0&z=1.5).
49 |
50 | We adopt Class uniform sampling proposed in [this paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhu_Improving_Semantic_Segmentation_via_Video_Propagation_and_Label_Relaxation_CVPR_2019_paper.pdf) to handle class imbalance problems.
51 |
52 |
53 | 1. For Cityscapes dataset, download "leftImg8bit_trainvaltest.zip" and "gtFine_trainvaltest.zip" from https://www.cityscapes-dataset.com/downloads/
54 | Unzip the files and make the directory structures as follows.
55 | ```
56 | cityscapes
57 | └ leftImg8bit_trainvaltest
58 | └ leftImg8bit
59 | └ train
60 | └ val
61 | └ test
62 | └ gtFine_trainvaltest
63 | └ gtFine
64 | └ train
65 | └ val
66 | └ test
67 | ```
68 | ```
69 | bdd-100k
70 | └ images
71 | └ train
72 | └ val
73 | └ test
74 | └ labels
75 | └ train
76 | └ val
77 | ```
78 | ```
79 | mapillary
80 | └ training
81 | └ images
82 | └ labels
83 | └ validation
84 | └ images
85 | └ labels
86 | └ test
87 | └ images
88 | └ labels
89 | ```
90 | ```
91 | imagenet
92 | └ data
93 | └ train
94 | └ val
95 | ```
96 |
97 | 2. We used [GTAV_Split](https://download.visinf.tu-darmstadt.de/data/from_games/code/read_mapping.zip) to split GTAV dataset into training/validation/test set. Please refer the txt files in [split_data](https://github.com/suhyeonlee/WildNet/tree/main/split_data).
98 |
99 | ```
100 | GTAV
101 | └ images
102 | └ train
103 | └ folder
104 | └ valid
105 | └ folder
106 | └ test
107 | └ folder
108 | └ labels
109 | └ train
110 | └ folder
111 | └ valid
112 | └ folder
113 | └ test
114 | └ folder
115 | ```
116 |
117 | 3. We split [Synthia dataset](http://synthia-dataset.net/download/808/) into train/val set following the [RobustNet](https://github.com/shachoi/RobustNet). Please refer the txt files in [split_data](https://github.com/suhyeonlee/WildNet/tree/main/split_data).
118 |
119 | ```
120 | synthia
121 | └ RGB
122 | └ train
123 | └ val
124 | └ GT
125 | └ COLOR
126 | └ train
127 | └ val
128 | └ LABELS
129 | └ train
130 | └ val
131 | ```
132 |
133 | 4. You should modify the path in **"/config.py"** according to your dataset path.
134 | ```
135 | #Cityscapes Dir Location
136 | __C.DATASET.CITYSCAPES_DIR =
137 | #Mapillary Dataset Dir Location
138 | __C.DATASET.MAPILLARY_DIR =
139 | #GTAV Dataset Dir Location
140 | __C.DATASET.GTAV_DIR =
141 | #BDD-100K Dataset Dir Location
142 | __C.DATASET.BDD_DIR =
143 | #Synthia Dataset Dir Location
144 | __C.DATASET.SYNTHIA_DIR =
145 | #ImageNet Dataset Dir Location
146 | __C.DATASET.ImageNet_DIR =
147 | ```
148 |
149 | 5. You can train WildNet with the following command.
150 | ```
151 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/train_wildnet_r50os16_gtav.sh
152 | ```
153 |
154 | 6. You can download our ResNet-50 model at [Google Drive](https://drive.google.com/file/d/16V_cWtVbJuJoQ-DJq4Yui7Umf4wggemu/view) and validate pretrained model with the following command.
155 | ```
156 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/valid_wildnet_r50os16_gtav.sh
157 | ```
158 |
159 | ## Citation
160 | If you find this work useful in your research, please cite our paper:
161 | ```
162 | @inproceedings{lee2022wildnet,
163 | title={WildNet: Learning Domain Generalized Semantic Segmentation from the Wild},
164 | author={Lee, Suhyeon and Seong, Hongje and Lee, Seongwon and Kim, Euntai},
165 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
166 | year={2022}
167 | }
168 | ```
169 |
170 | ## Terms of Use
171 | This software is for non-commercial use only.
172 | The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence
173 | (see [this](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for details)
174 |
175 |
--------------------------------------------------------------------------------
/assets/comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suhyeonlee/WildNet/c83aa92b7cd591512045fc5093da25a280bde430/assets/comparison.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py
4 |
5 | Source License
6 | # Copyright (c) 2017-present, Facebook, Inc.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | ##############################################################################
20 | #
21 | # Based on:
22 | # --------------------------------------------------------
23 | # Fast R-CNN
24 | # Copyright (c) 2015 Microsoft
25 | # Licensed under The MIT License [see LICENSE for details]
26 | # Written by Ross Girshick
27 | # --------------------------------------------------------
28 | """
29 | ##############################################################################
30 | #Config
31 | ##############################################################################
32 |
33 |
34 | from __future__ import absolute_import
35 | from __future__ import division
36 | from __future__ import print_function
37 | from __future__ import unicode_literals
38 |
39 |
40 | import torch
41 |
42 |
43 | from utils.attr_dict import AttrDict
44 |
45 |
46 | __C = AttrDict()
47 | cfg = __C
48 | __C.ITER = 0
49 | __C.EPOCH = 0
50 |
51 | __C.RANDOM_SEED = 304
52 | # Use Class Uniform Sampling to give each class proper sampling
53 | __C.CLASS_UNIFORM_PCT = 0.0
54 |
55 | # Use class weighted loss per batch to increase loss for low pixel count classes per batch
56 | __C.BATCH_WEIGHTING = False
57 |
58 | # Border Relaxation Count
59 | __C.BORDER_WINDOW = 1
60 | # Number of epoch to use before turn off border restriction
61 | __C.REDUCE_BORDER_ITER = -1
62 | __C.REDUCE_BORDER_EPOCH = -1
63 | # Comma Seperated List of class id to relax
64 | __C.STRICTBORDERCLASS = None
65 |
66 | #Attribute Dictionary for Dataset
67 | __C.DATASET = AttrDict()
68 | #Cityscapes Dir Location
69 | __C.DATASET.CITYSCAPES_DIR = '/mnt/hdd2/DGSS/cityscapes'
70 | #SDC Augmented Cityscapes Dir Location
71 | __C.DATASET.CITYSCAPES_AUG_DIR = ''
72 | #Mapillary Dataset Dir Location
73 | __C.DATASET.MAPILLARY_DIR = '/mnt/hdd2/DGSS/mapillary'
74 | #GTAV Dataset Dir Location
75 | __C.DATASET.GTAV_DIR = '/mnt/hdd2/DGSS/gta5'
76 | #BDD-100K Dataset Dir Location
77 | __C.DATASET.BDD_DIR = '/mnt/hdd2/DGSS/bdd100k/seg'
78 | #Synthia Dataset Dir Location
79 | __C.DATASET.SYNTHIA_DIR = '/mnt/hdd2/DGSS/synthia'
80 | #ImageNet Dataset Dir Location
81 | __C.DATASET.IMAGENET_DIR ='/mnt/hdd2/ImageNet/data'
82 | #Kitti Dataset Dir Location
83 | __C.DATASET.KITTI_DIR = ''
84 | #SDC Augmented Kitti Dataset Dir Location
85 | __C.DATASET.KITTI_AUG_DIR = ''
86 | #Camvid Dataset Dir Location
87 | __C.DATASET.CAMVID_DIR = ''
88 | #Number of splits to support
89 | __C.DATASET.CV_SPLITS = 3
90 |
91 |
92 | __C.MODEL = AttrDict()
93 | __C.MODEL.BN = 'pytorch-syncnorm'
94 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm
95 |
96 | def assert_and_infer_cfg(args, make_immutable=True, train_mode=True):
97 | """Call this function in your script after you have finished setting all cfg
98 | values that are necessary (e.g., merging a config from a file, merging
99 | command line config options, etc.). By default, this function will also
100 | mark the global cfg as immutable to prevent changing the global cfg settings
101 | during script execution (which can lead to hard to debug errors or code
102 | that's harder to understand than is necessary).
103 | """
104 |
105 | if hasattr(args, 'syncbn') and args.syncbn:
106 | __C.MODEL.BN = 'pytorch-syncnorm'
107 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm
108 | print('Using pytorch sync batch norm')
109 | else:
110 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d
111 | print('Using regular batch norm')
112 |
113 | if not train_mode:
114 | cfg.immutable(True)
115 | return
116 | if args.class_uniform_pct:
117 | cfg.CLASS_UNIFORM_PCT = args.class_uniform_pct
118 |
119 | if args.batch_weighting:
120 | __C.BATCH_WEIGHTING = True
121 |
122 | if args.jointwtborder:
123 | if args.strict_bdr_cls != '':
124 | __C.STRICTBORDERCLASS = [int(i) for i in args.strict_bdr_cls.split(",")]
125 | if args.rlx_off_iter > -1:
126 | __C.REDUCE_BORDER_ITER = args.rlx_off_iter
127 |
128 | if make_immutable:
129 | cfg.immutable(True)
130 |
--------------------------------------------------------------------------------
/datasets/bdd100k.py:
--------------------------------------------------------------------------------
1 | """
2 | BDD100K Dataset Loader
3 | """
4 | import logging
5 | import json
6 | import os
7 | import numpy as np
8 | from PIL import Image
9 | from skimage import color
10 |
11 | from torch.utils import data
12 | import torch
13 | import torchvision.transforms as transforms
14 | import datasets.uniform as uniform
15 | import datasets.cityscapes_labels as cityscapes_labels
16 |
17 | from config import cfg
18 |
19 | trainid_to_name = cityscapes_labels.trainId2name
20 | id_to_trainid = cityscapes_labels.label2trainid
21 | trainid_to_trainid = cityscapes_labels.trainId2trainId
22 | color_to_trainid = cityscapes_labels.color2trainId
23 | num_classes = 19
24 | ignore_label = 255
25 | root = cfg.DATASET.BDD_DIR
26 | img_postfix = '.jpg'
27 |
28 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153,
29 | 153, 153, 153, 250, 170, 30,
30 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60,
31 | 255, 0, 0, 0, 0, 142, 0, 0, 70,
32 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
33 | zero_pad = 256 * 3 - len(palette)
34 | for i in range(zero_pad):
35 | palette.append(0)
36 |
37 |
38 | def colorize_mask(mask):
39 | """
40 | Colorize a segmentation mask.
41 | """
42 | # mask: numpy array of the mask
43 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
44 | new_mask.putpalette(palette)
45 | return new_mask
46 |
47 |
48 | def add_items(items, aug_items, img_path, mask_path, mask_postfix, mode, maxSkip):
49 | """
50 |
51 | Add More items ot the list from the augmented dataset
52 | """
53 |
54 | if mode == "train":
55 | img_path = os.path.join(img_path, 'train')
56 | mask_path = os.path.join(mask_path, 'train')
57 | elif mode == "val":
58 | img_path = os.path.join(img_path, 'val')
59 | mask_path = os.path.join(mask_path, 'val')
60 |
61 | list_items = [name.split(img_postfix)[0] for name in
62 | os.listdir(img_path)]
63 | for it in list_items:
64 | item = (os.path.join(img_path, it + img_postfix),
65 | os.path.join(mask_path, it + mask_postfix))
66 | # ########################################################
67 | # ###### dataset augmentation ############################
68 | # ########################################################
69 | # if mode == "train" and maxSkip > 0:
70 | # new_img_path = os.path.join(aug_root, 'leftImg8bit_trainvaltest', 'leftImg8bit')
71 | # new_mask_path = os.path.join(aug_root, 'gtFine_trainvaltest', 'gtFine')
72 | # file_info = it.split("_")
73 | # cur_seq_id = file_info[-1]
74 |
75 | # prev_seq_id = "%06d" % (int(cur_seq_id) - maxSkip)
76 | # next_seq_id = "%06d" % (int(cur_seq_id) + maxSkip)
77 | # prev_it = file_info[0] + "_" + file_info[1] + "_" + prev_seq_id
78 | # next_it = file_info[0] + "_" + file_info[1] + "_" + next_seq_id
79 | # prev_item = (os.path.join(new_img_path, c, prev_it + img_postfix),
80 | # os.path.join(new_mask_path, c, prev_it + mask_postfix))
81 | # if os.path.isfile(prev_item[0]) and os.path.isfile(prev_item[1]):
82 | # aug_items.append(prev_item)
83 | # next_item = (os.path.join(new_img_path, c, next_it + img_postfix),
84 | # os.path.join(new_mask_path, c, next_it + mask_postfix))
85 | # if os.path.isfile(next_item[0]) and os.path.isfile(next_item[1]):
86 | # aug_items.append(next_item)
87 | items.append(item)
88 | # items.extend(extra_items)
89 |
90 |
91 | def make_cv_splits(img_dir_name):
92 | """
93 | Create splits of train/val data.
94 | A split is a lists of cities.
95 | split0 is aligned with the default Cityscapes train/val.
96 | """
97 | trn_path = os.path.join(root, img_dir_name, 'train')
98 | val_path = os.path.join(root, img_dir_name, 'val')
99 |
100 | trn_cities = ['train/' + c for c in os.listdir(trn_path)]
101 | val_cities = ['val/' + c for c in os.listdir(val_path)]
102 |
103 | # want reproducible randomly shuffled
104 | trn_cities = sorted(trn_cities)
105 |
106 | all_cities = val_cities + trn_cities
107 | num_val_cities = len(val_cities)
108 | num_cities = len(all_cities)
109 |
110 | cv_splits = []
111 | for split_idx in range(cfg.DATASET.CV_SPLITS):
112 | split = {}
113 | split['train'] = []
114 | split['val'] = []
115 | offset = split_idx * num_cities // cfg.DATASET.CV_SPLITS
116 | for j in range(num_cities):
117 | if j >= offset and j < (offset + num_val_cities):
118 | split['val'].append(all_cities[j])
119 | else:
120 | split['train'].append(all_cities[j])
121 | cv_splits.append(split)
122 |
123 | return cv_splits
124 |
125 |
126 | def make_split_coarse(img_path):
127 | """
128 | Create a train/val split for coarse
129 | return: city split in train
130 | """
131 | all_cities = os.listdir(img_path)
132 | all_cities = sorted(all_cities) # needs to always be the same
133 | val_cities = [] # Can manually set cities to not be included into train split
134 |
135 | split = {}
136 | split['val'] = val_cities
137 | split['train'] = [c for c in all_cities if c not in val_cities]
138 | return split
139 |
140 |
141 | def make_test_split(img_dir_name):
142 | test_path = os.path.join(root, img_dir_name, 'leftImg8bit', 'test')
143 | test_cities = ['test/' + c for c in os.listdir(test_path)]
144 |
145 | return test_cities
146 |
147 |
148 | def make_dataset(mode, maxSkip=0, cv_split=0):
149 | """
150 | Assemble list of images + mask files
151 |
152 | fine - modes: train/val/test/trainval cv:0,1,2
153 | coarse - modes: train/val cv:na
154 |
155 | path examples:
156 | leftImg8bit_trainextra/leftImg8bit/train_extra/augsburg
157 | gtCoarse/gtCoarse/train_extra/augsburg
158 | """
159 | items = []
160 | aug_items = []
161 |
162 | assert mode in ['train', 'val', 'test', 'trainval']
163 | img_dir_name = 'images'
164 | img_path = os.path.join(root, img_dir_name)
165 | mask_path = os.path.join(root, 'labels')
166 | mask_postfix = '_train_id.png'
167 | # cv_splits = make_cv_splits(img_dir_name)
168 | if mode == 'trainval':
169 | modes = ['train', 'val']
170 | else:
171 | modes = [mode]
172 | for mode in modes:
173 | logging.info('{} fine cities: '.format(mode))
174 | add_items(items, aug_items, img_path, mask_path,
175 | mask_postfix, mode, maxSkip)
176 |
177 | # logging.info('Cityscapes-{}: {} images'.format(mode, len(items)))
178 | logging.info('BDD100K-{}: {} images'.format(mode, len(items) + len(aug_items)))
179 | return items, aug_items
180 |
181 |
182 | class BDD100K(data.Dataset):
183 |
184 | def __init__(self, mode, maxSkip=0, joint_transform=None, sliding_crop=None,
185 | transform=None, target_transform=None, target_aux_transform=None, dump_images=False,
186 | cv_split=None, eval_mode=False,
187 | eval_scales=None, eval_flip=False, image_in=False,
188 | extract_feature=False):
189 | self.mode = mode
190 | self.maxSkip = maxSkip
191 | self.joint_transform = joint_transform
192 | self.sliding_crop = sliding_crop
193 | self.transform = transform
194 | self.target_transform = target_transform
195 | self.target_aux_transform = target_aux_transform
196 | self.dump_images = dump_images
197 | self.eval_mode = eval_mode
198 | self.eval_flip = eval_flip
199 | self.eval_scales = None
200 | self.image_in = image_in
201 | self.extract_feature = extract_feature
202 |
203 |
204 | if eval_scales != None:
205 | self.eval_scales = [float(scale) for scale in eval_scales.split(",")]
206 |
207 | if cv_split:
208 | self.cv_split = cv_split
209 | assert cv_split < cfg.DATASET.CV_SPLITS, \
210 | 'expected cv_split {} to be < CV_SPLITS {}'.format(
211 | cv_split, cfg.DATASET.CV_SPLITS)
212 | else:
213 | self.cv_split = 0
214 | self.imgs, _ = make_dataset(mode, self.maxSkip, cv_split=self.cv_split)
215 | if len(self.imgs) == 0:
216 | raise RuntimeError('Found 0 images, please check the data set')
217 |
218 | self.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
219 |
220 | def _eval_get_item(self, img, mask, scales, flip_bool):
221 | return_imgs = []
222 | for flip in range(int(flip_bool) + 1):
223 | imgs = []
224 | if flip:
225 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
226 | for scale in scales:
227 | w, h = img.size
228 | target_w, target_h = int(w * scale), int(h * scale)
229 | resize_img = img.resize((target_w, target_h))
230 | tensor_img = transforms.ToTensor()(resize_img)
231 | final_tensor = transforms.Normalize(*self.mean_std)(tensor_img)
232 | imgs.append(final_tensor)
233 | return_imgs.append(imgs)
234 | return return_imgs, mask
235 |
236 | def __getitem__(self, index):
237 |
238 | img_path, mask_path = self.imgs[index]
239 |
240 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
241 | img_name = os.path.splitext(os.path.basename(img_path))[0]
242 |
243 | mask = np.array(mask)
244 | mask_copy = mask.copy()
245 | for k, v in trainid_to_trainid.items():
246 | mask_copy[mask == k] = v
247 |
248 | if self.eval_mode:
249 | return [transforms.ToTensor()(img)], self._eval_get_item(img, mask_copy,
250 | self.eval_scales,
251 | self.eval_flip), img_name
252 |
253 | mask = Image.fromarray(mask_copy.astype(np.uint8))
254 |
255 | # Image Transformations
256 | if self.extract_feature is not True:
257 | if self.joint_transform is not None:
258 | img, mask = self.joint_transform(img, mask)
259 |
260 | if self.transform is not None:
261 | img = self.transform(img)
262 |
263 | rgb_mean_std_gt = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
264 | img_gt = transforms.Normalize(*rgb_mean_std_gt)(img)
265 |
266 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
267 | if self.image_in:
268 | eps = 1e-5
269 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])],
270 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps])
271 | img = transforms.Normalize(*rgb_mean_std)(img)
272 |
273 | if self.target_aux_transform is not None:
274 | mask_aux = self.target_aux_transform(mask)
275 | else:
276 | mask_aux = torch.tensor([0])
277 | if self.target_transform is not None:
278 | mask = self.target_transform(mask)
279 |
280 | # Debug
281 | if self.dump_images:
282 | outdir = '../../dump_imgs_{}'.format(self.mode)
283 | os.makedirs(outdir, exist_ok=True)
284 | out_img_fn = os.path.join(outdir, img_name + '.png')
285 | out_msk_fn = os.path.join(outdir, img_name + '_mask.png')
286 | mask_img = colorize_mask(np.array(mask))
287 | img.save(out_img_fn)
288 | mask_img.save(out_msk_fn)
289 |
290 | return img, mask, img_name, mask_aux
291 |
292 | def __len__(self):
293 | return len(self.imgs)
294 |
295 | class BDD100KUniform(data.Dataset):
296 | """
297 | Please do not use this for AGG
298 | """
299 |
300 | def __init__(self, mode, maxSkip=0, joint_transform_list=None, sliding_crop=None,
301 | transform=None, target_transform=None, target_aux_transform=None, dump_images=False,
302 | cv_split=None, class_uniform_pct=0.5, class_uniform_tile=1024,
303 | test=False, coarse_boost_classes=None, image_in=False, extract_feature=False):
304 | self.mode = mode
305 | self.maxSkip = maxSkip
306 | self.joint_transform_list = joint_transform_list
307 | self.sliding_crop = sliding_crop
308 | self.transform = transform
309 | self.target_transform = target_transform
310 | self.target_aux_transform = target_aux_transform
311 | self.dump_images = dump_images
312 | self.class_uniform_pct = class_uniform_pct
313 | self.class_uniform_tile = class_uniform_tile
314 | self.coarse_boost_classes = coarse_boost_classes
315 | self.image_in = image_in
316 | self.extract_feature = extract_feature
317 |
318 |
319 | if cv_split:
320 | self.cv_split = cv_split
321 | assert cv_split < cfg.DATASET.CV_SPLITS, \
322 | 'expected cv_split {} to be < CV_SPLITS {}'.format(
323 | cv_split, cfg.DATASET.CV_SPLITS)
324 | else:
325 | self.cv_split = 0
326 |
327 | self.imgs, self.aug_imgs = make_dataset(mode, self.maxSkip, cv_split=self.cv_split)
328 | assert len(self.imgs), 'Found 0 images, please check the data set'
329 |
330 | # Centroids for fine data
331 | json_fn = 'bdd100k_{}_cv{}_tile{}.json'.format(
332 | self.mode, self.cv_split, self.class_uniform_tile)
333 | if os.path.isfile(json_fn):
334 | with open(json_fn, 'r') as json_data:
335 | centroids = json.load(json_data)
336 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
337 | else:
338 | self.centroids = uniform.class_centroids_all(
339 | self.imgs,
340 | num_classes,
341 | id2trainid=trainid_to_trainid,
342 | tile_size=class_uniform_tile)
343 | with open(json_fn, 'w') as outfile:
344 | json.dump(self.centroids, outfile, indent=4)
345 |
346 | self.fine_centroids = self.centroids.copy()
347 |
348 | self.build_epoch()
349 |
350 | def cities_uniform(self, imgs, name):
351 | """ list out cities in imgs_uniform """
352 | cities = {}
353 | for item in imgs:
354 | img_fn = item[0]
355 | img_fn = os.path.basename(img_fn)
356 | city = img_fn.split('_')[0]
357 | cities[city] = 1
358 | city_names = cities.keys()
359 | logging.info('Cities for {} '.format(name) + str(sorted(city_names)))
360 |
361 | def build_epoch(self, cut=False):
362 | """
363 | Perform Uniform Sampling per epoch to create a new list for training such that it
364 | uniformly samples all classes
365 | """
366 | if self.class_uniform_pct > 0:
367 | if cut:
368 | # after max_cu_epoch, we only fine images to fine tune
369 | self.imgs_uniform = uniform.build_epoch(self.imgs,
370 | self.fine_centroids,
371 | num_classes,
372 | cfg.CLASS_UNIFORM_PCT)
373 | else:
374 | self.imgs_uniform = uniform.build_epoch(self.imgs + self.aug_imgs,
375 | self.centroids,
376 | num_classes,
377 | cfg.CLASS_UNIFORM_PCT)
378 | else:
379 | self.imgs_uniform = self.imgs
380 |
381 | def __getitem__(self, index):
382 | elem = self.imgs_uniform[index]
383 | centroid = None
384 | if len(elem) == 4:
385 | img_path, mask_path, centroid, class_id = elem
386 | else:
387 | img_path, mask_path = elem
388 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
389 | img_name = os.path.splitext(os.path.basename(img_path))[0]
390 |
391 | mask = np.array(mask)
392 | mask_copy = mask.copy()
393 | for k, v in trainid_to_trainid.items():
394 | mask_copy[mask == k] = v
395 | mask = Image.fromarray(mask_copy.astype(np.uint8))
396 |
397 | # Image Transformations
398 | if self.extract_feature is not True:
399 | if self.joint_transform_list is not None:
400 | for idx, xform in enumerate(self.joint_transform_list):
401 | if idx == 0 and centroid is not None:
402 | # HACK
403 | # We assume that the first transform is capable of taking
404 | # in a centroid
405 | img, mask = xform(img, mask, centroid)
406 | else:
407 | img, mask = xform(img, mask)
408 |
409 | # Debug
410 | if self.dump_images and centroid is not None:
411 | outdir = '../../dump_imgs_{}'.format(self.mode)
412 | os.makedirs(outdir, exist_ok=True)
413 | dump_img_name = trainid_to_name[class_id] + '_' + img_name
414 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
415 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
416 | mask_img = colorize_mask(np.array(mask))
417 | img.save(out_img_fn)
418 | mask_img.save(out_msk_fn)
419 |
420 | if self.transform is not None:
421 | img = self.transform(img)
422 |
423 | rgb_mean_std_gt = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
424 | img_gt = transforms.Normalize(*rgb_mean_std_gt)(img)
425 |
426 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
427 | if self.image_in:
428 | eps = 1e-5
429 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])],
430 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps])
431 | img = transforms.Normalize(*rgb_mean_std)(img)
432 |
433 | if self.target_aux_transform is not None:
434 | mask_aux = self.target_aux_transform(mask)
435 | else:
436 | mask_aux = torch.tensor([0])
437 | if self.target_transform is not None:
438 | mask = self.target_transform(mask)
439 |
440 | return img, mask, img_name, mask_aux
441 |
442 | def __len__(self):
443 | return len(self.imgs_uniform)
444 |
445 | class BDD100KAug(data.Dataset):
446 |
447 | def __init__(self, mode, maxSkip=0, joint_transform=None, sliding_crop=None,
448 | transform=None, color_transform=None, geometric_transform=None, target_transform=None, target_aux_transform=None, dump_images=False,
449 | cv_split=None, eval_mode=False,
450 | eval_scales=None, eval_flip=False, image_in=False,
451 | extract_feature=False):
452 | self.mode = mode
453 | self.maxSkip = maxSkip
454 | self.joint_transform = joint_transform
455 | self.sliding_crop = sliding_crop
456 | self.transform = transform
457 | self.color_transform = color_transform
458 | self.geometric_transform = geometric_transform
459 | self.target_transform = target_transform
460 | self.target_aux_transform = target_aux_transform
461 | self.dump_images = dump_images
462 | self.eval_mode = eval_mode
463 | self.eval_flip = eval_flip
464 | self.eval_scales = None
465 | self.image_in = image_in
466 | self.extract_feature = extract_feature
467 |
468 |
469 | if eval_scales != None:
470 | self.eval_scales = [float(scale) for scale in eval_scales.split(",")]
471 |
472 | if cv_split:
473 | self.cv_split = cv_split
474 | assert cv_split < cfg.DATASET.CV_SPLITS, \
475 | 'expected cv_split {} to be < CV_SPLITS {}'.format(
476 | cv_split, cfg.DATASET.CV_SPLITS)
477 | else:
478 | self.cv_split = 0
479 | self.imgs, _ = make_dataset(mode, self.maxSkip, cv_split=self.cv_split)
480 | if len(self.imgs) == 0:
481 | raise RuntimeError('Found 0 images, please check the data set')
482 |
483 | self.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
484 |
485 | def _eval_get_item(self, img, mask, scales, flip_bool):
486 | return_imgs = []
487 | for flip in range(int(flip_bool) + 1):
488 | imgs = []
489 | if flip:
490 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
491 | for scale in scales:
492 | w, h = img.size
493 | target_w, target_h = int(w * scale), int(h * scale)
494 | resize_img = img.resize((target_w, target_h))
495 | tensor_img = transforms.ToTensor()(resize_img)
496 | final_tensor = transforms.Normalize(*self.mean_std)(tensor_img)
497 | imgs.append(final_tensor)
498 | return_imgs.append(imgs)
499 | return return_imgs, mask
500 |
501 | def __getitem__(self, index):
502 |
503 | img_path, mask_path = self.imgs[index]
504 |
505 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
506 | img_name = os.path.splitext(os.path.basename(img_path))[0]
507 |
508 | mask = np.array(mask)
509 | mask_copy = mask.copy()
510 | for k, v in trainid_to_trainid.items():
511 | mask_copy[mask == k] = v
512 |
513 | if self.eval_mode:
514 | return [transforms.ToTensor()(img)], self._eval_get_item(img, mask_copy,
515 | self.eval_scales,
516 | self.eval_flip), img_name
517 |
518 | mask = Image.fromarray(mask_copy.astype(np.uint8))
519 |
520 | if self.joint_transform is not None:
521 | img, mask = self.joint_transform(img, mask)
522 |
523 | if self.transform is not None:
524 | img_or = self.transform(img)
525 |
526 | if self.color_transform is not None:
527 | img_color = self.color_transform(img)
528 |
529 | if self.geometric_transform is not None:
530 | img_geometric = self.geometric_transform(img)
531 |
532 | rgb_mean_std_or = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
533 | rgb_mean_std_color = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
534 | rgb_mean_std_geometric = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
535 | if self.image_in:
536 | eps = 1e-5
537 | rgb_mean_std_or = ([torch.mean(img_or[0]), torch.mean(img_or[1]), torch.mean(img_or[2])],
538 | [torch.std(img_or[0])+eps, torch.std(img_or[1])+eps, torch.std(img_or[2])+eps])
539 | rgb_mean_std_color = ([torch.mean(img_color[0]), torch.mean(img_color[1]), torch.mean(img_color[2])],
540 | [torch.std(img_color[0])+eps, torch.std(img_color[1])+eps, torch.std(img_color[2])+eps])
541 | rgb_mean_std_geometric = ([torch.mean(img_geometric[0]), torch.mean(img_geometric[1]), torch.mean(img_geometric[2])],
542 | [torch.std(img_geometric[0])+eps, torch.std(img_geometric[1])+eps, torch.std(img_geometric[2])+eps])
543 | img_or = transforms.Normalize(*rgb_mean_std_or)(img_or)
544 | img_color = transforms.Normalize(*rgb_mean_std_color)(img_color)
545 | img_geometric = transforms.Normalize(*rgb_mean_std_geometric)(img_geometric)
546 |
547 | return img_or, img_color, img_geometric, img_name
548 |
549 | def __len__(self):
550 | return len(self.imgs)
551 |
552 |
--------------------------------------------------------------------------------
/datasets/cityscapes_labels.py:
--------------------------------------------------------------------------------
1 | """
2 | # File taken from https://github.com/mcordts/cityscapesScripts/
3 | # License File Available at:
4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt
5 |
6 | # ----------------------
7 | # The Cityscapes Dataset
8 | # ----------------------
9 | #
10 | #
11 | # License agreement
12 | # -----------------
13 | #
14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree:
15 | #
16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions.
17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website.
18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character.
19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain.
20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt).
21 | #
22 | #
23 | # Contact
24 | # -------
25 | #
26 | # Marius Cordts, Mohamed Omran
27 | # www.cityscapes-dataset.net
28 |
29 | """
30 | from collections import namedtuple
31 |
32 |
33 | #--------------------------------------------------------------------------------
34 | # Definitions
35 | #--------------------------------------------------------------------------------
36 |
37 | # a label and all meta information
38 | Label = namedtuple( 'Label' , [
39 |
40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
41 | # We use them to uniquely name a class
42 |
43 | 'id' , # An integer ID that is associated with this label.
44 | # The IDs are used to represent the label in ground truth images
45 | # An ID of -1 means that this label does not have an ID and thus
46 | # is ignored when creating ground truth images (e.g. license plate).
47 | # Do not modify these IDs, since exactly these IDs are expected by the
48 | # evaluation server.
49 |
50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
51 | # ground truth images with train IDs, using the tools provided in the
52 | # 'preparation' folder. However, make sure to validate or submit results
53 | # to our evaluation server using the regular IDs above!
54 | # For trainIds, multiple labels might have the same ID. Then, these labels
55 | # are mapped to the same class in the ground truth images. For the inverse
56 | # mapping, we use the label that is defined first in the list below.
57 | # For example, mapping all void-type classes to the same ID in training,
58 | # might make sense for some approaches.
59 | # Max value is 255!
60 |
61 | 'category' , # The name of the category that this label belongs to
62 |
63 | 'categoryId' , # The ID of this category. Used to create ground truth images
64 | # on category level.
65 |
66 | 'hasInstances', # Whether this label distinguishes between single instances or not
67 |
68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
69 | # during evaluations or not
70 |
71 | 'color' , # The color of this label
72 | ] )
73 |
74 |
75 | #--------------------------------------------------------------------------------
76 | # A list of all labels
77 | #--------------------------------------------------------------------------------
78 |
79 | # Please adapt the train IDs as appropriate for you approach.
80 | # Note that you might want to ignore labels with ID 255 during training.
81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like.
82 | # Make sure to provide your results using the original IDs and not the training IDs.
83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these!
84 |
85 | labels = [
86 | # name id trainId category catId hasInstances ignoreInEval color
87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,154) ), # (153,153,153)
106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,143) ), # ( 0, 0,142)
122 | ]
123 |
124 |
125 | #--------------------------------------------------------------------------------
126 | # Create dictionaries for a fast lookup
127 | #--------------------------------------------------------------------------------
128 |
129 | # Please refer to the main method below for example usages!
130 |
131 | # name to label object
132 | name2label = { label.name : label for label in labels }
133 | # id to label object
134 | id2label = { label.id : label for label in labels }
135 | # trainId to label object
136 | trainId2label = { label.trainId : label for label in reversed(labels) }
137 | # label2trainid
138 | label2trainid = { label.id : label.trainId for label in labels }
139 | # trainId to label object
140 | trainId2name = { label.trainId : label.name for label in labels }
141 | trainId2color = { label.trainId : label.color for label in labels }
142 |
143 | color2trainId = { label.color : label.trainId for label in labels }
144 |
145 | trainId2trainId = { label.trainId : label.trainId for label in labels }
146 |
147 | # category to list of label objects
148 | category2labels = {}
149 | for label in labels:
150 | category = label.category
151 | if category in category2labels:
152 | category2labels[category].append(label)
153 | else:
154 | category2labels[category] = [label]
155 |
156 | #--------------------------------------------------------------------------------
157 | # Assure single instance name
158 | #--------------------------------------------------------------------------------
159 |
160 | # returns the label name that describes a single instance (if possible)
161 | # e.g. input | output
162 | # ----------------------
163 | # car | car
164 | # cargroup | car
165 | # foo | None
166 | # foogroup | None
167 | # skygroup | None
168 | def assureSingleInstanceName( name ):
169 | # if the name is known, it is not a group
170 | if name in name2label:
171 | return name
172 | # test if the name actually denotes a group
173 | if not name.endswith("group"):
174 | return None
175 | # remove group
176 | name = name[:-len("group")]
177 | # test if the new name exists
178 | if not name in name2label:
179 | return None
180 | # test if the new name denotes a label that actually has instances
181 | if not name2label[name].hasInstances:
182 | return None
183 | # all good then
184 | return name
185 |
186 | #--------------------------------------------------------------------------------
187 | # Main for testing
188 | #--------------------------------------------------------------------------------
189 |
190 | # just a dummy main
191 | if __name__ == "__main__":
192 | # Print all the labels
193 | print("List of cityscapes labels:")
194 | print("")
195 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )))
196 | print((" " + ('-' * 98)))
197 | for label in labels:
198 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )))
199 | print("")
200 |
201 | print("Example usages:")
202 |
203 | # Map from name to label
204 | name = 'car'
205 | id = name2label[name].id
206 | print(("ID of label '{name}': {id}".format( name=name, id=id )))
207 |
208 | # Map from ID to label
209 | category = id2label[id].category
210 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category )))
211 |
212 | # Map from trainID to label
213 | trainId = 0
214 | name = trainId2label[trainId].name
215 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )))
216 |
--------------------------------------------------------------------------------
/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | """
2 | ImageNet Dataset Loader
3 | """
4 | import logging
5 | import json
6 | import os
7 | import numpy as np
8 | from PIL import Image
9 | from skimage import color
10 |
11 | from torch.utils import data
12 | import torch
13 | import torchvision.transforms as transforms
14 | import datasets.uniform as uniform
15 | import datasets.cityscapes_labels as cityscapes_labels
16 | import scipy.misc as m
17 |
18 | from config import cfg
19 |
20 | import pdb
21 | import random
22 | from itertools import cycle, islice
23 |
24 | root = cfg.DATASET.IMAGENET_DIR
25 | img_postfix = '.JPEG'
26 |
27 |
28 | def make_cv_splits():
29 | """
30 | Create splits of train/valid data.
31 | A split is a lists of cities.
32 | split0 is aligned with the default Cityscapes train/valid.
33 | """
34 | trn_path = os.path.join(root, 'train')
35 | val_path = os.path.join(root, 'val')
36 |
37 | trn_cities = ['train/' + c for c in os.listdir(trn_path)]
38 | val_cities = ['val/' + c for c in os.listdir(val_path)]
39 |
40 | # want reproducible randomly shuffled
41 | trn_cities = sorted(trn_cities)
42 |
43 | all_cities = val_cities + trn_cities
44 | num_val_cities = len(val_cities)
45 | num_cities = len(all_cities)
46 |
47 | cv_splits = []
48 | for split_idx in range(cfg.DATASET.CV_SPLITS):
49 | split = {}
50 | split['train'] = []
51 | split['val'] = []
52 | offset = split_idx * num_cities // cfg.DATASET.CV_SPLITS
53 | for j in range(num_cities):
54 | if j >= offset and j < (offset + num_val_cities):
55 | split['val'].append(all_cities[j])
56 | else:
57 | split['train'].append(all_cities[j])
58 | cv_splits.append(split)
59 |
60 | return cv_splits
61 |
62 |
63 | def add_items(items, aug_items, cities, img_path, mode, maxSkip):
64 | """
65 |
66 | Add More items ot the list from the augmented dataset
67 | """
68 |
69 | for c in cities:
70 | c_items = [name.split(img_postfix)[0] for name in
71 | os.listdir(os.path.join(img_path, c))]
72 | for it in c_items:
73 | item = (os.path.join(img_path, c, it + img_postfix))
74 | items.append(item)
75 |
76 |
77 | def make_dataset(mode, maxSkip=0, cv_split=0):
78 | """
79 | Assemble list of images + mask files
80 |
81 | fine - modes: train/valid/test/trainval cv:0,1,2
82 | coarse - modes: train/valid cv:na
83 |
84 | path examples:
85 | leftImg8bit_trainextra/leftImg8bit/train_extra/augsburg
86 | gtCoarse/gtCoarse/train_extra/augsburg
87 | """
88 | items = []
89 | aug_items = []
90 |
91 | assert mode in ['train', 'val', 'trainval']
92 | cv_splits = make_cv_splits()
93 | if mode == 'trainval':
94 | modes = ['train', 'val']
95 | else:
96 | modes = [mode]
97 | for mode in modes:
98 | logging.info('{} imagenet: '.format(mode) + str(cv_splits[cv_split][mode]))
99 | add_items(items, aug_items, cv_splits[cv_split][mode], root, mode, maxSkip)
100 |
101 | logging.info('ImageNet-{}: {} images'.format(mode, len(items) + len(aug_items)))
102 |
103 | return items, aug_items
104 |
105 |
106 | class ImageNet(data.Dataset):
107 |
108 | def __init__(self, mode, maxSkip=0, joint_transform=None, sliding_crop=None,
109 | transform=None, dump_images=False,
110 | cv_split=None, eval_mode=False,
111 | eval_scales=None, eval_flip=False, image_in=False,
112 | extract_feature=False):
113 | self.mode = mode
114 | self.maxSkip = maxSkip
115 | self.joint_transform = joint_transform
116 | self.sliding_crop = sliding_crop
117 | self.transform = transform
118 | self.dump_images = dump_images
119 | self.eval_mode = eval_mode
120 | self.eval_flip = eval_flip
121 | self.eval_scales = None
122 | self.image_in = image_in
123 | self.extract_feature = extract_feature
124 |
125 |
126 | if eval_scales != None:
127 | self.eval_scales = [float(scale) for scale in eval_scales.split(",")]
128 |
129 | if cv_split:
130 | self.cv_split = cv_split
131 | assert cv_split < cfg.DATASET.CV_SPLITS, \
132 | 'expected cv_split {} to be < CV_SPLITS {}'.format(
133 | cv_split, cfg.DATASET.CV_SPLITS)
134 | else:
135 | self.cv_split = 0
136 | self.imgs, _ = make_dataset(mode, self.maxSkip, cv_split=self.cv_split)
137 | if len(self.imgs) == 0:
138 | raise RuntimeError('Found 0 images, please check the data set')
139 |
140 | self.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
141 |
142 | def _eval_get_item(self, img, scales, flip_bool):
143 | return_imgs = []
144 | for flip in range(int(flip_bool) + 1):
145 | imgs = []
146 | if flip:
147 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
148 | for scale in scales:
149 | w, h = img.size
150 | target_w, target_h = int(w * scale), int(h * scale)
151 | resize_img = img.resize((target_w, target_h))
152 | tensor_img = transforms.ToTensor()(resize_img)
153 | final_tensor = transforms.Normalize(*self.mean_std)(tensor_img)
154 | imgs.append(final_tensor)
155 | return_imgs.append(imgs)
156 | return return_imgs
157 |
158 | def __getitem__(self, index):
159 |
160 | img_path = self.imgs[index]
161 | img = Image.open(img_path).convert('RGB')
162 | img_name = os.path.splitext(os.path.basename(img_path))[0]
163 |
164 | if self.eval_mode:
165 | return [transforms.ToTensor()(img)], self._eval_get_item(img,
166 | self.eval_scales,
167 | self.eval_flip), img_name
168 | # Image Transformations
169 | if self.extract_feature is not True:
170 | if self.joint_transform is not None:
171 | img = self.joint_transform(img)
172 |
173 | if self.transform is not None:
174 | img = self.transform(img)
175 |
176 | rgb_mean_std_gt = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
177 | img_gt = transforms.Normalize(*rgb_mean_std_gt)(img)
178 |
179 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
180 | if self.image_in:
181 | eps = 1e-5
182 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])],
183 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps])
184 | img = transforms.Normalize(*rgb_mean_std)(img)
185 |
186 | # Debug
187 | if self.dump_images:
188 | outdir = '../../dump_imgs_{}'.format(self.mode)
189 | os.makedirs(outdir, exist_ok=True)
190 | out_img_fn = os.path.join(outdir, img_name + '.png')
191 | img.save(out_img_fn)
192 |
193 | return img, img_name
194 |
195 | def __len__(self):
196 | return len(self.imgs)
197 |
--------------------------------------------------------------------------------
/datasets/mapillary.py:
--------------------------------------------------------------------------------
1 | """
2 | Mapillary Dataset Loader
3 | """
4 | import logging
5 | import json
6 | import os
7 | import numpy as np
8 | from PIL import Image, ImageCms
9 | from skimage import color
10 |
11 | from torch.utils import data
12 | import torch
13 | import torchvision.transforms as transforms
14 | import datasets.uniform as uniform
15 | import datasets.cityscapes_labels as cityscapes_labels
16 | import transforms.transforms as extended_transforms
17 | import copy
18 |
19 | from config import cfg
20 |
21 | # Convert this dataset to have labels from cityscapes
22 | num_classes = 19 #65
23 | ignore_label = 255 #65
24 | root = cfg.DATASET.MAPILLARY_DIR
25 | config_fn = os.path.join(root, 'config.json')
26 | color_mapping = []
27 | id_to_trainid = {}
28 | id_to_ignore_or_group = {}
29 |
30 |
31 | def gen_id_to_ignore():
32 | global id_to_ignore_or_group
33 | for i in range(66):
34 | id_to_ignore_or_group[i] = ignore_label
35 |
36 | ### Convert each class to cityscapes one
37 | ### Road
38 | # Road
39 | id_to_ignore_or_group[13] = 0
40 | # Lane Marking - General
41 | id_to_ignore_or_group[24] = 0
42 | # Manhole
43 | id_to_ignore_or_group[41] = 0
44 |
45 | ### Sidewalk
46 | # Curb
47 | id_to_ignore_or_group[2] = 1
48 | # Sidewalk
49 | id_to_ignore_or_group[15] = 1
50 |
51 | ### Building
52 | # Building
53 | id_to_ignore_or_group[17] = 2
54 |
55 | ### Wall
56 | # Wall
57 | id_to_ignore_or_group[6] = 3
58 |
59 | ### Fence
60 | # Fence
61 | id_to_ignore_or_group[3] = 4
62 |
63 | ### Pole
64 | # Pole
65 | id_to_ignore_or_group[45] = 5
66 | # Utility Pole
67 | id_to_ignore_or_group[47] = 5
68 |
69 | ### Traffic Light
70 | # Traffic Light
71 | id_to_ignore_or_group[48] = 6
72 |
73 | ### Traffic Sign
74 | # Traffic Sign
75 | id_to_ignore_or_group[50] = 7
76 |
77 | ### Vegetation
78 | # Vegitation
79 | id_to_ignore_or_group[30] = 8
80 |
81 | ### Terrain
82 | # Terrain
83 | id_to_ignore_or_group[29] = 9
84 |
85 | ### Sky
86 | # Sky
87 | id_to_ignore_or_group[27] = 10
88 |
89 | ### Person
90 | # Person
91 | id_to_ignore_or_group[19] = 11
92 |
93 | ### Rider
94 | # Bicyclist
95 | id_to_ignore_or_group[20] = 12
96 | # Motorcyclist
97 | id_to_ignore_or_group[21] = 12
98 | # Other Rider
99 | id_to_ignore_or_group[22] = 12
100 |
101 | ### Car
102 | # Car
103 | id_to_ignore_or_group[55] = 13
104 |
105 | ### Truck
106 | # Truck
107 | id_to_ignore_or_group[61] = 14
108 |
109 | ### Bus
110 | # Bus
111 | id_to_ignore_or_group[54] = 15
112 |
113 | ### Train
114 | # On Rails
115 | id_to_ignore_or_group[58] = 16
116 |
117 | ### Motorcycle
118 | # Motorcycle
119 | id_to_ignore_or_group[57] = 17
120 |
121 | ### Bicycle
122 | # Bicycle
123 | id_to_ignore_or_group[52] = 18
124 |
125 |
126 | def colorize_mask(image_array):
127 | """
128 | Colorize a segmentation mask
129 | """
130 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P')
131 | new_mask.putpalette(color_mapping)
132 | return new_mask
133 |
134 |
135 | def make_dataset(quality, mode):
136 | """
137 | Create File List
138 | """
139 | assert (quality == 'semantic' and mode in ['train', 'val'])
140 | img_dir_name = None
141 | if quality == 'semantic':
142 | if mode == 'train':
143 | img_dir_name = 'training'
144 | if mode == 'val':
145 | img_dir_name = 'validation'
146 | mask_path = os.path.join(root, img_dir_name, 'labels')
147 | else:
148 | raise BaseException("Instance Segmentation Not support")
149 |
150 | img_path = os.path.join(root, img_dir_name, 'images')
151 | print(img_path)
152 | if quality != 'video':
153 | imgs = sorted([os.path.splitext(f)[0] for f in os.listdir(img_path)])
154 | msks = sorted([os.path.splitext(f)[0] for f in os.listdir(mask_path)])
155 | assert imgs == msks
156 |
157 | items = []
158 | c_items = os.listdir(img_path)
159 | if '.DS_Store' in c_items:
160 | c_items.remove('.DS_Store')
161 |
162 | for it in c_items:
163 | if quality == 'video':
164 | item = (os.path.join(img_path, it), os.path.join(img_path, it))
165 | else:
166 | item = (os.path.join(img_path, it),
167 | os.path.join(mask_path, it.replace(".jpg", ".png")))
168 | items.append(item)
169 | return items
170 |
171 |
172 | def gen_colormap():
173 | """
174 | Get Color Map from file
175 | """
176 | global color_mapping
177 |
178 | # load mapillary config
179 | with open(config_fn) as config_file:
180 | config = json.load(config_file)
181 | config_labels = config['labels']
182 |
183 | # calculate label color mapping
184 | colormap = []
185 | id2name = {}
186 | for i in range(0, len(config_labels)):
187 | colormap = colormap + config_labels[i]['color']
188 | id2name[i] = config_labels[i]['readable']
189 | color_mapping = colormap
190 | return id2name
191 |
192 |
193 | class Mapillary(data.Dataset):
194 | def __init__(self, quality, mode, joint_transform_list=None,
195 | transform=None, target_transform=None, target_aux_transform=None,
196 | image_in=False, dump_images=False, class_uniform_pct=0,
197 | class_uniform_tile=768, test=False):
198 | """
199 | class_uniform_pct = Percent of class uniform samples. 1.0 means fully uniform.
200 | 0.0 means fully random.
201 | class_uniform_tile_size = Class uniform tile size
202 | """
203 | gen_id_to_ignore()
204 | self.quality = quality
205 | self.mode = mode
206 | self.joint_transform_list = joint_transform_list
207 | self.transform = transform
208 | self.target_transform = target_transform
209 | self.image_in = image_in
210 | self.target_aux_transform = target_aux_transform
211 | self.dump_images = dump_images
212 | self.class_uniform_pct = class_uniform_pct
213 | self.class_uniform_tile = class_uniform_tile
214 | self.id2name = gen_colormap()
215 | self.imgs_uniform = None
216 |
217 |
218 | # find all images
219 | self.imgs = make_dataset(quality, mode)
220 | if len(self.imgs) == 0:
221 | raise RuntimeError('Found 0 images, please check the data set')
222 | if test:
223 | np.random.shuffle(self.imgs)
224 | self.imgs = self.imgs[:200]
225 |
226 | if self.class_uniform_pct:
227 | json_fn = 'mapillary_tile{}.json'.format(self.class_uniform_tile)
228 | if os.path.isfile(json_fn):
229 | with open(json_fn, 'r') as json_data:
230 | centroids = json.load(json_data)
231 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
232 | else:
233 | # centroids is a dict (indexed by class) of lists of centroids
234 | self.centroids = uniform.class_centroids_all(
235 | self.imgs,
236 | num_classes,
237 | id2trainid=None,
238 | tile_size=self.class_uniform_tile)
239 | with open(json_fn, 'w') as outfile:
240 | json.dump(self.centroids, outfile, indent=4)
241 | else:
242 | self.centroids = []
243 | self.build_epoch()
244 |
245 | def build_epoch(self):
246 | if self.class_uniform_pct != 0:
247 | self.imgs_uniform = uniform.build_epoch(self.imgs,
248 | self.centroids,
249 | num_classes,
250 | self.class_uniform_pct)
251 | else:
252 | self.imgs_uniform = self.imgs
253 |
254 | def __getitem__(self, index):
255 | if len(self.imgs_uniform[index]) == 2:
256 | img_path, mask_path = self.imgs_uniform[index]
257 | centroid = None
258 | class_id = None
259 | else:
260 | img_path, mask_path, centroid, class_id = self.imgs_uniform[index]
261 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
262 | img_name = os.path.splitext(os.path.basename(img_path))[0]
263 |
264 | mask = np.array(mask)
265 | mask_copy = mask.copy()
266 | for k, v in id_to_ignore_or_group.items():
267 | mask_copy[mask == k] = v
268 | mask = Image.fromarray(mask_copy.astype(np.uint8))
269 |
270 | # Image Transformations
271 | if self.joint_transform_list is not None:
272 | for idx, xform in enumerate(self.joint_transform_list):
273 | if idx == 0 and centroid is not None:
274 | # HACK! Assume the first transform accepts a centroid
275 | img, mask = xform(img, mask, centroid)
276 | else:
277 | img, mask = xform(img, mask)
278 |
279 | if self.dump_images:
280 | outdir = 'dump_imgs_{}'.format(self.mode)
281 | os.makedirs(outdir, exist_ok=True)
282 | if centroid is not None:
283 | dump_img_name = self.id2name[class_id] + '_' + img_name
284 | else:
285 | dump_img_name = img_name
286 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
287 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
288 | mask_img = colorize_mask(np.array(mask))
289 | img.save(out_img_fn)
290 | mask_img.save(out_msk_fn)
291 |
292 | if self.transform is not None:
293 | img = self.transform(img)
294 |
295 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
296 | img_gt = transforms.Normalize(*rgb_mean_std)(img)
297 | if self.image_in:
298 | eps = 1e-5
299 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])],
300 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps])
301 | img = transforms.Normalize(*rgb_mean_std)(img)
302 |
303 | if self.target_aux_transform is not None:
304 | mask_aux = self.target_aux_transform(mask)
305 | else:
306 | mask_aux = torch.tensor([0])
307 | if self.target_transform is not None:
308 | mask = self.target_transform(mask)
309 |
310 | mask = extended_transforms.MaskToTensor()(mask)
311 | return img, mask, img_name, mask_aux
312 |
313 | def __len__(self):
314 | return len(self.imgs_uniform)
315 |
316 | def calculate_weights(self):
317 | raise BaseException("not supported yet")
318 |
--------------------------------------------------------------------------------
/datasets/multi_loader.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom DomainUniformConcatDataset
3 | """
4 | import numpy as np
5 |
6 | from torch.utils.data import Dataset
7 | import torch
8 | from config import cfg
9 |
10 |
11 | np.random.seed(cfg.RANDOM_SEED)
12 |
13 |
14 | class DomainUniformConcatDataset(Dataset):
15 | """
16 | DomainUniformConcatDataset
17 |
18 | Sample images uniformly across the domains
19 | If bs_mul is n, this outputs # of domains * n images per batch
20 | """
21 | @staticmethod
22 | def cumsum(sequence):
23 | r, s = [], 0
24 | for e in sequence:
25 | l = len(e)
26 | r.append(l + s)
27 | s += l
28 | return r
29 |
30 | def __init__(self, args, datasets):
31 | """
32 | This dataset is to return sample image (source)
33 | and augmented sample image (target)
34 | Args:
35 | args: input config arguments
36 | datasets: list of datasets to concat
37 | """
38 | super(DomainUniformConcatDataset, self).__init__()
39 | self.datasets = datasets
40 | self.lengths = [len(d) for d in datasets]
41 | self.offsets = self.cumsum(datasets)
42 | self.length = np.sum(self.lengths)
43 |
44 | print("# domains: {}, Total length: {}, 1 epoch: {}, offsets: {}".format(
45 | str(len(datasets)), str(self.length), str(len(self)), str(self.offsets)))
46 |
47 |
48 | def __len__(self):
49 | """
50 | Returns:
51 | The number of images in a domain that has minimum image samples
52 | """
53 | return min(self.lengths)
54 |
55 |
56 | def _get_batch_from_dataset(self, dataset, idx):
57 | """
58 | Get batch from dataset
59 | New idx = idx + random integer
60 | Args:
61 | dataset: dataset class object
62 | idx: integer
63 |
64 | Returns:
65 | One batch from dataset
66 | """
67 | p_index = idx + np.random.randint(len(dataset))
68 | if p_index > len(dataset) - 1:
69 | p_index -= len(dataset)
70 |
71 | return dataset[p_index]
72 |
73 |
74 | def __getitem__(self, idx):
75 | """
76 | Args:
77 | idx (int): Index
78 |
79 | Returns:
80 | images corresonding to the index from each domain
81 | """
82 | imgs = []
83 | masks = []
84 | img_names = []
85 | mask_auxs = []
86 |
87 | for dataset in self.datasets:
88 | img, mask, img_name, mask_aux = self._get_batch_from_dataset(dataset, idx)
89 | imgs.append(img)
90 | masks.append(mask)
91 | img_names.append(img_name)
92 | mask_auxs.append(mask_aux)
93 | imgs, masks, mask_auxs = torch.stack(imgs, 0), torch.stack(masks, 0), torch.stack(mask_auxs, 0)
94 |
95 | return imgs, masks, img_names, mask_auxs
96 |
97 |
--------------------------------------------------------------------------------
/datasets/nullloader.py:
--------------------------------------------------------------------------------
1 | """
2 | Null Loader
3 | """
4 | import numpy as np
5 | import torch
6 | from torch.utils import data
7 |
8 | num_classes = 19
9 | ignore_label = 255
10 |
11 | class NullLoader(data.Dataset):
12 | """
13 | Null Dataset for Performance
14 | """
15 | def __init__(self,crop_size):
16 | self.imgs = range(200)
17 | self.crop_size = crop_size
18 |
19 | def __getitem__(self, index):
20 | #Return img, mask, name
21 | return torch.FloatTensor(np.zeros((3,self.crop_size,self.crop_size))), torch.LongTensor(np.zeros((self.crop_size,self.crop_size))), 'img' + str(index)
22 |
23 | def __len__(self):
24 | return len(self.imgs)
--------------------------------------------------------------------------------
/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017,
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 |
36 |
37 |
38 | import math
39 | import torch
40 | from torch.distributed import get_world_size, get_rank
41 | from torch.utils.data import Sampler
42 |
43 | class DistributedSampler(Sampler):
44 | """Sampler that restricts data loading to a subset of the dataset.
45 |
46 | It is especially useful in conjunction with
47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
48 | process can pass a DistributedSampler instance as a DataLoader sampler,
49 | and load a subset of the original dataset that is exclusive to it.
50 |
51 | .. note::
52 | Dataset is assumed to be of constant size.
53 |
54 | Arguments:
55 | dataset: Dataset used for sampling.
56 | num_replicas (optional): Number of processes participating in
57 | distributed training.
58 | rank (optional): Rank of the current process within num_replicas.
59 | """
60 |
61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None):
62 | if num_replicas is None:
63 | num_replicas = get_world_size()
64 | if rank is None:
65 | rank = get_rank()
66 | self.dataset = dataset
67 | self.num_replicas = num_replicas
68 | self.rank = rank
69 | self.epoch = 0
70 | self.consecutive_sample = consecutive_sample
71 | self.permutation = permutation
72 | if pad:
73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
74 | else:
75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
76 | self.total_size = self.num_samples * self.num_replicas
77 |
78 | def __iter__(self):
79 | # deterministically shuffle based on epoch
80 | g = torch.Generator()
81 | g.manual_seed(self.epoch)
82 |
83 | if self.permutation:
84 | indices = list(torch.randperm(len(self.dataset), generator=g))
85 | else:
86 | indices = list([x for x in range(len(self.dataset))])
87 |
88 | # add extra samples to make it evenly divisible
89 | if self.total_size > len(indices):
90 | indices += indices[:(self.total_size - len(indices))]
91 |
92 | # subsample
93 | if self.consecutive_sample:
94 | offset = self.num_samples * self.rank
95 | indices = indices[offset:offset + self.num_samples]
96 | else:
97 | indices = indices[self.rank:self.total_size:self.num_replicas]
98 | assert len(indices) == self.num_samples
99 |
100 | return iter(indices)
101 |
102 | def __len__(self):
103 | return self.num_samples
104 |
105 | def set_epoch(self, epoch):
106 | self.epoch = epoch
107 |
108 | def set_num_samples(self):
109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
110 | self.total_size = self.num_samples * self.num_replicas
--------------------------------------------------------------------------------
/datasets/uniform.py:
--------------------------------------------------------------------------------
1 | """
2 | Uniform sampling of classes.
3 | For all images, for all classes, generate centroids around which to sample.
4 |
5 | All images are divided into tiles.
6 | For each tile, a class can be present or not. If it is
7 | present, calculate the centroid of the class and record it.
8 |
9 | We would like to thank Peter Kontschieder for the inspiration of this idea.
10 | """
11 |
12 | import logging
13 | from collections import defaultdict
14 | from PIL import Image
15 | import numpy as np
16 | from scipy import ndimage
17 | from tqdm import tqdm
18 |
19 | pbar = None
20 |
21 | class Point():
22 | """
23 | Point Class For X and Y Location
24 | """
25 | def __init__(self, x, y):
26 | self.x = x
27 | self.y = y
28 |
29 |
30 | def calc_tile_locations(tile_size, image_size):
31 | """
32 | Divide an image into tiles to help us cover classes that are spread out.
33 | tile_size: size of tile to distribute
34 | image_size: original image size
35 | return: locations of the tiles
36 | """
37 | image_size_y, image_size_x = image_size
38 | locations = []
39 | for y in range(image_size_y // tile_size):
40 | for x in range(image_size_x // tile_size):
41 | x_offs = x * tile_size
42 | y_offs = y * tile_size
43 | locations.append((x_offs, y_offs))
44 | return locations
45 |
46 |
47 | def class_centroids_image(item, tile_size, num_classes, id2trainid):
48 | """
49 | For one image, calculate centroids for all classes present in image.
50 | item: image, image_name
51 | tile_size:
52 | num_classes:
53 | id2trainid: mapping from original id to training ids
54 | return: Centroids are calculated for each tile.
55 | """
56 | image_fn, label_fn = item
57 | centroids = defaultdict(list)
58 | mask = np.array(Image.open(label_fn))
59 | if len(mask.shape) == 3:
60 | # Remove instance mask
61 | mask = mask[:,:,0]
62 | image_size = mask.shape
63 | tile_locations = calc_tile_locations(tile_size, image_size)
64 |
65 | mask_copy = mask.copy()
66 | if id2trainid:
67 | for k, v in id2trainid.items():
68 | mask[mask_copy == k] = v
69 |
70 | for x_offs, y_offs in tile_locations:
71 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size]
72 | for class_id in range(num_classes):
73 | if class_id in patch:
74 | patch_class = (patch == class_id).astype(int)
75 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class)
76 | centroid_y = int(centroid_y) + y_offs
77 | centroid_x = int(centroid_x) + x_offs
78 | centroid = (centroid_x, centroid_y)
79 | centroids[class_id].append((image_fn, label_fn, centroid, class_id))
80 | pbar.update(1)
81 | return centroids
82 |
83 | import scipy.misc as m
84 |
85 | def class_centroids_image_from_color(item, tile_size, num_classes, id2trainid):
86 | """
87 | For one image, calculate centroids for all classes present in image.
88 | item: image, image_name
89 | tile_size:
90 | num_classes:
91 | id2trainid: mapping from original id to training ids
92 | return: Centroids are calculated for each tile.
93 | """
94 | image_fn, label_fn = item
95 | centroids = defaultdict(list)
96 | mask = m.imread(label_fn)
97 | image_size = mask[:,:,0].shape
98 | tile_locations = calc_tile_locations(tile_size, image_size)
99 |
100 | # mask = m.imread(label_fn)
101 | # mask_copy = np.full((img.size[1], img.size[0]), 255, dtype=np.uint8)
102 | # for k, v in id2trainid.items():
103 | # mask_copy[(mask == k)[:,:,0]] = v
104 | # mask = Image.fromarray(mask_copy.astype(np.uint8))
105 |
106 | # mask_copy = mask.copy()
107 | # mask_copy = mask.copy()
108 | # if id2trainid:
109 | # for k, v in id2trainid.items():
110 | # mask[mask_copy == k] = v
111 |
112 | mask_copy = np.full(image_size, 255, dtype=np.uint8)
113 |
114 | if id2trainid:
115 | for k, v in id2trainid.items():
116 | # print("0", mask.shape)
117 | # print("1", ((mask == np.array(k))[:,:,0]).shape) # 1052, 1914
118 | # # print("2", mask == np.array(k)[:,:,0])
119 | # break
120 | # if v != 255:
121 | # print(v)
122 | # if v == 2:
123 | # print(k, v, "num", np.count_nonzero(mask == np.array(k)))
124 | # break
125 | if v != 255 and v != -1:
126 | mask_copy[(mask == np.array(k))[:,:,0] & (mask == np.array(k))[:,:,1] & (mask == np.array(k))[:,:,2]] = v
127 | mask = mask_copy
128 |
129 | # mask_copy = mask.copy()
130 | # if id2trainid:
131 | # for k, v in id2trainid.items():
132 | # mask[mask_copy == k] = v
133 |
134 | for x_offs, y_offs in tile_locations:
135 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size]
136 | for class_id in range(num_classes):
137 | if class_id in patch:
138 | patch_class = (patch == class_id).astype(int)
139 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class)
140 | centroid_y = int(centroid_y) + y_offs
141 | centroid_x = int(centroid_x) + x_offs
142 | centroid = (centroid_x, centroid_y)
143 | centroids[class_id].append((image_fn, label_fn, centroid, class_id))
144 | pbar.update(1)
145 | return centroids
146 |
147 | def pooled_class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024):
148 | """
149 | Calculate class centroids for all classes for all images for all tiles.
150 | items: list of (image_fn, label_fn)
151 | tile size: size of tile
152 | returns: dict that contains a list of centroids for each class
153 | """
154 | from multiprocessing.dummy import Pool
155 | from functools import partial
156 | pool = Pool(32)
157 | global pbar
158 | pbar = tqdm(total=len(items), desc='pooled centroid extraction')
159 | class_centroids_item = partial(class_centroids_image_from_color,
160 | num_classes=num_classes,
161 | id2trainid=id2trainid,
162 | tile_size=tile_size)
163 |
164 | centroids = defaultdict(list)
165 | new_centroids = pool.map(class_centroids_item, items)
166 | pool.close()
167 | pool.join()
168 |
169 | # combine each image's items into a single global dict
170 | for image_items in new_centroids:
171 | for class_id in image_items:
172 | centroids[class_id].extend(image_items[class_id])
173 | return centroids
174 |
175 |
176 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024):
177 | """
178 | Calculate class centroids for all classes for all images for all tiles.
179 | items: list of (image_fn, label_fn)
180 | tile size: size of tile
181 | returns: dict that contains a list of centroids for each class
182 | """
183 | from multiprocessing.dummy import Pool
184 | from functools import partial
185 | pool = Pool(80)
186 | global pbar
187 | pbar = tqdm(total=len(items), desc='pooled centroid extraction')
188 | class_centroids_item = partial(class_centroids_image,
189 | num_classes=num_classes,
190 | id2trainid=id2trainid,
191 | tile_size=tile_size)
192 |
193 | centroids = defaultdict(list)
194 | new_centroids = pool.map(class_centroids_item, items)
195 | pool.close()
196 | pool.join()
197 |
198 | # combine each image's items into a single global dict
199 | for image_items in new_centroids:
200 | for class_id in image_items:
201 | centroids[class_id].extend(image_items[class_id])
202 | return centroids
203 |
204 |
205 | def unpooled_class_centroids_all(items, num_classes, tile_size=1024):
206 | """
207 | Calculate class centroids for all classes for all images for all tiles.
208 | items: list of (image_fn, label_fn)
209 | tile size: size of tile
210 | returns: dict that contains a list of centroids for each class
211 | """
212 | centroids = defaultdict(list)
213 | global pbar
214 | pbar = tqdm(total=len(items), desc='centroid extraction')
215 | for image, label in items:
216 | new_centroids = class_centroids_image((image, label),
217 | tile_size,
218 | num_classes)
219 | for class_id in new_centroids:
220 | centroids[class_id].extend(new_centroids[class_id])
221 |
222 | return centroids
223 |
224 |
225 | def class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024):
226 | """
227 | intermediate function to call pooled_class_centroid
228 | """
229 |
230 | pooled_centroids = pooled_class_centroids_all_from_color(items, num_classes,
231 | id2trainid, tile_size)
232 | return pooled_centroids
233 |
234 |
235 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024):
236 | """
237 | intermediate function to call pooled_class_centroid
238 | """
239 |
240 | pooled_centroids = pooled_class_centroids_all(items, num_classes,
241 | id2trainid, tile_size)
242 | return pooled_centroids
243 |
244 |
245 | def random_sampling(alist, num):
246 | """
247 | Randomly sample num items from the list
248 | alist: list of centroids to sample from
249 | num: can be larger than the list and if so, then wrap around
250 | return: class uniform samples from the list
251 | """
252 | sampling = []
253 | len_list = len(alist)
254 | assert len_list, 'len_list is zero!'
255 | indices = np.arange(len_list)
256 | np.random.shuffle(indices)
257 |
258 | for i in range(num):
259 | item = alist[indices[i % len_list]]
260 | sampling.append(item)
261 | return sampling
262 |
263 |
264 | def build_epoch(imgs, centroids, num_classes, class_uniform_pct):
265 | """
266 | Generate an epochs-worth of crops using uniform sampling. Needs to be called every
267 | imgs: list of imgs
268 | centroids:
269 | num_classes:
270 | class_uniform_pct: class uniform sampling percent ( % of uniform images in one epoch )
271 | """
272 | logging.info("Class Uniform Percentage: %s", str(class_uniform_pct))
273 | num_epoch = int(len(imgs))
274 |
275 | logging.info('Class Uniform items per Epoch:%s', str(num_epoch))
276 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes)
277 | num_rand = num_epoch - num_per_class * num_classes
278 | # create random crops
279 | imgs_uniform = random_sampling(imgs, num_rand)
280 |
281 | # now add uniform sampling
282 | for class_id in range(num_classes):
283 | string_format = "cls %d len %d"% (class_id, len(centroids[class_id]))
284 | logging.info(string_format)
285 | for class_id in range(num_classes):
286 | centroid_len = len(centroids[class_id])
287 | if centroid_len == 0:
288 | pass
289 | else:
290 | class_centroids = random_sampling(centroids[class_id], num_per_class)
291 | imgs_uniform.extend(class_centroids)
292 |
293 | return imgs_uniform
294 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Loss.py
3 | """
4 |
5 | import logging
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import datasets
11 | from config import cfg
12 |
13 |
14 | def get_loss(args):
15 | """
16 | Get the criterion based on the loss function
17 | args: commandline arguments
18 | return: criterion, criterion_val
19 | """
20 | if args.cls_wt_loss:
21 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
22 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
23 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507])
24 | else:
25 | ce_weight = None
26 |
27 | if args.img_wt_loss:
28 | criterion = ImageBasedCrossEntropyLoss2d(
29 | classes=datasets.num_classes, size_average=True,
30 | ignore_index=datasets.ignore_label,
31 | upper_bound=args.wt_bound).cuda()
32 | elif args.jointwtborder:
33 | criterion = ImgWtLossSoftNLL(classes=datasets.num_classes,
34 | ignore_index=datasets.ignore_label,
35 | upper_bound=args.wt_bound).cuda()
36 | else:
37 | print("standard cross entropy")
38 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean',
39 | ignore_index=datasets.ignore_label).cuda()
40 |
41 | criterion_val = nn.CrossEntropyLoss(reduction='mean',
42 | ignore_index=datasets.ignore_label).cuda()
43 | return criterion, criterion_val
44 |
45 | def get_loss_by_epoch(args):
46 | """
47 | Get the criterion based on the loss function
48 | args: commandline arguments
49 | return: criterion, criterion_val
50 | """
51 |
52 | if args.img_wt_loss:
53 | criterion = ImageBasedCrossEntropyLoss2d(
54 | classes=datasets.num_classes, size_average=True,
55 | ignore_index=datasets.ignore_label,
56 | upper_bound=args.wt_bound).cuda()
57 | elif args.jointwtborder:
58 | criterion = ImgWtLossSoftNLL_by_epoch(classes=datasets.num_classes,
59 | ignore_index=datasets.ignore_label,
60 | upper_bound=args.wt_bound).cuda()
61 | else:
62 | criterion = CrossEntropyLoss2d(size_average=True,
63 | ignore_index=datasets.ignore_label).cuda()
64 |
65 | criterion_val = CrossEntropyLoss2d(size_average=True,
66 | weight=None,
67 | ignore_index=datasets.ignore_label).cuda()
68 | return criterion, criterion_val
69 |
70 |
71 | def get_loss_aux(args):
72 | """
73 | Get the criterion based on the loss function
74 | args: commandline arguments
75 | return: criterion, criterion_val
76 | """
77 | if args.cls_wt_loss:
78 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
79 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
80 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507])
81 | else:
82 | ce_weight = None
83 |
84 | print("standard cross entropy")
85 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean',
86 | ignore_index=datasets.ignore_label).cuda()
87 |
88 | return criterion
89 |
90 | def get_loss_bcelogit(args):
91 | if args.cls_wt_loss:
92 | pos_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
93 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
94 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507])
95 | else:
96 | pos_weight = None
97 | print("standard bce with logit cross entropy")
98 | criterion = nn.BCEWithLogitsLoss(reduction='mean').cuda()
99 |
100 | return criterion
101 |
102 | def weighted_binary_cross_entropy(output, target):
103 |
104 | weights = torch.Tensor([0.1, 0.9])
105 |
106 | loss = weights[1] * (target * torch.log(output)) + \
107 | weights[0] * ((1 - target) * torch.log(1 - output))
108 |
109 | return torch.neg(torch.mean(loss))
110 |
111 |
112 | class L1Loss(nn.Module):
113 | def __init__(self):
114 | super(L1Loss, self).__init__()
115 |
116 | def __call__(self, in0, in1):
117 | return torch.sum(torch.abs(in0 - in1), dim=1, keepdim=True)
118 |
119 |
120 | class ImageBasedCrossEntropyLoss2d(nn.Module):
121 | """
122 | Image Weighted Cross Entropy Loss
123 | """
124 |
125 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255,
126 | norm=False, upper_bound=1.0):
127 | super(ImageBasedCrossEntropyLoss2d, self).__init__()
128 | logging.info("Using Per Image based weighted loss")
129 | self.num_classes = classes
130 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index)
131 | self.norm = norm
132 | self.upper_bound = upper_bound
133 | self.batch_weights = cfg.BATCH_WEIGHTING
134 | self.logsoftmax = nn.LogSoftmax(dim=1)
135 |
136 | def calculate_weights(self, target):
137 | """
138 | Calculate weights of classes based on the training crop
139 | """
140 | hist = np.histogram(target.flatten(), range(
141 | self.num_classes + 1), normed=True)[0]
142 | if self.norm:
143 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
144 | else:
145 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
146 | return hist
147 |
148 | def forward(self, inputs, targets):
149 |
150 | target_cpu = targets.data.cpu().numpy()
151 | if self.batch_weights:
152 | weights = self.calculate_weights(target_cpu)
153 | self.nll_loss.weight = torch.Tensor(weights).cuda()
154 |
155 | loss = 0.0
156 | for i in range(0, inputs.shape[0]):
157 | if not self.batch_weights:
158 | weights = self.calculate_weights(target_cpu[i])
159 | self.nll_loss.weight = torch.Tensor(weights).cuda()
160 |
161 | loss += self.nll_loss(self.logsoftmax(inputs[i].unsqueeze(0)),
162 | targets[i].unsqueeze(0))
163 | return loss
164 |
165 |
166 |
167 | class CrossEntropyLoss2d(nn.Module):
168 | """
169 | Cross Entroply NLL Loss
170 | """
171 |
172 | def __init__(self, weight=None, size_average=True, ignore_index=255):
173 | super(CrossEntropyLoss2d, self).__init__()
174 | logging.info("Using Cross Entropy Loss")
175 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index)
176 | self.logsoftmax = nn.LogSoftmax(dim=1)
177 | # self.weight = weight
178 |
179 | def forward(self, inputs, targets):
180 | return self.nll_loss(self.logsoftmax(inputs), targets)
181 |
182 | def customsoftmax(inp, multihotmask):
183 | """
184 | Custom Softmax
185 | """
186 | soft = F.softmax(inp, dim=1)
187 | # This takes the mask * softmax ( sums it up hence summing up the classes in border
188 | # then takes of summed up version vs no summed version
189 | return torch.log(
190 | torch.max(soft, (multihotmask * (soft * multihotmask).sum(1, keepdim=True)))
191 | )
192 |
193 | class ImgWtLossSoftNLL(nn.Module):
194 | """
195 | Relax Loss
196 | """
197 |
198 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0,
199 | norm=False):
200 | super(ImgWtLossSoftNLL, self).__init__()
201 | self.weights = weights
202 | self.num_classes = classes
203 | self.ignore_index = ignore_index
204 | self.upper_bound = upper_bound
205 | self.norm = norm
206 | self.batch_weights = cfg.BATCH_WEIGHTING
207 |
208 | def calculate_weights(self, target):
209 | """
210 | Calculate weights of the classes based on training crop
211 | """
212 | if len(target.shape) == 3:
213 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum()
214 | else:
215 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum()
216 | if self.norm:
217 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
218 | else:
219 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
220 | return hist[:-1]
221 |
222 | def custom_nll(self, inputs, target, class_weights, border_weights, mask):
223 | """
224 | NLL Relaxed Loss Implementation
225 | """
226 | if (cfg.REDUCE_BORDER_ITER != -1 and cfg.ITER > cfg.REDUCE_BORDER_ITER):
227 | border_weights = 1 / border_weights
228 | target[target > 1] = 1
229 |
230 | loss_matrix = (-1 / border_weights *
231 | (target[:, :-1, :, :].float() *
232 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
233 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \
234 | (1. - mask.float())
235 |
236 | # loss_matrix[border_weights > 1] = 0
237 | loss = loss_matrix.sum()
238 |
239 | # +1 to prevent division by 0
240 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1)
241 | return loss
242 |
243 | def forward(self, inputs, target):
244 | weights = target[:, :-1, :, :].sum(1).float()
245 | ignore_mask = (weights == 0)
246 | weights[ignore_mask] = 1
247 |
248 | loss = 0
249 | target_cpu = target.data.cpu().numpy()
250 |
251 | if self.batch_weights:
252 | class_weights = self.calculate_weights(target_cpu)
253 |
254 | for i in range(0, inputs.shape[0]):
255 | if not self.batch_weights:
256 | class_weights = self.calculate_weights(target_cpu[i])
257 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0),
258 | target[i].unsqueeze(0),
259 | class_weights=torch.Tensor(class_weights).cuda(),
260 | border_weights=weights[i], mask=ignore_mask[i])
261 |
262 | loss = loss / inputs.shape[0]
263 | return loss
264 |
265 | class ImgWtLossSoftNLL_by_epoch(nn.Module):
266 | """
267 | Relax Loss
268 | """
269 |
270 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0,
271 | norm=False):
272 | super(ImgWtLossSoftNLL_by_epoch, self).__init__()
273 | self.weights = weights
274 | self.num_classes = classes
275 | self.ignore_index = ignore_index
276 | self.upper_bound = upper_bound
277 | self.norm = norm
278 | self.batch_weights = cfg.BATCH_WEIGHTING
279 | self.fp16 = False
280 |
281 |
282 | def calculate_weights(self, target):
283 | """
284 | Calculate weights of the classes based on training crop
285 | """
286 | if len(target.shape) == 3:
287 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum()
288 | else:
289 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum()
290 | if self.norm:
291 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
292 | else:
293 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
294 | return hist[:-1]
295 |
296 | def custom_nll(self, inputs, target, class_weights, border_weights, mask):
297 | """
298 | NLL Relaxed Loss Implementation
299 | """
300 | if (cfg.REDUCE_BORDER_EPOCH != -1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH):
301 | border_weights = 1 / border_weights
302 | target[target > 1] = 1
303 | if self.fp16:
304 | loss_matrix = (-1 / border_weights *
305 | (target[:, :-1, :, :].half() *
306 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
307 | customsoftmax(inputs, target[:, :-1, :, :].half())).sum(1)) * \
308 | (1. - mask.half())
309 | else:
310 | loss_matrix = (-1 / border_weights *
311 | (target[:, :-1, :, :].float() *
312 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
313 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \
314 | (1. - mask.float())
315 |
316 | # loss_matrix[border_weights > 1] = 0
317 | loss = loss_matrix.sum()
318 |
319 | # +1 to prevent division by 0
320 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1)
321 | return loss
322 |
323 | def forward(self, inputs, target):
324 | if self.fp16:
325 | weights = target[:, :-1, :, :].sum(1).half()
326 | else:
327 | weights = target[:, :-1, :, :].sum(1).float()
328 | ignore_mask = (weights == 0)
329 | weights[ignore_mask] = 1
330 |
331 | loss = 0
332 | target_cpu = target.data.cpu().numpy()
333 |
334 | if self.batch_weights:
335 | class_weights = self.calculate_weights(target_cpu)
336 |
337 | for i in range(0, inputs.shape[0]):
338 | if not self.batch_weights:
339 | class_weights = self.calculate_weights(target_cpu[i])
340 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0),
341 | target[i].unsqueeze(0),
342 | class_weights=torch.Tensor(class_weights).cuda(),
343 | border_weights=weights, mask=ignore_mask[i])
344 |
345 | return loss
346 |
--------------------------------------------------------------------------------
/network/Resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code Adapted from:
3 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017,
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 |
36 | import torch
37 | import torch.nn as nn
38 | import torch.utils.model_zoo as model_zoo
39 | import network.mynn as mynn
40 | from network.adain import AdaptiveInstanceNormalization
41 |
42 |
43 | __all__ = ['ResNet', 'resnet50']
44 |
45 | model_urls = {
46 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
47 | }
48 |
49 |
50 | class Bottleneck(nn.Module):
51 | """
52 | Bottleneck Layer for Resnet
53 | """
54 | expansion = 4
55 |
56 | def __init__(self, inplanes, planes, stride=1, downsample=None, fs=0):
57 | super(Bottleneck, self).__init__()
58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
59 | self.bn1 = mynn.Norm2d(planes)
60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
61 | padding=1, bias=False)
62 | self.bn2 = mynn.Norm2d(planes)
63 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
64 | self.bn3 = mynn.Norm2d(planes * self.expansion)
65 | self.downsample = downsample
66 | self.stride = stride
67 |
68 | self.fs = fs
69 | if self.fs == 1:
70 | self.instance_norm_layer = AdaptiveInstanceNormalization()
71 | self.relu = nn.ReLU(inplace=False)
72 | else:
73 | self.relu = nn.ReLU(inplace=True)
74 |
75 | def forward(self, x_tuple):
76 | if len(x_tuple) == 1:
77 | x = x_tuple[0]
78 | elif len(x_tuple) == 3:
79 | x = x_tuple[0]
80 | x_w = x_tuple[1]
81 | x_sw = x_tuple[2]
82 | else:
83 | raise NotImplementedError("%d is not supported length of the tuple"%(len(x_tuple)))
84 |
85 | residual = x
86 |
87 | out = self.conv1(x)
88 | out = self.bn1(out)
89 | out = self.relu(out)
90 |
91 | out = self.conv2(out)
92 | out = self.bn2(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv3(out)
96 | out = self.bn3(out)
97 |
98 | if self.downsample is not None:
99 | residual = self.downsample(x)
100 |
101 | out += residual
102 |
103 | if len(x_tuple) == 3:
104 | with torch.no_grad():
105 | residual_w = x_w
106 |
107 | out_w = self.conv1(x_w)
108 | out_w = self.bn1(out_w)
109 | out_w = self.relu(out_w)
110 |
111 | out_w = self.conv2(out_w)
112 | out_w = self.bn2(out_w)
113 | out_w = self.relu(out_w)
114 |
115 | out_w = self.conv3(out_w)
116 | out_w = self.bn3(out_w)
117 |
118 | if self.downsample is not None:
119 | residual_w = self.downsample(x_w)
120 |
121 | out_w += residual_w
122 |
123 | residual_sw = x_sw
124 |
125 | out_sw = self.conv1(x_sw)
126 | out_sw = self.bn1(out_sw)
127 | out_sw = self.relu(out_sw)
128 |
129 | out_sw = self.conv2(out_sw)
130 | out_sw = self.bn2(out_sw)
131 | out_sw = self.relu(out_sw)
132 |
133 | out_sw = self.conv3(out_sw)
134 | out_sw = self.bn3(out_sw)
135 |
136 | if self.downsample is not None:
137 | residual_sw = self.downsample(x_sw)
138 |
139 | out_sw += residual_sw
140 |
141 |
142 | if self.fs == 1:
143 | out = self.instance_norm_layer(out)
144 | if len(x_tuple) == 3:
145 | out_sw = self.instance_norm_layer(out_sw, out_w)
146 | with torch.no_grad():
147 | out_w = self.instance_norm_layer(out_w)
148 |
149 | out = self.relu(out)
150 |
151 | if len(x_tuple) == 3:
152 | with torch.no_grad():
153 | out_w = self.relu(out_w)
154 | out_sw = self.relu(out_sw)
155 | return [out, out_w, out_sw]
156 | else:
157 | return [out]
158 |
159 |
160 | class ResNet(nn.Module):
161 | """
162 | Resnet Global Module for Initialization
163 | """
164 |
165 | def __init__(self, block, layers, fs_layer=None, num_classes=1000):
166 | self.inplanes = 64
167 | super(ResNet, self).__init__()
168 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
169 | bias=False)
170 | if fs_layer[0] == 1:
171 | self.bn1 = AdaptiveInstanceNormalization()
172 | self.relu = nn.ReLU(inplace=False)
173 | else:
174 | self.bn1 = mynn.Norm2d(64)
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
178 | self.layer1 = self._make_layer(block, 64, layers[0], fs_layer=fs_layer[1])
179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, fs_layer=fs_layer[2])
180 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, fs_layer=fs_layer[3])
181 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, fs_layer=fs_layer[4])
182 | self.avgpool = nn.AvgPool2d(7, stride=1)
183 | self.fc = nn.Linear(512 * block.expansion, num_classes)
184 |
185 | for m in self.modules():
186 | if isinstance(m, nn.Conv2d):
187 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
188 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.SyncBatchNorm):
189 | if m.weight is not None:
190 | nn.init.constant_(m.weight, 1)
191 | if m.bias is not None:
192 | nn.init.constant_(m.bias, 0)
193 |
194 | def _make_layer(self, block, planes, blocks, stride=1, fs_layer=0):
195 | downsample = None
196 | if stride != 1 or self.inplanes != planes * block.expansion:
197 | downsample = nn.Sequential(
198 | nn.Conv2d(self.inplanes, planes * block.expansion,
199 | kernel_size=1, stride=stride, bias=False),
200 | mynn.Norm2d(planes * block.expansion),
201 | )
202 |
203 | layers = []
204 | layers.append(block(self.inplanes, planes, stride, downsample, fs=0))
205 | self.inplanes = planes * block.expansion
206 | for index in range(1, blocks):
207 | layers.append(block(self.inplanes, planes,
208 | fs=0 if (fs_layer > 0 and index < blocks - 1) else fs_layer))
209 | return nn.Sequential(*layers)
210 |
211 | def forward(self, x):
212 | x = self.conv1(x)
213 | x = self.bn1(x)
214 | x = self.relu(x)
215 |
216 | x = self.maxpool(x)
217 |
218 | x = self.layer1(x)
219 | x = self.layer2(x)
220 | x = self.layer3(x)
221 | x = self.layer4(x)
222 |
223 | x = self.avgpool(x)
224 | x = x.view(x.size(0), -1)
225 | x = self.fc(x)
226 |
227 | return x
228 |
229 |
230 | def resnet50(pretrained=True, fs_layer=None, **kwargs):
231 | """Constructs a ResNet-50 model.
232 |
233 | Args:
234 | pretrained (bool): If True, returns a model pre-trained on ImageNet
235 | """
236 | if fs_layer is None:
237 | fs_layer = [0, 0, 0, 0, 0]
238 | model = ResNet(Bottleneck, [3, 4, 6, 3], fs_layer=fs_layer, **kwargs)
239 | if pretrained:
240 | print("########### pretrained ##############")
241 | mynn.forgiving_state_restore(model, model_zoo.load_url(model_urls['resnet50']))
242 | return model
243 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Network Initializations
3 | """
4 |
5 | import logging
6 | import importlib
7 | import torch
8 | import datasets
9 |
10 |
11 |
12 | def get_net(args, criterion, criterion_aux=None, cont_proj_head=0, wild_cont_dict_size=0):
13 | """
14 | Get Network Architecture based on arguments provided
15 | """
16 | net = get_model(args=args, num_classes=datasets.num_classes,
17 | criterion=criterion, criterion_aux=criterion_aux, cont_proj_head=cont_proj_head, wild_cont_dict_size=wild_cont_dict_size)
18 | num_params = sum([param.nelement() for param in net.parameters()])
19 | logging.info('Model params = {:2.3f}M'.format(num_params / 1000000))
20 |
21 | net = net.cuda()
22 | return net
23 |
24 |
25 | def warp_network_in_dataparallel(net, gpuid):
26 | """
27 | Wrap the network in Dataparallel
28 | """
29 | # torch.cuda.set_device(gpuid)
30 | # net.cuda(gpuid)
31 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid], find_unused_parameters=True)
32 | # net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid])#, find_unused_parameters=True)
33 | return net
34 |
35 |
36 | def get_model(args, num_classes, criterion, criterion_aux=None, cont_proj_head=0, wild_cont_dict_size=0):
37 | """
38 | Fetch Network Function Pointer
39 | """
40 | network = args.arch
41 | module = network[:network.rfind('.')]
42 | model = network[network.rfind('.') + 1:]
43 | mod = importlib.import_module(module)
44 | net_func = getattr(mod, model)
45 | net = net_func(args=args, num_classes=num_classes, criterion=criterion, criterion_aux=criterion_aux, cont_proj_head=cont_proj_head, wild_cont_dict_size=wild_cont_dict_size)
46 | return net
47 |
--------------------------------------------------------------------------------
/network/adain.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class AdaptiveInstanceNormalization(nn.Module):
6 |
7 | def __init__(self):
8 | super(AdaptiveInstanceNormalization, self).__init__()
9 |
10 | def forward(self, x_cont, x_style=None):
11 | if x_style is not None:
12 | assert (x_cont.size()[:2] == x_style.size()[:2])
13 | size = x_cont.size()
14 | style_mean, style_std = calc_mean_std(x_style)
15 | content_mean, content_std = calc_mean_std(x_cont)
16 |
17 | normalized_x_cont = (x_cont - content_mean.expand(size))/content_std.expand(size)
18 | denormalized_x_cont = normalized_x_cont * style_std.expand(size) + style_mean.expand(size)
19 |
20 | return denormalized_x_cont
21 |
22 | else:
23 | return x_cont
24 |
25 |
26 | def calc_mean_std(feat, eps=1e-5):
27 | # eps is a small value added to the variance to avoid divide-by-zero.
28 | size = feat.size()
29 | assert (len(size) == 4)
30 | N, C = size[:2]
31 | feat_var = feat.view(N, C, -1).var(dim=2) + eps
32 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
33 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
34 | return feat_mean, feat_std
35 |
36 |
--------------------------------------------------------------------------------
/network/cel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 | import torchvision
5 | import numpy as np
6 |
7 |
8 | def get_content_extension_loss(feats_s, feats_sw, feats_w, gts, queue):
9 |
10 | B, C, H, W = feats_s.shape # feat feature size (B X C X H X W)
11 |
12 | # uniform sampling with a size of 64 x 64 for source and wild-stylized source feature maps
13 | H_W_resize = 64
14 | HW = H_W_resize * H_W_resize
15 |
16 | upsample_n = nn.Upsample(size=[H_W_resize, H_W_resize], mode='nearest')
17 | feats_s_flat = upsample_n(feats_s)
18 | feats_sw_flat = upsample_n(feats_sw)
19 |
20 | feats_s_flat = feats_s_flat.contiguous().view(B, C, -1) # B X C X H X W > B X C X (H X W)
21 | feats_sw_flat = feats_sw_flat.contiguous().view(B, C, -1) # B X C X H X W > B X C X (H X W)
22 | gts_flat = upsample_n(gts.unsqueeze(1).float()).squeeze(1).long().view(B, HW)
23 |
24 | # uniform sampling with a size of 16 x 16 for wild feature map
25 | H_W_resize_w = 16
26 | HW_w = H_W_resize_w * H_W_resize_w
27 |
28 | upsample_n_w = nn.Upsample(size=[H_W_resize_w, H_W_resize_w], mode='nearest')
29 | feats_w_flat = upsample_n_w(feats_w)
30 | feats_w_flat = torch.einsum("bchw->bhwc", feats_w_flat).contiguous().view(B*H_W_resize_w*H_W_resize_w, C) # B X C X H X W > (B X H X W) X C
31 |
32 | # normalize feature of each pixel
33 | feats_s_flat = nn.functional.normalize(feats_s_flat, p=2, dim=1)
34 | feats_sw_flat = nn.functional.normalize(feats_sw_flat, p=2, dim=1)
35 | feats_w_flat = nn.functional.normalize(feats_w_flat, p=2, dim=1).detach() # (B X H X W) X C
36 |
37 |
38 | # log(dot(feats_s_flat, feats_sw_flat))
39 | T = 0.07
40 | logits_sce = torch.bmm(feats_s_flat.transpose(1,2), feats_sw_flat) # dot product: B X (H X W) X (H X W)
41 | logits_sce = (torch.clamp(logits_sce, min=-1, max=1))/T
42 |
43 | # compute ignore mask: same-class (excluding self) + unknown-labeled pixels
44 | # compute positive mask (same-class)
45 | logits_mask_sce_ignore = torch.eq(gts_flat.unsqueeze(2), gts_flat.unsqueeze(1)) # pos:1, neg:0. B X (H X W) X (H X W)
46 | # include unknown-labeled pixel
47 | logits_mask_sce_ignore = include_unknown(logits_mask_sce_ignore, gts_flat)
48 |
49 | # exclude self-pixel
50 | logits_mask_sce_ignore *= ~torch.eye(HW,HW).type(torch.cuda.BoolTensor).unsqueeze(0).expand([B, -1, -1]) # self:1, other:0. B X (H X W) X (H X W)
51 |
52 | # compute positive mask for cross entropy loss: B X (H X W)
53 | logits_mask_sce_pos = torch.linspace(start=0, end=HW-1, steps=HW).unsqueeze(0).expand([B, -1]).type(torch.cuda.LongTensor)
54 |
55 | # compute unknown-labeled mask for cross entropy loss: B X (H X W)
56 | logits_mask_sce_unk = torch.zeros_like(logits_mask_sce_pos, dtype=torch.bool)
57 | logits_mask_sce_unk[gts_flat>254] = True
58 |
59 | # compute loss_sce
60 | eps = 1e-5
61 | logits_sce[logits_mask_sce_ignore] = -1/T
62 | CELoss = nn.CrossEntropyLoss(reduction='none')
63 | loss_sce = CELoss(logits_sce.transpose(1,2), logits_mask_sce_pos)
64 | loss_sce = ((loss_sce * (~logits_mask_sce_unk)).sum(1) / ((~logits_mask_sce_unk).sum(1) + eps)).mean()
65 |
66 |
67 | # get wild content closest to wild-stylized source content
68 | idx_sim_bs = 512
69 | index_nearest_neighbours = (torch.randn(0)).type(torch.cuda.LongTensor)
70 | for idx_sim in range(int(np.ceil(HW/idx_sim_bs))):
71 | idx_sim_start = idx_sim*idx_sim_bs
72 | idx_sim_end = min((idx_sim+1)*idx_sim_bs, HW)
73 | similarity_matrix = torch.einsum("bcn,cq->bnq", feats_sw_flat[:,:,idx_sim_start:idx_sim_end].type(torch.cuda.HalfTensor), queue['wild'].type(torch.cuda.HalfTensor)) # B X (H X W) X Q
74 | index_nearest_neighbours = torch.cat((index_nearest_neighbours,torch.argmax(similarity_matrix, dim=2)),dim=1) # B X (H X W)
75 | # similarity_matrix = torch.einsum("bcn,cq->bnq", feats_sw_flat, queue['wild']) # B X (H X W) X Q
76 | # index_nearest_neighbours = torch.argmax(similarity_matrix, dim=2) # B X (H X W)
77 | del similarity_matrix
78 | nearest_neighbours = torch.index_select(queue['wild'], dim=1, index=index_nearest_neighbours.view(-1)).view(C, B, HW) # C X B X (H X W)
79 |
80 | # compute exp(dot(feats_s_flat, nearest_neighbours))
81 | logits_wce_pos = torch.einsum("bcn,cbn->bn", feats_s_flat, nearest_neighbours) # dot product: B X C X (H X W) & C X B X (H X W) => B X (H X W)
82 | logits_wce_pos = (torch.clamp(logits_wce_pos, min=-1, max=1))/T
83 | exp_logits_wce_pos = torch.exp(logits_wce_pos)
84 |
85 | # compute negative mask of logits_sce
86 | logits_mask_sce_neg = ~torch.eq(gts_flat.unsqueeze(2), gts_flat.unsqueeze(1)) # pos:0, neg:1. B X (H X W) X (H X W)
87 |
88 | # exclude unknown-labeled pixels from negative samples
89 | logits_mask_sce_neg = exclude_unknown(logits_mask_sce_neg, gts_flat)
90 |
91 | # sum exp(neg samples)
92 | exp_logits_sce_neg = (torch.exp(logits_sce) * logits_mask_sce_neg).sum(2) # B X (H X W)
93 |
94 | # Compute log_prob
95 | log_prob_wce = logits_wce_pos - torch.log(exp_logits_wce_pos + exp_logits_sce_neg) # B X (H X W)
96 |
97 | # Compute loss_wce
98 | loss_wce = -((log_prob_wce * (~logits_mask_sce_unk)).sum(1) / ((~logits_mask_sce_unk).sum(1) + eps)).mean()
99 |
100 |
101 | # enqueue wild contents
102 | sup_enqueue = feats_w_flat # Q X C # (B X H X W) X C
103 | _dequeue_and_enqueue(queue, sup_enqueue)
104 |
105 | # compute content extension learning loss
106 | loss_cel = loss_sce + loss_wce
107 |
108 | return loss_cel
109 |
110 | def exclude_unknown(mask, gts):
111 | '''
112 | mask: [B, HW, HW]
113 | gts: [B, HW]
114 | '''
115 | mask = mask.transpose(1,2).contiguous()
116 | mask[gts>254,:] = False
117 | mask = mask.transpose(1,2).contiguous()
118 |
119 | return mask
120 |
121 | def include_unknown(mask, gts):
122 | '''
123 | mask: [B, HW, HW]
124 | gts: [B, HW]
125 | '''
126 | mask = mask.transpose(1,2).contiguous()
127 | mask[gts>254,:] = True
128 | mask = mask.transpose(1,2).contiguous()
129 |
130 | return mask
131 |
132 | @torch.no_grad()
133 | def _dequeue_and_enqueue(queue, keys):
134 | # gather keys before updating queue
135 | keys = concat_all_gather(keys) # (B X H X W) X C
136 |
137 | batch_size = keys.shape[0]
138 |
139 | ptr = int(queue['wild_ptr'])
140 |
141 | # replace the keys at ptr (dequeue and enqueue)
142 | if (ptr + batch_size) <= queue['size']:
143 | # wild queue
144 | queue['wild'][:, ptr:ptr + batch_size] = keys.T
145 | ptr = (ptr + batch_size) % queue['size'] # move pointer
146 | else:
147 | # wild queue
148 | last_input_num = queue['size'] - ptr
149 | queue['wild'][:,ptr:] = (keys.T)[:,:last_input_num]
150 | ptr = (ptr + batch_size) % queue['size'] # move pointer
151 | queue['wild'][:,:ptr] = (keys.T)[:,last_input_num:]
152 | queue['wild_ptr'][0] = ptr
153 |
154 | # utils
155 | @torch.no_grad()
156 | def concat_all_gather(tensor):
157 | """
158 | Performs all_gather operation on the provided tensors.
159 | *** Warning ***: torch.distributed.all_gather has no gradient.
160 | """
161 |
162 | tensors_gather = varsize_tensor_all_gather(tensor)
163 |
164 | output = tensors_gather
165 | return output
166 |
167 | def varsize_tensor_all_gather(tensor: torch.Tensor):
168 | tensor = tensor.contiguous()
169 |
170 | cuda_device = f'cuda:{torch.distributed.get_rank()}'
171 | size_tens = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=cuda_device)
172 |
173 | size_tens2 = [torch.ones_like(size_tens)
174 | for _ in range(torch.distributed.get_world_size())]
175 |
176 | torch.distributed.all_gather(size_tens2, size_tens)
177 | size_tens2 = torch.cat(size_tens2, dim=0).cpu()
178 | max_size = size_tens2.max()
179 |
180 | padded = torch.empty(max_size, *tensor.shape[1:],
181 | dtype=tensor.dtype,
182 | device=cuda_device)
183 | padded[:tensor.shape[0]] = tensor
184 |
185 | ag = [torch.ones_like(padded)
186 | for _ in range(torch.distributed.get_world_size())]
187 |
188 | torch.distributed.all_gather(ag,padded)
189 | ag = torch.cat(ag, dim=0)
190 |
191 | slices = []
192 | for i, sz in enumerate(size_tens2):
193 | start_idx = i * max_size
194 | end_idx = start_idx + sz.item()
195 |
196 | if end_idx > start_idx:
197 | slices.append(ag[start_idx:end_idx])
198 |
199 | ret = torch.cat(slices, dim=0)
200 |
201 | return ret.to(tensor)
202 |
--------------------------------------------------------------------------------
/network/deepv3.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code Adapted from:
3 | # https://github.com/sthalles/deeplab_v3
4 | #
5 | # MIT License
6 | #
7 | # Copyright (c) 2018 Thalles Santos Silva
8 | #
9 | # Permission is hereby granted, free of charge, to any person obtaining a copy
10 | # of this software and associated documentation files (the "Software"), to deal
11 | # in the Software without restriction, including without limitation the rights
12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | # copies of the Software, and to permit persons to whom the Software is
14 | # furnished to do so, subject to the following conditions:
15 | #
16 | # The above copyright notice and this permission notice shall be included in all
17 | # copies or substantial portions of the Software.
18 | #
19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | """
26 | import logging
27 | import torch
28 | from torch import nn
29 | from network import Resnet
30 | from network.mynn import initialize_weights, Norm2d, Upsample
31 |
32 | import torchvision.models as models
33 |
34 | from network.cel import get_content_extension_loss
35 | import torchvision
36 |
37 |
38 | class _AtrousSpatialPyramidPoolingModule(nn.Module):
39 | """
40 | operations performed:
41 | 1x1 x depth
42 | 3x3 x depth dilation 6
43 | 3x3 x depth dilation 12
44 | 3x3 x depth dilation 18
45 | image pooling
46 | concatenate all together
47 | Final 1x1 conv
48 | """
49 |
50 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)):
51 | super(_AtrousSpatialPyramidPoolingModule, self).__init__()
52 |
53 | # Check if we are using distributed BN and use the nn from encoding.nn
54 | # library rather than using standard pytorch.nn
55 | print("output_stride = ", output_stride)
56 | if output_stride == 8:
57 | rates = [2 * r for r in rates]
58 | elif output_stride == 4:
59 | rates = [4 * r for r in rates]
60 | elif output_stride == 16:
61 | pass
62 | elif output_stride == 32:
63 | rates = [r // 2 for r in rates]
64 | else:
65 | raise 'output stride of {} not supported'.format(output_stride)
66 |
67 | self.features = []
68 | # 1x1
69 | self.features.append(
70 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
71 | Norm2d(reduction_dim), nn.ReLU(inplace=True)))
72 | # other rates
73 | for r in rates:
74 | self.features.append(nn.Sequential(
75 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3,
76 | dilation=r, padding=r, bias=False),
77 | Norm2d(reduction_dim),
78 | nn.ReLU(inplace=True)
79 | ))
80 | self.features = torch.nn.ModuleList(self.features)
81 |
82 | # img level features
83 | self.img_pooling = nn.AdaptiveAvgPool2d(1)
84 | self.img_conv = nn.Sequential(
85 | nn.Conv2d(in_dim, 256, kernel_size=1, bias=False),
86 | Norm2d(256), nn.ReLU(inplace=True))
87 |
88 | def forward(self, x):
89 | x_size = x.size()
90 |
91 | img_features = self.img_pooling(x)
92 | img_features = self.img_conv(img_features)
93 | img_features = Upsample(img_features, x_size[2:])
94 | out = img_features
95 |
96 | for f in self.features:
97 | y = f(x)
98 | out = torch.cat((out, y), 1)
99 | return out
100 |
101 |
102 | class DeepV3Plus(nn.Module):
103 | """
104 | Implement DeepLab-V3 model
105 | A: stride8
106 | B: stride16
107 | with skip connections
108 | """
109 |
110 | def __init__(self, num_classes, trunk='resnet-50', criterion=None, criterion_aux=None, cont_proj_head=0, wild_cont_dict_size=0,
111 | variant='D16', skip='m1', skip_num=48, args=None):
112 | super(DeepV3Plus, self).__init__()
113 |
114 | self.args = args
115 |
116 | # loss
117 | self.criterion = criterion
118 | self.criterion_aux = criterion_aux
119 | self.criterion_kl = nn.KLDivLoss(reduction='batchmean').cuda()
120 |
121 | # create the wild-content dictionary
122 | self.cont_proj_head = cont_proj_head
123 | if wild_cont_dict_size > 0:
124 | if cont_proj_head > 0:
125 | self.cont_dict = {}
126 | self.cont_dict['size'] = wild_cont_dict_size
127 | self.cont_dict['dim'] = self.cont_proj_head
128 |
129 | self.register_buffer("wild_cont_dict", torch.randn(self.cont_dict['dim'], self.cont_dict['size']))
130 | self.wild_cont_dict = nn.functional.normalize(self.wild_cont_dict, p=2, dim=0) # C X Q
131 | self.register_buffer("wild_cont_dict_ptr", torch.zeros(1, dtype=torch.long))
132 | self.cont_dict['wild'] = self.wild_cont_dict.cuda()
133 | self.cont_dict['wild_ptr'] = self.wild_cont_dict_ptr
134 | else:
135 | raise 'dimension of wild-content dictionary is zero'
136 |
137 | # set backbone
138 | self.variant = variant
139 | self.trunk = trunk
140 |
141 | channel_1st = 3
142 | channel_2nd = 64
143 | channel_3rd = 256
144 | channel_4th = 512
145 | prev_final_channel = 1024
146 | final_channel = 2048
147 |
148 | if trunk == 'resnet-50':
149 | resnet = Resnet.resnet50(fs_layer=self.args.fs_layer)
150 | resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
151 | else:
152 | raise ValueError("Not a valid network arch")
153 |
154 | self.layer0 = resnet.layer0
155 | self.layer1, self.layer2, self.layer3, self.layer4 = \
156 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
157 |
158 | if self.variant == 'D16':
159 | os = 16
160 | for n, m in self.layer4.named_modules():
161 | if 'conv2' in n:
162 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
163 | elif 'downsample.0' in n:
164 | m.stride = (1, 1)
165 | else:
166 | raise 'unknown deepv3 variant: {}'.format(self.variant)
167 |
168 | self.output_stride = os
169 | self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel, 256,
170 | output_stride=os)
171 |
172 | self.bot_fine = nn.Sequential(
173 | nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False),
174 | Norm2d(48),
175 | nn.ReLU(inplace=True))
176 |
177 | self.bot_aspp = nn.Sequential(
178 | nn.Conv2d(1280, 256, kernel_size=1, bias=False),
179 | Norm2d(256),
180 | nn.ReLU(inplace=True))
181 |
182 | self.final1 = nn.Sequential(
183 | nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False),
184 | Norm2d(256),
185 | nn.ReLU(inplace=True),
186 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
187 | Norm2d(256),
188 | nn.ReLU(inplace=True))
189 |
190 | self.final2 = nn.Sequential(
191 | nn.Conv2d(256, num_classes, kernel_size=1, bias=True))
192 |
193 | self.dsn = nn.Sequential(
194 | nn.Conv2d(prev_final_channel, 512, kernel_size=3, stride=1, padding=1),
195 | Norm2d(512),
196 | nn.ReLU(inplace=True),
197 | nn.Dropout2d(0.1),
198 | nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True)
199 | )
200 | initialize_weights(self.dsn)
201 |
202 | initialize_weights(self.aspp)
203 | initialize_weights(self.bot_aspp)
204 | initialize_weights(self.bot_fine)
205 | initialize_weights(self.final1)
206 | initialize_weights(self.final2)
207 |
208 | if self.cont_proj_head > 0:
209 | self.proj = nn.Sequential(
210 | nn.Linear(256, 256, bias=True),
211 | nn.ReLU(inplace=False),
212 | nn.Linear(256, self.cont_proj_head, bias=True))
213 | initialize_weights(self.proj)
214 |
215 | # Setting the flags
216 | self.eps = 1e-5
217 | self.whitening = False
218 |
219 | if trunk == 'resnet-50':
220 | in_channel_list = [0, 0, 64, 256, 512, 1024, 2048]
221 | out_channel_list = [0, 0, 32, 128, 256, 512, 1024]
222 | else:
223 | raise ValueError("Not a valid network arch")
224 |
225 |
226 | def forward(self, x, gts=None, aux_gts=None, x_w=None, apply_fs=False):
227 |
228 | x_size = x.size() # 800
229 |
230 | # encoder
231 | x = self.layer0[0](x)
232 | if self.training & apply_fs:
233 | with torch.no_grad():
234 | x_w = self.layer0[0](x_w)
235 | x = self.layer0[1](x)
236 | if self.training & apply_fs:
237 | x_sw = self.layer0[1](x, x_w) # feature stylization
238 | with torch.no_grad():
239 | x_w = self.layer0[1](x_w)
240 | x = self.layer0[2](x)
241 | x = self.layer0[3](x)
242 | if self.training & apply_fs:
243 | with torch.no_grad():
244 | x_w = self.layer0[2](x_w)
245 | x_w = self.layer0[3](x_w)
246 | x_sw = self.layer0[2](x_sw)
247 | x_sw = self.layer0[3](x_sw)
248 |
249 | if self.training & apply_fs:
250 | x_tuple = self.layer1([x, x_w, x_sw])
251 | low_level = x_tuple[0]
252 | low_level_w = x_tuple[1]
253 | low_level_sw = x_tuple[2]
254 | else:
255 | x_tuple = self.layer1([x])
256 | low_level = x_tuple[0]
257 |
258 | x_tuple = self.layer2(x_tuple)
259 | x_tuple = self.layer3(x_tuple)
260 | aux_out = x_tuple[0]
261 | if self.training & apply_fs:
262 | aux_out_w = x_tuple[1]
263 | aux_out_sw = x_tuple[2]
264 | x_tuple = self.layer4(x_tuple)
265 | x = x_tuple[0]
266 | if self.training & apply_fs:
267 | x_w = x_tuple[1]
268 | x_sw = x_tuple[2]
269 |
270 | # decoder
271 | x = self.aspp(x)
272 | dec0_up = self.bot_aspp(x)
273 | dec0_fine = self.bot_fine(low_level)
274 | dec0_up = Upsample(dec0_up, low_level.size()[2:])
275 | dec0 = [dec0_fine, dec0_up]
276 | dec0 = torch.cat(dec0, 1)
277 | dec1 = self.final1(dec0)
278 | dec2 = self.final2(dec1)
279 | main_out = Upsample(dec2, x_size[2:])
280 |
281 | if self.training:
282 | # compute original semantic segmentation loss
283 | loss_orig = self.criterion(main_out, gts)
284 | aux_out = self.dsn(aux_out)
285 | if aux_gts.dim() == 1:
286 | aux_gts = gts
287 | aux_gts = aux_gts.unsqueeze(1).float()
288 | aux_gts = nn.functional.interpolate(aux_gts, size=aux_out.shape[2:], mode='nearest')
289 | aux_gts = aux_gts.squeeze(1).long()
290 | loss_orig_aux = self.criterion_aux(aux_out, aux_gts)
291 |
292 | return_loss = [loss_orig, loss_orig_aux]
293 |
294 | if apply_fs:
295 | x_sw = self.aspp(x_sw)
296 | dec0_up_sw = self.bot_aspp(x_sw)
297 | dec0_fine_sw = self.bot_fine(low_level_sw)
298 | dec0_up_sw = Upsample(dec0_up_sw, low_level_sw.size()[2:])
299 | dec0_sw = [dec0_fine_sw, dec0_up_sw]
300 | dec0_sw = torch.cat(dec0_sw, 1)
301 | dec1_sw = self.final1(dec0_sw)
302 | dec2_sw = self.final2(dec1_sw)
303 | main_out_sw = Upsample(dec2_sw, x_size[2:])
304 |
305 | with torch.no_grad():
306 | x_w = self.aspp(x_w)
307 | dec0_up_w = self.bot_aspp(x_w)
308 | dec0_fine_w = self.bot_fine(low_level_w)
309 | dec0_up_w = Upsample(dec0_up_w, low_level_w.size()[2:])
310 | dec0_w = [dec0_fine_w, dec0_up_w]
311 | dec0_w = torch.cat(dec0_w, 1)
312 | dec1_w = self.final1(dec0_w)
313 | dec2_w = self.final2(dec1_w)
314 | main_out_w = Upsample(dec2_w, x_size[2:])
315 |
316 | if self.args.use_cel:
317 | # projected features
318 | assert (self.cont_proj_head > 0)
319 | proj2 = self.proj(dec1.permute(0,2,3,1)).permute(0,3,1,2)
320 | proj2_sw = self.proj(dec1_sw.permute(0,2,3,1)).permute(0,3,1,2)
321 | with torch.no_grad():
322 | proj2_w = self.proj(dec1_w.permute(0,2,3,1)).permute(0,3,1,2)
323 |
324 | # compute content extension learning loss
325 | loss_cel = get_content_extension_loss(proj2, proj2_sw, proj2_w, gts, self.cont_dict)
326 |
327 | return_loss.append(loss_cel)
328 |
329 | if self.args.use_sel:
330 | # compute style extension learning loss
331 | loss_sel = self.criterion(main_out_sw, gts)
332 | aux_out_sw = self.dsn(aux_out_sw)
333 | loss_sel_aux = self.criterion_aux(aux_out_sw, aux_gts)
334 | return_loss.append(loss_sel)
335 | return_loss.append(loss_sel_aux)
336 |
337 | if self.args.use_scr:
338 | # compute semantic consistency regularization loss
339 | loss_scr = torch.clamp((self.criterion_kl(nn.functional.log_softmax(main_out_sw, dim=1), nn.functional.softmax(main_out, dim=1)))/(torch.prod(torch.tensor(main_out.shape[1:]))), min=0)
340 | loss_scr_aux = torch.clamp((self.criterion_kl(nn.functional.log_softmax(aux_out_sw, dim=1), nn.functional.softmax(aux_out, dim=1)))/(torch.prod(torch.tensor(aux_out.shape[1:]))), min=0)
341 | return_loss.append(loss_scr)
342 | return_loss.append(loss_scr_aux)
343 |
344 | return return_loss
345 | else:
346 | return main_out
347 |
348 |
349 | def DeepR50V3PlusD(args, num_classes, criterion, criterion_aux, cont_proj_head, wild_cont_dict_size):
350 | """
351 | Resnet 50 Based Network
352 | """
353 | print("Model : DeepLabv3+, Backbone : ResNet-50")
354 | return DeepV3Plus(num_classes, trunk='resnet-50', criterion=criterion, criterion_aux=criterion_aux, cont_proj_head=cont_proj_head, wild_cont_dict_size=wild_cont_dict_size,
355 | variant='D16', skip='m1', args=args)
356 |
--------------------------------------------------------------------------------
/network/mynn.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization
3 | """
4 | import torch.nn as nn
5 | import torch
6 | from config import cfg
7 |
8 | import numpy as np
9 |
10 | def Norm2d(in_channels):
11 | """
12 | Custom Norm Function to allow flexible switching
13 | """
14 | layer = getattr(cfg.MODEL, 'BNFUNC')
15 | normalization_layer = layer(in_channels)
16 | return normalization_layer
17 |
18 |
19 | def freeze_weights(*models):
20 | for model in models:
21 | for k in model.parameters():
22 | k.requires_grad = False
23 |
24 | def unfreeze_weights(*models):
25 | for model in models:
26 | for k in model.parameters():
27 | k.requires_grad = True
28 |
29 | def initialize_weights(*models):
30 | """
31 | Initialize Model Weights
32 | """
33 | for model in models:
34 | for module in model.modules():
35 | if isinstance(module, (nn.Conv2d, nn.Linear)):
36 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
37 | if module.bias is not None:
38 | module.bias.data.zero_()
39 | elif isinstance(module, nn.Conv1d):
40 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
41 | if module.bias is not None:
42 | module.bias.data.zero_()
43 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or \
44 | isinstance(module, nn.GroupNorm) or isinstance(module, nn.SyncBatchNorm):
45 | module.weight.data.fill_(1)
46 | module.bias.data.zero_()
47 | elif isinstance(module, nn.ConvTranspose2d):
48 | assert module.kernel_size[0] == module.kernel_size[1]
49 | initial_weight = get_upsampling_weight(module.in_channels, module.out_channels, module.kernel_size[0])
50 | module.weight.data.copy_(initial_weight)
51 |
52 | def get_upsampling_weight(in_channels, out_channels, kernel_size):
53 | """Make a 2D bilinear kernel suitable for upsampling"""
54 | factor = (kernel_size + 1) // 2
55 | if kernel_size % 2 == 1:
56 | center = factor - 1
57 | else:
58 | center = factor - 0.5
59 | og = np.ogrid[:kernel_size, :kernel_size]
60 | filt = (1 - abs(og[0] - center) / factor) * \
61 | (1 - abs(og[1] - center) / factor)
62 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
63 | dtype=np.float64)
64 | weight[range(in_channels), range(out_channels), :, :] = filt
65 | return torch.from_numpy(weight).float()
66 |
67 | def initialize_embedding(*models):
68 | """
69 | Initialize Model Weights
70 | """
71 | for model in models:
72 | for module in model.modules():
73 | if isinstance(module, nn.Embedding):
74 | module.weight.data.zero_() #original
75 |
76 |
77 |
78 | def Upsample(x, size):
79 | """
80 | Wrapper Around the Upsample Call
81 | """
82 | return nn.functional.interpolate(x, size=size, mode='bilinear',
83 | align_corners=True)
84 |
85 | def forgiving_state_restore(net, loaded_dict):
86 | """
87 | Handle partial loading when some tensors don't match up in size.
88 | Because we want to use models that were trained off a different
89 | number of classes.
90 | """
91 | net_state_dict = net.state_dict()
92 | new_loaded_dict = {}
93 | for k in net_state_dict:
94 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
95 | new_loaded_dict[k] = loaded_dict[k]
96 | else:
97 | print("Skipped loading parameter", k)
98 | # logging.info("Skipped loading parameter %s", k)
99 | net_state_dict.update(new_loaded_dict)
100 | net.load_state_dict(net_state_dict)
101 | return net
102 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Pytorch Optimizer and Scheduler Related Task
3 | """
4 | import math
5 | import logging
6 | import torch
7 | from torch import optim
8 | from config import cfg
9 | import pdb
10 |
11 | def get_optimizer(args, net):
12 | """
13 | Decide Optimizer (Adam or SGD)
14 | """
15 | base_params = []
16 |
17 | for name, param in net.named_parameters():
18 | base_params.append(param)
19 |
20 | if args.sgd:
21 | optimizer = optim.SGD(base_params,
22 | lr=args.lr,
23 | weight_decay=args.weight_decay,
24 | momentum=args.momentum,
25 | nesterov=False)
26 | else:
27 | raise ValueError('Not a valid optimizer')
28 |
29 | if args.lr_schedule == 'scl-poly':
30 | if cfg.REDUCE_BORDER_ITER == -1:
31 | raise ValueError('ERROR Cannot Do Scale Poly')
32 |
33 | rescale_thresh = cfg.REDUCE_BORDER_ITER
34 | scale_value = args.rescale
35 | lambda1 = lambda iteration: \
36 | math.pow(1 - iteration / args.max_iter,
37 | args.poly_exp) if iteration < rescale_thresh else scale_value * math.pow(
38 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh),
39 | args.repoly)
40 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
41 | elif args.lr_schedule == 'poly':
42 | lambda1 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp)
43 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
44 | else:
45 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule))
46 |
47 | return optimizer, scheduler
48 |
49 |
50 | def load_weights(net, optimizer, scheduler, snapshot_file, restore_optimizer_bool=False):
51 | """
52 | Load weights from snapshot file
53 | """
54 | logging.info("Loading weights from model %s", snapshot_file)
55 | net, optimizer, scheduler, epoch, mean_iu = restore_snapshot(net, optimizer, scheduler, snapshot_file,
56 | restore_optimizer_bool)
57 | return epoch, mean_iu
58 |
59 |
60 | def restore_snapshot(net, optimizer, scheduler, snapshot, restore_optimizer_bool):
61 | """
62 | Restore weights and optimizer (if needed ) for resuming job.
63 | """
64 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu'))
65 | logging.info("Checkpoint Load Compelete")
66 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool:
67 | optimizer.load_state_dict(checkpoint['optimizer'])
68 | if scheduler is not None and 'scheduler' in checkpoint and restore_optimizer_bool:
69 | scheduler.load_state_dict(checkpoint['scheduler'])
70 |
71 | if 'state_dict' in checkpoint:
72 | net = forgiving_state_restore(net, checkpoint['state_dict'])
73 | else:
74 | net = forgiving_state_restore(net, checkpoint)
75 |
76 | if 'epoch' in checkpoint:
77 | epoch = checkpoint['epoch']
78 | else:
79 | epoch = 0
80 | if 'mean_iu' in checkpoint:
81 | mean_iu = checkpoint['mean_iu']
82 | else:
83 | mean_iu = 0.0
84 |
85 | return net, optimizer, scheduler, epoch, mean_iu
86 |
87 |
88 | def forgiving_state_restore(net, loaded_dict):
89 | """
90 | Handle partial loading when some tensors don't match up in size.
91 | Because we want to use models that were trained off a different
92 | number of classes.
93 | """
94 | net_state_dict = net.state_dict()
95 | new_loaded_dict = {}
96 | for k in net_state_dict:
97 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
98 | new_loaded_dict[k] = loaded_dict[k]
99 | else:
100 | print("Skipped loading parameter", k)
101 | # logging.info("Skipped loading parameter %s", k)
102 | net_state_dict.update(new_loaded_dict)
103 | net.load_state_dict(net_state_dict)
104 | return net
105 |
106 | def forgiving_state_copy(target_net, source_net):
107 | """
108 | Handle partial loading when some tensors don't match up in size.
109 | Because we want to use models that were trained off a different
110 | number of classes.
111 | """
112 | net_state_dict = target_net.state_dict()
113 | loaded_dict = source_net.state_dict()
114 | new_loaded_dict = {}
115 | for k in net_state_dict:
116 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
117 | new_loaded_dict[k] = loaded_dict[k]
118 | print("Matched", k)
119 | else:
120 | print("Skipped loading parameter ", k)
121 | # logging.info("Skipped loading parameter %s", k)
122 | net_state_dict.update(new_loaded_dict)
123 | target_net.load_state_dict(net_state_dict)
124 | return target_net
125 |
--------------------------------------------------------------------------------
/scripts/train_wildnet_r50os16_gtav.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Example on GTAV
3 | python -m torch.distributed.launch --nproc_per_node=2 train.py \
4 | --dataset gtav \
5 | --val_dataset bdd100k cityscapes synthia mapillary \
6 | --wild_dataset imagenet \
7 | --arch network.deepv3.DeepR50V3PlusD \
8 | --city_mode 'train' \
9 | --sgd \
10 | --lr_schedule poly \
11 | --lr 0.0025 \
12 | --poly_exp 0.9 \
13 | --max_cu_epoch 10000 \
14 | --class_uniform_pct 0.5 \
15 | --class_uniform_tile 1024 \
16 | --crop_size 768 \
17 | --scale_min 0.5 \
18 | --scale_max 2.0 \
19 | --rrotate 0 \
20 | --max_iter 60000 \
21 | --bs_mult 4 \
22 | --gblur \
23 | --color_aug 0.5 \
24 | --fs_layer 1 1 1 0 0 \
25 | --cont_proj_head 256 \
26 | --wild_cont_dict_size 393216 \
27 | --lambda_cel 0.1 \
28 | --lambda_sel 1.0 \
29 | --lambda_scr 10.0 \
30 | --date 0101 \
31 | --exp r50os16_gtav_wildnet \
32 | --ckpt ./logs/ \
33 | --tb_path ./logs/
34 |
35 |
--------------------------------------------------------------------------------
/scripts/valid_wildnet_r50os16_gtav.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 |
4 | python -m torch.distributed.launch --nproc_per_node=2 valid.py \
5 | --dataset gtav \
6 | --val_dataset bdd100k cityscapes synthia mapillary \
7 | --wild_dataset imagenet \
8 | --arch network.deepv3.DeepR50V3PlusD \
9 | --city_mode 'train' \
10 | --sgd \
11 | --lr_schedule poly \
12 | --lr 0.0025 \
13 | --poly_exp 0.9 \
14 | --max_cu_epoch 10000 \
15 | --class_uniform_pct 0.5 \
16 | --class_uniform_tile 1024 \
17 | --crop_size 768 \
18 | --scale_min 0.5 \
19 | --scale_max 2.0 \
20 | --rrotate 0 \
21 | --max_iter 60000 \
22 | --bs_mult 4 \
23 | --gblur \
24 | --color_aug 0.5 \
25 | --fs_layer 1 1 1 0 0 \
26 | --cont_proj_head 256 \
27 | --wild_cont_dict_size 393216 \
28 | --lambda_cel 0.1 \
29 | --lambda_sel 1.0 \
30 | --lambda_scr 10.0 \
31 | --date 0101 \
32 | --exp r50os16_gtav_wildnet \
33 | --ckpt ./logs/ \
34 | --tb_path ./logs/ \
35 | --snapshot ${1}
36 |
37 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | training code
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | import argparse
7 | import logging
8 | import os
9 | import torch
10 |
11 | from config import cfg, assert_and_infer_cfg
12 | from utils.misc import AverageMeter, prep_experiment, evaluate_eval, fast_hist
13 | import datasets
14 | import loss
15 | import network
16 | import optimizer
17 | import time
18 | import torchvision.utils as vutils
19 | import torch.nn.functional as F
20 | import numpy as np
21 | import random
22 |
23 |
24 | # Argument Parser
25 | parser = argparse.ArgumentParser(description='Semantic Segmentation')
26 | parser.add_argument('--lr', type=float, default=0.01)
27 | parser.add_argument('--arch', type=str, default='network.deepv3.DeepR50V3PlusD',
28 | help='Network architecture.')
29 | parser.add_argument('--dataset', nargs='*', type=str, default=['gtav'],
30 | help='a list of datasets; cityscapes, mapillary, gtav, bdd100k, synthia')
31 | parser.add_argument('--image_uniform_sampling', action='store_true', default=False,
32 | help='uniformly sample images across the multiple source domains')
33 | parser.add_argument('--val_dataset', nargs='*', type=str, default=['bdd100k'],
34 | help='a list consists of cityscapes, mapillary, gtav, bdd100k, synthia')
35 | parser.add_argument('--wild_dataset', nargs='*', type=str, default=['imagenet'],
36 | help='a list consists of imagenet')
37 | parser.add_argument('--cv', type=int, default=0,
38 | help='cross-validation split id to use. Default # of splits set to 3 in config')
39 | parser.add_argument('--class_uniform_pct', type=float, default=0,
40 | help='What fraction of images is uniformly sampled')
41 | parser.add_argument('--class_uniform_tile', type=int, default=1024,
42 | help='tile size for class uniform sampling')
43 | parser.add_argument('--coarse_boost_classes', type=str, default=None,
44 | help='use coarse annotations to boost fine data with specific classes')
45 |
46 | parser.add_argument('--img_wt_loss', action='store_true', default=False,
47 | help='per-image class-weighted loss')
48 | parser.add_argument('--cls_wt_loss', action='store_true', default=False,
49 | help='class-weighted loss')
50 | parser.add_argument('--batch_weighting', action='store_true', default=False,
51 | help='Batch weighting for class (use nll class weighting using batch stats')
52 |
53 | parser.add_argument('--jointwtborder', action='store_true', default=False,
54 | help='Enable boundary label relaxation')
55 | parser.add_argument('--strict_bdr_cls', type=str, default='',
56 | help='Enable boundary label relaxation for specific classes')
57 | parser.add_argument('--rlx_off_iter', type=int, default=-1,
58 | help='Turn off border relaxation after specific epoch count')
59 | parser.add_argument('--rescale', type=float, default=1.0,
60 | help='Warm Restarts new learning rate ratio compared to original lr')
61 | parser.add_argument('--repoly', type=float, default=1.5,
62 | help='Warm Restart new poly exp')
63 |
64 | parser.add_argument('--fp16', action='store_true', default=False,
65 | help='Use Nvidia Apex AMP')
66 | parser.add_argument('--local_rank', default=0, type=int,
67 | help='parameter used by apex library')
68 |
69 | parser.add_argument('--sgd', action='store_true', default=False)
70 | parser.add_argument('--adam', action='store_true', default=False)
71 | parser.add_argument('--amsgrad', action='store_true', default=False)
72 |
73 | parser.add_argument('--freeze_trunk', action='store_true', default=False)
74 | parser.add_argument('--hardnm', default=0, type=int,
75 | help='0 means no aug, 1 means hard negative mining iter 1,' +
76 | '2 means hard negative mining iter 2')
77 |
78 | parser.add_argument('--trunk', type=str, default='resnet-50',
79 | help='trunk model, can be: resnet-50 (default)')
80 | parser.add_argument('--max_epoch', type=int, default=180)
81 | parser.add_argument('--max_iter', type=int, default=30000)
82 | parser.add_argument('--max_cu_epoch', type=int, default=100000,
83 | help='Class Uniform Max Epochs')
84 | parser.add_argument('--start_epoch', type=int, default=0)
85 | parser.add_argument('--crop_nopad', action='store_true', default=False)
86 | parser.add_argument('--rrotate', type=int,
87 | default=0, help='degree of random roate')
88 | parser.add_argument('--color_aug', type=float,
89 | default=0.0, help='level of color augmentation')
90 | parser.add_argument('--gblur', action='store_true', default=False,
91 | help='Use Guassian Blur Augmentation')
92 | parser.add_argument('--bblur', action='store_true', default=False,
93 | help='Use Bilateral Blur Augmentation')
94 | parser.add_argument('--lr_schedule', type=str, default='poly',
95 | help='name of lr schedule: poly')
96 | parser.add_argument('--poly_exp', type=float, default=0.9,
97 | help='polynomial LR exponent')
98 | parser.add_argument('--bs_mult', type=int, default=2,
99 | help='Batch size for training per gpu')
100 | parser.add_argument('--bs_mult_val', type=int, default=1,
101 | help='Batch size for Validation per gpu')
102 | parser.add_argument('--crop_size', type=int, default=720,
103 | help='training crop size')
104 | parser.add_argument('--pre_size', type=int, default=None,
105 | help='resize image shorter edge to this before augmentation')
106 | parser.add_argument('--scale_min', type=float, default=0.5,
107 | help='dynamically scale training images down to this size')
108 | parser.add_argument('--scale_max', type=float, default=2.0,
109 | help='dynamically scale training images up to this size')
110 | parser.add_argument('--weight_decay', type=float, default=5e-4)
111 | parser.add_argument('--momentum', type=float, default=0.9)
112 | parser.add_argument('--snapshot', type=str, default=None)
113 | parser.add_argument('--restore_optimizer', action='store_true', default=False)
114 |
115 | parser.add_argument('--city_mode', type=str, default='train',
116 | help='experiment directory date name')
117 | parser.add_argument('--date', type=str, default='default',
118 | help='experiment directory date name')
119 | parser.add_argument('--exp', type=str, default='default',
120 | help='experiment directory name')
121 | parser.add_argument('--tb_tag', type=str, default='',
122 | help='add tag to tb dir')
123 | parser.add_argument('--ckpt', type=str, default='logs/ckpt',
124 | help='Save Checkpoint Point')
125 | parser.add_argument('--tb_path', type=str, default='logs/tb',
126 | help='Save Tensorboard Path')
127 | parser.add_argument('--syncbn', action='store_true', default=True,
128 | help='Use Synchronized BN')
129 | parser.add_argument('--dump_augmentation_images', action='store_true', default=False,
130 | help='Dump Augmentated Images for sanity check')
131 | parser.add_argument('--test_mode', action='store_true', default=False,
132 | help='Minimum testing to verify nothing failed, ' +
133 | 'Runs code for 1 epoch of train and val')
134 | parser.add_argument('-wb', '--wt_bound', type=float, default=1.0,
135 | help='Weight Scaling for the losses')
136 | parser.add_argument('--maxSkip', type=int, default=0,
137 | help='Skip x number of frames of video augmented dataset')
138 | parser.add_argument('--scf', action='store_true', default=False,
139 | help='scale correction factor')
140 | parser.add_argument('--dist_url', default='tcp://127.0.0.1:', type=str,
141 | help='url used to set up distributed training')
142 |
143 | parser.add_argument('--image_in', action='store_true', default=False,
144 | help='Input Image Instance Norm')
145 |
146 | parser.add_argument('--fs_layer', nargs='*', type=int, default=[0,0,0,0,0],
147 | help='0: None, 1: AdaIN')
148 | parser.add_argument('--lambda_cel', type=float, default=0.0,
149 | help='lambda for content extension learning loss')
150 | parser.add_argument('--lambda_sel', type=float, default=0.0,
151 | help='lambda for style extension learning loss')
152 | parser.add_argument('--lambda_scr', type=float, default=0.0,
153 | help='lambda for semantic consistency regularization loss')
154 | parser.add_argument('--cont_proj_head', type=int, default=0,
155 | help='number of output channels of content projection head')
156 | parser.add_argument('--wild_cont_dict_size', type=int, default=0,
157 | help='wild-content dictionary size')
158 |
159 | parser.add_argument('--use_fs', action='store_true', default=False,
160 | help='Automatic setting from fs_layer. feature stylization with wild dataset')
161 | parser.add_argument('--use_scr', action='store_true', default=False,
162 | help='Automatic setting from lambda_scr')
163 | parser.add_argument('--use_sel', action='store_true', default=False,
164 | help='Automatic setting from lambda_sel')
165 | parser.add_argument('--use_cel', action='store_true', default=False,
166 | help='Automatic setting from lambda_cel')
167 |
168 | args = parser.parse_args()
169 |
170 | random_seed = cfg.RANDOM_SEED #304
171 | torch.manual_seed(random_seed)
172 | torch.cuda.manual_seed(random_seed)
173 | torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
174 | torch.backends.cudnn.deterministic = True
175 | torch.backends.cudnn.benchmark = False
176 | np.random.seed(random_seed)
177 | random.seed(random_seed)
178 |
179 | args.world_size = 1
180 |
181 | # Test Mode run two epochs with a few iterations of training and val
182 | if args.test_mode:
183 | args.max_epoch = 2
184 |
185 | if 'WORLD_SIZE' in os.environ:
186 | # args.apex = int(os.environ['WORLD_SIZE']) > 1
187 | args.world_size = int(os.environ['WORLD_SIZE'])
188 | print("Total world size: ", int(os.environ['WORLD_SIZE']))
189 |
190 | torch.cuda.set_device(args.local_rank)
191 | print('My Rank:', args.local_rank)
192 | # Initialize distributed communication
193 | args.dist_url = args.dist_url + str(8000 + (int(time.time()%1000))//10)
194 |
195 | torch.distributed.init_process_group(backend='nccl',
196 | init_method='env://',
197 | world_size=args.world_size,
198 | rank=args.local_rank)
199 | # torch.distributed.init_process_group(backend='nccl',
200 | # init_method=args.dist_url,
201 | # world_size=args.world_size,
202 | # rank=args.local_rank)
203 |
204 | for i in range(len(args.fs_layer)):
205 | if args.fs_layer[i] == 1:
206 | args.use_fs = True
207 |
208 | if args.lambda_cel > 0:
209 | args.use_cel = True
210 | if args.lambda_sel > 0:
211 | args.use_sel = True
212 | if args.lambda_scr > 0:
213 | args.use_scr = True
214 |
215 | def main():
216 | """
217 | Main Function
218 | """
219 | # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
220 | assert_and_infer_cfg(args)
221 | writer = prep_experiment(args, parser)
222 |
223 | train_source_loader, val_loaders, train_wild_loader, train_obj, extra_val_loaders = datasets.setup_loaders(args)
224 |
225 | criterion, criterion_val = loss.get_loss(args)
226 | criterion_aux = loss.get_loss_aux(args)
227 | net = network.get_net(args, criterion, criterion_aux, args.cont_proj_head, args.wild_cont_dict_size)
228 |
229 | optim, scheduler = optimizer.get_optimizer(args, net)
230 |
231 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
232 | net = network.warp_network_in_dataparallel(net, args.local_rank)
233 | epoch = 0
234 | i = 0
235 |
236 | if args.snapshot:
237 | epoch, mean_iu = optimizer.load_weights(net, optim, scheduler,
238 | args.snapshot, args.restore_optimizer)
239 | if args.restore_optimizer is True:
240 | iter_per_epoch = len(train_source_loader)
241 | i = iter_per_epoch * epoch
242 | epoch = epoch + 1
243 | else:
244 | epoch = 0
245 |
246 | print("#### iteration", i)
247 | torch.cuda.empty_cache()
248 |
249 | while i < args.max_iter:
250 | # Update EPOCH CTR
251 | cfg.immutable(False)
252 | cfg.ITER = i
253 | cfg.immutable(True)
254 |
255 | i = train(train_source_loader, train_wild_loader, net, optim, epoch, writer, scheduler, args.max_iter)
256 | train_source_loader.sampler.set_epoch(epoch + 1)
257 | train_wild_loader.sampler.set_epoch(epoch + 1)
258 |
259 | if args.local_rank == 0:
260 | print("Saving pth file...")
261 | evaluate_eval(args, net, optim, scheduler, None, None, [],
262 | writer, epoch, "None", None, i, save_pth=True)
263 |
264 | if args.class_uniform_pct:
265 | if epoch >= args.max_cu_epoch:
266 | train_obj.build_epoch(cut=True)
267 | train_source_loader.sampler.set_num_samples()
268 | else:
269 | train_obj.build_epoch()
270 |
271 | epoch += 1
272 |
273 | # Validation after epochs
274 | if len(val_loaders) == 1:
275 | # Run validation only one time - To save models
276 | for dataset, val_loader in val_loaders.items():
277 | validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i)
278 | else:
279 | if args.local_rank == 0:
280 | print("Saving pth file...")
281 | evaluate_eval(args, net, optim, scheduler, None, None, [],
282 | writer, epoch, "None", None, i, save_pth=True)
283 |
284 | for dataset, val_loader in extra_val_loaders.items():
285 | print("Extra validating... This won't save pth file")
286 | validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False)
287 |
288 |
289 | def train(source_loader, wild_loader, net, optim, curr_epoch, writer, scheduler, max_iter):
290 | """
291 | Runs the training loop per epoch
292 | source_loader: Source data loader for train
293 | wild_loader: Wild data loader for train
294 | net: thet network
295 | optim: optimizer
296 | curr_epoch: current epoch
297 | writer: tensorboard writer
298 | return:
299 | """
300 | net.train()
301 |
302 | train_total_loss = AverageMeter()
303 | time_meter = AverageMeter()
304 |
305 | curr_iter = curr_epoch * len(source_loader)
306 |
307 | wild_loader_iter = enumerate(wild_loader)
308 |
309 | for i, data in enumerate(source_loader):
310 | if curr_iter >= max_iter:
311 | break
312 |
313 | inputs, gts, _, aux_gts = data
314 |
315 | # Multi source and AGG case
316 | if len(inputs.shape) == 5:
317 | B, D, C, H, W = inputs.shape
318 | num_domains = D
319 | inputs = inputs.transpose(0, 1)
320 | gts = gts.transpose(0, 1).squeeze(2)
321 | aux_gts = aux_gts.transpose(0, 1).squeeze(2)
322 |
323 | inputs = [input.squeeze(0) for input in torch.chunk(inputs, num_domains, 0)]
324 | gts = [gt.squeeze(0) for gt in torch.chunk(gts, num_domains, 0)]
325 | aux_gts = [aux_gt.squeeze(0) for aux_gt in torch.chunk(aux_gts, num_domains, 0)]
326 | else:
327 | B, C, H, W = inputs.shape
328 | num_domains = 1
329 | inputs = [inputs]
330 | gts = [gts]
331 | aux_gts = [aux_gts]
332 |
333 | batch_pixel_size = C * H * W
334 |
335 | for di, ingredients in enumerate(zip(inputs, gts, aux_gts)):
336 | input, gt, aux_gt = ingredients
337 |
338 | _, inputs_wild = next(wild_loader_iter)
339 | input_wild = inputs_wild[0]
340 |
341 | start_ts = time.time()
342 |
343 | img_gt = None
344 | input, gt = input.cuda(), gt.cuda()
345 | input_wild = input_wild.cuda()
346 |
347 | optim.zero_grad()
348 | outputs = net(x=input, gts=gt, aux_gts=aux_gt, x_w=input_wild, apply_fs=args.use_fs)
349 |
350 | outputs_index = 0
351 | main_loss = outputs[outputs_index]
352 | outputs_index += 1
353 | aux_loss = outputs[outputs_index]
354 | outputs_index += 1
355 | total_loss = main_loss + (0.4 * aux_loss)
356 |
357 | if args.use_fs:
358 | if args.use_cel:
359 | cel_loss = outputs[outputs_index]
360 | outputs_index += 1
361 | total_loss = total_loss + (args.lambda_cel * cel_loss)
362 | else:
363 | cel_loss = 0
364 |
365 | if args.use_sel:
366 | sel_loss_main = outputs[outputs_index]
367 | outputs_index += 1
368 | sel_loss_aux = outputs[outputs_index]
369 | outputs_index += 1
370 | total_loss = total_loss + args.lambda_sel * (sel_loss_main + (0.4 * sel_loss_aux))
371 | else:
372 | sel_loss_main = 0
373 | sel_loss_aux = 0
374 |
375 | if args.use_scr:
376 | scr_loss_main = outputs[outputs_index]
377 | outputs_index += 1
378 | scr_loss_aux = outputs[outputs_index]
379 | outputs_index += 1
380 | total_loss = total_loss + args.lambda_scr * (scr_loss_main + (0.4 * scr_loss_aux))
381 | else:
382 | scr_loss_main = 0
383 | scr_loss_aux = 0
384 |
385 |
386 | log_total_loss = total_loss.clone().detach_()
387 | torch.distributed.all_reduce(log_total_loss, torch.distributed.ReduceOp.SUM)
388 | log_total_loss = log_total_loss / args.world_size
389 | train_total_loss.update(log_total_loss.item(), batch_pixel_size)
390 |
391 | total_loss.backward()
392 | optim.step()
393 |
394 | time_meter.update(time.time() - start_ts)
395 |
396 | del total_loss, log_total_loss
397 |
398 | if args.local_rank == 0:
399 | if i % 50 == 49:
400 | msg = '[epoch {}], [iter {} / {} : {}], [loss {:0.6f}], [lr {:0.6f}], [time {:0.4f}]'.format(
401 | curr_epoch, i + 1, len(source_loader), curr_iter, train_total_loss.avg,
402 | optim.param_groups[-1]['lr'], time_meter.avg / args.train_batch_size)
403 |
404 | logging.info(msg)
405 |
406 | # Log tensorboard metrics for each iteration of the training phase
407 | writer.add_scalar('loss/train_loss', (train_total_loss.avg), curr_iter)
408 | train_total_loss.reset()
409 | time_meter.reset()
410 |
411 | curr_iter += 1
412 | scheduler.step()
413 |
414 | if i > 5 and args.test_mode:
415 | return curr_iter
416 |
417 | return curr_iter
418 |
419 | def validate(val_loader, dataset, net, criterion, optim, scheduler, curr_epoch, writer, curr_iter, save_pth=True):
420 | """
421 | Runs the validation loop after each training epoch
422 | val_loader: Data loader for validation
423 | dataset: dataset name (str)
424 | net: thet network
425 | criterion: loss fn
426 | optimizer: optimizer
427 | curr_epoch: current epoch
428 | writer: tensorboard writer
429 | return: val_avg for step function if required
430 | """
431 |
432 | net.eval()
433 | val_loss = AverageMeter()
434 | iou_acc = 0
435 | error_acc = 0
436 | dump_images = []
437 |
438 | for val_idx, data in enumerate(val_loader):
439 |
440 | inputs, gt_image, img_names, _ = data
441 |
442 | if len(inputs.shape) == 5:
443 | B, D, C, H, W = inputs.shape
444 | inputs = inputs.view(-1, C, H, W)
445 | gt_image = gt_image.view(-1, 1, H, W)
446 |
447 | assert len(inputs.size()) == 4 and len(gt_image.size()) == 3
448 | assert inputs.size()[2:] == gt_image.size()[1:]
449 |
450 | batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3)
451 | inputs, gt_cuda = inputs.cuda(), gt_image.cuda()
452 |
453 | with torch.no_grad():
454 | output = net(inputs)
455 |
456 | del inputs
457 |
458 | assert output.size()[2:] == gt_image.size()[1:]
459 | assert output.size()[1] == datasets.num_classes
460 |
461 | val_loss.update(criterion(output, gt_cuda).item(), batch_pixel_size)
462 |
463 | del gt_cuda
464 |
465 | # Collect data from different GPU to a single GPU since
466 | # encoding.parallel.criterionparallel function calculates distributed loss
467 | # functions
468 | predictions = output.data.max(1)[1].cpu()
469 |
470 | # Logging
471 | if val_idx % 20 == 0:
472 | if args.local_rank == 0:
473 | logging.info("validating: %d / %d", val_idx + 1, len(val_loader))
474 | if val_idx > 10 and args.test_mode:
475 | break
476 |
477 | # Image Dumps
478 | if val_idx < 10:
479 | dump_images.append([gt_image, predictions, img_names])
480 |
481 | iou_acc += fast_hist(predictions.numpy().flatten(), gt_image.numpy().flatten(),
482 | datasets.num_classes)
483 | del output, val_idx, data
484 |
485 | iou_acc_tensor = torch.cuda.FloatTensor(iou_acc)
486 | torch.distributed.all_reduce(iou_acc_tensor, op=torch.distributed.ReduceOp.SUM)
487 | iou_acc = iou_acc_tensor.cpu().numpy()
488 |
489 | if args.local_rank == 0:
490 | evaluate_eval(args, net, optim, scheduler, val_loss, iou_acc, dump_images,
491 | writer, curr_epoch, dataset, None, curr_iter, save_pth=save_pth)
492 |
493 | return val_loss.avg
494 |
495 |
496 | if __name__ == '__main__':
497 | main()
498 |
--------------------------------------------------------------------------------
/transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suhyeonlee/WildNet/c83aa92b7cd591512045fc5093da25a280bde430/transforms/__init__.py
--------------------------------------------------------------------------------
/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code borrowded from:
3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py
4 | #
5 | #
6 | # MIT License
7 | #
8 | # Copyright (c) 2017 ZijunDeng
9 | #
10 | # Permission is hereby granted, free of charge, to any person obtaining a copy
11 | # of this software and associated documentation files (the "Software"), to deal
12 | # in the Software without restriction, including without limitation the rights
13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14 | # copies of the Software, and to permit persons to whom the Software is
15 | # furnished to do so, subject to the following conditions:
16 | #
17 | # The above copyright notice and this permission notice shall be included in all
18 | # copies or substantial portions of the Software.
19 | #
20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 | # SOFTWARE.
27 |
28 | """
29 |
30 | """
31 | Standard Transform
32 | """
33 |
34 | import random
35 | import numpy as np
36 | from skimage.filters import gaussian
37 | from skimage.restoration import denoise_bilateral
38 | import torch
39 | from PIL import Image, ImageEnhance
40 | import torchvision.transforms as torch_tr
41 | from config import cfg
42 | from scipy.ndimage.interpolation import shift
43 |
44 | from skimage.segmentation import find_boundaries
45 | from skimage.util import random_noise
46 |
47 | try:
48 | import accimage
49 | except ImportError:
50 | accimage = None
51 |
52 |
53 | class RandomVerticalFlip(object):
54 | def __call__(self, img):
55 | if random.random() < 0.5:
56 | return img.transpose(Image.FLIP_TOP_BOTTOM)
57 | return img
58 |
59 |
60 | class DeNormalize(object):
61 | def __init__(self, mean, std):
62 | self.mean = mean
63 | self.std = std
64 |
65 | def __call__(self, tensor):
66 | for t, m, s in zip(tensor, self.mean, self.std):
67 | t.mul_(s).add_(m)
68 | return tensor
69 |
70 |
71 | class MaskToTensor(object):
72 | def __call__(self, img):
73 | return torch.from_numpy(np.array(img, dtype=np.int32)).long()
74 |
75 | class RelaxedBoundaryLossToTensor(object):
76 | """
77 | Boundary Relaxation
78 | """
79 | def __init__(self,ignore_id, num_classes):
80 | self.ignore_id=ignore_id
81 | self.num_classes= num_classes
82 |
83 |
84 | def new_one_hot_converter(self,a):
85 | ncols = self.num_classes+1
86 | out = np.zeros( (a.size,ncols), dtype=np.uint8)
87 | out[np.arange(a.size),a.ravel()] = 1
88 | out.shape = a.shape + (ncols,)
89 | return out
90 |
91 | def __call__(self,img):
92 |
93 | img_arr = np.array(img)
94 | img_arr[img_arr==self.ignore_id]=self.num_classes
95 |
96 | if cfg.STRICTBORDERCLASS != None:
97 | one_hot_orig = self.new_one_hot_converter(img_arr)
98 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1]))
99 | for cls in cfg.STRICTBORDERCLASS:
100 | mask = np.logical_or(mask,(img_arr == cls))
101 | one_hot = 0
102 |
103 | border = cfg.BORDER_WINDOW
104 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER):
105 | border = border // 2
106 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8)
107 |
108 | for i in range(-border,border+1):
109 | for j in range(-border, border+1):
110 | shifted= shift(img_arr,(i,j), cval=self.num_classes)
111 | one_hot += self.new_one_hot_converter(shifted)
112 |
113 | one_hot[one_hot>1] = 1
114 |
115 | if cfg.STRICTBORDERCLASS != None:
116 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot)
117 |
118 | one_hot = np.moveaxis(one_hot,-1,0)
119 |
120 |
121 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER):
122 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot)
123 | # print(one_hot.shape)
124 | return torch.from_numpy(one_hot).byte()
125 |
126 | class ResizeHeight(object):
127 | def __init__(self, size, interpolation=Image.BILINEAR):
128 | self.target_h = size
129 | self.interpolation = interpolation
130 |
131 | def __call__(self, img):
132 | w, h = img.size
133 | target_w = int(w / h * self.target_h)
134 | return img.resize((target_w, self.target_h), self.interpolation)
135 |
136 |
137 | class FreeScale(object):
138 | def __init__(self, size, interpolation=Image.BILINEAR):
139 | self.size = tuple(reversed(size)) # size: (h, w)
140 | self.interpolation = interpolation
141 |
142 | def __call__(self, img):
143 | return img.resize(self.size, self.interpolation)
144 |
145 |
146 | class FlipChannels(object):
147 | """
148 | Flip around the x-axis
149 | """
150 | def __call__(self, img):
151 | img = np.array(img)[:, :, ::-1]
152 | return Image.fromarray(img.astype(np.uint8))
153 |
154 |
155 | class RandomGaussianBlur(object):
156 | """
157 | Apply Gaussian Blur
158 | """
159 | def __call__(self, img):
160 | sigma = 0.15 + random.random() * 1.15
161 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
162 | blurred_img *= 255
163 | return Image.fromarray(blurred_img.astype(np.uint8))
164 |
165 |
166 | class RandomGaussianNoise(object):
167 | def __call__(self, img):
168 | noised_img = random_noise(np.array(img), mode='gaussian')
169 | noised_img *= 255
170 | return Image.fromarray(noised_img.astype(np.uint8))
171 |
172 |
173 | class RandomBilateralBlur(object):
174 | """
175 | Apply Bilateral Filtering
176 |
177 | """
178 | def __call__(self, img):
179 | sigma = random.uniform(0.05,0.75)
180 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True)
181 | blurred_img *= 255
182 | return Image.fromarray(blurred_img.astype(np.uint8))
183 |
184 | def _is_pil_image(img):
185 | if accimage is not None:
186 | return isinstance(img, (Image.Image, accimage.Image))
187 | else:
188 | return isinstance(img, Image.Image)
189 |
190 |
191 | def adjust_brightness(img, brightness_factor):
192 | """Adjust brightness of an Image.
193 |
194 | Args:
195 | img (PIL Image): PIL Image to be adjusted.
196 | brightness_factor (float): How much to adjust the brightness. Can be
197 | any non negative number. 0 gives a black image, 1 gives the
198 | original image while 2 increases the brightness by a factor of 2.
199 |
200 | Returns:
201 | PIL Image: Brightness adjusted image.
202 | """
203 | if not _is_pil_image(img):
204 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
205 |
206 | enhancer = ImageEnhance.Brightness(img)
207 | img = enhancer.enhance(brightness_factor)
208 | return img
209 |
210 |
211 | def adjust_contrast(img, contrast_factor):
212 | """Adjust contrast of an Image.
213 |
214 | Args:
215 | img (PIL Image): PIL Image to be adjusted.
216 | contrast_factor (float): How much to adjust the contrast. Can be any
217 | non negative number. 0 gives a solid gray image, 1 gives the
218 | original image while 2 increases the contrast by a factor of 2.
219 |
220 | Returns:
221 | PIL Image: Contrast adjusted image.
222 | """
223 | if not _is_pil_image(img):
224 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
225 |
226 | enhancer = ImageEnhance.Contrast(img)
227 | img = enhancer.enhance(contrast_factor)
228 | return img
229 |
230 |
231 | def adjust_saturation(img, saturation_factor):
232 | """Adjust color saturation of an image.
233 |
234 | Args:
235 | img (PIL Image): PIL Image to be adjusted.
236 | saturation_factor (float): How much to adjust the saturation. 0 will
237 | give a black and white image, 1 will give the original image while
238 | 2 will enhance the saturation by a factor of 2.
239 |
240 | Returns:
241 | PIL Image: Saturation adjusted image.
242 | """
243 | if not _is_pil_image(img):
244 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
245 |
246 | enhancer = ImageEnhance.Color(img)
247 | img = enhancer.enhance(saturation_factor)
248 | return img
249 |
250 |
251 | def adjust_hue(img, hue_factor):
252 | """Adjust hue of an image.
253 |
254 | The image hue is adjusted by converting the image to HSV and
255 | cyclically shifting the intensities in the hue channel (H).
256 | The image is then converted back to original image mode.
257 |
258 | `hue_factor` is the amount of shift in H channel and must be in the
259 | interval `[-0.5, 0.5]`.
260 |
261 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
262 |
263 | Args:
264 | img (PIL Image): PIL Image to be adjusted.
265 | hue_factor (float): How much to shift the hue channel. Should be in
266 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
267 | HSV space in positive and negative direction respectively.
268 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
269 | with complementary colors while 0 gives the original image.
270 |
271 | Returns:
272 | PIL Image: Hue adjusted image.
273 | """
274 | if not(-0.5 <= hue_factor <= 0.5):
275 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
276 |
277 | if not _is_pil_image(img):
278 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
279 | input_mode = img.mode
280 | if input_mode in {'L', '1', 'I', 'F'}:
281 | return img
282 |
283 | h, s, v = img.convert('HSV').split()
284 |
285 | np_h = np.array(h, dtype=np.uint8)
286 | # uint8 addition take cares of rotation across boundaries
287 | with np.errstate(over='ignore'):
288 | np_h += np.uint8(hue_factor * 255)
289 | h = Image.fromarray(np_h, 'L')
290 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
291 | return img
292 |
293 |
294 | class ColorJitter(object):
295 | """Randomly change the brightness, contrast and saturation of an image.
296 |
297 | Args:
298 | brightness (float): How much to jitter brightness. brightness_factor
299 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
300 | contrast (float): How much to jitter contrast. contrast_factor
301 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
302 | saturation (float): How much to jitter saturation. saturation_factor
303 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
304 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
305 | [-hue, hue]. Should be >=0 and <= 0.5.
306 | """
307 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
308 | self.brightness = brightness
309 | self.contrast = contrast
310 | self.saturation = saturation
311 | self.hue = hue
312 |
313 | @staticmethod
314 | def get_params(brightness, contrast, saturation, hue):
315 | """Get a randomized transform to be applied on image.
316 |
317 | Arguments are same as that of __init__.
318 |
319 | Returns:
320 | Transform which randomly adjusts brightness, contrast and
321 | saturation in a random order.
322 | """
323 | transforms = []
324 | if brightness > 0:
325 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
326 | transforms.append(
327 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
328 |
329 | if contrast > 0:
330 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
331 | transforms.append(
332 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
333 |
334 | if saturation > 0:
335 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
336 | transforms.append(
337 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
338 |
339 | if hue > 0:
340 | hue_factor = np.random.uniform(-hue, hue)
341 | transforms.append(
342 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
343 |
344 | np.random.shuffle(transforms)
345 | transform = torch_tr.Compose(transforms)
346 |
347 | return transform
348 |
349 | def __call__(self, img):
350 | """
351 | Args:
352 | img (PIL Image): Input image.
353 |
354 | Returns:
355 | PIL Image: Color jittered image.
356 | """
357 | transform = self.get_params(self.brightness, self.contrast,
358 | self.saturation, self.hue)
359 | return transform(img)
360 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/suhyeonlee/WildNet/c83aa92b7cd591512045fc5093da25a280bde430/utils/__init__.py
--------------------------------------------------------------------------------
/utils/attr_dict.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py
4 |
5 | Source License
6 | # Copyright (c) 2017-present, Facebook, Inc.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | ##############################################################################
20 | #
21 | # Based on:
22 | # --------------------------------------------------------
23 | # Fast R-CNN
24 | # Copyright (c) 2015 Microsoft
25 | # Licensed under The MIT License [see LICENSE for details]
26 | # Written by Ross Girshick
27 | # --------------------------------------------------------
28 | """
29 |
30 | class AttrDict(dict):
31 |
32 | IMMUTABLE = '__immutable__'
33 |
34 | def __init__(self, *args, **kwargs):
35 | super(AttrDict, self).__init__(*args, **kwargs)
36 | self.__dict__[AttrDict.IMMUTABLE] = False
37 |
38 | def __getattr__(self, name):
39 | if name in self.__dict__:
40 | return self.__dict__[name]
41 | elif name in self:
42 | return self[name]
43 | else:
44 | raise AttributeError(name)
45 |
46 | def __setattr__(self, name, value):
47 | if not self.__dict__[AttrDict.IMMUTABLE]:
48 | if name in self.__dict__:
49 | self.__dict__[name] = value
50 | else:
51 | self[name] = value
52 | else:
53 | raise AttributeError(
54 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'.
55 | format(name, value)
56 | )
57 |
58 | def immutable(self, is_immutable):
59 | """Set immutability to is_immutable and recursively apply the setting
60 | to all nested AttrDicts.
61 | """
62 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable
63 | # Recursively set immutable state
64 | for v in self.__dict__.values():
65 | if isinstance(v, AttrDict):
66 | v.immutable(is_immutable)
67 | for v in self.values():
68 | if isinstance(v, AttrDict):
69 | v.immutable(is_immutable)
70 |
71 | def is_immutable(self):
72 | return self.__dict__[AttrDict.IMMUTABLE]
73 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | Miscellanous Functions
3 | """
4 |
5 | import sys
6 | import re
7 | import os
8 | import shutil
9 | import torch
10 | from datetime import datetime
11 | import logging
12 | from subprocess import call
13 | import shlex
14 | from tensorboardX import SummaryWriter
15 | import datasets
16 | import numpy as np
17 | import torchvision.transforms as standard_transforms
18 | import torchvision.utils as vutils
19 | from config import cfg
20 | import random
21 |
22 |
23 | # Create unique output dir name based on non-default command line args
24 | def make_exp_name(args, parser):
25 | exp_name = '{}-{}'.format(args.dataset[:4], args.arch[:])
26 | dict_args = vars(args)
27 |
28 | # sort so that we get a consistent directory name
29 | argnames = sorted(dict_args)
30 | ignorelist = ['date', 'exp', 'arch','prev_best_filepath', 'lr_schedule', 'max_cu_epoch', 'max_epoch',
31 | 'strict_bdr_cls', 'world_size', 'tb_path','best_record', 'test_mode', 'ckpt', 'coarse_boost_classes',
32 | 'crop_size', 'dist_url', 'syncbn', 'max_iter', 'color_aug', 'scale_max', 'scale_min', 'bs_mult',
33 | 'class_uniform_pct', 'class_uniform_tile']
34 | # build experiment name with non-default args
35 | for argname in argnames:
36 | if dict_args[argname] != parser.get_default(argname):
37 | if argname in ignorelist:
38 | continue
39 | if argname == 'snapshot':
40 | arg_str = 'PT'
41 | argname = ''
42 | elif argname == 'nosave':
43 | arg_str = ''
44 | argname=''
45 | elif argname == 'freeze_trunk':
46 | argname = ''
47 | arg_str = 'ft'
48 | elif argname == 'syncbn':
49 | argname = ''
50 | arg_str = 'sbn'
51 | elif argname == 'jointwtborder':
52 | argname = ''
53 | arg_str = 'rlx_loss'
54 | elif isinstance(dict_args[argname], bool):
55 | arg_str = 'T' if dict_args[argname] else 'F'
56 | else:
57 | arg_str = str(dict_args[argname])[:7]
58 | if argname is not '':
59 | exp_name += '_{}_{}'.format(str(argname), arg_str)
60 | else:
61 | exp_name += '_{}'.format(arg_str)
62 | # clean special chars out exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name)
63 | return exp_name
64 |
65 | def fast_hist(label_pred, label_true, num_classes):
66 | mask = (label_true >= 0) & (label_true < num_classes)
67 | hist = np.bincount(
68 | num_classes * label_true[mask].astype(int) +
69 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
70 | return hist
71 |
72 | def per_class_iu(hist):
73 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
74 |
75 | def save_log(prefix, output_dir, date_str, rank=0):
76 | fmt = '%(asctime)s.%(msecs)03d %(message)s'
77 | date_fmt = '%m-%d %H:%M:%S'
78 | filename = os.path.join(output_dir, prefix + '_' + date_str +'_rank_' + str(rank) +'.log')
79 | print("Logging :", filename)
80 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt,
81 | filename=filename, filemode='w')
82 | console = logging.StreamHandler()
83 | console.setLevel(logging.INFO)
84 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt)
85 | console.setFormatter(formatter)
86 | if rank == 0:
87 | logging.getLogger('').addHandler(console)
88 | else:
89 | fh = logging.FileHandler(filename)
90 | logging.getLogger('').addHandler(fh)
91 |
92 |
93 |
94 | def prep_experiment(args, parser):
95 | """
96 | Make output directories, setup logging, Tensorboard, snapshot code.
97 | """
98 | ckpt_path = args.ckpt
99 | tb_path = args.tb_path
100 | exp_name = make_exp_name(args, parser)
101 | args.exp_path = os.path.join(ckpt_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H')))
102 | args.tb_exp_path = os.path.join(tb_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H')))
103 | args.ngpu = torch.cuda.device_count()
104 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
105 | args.best_record = {}
106 | # args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
107 | # 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
108 | args.last_record = {}
109 | if args.local_rank == 0:
110 | os.makedirs(args.exp_path, exist_ok=True)
111 | os.makedirs(args.tb_exp_path, exist_ok=True)
112 | save_log('log', args.exp_path, args.date_str, rank=args.local_rank)
113 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write(
114 | str(args) + '\n\n')
115 | writer = SummaryWriter(log_dir=args.tb_exp_path, comment=args.tb_tag)
116 | return writer
117 | return None
118 |
119 | def evaluate_eval_for_inference(hist, dataset=None):
120 | """
121 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for
122 | large dataset) Only applies to eval/eval.py
123 | """
124 | # axis 0: gt, axis 1: prediction
125 | acc = np.diag(hist).sum() / hist.sum()
126 | acc_cls = np.diag(hist) / hist.sum(axis=1)
127 | acc_cls = np.nanmean(acc_cls)
128 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
129 |
130 | print_evaluate_results(hist, iu, dataset=dataset)
131 | freq = hist.sum(axis=1) / hist.sum()
132 | mean_iu = np.nanmean(iu)
133 | logging.info('mean {}'.format(mean_iu))
134 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
135 | return acc, acc_cls, mean_iu, fwavacc
136 |
137 |
138 |
139 | def evaluate_eval(args, net, optimizer, scheduler, val_loss, hist, dump_images, writer, epoch=0, dataset_name=None, dataset=None, curr_iter=0, optimizer_at=None, scheduler_at=None, save_pth=True):
140 | """
141 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for
142 | large dataset) Only applies to eval/eval.py
143 | """
144 | if val_loss is not None and hist is not None:
145 | # axis 0: gt, axis 1: prediction
146 | acc = np.diag(hist).sum() / hist.sum()
147 | acc_cls = np.diag(hist) / hist.sum(axis=1)
148 | acc_cls = np.nanmean(acc_cls)
149 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
150 |
151 | print_evaluate_results(hist, iu, dataset_name=dataset_name, dataset=dataset)
152 | freq = hist.sum(axis=1) / hist.sum()
153 | mean_iu = np.nanmean(iu)
154 | logging.info('mean {}'.format(mean_iu))
155 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
156 | else:
157 | mean_iu = 0
158 |
159 | if dataset_name not in args.last_record.keys():
160 | args.last_record[dataset_name] = {}
161 |
162 | if save_pth:
163 | # update latest snapshot
164 | if 'mean_iu' in args.last_record[dataset_name]:
165 | last_snapshot = 'last_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(
166 | dataset_name, args.last_record[dataset_name]['epoch'],
167 | args.last_record[dataset_name]['mean_iu'])
168 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
169 | # if dataset_name != "cityscapes":
170 | if dataset_name != "gtav":
171 | try:
172 | os.remove(last_snapshot)
173 | except OSError:
174 | pass
175 |
176 | last_snapshot = 'last_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(dataset_name, epoch, mean_iu)
177 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
178 | args.last_record[dataset_name]['mean_iu'] = mean_iu
179 | args.last_record[dataset_name]['epoch'] = epoch
180 |
181 | torch.cuda.synchronize()
182 |
183 | if optimizer_at is not None:
184 | torch.save({
185 | 'state_dict': net.state_dict(),
186 | 'optimizer': optimizer.state_dict(),
187 | 'optimizer_at': optimizer_at.state_dict(),
188 | 'scheduler': scheduler.state_dict(),
189 | 'scheduler_at': scheduler_at.state_dict(),
190 | 'epoch': epoch,
191 | 'mean_iu': mean_iu,
192 | 'command': ' '.join(sys.argv[1:])
193 | }, last_snapshot)
194 | else:
195 | torch.save({
196 | 'state_dict': net.state_dict(),
197 | 'optimizer': optimizer.state_dict(),
198 | 'scheduler': scheduler.state_dict(),
199 | 'epoch': epoch,
200 | 'mean_iu': mean_iu,
201 | 'command': ' '.join(sys.argv[1:])
202 | }, last_snapshot)
203 |
204 | if val_loss is not None and hist is not None:
205 | if dataset_name not in args.best_record.keys():
206 | args.best_record[dataset_name] = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
207 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
208 | # update best snapshot
209 | if mean_iu > args.best_record[dataset_name]['mean_iu'] :
210 | # remove old best snapshot
211 | if args.best_record[dataset_name]['epoch'] != -1:
212 | best_snapshot = 'best_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(
213 | dataset_name, args.best_record[dataset_name]['epoch'],
214 | args.best_record[dataset_name]['mean_iu'])
215 |
216 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
217 | assert os.path.exists(best_snapshot), \
218 | 'cant find old snapshot {}'.format(best_snapshot)
219 | os.remove(best_snapshot)
220 |
221 | # save new best
222 | args.best_record[dataset_name]['val_loss'] = val_loss.avg
223 | args.best_record[dataset_name]['epoch'] = epoch
224 | args.best_record[dataset_name]['acc'] = acc
225 | args.best_record[dataset_name]['acc_cls'] = acc_cls
226 | args.best_record[dataset_name]['mean_iu'] = mean_iu
227 | args.best_record[dataset_name]['fwavacc'] = fwavacc
228 |
229 | best_snapshot = 'best_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(
230 | dataset_name, args.best_record[dataset_name]['epoch'],
231 | args.best_record[dataset_name]['mean_iu'])
232 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
233 | shutil.copyfile(last_snapshot, best_snapshot)
234 | else:
235 | logging.info("Saved file to {}".format(last_snapshot))
236 |
237 | if val_loss is not None and hist is not None:
238 | logging.info('-' * 107)
239 | fmt_str = '[epoch %d], [dataset name %s], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\
240 | '[mean_iu %.5f], [fwavacc %.5f]'
241 | logging.info(fmt_str % (epoch, dataset_name, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))
242 | if save_pth:
243 | fmt_str = 'best record: [dataset name %s], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\
244 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], '
245 | logging.info(fmt_str % (dataset_name,
246 | args.best_record[dataset_name]['val_loss'], args.best_record[dataset_name]['acc'],
247 | args.best_record[dataset_name]['acc_cls'], args.best_record[dataset_name]['mean_iu'],
248 | args.best_record[dataset_name]['fwavacc'], args.best_record[dataset_name]['epoch']))
249 | logging.info('-' * 107)
250 |
251 | if writer:
252 | # tensorboard logging of validation phase metrics
253 | writer.add_scalar('{}/acc'.format(dataset_name), acc, curr_iter)
254 | writer.add_scalar('{}/acc_cls'.format(dataset_name), acc_cls, curr_iter)
255 | writer.add_scalar('{}/mean_iu'.format(dataset_name), mean_iu, curr_iter)
256 | writer.add_scalar('{}/val_loss'.format(dataset_name), val_loss.avg, curr_iter)
257 |
258 |
259 |
260 |
261 |
262 | def print_evaluate_results(hist, iu, dataset_name=None, dataset=None):
263 | # fixme: Need to refactor this dict
264 | try:
265 | id2cat = dataset.id2cat
266 | except:
267 | id2cat = {i: i for i in range(datasets.num_classes)}
268 | iu_false_positive = hist.sum(axis=1) - np.diag(hist)
269 | iu_false_negative = hist.sum(axis=0) - np.diag(hist)
270 | iu_true_positive = np.diag(hist)
271 |
272 | logging.info('Dataset name: {}'.format(dataset_name))
273 | logging.info('IoU:')
274 | logging.info('label_id label iU Precision Recall TP FP FN')
275 | for idx, i in enumerate(iu):
276 | # Format all of the strings:
277 | idx_string = "{:2d}".format(idx)
278 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else ''
279 | iu_string = '{:5.1f}'.format(i * 100)
280 | total_pixels = hist.sum()
281 | tp = '{:5.1f}'.format(100 * iu_true_positive[idx] / total_pixels)
282 | fp = '{:5.1f}'.format(
283 | iu_false_positive[idx] / iu_true_positive[idx])
284 | fn = '{:5.1f}'.format(iu_false_negative[idx] / iu_true_positive[idx])
285 | precision = '{:5.1f}'.format(
286 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx]))
287 | recall = '{:5.1f}'.format(
288 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx]))
289 | logging.info('{} {} {} {} {} {} {} {}'.format(
290 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn))
291 |
292 |
293 |
294 |
295 | class AverageMeter(object):
296 |
297 | def __init__(self):
298 | self.reset()
299 |
300 | def reset(self):
301 | self.val = 0
302 | self.avg = 0
303 | self.sum = 0
304 | self.count = 0
305 |
306 | def update(self, val, n=1):
307 | self.val = val
308 | self.sum += val * n
309 | self.count += n
310 | self.avg = self.sum / self.count
311 |
--------------------------------------------------------------------------------
/utils/my_data_parallel.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | # Code adapted from:
4 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py
5 | #
6 | # BSD 3-Clause License
7 | #
8 | # Copyright (c) 2017,
9 | # All rights reserved.
10 | #
11 | # Redistribution and use in source and binary forms, with or without
12 | # modification, are permitted provided that the following conditions are met:
13 | #
14 | # * Redistributions of source code must retain the above copyright notice, this
15 | # list of conditions and the following disclaimer.
16 | #
17 | # * Redistributions in binary form must reproduce the above copyright notice,
18 | # this list of conditions and the following disclaimer in the documentation
19 | # and/or other materials provided with the distribution.
20 | #
21 | # * Neither the name of the copyright holder nor the names of its
22 | # contributors may be used to endorse or promote products derived from
23 | # this software without specific prior written permission.
24 | #
25 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s
35 | """
36 |
37 |
38 | import operator
39 | import torch
40 | import warnings
41 | from torch.nn.modules import Module
42 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
43 | from torch.nn.parallel.replicate import replicate
44 | from torch.nn.parallel.parallel_apply import parallel_apply
45 |
46 |
47 | def _check_balance(device_ids):
48 | imbalance_warn = """
49 | There is an imbalance between your GPUs. You may want to exclude GPU {} which
50 | has less than 75% of the memory or cores of GPU {}. You can do so by setting
51 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
52 | environment variable."""
53 |
54 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
55 |
56 | def warn_imbalance(get_prop):
57 | values = [get_prop(props) for props in dev_props]
58 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
59 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
60 | if min_val / max_val < 0.75:
61 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
62 | return True
63 | return False
64 |
65 | if warn_imbalance(lambda props: props.total_memory):
66 | return
67 | if warn_imbalance(lambda props: props.multi_processor_count):
68 | return
69 |
70 |
71 |
72 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, gather=True):
73 | """
74 | Evaluates module(input) in parallel across the GPUs given in device_ids.
75 | This is the functional version of the DataParallel module.
76 | Args:
77 | module: the module to evaluate in parallel
78 | inputs: inputs to the module
79 | device_ids: GPU ids on which to replicate module
80 | output_device: GPU location of the output Use -1 to indicate the CPU.
81 | (default: device_ids[0])
82 | Returns:
83 | a Tensor containing the result of module(input) located on
84 | output_device
85 | """
86 | if not isinstance(inputs, tuple):
87 | inputs = (inputs,)
88 |
89 | if device_ids is None:
90 | device_ids = list(range(torch.cuda.device_count()))
91 |
92 | if output_device is None:
93 | output_device = device_ids[0]
94 |
95 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
96 | if len(device_ids) == 1:
97 | return module(*inputs[0], **module_kwargs[0])
98 | used_device_ids = device_ids[:len(inputs)]
99 | replicas = replicate(module, used_device_ids)
100 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
101 | if gather:
102 | return gather(outputs, output_device, dim)
103 | else:
104 | return outputs
105 |
106 |
107 |
108 | class MyDataParallel(Module):
109 | """
110 | Implements data parallelism at the module level.
111 | This container parallelizes the application of the given module by
112 | splitting the input across the specified devices by chunking in the batch
113 | dimension. In the forward pass, the module is replicated on each device,
114 | and each replica handles a portion of the input. During the backwards
115 | pass, gradients from each replica are summed into the original module.
116 | The batch size should be larger than the number of GPUs used.
117 | See also: :ref:`cuda-nn-dataparallel-instead`
118 | Arbitrary positional and keyword inputs are allowed to be passed into
119 | DataParallel EXCEPT Tensors. All tensors will be scattered on dim
120 | specified (default 0). Primitive types will be broadcasted, but all
121 | other types will be a shallow copy and can be corrupted if written to in
122 | the model's forward pass.
123 | .. warning::
124 | Forward and backward hooks defined on :attr:`module` and its submodules
125 | will be invoked ``len(device_ids)`` times, each with inputs located on
126 | a particular device. Particularly, the hooks are only guaranteed to be
127 | executed in correct order with respect to operations on corresponding
128 | devices. For example, it is not guaranteed that hooks set via
129 | :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
130 | `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
131 | that each such hook be executed before the corresponding
132 | :meth:`~torch.nn.Module.forward` call of that device.
133 | .. warning::
134 | When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
135 | :func:`forward`, this wrapper will return a vector of length equal to
136 | number of devices used in data parallelism, containing the result from
137 | each device.
138 | .. note::
139 | There is a subtlety in using the
140 | ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
141 | :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
142 | See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
143 | details.
144 | Args:
145 | module: module to be parallelized
146 | device_ids: CUDA devices (default: all devices)
147 | output_device: device location of output (default: device_ids[0])
148 | Attributes:
149 | module (Module): the module to be parallelized
150 | Example::
151 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
152 | >>> output = net(input_var)
153 | """
154 |
155 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
156 |
157 | def __init__(self, module, device_ids=None, output_device=None, dim=0, gather=True):
158 | super(MyDataParallel, self).__init__()
159 |
160 | if not torch.cuda.is_available():
161 | self.module = module
162 | self.device_ids = []
163 | return
164 |
165 | if device_ids is None:
166 | device_ids = list(range(torch.cuda.device_count()))
167 | if output_device is None:
168 | output_device = device_ids[0]
169 | self.dim = dim
170 | self.module = module
171 | self.device_ids = device_ids
172 | self.output_device = output_device
173 | self.gather_bool = gather
174 |
175 | _check_balance(self.device_ids)
176 |
177 | if len(self.device_ids) == 1:
178 | self.module.cuda(device_ids[0])
179 |
180 | def forward(self, *inputs, **kwargs):
181 | if not self.device_ids:
182 | return self.module(*inputs, **kwargs)
183 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
184 | if len(self.device_ids) == 1:
185 | return [self.module(*inputs[0], **kwargs[0])]
186 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
187 | outputs = self.parallel_apply(replicas, inputs, kwargs)
188 | if self.gather_bool:
189 | return self.gather(outputs, self.output_device)
190 | else:
191 | return outputs
192 |
193 | def replicate(self, module, device_ids):
194 | return replicate(module, device_ids)
195 |
196 | def scatter(self, inputs, kwargs, device_ids):
197 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
198 |
199 | def parallel_apply(self, replicas, inputs, kwargs):
200 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
201 |
202 | def gather(self, outputs, output_device):
203 | return gather(outputs, output_device, dim=self.dim)
204 |
205 |
--------------------------------------------------------------------------------
/valid.py:
--------------------------------------------------------------------------------
1 | """
2 | training code
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | import argparse
7 | import logging
8 | import os
9 | import torch
10 |
11 | from config import cfg, assert_and_infer_cfg
12 | from utils.misc import AverageMeter, prep_experiment, evaluate_eval, fast_hist
13 | import datasets
14 | import loss
15 | import network
16 | import optimizer
17 | import time
18 | import torchvision.utils as vutils
19 | import torch.nn.functional as F
20 | import numpy as np
21 | import random
22 | import pdb
23 |
24 | # Argument Parser
25 | parser = argparse.ArgumentParser(description='Semantic Segmentation')
26 | parser.add_argument('--lr', type=float, default=0.01)
27 | parser.add_argument('--arch', type=str, default='network.deepv3.DeepR50V3PlusD',
28 | help='Network architecture.')
29 | parser.add_argument('--dataset', nargs='*', type=str, default=['gtav'],
30 | help='a list of datasets; cityscapes, mapillary, gtav, bdd100k, synthia')
31 | parser.add_argument('--image_uniform_sampling', action='store_true', default=False,
32 | help='uniformly sample images across the multiple source domains')
33 | parser.add_argument('--val_dataset', nargs='*', type=str, default=['bdd100k'],
34 | help='a list consists of cityscapes, mapillary, gtav, bdd100k, synthia')
35 | parser.add_argument('--wild_dataset', nargs='*', type=str, default=['imagenet'],
36 | help='a list consists of imagenet')
37 | parser.add_argument('--cv', type=int, default=0,
38 | help='cross-validation split id to use. Default # of splits set to 3 in config')
39 | parser.add_argument('--class_uniform_pct', type=float, default=0,
40 | help='What fraction of images is uniformly sampled')
41 | parser.add_argument('--class_uniform_tile', type=int, default=1024,
42 | help='tile size for class uniform sampling')
43 | parser.add_argument('--coarse_boost_classes', type=str, default=None,
44 | help='use coarse annotations to boost fine data with specific classes')
45 |
46 | parser.add_argument('--img_wt_loss', action='store_true', default=False,
47 | help='per-image class-weighted loss')
48 | parser.add_argument('--cls_wt_loss', action='store_true', default=False,
49 | help='class-weighted loss')
50 | parser.add_argument('--batch_weighting', action='store_true', default=False,
51 | help='Batch weighting for class (use nll class weighting using batch stats')
52 |
53 | parser.add_argument('--jointwtborder', action='store_true', default=False,
54 | help='Enable boundary label relaxation')
55 | parser.add_argument('--strict_bdr_cls', type=str, default='',
56 | help='Enable boundary label relaxation for specific classes')
57 | parser.add_argument('--rlx_off_iter', type=int, default=-1,
58 | help='Turn off border relaxation after specific epoch count')
59 | parser.add_argument('--rescale', type=float, default=1.0,
60 | help='Warm Restarts new learning rate ratio compared to original lr')
61 | parser.add_argument('--repoly', type=float, default=1.5,
62 | help='Warm Restart new poly exp')
63 |
64 | parser.add_argument('--fp16', action='store_true', default=False,
65 | help='Use Nvidia Apex AMP')
66 | parser.add_argument('--local_rank', default=0, type=int,
67 | help='parameter used by apex library')
68 |
69 | parser.add_argument('--sgd', action='store_true', default=False)
70 | parser.add_argument('--adam', action='store_true', default=False)
71 | parser.add_argument('--amsgrad', action='store_true', default=False)
72 |
73 | parser.add_argument('--freeze_trunk', action='store_true', default=False)
74 | parser.add_argument('--hardnm', default=0, type=int,
75 | help='0 means no aug, 1 means hard negative mining iter 1,' +
76 | '2 means hard negative mining iter 2')
77 |
78 | parser.add_argument('--trunk', type=str, default='resnet-50',
79 | help='trunk model, can be: resnet-50 (default)')
80 | parser.add_argument('--max_epoch', type=int, default=180)
81 | parser.add_argument('--max_iter', type=int, default=30000)
82 | parser.add_argument('--max_cu_epoch', type=int, default=100000,
83 | help='Class Uniform Max Epochs')
84 | parser.add_argument('--start_epoch', type=int, default=0)
85 | parser.add_argument('--crop_nopad', action='store_true', default=False)
86 | parser.add_argument('--rrotate', type=int,
87 | default=0, help='degree of random roate')
88 | parser.add_argument('--color_aug', type=float,
89 | default=0.0, help='level of color augmentation')
90 | parser.add_argument('--gblur', action='store_true', default=False,
91 | help='Use Guassian Blur Augmentation')
92 | parser.add_argument('--bblur', action='store_true', default=False,
93 | help='Use Bilateral Blur Augmentation')
94 | parser.add_argument('--lr_schedule', type=str, default='poly',
95 | help='name of lr schedule: poly')
96 | parser.add_argument('--poly_exp', type=float, default=0.9,
97 | help='polynomial LR exponent')
98 | parser.add_argument('--bs_mult', type=int, default=2,
99 | help='Batch size for training per gpu')
100 | parser.add_argument('--bs_mult_val', type=int, default=1,
101 | help='Batch size for Validation per gpu')
102 | parser.add_argument('--crop_size', type=int, default=720,
103 | help='training crop size')
104 | parser.add_argument('--pre_size', type=int, default=None,
105 | help='resize image shorter edge to this before augmentation')
106 | parser.add_argument('--scale_min', type=float, default=0.5,
107 | help='dynamically scale training images down to this size')
108 | parser.add_argument('--scale_max', type=float, default=2.0,
109 | help='dynamically scale training images up to this size')
110 | parser.add_argument('--weight_decay', type=float, default=5e-4)
111 | parser.add_argument('--momentum', type=float, default=0.9)
112 | parser.add_argument('--snapshot', type=str, default=None)
113 | parser.add_argument('--restore_optimizer', action='store_true', default=False)
114 |
115 | parser.add_argument('--city_mode', type=str, default='train',
116 | help='experiment directory date name')
117 | parser.add_argument('--date', type=str, default='default',
118 | help='experiment directory date name')
119 | parser.add_argument('--exp', type=str, default='default',
120 | help='experiment directory name')
121 | parser.add_argument('--tb_tag', type=str, default='',
122 | help='add tag to tb dir')
123 | parser.add_argument('--ckpt', type=str, default='logs/ckpt',
124 | help='Save Checkpoint Point')
125 | parser.add_argument('--tb_path', type=str, default='logs/tb',
126 | help='Save Tensorboard Path')
127 | parser.add_argument('--syncbn', action='store_true', default=True,
128 | help='Use Synchronized BN')
129 | parser.add_argument('--dump_augmentation_images', action='store_true', default=False,
130 | help='Dump Augmentated Images for sanity check')
131 | parser.add_argument('--test_mode', action='store_true', default=False,
132 | help='Minimum testing to verify nothing failed, ' +
133 | 'Runs code for 1 epoch of train and val')
134 | parser.add_argument('-wb', '--wt_bound', type=float, default=1.0,
135 | help='Weight Scaling for the losses')
136 | parser.add_argument('--maxSkip', type=int, default=0,
137 | help='Skip x number of frames of video augmented dataset')
138 | parser.add_argument('--scf', action='store_true', default=False,
139 | help='scale correction factor')
140 | parser.add_argument('--dist_url', default='tcp://127.0.0.1:', type=str,
141 | help='url used to set up distributed training')
142 |
143 | parser.add_argument('--image_in', action='store_true', default=False,
144 | help='Input Image Instance Norm')
145 |
146 | parser.add_argument('--fs_layer', nargs='*', type=int, default=[0,0,0,0,0],
147 | help='0: None, 1: AdaIN')
148 | parser.add_argument('--lambda_cel', type=float, default=0.0,
149 | help='lambda for content extension learning loss')
150 | parser.add_argument('--lambda_sel', type=float, default=0.0,
151 | help='lambda for style extension learning loss')
152 | parser.add_argument('--lambda_scr', type=float, default=0.0,
153 | help='lambda for semantic consistency regularization loss')
154 | parser.add_argument('--cont_proj_head', type=int, default=0,
155 | help='number of output channels of content projection head')
156 | parser.add_argument('--wild_cont_dict_size', type=int, default=0,
157 | help='wild-content dictionary size')
158 |
159 | parser.add_argument('--use_fs', action='store_true', default=False,
160 | help='Automatic setting from fs_layer. feature stylization with wild dataset')
161 | parser.add_argument('--use_scr', action='store_true', default=False,
162 | help='Automatic setting from lambda_scr')
163 | parser.add_argument('--use_sel', action='store_true', default=False,
164 | help='Automatic setting from lambda_sel')
165 | parser.add_argument('--use_cel', action='store_true', default=False,
166 | help='Automatic setting from lambda_cel')
167 |
168 | args = parser.parse_args()
169 |
170 | random_seed = cfg.RANDOM_SEED #304
171 | torch.manual_seed(random_seed)
172 | torch.cuda.manual_seed(random_seed)
173 | torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
174 | torch.backends.cudnn.deterministic = True
175 | torch.backends.cudnn.benchmark = False
176 | np.random.seed(random_seed)
177 | random.seed(random_seed)
178 |
179 | args.world_size = 1
180 |
181 | # Test Mode run two epochs with a few iterations of training and val
182 | if args.test_mode:
183 | args.max_epoch = 2
184 |
185 | if 'WORLD_SIZE' in os.environ:
186 | # args.apex = int(os.environ['WORLD_SIZE']) > 1
187 | args.world_size = int(os.environ['WORLD_SIZE'])
188 | print("Total world size: ", int(os.environ['WORLD_SIZE']))
189 |
190 | torch.cuda.set_device(args.local_rank)
191 | print('My Rank:', args.local_rank)
192 | # Initialize distributed communication
193 | args.dist_url = args.dist_url + str(8000 + (int(time.time()%1000))//10)
194 |
195 | torch.distributed.init_process_group(backend='nccl',
196 | init_method='env://',
197 | world_size=args.world_size,
198 | rank=args.local_rank)
199 | # torch.distributed.init_process_group(backend='nccl',
200 | # init_method=args.dist_url,
201 | # world_size=args.world_size,
202 | # rank=args.local_rank)
203 |
204 | def main():
205 | """
206 | Main Function
207 | """
208 | # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
209 | assert_and_infer_cfg(args)
210 | prep_experiment(args, parser)
211 | writer = None
212 |
213 | _, val_loaders, _, _, extra_val_loaders = datasets.setup_loaders(args)
214 |
215 | criterion, criterion_val = loss.get_loss(args)
216 | criterion_aux = loss.get_loss_aux(args)
217 | net = network.get_net(args, criterion, criterion_aux, args.cont_proj_head, args.wild_cont_dict_size)
218 |
219 | optim, scheduler = optimizer.get_optimizer(args, net)
220 |
221 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
222 | net = network.warp_network_in_dataparallel(net, args.local_rank)
223 | epoch = 0
224 | i = 0
225 |
226 | if args.snapshot:
227 | epoch, mean_iu = optimizer.load_weights(net, optim, scheduler,
228 | args.snapshot, args.restore_optimizer)
229 |
230 | print("#### iteration", i)
231 | torch.cuda.empty_cache()
232 |
233 | for dataset, val_loader in val_loaders.items():
234 | validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False)
235 |
236 | for dataset, val_loader in extra_val_loaders.items():
237 | print("Extra validating... This won't save pth file")
238 | validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False)
239 |
240 | def validate(val_loader, dataset, net, criterion, optim, scheduler, curr_epoch, writer, curr_iter, save_pth=True):
241 | """
242 | Runs the validation loop after each training epoch
243 | val_loader: Data loader for validation
244 | dataset: dataset name (str)
245 | net: thet network
246 | criterion: loss fn
247 | optimizer: optimizer
248 | curr_epoch: current epoch
249 | writer: tensorboard writer
250 | return: val_avg for step function if required
251 | """
252 |
253 | net.eval()
254 | val_loss = AverageMeter()
255 | iou_acc = 0
256 | error_acc = 0
257 | dump_images = []
258 |
259 | for val_idx, data in enumerate(val_loader):
260 |
261 | inputs, gt_image, img_names, _ = data
262 |
263 | if len(inputs.shape) == 5:
264 | B, D, C, H, W = inputs.shape
265 | inputs = inputs.view(-1, C, H, W)
266 | gt_image = gt_image.view(-1, 1, H, W)
267 |
268 | assert len(inputs.size()) == 4 and len(gt_image.size()) == 3
269 | assert inputs.size()[2:] == gt_image.size()[1:]
270 |
271 | batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3)
272 | inputs, gt_cuda = inputs.cuda(), gt_image.cuda()
273 |
274 | with torch.no_grad():
275 | output = net(inputs)
276 |
277 | del inputs
278 |
279 | assert output.size()[2:] == gt_image.size()[1:]
280 | assert output.size()[1] == datasets.num_classes
281 |
282 | val_loss.update(criterion(output, gt_cuda).item(), batch_pixel_size)
283 |
284 | del gt_cuda
285 |
286 | # Collect data from different GPU to a single GPU since
287 | # encoding.parallel.criterionparallel function calculates distributed loss
288 | # functions
289 | predictions = output.data.max(1)[1].cpu()
290 |
291 | # Logging
292 | if val_idx % 20 == 0:
293 | if args.local_rank == 0:
294 | logging.info("validating: %d / %d", val_idx + 1, len(val_loader))
295 | if val_idx > 10 and args.test_mode:
296 | break
297 |
298 | # Image Dumps
299 | if val_idx < 10:
300 | dump_images.append([gt_image, predictions, img_names])
301 |
302 | iou_acc += fast_hist(predictions.numpy().flatten(), gt_image.numpy().flatten(),
303 | datasets.num_classes)
304 | del output, val_idx, data
305 |
306 | iou_acc_tensor = torch.cuda.FloatTensor(iou_acc)
307 | torch.distributed.all_reduce(iou_acc_tensor, op=torch.distributed.ReduceOp.SUM)
308 | iou_acc = iou_acc_tensor.cpu().numpy()
309 |
310 | if args.local_rank == 0:
311 | evaluate_eval(args, net, optim, scheduler, val_loss, iou_acc, dump_images,
312 | writer, curr_epoch, dataset, None, curr_iter, save_pth=save_pth)
313 |
314 | return val_loss.avg
315 |
316 |
317 | if __name__ == '__main__':
318 | main()
319 |
--------------------------------------------------------------------------------