├── README.md ├── assets └── comparison.png ├── config.py ├── datasets ├── __init__.py ├── bdd100k.py ├── cityscapes.py ├── cityscapes_labels.py ├── gtav.py ├── imagenet.py ├── mapillary.py ├── multi_loader.py ├── nullloader.py ├── sampler.py ├── synthia.py └── uniform.py ├── loss.py ├── network ├── Resnet.py ├── __init__.py ├── adain.py ├── cel.py ├── deepv3.py └── mynn.py ├── optimizer.py ├── scripts ├── train_wildnet_r50os16_gtav.sh └── valid_wildnet_r50os16_gtav.sh ├── split_data ├── gtav_split_test.txt ├── gtav_split_train.txt ├── gtav_split_val.txt ├── synthia_split_train.txt └── synthia_split_val.txt ├── train.py ├── transforms ├── __init__.py ├── joint_transforms.py └── transforms.py ├── utils ├── __init__.py ├── attr_dict.py ├── misc.py └── my_data_parallel.py └── valid.py /README.md: -------------------------------------------------------------------------------- 1 | ## WildNet (CVPR 2022): Official Project Webpage 2 | This repository provides the official PyTorch implementation of the following paper: 3 | > [**WildNet: Learning Domain Generalized Semantic Segmentation from the Wild**](https://arxiv.org/abs/2204.01446)
4 | > [Suhyeon Lee](https://suhyeonlee.github.io/), [Hongje Seong](https://hongje.github.io/), [Seongwon Lee](https://sungonce.github.io/), [Euntai Kim](https://cilab.yonsei.ac.kr/)
5 | > Yonsei University
6 | 7 | > **Abstract:** 8 | *We present a new domain generalized semantic segmentation network named WildNet, which learns domain-generalized features by leveraging a variety of contents and styles from the wild. 9 | In domain generalization, the low generalization ability for unseen target domains is clearly due to overfitting to the source domain. 10 | To address this problem, previous works have focused on generalizing the domain by removing or diversifying the styles of the source domain. 11 | These alleviated overfitting to the source-style but overlooked overfitting to the source-content. 12 | In this paper, we propose to diversify both the content and style of the source domain with the help of the wild. 13 | Our main idea is for networks to naturally learn domain-generalized semantic information from the wild. 14 | To this end, we diversify styles by augmenting source features to resemble wild styles and enable networks to adapt to a variety of styles. 15 | Furthermore, we encourage networks to learn class-discriminant features by providing semantic variations borrowed from the wild to source contents in the feature space. 16 | Finally, we regularize networks to capture consistent semantic information even when both the content and style of the source domain are extended to the wild. 17 | Extensive experiments on five different datasets validate the effectiveness of our WildNet, and we significantly outperform state-of-the-art methods.*
18 | 19 |

20 | 21 |

22 | 23 | ## Pytorch Implementation 24 | Our pytorch implementation is heavily derived from [RobustNet](https://github.com/shachoi/RobustNet) (CVPR 2021). If you use this code in your research, please also cite their work. 25 | [[link to license](https://github.com/shachoi/RobustNet/blob/main/LICENSE)] 26 | 27 | ### Installation 28 | Clone this repository. 29 | ``` 30 | git clone https://github.com/suhyeonlee/WildNet.git 31 | cd WildNet 32 | ``` 33 | Install following packages. 34 | ``` 35 | conda create --name wildnet python=3.7 36 | conda activate wildnet 37 | conda install pytorch==1.9.1 torchvision==0.10.1 torchaudio==0.9.1 cudatoolkit=11.1 -c pytorch 38 | conda install scipy==1.1.0 39 | conda install tqdm==4.46.0 40 | conda install scikit-image==0.16.2 41 | pip install tensorboardX 42 | pip install thop 43 | pip install kmeans1d 44 | imageio_download_bin freeimage 45 | ``` 46 | ### How to Run WildNet 47 | We trained our model with the source domain ([GTAV](https://download.visinf.tu-darmstadt.de/data/from_games/) or [Cityscapes](https://www.cityscapes-dataset.com/)) and the wild domain ([ImageNet](https://www.image-net.org/)). 48 | Then we evaluated the model on [Cityscapes](https://www.cityscapes-dataset.com/), [BDD-100K](https://bair.berkeley.edu/blog/2018/05/30/bdd/), [Synthia](https://synthia-dataset.net/downloads/) ([SYNTHIA-RAND-CITYSCAPES](http://synthia-dataset.net/download/808/)), [GTAV](https://download.visinf.tu-darmstadt.de/data/from_games/) and [Mapillary Vistas](https://www.mapillary.com/dataset/vistas?pKey=2ix3yvnjy9fwqdzwum3t9g&lat=20&lng=0&z=1.5). 49 | 50 | We adopt Class uniform sampling proposed in [this paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhu_Improving_Semantic_Segmentation_via_Video_Propagation_and_Label_Relaxation_CVPR_2019_paper.pdf) to handle class imbalance problems. 51 | 52 | 53 | 1. For Cityscapes dataset, download "leftImg8bit_trainvaltest.zip" and "gtFine_trainvaltest.zip" from https://www.cityscapes-dataset.com/downloads/
54 | Unzip the files and make the directory structures as follows. 55 | ``` 56 | cityscapes 57 | └ leftImg8bit_trainvaltest 58 | └ leftImg8bit 59 | └ train 60 | └ val 61 | └ test 62 | └ gtFine_trainvaltest 63 | └ gtFine 64 | └ train 65 | └ val 66 | └ test 67 | ``` 68 | ``` 69 | bdd-100k 70 | └ images 71 | └ train 72 | └ val 73 | └ test 74 | └ labels 75 | └ train 76 | └ val 77 | ``` 78 | ``` 79 | mapillary 80 | └ training 81 | └ images 82 | └ labels 83 | └ validation 84 | └ images 85 | └ labels 86 | └ test 87 | └ images 88 | └ labels 89 | ``` 90 | ``` 91 | imagenet 92 | └ data 93 | └ train 94 | └ val 95 | ``` 96 | 97 | 2. We used [GTAV_Split](https://download.visinf.tu-darmstadt.de/data/from_games/code/read_mapping.zip) to split GTAV dataset into training/validation/test set. Please refer the txt files in [split_data](https://github.com/suhyeonlee/WildNet/tree/main/split_data). 98 | 99 | ``` 100 | GTAV 101 | └ images 102 | └ train 103 | └ folder 104 | └ valid 105 | └ folder 106 | └ test 107 | └ folder 108 | └ labels 109 | └ train 110 | └ folder 111 | └ valid 112 | └ folder 113 | └ test 114 | └ folder 115 | ``` 116 | 117 | 3. We split [Synthia dataset](http://synthia-dataset.net/download/808/) into train/val set following the [RobustNet](https://github.com/shachoi/RobustNet). Please refer the txt files in [split_data](https://github.com/suhyeonlee/WildNet/tree/main/split_data). 118 | 119 | ``` 120 | synthia 121 | └ RGB 122 | └ train 123 | └ val 124 | └ GT 125 | └ COLOR 126 | └ train 127 | └ val 128 | └ LABELS 129 | └ train 130 | └ val 131 | ``` 132 | 133 | 4. You should modify the path in **"/config.py"** according to your dataset path. 134 | ``` 135 | #Cityscapes Dir Location 136 | __C.DATASET.CITYSCAPES_DIR = 137 | #Mapillary Dataset Dir Location 138 | __C.DATASET.MAPILLARY_DIR = 139 | #GTAV Dataset Dir Location 140 | __C.DATASET.GTAV_DIR = 141 | #BDD-100K Dataset Dir Location 142 | __C.DATASET.BDD_DIR = 143 | #Synthia Dataset Dir Location 144 | __C.DATASET.SYNTHIA_DIR = 145 | #ImageNet Dataset Dir Location 146 | __C.DATASET.ImageNet_DIR = 147 | ``` 148 | 149 | 5. You can train WildNet with the following command. 150 | ``` 151 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/train_wildnet_r50os16_gtav.sh 152 | ``` 153 | 154 | 6. You can download our ResNet-50 model at [Google Drive](https://drive.google.com/file/d/16V_cWtVbJuJoQ-DJq4Yui7Umf4wggemu/view) and validate pretrained model with the following command. 155 | ``` 156 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/valid_wildnet_r50os16_gtav.sh 157 | ``` 158 | 159 | ## Citation 160 | If you find this work useful in your research, please cite our paper: 161 | ``` 162 | @inproceedings{lee2022wildnet, 163 | title={WildNet: Learning Domain Generalized Semantic Segmentation from the Wild}, 164 | author={Lee, Suhyeon and Seong, Hongje and Lee, Seongwon and Kim, Euntai}, 165 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 166 | year={2022} 167 | } 168 | ``` 169 | 170 | ## Terms of Use 171 | This software is for non-commercial use only. 172 | The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence 173 | (see [this](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for details) 174 | 175 | -------------------------------------------------------------------------------- /assets/comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suhyeonlee/WildNet/c83aa92b7cd591512045fc5093da25a280bde430/assets/comparison.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | ############################################################################## 30 | #Config 31 | ############################################################################## 32 | 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | from __future__ import unicode_literals 38 | 39 | 40 | import torch 41 | 42 | 43 | from utils.attr_dict import AttrDict 44 | 45 | 46 | __C = AttrDict() 47 | cfg = __C 48 | __C.ITER = 0 49 | __C.EPOCH = 0 50 | 51 | __C.RANDOM_SEED = 304 52 | # Use Class Uniform Sampling to give each class proper sampling 53 | __C.CLASS_UNIFORM_PCT = 0.0 54 | 55 | # Use class weighted loss per batch to increase loss for low pixel count classes per batch 56 | __C.BATCH_WEIGHTING = False 57 | 58 | # Border Relaxation Count 59 | __C.BORDER_WINDOW = 1 60 | # Number of epoch to use before turn off border restriction 61 | __C.REDUCE_BORDER_ITER = -1 62 | __C.REDUCE_BORDER_EPOCH = -1 63 | # Comma Seperated List of class id to relax 64 | __C.STRICTBORDERCLASS = None 65 | 66 | #Attribute Dictionary for Dataset 67 | __C.DATASET = AttrDict() 68 | #Cityscapes Dir Location 69 | __C.DATASET.CITYSCAPES_DIR = '/mnt/hdd2/DGSS/cityscapes' 70 | #SDC Augmented Cityscapes Dir Location 71 | __C.DATASET.CITYSCAPES_AUG_DIR = '' 72 | #Mapillary Dataset Dir Location 73 | __C.DATASET.MAPILLARY_DIR = '/mnt/hdd2/DGSS/mapillary' 74 | #GTAV Dataset Dir Location 75 | __C.DATASET.GTAV_DIR = '/mnt/hdd2/DGSS/gta5' 76 | #BDD-100K Dataset Dir Location 77 | __C.DATASET.BDD_DIR = '/mnt/hdd2/DGSS/bdd100k/seg' 78 | #Synthia Dataset Dir Location 79 | __C.DATASET.SYNTHIA_DIR = '/mnt/hdd2/DGSS/synthia' 80 | #ImageNet Dataset Dir Location 81 | __C.DATASET.IMAGENET_DIR ='/mnt/hdd2/ImageNet/data' 82 | #Kitti Dataset Dir Location 83 | __C.DATASET.KITTI_DIR = '' 84 | #SDC Augmented Kitti Dataset Dir Location 85 | __C.DATASET.KITTI_AUG_DIR = '' 86 | #Camvid Dataset Dir Location 87 | __C.DATASET.CAMVID_DIR = '' 88 | #Number of splits to support 89 | __C.DATASET.CV_SPLITS = 3 90 | 91 | 92 | __C.MODEL = AttrDict() 93 | __C.MODEL.BN = 'pytorch-syncnorm' 94 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm 95 | 96 | def assert_and_infer_cfg(args, make_immutable=True, train_mode=True): 97 | """Call this function in your script after you have finished setting all cfg 98 | values that are necessary (e.g., merging a config from a file, merging 99 | command line config options, etc.). By default, this function will also 100 | mark the global cfg as immutable to prevent changing the global cfg settings 101 | during script execution (which can lead to hard to debug errors or code 102 | that's harder to understand than is necessary). 103 | """ 104 | 105 | if hasattr(args, 'syncbn') and args.syncbn: 106 | __C.MODEL.BN = 'pytorch-syncnorm' 107 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm 108 | print('Using pytorch sync batch norm') 109 | else: 110 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d 111 | print('Using regular batch norm') 112 | 113 | if not train_mode: 114 | cfg.immutable(True) 115 | return 116 | if args.class_uniform_pct: 117 | cfg.CLASS_UNIFORM_PCT = args.class_uniform_pct 118 | 119 | if args.batch_weighting: 120 | __C.BATCH_WEIGHTING = True 121 | 122 | if args.jointwtborder: 123 | if args.strict_bdr_cls != '': 124 | __C.STRICTBORDERCLASS = [int(i) for i in args.strict_bdr_cls.split(",")] 125 | if args.rlx_off_iter > -1: 126 | __C.REDUCE_BORDER_ITER = args.rlx_off_iter 127 | 128 | if make_immutable: 129 | cfg.immutable(True) 130 | -------------------------------------------------------------------------------- /datasets/bdd100k.py: -------------------------------------------------------------------------------- 1 | """ 2 | BDD100K Dataset Loader 3 | """ 4 | import logging 5 | import json 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | from skimage import color 10 | 11 | from torch.utils import data 12 | import torch 13 | import torchvision.transforms as transforms 14 | import datasets.uniform as uniform 15 | import datasets.cityscapes_labels as cityscapes_labels 16 | 17 | from config import cfg 18 | 19 | trainid_to_name = cityscapes_labels.trainId2name 20 | id_to_trainid = cityscapes_labels.label2trainid 21 | trainid_to_trainid = cityscapes_labels.trainId2trainId 22 | color_to_trainid = cityscapes_labels.color2trainId 23 | num_classes = 19 24 | ignore_label = 255 25 | root = cfg.DATASET.BDD_DIR 26 | img_postfix = '.jpg' 27 | 28 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 29 | 153, 153, 153, 250, 170, 30, 30 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 31 | 255, 0, 0, 0, 0, 142, 0, 0, 70, 32 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 33 | zero_pad = 256 * 3 - len(palette) 34 | for i in range(zero_pad): 35 | palette.append(0) 36 | 37 | 38 | def colorize_mask(mask): 39 | """ 40 | Colorize a segmentation mask. 41 | """ 42 | # mask: numpy array of the mask 43 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 44 | new_mask.putpalette(palette) 45 | return new_mask 46 | 47 | 48 | def add_items(items, aug_items, img_path, mask_path, mask_postfix, mode, maxSkip): 49 | """ 50 | 51 | Add More items ot the list from the augmented dataset 52 | """ 53 | 54 | if mode == "train": 55 | img_path = os.path.join(img_path, 'train') 56 | mask_path = os.path.join(mask_path, 'train') 57 | elif mode == "val": 58 | img_path = os.path.join(img_path, 'val') 59 | mask_path = os.path.join(mask_path, 'val') 60 | 61 | list_items = [name.split(img_postfix)[0] for name in 62 | os.listdir(img_path)] 63 | for it in list_items: 64 | item = (os.path.join(img_path, it + img_postfix), 65 | os.path.join(mask_path, it + mask_postfix)) 66 | # ######################################################## 67 | # ###### dataset augmentation ############################ 68 | # ######################################################## 69 | # if mode == "train" and maxSkip > 0: 70 | # new_img_path = os.path.join(aug_root, 'leftImg8bit_trainvaltest', 'leftImg8bit') 71 | # new_mask_path = os.path.join(aug_root, 'gtFine_trainvaltest', 'gtFine') 72 | # file_info = it.split("_") 73 | # cur_seq_id = file_info[-1] 74 | 75 | # prev_seq_id = "%06d" % (int(cur_seq_id) - maxSkip) 76 | # next_seq_id = "%06d" % (int(cur_seq_id) + maxSkip) 77 | # prev_it = file_info[0] + "_" + file_info[1] + "_" + prev_seq_id 78 | # next_it = file_info[0] + "_" + file_info[1] + "_" + next_seq_id 79 | # prev_item = (os.path.join(new_img_path, c, prev_it + img_postfix), 80 | # os.path.join(new_mask_path, c, prev_it + mask_postfix)) 81 | # if os.path.isfile(prev_item[0]) and os.path.isfile(prev_item[1]): 82 | # aug_items.append(prev_item) 83 | # next_item = (os.path.join(new_img_path, c, next_it + img_postfix), 84 | # os.path.join(new_mask_path, c, next_it + mask_postfix)) 85 | # if os.path.isfile(next_item[0]) and os.path.isfile(next_item[1]): 86 | # aug_items.append(next_item) 87 | items.append(item) 88 | # items.extend(extra_items) 89 | 90 | 91 | def make_cv_splits(img_dir_name): 92 | """ 93 | Create splits of train/val data. 94 | A split is a lists of cities. 95 | split0 is aligned with the default Cityscapes train/val. 96 | """ 97 | trn_path = os.path.join(root, img_dir_name, 'train') 98 | val_path = os.path.join(root, img_dir_name, 'val') 99 | 100 | trn_cities = ['train/' + c for c in os.listdir(trn_path)] 101 | val_cities = ['val/' + c for c in os.listdir(val_path)] 102 | 103 | # want reproducible randomly shuffled 104 | trn_cities = sorted(trn_cities) 105 | 106 | all_cities = val_cities + trn_cities 107 | num_val_cities = len(val_cities) 108 | num_cities = len(all_cities) 109 | 110 | cv_splits = [] 111 | for split_idx in range(cfg.DATASET.CV_SPLITS): 112 | split = {} 113 | split['train'] = [] 114 | split['val'] = [] 115 | offset = split_idx * num_cities // cfg.DATASET.CV_SPLITS 116 | for j in range(num_cities): 117 | if j >= offset and j < (offset + num_val_cities): 118 | split['val'].append(all_cities[j]) 119 | else: 120 | split['train'].append(all_cities[j]) 121 | cv_splits.append(split) 122 | 123 | return cv_splits 124 | 125 | 126 | def make_split_coarse(img_path): 127 | """ 128 | Create a train/val split for coarse 129 | return: city split in train 130 | """ 131 | all_cities = os.listdir(img_path) 132 | all_cities = sorted(all_cities) # needs to always be the same 133 | val_cities = [] # Can manually set cities to not be included into train split 134 | 135 | split = {} 136 | split['val'] = val_cities 137 | split['train'] = [c for c in all_cities if c not in val_cities] 138 | return split 139 | 140 | 141 | def make_test_split(img_dir_name): 142 | test_path = os.path.join(root, img_dir_name, 'leftImg8bit', 'test') 143 | test_cities = ['test/' + c for c in os.listdir(test_path)] 144 | 145 | return test_cities 146 | 147 | 148 | def make_dataset(mode, maxSkip=0, cv_split=0): 149 | """ 150 | Assemble list of images + mask files 151 | 152 | fine - modes: train/val/test/trainval cv:0,1,2 153 | coarse - modes: train/val cv:na 154 | 155 | path examples: 156 | leftImg8bit_trainextra/leftImg8bit/train_extra/augsburg 157 | gtCoarse/gtCoarse/train_extra/augsburg 158 | """ 159 | items = [] 160 | aug_items = [] 161 | 162 | assert mode in ['train', 'val', 'test', 'trainval'] 163 | img_dir_name = 'images' 164 | img_path = os.path.join(root, img_dir_name) 165 | mask_path = os.path.join(root, 'labels') 166 | mask_postfix = '_train_id.png' 167 | # cv_splits = make_cv_splits(img_dir_name) 168 | if mode == 'trainval': 169 | modes = ['train', 'val'] 170 | else: 171 | modes = [mode] 172 | for mode in modes: 173 | logging.info('{} fine cities: '.format(mode)) 174 | add_items(items, aug_items, img_path, mask_path, 175 | mask_postfix, mode, maxSkip) 176 | 177 | # logging.info('Cityscapes-{}: {} images'.format(mode, len(items))) 178 | logging.info('BDD100K-{}: {} images'.format(mode, len(items) + len(aug_items))) 179 | return items, aug_items 180 | 181 | 182 | class BDD100K(data.Dataset): 183 | 184 | def __init__(self, mode, maxSkip=0, joint_transform=None, sliding_crop=None, 185 | transform=None, target_transform=None, target_aux_transform=None, dump_images=False, 186 | cv_split=None, eval_mode=False, 187 | eval_scales=None, eval_flip=False, image_in=False, 188 | extract_feature=False): 189 | self.mode = mode 190 | self.maxSkip = maxSkip 191 | self.joint_transform = joint_transform 192 | self.sliding_crop = sliding_crop 193 | self.transform = transform 194 | self.target_transform = target_transform 195 | self.target_aux_transform = target_aux_transform 196 | self.dump_images = dump_images 197 | self.eval_mode = eval_mode 198 | self.eval_flip = eval_flip 199 | self.eval_scales = None 200 | self.image_in = image_in 201 | self.extract_feature = extract_feature 202 | 203 | 204 | if eval_scales != None: 205 | self.eval_scales = [float(scale) for scale in eval_scales.split(",")] 206 | 207 | if cv_split: 208 | self.cv_split = cv_split 209 | assert cv_split < cfg.DATASET.CV_SPLITS, \ 210 | 'expected cv_split {} to be < CV_SPLITS {}'.format( 211 | cv_split, cfg.DATASET.CV_SPLITS) 212 | else: 213 | self.cv_split = 0 214 | self.imgs, _ = make_dataset(mode, self.maxSkip, cv_split=self.cv_split) 215 | if len(self.imgs) == 0: 216 | raise RuntimeError('Found 0 images, please check the data set') 217 | 218 | self.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 219 | 220 | def _eval_get_item(self, img, mask, scales, flip_bool): 221 | return_imgs = [] 222 | for flip in range(int(flip_bool) + 1): 223 | imgs = [] 224 | if flip: 225 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 226 | for scale in scales: 227 | w, h = img.size 228 | target_w, target_h = int(w * scale), int(h * scale) 229 | resize_img = img.resize((target_w, target_h)) 230 | tensor_img = transforms.ToTensor()(resize_img) 231 | final_tensor = transforms.Normalize(*self.mean_std)(tensor_img) 232 | imgs.append(final_tensor) 233 | return_imgs.append(imgs) 234 | return return_imgs, mask 235 | 236 | def __getitem__(self, index): 237 | 238 | img_path, mask_path = self.imgs[index] 239 | 240 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 241 | img_name = os.path.splitext(os.path.basename(img_path))[0] 242 | 243 | mask = np.array(mask) 244 | mask_copy = mask.copy() 245 | for k, v in trainid_to_trainid.items(): 246 | mask_copy[mask == k] = v 247 | 248 | if self.eval_mode: 249 | return [transforms.ToTensor()(img)], self._eval_get_item(img, mask_copy, 250 | self.eval_scales, 251 | self.eval_flip), img_name 252 | 253 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 254 | 255 | # Image Transformations 256 | if self.extract_feature is not True: 257 | if self.joint_transform is not None: 258 | img, mask = self.joint_transform(img, mask) 259 | 260 | if self.transform is not None: 261 | img = self.transform(img) 262 | 263 | rgb_mean_std_gt = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 264 | img_gt = transforms.Normalize(*rgb_mean_std_gt)(img) 265 | 266 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 267 | if self.image_in: 268 | eps = 1e-5 269 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])], 270 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps]) 271 | img = transforms.Normalize(*rgb_mean_std)(img) 272 | 273 | if self.target_aux_transform is not None: 274 | mask_aux = self.target_aux_transform(mask) 275 | else: 276 | mask_aux = torch.tensor([0]) 277 | if self.target_transform is not None: 278 | mask = self.target_transform(mask) 279 | 280 | # Debug 281 | if self.dump_images: 282 | outdir = '../../dump_imgs_{}'.format(self.mode) 283 | os.makedirs(outdir, exist_ok=True) 284 | out_img_fn = os.path.join(outdir, img_name + '.png') 285 | out_msk_fn = os.path.join(outdir, img_name + '_mask.png') 286 | mask_img = colorize_mask(np.array(mask)) 287 | img.save(out_img_fn) 288 | mask_img.save(out_msk_fn) 289 | 290 | return img, mask, img_name, mask_aux 291 | 292 | def __len__(self): 293 | return len(self.imgs) 294 | 295 | class BDD100KUniform(data.Dataset): 296 | """ 297 | Please do not use this for AGG 298 | """ 299 | 300 | def __init__(self, mode, maxSkip=0, joint_transform_list=None, sliding_crop=None, 301 | transform=None, target_transform=None, target_aux_transform=None, dump_images=False, 302 | cv_split=None, class_uniform_pct=0.5, class_uniform_tile=1024, 303 | test=False, coarse_boost_classes=None, image_in=False, extract_feature=False): 304 | self.mode = mode 305 | self.maxSkip = maxSkip 306 | self.joint_transform_list = joint_transform_list 307 | self.sliding_crop = sliding_crop 308 | self.transform = transform 309 | self.target_transform = target_transform 310 | self.target_aux_transform = target_aux_transform 311 | self.dump_images = dump_images 312 | self.class_uniform_pct = class_uniform_pct 313 | self.class_uniform_tile = class_uniform_tile 314 | self.coarse_boost_classes = coarse_boost_classes 315 | self.image_in = image_in 316 | self.extract_feature = extract_feature 317 | 318 | 319 | if cv_split: 320 | self.cv_split = cv_split 321 | assert cv_split < cfg.DATASET.CV_SPLITS, \ 322 | 'expected cv_split {} to be < CV_SPLITS {}'.format( 323 | cv_split, cfg.DATASET.CV_SPLITS) 324 | else: 325 | self.cv_split = 0 326 | 327 | self.imgs, self.aug_imgs = make_dataset(mode, self.maxSkip, cv_split=self.cv_split) 328 | assert len(self.imgs), 'Found 0 images, please check the data set' 329 | 330 | # Centroids for fine data 331 | json_fn = 'bdd100k_{}_cv{}_tile{}.json'.format( 332 | self.mode, self.cv_split, self.class_uniform_tile) 333 | if os.path.isfile(json_fn): 334 | with open(json_fn, 'r') as json_data: 335 | centroids = json.load(json_data) 336 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 337 | else: 338 | self.centroids = uniform.class_centroids_all( 339 | self.imgs, 340 | num_classes, 341 | id2trainid=trainid_to_trainid, 342 | tile_size=class_uniform_tile) 343 | with open(json_fn, 'w') as outfile: 344 | json.dump(self.centroids, outfile, indent=4) 345 | 346 | self.fine_centroids = self.centroids.copy() 347 | 348 | self.build_epoch() 349 | 350 | def cities_uniform(self, imgs, name): 351 | """ list out cities in imgs_uniform """ 352 | cities = {} 353 | for item in imgs: 354 | img_fn = item[0] 355 | img_fn = os.path.basename(img_fn) 356 | city = img_fn.split('_')[0] 357 | cities[city] = 1 358 | city_names = cities.keys() 359 | logging.info('Cities for {} '.format(name) + str(sorted(city_names))) 360 | 361 | def build_epoch(self, cut=False): 362 | """ 363 | Perform Uniform Sampling per epoch to create a new list for training such that it 364 | uniformly samples all classes 365 | """ 366 | if self.class_uniform_pct > 0: 367 | if cut: 368 | # after max_cu_epoch, we only fine images to fine tune 369 | self.imgs_uniform = uniform.build_epoch(self.imgs, 370 | self.fine_centroids, 371 | num_classes, 372 | cfg.CLASS_UNIFORM_PCT) 373 | else: 374 | self.imgs_uniform = uniform.build_epoch(self.imgs + self.aug_imgs, 375 | self.centroids, 376 | num_classes, 377 | cfg.CLASS_UNIFORM_PCT) 378 | else: 379 | self.imgs_uniform = self.imgs 380 | 381 | def __getitem__(self, index): 382 | elem = self.imgs_uniform[index] 383 | centroid = None 384 | if len(elem) == 4: 385 | img_path, mask_path, centroid, class_id = elem 386 | else: 387 | img_path, mask_path = elem 388 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 389 | img_name = os.path.splitext(os.path.basename(img_path))[0] 390 | 391 | mask = np.array(mask) 392 | mask_copy = mask.copy() 393 | for k, v in trainid_to_trainid.items(): 394 | mask_copy[mask == k] = v 395 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 396 | 397 | # Image Transformations 398 | if self.extract_feature is not True: 399 | if self.joint_transform_list is not None: 400 | for idx, xform in enumerate(self.joint_transform_list): 401 | if idx == 0 and centroid is not None: 402 | # HACK 403 | # We assume that the first transform is capable of taking 404 | # in a centroid 405 | img, mask = xform(img, mask, centroid) 406 | else: 407 | img, mask = xform(img, mask) 408 | 409 | # Debug 410 | if self.dump_images and centroid is not None: 411 | outdir = '../../dump_imgs_{}'.format(self.mode) 412 | os.makedirs(outdir, exist_ok=True) 413 | dump_img_name = trainid_to_name[class_id] + '_' + img_name 414 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 415 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 416 | mask_img = colorize_mask(np.array(mask)) 417 | img.save(out_img_fn) 418 | mask_img.save(out_msk_fn) 419 | 420 | if self.transform is not None: 421 | img = self.transform(img) 422 | 423 | rgb_mean_std_gt = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 424 | img_gt = transforms.Normalize(*rgb_mean_std_gt)(img) 425 | 426 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 427 | if self.image_in: 428 | eps = 1e-5 429 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])], 430 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps]) 431 | img = transforms.Normalize(*rgb_mean_std)(img) 432 | 433 | if self.target_aux_transform is not None: 434 | mask_aux = self.target_aux_transform(mask) 435 | else: 436 | mask_aux = torch.tensor([0]) 437 | if self.target_transform is not None: 438 | mask = self.target_transform(mask) 439 | 440 | return img, mask, img_name, mask_aux 441 | 442 | def __len__(self): 443 | return len(self.imgs_uniform) 444 | 445 | class BDD100KAug(data.Dataset): 446 | 447 | def __init__(self, mode, maxSkip=0, joint_transform=None, sliding_crop=None, 448 | transform=None, color_transform=None, geometric_transform=None, target_transform=None, target_aux_transform=None, dump_images=False, 449 | cv_split=None, eval_mode=False, 450 | eval_scales=None, eval_flip=False, image_in=False, 451 | extract_feature=False): 452 | self.mode = mode 453 | self.maxSkip = maxSkip 454 | self.joint_transform = joint_transform 455 | self.sliding_crop = sliding_crop 456 | self.transform = transform 457 | self.color_transform = color_transform 458 | self.geometric_transform = geometric_transform 459 | self.target_transform = target_transform 460 | self.target_aux_transform = target_aux_transform 461 | self.dump_images = dump_images 462 | self.eval_mode = eval_mode 463 | self.eval_flip = eval_flip 464 | self.eval_scales = None 465 | self.image_in = image_in 466 | self.extract_feature = extract_feature 467 | 468 | 469 | if eval_scales != None: 470 | self.eval_scales = [float(scale) for scale in eval_scales.split(",")] 471 | 472 | if cv_split: 473 | self.cv_split = cv_split 474 | assert cv_split < cfg.DATASET.CV_SPLITS, \ 475 | 'expected cv_split {} to be < CV_SPLITS {}'.format( 476 | cv_split, cfg.DATASET.CV_SPLITS) 477 | else: 478 | self.cv_split = 0 479 | self.imgs, _ = make_dataset(mode, self.maxSkip, cv_split=self.cv_split) 480 | if len(self.imgs) == 0: 481 | raise RuntimeError('Found 0 images, please check the data set') 482 | 483 | self.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 484 | 485 | def _eval_get_item(self, img, mask, scales, flip_bool): 486 | return_imgs = [] 487 | for flip in range(int(flip_bool) + 1): 488 | imgs = [] 489 | if flip: 490 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 491 | for scale in scales: 492 | w, h = img.size 493 | target_w, target_h = int(w * scale), int(h * scale) 494 | resize_img = img.resize((target_w, target_h)) 495 | tensor_img = transforms.ToTensor()(resize_img) 496 | final_tensor = transforms.Normalize(*self.mean_std)(tensor_img) 497 | imgs.append(final_tensor) 498 | return_imgs.append(imgs) 499 | return return_imgs, mask 500 | 501 | def __getitem__(self, index): 502 | 503 | img_path, mask_path = self.imgs[index] 504 | 505 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 506 | img_name = os.path.splitext(os.path.basename(img_path))[0] 507 | 508 | mask = np.array(mask) 509 | mask_copy = mask.copy() 510 | for k, v in trainid_to_trainid.items(): 511 | mask_copy[mask == k] = v 512 | 513 | if self.eval_mode: 514 | return [transforms.ToTensor()(img)], self._eval_get_item(img, mask_copy, 515 | self.eval_scales, 516 | self.eval_flip), img_name 517 | 518 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 519 | 520 | if self.joint_transform is not None: 521 | img, mask = self.joint_transform(img, mask) 522 | 523 | if self.transform is not None: 524 | img_or = self.transform(img) 525 | 526 | if self.color_transform is not None: 527 | img_color = self.color_transform(img) 528 | 529 | if self.geometric_transform is not None: 530 | img_geometric = self.geometric_transform(img) 531 | 532 | rgb_mean_std_or = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 533 | rgb_mean_std_color = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 534 | rgb_mean_std_geometric = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 535 | if self.image_in: 536 | eps = 1e-5 537 | rgb_mean_std_or = ([torch.mean(img_or[0]), torch.mean(img_or[1]), torch.mean(img_or[2])], 538 | [torch.std(img_or[0])+eps, torch.std(img_or[1])+eps, torch.std(img_or[2])+eps]) 539 | rgb_mean_std_color = ([torch.mean(img_color[0]), torch.mean(img_color[1]), torch.mean(img_color[2])], 540 | [torch.std(img_color[0])+eps, torch.std(img_color[1])+eps, torch.std(img_color[2])+eps]) 541 | rgb_mean_std_geometric = ([torch.mean(img_geometric[0]), torch.mean(img_geometric[1]), torch.mean(img_geometric[2])], 542 | [torch.std(img_geometric[0])+eps, torch.std(img_geometric[1])+eps, torch.std(img_geometric[2])+eps]) 543 | img_or = transforms.Normalize(*rgb_mean_std_or)(img_or) 544 | img_color = transforms.Normalize(*rgb_mean_std_color)(img_color) 545 | img_geometric = transforms.Normalize(*rgb_mean_std_geometric)(img_geometric) 546 | 547 | return img_or, img_color, img_geometric, img_name 548 | 549 | def __len__(self): 550 | return len(self.imgs) 551 | 552 | -------------------------------------------------------------------------------- /datasets/cityscapes_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | # File taken from https://github.com/mcordts/cityscapesScripts/ 3 | # License File Available at: 4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt 5 | 6 | # ---------------------- 7 | # The Cityscapes Dataset 8 | # ---------------------- 9 | # 10 | # 11 | # License agreement 12 | # ----------------- 13 | # 14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree: 15 | # 16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions. 17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website. 18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character. 19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain. 20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt). 21 | # 22 | # 23 | # Contact 24 | # ------- 25 | # 26 | # Marius Cordts, Mohamed Omran 27 | # www.cityscapes-dataset.net 28 | 29 | """ 30 | from collections import namedtuple 31 | 32 | 33 | #-------------------------------------------------------------------------------- 34 | # Definitions 35 | #-------------------------------------------------------------------------------- 36 | 37 | # a label and all meta information 38 | Label = namedtuple( 'Label' , [ 39 | 40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 41 | # We use them to uniquely name a class 42 | 43 | 'id' , # An integer ID that is associated with this label. 44 | # The IDs are used to represent the label in ground truth images 45 | # An ID of -1 means that this label does not have an ID and thus 46 | # is ignored when creating ground truth images (e.g. license plate). 47 | # Do not modify these IDs, since exactly these IDs are expected by the 48 | # evaluation server. 49 | 50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 51 | # ground truth images with train IDs, using the tools provided in the 52 | # 'preparation' folder. However, make sure to validate or submit results 53 | # to our evaluation server using the regular IDs above! 54 | # For trainIds, multiple labels might have the same ID. Then, these labels 55 | # are mapped to the same class in the ground truth images. For the inverse 56 | # mapping, we use the label that is defined first in the list below. 57 | # For example, mapping all void-type classes to the same ID in training, 58 | # might make sense for some approaches. 59 | # Max value is 255! 60 | 61 | 'category' , # The name of the category that this label belongs to 62 | 63 | 'categoryId' , # The ID of this category. Used to create ground truth images 64 | # on category level. 65 | 66 | 'hasInstances', # Whether this label distinguishes between single instances or not 67 | 68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 69 | # during evaluations or not 70 | 71 | 'color' , # The color of this label 72 | ] ) 73 | 74 | 75 | #-------------------------------------------------------------------------------- 76 | # A list of all labels 77 | #-------------------------------------------------------------------------------- 78 | 79 | # Please adapt the train IDs as appropriate for you approach. 80 | # Note that you might want to ignore labels with ID 255 during training. 81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 82 | # Make sure to provide your results using the original IDs and not the training IDs. 83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 84 | 85 | labels = [ 86 | # name id trainId category catId hasInstances ignoreInEval color 87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,154) ), # (153,153,153) 106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,143) ), # ( 0, 0,142) 122 | ] 123 | 124 | 125 | #-------------------------------------------------------------------------------- 126 | # Create dictionaries for a fast lookup 127 | #-------------------------------------------------------------------------------- 128 | 129 | # Please refer to the main method below for example usages! 130 | 131 | # name to label object 132 | name2label = { label.name : label for label in labels } 133 | # id to label object 134 | id2label = { label.id : label for label in labels } 135 | # trainId to label object 136 | trainId2label = { label.trainId : label for label in reversed(labels) } 137 | # label2trainid 138 | label2trainid = { label.id : label.trainId for label in labels } 139 | # trainId to label object 140 | trainId2name = { label.trainId : label.name for label in labels } 141 | trainId2color = { label.trainId : label.color for label in labels } 142 | 143 | color2trainId = { label.color : label.trainId for label in labels } 144 | 145 | trainId2trainId = { label.trainId : label.trainId for label in labels } 146 | 147 | # category to list of label objects 148 | category2labels = {} 149 | for label in labels: 150 | category = label.category 151 | if category in category2labels: 152 | category2labels[category].append(label) 153 | else: 154 | category2labels[category] = [label] 155 | 156 | #-------------------------------------------------------------------------------- 157 | # Assure single instance name 158 | #-------------------------------------------------------------------------------- 159 | 160 | # returns the label name that describes a single instance (if possible) 161 | # e.g. input | output 162 | # ---------------------- 163 | # car | car 164 | # cargroup | car 165 | # foo | None 166 | # foogroup | None 167 | # skygroup | None 168 | def assureSingleInstanceName( name ): 169 | # if the name is known, it is not a group 170 | if name in name2label: 171 | return name 172 | # test if the name actually denotes a group 173 | if not name.endswith("group"): 174 | return None 175 | # remove group 176 | name = name[:-len("group")] 177 | # test if the new name exists 178 | if not name in name2label: 179 | return None 180 | # test if the new name denotes a label that actually has instances 181 | if not name2label[name].hasInstances: 182 | return None 183 | # all good then 184 | return name 185 | 186 | #-------------------------------------------------------------------------------- 187 | # Main for testing 188 | #-------------------------------------------------------------------------------- 189 | 190 | # just a dummy main 191 | if __name__ == "__main__": 192 | # Print all the labels 193 | print("List of cityscapes labels:") 194 | print("") 195 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))) 196 | print((" " + ('-' * 98))) 197 | for label in labels: 198 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))) 199 | print("") 200 | 201 | print("Example usages:") 202 | 203 | # Map from name to label 204 | name = 'car' 205 | id = name2label[name].id 206 | print(("ID of label '{name}': {id}".format( name=name, id=id ))) 207 | 208 | # Map from ID to label 209 | category = id2label[id].category 210 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category ))) 211 | 212 | # Map from trainID to label 213 | trainId = 0 214 | name = trainId2label[trainId].name 215 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))) 216 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ImageNet Dataset Loader 3 | """ 4 | import logging 5 | import json 6 | import os 7 | import numpy as np 8 | from PIL import Image 9 | from skimage import color 10 | 11 | from torch.utils import data 12 | import torch 13 | import torchvision.transforms as transforms 14 | import datasets.uniform as uniform 15 | import datasets.cityscapes_labels as cityscapes_labels 16 | import scipy.misc as m 17 | 18 | from config import cfg 19 | 20 | import pdb 21 | import random 22 | from itertools import cycle, islice 23 | 24 | root = cfg.DATASET.IMAGENET_DIR 25 | img_postfix = '.JPEG' 26 | 27 | 28 | def make_cv_splits(): 29 | """ 30 | Create splits of train/valid data. 31 | A split is a lists of cities. 32 | split0 is aligned with the default Cityscapes train/valid. 33 | """ 34 | trn_path = os.path.join(root, 'train') 35 | val_path = os.path.join(root, 'val') 36 | 37 | trn_cities = ['train/' + c for c in os.listdir(trn_path)] 38 | val_cities = ['val/' + c for c in os.listdir(val_path)] 39 | 40 | # want reproducible randomly shuffled 41 | trn_cities = sorted(trn_cities) 42 | 43 | all_cities = val_cities + trn_cities 44 | num_val_cities = len(val_cities) 45 | num_cities = len(all_cities) 46 | 47 | cv_splits = [] 48 | for split_idx in range(cfg.DATASET.CV_SPLITS): 49 | split = {} 50 | split['train'] = [] 51 | split['val'] = [] 52 | offset = split_idx * num_cities // cfg.DATASET.CV_SPLITS 53 | for j in range(num_cities): 54 | if j >= offset and j < (offset + num_val_cities): 55 | split['val'].append(all_cities[j]) 56 | else: 57 | split['train'].append(all_cities[j]) 58 | cv_splits.append(split) 59 | 60 | return cv_splits 61 | 62 | 63 | def add_items(items, aug_items, cities, img_path, mode, maxSkip): 64 | """ 65 | 66 | Add More items ot the list from the augmented dataset 67 | """ 68 | 69 | for c in cities: 70 | c_items = [name.split(img_postfix)[0] for name in 71 | os.listdir(os.path.join(img_path, c))] 72 | for it in c_items: 73 | item = (os.path.join(img_path, c, it + img_postfix)) 74 | items.append(item) 75 | 76 | 77 | def make_dataset(mode, maxSkip=0, cv_split=0): 78 | """ 79 | Assemble list of images + mask files 80 | 81 | fine - modes: train/valid/test/trainval cv:0,1,2 82 | coarse - modes: train/valid cv:na 83 | 84 | path examples: 85 | leftImg8bit_trainextra/leftImg8bit/train_extra/augsburg 86 | gtCoarse/gtCoarse/train_extra/augsburg 87 | """ 88 | items = [] 89 | aug_items = [] 90 | 91 | assert mode in ['train', 'val', 'trainval'] 92 | cv_splits = make_cv_splits() 93 | if mode == 'trainval': 94 | modes = ['train', 'val'] 95 | else: 96 | modes = [mode] 97 | for mode in modes: 98 | logging.info('{} imagenet: '.format(mode) + str(cv_splits[cv_split][mode])) 99 | add_items(items, aug_items, cv_splits[cv_split][mode], root, mode, maxSkip) 100 | 101 | logging.info('ImageNet-{}: {} images'.format(mode, len(items) + len(aug_items))) 102 | 103 | return items, aug_items 104 | 105 | 106 | class ImageNet(data.Dataset): 107 | 108 | def __init__(self, mode, maxSkip=0, joint_transform=None, sliding_crop=None, 109 | transform=None, dump_images=False, 110 | cv_split=None, eval_mode=False, 111 | eval_scales=None, eval_flip=False, image_in=False, 112 | extract_feature=False): 113 | self.mode = mode 114 | self.maxSkip = maxSkip 115 | self.joint_transform = joint_transform 116 | self.sliding_crop = sliding_crop 117 | self.transform = transform 118 | self.dump_images = dump_images 119 | self.eval_mode = eval_mode 120 | self.eval_flip = eval_flip 121 | self.eval_scales = None 122 | self.image_in = image_in 123 | self.extract_feature = extract_feature 124 | 125 | 126 | if eval_scales != None: 127 | self.eval_scales = [float(scale) for scale in eval_scales.split(",")] 128 | 129 | if cv_split: 130 | self.cv_split = cv_split 131 | assert cv_split < cfg.DATASET.CV_SPLITS, \ 132 | 'expected cv_split {} to be < CV_SPLITS {}'.format( 133 | cv_split, cfg.DATASET.CV_SPLITS) 134 | else: 135 | self.cv_split = 0 136 | self.imgs, _ = make_dataset(mode, self.maxSkip, cv_split=self.cv_split) 137 | if len(self.imgs) == 0: 138 | raise RuntimeError('Found 0 images, please check the data set') 139 | 140 | self.mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 141 | 142 | def _eval_get_item(self, img, scales, flip_bool): 143 | return_imgs = [] 144 | for flip in range(int(flip_bool) + 1): 145 | imgs = [] 146 | if flip: 147 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 148 | for scale in scales: 149 | w, h = img.size 150 | target_w, target_h = int(w * scale), int(h * scale) 151 | resize_img = img.resize((target_w, target_h)) 152 | tensor_img = transforms.ToTensor()(resize_img) 153 | final_tensor = transforms.Normalize(*self.mean_std)(tensor_img) 154 | imgs.append(final_tensor) 155 | return_imgs.append(imgs) 156 | return return_imgs 157 | 158 | def __getitem__(self, index): 159 | 160 | img_path = self.imgs[index] 161 | img = Image.open(img_path).convert('RGB') 162 | img_name = os.path.splitext(os.path.basename(img_path))[0] 163 | 164 | if self.eval_mode: 165 | return [transforms.ToTensor()(img)], self._eval_get_item(img, 166 | self.eval_scales, 167 | self.eval_flip), img_name 168 | # Image Transformations 169 | if self.extract_feature is not True: 170 | if self.joint_transform is not None: 171 | img = self.joint_transform(img) 172 | 173 | if self.transform is not None: 174 | img = self.transform(img) 175 | 176 | rgb_mean_std_gt = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 177 | img_gt = transforms.Normalize(*rgb_mean_std_gt)(img) 178 | 179 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 180 | if self.image_in: 181 | eps = 1e-5 182 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])], 183 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps]) 184 | img = transforms.Normalize(*rgb_mean_std)(img) 185 | 186 | # Debug 187 | if self.dump_images: 188 | outdir = '../../dump_imgs_{}'.format(self.mode) 189 | os.makedirs(outdir, exist_ok=True) 190 | out_img_fn = os.path.join(outdir, img_name + '.png') 191 | img.save(out_img_fn) 192 | 193 | return img, img_name 194 | 195 | def __len__(self): 196 | return len(self.imgs) 197 | -------------------------------------------------------------------------------- /datasets/mapillary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mapillary Dataset Loader 3 | """ 4 | import logging 5 | import json 6 | import os 7 | import numpy as np 8 | from PIL import Image, ImageCms 9 | from skimage import color 10 | 11 | from torch.utils import data 12 | import torch 13 | import torchvision.transforms as transforms 14 | import datasets.uniform as uniform 15 | import datasets.cityscapes_labels as cityscapes_labels 16 | import transforms.transforms as extended_transforms 17 | import copy 18 | 19 | from config import cfg 20 | 21 | # Convert this dataset to have labels from cityscapes 22 | num_classes = 19 #65 23 | ignore_label = 255 #65 24 | root = cfg.DATASET.MAPILLARY_DIR 25 | config_fn = os.path.join(root, 'config.json') 26 | color_mapping = [] 27 | id_to_trainid = {} 28 | id_to_ignore_or_group = {} 29 | 30 | 31 | def gen_id_to_ignore(): 32 | global id_to_ignore_or_group 33 | for i in range(66): 34 | id_to_ignore_or_group[i] = ignore_label 35 | 36 | ### Convert each class to cityscapes one 37 | ### Road 38 | # Road 39 | id_to_ignore_or_group[13] = 0 40 | # Lane Marking - General 41 | id_to_ignore_or_group[24] = 0 42 | # Manhole 43 | id_to_ignore_or_group[41] = 0 44 | 45 | ### Sidewalk 46 | # Curb 47 | id_to_ignore_or_group[2] = 1 48 | # Sidewalk 49 | id_to_ignore_or_group[15] = 1 50 | 51 | ### Building 52 | # Building 53 | id_to_ignore_or_group[17] = 2 54 | 55 | ### Wall 56 | # Wall 57 | id_to_ignore_or_group[6] = 3 58 | 59 | ### Fence 60 | # Fence 61 | id_to_ignore_or_group[3] = 4 62 | 63 | ### Pole 64 | # Pole 65 | id_to_ignore_or_group[45] = 5 66 | # Utility Pole 67 | id_to_ignore_or_group[47] = 5 68 | 69 | ### Traffic Light 70 | # Traffic Light 71 | id_to_ignore_or_group[48] = 6 72 | 73 | ### Traffic Sign 74 | # Traffic Sign 75 | id_to_ignore_or_group[50] = 7 76 | 77 | ### Vegetation 78 | # Vegitation 79 | id_to_ignore_or_group[30] = 8 80 | 81 | ### Terrain 82 | # Terrain 83 | id_to_ignore_or_group[29] = 9 84 | 85 | ### Sky 86 | # Sky 87 | id_to_ignore_or_group[27] = 10 88 | 89 | ### Person 90 | # Person 91 | id_to_ignore_or_group[19] = 11 92 | 93 | ### Rider 94 | # Bicyclist 95 | id_to_ignore_or_group[20] = 12 96 | # Motorcyclist 97 | id_to_ignore_or_group[21] = 12 98 | # Other Rider 99 | id_to_ignore_or_group[22] = 12 100 | 101 | ### Car 102 | # Car 103 | id_to_ignore_or_group[55] = 13 104 | 105 | ### Truck 106 | # Truck 107 | id_to_ignore_or_group[61] = 14 108 | 109 | ### Bus 110 | # Bus 111 | id_to_ignore_or_group[54] = 15 112 | 113 | ### Train 114 | # On Rails 115 | id_to_ignore_or_group[58] = 16 116 | 117 | ### Motorcycle 118 | # Motorcycle 119 | id_to_ignore_or_group[57] = 17 120 | 121 | ### Bicycle 122 | # Bicycle 123 | id_to_ignore_or_group[52] = 18 124 | 125 | 126 | def colorize_mask(image_array): 127 | """ 128 | Colorize a segmentation mask 129 | """ 130 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P') 131 | new_mask.putpalette(color_mapping) 132 | return new_mask 133 | 134 | 135 | def make_dataset(quality, mode): 136 | """ 137 | Create File List 138 | """ 139 | assert (quality == 'semantic' and mode in ['train', 'val']) 140 | img_dir_name = None 141 | if quality == 'semantic': 142 | if mode == 'train': 143 | img_dir_name = 'training' 144 | if mode == 'val': 145 | img_dir_name = 'validation' 146 | mask_path = os.path.join(root, img_dir_name, 'labels') 147 | else: 148 | raise BaseException("Instance Segmentation Not support") 149 | 150 | img_path = os.path.join(root, img_dir_name, 'images') 151 | print(img_path) 152 | if quality != 'video': 153 | imgs = sorted([os.path.splitext(f)[0] for f in os.listdir(img_path)]) 154 | msks = sorted([os.path.splitext(f)[0] for f in os.listdir(mask_path)]) 155 | assert imgs == msks 156 | 157 | items = [] 158 | c_items = os.listdir(img_path) 159 | if '.DS_Store' in c_items: 160 | c_items.remove('.DS_Store') 161 | 162 | for it in c_items: 163 | if quality == 'video': 164 | item = (os.path.join(img_path, it), os.path.join(img_path, it)) 165 | else: 166 | item = (os.path.join(img_path, it), 167 | os.path.join(mask_path, it.replace(".jpg", ".png"))) 168 | items.append(item) 169 | return items 170 | 171 | 172 | def gen_colormap(): 173 | """ 174 | Get Color Map from file 175 | """ 176 | global color_mapping 177 | 178 | # load mapillary config 179 | with open(config_fn) as config_file: 180 | config = json.load(config_file) 181 | config_labels = config['labels'] 182 | 183 | # calculate label color mapping 184 | colormap = [] 185 | id2name = {} 186 | for i in range(0, len(config_labels)): 187 | colormap = colormap + config_labels[i]['color'] 188 | id2name[i] = config_labels[i]['readable'] 189 | color_mapping = colormap 190 | return id2name 191 | 192 | 193 | class Mapillary(data.Dataset): 194 | def __init__(self, quality, mode, joint_transform_list=None, 195 | transform=None, target_transform=None, target_aux_transform=None, 196 | image_in=False, dump_images=False, class_uniform_pct=0, 197 | class_uniform_tile=768, test=False): 198 | """ 199 | class_uniform_pct = Percent of class uniform samples. 1.0 means fully uniform. 200 | 0.0 means fully random. 201 | class_uniform_tile_size = Class uniform tile size 202 | """ 203 | gen_id_to_ignore() 204 | self.quality = quality 205 | self.mode = mode 206 | self.joint_transform_list = joint_transform_list 207 | self.transform = transform 208 | self.target_transform = target_transform 209 | self.image_in = image_in 210 | self.target_aux_transform = target_aux_transform 211 | self.dump_images = dump_images 212 | self.class_uniform_pct = class_uniform_pct 213 | self.class_uniform_tile = class_uniform_tile 214 | self.id2name = gen_colormap() 215 | self.imgs_uniform = None 216 | 217 | 218 | # find all images 219 | self.imgs = make_dataset(quality, mode) 220 | if len(self.imgs) == 0: 221 | raise RuntimeError('Found 0 images, please check the data set') 222 | if test: 223 | np.random.shuffle(self.imgs) 224 | self.imgs = self.imgs[:200] 225 | 226 | if self.class_uniform_pct: 227 | json_fn = 'mapillary_tile{}.json'.format(self.class_uniform_tile) 228 | if os.path.isfile(json_fn): 229 | with open(json_fn, 'r') as json_data: 230 | centroids = json.load(json_data) 231 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 232 | else: 233 | # centroids is a dict (indexed by class) of lists of centroids 234 | self.centroids = uniform.class_centroids_all( 235 | self.imgs, 236 | num_classes, 237 | id2trainid=None, 238 | tile_size=self.class_uniform_tile) 239 | with open(json_fn, 'w') as outfile: 240 | json.dump(self.centroids, outfile, indent=4) 241 | else: 242 | self.centroids = [] 243 | self.build_epoch() 244 | 245 | def build_epoch(self): 246 | if self.class_uniform_pct != 0: 247 | self.imgs_uniform = uniform.build_epoch(self.imgs, 248 | self.centroids, 249 | num_classes, 250 | self.class_uniform_pct) 251 | else: 252 | self.imgs_uniform = self.imgs 253 | 254 | def __getitem__(self, index): 255 | if len(self.imgs_uniform[index]) == 2: 256 | img_path, mask_path = self.imgs_uniform[index] 257 | centroid = None 258 | class_id = None 259 | else: 260 | img_path, mask_path, centroid, class_id = self.imgs_uniform[index] 261 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 262 | img_name = os.path.splitext(os.path.basename(img_path))[0] 263 | 264 | mask = np.array(mask) 265 | mask_copy = mask.copy() 266 | for k, v in id_to_ignore_or_group.items(): 267 | mask_copy[mask == k] = v 268 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 269 | 270 | # Image Transformations 271 | if self.joint_transform_list is not None: 272 | for idx, xform in enumerate(self.joint_transform_list): 273 | if idx == 0 and centroid is not None: 274 | # HACK! Assume the first transform accepts a centroid 275 | img, mask = xform(img, mask, centroid) 276 | else: 277 | img, mask = xform(img, mask) 278 | 279 | if self.dump_images: 280 | outdir = 'dump_imgs_{}'.format(self.mode) 281 | os.makedirs(outdir, exist_ok=True) 282 | if centroid is not None: 283 | dump_img_name = self.id2name[class_id] + '_' + img_name 284 | else: 285 | dump_img_name = img_name 286 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 287 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 288 | mask_img = colorize_mask(np.array(mask)) 289 | img.save(out_img_fn) 290 | mask_img.save(out_msk_fn) 291 | 292 | if self.transform is not None: 293 | img = self.transform(img) 294 | 295 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 296 | img_gt = transforms.Normalize(*rgb_mean_std)(img) 297 | if self.image_in: 298 | eps = 1e-5 299 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])], 300 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps]) 301 | img = transforms.Normalize(*rgb_mean_std)(img) 302 | 303 | if self.target_aux_transform is not None: 304 | mask_aux = self.target_aux_transform(mask) 305 | else: 306 | mask_aux = torch.tensor([0]) 307 | if self.target_transform is not None: 308 | mask = self.target_transform(mask) 309 | 310 | mask = extended_transforms.MaskToTensor()(mask) 311 | return img, mask, img_name, mask_aux 312 | 313 | def __len__(self): 314 | return len(self.imgs_uniform) 315 | 316 | def calculate_weights(self): 317 | raise BaseException("not supported yet") 318 | -------------------------------------------------------------------------------- /datasets/multi_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom DomainUniformConcatDataset 3 | """ 4 | import numpy as np 5 | 6 | from torch.utils.data import Dataset 7 | import torch 8 | from config import cfg 9 | 10 | 11 | np.random.seed(cfg.RANDOM_SEED) 12 | 13 | 14 | class DomainUniformConcatDataset(Dataset): 15 | """ 16 | DomainUniformConcatDataset 17 | 18 | Sample images uniformly across the domains 19 | If bs_mul is n, this outputs # of domains * n images per batch 20 | """ 21 | @staticmethod 22 | def cumsum(sequence): 23 | r, s = [], 0 24 | for e in sequence: 25 | l = len(e) 26 | r.append(l + s) 27 | s += l 28 | return r 29 | 30 | def __init__(self, args, datasets): 31 | """ 32 | This dataset is to return sample image (source) 33 | and augmented sample image (target) 34 | Args: 35 | args: input config arguments 36 | datasets: list of datasets to concat 37 | """ 38 | super(DomainUniformConcatDataset, self).__init__() 39 | self.datasets = datasets 40 | self.lengths = [len(d) for d in datasets] 41 | self.offsets = self.cumsum(datasets) 42 | self.length = np.sum(self.lengths) 43 | 44 | print("# domains: {}, Total length: {}, 1 epoch: {}, offsets: {}".format( 45 | str(len(datasets)), str(self.length), str(len(self)), str(self.offsets))) 46 | 47 | 48 | def __len__(self): 49 | """ 50 | Returns: 51 | The number of images in a domain that has minimum image samples 52 | """ 53 | return min(self.lengths) 54 | 55 | 56 | def _get_batch_from_dataset(self, dataset, idx): 57 | """ 58 | Get batch from dataset 59 | New idx = idx + random integer 60 | Args: 61 | dataset: dataset class object 62 | idx: integer 63 | 64 | Returns: 65 | One batch from dataset 66 | """ 67 | p_index = idx + np.random.randint(len(dataset)) 68 | if p_index > len(dataset) - 1: 69 | p_index -= len(dataset) 70 | 71 | return dataset[p_index] 72 | 73 | 74 | def __getitem__(self, idx): 75 | """ 76 | Args: 77 | idx (int): Index 78 | 79 | Returns: 80 | images corresonding to the index from each domain 81 | """ 82 | imgs = [] 83 | masks = [] 84 | img_names = [] 85 | mask_auxs = [] 86 | 87 | for dataset in self.datasets: 88 | img, mask, img_name, mask_aux = self._get_batch_from_dataset(dataset, idx) 89 | imgs.append(img) 90 | masks.append(mask) 91 | img_names.append(img_name) 92 | mask_auxs.append(mask_aux) 93 | imgs, masks, mask_auxs = torch.stack(imgs, 0), torch.stack(masks, 0), torch.stack(mask_auxs, 0) 94 | 95 | return imgs, masks, img_names, mask_auxs 96 | 97 | -------------------------------------------------------------------------------- /datasets/nullloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Null Loader 3 | """ 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | 8 | num_classes = 19 9 | ignore_label = 255 10 | 11 | class NullLoader(data.Dataset): 12 | """ 13 | Null Dataset for Performance 14 | """ 15 | def __init__(self,crop_size): 16 | self.imgs = range(200) 17 | self.crop_size = crop_size 18 | 19 | def __getitem__(self, index): 20 | #Return img, mask, name 21 | return torch.FloatTensor(np.zeros((3,self.crop_size,self.crop_size))), torch.LongTensor(np.zeros((self.crop_size,self.crop_size))), 'img' + str(index) 22 | 23 | def __len__(self): 24 | return len(self.imgs) -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | 37 | 38 | import math 39 | import torch 40 | from torch.distributed import get_world_size, get_rank 41 | from torch.utils.data import Sampler 42 | 43 | class DistributedSampler(Sampler): 44 | """Sampler that restricts data loading to a subset of the dataset. 45 | 46 | It is especially useful in conjunction with 47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 48 | process can pass a DistributedSampler instance as a DataLoader sampler, 49 | and load a subset of the original dataset that is exclusive to it. 50 | 51 | .. note:: 52 | Dataset is assumed to be of constant size. 53 | 54 | Arguments: 55 | dataset: Dataset used for sampling. 56 | num_replicas (optional): Number of processes participating in 57 | distributed training. 58 | rank (optional): Rank of the current process within num_replicas. 59 | """ 60 | 61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None): 62 | if num_replicas is None: 63 | num_replicas = get_world_size() 64 | if rank is None: 65 | rank = get_rank() 66 | self.dataset = dataset 67 | self.num_replicas = num_replicas 68 | self.rank = rank 69 | self.epoch = 0 70 | self.consecutive_sample = consecutive_sample 71 | self.permutation = permutation 72 | if pad: 73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 74 | else: 75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 76 | self.total_size = self.num_samples * self.num_replicas 77 | 78 | def __iter__(self): 79 | # deterministically shuffle based on epoch 80 | g = torch.Generator() 81 | g.manual_seed(self.epoch) 82 | 83 | if self.permutation: 84 | indices = list(torch.randperm(len(self.dataset), generator=g)) 85 | else: 86 | indices = list([x for x in range(len(self.dataset))]) 87 | 88 | # add extra samples to make it evenly divisible 89 | if self.total_size > len(indices): 90 | indices += indices[:(self.total_size - len(indices))] 91 | 92 | # subsample 93 | if self.consecutive_sample: 94 | offset = self.num_samples * self.rank 95 | indices = indices[offset:offset + self.num_samples] 96 | else: 97 | indices = indices[self.rank:self.total_size:self.num_replicas] 98 | assert len(indices) == self.num_samples 99 | 100 | return iter(indices) 101 | 102 | def __len__(self): 103 | return self.num_samples 104 | 105 | def set_epoch(self, epoch): 106 | self.epoch = epoch 107 | 108 | def set_num_samples(self): 109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 110 | self.total_size = self.num_samples * self.num_replicas -------------------------------------------------------------------------------- /datasets/uniform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Uniform sampling of classes. 3 | For all images, for all classes, generate centroids around which to sample. 4 | 5 | All images are divided into tiles. 6 | For each tile, a class can be present or not. If it is 7 | present, calculate the centroid of the class and record it. 8 | 9 | We would like to thank Peter Kontschieder for the inspiration of this idea. 10 | """ 11 | 12 | import logging 13 | from collections import defaultdict 14 | from PIL import Image 15 | import numpy as np 16 | from scipy import ndimage 17 | from tqdm import tqdm 18 | 19 | pbar = None 20 | 21 | class Point(): 22 | """ 23 | Point Class For X and Y Location 24 | """ 25 | def __init__(self, x, y): 26 | self.x = x 27 | self.y = y 28 | 29 | 30 | def calc_tile_locations(tile_size, image_size): 31 | """ 32 | Divide an image into tiles to help us cover classes that are spread out. 33 | tile_size: size of tile to distribute 34 | image_size: original image size 35 | return: locations of the tiles 36 | """ 37 | image_size_y, image_size_x = image_size 38 | locations = [] 39 | for y in range(image_size_y // tile_size): 40 | for x in range(image_size_x // tile_size): 41 | x_offs = x * tile_size 42 | y_offs = y * tile_size 43 | locations.append((x_offs, y_offs)) 44 | return locations 45 | 46 | 47 | def class_centroids_image(item, tile_size, num_classes, id2trainid): 48 | """ 49 | For one image, calculate centroids for all classes present in image. 50 | item: image, image_name 51 | tile_size: 52 | num_classes: 53 | id2trainid: mapping from original id to training ids 54 | return: Centroids are calculated for each tile. 55 | """ 56 | image_fn, label_fn = item 57 | centroids = defaultdict(list) 58 | mask = np.array(Image.open(label_fn)) 59 | if len(mask.shape) == 3: 60 | # Remove instance mask 61 | mask = mask[:,:,0] 62 | image_size = mask.shape 63 | tile_locations = calc_tile_locations(tile_size, image_size) 64 | 65 | mask_copy = mask.copy() 66 | if id2trainid: 67 | for k, v in id2trainid.items(): 68 | mask[mask_copy == k] = v 69 | 70 | for x_offs, y_offs in tile_locations: 71 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size] 72 | for class_id in range(num_classes): 73 | if class_id in patch: 74 | patch_class = (patch == class_id).astype(int) 75 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class) 76 | centroid_y = int(centroid_y) + y_offs 77 | centroid_x = int(centroid_x) + x_offs 78 | centroid = (centroid_x, centroid_y) 79 | centroids[class_id].append((image_fn, label_fn, centroid, class_id)) 80 | pbar.update(1) 81 | return centroids 82 | 83 | import scipy.misc as m 84 | 85 | def class_centroids_image_from_color(item, tile_size, num_classes, id2trainid): 86 | """ 87 | For one image, calculate centroids for all classes present in image. 88 | item: image, image_name 89 | tile_size: 90 | num_classes: 91 | id2trainid: mapping from original id to training ids 92 | return: Centroids are calculated for each tile. 93 | """ 94 | image_fn, label_fn = item 95 | centroids = defaultdict(list) 96 | mask = m.imread(label_fn) 97 | image_size = mask[:,:,0].shape 98 | tile_locations = calc_tile_locations(tile_size, image_size) 99 | 100 | # mask = m.imread(label_fn) 101 | # mask_copy = np.full((img.size[1], img.size[0]), 255, dtype=np.uint8) 102 | # for k, v in id2trainid.items(): 103 | # mask_copy[(mask == k)[:,:,0]] = v 104 | # mask = Image.fromarray(mask_copy.astype(np.uint8)) 105 | 106 | # mask_copy = mask.copy() 107 | # mask_copy = mask.copy() 108 | # if id2trainid: 109 | # for k, v in id2trainid.items(): 110 | # mask[mask_copy == k] = v 111 | 112 | mask_copy = np.full(image_size, 255, dtype=np.uint8) 113 | 114 | if id2trainid: 115 | for k, v in id2trainid.items(): 116 | # print("0", mask.shape) 117 | # print("1", ((mask == np.array(k))[:,:,0]).shape) # 1052, 1914 118 | # # print("2", mask == np.array(k)[:,:,0]) 119 | # break 120 | # if v != 255: 121 | # print(v) 122 | # if v == 2: 123 | # print(k, v, "num", np.count_nonzero(mask == np.array(k))) 124 | # break 125 | if v != 255 and v != -1: 126 | mask_copy[(mask == np.array(k))[:,:,0] & (mask == np.array(k))[:,:,1] & (mask == np.array(k))[:,:,2]] = v 127 | mask = mask_copy 128 | 129 | # mask_copy = mask.copy() 130 | # if id2trainid: 131 | # for k, v in id2trainid.items(): 132 | # mask[mask_copy == k] = v 133 | 134 | for x_offs, y_offs in tile_locations: 135 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size] 136 | for class_id in range(num_classes): 137 | if class_id in patch: 138 | patch_class = (patch == class_id).astype(int) 139 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class) 140 | centroid_y = int(centroid_y) + y_offs 141 | centroid_x = int(centroid_x) + x_offs 142 | centroid = (centroid_x, centroid_y) 143 | centroids[class_id].append((image_fn, label_fn, centroid, class_id)) 144 | pbar.update(1) 145 | return centroids 146 | 147 | def pooled_class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024): 148 | """ 149 | Calculate class centroids for all classes for all images for all tiles. 150 | items: list of (image_fn, label_fn) 151 | tile size: size of tile 152 | returns: dict that contains a list of centroids for each class 153 | """ 154 | from multiprocessing.dummy import Pool 155 | from functools import partial 156 | pool = Pool(32) 157 | global pbar 158 | pbar = tqdm(total=len(items), desc='pooled centroid extraction') 159 | class_centroids_item = partial(class_centroids_image_from_color, 160 | num_classes=num_classes, 161 | id2trainid=id2trainid, 162 | tile_size=tile_size) 163 | 164 | centroids = defaultdict(list) 165 | new_centroids = pool.map(class_centroids_item, items) 166 | pool.close() 167 | pool.join() 168 | 169 | # combine each image's items into a single global dict 170 | for image_items in new_centroids: 171 | for class_id in image_items: 172 | centroids[class_id].extend(image_items[class_id]) 173 | return centroids 174 | 175 | 176 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 177 | """ 178 | Calculate class centroids for all classes for all images for all tiles. 179 | items: list of (image_fn, label_fn) 180 | tile size: size of tile 181 | returns: dict that contains a list of centroids for each class 182 | """ 183 | from multiprocessing.dummy import Pool 184 | from functools import partial 185 | pool = Pool(80) 186 | global pbar 187 | pbar = tqdm(total=len(items), desc='pooled centroid extraction') 188 | class_centroids_item = partial(class_centroids_image, 189 | num_classes=num_classes, 190 | id2trainid=id2trainid, 191 | tile_size=tile_size) 192 | 193 | centroids = defaultdict(list) 194 | new_centroids = pool.map(class_centroids_item, items) 195 | pool.close() 196 | pool.join() 197 | 198 | # combine each image's items into a single global dict 199 | for image_items in new_centroids: 200 | for class_id in image_items: 201 | centroids[class_id].extend(image_items[class_id]) 202 | return centroids 203 | 204 | 205 | def unpooled_class_centroids_all(items, num_classes, tile_size=1024): 206 | """ 207 | Calculate class centroids for all classes for all images for all tiles. 208 | items: list of (image_fn, label_fn) 209 | tile size: size of tile 210 | returns: dict that contains a list of centroids for each class 211 | """ 212 | centroids = defaultdict(list) 213 | global pbar 214 | pbar = tqdm(total=len(items), desc='centroid extraction') 215 | for image, label in items: 216 | new_centroids = class_centroids_image((image, label), 217 | tile_size, 218 | num_classes) 219 | for class_id in new_centroids: 220 | centroids[class_id].extend(new_centroids[class_id]) 221 | 222 | return centroids 223 | 224 | 225 | def class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024): 226 | """ 227 | intermediate function to call pooled_class_centroid 228 | """ 229 | 230 | pooled_centroids = pooled_class_centroids_all_from_color(items, num_classes, 231 | id2trainid, tile_size) 232 | return pooled_centroids 233 | 234 | 235 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 236 | """ 237 | intermediate function to call pooled_class_centroid 238 | """ 239 | 240 | pooled_centroids = pooled_class_centroids_all(items, num_classes, 241 | id2trainid, tile_size) 242 | return pooled_centroids 243 | 244 | 245 | def random_sampling(alist, num): 246 | """ 247 | Randomly sample num items from the list 248 | alist: list of centroids to sample from 249 | num: can be larger than the list and if so, then wrap around 250 | return: class uniform samples from the list 251 | """ 252 | sampling = [] 253 | len_list = len(alist) 254 | assert len_list, 'len_list is zero!' 255 | indices = np.arange(len_list) 256 | np.random.shuffle(indices) 257 | 258 | for i in range(num): 259 | item = alist[indices[i % len_list]] 260 | sampling.append(item) 261 | return sampling 262 | 263 | 264 | def build_epoch(imgs, centroids, num_classes, class_uniform_pct): 265 | """ 266 | Generate an epochs-worth of crops using uniform sampling. Needs to be called every 267 | imgs: list of imgs 268 | centroids: 269 | num_classes: 270 | class_uniform_pct: class uniform sampling percent ( % of uniform images in one epoch ) 271 | """ 272 | logging.info("Class Uniform Percentage: %s", str(class_uniform_pct)) 273 | num_epoch = int(len(imgs)) 274 | 275 | logging.info('Class Uniform items per Epoch:%s', str(num_epoch)) 276 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes) 277 | num_rand = num_epoch - num_per_class * num_classes 278 | # create random crops 279 | imgs_uniform = random_sampling(imgs, num_rand) 280 | 281 | # now add uniform sampling 282 | for class_id in range(num_classes): 283 | string_format = "cls %d len %d"% (class_id, len(centroids[class_id])) 284 | logging.info(string_format) 285 | for class_id in range(num_classes): 286 | centroid_len = len(centroids[class_id]) 287 | if centroid_len == 0: 288 | pass 289 | else: 290 | class_centroids = random_sampling(centroids[class_id], num_per_class) 291 | imgs_uniform.extend(class_centroids) 292 | 293 | return imgs_uniform 294 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss.py 3 | """ 4 | 5 | import logging 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import datasets 11 | from config import cfg 12 | 13 | 14 | def get_loss(args): 15 | """ 16 | Get the criterion based on the loss function 17 | args: commandline arguments 18 | return: criterion, criterion_val 19 | """ 20 | if args.cls_wt_loss: 21 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, 22 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 23 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 24 | else: 25 | ce_weight = None 26 | 27 | if args.img_wt_loss: 28 | criterion = ImageBasedCrossEntropyLoss2d( 29 | classes=datasets.num_classes, size_average=True, 30 | ignore_index=datasets.ignore_label, 31 | upper_bound=args.wt_bound).cuda() 32 | elif args.jointwtborder: 33 | criterion = ImgWtLossSoftNLL(classes=datasets.num_classes, 34 | ignore_index=datasets.ignore_label, 35 | upper_bound=args.wt_bound).cuda() 36 | else: 37 | print("standard cross entropy") 38 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean', 39 | ignore_index=datasets.ignore_label).cuda() 40 | 41 | criterion_val = nn.CrossEntropyLoss(reduction='mean', 42 | ignore_index=datasets.ignore_label).cuda() 43 | return criterion, criterion_val 44 | 45 | def get_loss_by_epoch(args): 46 | """ 47 | Get the criterion based on the loss function 48 | args: commandline arguments 49 | return: criterion, criterion_val 50 | """ 51 | 52 | if args.img_wt_loss: 53 | criterion = ImageBasedCrossEntropyLoss2d( 54 | classes=datasets.num_classes, size_average=True, 55 | ignore_index=datasets.ignore_label, 56 | upper_bound=args.wt_bound).cuda() 57 | elif args.jointwtborder: 58 | criterion = ImgWtLossSoftNLL_by_epoch(classes=datasets.num_classes, 59 | ignore_index=datasets.ignore_label, 60 | upper_bound=args.wt_bound).cuda() 61 | else: 62 | criterion = CrossEntropyLoss2d(size_average=True, 63 | ignore_index=datasets.ignore_label).cuda() 64 | 65 | criterion_val = CrossEntropyLoss2d(size_average=True, 66 | weight=None, 67 | ignore_index=datasets.ignore_label).cuda() 68 | return criterion, criterion_val 69 | 70 | 71 | def get_loss_aux(args): 72 | """ 73 | Get the criterion based on the loss function 74 | args: commandline arguments 75 | return: criterion, criterion_val 76 | """ 77 | if args.cls_wt_loss: 78 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, 79 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 80 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 81 | else: 82 | ce_weight = None 83 | 84 | print("standard cross entropy") 85 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean', 86 | ignore_index=datasets.ignore_label).cuda() 87 | 88 | return criterion 89 | 90 | def get_loss_bcelogit(args): 91 | if args.cls_wt_loss: 92 | pos_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, 93 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 94 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 95 | else: 96 | pos_weight = None 97 | print("standard bce with logit cross entropy") 98 | criterion = nn.BCEWithLogitsLoss(reduction='mean').cuda() 99 | 100 | return criterion 101 | 102 | def weighted_binary_cross_entropy(output, target): 103 | 104 | weights = torch.Tensor([0.1, 0.9]) 105 | 106 | loss = weights[1] * (target * torch.log(output)) + \ 107 | weights[0] * ((1 - target) * torch.log(1 - output)) 108 | 109 | return torch.neg(torch.mean(loss)) 110 | 111 | 112 | class L1Loss(nn.Module): 113 | def __init__(self): 114 | super(L1Loss, self).__init__() 115 | 116 | def __call__(self, in0, in1): 117 | return torch.sum(torch.abs(in0 - in1), dim=1, keepdim=True) 118 | 119 | 120 | class ImageBasedCrossEntropyLoss2d(nn.Module): 121 | """ 122 | Image Weighted Cross Entropy Loss 123 | """ 124 | 125 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255, 126 | norm=False, upper_bound=1.0): 127 | super(ImageBasedCrossEntropyLoss2d, self).__init__() 128 | logging.info("Using Per Image based weighted loss") 129 | self.num_classes = classes 130 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index) 131 | self.norm = norm 132 | self.upper_bound = upper_bound 133 | self.batch_weights = cfg.BATCH_WEIGHTING 134 | self.logsoftmax = nn.LogSoftmax(dim=1) 135 | 136 | def calculate_weights(self, target): 137 | """ 138 | Calculate weights of classes based on the training crop 139 | """ 140 | hist = np.histogram(target.flatten(), range( 141 | self.num_classes + 1), normed=True)[0] 142 | if self.norm: 143 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 144 | else: 145 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 146 | return hist 147 | 148 | def forward(self, inputs, targets): 149 | 150 | target_cpu = targets.data.cpu().numpy() 151 | if self.batch_weights: 152 | weights = self.calculate_weights(target_cpu) 153 | self.nll_loss.weight = torch.Tensor(weights).cuda() 154 | 155 | loss = 0.0 156 | for i in range(0, inputs.shape[0]): 157 | if not self.batch_weights: 158 | weights = self.calculate_weights(target_cpu[i]) 159 | self.nll_loss.weight = torch.Tensor(weights).cuda() 160 | 161 | loss += self.nll_loss(self.logsoftmax(inputs[i].unsqueeze(0)), 162 | targets[i].unsqueeze(0)) 163 | return loss 164 | 165 | 166 | 167 | class CrossEntropyLoss2d(nn.Module): 168 | """ 169 | Cross Entroply NLL Loss 170 | """ 171 | 172 | def __init__(self, weight=None, size_average=True, ignore_index=255): 173 | super(CrossEntropyLoss2d, self).__init__() 174 | logging.info("Using Cross Entropy Loss") 175 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index) 176 | self.logsoftmax = nn.LogSoftmax(dim=1) 177 | # self.weight = weight 178 | 179 | def forward(self, inputs, targets): 180 | return self.nll_loss(self.logsoftmax(inputs), targets) 181 | 182 | def customsoftmax(inp, multihotmask): 183 | """ 184 | Custom Softmax 185 | """ 186 | soft = F.softmax(inp, dim=1) 187 | # This takes the mask * softmax ( sums it up hence summing up the classes in border 188 | # then takes of summed up version vs no summed version 189 | return torch.log( 190 | torch.max(soft, (multihotmask * (soft * multihotmask).sum(1, keepdim=True))) 191 | ) 192 | 193 | class ImgWtLossSoftNLL(nn.Module): 194 | """ 195 | Relax Loss 196 | """ 197 | 198 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0, 199 | norm=False): 200 | super(ImgWtLossSoftNLL, self).__init__() 201 | self.weights = weights 202 | self.num_classes = classes 203 | self.ignore_index = ignore_index 204 | self.upper_bound = upper_bound 205 | self.norm = norm 206 | self.batch_weights = cfg.BATCH_WEIGHTING 207 | 208 | def calculate_weights(self, target): 209 | """ 210 | Calculate weights of the classes based on training crop 211 | """ 212 | if len(target.shape) == 3: 213 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum() 214 | else: 215 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum() 216 | if self.norm: 217 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 218 | else: 219 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 220 | return hist[:-1] 221 | 222 | def custom_nll(self, inputs, target, class_weights, border_weights, mask): 223 | """ 224 | NLL Relaxed Loss Implementation 225 | """ 226 | if (cfg.REDUCE_BORDER_ITER != -1 and cfg.ITER > cfg.REDUCE_BORDER_ITER): 227 | border_weights = 1 / border_weights 228 | target[target > 1] = 1 229 | 230 | loss_matrix = (-1 / border_weights * 231 | (target[:, :-1, :, :].float() * 232 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 233 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \ 234 | (1. - mask.float()) 235 | 236 | # loss_matrix[border_weights > 1] = 0 237 | loss = loss_matrix.sum() 238 | 239 | # +1 to prevent division by 0 240 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1) 241 | return loss 242 | 243 | def forward(self, inputs, target): 244 | weights = target[:, :-1, :, :].sum(1).float() 245 | ignore_mask = (weights == 0) 246 | weights[ignore_mask] = 1 247 | 248 | loss = 0 249 | target_cpu = target.data.cpu().numpy() 250 | 251 | if self.batch_weights: 252 | class_weights = self.calculate_weights(target_cpu) 253 | 254 | for i in range(0, inputs.shape[0]): 255 | if not self.batch_weights: 256 | class_weights = self.calculate_weights(target_cpu[i]) 257 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0), 258 | target[i].unsqueeze(0), 259 | class_weights=torch.Tensor(class_weights).cuda(), 260 | border_weights=weights[i], mask=ignore_mask[i]) 261 | 262 | loss = loss / inputs.shape[0] 263 | return loss 264 | 265 | class ImgWtLossSoftNLL_by_epoch(nn.Module): 266 | """ 267 | Relax Loss 268 | """ 269 | 270 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0, 271 | norm=False): 272 | super(ImgWtLossSoftNLL_by_epoch, self).__init__() 273 | self.weights = weights 274 | self.num_classes = classes 275 | self.ignore_index = ignore_index 276 | self.upper_bound = upper_bound 277 | self.norm = norm 278 | self.batch_weights = cfg.BATCH_WEIGHTING 279 | self.fp16 = False 280 | 281 | 282 | def calculate_weights(self, target): 283 | """ 284 | Calculate weights of the classes based on training crop 285 | """ 286 | if len(target.shape) == 3: 287 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum() 288 | else: 289 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum() 290 | if self.norm: 291 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 292 | else: 293 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 294 | return hist[:-1] 295 | 296 | def custom_nll(self, inputs, target, class_weights, border_weights, mask): 297 | """ 298 | NLL Relaxed Loss Implementation 299 | """ 300 | if (cfg.REDUCE_BORDER_EPOCH != -1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 301 | border_weights = 1 / border_weights 302 | target[target > 1] = 1 303 | if self.fp16: 304 | loss_matrix = (-1 / border_weights * 305 | (target[:, :-1, :, :].half() * 306 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 307 | customsoftmax(inputs, target[:, :-1, :, :].half())).sum(1)) * \ 308 | (1. - mask.half()) 309 | else: 310 | loss_matrix = (-1 / border_weights * 311 | (target[:, :-1, :, :].float() * 312 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 313 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \ 314 | (1. - mask.float()) 315 | 316 | # loss_matrix[border_weights > 1] = 0 317 | loss = loss_matrix.sum() 318 | 319 | # +1 to prevent division by 0 320 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1) 321 | return loss 322 | 323 | def forward(self, inputs, target): 324 | if self.fp16: 325 | weights = target[:, :-1, :, :].sum(1).half() 326 | else: 327 | weights = target[:, :-1, :, :].sum(1).float() 328 | ignore_mask = (weights == 0) 329 | weights[ignore_mask] = 1 330 | 331 | loss = 0 332 | target_cpu = target.data.cpu().numpy() 333 | 334 | if self.batch_weights: 335 | class_weights = self.calculate_weights(target_cpu) 336 | 337 | for i in range(0, inputs.shape[0]): 338 | if not self.batch_weights: 339 | class_weights = self.calculate_weights(target_cpu[i]) 340 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0), 341 | target[i].unsqueeze(0), 342 | class_weights=torch.Tensor(class_weights).cuda(), 343 | border_weights=weights, mask=ignore_mask[i]) 344 | 345 | return loss 346 | -------------------------------------------------------------------------------- /network/Resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code Adapted from: 3 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | import torch 37 | import torch.nn as nn 38 | import torch.utils.model_zoo as model_zoo 39 | import network.mynn as mynn 40 | from network.adain import AdaptiveInstanceNormalization 41 | 42 | 43 | __all__ = ['ResNet', 'resnet50'] 44 | 45 | model_urls = { 46 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 47 | } 48 | 49 | 50 | class Bottleneck(nn.Module): 51 | """ 52 | Bottleneck Layer for Resnet 53 | """ 54 | expansion = 4 55 | 56 | def __init__(self, inplanes, planes, stride=1, downsample=None, fs=0): 57 | super(Bottleneck, self).__init__() 58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 59 | self.bn1 = mynn.Norm2d(planes) 60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 61 | padding=1, bias=False) 62 | self.bn2 = mynn.Norm2d(planes) 63 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 64 | self.bn3 = mynn.Norm2d(planes * self.expansion) 65 | self.downsample = downsample 66 | self.stride = stride 67 | 68 | self.fs = fs 69 | if self.fs == 1: 70 | self.instance_norm_layer = AdaptiveInstanceNormalization() 71 | self.relu = nn.ReLU(inplace=False) 72 | else: 73 | self.relu = nn.ReLU(inplace=True) 74 | 75 | def forward(self, x_tuple): 76 | if len(x_tuple) == 1: 77 | x = x_tuple[0] 78 | elif len(x_tuple) == 3: 79 | x = x_tuple[0] 80 | x_w = x_tuple[1] 81 | x_sw = x_tuple[2] 82 | else: 83 | raise NotImplementedError("%d is not supported length of the tuple"%(len(x_tuple))) 84 | 85 | residual = x 86 | 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv3(out) 96 | out = self.bn3(out) 97 | 98 | if self.downsample is not None: 99 | residual = self.downsample(x) 100 | 101 | out += residual 102 | 103 | if len(x_tuple) == 3: 104 | with torch.no_grad(): 105 | residual_w = x_w 106 | 107 | out_w = self.conv1(x_w) 108 | out_w = self.bn1(out_w) 109 | out_w = self.relu(out_w) 110 | 111 | out_w = self.conv2(out_w) 112 | out_w = self.bn2(out_w) 113 | out_w = self.relu(out_w) 114 | 115 | out_w = self.conv3(out_w) 116 | out_w = self.bn3(out_w) 117 | 118 | if self.downsample is not None: 119 | residual_w = self.downsample(x_w) 120 | 121 | out_w += residual_w 122 | 123 | residual_sw = x_sw 124 | 125 | out_sw = self.conv1(x_sw) 126 | out_sw = self.bn1(out_sw) 127 | out_sw = self.relu(out_sw) 128 | 129 | out_sw = self.conv2(out_sw) 130 | out_sw = self.bn2(out_sw) 131 | out_sw = self.relu(out_sw) 132 | 133 | out_sw = self.conv3(out_sw) 134 | out_sw = self.bn3(out_sw) 135 | 136 | if self.downsample is not None: 137 | residual_sw = self.downsample(x_sw) 138 | 139 | out_sw += residual_sw 140 | 141 | 142 | if self.fs == 1: 143 | out = self.instance_norm_layer(out) 144 | if len(x_tuple) == 3: 145 | out_sw = self.instance_norm_layer(out_sw, out_w) 146 | with torch.no_grad(): 147 | out_w = self.instance_norm_layer(out_w) 148 | 149 | out = self.relu(out) 150 | 151 | if len(x_tuple) == 3: 152 | with torch.no_grad(): 153 | out_w = self.relu(out_w) 154 | out_sw = self.relu(out_sw) 155 | return [out, out_w, out_sw] 156 | else: 157 | return [out] 158 | 159 | 160 | class ResNet(nn.Module): 161 | """ 162 | Resnet Global Module for Initialization 163 | """ 164 | 165 | def __init__(self, block, layers, fs_layer=None, num_classes=1000): 166 | self.inplanes = 64 167 | super(ResNet, self).__init__() 168 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 169 | bias=False) 170 | if fs_layer[0] == 1: 171 | self.bn1 = AdaptiveInstanceNormalization() 172 | self.relu = nn.ReLU(inplace=False) 173 | else: 174 | self.bn1 = mynn.Norm2d(64) 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 178 | self.layer1 = self._make_layer(block, 64, layers[0], fs_layer=fs_layer[1]) 179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, fs_layer=fs_layer[2]) 180 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, fs_layer=fs_layer[3]) 181 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, fs_layer=fs_layer[4]) 182 | self.avgpool = nn.AvgPool2d(7, stride=1) 183 | self.fc = nn.Linear(512 * block.expansion, num_classes) 184 | 185 | for m in self.modules(): 186 | if isinstance(m, nn.Conv2d): 187 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 188 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.SyncBatchNorm): 189 | if m.weight is not None: 190 | nn.init.constant_(m.weight, 1) 191 | if m.bias is not None: 192 | nn.init.constant_(m.bias, 0) 193 | 194 | def _make_layer(self, block, planes, blocks, stride=1, fs_layer=0): 195 | downsample = None 196 | if stride != 1 or self.inplanes != planes * block.expansion: 197 | downsample = nn.Sequential( 198 | nn.Conv2d(self.inplanes, planes * block.expansion, 199 | kernel_size=1, stride=stride, bias=False), 200 | mynn.Norm2d(planes * block.expansion), 201 | ) 202 | 203 | layers = [] 204 | layers.append(block(self.inplanes, planes, stride, downsample, fs=0)) 205 | self.inplanes = planes * block.expansion 206 | for index in range(1, blocks): 207 | layers.append(block(self.inplanes, planes, 208 | fs=0 if (fs_layer > 0 and index < blocks - 1) else fs_layer)) 209 | return nn.Sequential(*layers) 210 | 211 | def forward(self, x): 212 | x = self.conv1(x) 213 | x = self.bn1(x) 214 | x = self.relu(x) 215 | 216 | x = self.maxpool(x) 217 | 218 | x = self.layer1(x) 219 | x = self.layer2(x) 220 | x = self.layer3(x) 221 | x = self.layer4(x) 222 | 223 | x = self.avgpool(x) 224 | x = x.view(x.size(0), -1) 225 | x = self.fc(x) 226 | 227 | return x 228 | 229 | 230 | def resnet50(pretrained=True, fs_layer=None, **kwargs): 231 | """Constructs a ResNet-50 model. 232 | 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | """ 236 | if fs_layer is None: 237 | fs_layer = [0, 0, 0, 0, 0] 238 | model = ResNet(Bottleneck, [3, 4, 6, 3], fs_layer=fs_layer, **kwargs) 239 | if pretrained: 240 | print("########### pretrained ##############") 241 | mynn.forgiving_state_restore(model, model_zoo.load_url(model_urls['resnet50'])) 242 | return model 243 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Network Initializations 3 | """ 4 | 5 | import logging 6 | import importlib 7 | import torch 8 | import datasets 9 | 10 | 11 | 12 | def get_net(args, criterion, criterion_aux=None, cont_proj_head=0, wild_cont_dict_size=0): 13 | """ 14 | Get Network Architecture based on arguments provided 15 | """ 16 | net = get_model(args=args, num_classes=datasets.num_classes, 17 | criterion=criterion, criterion_aux=criterion_aux, cont_proj_head=cont_proj_head, wild_cont_dict_size=wild_cont_dict_size) 18 | num_params = sum([param.nelement() for param in net.parameters()]) 19 | logging.info('Model params = {:2.3f}M'.format(num_params / 1000000)) 20 | 21 | net = net.cuda() 22 | return net 23 | 24 | 25 | def warp_network_in_dataparallel(net, gpuid): 26 | """ 27 | Wrap the network in Dataparallel 28 | """ 29 | # torch.cuda.set_device(gpuid) 30 | # net.cuda(gpuid) 31 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid], find_unused_parameters=True) 32 | # net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid])#, find_unused_parameters=True) 33 | return net 34 | 35 | 36 | def get_model(args, num_classes, criterion, criterion_aux=None, cont_proj_head=0, wild_cont_dict_size=0): 37 | """ 38 | Fetch Network Function Pointer 39 | """ 40 | network = args.arch 41 | module = network[:network.rfind('.')] 42 | model = network[network.rfind('.') + 1:] 43 | mod = importlib.import_module(module) 44 | net_func = getattr(mod, model) 45 | net = net_func(args=args, num_classes=num_classes, criterion=criterion, criterion_aux=criterion_aux, cont_proj_head=cont_proj_head, wild_cont_dict_size=wild_cont_dict_size) 46 | return net 47 | -------------------------------------------------------------------------------- /network/adain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class AdaptiveInstanceNormalization(nn.Module): 6 | 7 | def __init__(self): 8 | super(AdaptiveInstanceNormalization, self).__init__() 9 | 10 | def forward(self, x_cont, x_style=None): 11 | if x_style is not None: 12 | assert (x_cont.size()[:2] == x_style.size()[:2]) 13 | size = x_cont.size() 14 | style_mean, style_std = calc_mean_std(x_style) 15 | content_mean, content_std = calc_mean_std(x_cont) 16 | 17 | normalized_x_cont = (x_cont - content_mean.expand(size))/content_std.expand(size) 18 | denormalized_x_cont = normalized_x_cont * style_std.expand(size) + style_mean.expand(size) 19 | 20 | return denormalized_x_cont 21 | 22 | else: 23 | return x_cont 24 | 25 | 26 | def calc_mean_std(feat, eps=1e-5): 27 | # eps is a small value added to the variance to avoid divide-by-zero. 28 | size = feat.size() 29 | assert (len(size) == 4) 30 | N, C = size[:2] 31 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 32 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 33 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 34 | return feat_mean, feat_std 35 | 36 | -------------------------------------------------------------------------------- /network/cel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | import torchvision 5 | import numpy as np 6 | 7 | 8 | def get_content_extension_loss(feats_s, feats_sw, feats_w, gts, queue): 9 | 10 | B, C, H, W = feats_s.shape # feat feature size (B X C X H X W) 11 | 12 | # uniform sampling with a size of 64 x 64 for source and wild-stylized source feature maps 13 | H_W_resize = 64 14 | HW = H_W_resize * H_W_resize 15 | 16 | upsample_n = nn.Upsample(size=[H_W_resize, H_W_resize], mode='nearest') 17 | feats_s_flat = upsample_n(feats_s) 18 | feats_sw_flat = upsample_n(feats_sw) 19 | 20 | feats_s_flat = feats_s_flat.contiguous().view(B, C, -1) # B X C X H X W > B X C X (H X W) 21 | feats_sw_flat = feats_sw_flat.contiguous().view(B, C, -1) # B X C X H X W > B X C X (H X W) 22 | gts_flat = upsample_n(gts.unsqueeze(1).float()).squeeze(1).long().view(B, HW) 23 | 24 | # uniform sampling with a size of 16 x 16 for wild feature map 25 | H_W_resize_w = 16 26 | HW_w = H_W_resize_w * H_W_resize_w 27 | 28 | upsample_n_w = nn.Upsample(size=[H_W_resize_w, H_W_resize_w], mode='nearest') 29 | feats_w_flat = upsample_n_w(feats_w) 30 | feats_w_flat = torch.einsum("bchw->bhwc", feats_w_flat).contiguous().view(B*H_W_resize_w*H_W_resize_w, C) # B X C X H X W > (B X H X W) X C 31 | 32 | # normalize feature of each pixel 33 | feats_s_flat = nn.functional.normalize(feats_s_flat, p=2, dim=1) 34 | feats_sw_flat = nn.functional.normalize(feats_sw_flat, p=2, dim=1) 35 | feats_w_flat = nn.functional.normalize(feats_w_flat, p=2, dim=1).detach() # (B X H X W) X C 36 | 37 | 38 | # log(dot(feats_s_flat, feats_sw_flat)) 39 | T = 0.07 40 | logits_sce = torch.bmm(feats_s_flat.transpose(1,2), feats_sw_flat) # dot product: B X (H X W) X (H X W) 41 | logits_sce = (torch.clamp(logits_sce, min=-1, max=1))/T 42 | 43 | # compute ignore mask: same-class (excluding self) + unknown-labeled pixels 44 | # compute positive mask (same-class) 45 | logits_mask_sce_ignore = torch.eq(gts_flat.unsqueeze(2), gts_flat.unsqueeze(1)) # pos:1, neg:0. B X (H X W) X (H X W) 46 | # include unknown-labeled pixel 47 | logits_mask_sce_ignore = include_unknown(logits_mask_sce_ignore, gts_flat) 48 | 49 | # exclude self-pixel 50 | logits_mask_sce_ignore *= ~torch.eye(HW,HW).type(torch.cuda.BoolTensor).unsqueeze(0).expand([B, -1, -1]) # self:1, other:0. B X (H X W) X (H X W) 51 | 52 | # compute positive mask for cross entropy loss: B X (H X W) 53 | logits_mask_sce_pos = torch.linspace(start=0, end=HW-1, steps=HW).unsqueeze(0).expand([B, -1]).type(torch.cuda.LongTensor) 54 | 55 | # compute unknown-labeled mask for cross entropy loss: B X (H X W) 56 | logits_mask_sce_unk = torch.zeros_like(logits_mask_sce_pos, dtype=torch.bool) 57 | logits_mask_sce_unk[gts_flat>254] = True 58 | 59 | # compute loss_sce 60 | eps = 1e-5 61 | logits_sce[logits_mask_sce_ignore] = -1/T 62 | CELoss = nn.CrossEntropyLoss(reduction='none') 63 | loss_sce = CELoss(logits_sce.transpose(1,2), logits_mask_sce_pos) 64 | loss_sce = ((loss_sce * (~logits_mask_sce_unk)).sum(1) / ((~logits_mask_sce_unk).sum(1) + eps)).mean() 65 | 66 | 67 | # get wild content closest to wild-stylized source content 68 | idx_sim_bs = 512 69 | index_nearest_neighbours = (torch.randn(0)).type(torch.cuda.LongTensor) 70 | for idx_sim in range(int(np.ceil(HW/idx_sim_bs))): 71 | idx_sim_start = idx_sim*idx_sim_bs 72 | idx_sim_end = min((idx_sim+1)*idx_sim_bs, HW) 73 | similarity_matrix = torch.einsum("bcn,cq->bnq", feats_sw_flat[:,:,idx_sim_start:idx_sim_end].type(torch.cuda.HalfTensor), queue['wild'].type(torch.cuda.HalfTensor)) # B X (H X W) X Q 74 | index_nearest_neighbours = torch.cat((index_nearest_neighbours,torch.argmax(similarity_matrix, dim=2)),dim=1) # B X (H X W) 75 | # similarity_matrix = torch.einsum("bcn,cq->bnq", feats_sw_flat, queue['wild']) # B X (H X W) X Q 76 | # index_nearest_neighbours = torch.argmax(similarity_matrix, dim=2) # B X (H X W) 77 | del similarity_matrix 78 | nearest_neighbours = torch.index_select(queue['wild'], dim=1, index=index_nearest_neighbours.view(-1)).view(C, B, HW) # C X B X (H X W) 79 | 80 | # compute exp(dot(feats_s_flat, nearest_neighbours)) 81 | logits_wce_pos = torch.einsum("bcn,cbn->bn", feats_s_flat, nearest_neighbours) # dot product: B X C X (H X W) & C X B X (H X W) => B X (H X W) 82 | logits_wce_pos = (torch.clamp(logits_wce_pos, min=-1, max=1))/T 83 | exp_logits_wce_pos = torch.exp(logits_wce_pos) 84 | 85 | # compute negative mask of logits_sce 86 | logits_mask_sce_neg = ~torch.eq(gts_flat.unsqueeze(2), gts_flat.unsqueeze(1)) # pos:0, neg:1. B X (H X W) X (H X W) 87 | 88 | # exclude unknown-labeled pixels from negative samples 89 | logits_mask_sce_neg = exclude_unknown(logits_mask_sce_neg, gts_flat) 90 | 91 | # sum exp(neg samples) 92 | exp_logits_sce_neg = (torch.exp(logits_sce) * logits_mask_sce_neg).sum(2) # B X (H X W) 93 | 94 | # Compute log_prob 95 | log_prob_wce = logits_wce_pos - torch.log(exp_logits_wce_pos + exp_logits_sce_neg) # B X (H X W) 96 | 97 | # Compute loss_wce 98 | loss_wce = -((log_prob_wce * (~logits_mask_sce_unk)).sum(1) / ((~logits_mask_sce_unk).sum(1) + eps)).mean() 99 | 100 | 101 | # enqueue wild contents 102 | sup_enqueue = feats_w_flat # Q X C # (B X H X W) X C 103 | _dequeue_and_enqueue(queue, sup_enqueue) 104 | 105 | # compute content extension learning loss 106 | loss_cel = loss_sce + loss_wce 107 | 108 | return loss_cel 109 | 110 | def exclude_unknown(mask, gts): 111 | ''' 112 | mask: [B, HW, HW] 113 | gts: [B, HW] 114 | ''' 115 | mask = mask.transpose(1,2).contiguous() 116 | mask[gts>254,:] = False 117 | mask = mask.transpose(1,2).contiguous() 118 | 119 | return mask 120 | 121 | def include_unknown(mask, gts): 122 | ''' 123 | mask: [B, HW, HW] 124 | gts: [B, HW] 125 | ''' 126 | mask = mask.transpose(1,2).contiguous() 127 | mask[gts>254,:] = True 128 | mask = mask.transpose(1,2).contiguous() 129 | 130 | return mask 131 | 132 | @torch.no_grad() 133 | def _dequeue_and_enqueue(queue, keys): 134 | # gather keys before updating queue 135 | keys = concat_all_gather(keys) # (B X H X W) X C 136 | 137 | batch_size = keys.shape[0] 138 | 139 | ptr = int(queue['wild_ptr']) 140 | 141 | # replace the keys at ptr (dequeue and enqueue) 142 | if (ptr + batch_size) <= queue['size']: 143 | # wild queue 144 | queue['wild'][:, ptr:ptr + batch_size] = keys.T 145 | ptr = (ptr + batch_size) % queue['size'] # move pointer 146 | else: 147 | # wild queue 148 | last_input_num = queue['size'] - ptr 149 | queue['wild'][:,ptr:] = (keys.T)[:,:last_input_num] 150 | ptr = (ptr + batch_size) % queue['size'] # move pointer 151 | queue['wild'][:,:ptr] = (keys.T)[:,last_input_num:] 152 | queue['wild_ptr'][0] = ptr 153 | 154 | # utils 155 | @torch.no_grad() 156 | def concat_all_gather(tensor): 157 | """ 158 | Performs all_gather operation on the provided tensors. 159 | *** Warning ***: torch.distributed.all_gather has no gradient. 160 | """ 161 | 162 | tensors_gather = varsize_tensor_all_gather(tensor) 163 | 164 | output = tensors_gather 165 | return output 166 | 167 | def varsize_tensor_all_gather(tensor: torch.Tensor): 168 | tensor = tensor.contiguous() 169 | 170 | cuda_device = f'cuda:{torch.distributed.get_rank()}' 171 | size_tens = torch.tensor([tensor.shape[0]], dtype=torch.int64, device=cuda_device) 172 | 173 | size_tens2 = [torch.ones_like(size_tens) 174 | for _ in range(torch.distributed.get_world_size())] 175 | 176 | torch.distributed.all_gather(size_tens2, size_tens) 177 | size_tens2 = torch.cat(size_tens2, dim=0).cpu() 178 | max_size = size_tens2.max() 179 | 180 | padded = torch.empty(max_size, *tensor.shape[1:], 181 | dtype=tensor.dtype, 182 | device=cuda_device) 183 | padded[:tensor.shape[0]] = tensor 184 | 185 | ag = [torch.ones_like(padded) 186 | for _ in range(torch.distributed.get_world_size())] 187 | 188 | torch.distributed.all_gather(ag,padded) 189 | ag = torch.cat(ag, dim=0) 190 | 191 | slices = [] 192 | for i, sz in enumerate(size_tens2): 193 | start_idx = i * max_size 194 | end_idx = start_idx + sz.item() 195 | 196 | if end_idx > start_idx: 197 | slices.append(ag[start_idx:end_idx]) 198 | 199 | ret = torch.cat(slices, dim=0) 200 | 201 | return ret.to(tensor) 202 | -------------------------------------------------------------------------------- /network/deepv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code Adapted from: 3 | # https://github.com/sthalles/deeplab_v3 4 | # 5 | # MIT License 6 | # 7 | # Copyright (c) 2018 Thalles Santos Silva 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | """ 26 | import logging 27 | import torch 28 | from torch import nn 29 | from network import Resnet 30 | from network.mynn import initialize_weights, Norm2d, Upsample 31 | 32 | import torchvision.models as models 33 | 34 | from network.cel import get_content_extension_loss 35 | import torchvision 36 | 37 | 38 | class _AtrousSpatialPyramidPoolingModule(nn.Module): 39 | """ 40 | operations performed: 41 | 1x1 x depth 42 | 3x3 x depth dilation 6 43 | 3x3 x depth dilation 12 44 | 3x3 x depth dilation 18 45 | image pooling 46 | concatenate all together 47 | Final 1x1 conv 48 | """ 49 | 50 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)): 51 | super(_AtrousSpatialPyramidPoolingModule, self).__init__() 52 | 53 | # Check if we are using distributed BN and use the nn from encoding.nn 54 | # library rather than using standard pytorch.nn 55 | print("output_stride = ", output_stride) 56 | if output_stride == 8: 57 | rates = [2 * r for r in rates] 58 | elif output_stride == 4: 59 | rates = [4 * r for r in rates] 60 | elif output_stride == 16: 61 | pass 62 | elif output_stride == 32: 63 | rates = [r // 2 for r in rates] 64 | else: 65 | raise 'output stride of {} not supported'.format(output_stride) 66 | 67 | self.features = [] 68 | # 1x1 69 | self.features.append( 70 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 71 | Norm2d(reduction_dim), nn.ReLU(inplace=True))) 72 | # other rates 73 | for r in rates: 74 | self.features.append(nn.Sequential( 75 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3, 76 | dilation=r, padding=r, bias=False), 77 | Norm2d(reduction_dim), 78 | nn.ReLU(inplace=True) 79 | )) 80 | self.features = torch.nn.ModuleList(self.features) 81 | 82 | # img level features 83 | self.img_pooling = nn.AdaptiveAvgPool2d(1) 84 | self.img_conv = nn.Sequential( 85 | nn.Conv2d(in_dim, 256, kernel_size=1, bias=False), 86 | Norm2d(256), nn.ReLU(inplace=True)) 87 | 88 | def forward(self, x): 89 | x_size = x.size() 90 | 91 | img_features = self.img_pooling(x) 92 | img_features = self.img_conv(img_features) 93 | img_features = Upsample(img_features, x_size[2:]) 94 | out = img_features 95 | 96 | for f in self.features: 97 | y = f(x) 98 | out = torch.cat((out, y), 1) 99 | return out 100 | 101 | 102 | class DeepV3Plus(nn.Module): 103 | """ 104 | Implement DeepLab-V3 model 105 | A: stride8 106 | B: stride16 107 | with skip connections 108 | """ 109 | 110 | def __init__(self, num_classes, trunk='resnet-50', criterion=None, criterion_aux=None, cont_proj_head=0, wild_cont_dict_size=0, 111 | variant='D16', skip='m1', skip_num=48, args=None): 112 | super(DeepV3Plus, self).__init__() 113 | 114 | self.args = args 115 | 116 | # loss 117 | self.criterion = criterion 118 | self.criterion_aux = criterion_aux 119 | self.criterion_kl = nn.KLDivLoss(reduction='batchmean').cuda() 120 | 121 | # create the wild-content dictionary 122 | self.cont_proj_head = cont_proj_head 123 | if wild_cont_dict_size > 0: 124 | if cont_proj_head > 0: 125 | self.cont_dict = {} 126 | self.cont_dict['size'] = wild_cont_dict_size 127 | self.cont_dict['dim'] = self.cont_proj_head 128 | 129 | self.register_buffer("wild_cont_dict", torch.randn(self.cont_dict['dim'], self.cont_dict['size'])) 130 | self.wild_cont_dict = nn.functional.normalize(self.wild_cont_dict, p=2, dim=0) # C X Q 131 | self.register_buffer("wild_cont_dict_ptr", torch.zeros(1, dtype=torch.long)) 132 | self.cont_dict['wild'] = self.wild_cont_dict.cuda() 133 | self.cont_dict['wild_ptr'] = self.wild_cont_dict_ptr 134 | else: 135 | raise 'dimension of wild-content dictionary is zero' 136 | 137 | # set backbone 138 | self.variant = variant 139 | self.trunk = trunk 140 | 141 | channel_1st = 3 142 | channel_2nd = 64 143 | channel_3rd = 256 144 | channel_4th = 512 145 | prev_final_channel = 1024 146 | final_channel = 2048 147 | 148 | if trunk == 'resnet-50': 149 | resnet = Resnet.resnet50(fs_layer=self.args.fs_layer) 150 | resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 151 | else: 152 | raise ValueError("Not a valid network arch") 153 | 154 | self.layer0 = resnet.layer0 155 | self.layer1, self.layer2, self.layer3, self.layer4 = \ 156 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 157 | 158 | if self.variant == 'D16': 159 | os = 16 160 | for n, m in self.layer4.named_modules(): 161 | if 'conv2' in n: 162 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 163 | elif 'downsample.0' in n: 164 | m.stride = (1, 1) 165 | else: 166 | raise 'unknown deepv3 variant: {}'.format(self.variant) 167 | 168 | self.output_stride = os 169 | self.aspp = _AtrousSpatialPyramidPoolingModule(final_channel, 256, 170 | output_stride=os) 171 | 172 | self.bot_fine = nn.Sequential( 173 | nn.Conv2d(channel_3rd, 48, kernel_size=1, bias=False), 174 | Norm2d(48), 175 | nn.ReLU(inplace=True)) 176 | 177 | self.bot_aspp = nn.Sequential( 178 | nn.Conv2d(1280, 256, kernel_size=1, bias=False), 179 | Norm2d(256), 180 | nn.ReLU(inplace=True)) 181 | 182 | self.final1 = nn.Sequential( 183 | nn.Conv2d(304, 256, kernel_size=3, padding=1, bias=False), 184 | Norm2d(256), 185 | nn.ReLU(inplace=True), 186 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 187 | Norm2d(256), 188 | nn.ReLU(inplace=True)) 189 | 190 | self.final2 = nn.Sequential( 191 | nn.Conv2d(256, num_classes, kernel_size=1, bias=True)) 192 | 193 | self.dsn = nn.Sequential( 194 | nn.Conv2d(prev_final_channel, 512, kernel_size=3, stride=1, padding=1), 195 | Norm2d(512), 196 | nn.ReLU(inplace=True), 197 | nn.Dropout2d(0.1), 198 | nn.Conv2d(512, num_classes, kernel_size=1, stride=1, padding=0, bias=True) 199 | ) 200 | initialize_weights(self.dsn) 201 | 202 | initialize_weights(self.aspp) 203 | initialize_weights(self.bot_aspp) 204 | initialize_weights(self.bot_fine) 205 | initialize_weights(self.final1) 206 | initialize_weights(self.final2) 207 | 208 | if self.cont_proj_head > 0: 209 | self.proj = nn.Sequential( 210 | nn.Linear(256, 256, bias=True), 211 | nn.ReLU(inplace=False), 212 | nn.Linear(256, self.cont_proj_head, bias=True)) 213 | initialize_weights(self.proj) 214 | 215 | # Setting the flags 216 | self.eps = 1e-5 217 | self.whitening = False 218 | 219 | if trunk == 'resnet-50': 220 | in_channel_list = [0, 0, 64, 256, 512, 1024, 2048] 221 | out_channel_list = [0, 0, 32, 128, 256, 512, 1024] 222 | else: 223 | raise ValueError("Not a valid network arch") 224 | 225 | 226 | def forward(self, x, gts=None, aux_gts=None, x_w=None, apply_fs=False): 227 | 228 | x_size = x.size() # 800 229 | 230 | # encoder 231 | x = self.layer0[0](x) 232 | if self.training & apply_fs: 233 | with torch.no_grad(): 234 | x_w = self.layer0[0](x_w) 235 | x = self.layer0[1](x) 236 | if self.training & apply_fs: 237 | x_sw = self.layer0[1](x, x_w) # feature stylization 238 | with torch.no_grad(): 239 | x_w = self.layer0[1](x_w) 240 | x = self.layer0[2](x) 241 | x = self.layer0[3](x) 242 | if self.training & apply_fs: 243 | with torch.no_grad(): 244 | x_w = self.layer0[2](x_w) 245 | x_w = self.layer0[3](x_w) 246 | x_sw = self.layer0[2](x_sw) 247 | x_sw = self.layer0[3](x_sw) 248 | 249 | if self.training & apply_fs: 250 | x_tuple = self.layer1([x, x_w, x_sw]) 251 | low_level = x_tuple[0] 252 | low_level_w = x_tuple[1] 253 | low_level_sw = x_tuple[2] 254 | else: 255 | x_tuple = self.layer1([x]) 256 | low_level = x_tuple[0] 257 | 258 | x_tuple = self.layer2(x_tuple) 259 | x_tuple = self.layer3(x_tuple) 260 | aux_out = x_tuple[0] 261 | if self.training & apply_fs: 262 | aux_out_w = x_tuple[1] 263 | aux_out_sw = x_tuple[2] 264 | x_tuple = self.layer4(x_tuple) 265 | x = x_tuple[0] 266 | if self.training & apply_fs: 267 | x_w = x_tuple[1] 268 | x_sw = x_tuple[2] 269 | 270 | # decoder 271 | x = self.aspp(x) 272 | dec0_up = self.bot_aspp(x) 273 | dec0_fine = self.bot_fine(low_level) 274 | dec0_up = Upsample(dec0_up, low_level.size()[2:]) 275 | dec0 = [dec0_fine, dec0_up] 276 | dec0 = torch.cat(dec0, 1) 277 | dec1 = self.final1(dec0) 278 | dec2 = self.final2(dec1) 279 | main_out = Upsample(dec2, x_size[2:]) 280 | 281 | if self.training: 282 | # compute original semantic segmentation loss 283 | loss_orig = self.criterion(main_out, gts) 284 | aux_out = self.dsn(aux_out) 285 | if aux_gts.dim() == 1: 286 | aux_gts = gts 287 | aux_gts = aux_gts.unsqueeze(1).float() 288 | aux_gts = nn.functional.interpolate(aux_gts, size=aux_out.shape[2:], mode='nearest') 289 | aux_gts = aux_gts.squeeze(1).long() 290 | loss_orig_aux = self.criterion_aux(aux_out, aux_gts) 291 | 292 | return_loss = [loss_orig, loss_orig_aux] 293 | 294 | if apply_fs: 295 | x_sw = self.aspp(x_sw) 296 | dec0_up_sw = self.bot_aspp(x_sw) 297 | dec0_fine_sw = self.bot_fine(low_level_sw) 298 | dec0_up_sw = Upsample(dec0_up_sw, low_level_sw.size()[2:]) 299 | dec0_sw = [dec0_fine_sw, dec0_up_sw] 300 | dec0_sw = torch.cat(dec0_sw, 1) 301 | dec1_sw = self.final1(dec0_sw) 302 | dec2_sw = self.final2(dec1_sw) 303 | main_out_sw = Upsample(dec2_sw, x_size[2:]) 304 | 305 | with torch.no_grad(): 306 | x_w = self.aspp(x_w) 307 | dec0_up_w = self.bot_aspp(x_w) 308 | dec0_fine_w = self.bot_fine(low_level_w) 309 | dec0_up_w = Upsample(dec0_up_w, low_level_w.size()[2:]) 310 | dec0_w = [dec0_fine_w, dec0_up_w] 311 | dec0_w = torch.cat(dec0_w, 1) 312 | dec1_w = self.final1(dec0_w) 313 | dec2_w = self.final2(dec1_w) 314 | main_out_w = Upsample(dec2_w, x_size[2:]) 315 | 316 | if self.args.use_cel: 317 | # projected features 318 | assert (self.cont_proj_head > 0) 319 | proj2 = self.proj(dec1.permute(0,2,3,1)).permute(0,3,1,2) 320 | proj2_sw = self.proj(dec1_sw.permute(0,2,3,1)).permute(0,3,1,2) 321 | with torch.no_grad(): 322 | proj2_w = self.proj(dec1_w.permute(0,2,3,1)).permute(0,3,1,2) 323 | 324 | # compute content extension learning loss 325 | loss_cel = get_content_extension_loss(proj2, proj2_sw, proj2_w, gts, self.cont_dict) 326 | 327 | return_loss.append(loss_cel) 328 | 329 | if self.args.use_sel: 330 | # compute style extension learning loss 331 | loss_sel = self.criterion(main_out_sw, gts) 332 | aux_out_sw = self.dsn(aux_out_sw) 333 | loss_sel_aux = self.criterion_aux(aux_out_sw, aux_gts) 334 | return_loss.append(loss_sel) 335 | return_loss.append(loss_sel_aux) 336 | 337 | if self.args.use_scr: 338 | # compute semantic consistency regularization loss 339 | loss_scr = torch.clamp((self.criterion_kl(nn.functional.log_softmax(main_out_sw, dim=1), nn.functional.softmax(main_out, dim=1)))/(torch.prod(torch.tensor(main_out.shape[1:]))), min=0) 340 | loss_scr_aux = torch.clamp((self.criterion_kl(nn.functional.log_softmax(aux_out_sw, dim=1), nn.functional.softmax(aux_out, dim=1)))/(torch.prod(torch.tensor(aux_out.shape[1:]))), min=0) 341 | return_loss.append(loss_scr) 342 | return_loss.append(loss_scr_aux) 343 | 344 | return return_loss 345 | else: 346 | return main_out 347 | 348 | 349 | def DeepR50V3PlusD(args, num_classes, criterion, criterion_aux, cont_proj_head, wild_cont_dict_size): 350 | """ 351 | Resnet 50 Based Network 352 | """ 353 | print("Model : DeepLabv3+, Backbone : ResNet-50") 354 | return DeepV3Plus(num_classes, trunk='resnet-50', criterion=criterion, criterion_aux=criterion_aux, cont_proj_head=cont_proj_head, wild_cont_dict_size=wild_cont_dict_size, 355 | variant='D16', skip='m1', args=args) 356 | -------------------------------------------------------------------------------- /network/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization 3 | """ 4 | import torch.nn as nn 5 | import torch 6 | from config import cfg 7 | 8 | import numpy as np 9 | 10 | def Norm2d(in_channels): 11 | """ 12 | Custom Norm Function to allow flexible switching 13 | """ 14 | layer = getattr(cfg.MODEL, 'BNFUNC') 15 | normalization_layer = layer(in_channels) 16 | return normalization_layer 17 | 18 | 19 | def freeze_weights(*models): 20 | for model in models: 21 | for k in model.parameters(): 22 | k.requires_grad = False 23 | 24 | def unfreeze_weights(*models): 25 | for model in models: 26 | for k in model.parameters(): 27 | k.requires_grad = True 28 | 29 | def initialize_weights(*models): 30 | """ 31 | Initialize Model Weights 32 | """ 33 | for model in models: 34 | for module in model.modules(): 35 | if isinstance(module, (nn.Conv2d, nn.Linear)): 36 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu') 37 | if module.bias is not None: 38 | module.bias.data.zero_() 39 | elif isinstance(module, nn.Conv1d): 40 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu') 41 | if module.bias is not None: 42 | module.bias.data.zero_() 43 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or \ 44 | isinstance(module, nn.GroupNorm) or isinstance(module, nn.SyncBatchNorm): 45 | module.weight.data.fill_(1) 46 | module.bias.data.zero_() 47 | elif isinstance(module, nn.ConvTranspose2d): 48 | assert module.kernel_size[0] == module.kernel_size[1] 49 | initial_weight = get_upsampling_weight(module.in_channels, module.out_channels, module.kernel_size[0]) 50 | module.weight.data.copy_(initial_weight) 51 | 52 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 53 | """Make a 2D bilinear kernel suitable for upsampling""" 54 | factor = (kernel_size + 1) // 2 55 | if kernel_size % 2 == 1: 56 | center = factor - 1 57 | else: 58 | center = factor - 0.5 59 | og = np.ogrid[:kernel_size, :kernel_size] 60 | filt = (1 - abs(og[0] - center) / factor) * \ 61 | (1 - abs(og[1] - center) / factor) 62 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), 63 | dtype=np.float64) 64 | weight[range(in_channels), range(out_channels), :, :] = filt 65 | return torch.from_numpy(weight).float() 66 | 67 | def initialize_embedding(*models): 68 | """ 69 | Initialize Model Weights 70 | """ 71 | for model in models: 72 | for module in model.modules(): 73 | if isinstance(module, nn.Embedding): 74 | module.weight.data.zero_() #original 75 | 76 | 77 | 78 | def Upsample(x, size): 79 | """ 80 | Wrapper Around the Upsample Call 81 | """ 82 | return nn.functional.interpolate(x, size=size, mode='bilinear', 83 | align_corners=True) 84 | 85 | def forgiving_state_restore(net, loaded_dict): 86 | """ 87 | Handle partial loading when some tensors don't match up in size. 88 | Because we want to use models that were trained off a different 89 | number of classes. 90 | """ 91 | net_state_dict = net.state_dict() 92 | new_loaded_dict = {} 93 | for k in net_state_dict: 94 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 95 | new_loaded_dict[k] = loaded_dict[k] 96 | else: 97 | print("Skipped loading parameter", k) 98 | # logging.info("Skipped loading parameter %s", k) 99 | net_state_dict.update(new_loaded_dict) 100 | net.load_state_dict(net_state_dict) 101 | return net 102 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch Optimizer and Scheduler Related Task 3 | """ 4 | import math 5 | import logging 6 | import torch 7 | from torch import optim 8 | from config import cfg 9 | import pdb 10 | 11 | def get_optimizer(args, net): 12 | """ 13 | Decide Optimizer (Adam or SGD) 14 | """ 15 | base_params = [] 16 | 17 | for name, param in net.named_parameters(): 18 | base_params.append(param) 19 | 20 | if args.sgd: 21 | optimizer = optim.SGD(base_params, 22 | lr=args.lr, 23 | weight_decay=args.weight_decay, 24 | momentum=args.momentum, 25 | nesterov=False) 26 | else: 27 | raise ValueError('Not a valid optimizer') 28 | 29 | if args.lr_schedule == 'scl-poly': 30 | if cfg.REDUCE_BORDER_ITER == -1: 31 | raise ValueError('ERROR Cannot Do Scale Poly') 32 | 33 | rescale_thresh = cfg.REDUCE_BORDER_ITER 34 | scale_value = args.rescale 35 | lambda1 = lambda iteration: \ 36 | math.pow(1 - iteration / args.max_iter, 37 | args.poly_exp) if iteration < rescale_thresh else scale_value * math.pow( 38 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh), 39 | args.repoly) 40 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 41 | elif args.lr_schedule == 'poly': 42 | lambda1 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp) 43 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 44 | else: 45 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule)) 46 | 47 | return optimizer, scheduler 48 | 49 | 50 | def load_weights(net, optimizer, scheduler, snapshot_file, restore_optimizer_bool=False): 51 | """ 52 | Load weights from snapshot file 53 | """ 54 | logging.info("Loading weights from model %s", snapshot_file) 55 | net, optimizer, scheduler, epoch, mean_iu = restore_snapshot(net, optimizer, scheduler, snapshot_file, 56 | restore_optimizer_bool) 57 | return epoch, mean_iu 58 | 59 | 60 | def restore_snapshot(net, optimizer, scheduler, snapshot, restore_optimizer_bool): 61 | """ 62 | Restore weights and optimizer (if needed ) for resuming job. 63 | """ 64 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 65 | logging.info("Checkpoint Load Compelete") 66 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool: 67 | optimizer.load_state_dict(checkpoint['optimizer']) 68 | if scheduler is not None and 'scheduler' in checkpoint and restore_optimizer_bool: 69 | scheduler.load_state_dict(checkpoint['scheduler']) 70 | 71 | if 'state_dict' in checkpoint: 72 | net = forgiving_state_restore(net, checkpoint['state_dict']) 73 | else: 74 | net = forgiving_state_restore(net, checkpoint) 75 | 76 | if 'epoch' in checkpoint: 77 | epoch = checkpoint['epoch'] 78 | else: 79 | epoch = 0 80 | if 'mean_iu' in checkpoint: 81 | mean_iu = checkpoint['mean_iu'] 82 | else: 83 | mean_iu = 0.0 84 | 85 | return net, optimizer, scheduler, epoch, mean_iu 86 | 87 | 88 | def forgiving_state_restore(net, loaded_dict): 89 | """ 90 | Handle partial loading when some tensors don't match up in size. 91 | Because we want to use models that were trained off a different 92 | number of classes. 93 | """ 94 | net_state_dict = net.state_dict() 95 | new_loaded_dict = {} 96 | for k in net_state_dict: 97 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 98 | new_loaded_dict[k] = loaded_dict[k] 99 | else: 100 | print("Skipped loading parameter", k) 101 | # logging.info("Skipped loading parameter %s", k) 102 | net_state_dict.update(new_loaded_dict) 103 | net.load_state_dict(net_state_dict) 104 | return net 105 | 106 | def forgiving_state_copy(target_net, source_net): 107 | """ 108 | Handle partial loading when some tensors don't match up in size. 109 | Because we want to use models that were trained off a different 110 | number of classes. 111 | """ 112 | net_state_dict = target_net.state_dict() 113 | loaded_dict = source_net.state_dict() 114 | new_loaded_dict = {} 115 | for k in net_state_dict: 116 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 117 | new_loaded_dict[k] = loaded_dict[k] 118 | print("Matched", k) 119 | else: 120 | print("Skipped loading parameter ", k) 121 | # logging.info("Skipped loading parameter %s", k) 122 | net_state_dict.update(new_loaded_dict) 123 | target_net.load_state_dict(net_state_dict) 124 | return target_net 125 | -------------------------------------------------------------------------------- /scripts/train_wildnet_r50os16_gtav.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Example on GTAV 3 | python -m torch.distributed.launch --nproc_per_node=2 train.py \ 4 | --dataset gtav \ 5 | --val_dataset bdd100k cityscapes synthia mapillary \ 6 | --wild_dataset imagenet \ 7 | --arch network.deepv3.DeepR50V3PlusD \ 8 | --city_mode 'train' \ 9 | --sgd \ 10 | --lr_schedule poly \ 11 | --lr 0.0025 \ 12 | --poly_exp 0.9 \ 13 | --max_cu_epoch 10000 \ 14 | --class_uniform_pct 0.5 \ 15 | --class_uniform_tile 1024 \ 16 | --crop_size 768 \ 17 | --scale_min 0.5 \ 18 | --scale_max 2.0 \ 19 | --rrotate 0 \ 20 | --max_iter 60000 \ 21 | --bs_mult 4 \ 22 | --gblur \ 23 | --color_aug 0.5 \ 24 | --fs_layer 1 1 1 0 0 \ 25 | --cont_proj_head 256 \ 26 | --wild_cont_dict_size 393216 \ 27 | --lambda_cel 0.1 \ 28 | --lambda_sel 1.0 \ 29 | --lambda_scr 10.0 \ 30 | --date 0101 \ 31 | --exp r50os16_gtav_wildnet \ 32 | --ckpt ./logs/ \ 33 | --tb_path ./logs/ 34 | 35 | -------------------------------------------------------------------------------- /scripts/valid_wildnet_r50os16_gtav.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | 4 | python -m torch.distributed.launch --nproc_per_node=2 valid.py \ 5 | --dataset gtav \ 6 | --val_dataset bdd100k cityscapes synthia mapillary \ 7 | --wild_dataset imagenet \ 8 | --arch network.deepv3.DeepR50V3PlusD \ 9 | --city_mode 'train' \ 10 | --sgd \ 11 | --lr_schedule poly \ 12 | --lr 0.0025 \ 13 | --poly_exp 0.9 \ 14 | --max_cu_epoch 10000 \ 15 | --class_uniform_pct 0.5 \ 16 | --class_uniform_tile 1024 \ 17 | --crop_size 768 \ 18 | --scale_min 0.5 \ 19 | --scale_max 2.0 \ 20 | --rrotate 0 \ 21 | --max_iter 60000 \ 22 | --bs_mult 4 \ 23 | --gblur \ 24 | --color_aug 0.5 \ 25 | --fs_layer 1 1 1 0 0 \ 26 | --cont_proj_head 256 \ 27 | --wild_cont_dict_size 393216 \ 28 | --lambda_cel 0.1 \ 29 | --lambda_sel 1.0 \ 30 | --lambda_scr 10.0 \ 31 | --date 0101 \ 32 | --exp r50os16_gtav_wildnet \ 33 | --ckpt ./logs/ \ 34 | --tb_path ./logs/ \ 35 | --snapshot ${1} 36 | 37 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | training code 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | import argparse 7 | import logging 8 | import os 9 | import torch 10 | 11 | from config import cfg, assert_and_infer_cfg 12 | from utils.misc import AverageMeter, prep_experiment, evaluate_eval, fast_hist 13 | import datasets 14 | import loss 15 | import network 16 | import optimizer 17 | import time 18 | import torchvision.utils as vutils 19 | import torch.nn.functional as F 20 | import numpy as np 21 | import random 22 | 23 | 24 | # Argument Parser 25 | parser = argparse.ArgumentParser(description='Semantic Segmentation') 26 | parser.add_argument('--lr', type=float, default=0.01) 27 | parser.add_argument('--arch', type=str, default='network.deepv3.DeepR50V3PlusD', 28 | help='Network architecture.') 29 | parser.add_argument('--dataset', nargs='*', type=str, default=['gtav'], 30 | help='a list of datasets; cityscapes, mapillary, gtav, bdd100k, synthia') 31 | parser.add_argument('--image_uniform_sampling', action='store_true', default=False, 32 | help='uniformly sample images across the multiple source domains') 33 | parser.add_argument('--val_dataset', nargs='*', type=str, default=['bdd100k'], 34 | help='a list consists of cityscapes, mapillary, gtav, bdd100k, synthia') 35 | parser.add_argument('--wild_dataset', nargs='*', type=str, default=['imagenet'], 36 | help='a list consists of imagenet') 37 | parser.add_argument('--cv', type=int, default=0, 38 | help='cross-validation split id to use. Default # of splits set to 3 in config') 39 | parser.add_argument('--class_uniform_pct', type=float, default=0, 40 | help='What fraction of images is uniformly sampled') 41 | parser.add_argument('--class_uniform_tile', type=int, default=1024, 42 | help='tile size for class uniform sampling') 43 | parser.add_argument('--coarse_boost_classes', type=str, default=None, 44 | help='use coarse annotations to boost fine data with specific classes') 45 | 46 | parser.add_argument('--img_wt_loss', action='store_true', default=False, 47 | help='per-image class-weighted loss') 48 | parser.add_argument('--cls_wt_loss', action='store_true', default=False, 49 | help='class-weighted loss') 50 | parser.add_argument('--batch_weighting', action='store_true', default=False, 51 | help='Batch weighting for class (use nll class weighting using batch stats') 52 | 53 | parser.add_argument('--jointwtborder', action='store_true', default=False, 54 | help='Enable boundary label relaxation') 55 | parser.add_argument('--strict_bdr_cls', type=str, default='', 56 | help='Enable boundary label relaxation for specific classes') 57 | parser.add_argument('--rlx_off_iter', type=int, default=-1, 58 | help='Turn off border relaxation after specific epoch count') 59 | parser.add_argument('--rescale', type=float, default=1.0, 60 | help='Warm Restarts new learning rate ratio compared to original lr') 61 | parser.add_argument('--repoly', type=float, default=1.5, 62 | help='Warm Restart new poly exp') 63 | 64 | parser.add_argument('--fp16', action='store_true', default=False, 65 | help='Use Nvidia Apex AMP') 66 | parser.add_argument('--local_rank', default=0, type=int, 67 | help='parameter used by apex library') 68 | 69 | parser.add_argument('--sgd', action='store_true', default=False) 70 | parser.add_argument('--adam', action='store_true', default=False) 71 | parser.add_argument('--amsgrad', action='store_true', default=False) 72 | 73 | parser.add_argument('--freeze_trunk', action='store_true', default=False) 74 | parser.add_argument('--hardnm', default=0, type=int, 75 | help='0 means no aug, 1 means hard negative mining iter 1,' + 76 | '2 means hard negative mining iter 2') 77 | 78 | parser.add_argument('--trunk', type=str, default='resnet-50', 79 | help='trunk model, can be: resnet-50 (default)') 80 | parser.add_argument('--max_epoch', type=int, default=180) 81 | parser.add_argument('--max_iter', type=int, default=30000) 82 | parser.add_argument('--max_cu_epoch', type=int, default=100000, 83 | help='Class Uniform Max Epochs') 84 | parser.add_argument('--start_epoch', type=int, default=0) 85 | parser.add_argument('--crop_nopad', action='store_true', default=False) 86 | parser.add_argument('--rrotate', type=int, 87 | default=0, help='degree of random roate') 88 | parser.add_argument('--color_aug', type=float, 89 | default=0.0, help='level of color augmentation') 90 | parser.add_argument('--gblur', action='store_true', default=False, 91 | help='Use Guassian Blur Augmentation') 92 | parser.add_argument('--bblur', action='store_true', default=False, 93 | help='Use Bilateral Blur Augmentation') 94 | parser.add_argument('--lr_schedule', type=str, default='poly', 95 | help='name of lr schedule: poly') 96 | parser.add_argument('--poly_exp', type=float, default=0.9, 97 | help='polynomial LR exponent') 98 | parser.add_argument('--bs_mult', type=int, default=2, 99 | help='Batch size for training per gpu') 100 | parser.add_argument('--bs_mult_val', type=int, default=1, 101 | help='Batch size for Validation per gpu') 102 | parser.add_argument('--crop_size', type=int, default=720, 103 | help='training crop size') 104 | parser.add_argument('--pre_size', type=int, default=None, 105 | help='resize image shorter edge to this before augmentation') 106 | parser.add_argument('--scale_min', type=float, default=0.5, 107 | help='dynamically scale training images down to this size') 108 | parser.add_argument('--scale_max', type=float, default=2.0, 109 | help='dynamically scale training images up to this size') 110 | parser.add_argument('--weight_decay', type=float, default=5e-4) 111 | parser.add_argument('--momentum', type=float, default=0.9) 112 | parser.add_argument('--snapshot', type=str, default=None) 113 | parser.add_argument('--restore_optimizer', action='store_true', default=False) 114 | 115 | parser.add_argument('--city_mode', type=str, default='train', 116 | help='experiment directory date name') 117 | parser.add_argument('--date', type=str, default='default', 118 | help='experiment directory date name') 119 | parser.add_argument('--exp', type=str, default='default', 120 | help='experiment directory name') 121 | parser.add_argument('--tb_tag', type=str, default='', 122 | help='add tag to tb dir') 123 | parser.add_argument('--ckpt', type=str, default='logs/ckpt', 124 | help='Save Checkpoint Point') 125 | parser.add_argument('--tb_path', type=str, default='logs/tb', 126 | help='Save Tensorboard Path') 127 | parser.add_argument('--syncbn', action='store_true', default=True, 128 | help='Use Synchronized BN') 129 | parser.add_argument('--dump_augmentation_images', action='store_true', default=False, 130 | help='Dump Augmentated Images for sanity check') 131 | parser.add_argument('--test_mode', action='store_true', default=False, 132 | help='Minimum testing to verify nothing failed, ' + 133 | 'Runs code for 1 epoch of train and val') 134 | parser.add_argument('-wb', '--wt_bound', type=float, default=1.0, 135 | help='Weight Scaling for the losses') 136 | parser.add_argument('--maxSkip', type=int, default=0, 137 | help='Skip x number of frames of video augmented dataset') 138 | parser.add_argument('--scf', action='store_true', default=False, 139 | help='scale correction factor') 140 | parser.add_argument('--dist_url', default='tcp://127.0.0.1:', type=str, 141 | help='url used to set up distributed training') 142 | 143 | parser.add_argument('--image_in', action='store_true', default=False, 144 | help='Input Image Instance Norm') 145 | 146 | parser.add_argument('--fs_layer', nargs='*', type=int, default=[0,0,0,0,0], 147 | help='0: None, 1: AdaIN') 148 | parser.add_argument('--lambda_cel', type=float, default=0.0, 149 | help='lambda for content extension learning loss') 150 | parser.add_argument('--lambda_sel', type=float, default=0.0, 151 | help='lambda for style extension learning loss') 152 | parser.add_argument('--lambda_scr', type=float, default=0.0, 153 | help='lambda for semantic consistency regularization loss') 154 | parser.add_argument('--cont_proj_head', type=int, default=0, 155 | help='number of output channels of content projection head') 156 | parser.add_argument('--wild_cont_dict_size', type=int, default=0, 157 | help='wild-content dictionary size') 158 | 159 | parser.add_argument('--use_fs', action='store_true', default=False, 160 | help='Automatic setting from fs_layer. feature stylization with wild dataset') 161 | parser.add_argument('--use_scr', action='store_true', default=False, 162 | help='Automatic setting from lambda_scr') 163 | parser.add_argument('--use_sel', action='store_true', default=False, 164 | help='Automatic setting from lambda_sel') 165 | parser.add_argument('--use_cel', action='store_true', default=False, 166 | help='Automatic setting from lambda_cel') 167 | 168 | args = parser.parse_args() 169 | 170 | random_seed = cfg.RANDOM_SEED #304 171 | torch.manual_seed(random_seed) 172 | torch.cuda.manual_seed(random_seed) 173 | torch.cuda.manual_seed_all(random_seed) # if use multi-GPU 174 | torch.backends.cudnn.deterministic = True 175 | torch.backends.cudnn.benchmark = False 176 | np.random.seed(random_seed) 177 | random.seed(random_seed) 178 | 179 | args.world_size = 1 180 | 181 | # Test Mode run two epochs with a few iterations of training and val 182 | if args.test_mode: 183 | args.max_epoch = 2 184 | 185 | if 'WORLD_SIZE' in os.environ: 186 | # args.apex = int(os.environ['WORLD_SIZE']) > 1 187 | args.world_size = int(os.environ['WORLD_SIZE']) 188 | print("Total world size: ", int(os.environ['WORLD_SIZE'])) 189 | 190 | torch.cuda.set_device(args.local_rank) 191 | print('My Rank:', args.local_rank) 192 | # Initialize distributed communication 193 | args.dist_url = args.dist_url + str(8000 + (int(time.time()%1000))//10) 194 | 195 | torch.distributed.init_process_group(backend='nccl', 196 | init_method='env://', 197 | world_size=args.world_size, 198 | rank=args.local_rank) 199 | # torch.distributed.init_process_group(backend='nccl', 200 | # init_method=args.dist_url, 201 | # world_size=args.world_size, 202 | # rank=args.local_rank) 203 | 204 | for i in range(len(args.fs_layer)): 205 | if args.fs_layer[i] == 1: 206 | args.use_fs = True 207 | 208 | if args.lambda_cel > 0: 209 | args.use_cel = True 210 | if args.lambda_sel > 0: 211 | args.use_sel = True 212 | if args.lambda_scr > 0: 213 | args.use_scr = True 214 | 215 | def main(): 216 | """ 217 | Main Function 218 | """ 219 | # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer 220 | assert_and_infer_cfg(args) 221 | writer = prep_experiment(args, parser) 222 | 223 | train_source_loader, val_loaders, train_wild_loader, train_obj, extra_val_loaders = datasets.setup_loaders(args) 224 | 225 | criterion, criterion_val = loss.get_loss(args) 226 | criterion_aux = loss.get_loss_aux(args) 227 | net = network.get_net(args, criterion, criterion_aux, args.cont_proj_head, args.wild_cont_dict_size) 228 | 229 | optim, scheduler = optimizer.get_optimizer(args, net) 230 | 231 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) 232 | net = network.warp_network_in_dataparallel(net, args.local_rank) 233 | epoch = 0 234 | i = 0 235 | 236 | if args.snapshot: 237 | epoch, mean_iu = optimizer.load_weights(net, optim, scheduler, 238 | args.snapshot, args.restore_optimizer) 239 | if args.restore_optimizer is True: 240 | iter_per_epoch = len(train_source_loader) 241 | i = iter_per_epoch * epoch 242 | epoch = epoch + 1 243 | else: 244 | epoch = 0 245 | 246 | print("#### iteration", i) 247 | torch.cuda.empty_cache() 248 | 249 | while i < args.max_iter: 250 | # Update EPOCH CTR 251 | cfg.immutable(False) 252 | cfg.ITER = i 253 | cfg.immutable(True) 254 | 255 | i = train(train_source_loader, train_wild_loader, net, optim, epoch, writer, scheduler, args.max_iter) 256 | train_source_loader.sampler.set_epoch(epoch + 1) 257 | train_wild_loader.sampler.set_epoch(epoch + 1) 258 | 259 | if args.local_rank == 0: 260 | print("Saving pth file...") 261 | evaluate_eval(args, net, optim, scheduler, None, None, [], 262 | writer, epoch, "None", None, i, save_pth=True) 263 | 264 | if args.class_uniform_pct: 265 | if epoch >= args.max_cu_epoch: 266 | train_obj.build_epoch(cut=True) 267 | train_source_loader.sampler.set_num_samples() 268 | else: 269 | train_obj.build_epoch() 270 | 271 | epoch += 1 272 | 273 | # Validation after epochs 274 | if len(val_loaders) == 1: 275 | # Run validation only one time - To save models 276 | for dataset, val_loader in val_loaders.items(): 277 | validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i) 278 | else: 279 | if args.local_rank == 0: 280 | print("Saving pth file...") 281 | evaluate_eval(args, net, optim, scheduler, None, None, [], 282 | writer, epoch, "None", None, i, save_pth=True) 283 | 284 | for dataset, val_loader in extra_val_loaders.items(): 285 | print("Extra validating... This won't save pth file") 286 | validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False) 287 | 288 | 289 | def train(source_loader, wild_loader, net, optim, curr_epoch, writer, scheduler, max_iter): 290 | """ 291 | Runs the training loop per epoch 292 | source_loader: Source data loader for train 293 | wild_loader: Wild data loader for train 294 | net: thet network 295 | optim: optimizer 296 | curr_epoch: current epoch 297 | writer: tensorboard writer 298 | return: 299 | """ 300 | net.train() 301 | 302 | train_total_loss = AverageMeter() 303 | time_meter = AverageMeter() 304 | 305 | curr_iter = curr_epoch * len(source_loader) 306 | 307 | wild_loader_iter = enumerate(wild_loader) 308 | 309 | for i, data in enumerate(source_loader): 310 | if curr_iter >= max_iter: 311 | break 312 | 313 | inputs, gts, _, aux_gts = data 314 | 315 | # Multi source and AGG case 316 | if len(inputs.shape) == 5: 317 | B, D, C, H, W = inputs.shape 318 | num_domains = D 319 | inputs = inputs.transpose(0, 1) 320 | gts = gts.transpose(0, 1).squeeze(2) 321 | aux_gts = aux_gts.transpose(0, 1).squeeze(2) 322 | 323 | inputs = [input.squeeze(0) for input in torch.chunk(inputs, num_domains, 0)] 324 | gts = [gt.squeeze(0) for gt in torch.chunk(gts, num_domains, 0)] 325 | aux_gts = [aux_gt.squeeze(0) for aux_gt in torch.chunk(aux_gts, num_domains, 0)] 326 | else: 327 | B, C, H, W = inputs.shape 328 | num_domains = 1 329 | inputs = [inputs] 330 | gts = [gts] 331 | aux_gts = [aux_gts] 332 | 333 | batch_pixel_size = C * H * W 334 | 335 | for di, ingredients in enumerate(zip(inputs, gts, aux_gts)): 336 | input, gt, aux_gt = ingredients 337 | 338 | _, inputs_wild = next(wild_loader_iter) 339 | input_wild = inputs_wild[0] 340 | 341 | start_ts = time.time() 342 | 343 | img_gt = None 344 | input, gt = input.cuda(), gt.cuda() 345 | input_wild = input_wild.cuda() 346 | 347 | optim.zero_grad() 348 | outputs = net(x=input, gts=gt, aux_gts=aux_gt, x_w=input_wild, apply_fs=args.use_fs) 349 | 350 | outputs_index = 0 351 | main_loss = outputs[outputs_index] 352 | outputs_index += 1 353 | aux_loss = outputs[outputs_index] 354 | outputs_index += 1 355 | total_loss = main_loss + (0.4 * aux_loss) 356 | 357 | if args.use_fs: 358 | if args.use_cel: 359 | cel_loss = outputs[outputs_index] 360 | outputs_index += 1 361 | total_loss = total_loss + (args.lambda_cel * cel_loss) 362 | else: 363 | cel_loss = 0 364 | 365 | if args.use_sel: 366 | sel_loss_main = outputs[outputs_index] 367 | outputs_index += 1 368 | sel_loss_aux = outputs[outputs_index] 369 | outputs_index += 1 370 | total_loss = total_loss + args.lambda_sel * (sel_loss_main + (0.4 * sel_loss_aux)) 371 | else: 372 | sel_loss_main = 0 373 | sel_loss_aux = 0 374 | 375 | if args.use_scr: 376 | scr_loss_main = outputs[outputs_index] 377 | outputs_index += 1 378 | scr_loss_aux = outputs[outputs_index] 379 | outputs_index += 1 380 | total_loss = total_loss + args.lambda_scr * (scr_loss_main + (0.4 * scr_loss_aux)) 381 | else: 382 | scr_loss_main = 0 383 | scr_loss_aux = 0 384 | 385 | 386 | log_total_loss = total_loss.clone().detach_() 387 | torch.distributed.all_reduce(log_total_loss, torch.distributed.ReduceOp.SUM) 388 | log_total_loss = log_total_loss / args.world_size 389 | train_total_loss.update(log_total_loss.item(), batch_pixel_size) 390 | 391 | total_loss.backward() 392 | optim.step() 393 | 394 | time_meter.update(time.time() - start_ts) 395 | 396 | del total_loss, log_total_loss 397 | 398 | if args.local_rank == 0: 399 | if i % 50 == 49: 400 | msg = '[epoch {}], [iter {} / {} : {}], [loss {:0.6f}], [lr {:0.6f}], [time {:0.4f}]'.format( 401 | curr_epoch, i + 1, len(source_loader), curr_iter, train_total_loss.avg, 402 | optim.param_groups[-1]['lr'], time_meter.avg / args.train_batch_size) 403 | 404 | logging.info(msg) 405 | 406 | # Log tensorboard metrics for each iteration of the training phase 407 | writer.add_scalar('loss/train_loss', (train_total_loss.avg), curr_iter) 408 | train_total_loss.reset() 409 | time_meter.reset() 410 | 411 | curr_iter += 1 412 | scheduler.step() 413 | 414 | if i > 5 and args.test_mode: 415 | return curr_iter 416 | 417 | return curr_iter 418 | 419 | def validate(val_loader, dataset, net, criterion, optim, scheduler, curr_epoch, writer, curr_iter, save_pth=True): 420 | """ 421 | Runs the validation loop after each training epoch 422 | val_loader: Data loader for validation 423 | dataset: dataset name (str) 424 | net: thet network 425 | criterion: loss fn 426 | optimizer: optimizer 427 | curr_epoch: current epoch 428 | writer: tensorboard writer 429 | return: val_avg for step function if required 430 | """ 431 | 432 | net.eval() 433 | val_loss = AverageMeter() 434 | iou_acc = 0 435 | error_acc = 0 436 | dump_images = [] 437 | 438 | for val_idx, data in enumerate(val_loader): 439 | 440 | inputs, gt_image, img_names, _ = data 441 | 442 | if len(inputs.shape) == 5: 443 | B, D, C, H, W = inputs.shape 444 | inputs = inputs.view(-1, C, H, W) 445 | gt_image = gt_image.view(-1, 1, H, W) 446 | 447 | assert len(inputs.size()) == 4 and len(gt_image.size()) == 3 448 | assert inputs.size()[2:] == gt_image.size()[1:] 449 | 450 | batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3) 451 | inputs, gt_cuda = inputs.cuda(), gt_image.cuda() 452 | 453 | with torch.no_grad(): 454 | output = net(inputs) 455 | 456 | del inputs 457 | 458 | assert output.size()[2:] == gt_image.size()[1:] 459 | assert output.size()[1] == datasets.num_classes 460 | 461 | val_loss.update(criterion(output, gt_cuda).item(), batch_pixel_size) 462 | 463 | del gt_cuda 464 | 465 | # Collect data from different GPU to a single GPU since 466 | # encoding.parallel.criterionparallel function calculates distributed loss 467 | # functions 468 | predictions = output.data.max(1)[1].cpu() 469 | 470 | # Logging 471 | if val_idx % 20 == 0: 472 | if args.local_rank == 0: 473 | logging.info("validating: %d / %d", val_idx + 1, len(val_loader)) 474 | if val_idx > 10 and args.test_mode: 475 | break 476 | 477 | # Image Dumps 478 | if val_idx < 10: 479 | dump_images.append([gt_image, predictions, img_names]) 480 | 481 | iou_acc += fast_hist(predictions.numpy().flatten(), gt_image.numpy().flatten(), 482 | datasets.num_classes) 483 | del output, val_idx, data 484 | 485 | iou_acc_tensor = torch.cuda.FloatTensor(iou_acc) 486 | torch.distributed.all_reduce(iou_acc_tensor, op=torch.distributed.ReduceOp.SUM) 487 | iou_acc = iou_acc_tensor.cpu().numpy() 488 | 489 | if args.local_rank == 0: 490 | evaluate_eval(args, net, optim, scheduler, val_loss, iou_acc, dump_images, 491 | writer, curr_epoch, dataset, None, curr_iter, save_pth=save_pth) 492 | 493 | return val_loss.avg 494 | 495 | 496 | if __name__ == '__main__': 497 | main() 498 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suhyeonlee/WildNet/c83aa92b7cd591512045fc5093da25a280bde430/transforms/__init__.py -------------------------------------------------------------------------------- /transforms/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code borrowded from: 3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py 4 | # 5 | # 6 | # MIT License 7 | # 8 | # Copyright (c) 2017 ZijunDeng 9 | # 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy 11 | # of this software and associated documentation files (the "Software"), to deal 12 | # in the Software without restriction, including without limitation the rights 13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | # copies of the Software, and to permit persons to whom the Software is 15 | # furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all 18 | # copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | # SOFTWARE. 27 | 28 | """ 29 | 30 | """ 31 | Standard Transform 32 | """ 33 | 34 | import random 35 | import numpy as np 36 | from skimage.filters import gaussian 37 | from skimage.restoration import denoise_bilateral 38 | import torch 39 | from PIL import Image, ImageEnhance 40 | import torchvision.transforms as torch_tr 41 | from config import cfg 42 | from scipy.ndimage.interpolation import shift 43 | 44 | from skimage.segmentation import find_boundaries 45 | from skimage.util import random_noise 46 | 47 | try: 48 | import accimage 49 | except ImportError: 50 | accimage = None 51 | 52 | 53 | class RandomVerticalFlip(object): 54 | def __call__(self, img): 55 | if random.random() < 0.5: 56 | return img.transpose(Image.FLIP_TOP_BOTTOM) 57 | return img 58 | 59 | 60 | class DeNormalize(object): 61 | def __init__(self, mean, std): 62 | self.mean = mean 63 | self.std = std 64 | 65 | def __call__(self, tensor): 66 | for t, m, s in zip(tensor, self.mean, self.std): 67 | t.mul_(s).add_(m) 68 | return tensor 69 | 70 | 71 | class MaskToTensor(object): 72 | def __call__(self, img): 73 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 74 | 75 | class RelaxedBoundaryLossToTensor(object): 76 | """ 77 | Boundary Relaxation 78 | """ 79 | def __init__(self,ignore_id, num_classes): 80 | self.ignore_id=ignore_id 81 | self.num_classes= num_classes 82 | 83 | 84 | def new_one_hot_converter(self,a): 85 | ncols = self.num_classes+1 86 | out = np.zeros( (a.size,ncols), dtype=np.uint8) 87 | out[np.arange(a.size),a.ravel()] = 1 88 | out.shape = a.shape + (ncols,) 89 | return out 90 | 91 | def __call__(self,img): 92 | 93 | img_arr = np.array(img) 94 | img_arr[img_arr==self.ignore_id]=self.num_classes 95 | 96 | if cfg.STRICTBORDERCLASS != None: 97 | one_hot_orig = self.new_one_hot_converter(img_arr) 98 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1])) 99 | for cls in cfg.STRICTBORDERCLASS: 100 | mask = np.logical_or(mask,(img_arr == cls)) 101 | one_hot = 0 102 | 103 | border = cfg.BORDER_WINDOW 104 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER): 105 | border = border // 2 106 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8) 107 | 108 | for i in range(-border,border+1): 109 | for j in range(-border, border+1): 110 | shifted= shift(img_arr,(i,j), cval=self.num_classes) 111 | one_hot += self.new_one_hot_converter(shifted) 112 | 113 | one_hot[one_hot>1] = 1 114 | 115 | if cfg.STRICTBORDERCLASS != None: 116 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot) 117 | 118 | one_hot = np.moveaxis(one_hot,-1,0) 119 | 120 | 121 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER): 122 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot) 123 | # print(one_hot.shape) 124 | return torch.from_numpy(one_hot).byte() 125 | 126 | class ResizeHeight(object): 127 | def __init__(self, size, interpolation=Image.BILINEAR): 128 | self.target_h = size 129 | self.interpolation = interpolation 130 | 131 | def __call__(self, img): 132 | w, h = img.size 133 | target_w = int(w / h * self.target_h) 134 | return img.resize((target_w, self.target_h), self.interpolation) 135 | 136 | 137 | class FreeScale(object): 138 | def __init__(self, size, interpolation=Image.BILINEAR): 139 | self.size = tuple(reversed(size)) # size: (h, w) 140 | self.interpolation = interpolation 141 | 142 | def __call__(self, img): 143 | return img.resize(self.size, self.interpolation) 144 | 145 | 146 | class FlipChannels(object): 147 | """ 148 | Flip around the x-axis 149 | """ 150 | def __call__(self, img): 151 | img = np.array(img)[:, :, ::-1] 152 | return Image.fromarray(img.astype(np.uint8)) 153 | 154 | 155 | class RandomGaussianBlur(object): 156 | """ 157 | Apply Gaussian Blur 158 | """ 159 | def __call__(self, img): 160 | sigma = 0.15 + random.random() * 1.15 161 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 162 | blurred_img *= 255 163 | return Image.fromarray(blurred_img.astype(np.uint8)) 164 | 165 | 166 | class RandomGaussianNoise(object): 167 | def __call__(self, img): 168 | noised_img = random_noise(np.array(img), mode='gaussian') 169 | noised_img *= 255 170 | return Image.fromarray(noised_img.astype(np.uint8)) 171 | 172 | 173 | class RandomBilateralBlur(object): 174 | """ 175 | Apply Bilateral Filtering 176 | 177 | """ 178 | def __call__(self, img): 179 | sigma = random.uniform(0.05,0.75) 180 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True) 181 | blurred_img *= 255 182 | return Image.fromarray(blurred_img.astype(np.uint8)) 183 | 184 | def _is_pil_image(img): 185 | if accimage is not None: 186 | return isinstance(img, (Image.Image, accimage.Image)) 187 | else: 188 | return isinstance(img, Image.Image) 189 | 190 | 191 | def adjust_brightness(img, brightness_factor): 192 | """Adjust brightness of an Image. 193 | 194 | Args: 195 | img (PIL Image): PIL Image to be adjusted. 196 | brightness_factor (float): How much to adjust the brightness. Can be 197 | any non negative number. 0 gives a black image, 1 gives the 198 | original image while 2 increases the brightness by a factor of 2. 199 | 200 | Returns: 201 | PIL Image: Brightness adjusted image. 202 | """ 203 | if not _is_pil_image(img): 204 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 205 | 206 | enhancer = ImageEnhance.Brightness(img) 207 | img = enhancer.enhance(brightness_factor) 208 | return img 209 | 210 | 211 | def adjust_contrast(img, contrast_factor): 212 | """Adjust contrast of an Image. 213 | 214 | Args: 215 | img (PIL Image): PIL Image to be adjusted. 216 | contrast_factor (float): How much to adjust the contrast. Can be any 217 | non negative number. 0 gives a solid gray image, 1 gives the 218 | original image while 2 increases the contrast by a factor of 2. 219 | 220 | Returns: 221 | PIL Image: Contrast adjusted image. 222 | """ 223 | if not _is_pil_image(img): 224 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 225 | 226 | enhancer = ImageEnhance.Contrast(img) 227 | img = enhancer.enhance(contrast_factor) 228 | return img 229 | 230 | 231 | def adjust_saturation(img, saturation_factor): 232 | """Adjust color saturation of an image. 233 | 234 | Args: 235 | img (PIL Image): PIL Image to be adjusted. 236 | saturation_factor (float): How much to adjust the saturation. 0 will 237 | give a black and white image, 1 will give the original image while 238 | 2 will enhance the saturation by a factor of 2. 239 | 240 | Returns: 241 | PIL Image: Saturation adjusted image. 242 | """ 243 | if not _is_pil_image(img): 244 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 245 | 246 | enhancer = ImageEnhance.Color(img) 247 | img = enhancer.enhance(saturation_factor) 248 | return img 249 | 250 | 251 | def adjust_hue(img, hue_factor): 252 | """Adjust hue of an image. 253 | 254 | The image hue is adjusted by converting the image to HSV and 255 | cyclically shifting the intensities in the hue channel (H). 256 | The image is then converted back to original image mode. 257 | 258 | `hue_factor` is the amount of shift in H channel and must be in the 259 | interval `[-0.5, 0.5]`. 260 | 261 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 262 | 263 | Args: 264 | img (PIL Image): PIL Image to be adjusted. 265 | hue_factor (float): How much to shift the hue channel. Should be in 266 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 267 | HSV space in positive and negative direction respectively. 268 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 269 | with complementary colors while 0 gives the original image. 270 | 271 | Returns: 272 | PIL Image: Hue adjusted image. 273 | """ 274 | if not(-0.5 <= hue_factor <= 0.5): 275 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 276 | 277 | if not _is_pil_image(img): 278 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 279 | input_mode = img.mode 280 | if input_mode in {'L', '1', 'I', 'F'}: 281 | return img 282 | 283 | h, s, v = img.convert('HSV').split() 284 | 285 | np_h = np.array(h, dtype=np.uint8) 286 | # uint8 addition take cares of rotation across boundaries 287 | with np.errstate(over='ignore'): 288 | np_h += np.uint8(hue_factor * 255) 289 | h = Image.fromarray(np_h, 'L') 290 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 291 | return img 292 | 293 | 294 | class ColorJitter(object): 295 | """Randomly change the brightness, contrast and saturation of an image. 296 | 297 | Args: 298 | brightness (float): How much to jitter brightness. brightness_factor 299 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 300 | contrast (float): How much to jitter contrast. contrast_factor 301 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 302 | saturation (float): How much to jitter saturation. saturation_factor 303 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 304 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 305 | [-hue, hue]. Should be >=0 and <= 0.5. 306 | """ 307 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 308 | self.brightness = brightness 309 | self.contrast = contrast 310 | self.saturation = saturation 311 | self.hue = hue 312 | 313 | @staticmethod 314 | def get_params(brightness, contrast, saturation, hue): 315 | """Get a randomized transform to be applied on image. 316 | 317 | Arguments are same as that of __init__. 318 | 319 | Returns: 320 | Transform which randomly adjusts brightness, contrast and 321 | saturation in a random order. 322 | """ 323 | transforms = [] 324 | if brightness > 0: 325 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 326 | transforms.append( 327 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) 328 | 329 | if contrast > 0: 330 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 331 | transforms.append( 332 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) 333 | 334 | if saturation > 0: 335 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 336 | transforms.append( 337 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) 338 | 339 | if hue > 0: 340 | hue_factor = np.random.uniform(-hue, hue) 341 | transforms.append( 342 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) 343 | 344 | np.random.shuffle(transforms) 345 | transform = torch_tr.Compose(transforms) 346 | 347 | return transform 348 | 349 | def __call__(self, img): 350 | """ 351 | Args: 352 | img (PIL Image): Input image. 353 | 354 | Returns: 355 | PIL Image: Color jittered image. 356 | """ 357 | transform = self.get_params(self.brightness, self.contrast, 358 | self.saturation, self.hue) 359 | return transform(img) 360 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suhyeonlee/WildNet/c83aa92b7cd591512045fc5093da25a280bde430/utils/__init__.py -------------------------------------------------------------------------------- /utils/attr_dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | 30 | class AttrDict(dict): 31 | 32 | IMMUTABLE = '__immutable__' 33 | 34 | def __init__(self, *args, **kwargs): 35 | super(AttrDict, self).__init__(*args, **kwargs) 36 | self.__dict__[AttrDict.IMMUTABLE] = False 37 | 38 | def __getattr__(self, name): 39 | if name in self.__dict__: 40 | return self.__dict__[name] 41 | elif name in self: 42 | return self[name] 43 | else: 44 | raise AttributeError(name) 45 | 46 | def __setattr__(self, name, value): 47 | if not self.__dict__[AttrDict.IMMUTABLE]: 48 | if name in self.__dict__: 49 | self.__dict__[name] = value 50 | else: 51 | self[name] = value 52 | else: 53 | raise AttributeError( 54 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'. 55 | format(name, value) 56 | ) 57 | 58 | def immutable(self, is_immutable): 59 | """Set immutability to is_immutable and recursively apply the setting 60 | to all nested AttrDicts. 61 | """ 62 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable 63 | # Recursively set immutable state 64 | for v in self.__dict__.values(): 65 | if isinstance(v, AttrDict): 66 | v.immutable(is_immutable) 67 | for v in self.values(): 68 | if isinstance(v, AttrDict): 69 | v.immutable(is_immutable) 70 | 71 | def is_immutable(self): 72 | return self.__dict__[AttrDict.IMMUTABLE] 73 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellanous Functions 3 | """ 4 | 5 | import sys 6 | import re 7 | import os 8 | import shutil 9 | import torch 10 | from datetime import datetime 11 | import logging 12 | from subprocess import call 13 | import shlex 14 | from tensorboardX import SummaryWriter 15 | import datasets 16 | import numpy as np 17 | import torchvision.transforms as standard_transforms 18 | import torchvision.utils as vutils 19 | from config import cfg 20 | import random 21 | 22 | 23 | # Create unique output dir name based on non-default command line args 24 | def make_exp_name(args, parser): 25 | exp_name = '{}-{}'.format(args.dataset[:4], args.arch[:]) 26 | dict_args = vars(args) 27 | 28 | # sort so that we get a consistent directory name 29 | argnames = sorted(dict_args) 30 | ignorelist = ['date', 'exp', 'arch','prev_best_filepath', 'lr_schedule', 'max_cu_epoch', 'max_epoch', 31 | 'strict_bdr_cls', 'world_size', 'tb_path','best_record', 'test_mode', 'ckpt', 'coarse_boost_classes', 32 | 'crop_size', 'dist_url', 'syncbn', 'max_iter', 'color_aug', 'scale_max', 'scale_min', 'bs_mult', 33 | 'class_uniform_pct', 'class_uniform_tile'] 34 | # build experiment name with non-default args 35 | for argname in argnames: 36 | if dict_args[argname] != parser.get_default(argname): 37 | if argname in ignorelist: 38 | continue 39 | if argname == 'snapshot': 40 | arg_str = 'PT' 41 | argname = '' 42 | elif argname == 'nosave': 43 | arg_str = '' 44 | argname='' 45 | elif argname == 'freeze_trunk': 46 | argname = '' 47 | arg_str = 'ft' 48 | elif argname == 'syncbn': 49 | argname = '' 50 | arg_str = 'sbn' 51 | elif argname == 'jointwtborder': 52 | argname = '' 53 | arg_str = 'rlx_loss' 54 | elif isinstance(dict_args[argname], bool): 55 | arg_str = 'T' if dict_args[argname] else 'F' 56 | else: 57 | arg_str = str(dict_args[argname])[:7] 58 | if argname is not '': 59 | exp_name += '_{}_{}'.format(str(argname), arg_str) 60 | else: 61 | exp_name += '_{}'.format(arg_str) 62 | # clean special chars out exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name) 63 | return exp_name 64 | 65 | def fast_hist(label_pred, label_true, num_classes): 66 | mask = (label_true >= 0) & (label_true < num_classes) 67 | hist = np.bincount( 68 | num_classes * label_true[mask].astype(int) + 69 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 70 | return hist 71 | 72 | def per_class_iu(hist): 73 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 74 | 75 | def save_log(prefix, output_dir, date_str, rank=0): 76 | fmt = '%(asctime)s.%(msecs)03d %(message)s' 77 | date_fmt = '%m-%d %H:%M:%S' 78 | filename = os.path.join(output_dir, prefix + '_' + date_str +'_rank_' + str(rank) +'.log') 79 | print("Logging :", filename) 80 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt, 81 | filename=filename, filemode='w') 82 | console = logging.StreamHandler() 83 | console.setLevel(logging.INFO) 84 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt) 85 | console.setFormatter(formatter) 86 | if rank == 0: 87 | logging.getLogger('').addHandler(console) 88 | else: 89 | fh = logging.FileHandler(filename) 90 | logging.getLogger('').addHandler(fh) 91 | 92 | 93 | 94 | def prep_experiment(args, parser): 95 | """ 96 | Make output directories, setup logging, Tensorboard, snapshot code. 97 | """ 98 | ckpt_path = args.ckpt 99 | tb_path = args.tb_path 100 | exp_name = make_exp_name(args, parser) 101 | args.exp_path = os.path.join(ckpt_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H'))) 102 | args.tb_exp_path = os.path.join(tb_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H'))) 103 | args.ngpu = torch.cuda.device_count() 104 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S')) 105 | args.best_record = {} 106 | # args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 107 | # 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 108 | args.last_record = {} 109 | if args.local_rank == 0: 110 | os.makedirs(args.exp_path, exist_ok=True) 111 | os.makedirs(args.tb_exp_path, exist_ok=True) 112 | save_log('log', args.exp_path, args.date_str, rank=args.local_rank) 113 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write( 114 | str(args) + '\n\n') 115 | writer = SummaryWriter(log_dir=args.tb_exp_path, comment=args.tb_tag) 116 | return writer 117 | return None 118 | 119 | def evaluate_eval_for_inference(hist, dataset=None): 120 | """ 121 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for 122 | large dataset) Only applies to eval/eval.py 123 | """ 124 | # axis 0: gt, axis 1: prediction 125 | acc = np.diag(hist).sum() / hist.sum() 126 | acc_cls = np.diag(hist) / hist.sum(axis=1) 127 | acc_cls = np.nanmean(acc_cls) 128 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 129 | 130 | print_evaluate_results(hist, iu, dataset=dataset) 131 | freq = hist.sum(axis=1) / hist.sum() 132 | mean_iu = np.nanmean(iu) 133 | logging.info('mean {}'.format(mean_iu)) 134 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 135 | return acc, acc_cls, mean_iu, fwavacc 136 | 137 | 138 | 139 | def evaluate_eval(args, net, optimizer, scheduler, val_loss, hist, dump_images, writer, epoch=0, dataset_name=None, dataset=None, curr_iter=0, optimizer_at=None, scheduler_at=None, save_pth=True): 140 | """ 141 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for 142 | large dataset) Only applies to eval/eval.py 143 | """ 144 | if val_loss is not None and hist is not None: 145 | # axis 0: gt, axis 1: prediction 146 | acc = np.diag(hist).sum() / hist.sum() 147 | acc_cls = np.diag(hist) / hist.sum(axis=1) 148 | acc_cls = np.nanmean(acc_cls) 149 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 150 | 151 | print_evaluate_results(hist, iu, dataset_name=dataset_name, dataset=dataset) 152 | freq = hist.sum(axis=1) / hist.sum() 153 | mean_iu = np.nanmean(iu) 154 | logging.info('mean {}'.format(mean_iu)) 155 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 156 | else: 157 | mean_iu = 0 158 | 159 | if dataset_name not in args.last_record.keys(): 160 | args.last_record[dataset_name] = {} 161 | 162 | if save_pth: 163 | # update latest snapshot 164 | if 'mean_iu' in args.last_record[dataset_name]: 165 | last_snapshot = 'last_{}_epoch_{}_mean-iu_{:.5f}.pth'.format( 166 | dataset_name, args.last_record[dataset_name]['epoch'], 167 | args.last_record[dataset_name]['mean_iu']) 168 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 169 | # if dataset_name != "cityscapes": 170 | if dataset_name != "gtav": 171 | try: 172 | os.remove(last_snapshot) 173 | except OSError: 174 | pass 175 | 176 | last_snapshot = 'last_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(dataset_name, epoch, mean_iu) 177 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 178 | args.last_record[dataset_name]['mean_iu'] = mean_iu 179 | args.last_record[dataset_name]['epoch'] = epoch 180 | 181 | torch.cuda.synchronize() 182 | 183 | if optimizer_at is not None: 184 | torch.save({ 185 | 'state_dict': net.state_dict(), 186 | 'optimizer': optimizer.state_dict(), 187 | 'optimizer_at': optimizer_at.state_dict(), 188 | 'scheduler': scheduler.state_dict(), 189 | 'scheduler_at': scheduler_at.state_dict(), 190 | 'epoch': epoch, 191 | 'mean_iu': mean_iu, 192 | 'command': ' '.join(sys.argv[1:]) 193 | }, last_snapshot) 194 | else: 195 | torch.save({ 196 | 'state_dict': net.state_dict(), 197 | 'optimizer': optimizer.state_dict(), 198 | 'scheduler': scheduler.state_dict(), 199 | 'epoch': epoch, 200 | 'mean_iu': mean_iu, 201 | 'command': ' '.join(sys.argv[1:]) 202 | }, last_snapshot) 203 | 204 | if val_loss is not None and hist is not None: 205 | if dataset_name not in args.best_record.keys(): 206 | args.best_record[dataset_name] = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 207 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 208 | # update best snapshot 209 | if mean_iu > args.best_record[dataset_name]['mean_iu'] : 210 | # remove old best snapshot 211 | if args.best_record[dataset_name]['epoch'] != -1: 212 | best_snapshot = 'best_{}_epoch_{}_mean-iu_{:.5f}.pth'.format( 213 | dataset_name, args.best_record[dataset_name]['epoch'], 214 | args.best_record[dataset_name]['mean_iu']) 215 | 216 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 217 | assert os.path.exists(best_snapshot), \ 218 | 'cant find old snapshot {}'.format(best_snapshot) 219 | os.remove(best_snapshot) 220 | 221 | # save new best 222 | args.best_record[dataset_name]['val_loss'] = val_loss.avg 223 | args.best_record[dataset_name]['epoch'] = epoch 224 | args.best_record[dataset_name]['acc'] = acc 225 | args.best_record[dataset_name]['acc_cls'] = acc_cls 226 | args.best_record[dataset_name]['mean_iu'] = mean_iu 227 | args.best_record[dataset_name]['fwavacc'] = fwavacc 228 | 229 | best_snapshot = 'best_{}_epoch_{}_mean-iu_{:.5f}.pth'.format( 230 | dataset_name, args.best_record[dataset_name]['epoch'], 231 | args.best_record[dataset_name]['mean_iu']) 232 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 233 | shutil.copyfile(last_snapshot, best_snapshot) 234 | else: 235 | logging.info("Saved file to {}".format(last_snapshot)) 236 | 237 | if val_loss is not None and hist is not None: 238 | logging.info('-' * 107) 239 | fmt_str = '[epoch %d], [dataset name %s], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 240 | '[mean_iu %.5f], [fwavacc %.5f]' 241 | logging.info(fmt_str % (epoch, dataset_name, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 242 | if save_pth: 243 | fmt_str = 'best record: [dataset name %s], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 244 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], ' 245 | logging.info(fmt_str % (dataset_name, 246 | args.best_record[dataset_name]['val_loss'], args.best_record[dataset_name]['acc'], 247 | args.best_record[dataset_name]['acc_cls'], args.best_record[dataset_name]['mean_iu'], 248 | args.best_record[dataset_name]['fwavacc'], args.best_record[dataset_name]['epoch'])) 249 | logging.info('-' * 107) 250 | 251 | if writer: 252 | # tensorboard logging of validation phase metrics 253 | writer.add_scalar('{}/acc'.format(dataset_name), acc, curr_iter) 254 | writer.add_scalar('{}/acc_cls'.format(dataset_name), acc_cls, curr_iter) 255 | writer.add_scalar('{}/mean_iu'.format(dataset_name), mean_iu, curr_iter) 256 | writer.add_scalar('{}/val_loss'.format(dataset_name), val_loss.avg, curr_iter) 257 | 258 | 259 | 260 | 261 | 262 | def print_evaluate_results(hist, iu, dataset_name=None, dataset=None): 263 | # fixme: Need to refactor this dict 264 | try: 265 | id2cat = dataset.id2cat 266 | except: 267 | id2cat = {i: i for i in range(datasets.num_classes)} 268 | iu_false_positive = hist.sum(axis=1) - np.diag(hist) 269 | iu_false_negative = hist.sum(axis=0) - np.diag(hist) 270 | iu_true_positive = np.diag(hist) 271 | 272 | logging.info('Dataset name: {}'.format(dataset_name)) 273 | logging.info('IoU:') 274 | logging.info('label_id label iU Precision Recall TP FP FN') 275 | for idx, i in enumerate(iu): 276 | # Format all of the strings: 277 | idx_string = "{:2d}".format(idx) 278 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else '' 279 | iu_string = '{:5.1f}'.format(i * 100) 280 | total_pixels = hist.sum() 281 | tp = '{:5.1f}'.format(100 * iu_true_positive[idx] / total_pixels) 282 | fp = '{:5.1f}'.format( 283 | iu_false_positive[idx] / iu_true_positive[idx]) 284 | fn = '{:5.1f}'.format(iu_false_negative[idx] / iu_true_positive[idx]) 285 | precision = '{:5.1f}'.format( 286 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx])) 287 | recall = '{:5.1f}'.format( 288 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx])) 289 | logging.info('{} {} {} {} {} {} {} {}'.format( 290 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn)) 291 | 292 | 293 | 294 | 295 | class AverageMeter(object): 296 | 297 | def __init__(self): 298 | self.reset() 299 | 300 | def reset(self): 301 | self.val = 0 302 | self.avg = 0 303 | self.sum = 0 304 | self.count = 0 305 | 306 | def update(self, val, n=1): 307 | self.val = val 308 | self.sum += val * n 309 | self.count += n 310 | self.avg = self.sum / self.count 311 | -------------------------------------------------------------------------------- /utils/my_data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | # Code adapted from: 4 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py 5 | # 6 | # BSD 3-Clause License 7 | # 8 | # Copyright (c) 2017, 9 | # All rights reserved. 10 | # 11 | # Redistribution and use in source and binary forms, with or without 12 | # modification, are permitted provided that the following conditions are met: 13 | # 14 | # * Redistributions of source code must retain the above copyright notice, this 15 | # list of conditions and the following disclaimer. 16 | # 17 | # * Redistributions in binary form must reproduce the above copyright notice, 18 | # this list of conditions and the following disclaimer in the documentation 19 | # and/or other materials provided with the distribution. 20 | # 21 | # * Neither the name of the copyright holder nor the names of its 22 | # contributors may be used to endorse or promote products derived from 23 | # this software without specific prior written permission. 24 | # 25 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 26 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 27 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 28 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 29 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 30 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 31 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 32 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 33 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 34 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s 35 | """ 36 | 37 | 38 | import operator 39 | import torch 40 | import warnings 41 | from torch.nn.modules import Module 42 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 43 | from torch.nn.parallel.replicate import replicate 44 | from torch.nn.parallel.parallel_apply import parallel_apply 45 | 46 | 47 | def _check_balance(device_ids): 48 | imbalance_warn = """ 49 | There is an imbalance between your GPUs. You may want to exclude GPU {} which 50 | has less than 75% of the memory or cores of GPU {}. You can do so by setting 51 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES 52 | environment variable.""" 53 | 54 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] 55 | 56 | def warn_imbalance(get_prop): 57 | values = [get_prop(props) for props in dev_props] 58 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) 59 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) 60 | if min_val / max_val < 0.75: 61 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) 62 | return True 63 | return False 64 | 65 | if warn_imbalance(lambda props: props.total_memory): 66 | return 67 | if warn_imbalance(lambda props: props.multi_processor_count): 68 | return 69 | 70 | 71 | 72 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, gather=True): 73 | """ 74 | Evaluates module(input) in parallel across the GPUs given in device_ids. 75 | This is the functional version of the DataParallel module. 76 | Args: 77 | module: the module to evaluate in parallel 78 | inputs: inputs to the module 79 | device_ids: GPU ids on which to replicate module 80 | output_device: GPU location of the output Use -1 to indicate the CPU. 81 | (default: device_ids[0]) 82 | Returns: 83 | a Tensor containing the result of module(input) located on 84 | output_device 85 | """ 86 | if not isinstance(inputs, tuple): 87 | inputs = (inputs,) 88 | 89 | if device_ids is None: 90 | device_ids = list(range(torch.cuda.device_count())) 91 | 92 | if output_device is None: 93 | output_device = device_ids[0] 94 | 95 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) 96 | if len(device_ids) == 1: 97 | return module(*inputs[0], **module_kwargs[0]) 98 | used_device_ids = device_ids[:len(inputs)] 99 | replicas = replicate(module, used_device_ids) 100 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) 101 | if gather: 102 | return gather(outputs, output_device, dim) 103 | else: 104 | return outputs 105 | 106 | 107 | 108 | class MyDataParallel(Module): 109 | """ 110 | Implements data parallelism at the module level. 111 | This container parallelizes the application of the given module by 112 | splitting the input across the specified devices by chunking in the batch 113 | dimension. In the forward pass, the module is replicated on each device, 114 | and each replica handles a portion of the input. During the backwards 115 | pass, gradients from each replica are summed into the original module. 116 | The batch size should be larger than the number of GPUs used. 117 | See also: :ref:`cuda-nn-dataparallel-instead` 118 | Arbitrary positional and keyword inputs are allowed to be passed into 119 | DataParallel EXCEPT Tensors. All tensors will be scattered on dim 120 | specified (default 0). Primitive types will be broadcasted, but all 121 | other types will be a shallow copy and can be corrupted if written to in 122 | the model's forward pass. 123 | .. warning:: 124 | Forward and backward hooks defined on :attr:`module` and its submodules 125 | will be invoked ``len(device_ids)`` times, each with inputs located on 126 | a particular device. Particularly, the hooks are only guaranteed to be 127 | executed in correct order with respect to operations on corresponding 128 | devices. For example, it is not guaranteed that hooks set via 129 | :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before 130 | `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but 131 | that each such hook be executed before the corresponding 132 | :meth:`~torch.nn.Module.forward` call of that device. 133 | .. warning:: 134 | When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in 135 | :func:`forward`, this wrapper will return a vector of length equal to 136 | number of devices used in data parallelism, containing the result from 137 | each device. 138 | .. note:: 139 | There is a subtlety in using the 140 | ``pack sequence -> recurrent network -> unpack sequence`` pattern in a 141 | :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. 142 | See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for 143 | details. 144 | Args: 145 | module: module to be parallelized 146 | device_ids: CUDA devices (default: all devices) 147 | output_device: device location of output (default: device_ids[0]) 148 | Attributes: 149 | module (Module): the module to be parallelized 150 | Example:: 151 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 152 | >>> output = net(input_var) 153 | """ 154 | 155 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well 156 | 157 | def __init__(self, module, device_ids=None, output_device=None, dim=0, gather=True): 158 | super(MyDataParallel, self).__init__() 159 | 160 | if not torch.cuda.is_available(): 161 | self.module = module 162 | self.device_ids = [] 163 | return 164 | 165 | if device_ids is None: 166 | device_ids = list(range(torch.cuda.device_count())) 167 | if output_device is None: 168 | output_device = device_ids[0] 169 | self.dim = dim 170 | self.module = module 171 | self.device_ids = device_ids 172 | self.output_device = output_device 173 | self.gather_bool = gather 174 | 175 | _check_balance(self.device_ids) 176 | 177 | if len(self.device_ids) == 1: 178 | self.module.cuda(device_ids[0]) 179 | 180 | def forward(self, *inputs, **kwargs): 181 | if not self.device_ids: 182 | return self.module(*inputs, **kwargs) 183 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 184 | if len(self.device_ids) == 1: 185 | return [self.module(*inputs[0], **kwargs[0])] 186 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 187 | outputs = self.parallel_apply(replicas, inputs, kwargs) 188 | if self.gather_bool: 189 | return self.gather(outputs, self.output_device) 190 | else: 191 | return outputs 192 | 193 | def replicate(self, module, device_ids): 194 | return replicate(module, device_ids) 195 | 196 | def scatter(self, inputs, kwargs, device_ids): 197 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 198 | 199 | def parallel_apply(self, replicas, inputs, kwargs): 200 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 201 | 202 | def gather(self, outputs, output_device): 203 | return gather(outputs, output_device, dim=self.dim) 204 | 205 | -------------------------------------------------------------------------------- /valid.py: -------------------------------------------------------------------------------- 1 | """ 2 | training code 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | import argparse 7 | import logging 8 | import os 9 | import torch 10 | 11 | from config import cfg, assert_and_infer_cfg 12 | from utils.misc import AverageMeter, prep_experiment, evaluate_eval, fast_hist 13 | import datasets 14 | import loss 15 | import network 16 | import optimizer 17 | import time 18 | import torchvision.utils as vutils 19 | import torch.nn.functional as F 20 | import numpy as np 21 | import random 22 | import pdb 23 | 24 | # Argument Parser 25 | parser = argparse.ArgumentParser(description='Semantic Segmentation') 26 | parser.add_argument('--lr', type=float, default=0.01) 27 | parser.add_argument('--arch', type=str, default='network.deepv3.DeepR50V3PlusD', 28 | help='Network architecture.') 29 | parser.add_argument('--dataset', nargs='*', type=str, default=['gtav'], 30 | help='a list of datasets; cityscapes, mapillary, gtav, bdd100k, synthia') 31 | parser.add_argument('--image_uniform_sampling', action='store_true', default=False, 32 | help='uniformly sample images across the multiple source domains') 33 | parser.add_argument('--val_dataset', nargs='*', type=str, default=['bdd100k'], 34 | help='a list consists of cityscapes, mapillary, gtav, bdd100k, synthia') 35 | parser.add_argument('--wild_dataset', nargs='*', type=str, default=['imagenet'], 36 | help='a list consists of imagenet') 37 | parser.add_argument('--cv', type=int, default=0, 38 | help='cross-validation split id to use. Default # of splits set to 3 in config') 39 | parser.add_argument('--class_uniform_pct', type=float, default=0, 40 | help='What fraction of images is uniformly sampled') 41 | parser.add_argument('--class_uniform_tile', type=int, default=1024, 42 | help='tile size for class uniform sampling') 43 | parser.add_argument('--coarse_boost_classes', type=str, default=None, 44 | help='use coarse annotations to boost fine data with specific classes') 45 | 46 | parser.add_argument('--img_wt_loss', action='store_true', default=False, 47 | help='per-image class-weighted loss') 48 | parser.add_argument('--cls_wt_loss', action='store_true', default=False, 49 | help='class-weighted loss') 50 | parser.add_argument('--batch_weighting', action='store_true', default=False, 51 | help='Batch weighting for class (use nll class weighting using batch stats') 52 | 53 | parser.add_argument('--jointwtborder', action='store_true', default=False, 54 | help='Enable boundary label relaxation') 55 | parser.add_argument('--strict_bdr_cls', type=str, default='', 56 | help='Enable boundary label relaxation for specific classes') 57 | parser.add_argument('--rlx_off_iter', type=int, default=-1, 58 | help='Turn off border relaxation after specific epoch count') 59 | parser.add_argument('--rescale', type=float, default=1.0, 60 | help='Warm Restarts new learning rate ratio compared to original lr') 61 | parser.add_argument('--repoly', type=float, default=1.5, 62 | help='Warm Restart new poly exp') 63 | 64 | parser.add_argument('--fp16', action='store_true', default=False, 65 | help='Use Nvidia Apex AMP') 66 | parser.add_argument('--local_rank', default=0, type=int, 67 | help='parameter used by apex library') 68 | 69 | parser.add_argument('--sgd', action='store_true', default=False) 70 | parser.add_argument('--adam', action='store_true', default=False) 71 | parser.add_argument('--amsgrad', action='store_true', default=False) 72 | 73 | parser.add_argument('--freeze_trunk', action='store_true', default=False) 74 | parser.add_argument('--hardnm', default=0, type=int, 75 | help='0 means no aug, 1 means hard negative mining iter 1,' + 76 | '2 means hard negative mining iter 2') 77 | 78 | parser.add_argument('--trunk', type=str, default='resnet-50', 79 | help='trunk model, can be: resnet-50 (default)') 80 | parser.add_argument('--max_epoch', type=int, default=180) 81 | parser.add_argument('--max_iter', type=int, default=30000) 82 | parser.add_argument('--max_cu_epoch', type=int, default=100000, 83 | help='Class Uniform Max Epochs') 84 | parser.add_argument('--start_epoch', type=int, default=0) 85 | parser.add_argument('--crop_nopad', action='store_true', default=False) 86 | parser.add_argument('--rrotate', type=int, 87 | default=0, help='degree of random roate') 88 | parser.add_argument('--color_aug', type=float, 89 | default=0.0, help='level of color augmentation') 90 | parser.add_argument('--gblur', action='store_true', default=False, 91 | help='Use Guassian Blur Augmentation') 92 | parser.add_argument('--bblur', action='store_true', default=False, 93 | help='Use Bilateral Blur Augmentation') 94 | parser.add_argument('--lr_schedule', type=str, default='poly', 95 | help='name of lr schedule: poly') 96 | parser.add_argument('--poly_exp', type=float, default=0.9, 97 | help='polynomial LR exponent') 98 | parser.add_argument('--bs_mult', type=int, default=2, 99 | help='Batch size for training per gpu') 100 | parser.add_argument('--bs_mult_val', type=int, default=1, 101 | help='Batch size for Validation per gpu') 102 | parser.add_argument('--crop_size', type=int, default=720, 103 | help='training crop size') 104 | parser.add_argument('--pre_size', type=int, default=None, 105 | help='resize image shorter edge to this before augmentation') 106 | parser.add_argument('--scale_min', type=float, default=0.5, 107 | help='dynamically scale training images down to this size') 108 | parser.add_argument('--scale_max', type=float, default=2.0, 109 | help='dynamically scale training images up to this size') 110 | parser.add_argument('--weight_decay', type=float, default=5e-4) 111 | parser.add_argument('--momentum', type=float, default=0.9) 112 | parser.add_argument('--snapshot', type=str, default=None) 113 | parser.add_argument('--restore_optimizer', action='store_true', default=False) 114 | 115 | parser.add_argument('--city_mode', type=str, default='train', 116 | help='experiment directory date name') 117 | parser.add_argument('--date', type=str, default='default', 118 | help='experiment directory date name') 119 | parser.add_argument('--exp', type=str, default='default', 120 | help='experiment directory name') 121 | parser.add_argument('--tb_tag', type=str, default='', 122 | help='add tag to tb dir') 123 | parser.add_argument('--ckpt', type=str, default='logs/ckpt', 124 | help='Save Checkpoint Point') 125 | parser.add_argument('--tb_path', type=str, default='logs/tb', 126 | help='Save Tensorboard Path') 127 | parser.add_argument('--syncbn', action='store_true', default=True, 128 | help='Use Synchronized BN') 129 | parser.add_argument('--dump_augmentation_images', action='store_true', default=False, 130 | help='Dump Augmentated Images for sanity check') 131 | parser.add_argument('--test_mode', action='store_true', default=False, 132 | help='Minimum testing to verify nothing failed, ' + 133 | 'Runs code for 1 epoch of train and val') 134 | parser.add_argument('-wb', '--wt_bound', type=float, default=1.0, 135 | help='Weight Scaling for the losses') 136 | parser.add_argument('--maxSkip', type=int, default=0, 137 | help='Skip x number of frames of video augmented dataset') 138 | parser.add_argument('--scf', action='store_true', default=False, 139 | help='scale correction factor') 140 | parser.add_argument('--dist_url', default='tcp://127.0.0.1:', type=str, 141 | help='url used to set up distributed training') 142 | 143 | parser.add_argument('--image_in', action='store_true', default=False, 144 | help='Input Image Instance Norm') 145 | 146 | parser.add_argument('--fs_layer', nargs='*', type=int, default=[0,0,0,0,0], 147 | help='0: None, 1: AdaIN') 148 | parser.add_argument('--lambda_cel', type=float, default=0.0, 149 | help='lambda for content extension learning loss') 150 | parser.add_argument('--lambda_sel', type=float, default=0.0, 151 | help='lambda for style extension learning loss') 152 | parser.add_argument('--lambda_scr', type=float, default=0.0, 153 | help='lambda for semantic consistency regularization loss') 154 | parser.add_argument('--cont_proj_head', type=int, default=0, 155 | help='number of output channels of content projection head') 156 | parser.add_argument('--wild_cont_dict_size', type=int, default=0, 157 | help='wild-content dictionary size') 158 | 159 | parser.add_argument('--use_fs', action='store_true', default=False, 160 | help='Automatic setting from fs_layer. feature stylization with wild dataset') 161 | parser.add_argument('--use_scr', action='store_true', default=False, 162 | help='Automatic setting from lambda_scr') 163 | parser.add_argument('--use_sel', action='store_true', default=False, 164 | help='Automatic setting from lambda_sel') 165 | parser.add_argument('--use_cel', action='store_true', default=False, 166 | help='Automatic setting from lambda_cel') 167 | 168 | args = parser.parse_args() 169 | 170 | random_seed = cfg.RANDOM_SEED #304 171 | torch.manual_seed(random_seed) 172 | torch.cuda.manual_seed(random_seed) 173 | torch.cuda.manual_seed_all(random_seed) # if use multi-GPU 174 | torch.backends.cudnn.deterministic = True 175 | torch.backends.cudnn.benchmark = False 176 | np.random.seed(random_seed) 177 | random.seed(random_seed) 178 | 179 | args.world_size = 1 180 | 181 | # Test Mode run two epochs with a few iterations of training and val 182 | if args.test_mode: 183 | args.max_epoch = 2 184 | 185 | if 'WORLD_SIZE' in os.environ: 186 | # args.apex = int(os.environ['WORLD_SIZE']) > 1 187 | args.world_size = int(os.environ['WORLD_SIZE']) 188 | print("Total world size: ", int(os.environ['WORLD_SIZE'])) 189 | 190 | torch.cuda.set_device(args.local_rank) 191 | print('My Rank:', args.local_rank) 192 | # Initialize distributed communication 193 | args.dist_url = args.dist_url + str(8000 + (int(time.time()%1000))//10) 194 | 195 | torch.distributed.init_process_group(backend='nccl', 196 | init_method='env://', 197 | world_size=args.world_size, 198 | rank=args.local_rank) 199 | # torch.distributed.init_process_group(backend='nccl', 200 | # init_method=args.dist_url, 201 | # world_size=args.world_size, 202 | # rank=args.local_rank) 203 | 204 | def main(): 205 | """ 206 | Main Function 207 | """ 208 | # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer 209 | assert_and_infer_cfg(args) 210 | prep_experiment(args, parser) 211 | writer = None 212 | 213 | _, val_loaders, _, _, extra_val_loaders = datasets.setup_loaders(args) 214 | 215 | criterion, criterion_val = loss.get_loss(args) 216 | criterion_aux = loss.get_loss_aux(args) 217 | net = network.get_net(args, criterion, criterion_aux, args.cont_proj_head, args.wild_cont_dict_size) 218 | 219 | optim, scheduler = optimizer.get_optimizer(args, net) 220 | 221 | net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) 222 | net = network.warp_network_in_dataparallel(net, args.local_rank) 223 | epoch = 0 224 | i = 0 225 | 226 | if args.snapshot: 227 | epoch, mean_iu = optimizer.load_weights(net, optim, scheduler, 228 | args.snapshot, args.restore_optimizer) 229 | 230 | print("#### iteration", i) 231 | torch.cuda.empty_cache() 232 | 233 | for dataset, val_loader in val_loaders.items(): 234 | validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False) 235 | 236 | for dataset, val_loader in extra_val_loaders.items(): 237 | print("Extra validating... This won't save pth file") 238 | validate(val_loader, dataset, net, criterion_val, optim, scheduler, epoch, writer, i, save_pth=False) 239 | 240 | def validate(val_loader, dataset, net, criterion, optim, scheduler, curr_epoch, writer, curr_iter, save_pth=True): 241 | """ 242 | Runs the validation loop after each training epoch 243 | val_loader: Data loader for validation 244 | dataset: dataset name (str) 245 | net: thet network 246 | criterion: loss fn 247 | optimizer: optimizer 248 | curr_epoch: current epoch 249 | writer: tensorboard writer 250 | return: val_avg for step function if required 251 | """ 252 | 253 | net.eval() 254 | val_loss = AverageMeter() 255 | iou_acc = 0 256 | error_acc = 0 257 | dump_images = [] 258 | 259 | for val_idx, data in enumerate(val_loader): 260 | 261 | inputs, gt_image, img_names, _ = data 262 | 263 | if len(inputs.shape) == 5: 264 | B, D, C, H, W = inputs.shape 265 | inputs = inputs.view(-1, C, H, W) 266 | gt_image = gt_image.view(-1, 1, H, W) 267 | 268 | assert len(inputs.size()) == 4 and len(gt_image.size()) == 3 269 | assert inputs.size()[2:] == gt_image.size()[1:] 270 | 271 | batch_pixel_size = inputs.size(0) * inputs.size(2) * inputs.size(3) 272 | inputs, gt_cuda = inputs.cuda(), gt_image.cuda() 273 | 274 | with torch.no_grad(): 275 | output = net(inputs) 276 | 277 | del inputs 278 | 279 | assert output.size()[2:] == gt_image.size()[1:] 280 | assert output.size()[1] == datasets.num_classes 281 | 282 | val_loss.update(criterion(output, gt_cuda).item(), batch_pixel_size) 283 | 284 | del gt_cuda 285 | 286 | # Collect data from different GPU to a single GPU since 287 | # encoding.parallel.criterionparallel function calculates distributed loss 288 | # functions 289 | predictions = output.data.max(1)[1].cpu() 290 | 291 | # Logging 292 | if val_idx % 20 == 0: 293 | if args.local_rank == 0: 294 | logging.info("validating: %d / %d", val_idx + 1, len(val_loader)) 295 | if val_idx > 10 and args.test_mode: 296 | break 297 | 298 | # Image Dumps 299 | if val_idx < 10: 300 | dump_images.append([gt_image, predictions, img_names]) 301 | 302 | iou_acc += fast_hist(predictions.numpy().flatten(), gt_image.numpy().flatten(), 303 | datasets.num_classes) 304 | del output, val_idx, data 305 | 306 | iou_acc_tensor = torch.cuda.FloatTensor(iou_acc) 307 | torch.distributed.all_reduce(iou_acc_tensor, op=torch.distributed.ReduceOp.SUM) 308 | iou_acc = iou_acc_tensor.cpu().numpy() 309 | 310 | if args.local_rank == 0: 311 | evaluate_eval(args, net, optim, scheduler, val_loss, iou_acc, dump_images, 312 | writer, curr_epoch, dataset, None, curr_iter, save_pth=save_pth) 313 | 314 | return val_loss.avg 315 | 316 | 317 | if __name__ == '__main__': 318 | main() 319 | --------------------------------------------------------------------------------