├── .runx ├── Dockerfile ├── LICENSE ├── PREPARE_DATASETS.md ├── README.md ├── config.py ├── datasets ├── __init__.py ├── base_loader.py ├── cityscapes.py ├── cityscapes_labels.py ├── mapillary.py ├── nullloader.py ├── randaugment.py ├── sampler.py ├── uniform.py └── utils.py ├── imgs ├── composited_sf.png └── test_imgs │ ├── nyc.jpg │ └── sf.jpg ├── loss ├── optimizer.py ├── radam.py ├── rmi.py ├── rmi_utils.py └── utils.py ├── network ├── Resnet.py ├── SEresnext.py ├── __init__.py ├── attnscale.py ├── basic.py ├── bn_helper.py ├── deeper.py ├── deepv3.py ├── hrnetv2.py ├── mscale.py ├── mscale2.py ├── mynn.py ├── ocr_utils.py ├── ocrnet.py ├── utils.py ├── wider_resnet.py └── xception.py ├── requirements.txt ├── scripts ├── dump_cityscapes.yml ├── dump_folder.yml ├── eval_cityscapes.yml ├── eval_mapillary.yml ├── train_cityscapes.yml ├── train_cityscapes_deepv3.yml ├── train_cityscapes_sota.yml └── train_mapillary.yml ├── train.py ├── transforms ├── __init__.py ├── joint_transforms.py └── transforms.py └── utils ├── __init__.py ├── attr_dict.py ├── f_boundary.py ├── misc.py ├── my_data_parallel.py ├── results_page.py └── trnval_utils.py /.runx: -------------------------------------------------------------------------------- 1 | # The log root directory: 2 | LOGROOT: logs 3 | 4 | # Code to ignore for copy 5 | CODE_IGNORE_PATTERNS: '*.pth,*.pyc,docs,test,.git,*__pycache__,logs' 6 | 7 | # Farm-specific items: 8 | FARM: myfarm 9 | 10 | myfarm: 11 | SUBMIT_CMD: 'submit_job' 12 | RESOURCES: 13 | image: some_docker_image 14 | gpu: 8 15 | mem: 400 16 | 17 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:19.10-py3 2 | 3 | RUN pip install --no-cache-dir runx==0.0.6 4 | RUN pip install --no-cache-dir numpy 5 | RUN pip install --no-cache-dir sklearn 6 | RUN pip install --no-cache-dir h5py 7 | RUN pip install --no-cache-dir jupyter 8 | RUN pip install --no-cache-dir scikit-image 9 | RUN pip install --no-cache-dir pillow 10 | RUN pip install --no-cache-dir piexif 11 | RUN pip install --no-cache-dir cffi 12 | RUN pip install --no-cache-dir tqdm 13 | RUN pip install --no-cache-dir dominate 14 | RUN pip install --no-cache-dir opencv-python 15 | RUN pip install --no-cache-dir nose 16 | RUN pip install --no-cache-dir ninja 17 | 18 | RUN apt-get update && apt-get install libgtk2.0-dev -y && rm -rf /var/lib/apt/lists/* 19 | 20 | # Install Apex 21 | RUN cd /home/ && git clone https://github.com/NVIDIA/apex.git apex && cd apex && python setup.py install --cuda_ext --cpp_ext 22 | WORKDIR /home/ 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Nvidia Corporation 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | 3. Neither the name of the copyright holder nor the names of its contributors 14 | may be used to endorse or promote products derived from this software 15 | without specific prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 18 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 19 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 20 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 21 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 22 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 23 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 24 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 25 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 26 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 27 | POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /PREPARE_DATASETS.md: -------------------------------------------------------------------------------- 1 | ## Mapillary Vistas Dataset 2 | 3 | First of all, please request the research edition dataset from [here](https://www.mapillary.com/dataset/vistas/). The downloaded file is named as `mapillary-vistas-dataset_public_v1.1.zip`. 4 | 5 | Then simply unzip the file by 6 | ```shell 7 | unzip mapillary-vistas-dataset_public_v1.1.zip 8 | ``` 9 | 10 | The folder structure will look like: 11 | ``` 12 | Mapillary 13 | ├── config.json 14 | ├── demo.py 15 | ├── Mapillary Vistas Research Edition License.pdf 16 | ├── README 17 | ├── requirements.txt 18 | ├── training 19 | │ ├── images 20 | │ ├── instances 21 | │ ├── labels 22 | │ ├── panoptic 23 | ├── validation 24 | │ ├── images 25 | │ ├── instances 26 | │ ├── labels 27 | │ ├── panoptic 28 | ├── testing 29 | │ ├── images 30 | │ ├── instances 31 | │ ├── labels 32 | │ ├── panoptic 33 | ``` 34 | Note that, the `instances`, `labels` and `panoptic` folders inside `testing` are empty. 35 | 36 | Suppose you store your dataset at `~/username/data/Mapillary`, please update the dataset path in `config.py`, 37 | ``` 38 | __C.DATASET.MAPILLARY_DIR = '~/username/data/Mapillary' 39 | ``` 40 | 41 | ## Cityscapes Dataset 42 | 43 | ### Download Dataset 44 | First of all, please request the dataset from [here](https://www.cityscapes-dataset.com/). You need multiple files. 45 | ``` 46 | - leftImg8bit_trainvaltest.zip 47 | - gtFine_trainvaltest.zip 48 | - leftImg8bit_trainextra.zip 49 | - gtCoarse.zip 50 | - refinement_final_v0.zip [link] (https://drive.google.com/file/d/1DtPo-WP-hjaOwsbj6ZxTtOo_7R_4TKRG/) # This file is only needed for autolabelled training for recreating SOTA 51 | ``` 52 | 53 | If you prefer to use command lines (e.g., `wget`) to download the dataset, 54 | ``` 55 | # First step, obtain your login credentials. 56 | Please register an account at https://www.cityscapes-dataset.com/login/. 57 | 58 | # Second step, log into cityscapes system, suppose you already have a USERNAME and a PASSWORD. 59 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=USERNAME&password=PASSWORD&submit=Login' https://www.cityscapes-dataset.com/login/ 60 | 61 | # Third step, download the zip files you need. 62 | wget -c -t 0 --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 63 | 64 | # The corresponding packageID is listed below, 65 | 1 -> gtFine_trainvaltest.zip (241MB) md5sum: 4237c19de34c8a376e9ba46b495d6f66 66 | 2 -> gtCoarse.zip (1.3GB) md5sum: 1c7b95c84b1d36cc59a9194d8e5b989f 67 | 3 -> leftImg8bit_trainvaltest.zip (11GB) md5sum: 0a6e97e94b616a514066c9e2adb0c97f 68 | 4 -> leftImg8bit_trainextra.zip (44GB) md5sum: 9167a331a158ce3e8989e166c95d56d4 69 | 5 -> refinement_final_v0.zip (5GB) md5sum: 82aa6698ef7358457894c7cc924534fb 70 | ``` 71 | 72 | ### Prepare Folder Structure 73 | 74 | Now unzip those files, the desired folder structure will look like, 75 | ``` 76 | Cityscapes 77 | ├── leftImg8bit_trainvaltest 78 | │ ├── leftImg8bit 79 | │ │ ├── train 80 | │ │ │ ├── aachen 81 | │ │ │ │ ├── aachen_000000_000019_leftImg8bit.png 82 | │ │ │ │ ├── aachen_000001_000019_leftImg8bit.png 83 | │ │ │ │ ├── ... 84 | │ │ │ ├── bochum 85 | │ │ │ ├── ... 86 | │ │ ├── val 87 | │ │ ├── test 88 | ├── gtFine_trainvaltest 89 | │ ├── gtFine 90 | │ │ ├── train 91 | │ │ │ ├── aachen 92 | │ │ │ │ ├── aachen_000000_000019_gtFine_color.png 93 | │ │ │ │ ├── aachen_000000_000019_gtFine_instanceIds.png 94 | │ │ │ │ ├── aachen_000000_000019_gtFine_labelIds.png 95 | │ │ │ │ ├── aachen_000000_000019_gtFine_polygons.json 96 | │ │ │ │ ├── ... 97 | │ │ │ ├── bochum 98 | │ │ │ ├── ... 99 | │ │ ├── val 100 | │ │ ├── test 101 | ├── leftImg8bit_trainextra 102 | │ ├── leftImg8bit 103 | │ │ ├── train_extra 104 | │ │ │ ├── augsburg 105 | │ │ │ ├── bad-honnef 106 | │ │ │ ├── ... 107 | ├── gtCoarse 108 | │ ├── gtCoarse 109 | │ │ ├── train 110 | │ │ ├── train_extra 111 | │ │ ├── val 112 | ├── autolabelled 113 | │ ├── train_extra 114 | │ │ ├── augsburg 115 | │ │ ├── bad-honnef 116 | │ │ ├── ... 117 | ``` 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### [Paper](https://arxiv.org/abs/2005.10821) | [YouTube](https://youtu.be/odAGA7pFBGA) | [Cityscapes Score](https://www.cityscapes-dataset.com/method-details/?submissionID=7836)
2 | 3 | Pytorch implementation of our paper [Hierarchical Multi-Scale Attention for Semantic Segmentation](https://arxiv.org/abs/2005.10821).
4 | 5 | Please refer to the `sdcnet` branch if you are looking for the code corresponding to [Improving Semantic Segmentation via Video Prediction and Label Relaxation](https://nv-adlr.github.io/publication/2018-Segmentation). 6 | 7 | ## Installation 8 | 9 | * The code is tested with pytorch 1.3 and python 3.6 10 | * You can use ./Dockerfile to build an image. 11 | 12 | 13 | ## Download Weights 14 | 15 | * Create a directory where you can keep large files. Ideally, not in this directory. 16 | ```bash 17 | > mkdir 18 | ``` 19 | 20 | * Update `__C.ASSETS_PATH` in `config.py` to point at that directory 21 | 22 | __C.ASSETS_PATH= 23 | 24 | * Download pretrained weights from [google drive](https://drive.google.com/open?id=1fs-uLzXvmsISbS635eRZCc5uzQdBIZ_U) and put into `/seg_weights` 25 | 26 | ## Download/Prepare Data 27 | 28 | If using Cityscapes, download Cityscapes data, then update `config.py` to set the path: 29 | ```python 30 | __C.DATASET.CITYSCAPES_DIR= 31 | ``` 32 | 33 | * Download Autolabelled-Data from [google drive](https://drive.google.com/file/d/1DtPo-WP-hjaOwsbj6ZxTtOo_7R_4TKRG/view?usp=sharing) 34 | 35 | If using Cityscapes Autolabelled Images, download Cityscapes data, then update `config.py` to set the path: 36 | ```python 37 | __C.DATASET.CITYSCAPES_CUSTOMCOARSE= 38 | ``` 39 | 40 | If using Mapillary, download Mapillary data, then update `config.py` to set the path: 41 | ```python 42 | __C.DATASET.MAPILLARY_DIR= 43 | ``` 44 | 45 | 46 | ## Running the code 47 | 48 | The instructions below make use of a tool called `runx`, which we find useful to help automate experiment running and summarization. For more information about this tool, please see [runx](https://github.com/NVIDIA/runx). 49 | In general, you can either use the runx-style commandlines shown below. Or you can call `python train.py ` directly if you like. 50 | 51 | 52 | ### Run inference on Cityscapes 53 | 54 | Dry run: 55 | ```bash 56 | > python -m runx.runx scripts/eval_cityscapes.yml -i -n 57 | ``` 58 | This will just print out the command but not run. It's a good way to inspect the commandline. 59 | 60 | Real run: 61 | ```bash 62 | > python -m runx.runx scripts/eval_cityscapes.yml -i 63 | ``` 64 | 65 | The reported IOU should be 86.92. This evaluates with scales of 0.5, 1.0. and 2.0. You will find evaluation results in ./logs/eval_cityscapes/... 66 | 67 | ### Run inference on Mapillary 68 | 69 | ```bash 70 | > python -m runx.runx scripts/eval_mapillary.yml -i 71 | ``` 72 | 73 | The reported IOU should be 61.05. Note that this must be run on a 32GB node and the use of 'O3' mode for amp is critical in order to avoid GPU out of memory. Results in logs/eval_mapillary/... 74 | 75 | ### Dump images for Cityscapes 76 | 77 | ```bash 78 | > python -m runx.runx scripts/dump_cityscapes.yml -i 79 | ``` 80 | 81 | This will dump network output and composited images from running evaluation with the Cityscapes validation set. 82 | 83 | ### Run inference and dump images on a folder of images 84 | 85 | ```bash 86 | > python -m runx.runx scripts/dump_folder.yml -i 87 | ``` 88 | 89 | You should end up seeing images that look like the following: 90 | 91 | ![alt text](imgs/composited_sf.png "example inference, composited") 92 | 93 | ## Train a model 94 | 95 | Train cityscapes, using HRNet + OCR + multi-scale attention with fine data and mapillary-pretrained model 96 | ```bash 97 | > python -m runx.runx scripts/train_cityscapes.yml -i 98 | ``` 99 | 100 | The first time this command is run, a centroid file has to be built for the dataset. It'll take about 10 minutes. The centroid file is used during training to know how to sample from the dataset in a class-uniform way. 101 | 102 | This training run should deliver a model that achieves 84.7 IOU. 103 | 104 | ## Train SOTA default train-val split 105 | ```bash 106 | > python -m runx.runx scripts/train_cityscapes_sota.yml -i 107 | ``` 108 | Again, use `-n` to do a dry run and just print out the command. This should result in a model with 86.8 IOU. If you run out of memory, try to lower the crop size or turn off rmi_loss. 109 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | Dataset setup and loaders 31 | """ 32 | 33 | import importlib 34 | import torchvision.transforms as standard_transforms 35 | 36 | import transforms.joint_transforms as joint_transforms 37 | import transforms.transforms as extended_transforms 38 | from torch.utils.data import DataLoader 39 | 40 | from config import cfg, update_dataset_cfg, update_dataset_inst 41 | from runx.logx import logx 42 | from datasets.randaugment import RandAugment 43 | 44 | 45 | def setup_loaders(args): 46 | """ 47 | Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin] 48 | input: argument passed by the user 49 | return: training data loader, validation data loader loader, train_set 50 | """ 51 | 52 | # TODO add error checking to make sure class exists 53 | logx.msg(f'dataset = {args.dataset}') 54 | 55 | mod = importlib.import_module('datasets.{}'.format(args.dataset)) 56 | dataset_cls = getattr(mod, 'Loader') 57 | 58 | logx.msg(f'ignore_label = {dataset_cls.ignore_label}') 59 | 60 | update_dataset_cfg(num_classes=dataset_cls.num_classes, 61 | ignore_label=dataset_cls.ignore_label) 62 | 63 | ###################################################################### 64 | # Define transformations, augmentations 65 | ###################################################################### 66 | 67 | # Joint transformations that must happen on both image and mask 68 | if ',' in args.crop_size: 69 | args.crop_size = [int(x) for x in args.crop_size.split(',')] 70 | else: 71 | args.crop_size = int(args.crop_size) 72 | train_joint_transform_list = [ 73 | # TODO FIXME: move these hparams into cfg 74 | joint_transforms.RandomSizeAndCrop(args.crop_size, 75 | False, 76 | scale_min=args.scale_min, 77 | scale_max=args.scale_max, 78 | full_size=args.full_crop_training, 79 | pre_size=args.pre_size)] 80 | train_joint_transform_list.append( 81 | joint_transforms.RandomHorizontallyFlip()) 82 | 83 | if args.rand_augment is not None: 84 | N, M = [int(i) for i in args.rand_augment.split(',')] 85 | assert isinstance(N, int) and isinstance(M, int), \ 86 | f'Either N {N} or M {M} not integer' 87 | train_joint_transform_list.append(RandAugment(N, M)) 88 | 89 | ###################################################################### 90 | # Image only augmentations 91 | ###################################################################### 92 | train_input_transform = [] 93 | 94 | if args.color_aug: 95 | train_input_transform += [extended_transforms.ColorJitter( 96 | brightness=args.color_aug, 97 | contrast=args.color_aug, 98 | saturation=args.color_aug, 99 | hue=args.color_aug)] 100 | if args.bblur: 101 | train_input_transform += [extended_transforms.RandomBilateralBlur()] 102 | elif args.gblur: 103 | train_input_transform += [extended_transforms.RandomGaussianBlur()] 104 | 105 | mean_std = (cfg.DATASET.MEAN, cfg.DATASET.STD) 106 | train_input_transform += [standard_transforms.ToTensor(), 107 | standard_transforms.Normalize(*mean_std)] 108 | train_input_transform = standard_transforms.Compose(train_input_transform) 109 | 110 | val_input_transform = standard_transforms.Compose([ 111 | standard_transforms.ToTensor(), 112 | standard_transforms.Normalize(*mean_std) 113 | ]) 114 | 115 | target_transform = extended_transforms.MaskToTensor() 116 | 117 | if args.jointwtborder: 118 | target_train_transform = \ 119 | extended_transforms.RelaxedBoundaryLossToTensor() 120 | else: 121 | target_train_transform = extended_transforms.MaskToTensor() 122 | 123 | if args.eval == 'folder': 124 | val_joint_transform_list = None 125 | elif 'mapillary' in args.dataset: 126 | if args.pre_size is None: 127 | eval_size = 2177 128 | else: 129 | eval_size = args.pre_size 130 | if cfg.DATASET.MAPILLARY_CROP_VAL: 131 | val_joint_transform_list = [ 132 | joint_transforms.ResizeHeight(eval_size), 133 | joint_transforms.CenterCropPad(eval_size)] 134 | else: 135 | val_joint_transform_list = [ 136 | joint_transforms.Scale(eval_size)] 137 | else: 138 | val_joint_transform_list = None 139 | 140 | if args.eval is None or args.eval == 'val': 141 | val_name = 'val' 142 | elif args.eval == 'trn': 143 | val_name = 'train' 144 | elif args.eval == 'folder': 145 | val_name = 'folder' 146 | else: 147 | raise 'unknown eval mode {}'.format(args.eval) 148 | 149 | ###################################################################### 150 | # Create loaders 151 | ###################################################################### 152 | val_set = dataset_cls( 153 | mode=val_name, 154 | joint_transform_list=val_joint_transform_list, 155 | img_transform=val_input_transform, 156 | label_transform=target_transform, 157 | eval_folder=args.eval_folder) 158 | 159 | update_dataset_inst(dataset_inst=val_set) 160 | 161 | if args.apex: 162 | from datasets.sampler import DistributedSampler 163 | val_sampler = DistributedSampler(val_set, pad=False, permutation=False, 164 | consecutive_sample=False) 165 | else: 166 | val_sampler = None 167 | 168 | val_loader = DataLoader(val_set, batch_size=args.bs_val, 169 | num_workers=args.num_workers // 2, 170 | shuffle=False, drop_last=False, 171 | sampler=val_sampler) 172 | 173 | if args.eval is not None: 174 | # Don't create train dataloader if eval 175 | train_set = None 176 | train_loader = None 177 | else: 178 | train_set = dataset_cls( 179 | mode='train', 180 | joint_transform_list=train_joint_transform_list, 181 | img_transform=train_input_transform, 182 | label_transform=target_train_transform) 183 | 184 | if args.apex: 185 | from datasets.sampler import DistributedSampler 186 | train_sampler = DistributedSampler(train_set, pad=True, 187 | permutation=True, 188 | consecutive_sample=False) 189 | train_batch_size = args.bs_trn 190 | else: 191 | train_sampler = None 192 | train_batch_size = args.bs_trn * args.ngpu 193 | 194 | train_loader = DataLoader(train_set, batch_size=train_batch_size, 195 | num_workers=args.num_workers, 196 | shuffle=(train_sampler is None), 197 | drop_last=True, sampler=train_sampler) 198 | 199 | return train_loader, val_loader, train_set 200 | -------------------------------------------------------------------------------- /datasets/base_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | Generic dataloader base class 31 | """ 32 | import os 33 | import glob 34 | import numpy as np 35 | import torch 36 | 37 | from PIL import Image 38 | from torch.utils import data 39 | from config import cfg 40 | from datasets import uniform 41 | from runx.logx import logx 42 | from utils.misc import tensor_to_pil 43 | 44 | 45 | class BaseLoader(data.Dataset): 46 | def __init__(self, quality, mode, joint_transform_list, img_transform, 47 | label_transform): 48 | 49 | super(BaseLoader, self).__init__() 50 | self.quality = quality 51 | self.mode = mode 52 | self.joint_transform_list = joint_transform_list 53 | self.img_transform = img_transform 54 | self.label_transform = label_transform 55 | self.train = mode == 'train' 56 | self.id_to_trainid = {} 57 | self.centroids = None 58 | self.all_imgs = None 59 | self.drop_mask = np.zeros((1024, 2048)) 60 | self.drop_mask[15:840, 14:2030] = 1.0 61 | 62 | def build_epoch(self): 63 | """ 64 | For class uniform sampling ... every epoch, we want to recompute 65 | which tiles from which images we want to sample from, so that the 66 | sampling is uniformly random. 67 | """ 68 | self.imgs = uniform.build_epoch(self.all_imgs, 69 | self.centroids, 70 | self.num_classes, 71 | self.train) 72 | 73 | @staticmethod 74 | def find_images(img_root, mask_root, img_ext, mask_ext): 75 | """ 76 | Find image and segmentation mask files and return a list of 77 | tuples of them. 78 | """ 79 | img_path = '{}/*.{}'.format(img_root, img_ext) 80 | imgs = glob.glob(img_path) 81 | items = [] 82 | for full_img_fn in imgs: 83 | img_dir, img_fn = os.path.split(full_img_fn) 84 | img_name, _ = os.path.splitext(img_fn) 85 | full_mask_fn = '{}.{}'.format(img_name, mask_ext) 86 | full_mask_fn = os.path.join(mask_root, full_mask_fn) 87 | assert os.path.exists(full_mask_fn) 88 | items.append((full_img_fn, full_mask_fn)) 89 | return items 90 | 91 | def disable_coarse(self): 92 | pass 93 | 94 | def colorize_mask(self, image_array): 95 | """ 96 | Colorize the segmentation mask 97 | """ 98 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P') 99 | new_mask.putpalette(self.color_mapping) 100 | return new_mask 101 | 102 | def dump_images(self, img_name, mask, centroid, class_id, img): 103 | img = tensor_to_pil(img) 104 | outdir = 'new_dump_imgs_{}'.format(self.mode) 105 | os.makedirs(outdir, exist_ok=True) 106 | if centroid is not None: 107 | dump_img_name = '{}_{}'.format(self.trainid_to_name[class_id], 108 | img_name) 109 | else: 110 | dump_img_name = img_name 111 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 112 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 113 | out_raw_fn = os.path.join(outdir, dump_img_name + '_mask_raw.png') 114 | mask_img = self.colorize_mask(np.array(mask)) 115 | raw_img = Image.fromarray(np.array(mask)) 116 | img.save(out_img_fn) 117 | mask_img.save(out_msk_fn) 118 | raw_img.save(out_raw_fn) 119 | 120 | def do_transforms(self, img, mask, centroid, img_name, class_id): 121 | """ 122 | Do transformations to image and mask 123 | 124 | :returns: image, mask 125 | """ 126 | scale_float = 1.0 127 | 128 | if self.joint_transform_list is not None: 129 | for idx, xform in enumerate(self.joint_transform_list): 130 | if idx == 0 and centroid is not None: 131 | # HACK! Assume the first transform accepts a centroid 132 | outputs = xform(img, mask, centroid) 133 | else: 134 | outputs = xform(img, mask) 135 | 136 | if len(outputs) == 3: 137 | img, mask, scale_float = outputs 138 | else: 139 | img, mask = outputs 140 | 141 | if self.img_transform is not None: 142 | img = self.img_transform(img) 143 | 144 | if cfg.DATASET.DUMP_IMAGES: 145 | self.dump_images(img_name, mask, centroid, class_id, img) 146 | 147 | if self.label_transform is not None: 148 | mask = self.label_transform(mask) 149 | 150 | return img, mask, scale_float 151 | 152 | def read_images(self, img_path, mask_path, mask_out=False): 153 | img = Image.open(img_path).convert('RGB') 154 | if mask_path is None or mask_path == '': 155 | w, h = img.size 156 | mask = np.zeros((h, w)) 157 | else: 158 | mask = Image.open(mask_path) 159 | 160 | drop_out_mask = None 161 | # This code is specific to cityscapes 162 | if(cfg.DATASET.CITYSCAPES_CUSTOMCOARSE in mask_path): 163 | 164 | gtCoarse_mask_path = mask_path.replace(cfg.DATASET.CITYSCAPES_CUSTOMCOARSE, os.path.join(cfg.DATASET.CITYSCAPES_DIR, 'gtCoarse/gtCoarse') ) 165 | gtCoarse_mask_path = gtCoarse_mask_path.replace('leftImg8bit','gtCoarse_labelIds') 166 | gtCoarse=np.array(Image.open(gtCoarse_mask_path)) 167 | 168 | 169 | 170 | img_name = os.path.splitext(os.path.basename(img_path))[0] 171 | 172 | mask = np.array(mask) 173 | if (mask_out): 174 | mask = self.drop_mask * mask 175 | 176 | mask = mask.copy() 177 | for k, v in self.id_to_trainid.items(): 178 | binary_mask = (mask == k) #+ (gtCoarse == k) 179 | if ('refinement' in mask_path) and cfg.DROPOUT_COARSE_BOOST_CLASSES != None and v in cfg.DROPOUT_COARSE_BOOST_CLASSES and binary_mask.sum() > 0 and 'vidseq' not in mask_path: 180 | binary_mask += (gtCoarse == k) 181 | binary_mask[binary_mask >= 1] = 1 182 | mask[binary_mask] = gtCoarse[binary_mask] 183 | mask[binary_mask] = v 184 | 185 | 186 | mask = Image.fromarray(mask.astype(np.uint8)) 187 | return img, mask, img_name 188 | 189 | def __getitem__(self, index): 190 | """ 191 | Generate data: 192 | 193 | :return: 194 | - image: image, tensor 195 | - mask: mask, tensor 196 | - image_name: basename of file, string 197 | """ 198 | # Pick an image, fill in defaults if not using class uniform 199 | if len(self.imgs[index]) == 2: 200 | img_path, mask_path = self.imgs[index] 201 | centroid = None 202 | class_id = None 203 | else: 204 | img_path, mask_path, centroid, class_id = self.imgs[index] 205 | 206 | mask_out = cfg.DATASET.MASK_OUT_CITYSCAPES and \ 207 | cfg.DATASET.CUSTOM_COARSE_PROB is not None and \ 208 | 'refinement' in mask_path 209 | 210 | img, mask, img_name = self.read_images(img_path, mask_path, 211 | mask_out=mask_out) 212 | 213 | ###################################################################### 214 | # Thresholding is done when using coarse-labelled Cityscapes images 215 | ###################################################################### 216 | if 'refinement' in mask_path: 217 | 218 | mask = np.array(mask) 219 | prob_mask_path = mask_path.replace('.png', '_prob.png') 220 | # put it in 0 to 1 221 | prob_map = np.array(Image.open(prob_mask_path)) / 255.0 222 | prob_map_threshold = (prob_map < cfg.DATASET.CUSTOM_COARSE_PROB) 223 | mask[prob_map_threshold] = cfg.DATASET.IGNORE_LABEL 224 | mask = Image.fromarray(mask.astype(np.uint8)) 225 | 226 | img, mask, scale_float = self.do_transforms(img, mask, centroid, 227 | img_name, class_id) 228 | 229 | return img, mask, img_name, scale_float 230 | 231 | def __len__(self): 232 | return len(self.imgs) 233 | 234 | def calculate_weights(self): 235 | raise BaseException("not supported yet") 236 | -------------------------------------------------------------------------------- /datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | import os 31 | import os.path as path 32 | 33 | from config import cfg 34 | from runx.logx import logx 35 | from datasets.base_loader import BaseLoader 36 | import datasets.cityscapes_labels as cityscapes_labels 37 | import datasets.uniform as uniform 38 | from datasets.utils import make_dataset_folder 39 | 40 | 41 | def cities_cv_split(root, split, cv_split): 42 | """ 43 | Find cities that correspond to a given split of the data. We split the data 44 | such that a given city belongs to either train or val, but never both. cv0 45 | is defined to be the default split. 46 | 47 | all_cities = [x x x x x x x x x x x x] 48 | val: 49 | split0 [x x x ] 50 | split1 [ x x x ] 51 | split2 [ x x x ] 52 | trn: 53 | split0 [ x x x x x x x x x] 54 | split1 [x x x x x x x x x] 55 | split2 [x x x x x x x x ] 56 | 57 | split - train/val/test 58 | cv_split - 0,1,2,3 59 | 60 | cv_split == 3 means use train + val 61 | """ 62 | trn_path = path.join(root, 'leftImg8bit_trainvaltest/leftImg8bit', 'train') 63 | val_path = path.join(root, 'leftImg8bit_trainvaltest/leftImg8bit', 'val') 64 | 65 | trn_cities = ['train/' + c for c in os.listdir(trn_path)] 66 | trn_cities = sorted(trn_cities) # sort to insure reproducibility 67 | val_cities = ['val/' + c for c in os.listdir(val_path)] 68 | 69 | all_cities = val_cities + trn_cities 70 | 71 | if cv_split == 3: 72 | logx.msg('cv split {} {} {}'.format(split, cv_split, all_cities)) 73 | return all_cities 74 | 75 | num_val_cities = len(val_cities) 76 | num_cities = len(all_cities) 77 | 78 | offset = cv_split * num_cities // cfg.DATASET.CV_SPLITS 79 | cities = [] 80 | for j in range(num_cities): 81 | if j >= offset and j < (offset + num_val_cities): 82 | if split == 'val': 83 | cities.append(all_cities[j]) 84 | else: 85 | if split == 'train': 86 | cities.append(all_cities[j]) 87 | 88 | logx.msg('cv split {} {} {}'.format(split, cv_split, cities)) 89 | return cities 90 | 91 | 92 | def coarse_cities(root): 93 | """ 94 | Find coarse cities 95 | """ 96 | split = 'train_extra' 97 | coarse_path = path.join(root, 'leftImg8bit_trainextra/leftImg8bit', 98 | split) 99 | coarse_cities = [f'{split}/' + c for c in os.listdir(coarse_path)] 100 | 101 | logx.msg(f'found {len(coarse_cities)} coarse cities') 102 | return coarse_cities 103 | 104 | 105 | class Loader(BaseLoader): 106 | num_classes = 19 107 | ignore_label = 255 108 | trainid_to_name = {} 109 | color_mapping = [] 110 | 111 | def __init__(self, mode, quality='fine', joint_transform_list=None, 112 | img_transform=None, label_transform=None, eval_folder=None): 113 | 114 | super(Loader, self).__init__(quality=quality, mode=mode, 115 | joint_transform_list=joint_transform_list, 116 | img_transform=img_transform, 117 | label_transform=label_transform) 118 | 119 | ###################################################################### 120 | # Cityscapes-specific stuff: 121 | ###################################################################### 122 | self.root = cfg.DATASET.CITYSCAPES_DIR 123 | self.id_to_trainid = cityscapes_labels.label2trainid 124 | self.trainid_to_name = cityscapes_labels.trainId2name 125 | self.fill_colormap() 126 | img_ext = 'png' 127 | mask_ext = 'png' 128 | img_root = path.join(self.root, 'leftImg8bit_trainvaltest/leftImg8bit') 129 | mask_root = path.join(self.root, 'gtFine_trainvaltest/gtFine') 130 | if mode == 'folder': 131 | self.all_imgs = make_dataset_folder(eval_folder) 132 | else: 133 | self.fine_cities = cities_cv_split(self.root, mode, cfg.DATASET.CV) 134 | self.all_imgs = self.find_cityscapes_images( 135 | self.fine_cities, img_root, mask_root, img_ext, mask_ext) 136 | 137 | logx.msg(f'cn num_classes {self.num_classes}') 138 | self.fine_centroids = uniform.build_centroids(self.all_imgs, 139 | self.num_classes, 140 | self.train, 141 | cv=cfg.DATASET.CV, 142 | id2trainid=self.id_to_trainid) 143 | self.centroids = self.fine_centroids 144 | 145 | if cfg.DATASET.COARSE_BOOST_CLASSES and mode == 'train': 146 | self.coarse_cities = coarse_cities(self.root) 147 | img_root = path.join(self.root, 148 | 'leftImg8bit_trainextra/leftImg8bit') 149 | mask_root = path.join(self.root, 'gtCoarse', 'gtCoarse') 150 | self.coarse_imgs = self.find_cityscapes_images( 151 | self.coarse_cities, img_root, mask_root, img_ext, mask_ext, 152 | fine_coarse='gtCoarse') 153 | 154 | if cfg.DATASET.CLASS_UNIFORM_PCT: 155 | 156 | custom_coarse = (cfg.DATASET.CUSTOM_COARSE_PROB is not None) 157 | self.coarse_centroids = uniform.build_centroids( 158 | self.coarse_imgs, self.num_classes, self.train, 159 | coarse=(not custom_coarse), custom_coarse=custom_coarse, 160 | id2trainid=self.id_to_trainid) 161 | 162 | for cid in cfg.DATASET.COARSE_BOOST_CLASSES: 163 | self.centroids[cid].extend(self.coarse_centroids[cid]) 164 | else: 165 | self.all_imgs.extend(self.coarse_imgs) 166 | 167 | self.build_epoch() 168 | 169 | def disable_coarse(self): 170 | """ 171 | Turn off using coarse images in training 172 | """ 173 | self.centroids = self.fine_centroids 174 | 175 | def only_coarse(self): 176 | """ 177 | Turn on using coarse images in training 178 | """ 179 | print('==============+Running Only Coarse+===============') 180 | self.centroids = self.coarse_centroids 181 | 182 | def find_cityscapes_images(self, cities, img_root, mask_root, img_ext, 183 | mask_ext, fine_coarse='gtFine'): 184 | """ 185 | Find image and segmentation mask files and return a list of 186 | tuples of them. 187 | 188 | Inputs: 189 | img_root: path to parent directory of train/val/test dirs 190 | mask_root: path to parent directory of train/val/test dirs 191 | img_ext: image file extension 192 | mask_ext: mask file extension 193 | cities: a list of cities, each element in the form of 'train/a_city' 194 | or 'val/a_city', for example. 195 | """ 196 | items = [] 197 | for city in cities: 198 | img_dir = '{root}/{city}'.format(root=img_root, city=city) 199 | for file_name in os.listdir(img_dir): 200 | basename, ext = os.path.splitext(file_name) 201 | assert ext == '.' + img_ext, '{} {}'.format(ext, img_ext) 202 | full_img_fn = os.path.join(img_dir, file_name) 203 | basename, ext = file_name.split('_leftImg8bit') 204 | if cfg.DATASET.CUSTOM_COARSE_PROB and fine_coarse != 'gtFine': 205 | mask_fn = f'{basename}_leftImg8bit.png' 206 | cc_path = cfg.DATASET.CITYSCAPES_CUSTOMCOARSE 207 | full_mask_fn = os.path.join(cc_path, city, mask_fn) 208 | os.path.isfile(full_mask_fn) 209 | else: 210 | mask_fn = f'{basename}_{fine_coarse}_labelIds{ext}' 211 | full_mask_fn = os.path.join(mask_root, city, mask_fn) 212 | items.append((full_img_fn, full_mask_fn)) 213 | 214 | logx.msg('mode {} found {} images'.format(self.mode, len(items))) 215 | 216 | return items 217 | 218 | def fill_colormap(self): 219 | palette = [128, 64, 128, 220 | 244, 35, 232, 221 | 70, 70, 70, 222 | 102, 102, 156, 223 | 190, 153, 153, 224 | 153, 153, 153, 225 | 250, 170, 30, 226 | 220, 220, 0, 227 | 107, 142, 35, 228 | 152, 251, 152, 229 | 70, 130, 180, 230 | 220, 20, 60, 231 | 255, 0, 0, 232 | 0, 0, 142, 233 | 0, 0, 70, 234 | 0, 60, 100, 235 | 0, 80, 100, 236 | 0, 0, 230, 237 | 119, 11, 32] 238 | zero_pad = 256 * 3 - len(palette) 239 | for i in range(zero_pad): 240 | palette.append(0) 241 | self.color_mapping = palette 242 | -------------------------------------------------------------------------------- /datasets/cityscapes_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | # File taken from https://github.com/mcordts/cityscapesScripts/ 3 | # License File Available at: 4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt 5 | 6 | # ---------------------- 7 | # The Cityscapes Dataset 8 | # ---------------------- 9 | # 10 | # 11 | # License agreement 12 | # ----------------- 13 | # 14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree: 15 | # 16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions. 17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website. 18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character. 19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain. 20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt). 21 | # 22 | # 23 | # Contact 24 | # ------- 25 | # 26 | # Marius Cordts, Mohamed Omran 27 | # www.cityscapes-dataset.net 28 | 29 | """ 30 | from collections import namedtuple 31 | 32 | 33 | #-------------------------------------------------------------------------------- 34 | # Definitions 35 | #-------------------------------------------------------------------------------- 36 | 37 | # a label and all meta information 38 | Label = namedtuple( 'Label' , [ 39 | 40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 41 | # We use them to uniquely name a class 42 | 43 | 'id' , # An integer ID that is associated with this label. 44 | # The IDs are used to represent the label in ground truth images 45 | # An ID of -1 means that this label does not have an ID and thus 46 | # is ignored when creating ground truth images (e.g. license plate). 47 | # Do not modify these IDs, since exactly these IDs are expected by the 48 | # evaluation server. 49 | 50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 51 | # ground truth images with train IDs, using the tools provided in the 52 | # 'preparation' folder. However, make sure to validate or submit results 53 | # to our evaluation server using the regular IDs above! 54 | # For trainIds, multiple labels might have the same ID. Then, these labels 55 | # are mapped to the same class in the ground truth images. For the inverse 56 | # mapping, we use the label that is defined first in the list below. 57 | # For example, mapping all void-type classes to the same ID in training, 58 | # might make sense for some approaches. 59 | # Max value is 255! 60 | 61 | 'category' , # The name of the category that this label belongs to 62 | 63 | 'categoryId' , # The ID of this category. Used to create ground truth images 64 | # on category level. 65 | 66 | 'hasInstances', # Whether this label distinguishes between single instances or not 67 | 68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 69 | # during evaluations or not 70 | 71 | 'color' , # The color of this label 72 | ] ) 73 | 74 | 75 | #-------------------------------------------------------------------------------- 76 | # A list of all labels 77 | #-------------------------------------------------------------------------------- 78 | 79 | # Please adapt the train IDs as appropriate for you approach. 80 | # Note that you might want to ignore labels with ID 255 during training. 81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 82 | # Make sure to provide your results using the original IDs and not the training IDs. 83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 84 | 85 | labels = [ 86 | # name id trainId category catId hasInstances ignoreInEval color 87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 122 | ] 123 | 124 | 125 | #-------------------------------------------------------------------------------- 126 | # Create dictionaries for a fast lookup 127 | #-------------------------------------------------------------------------------- 128 | 129 | # Please refer to the main method below for example usages! 130 | 131 | # name to label object 132 | name2label = { label.name : label for label in labels } 133 | # id to label object 134 | id2label = { label.id : label for label in labels } 135 | # trainId to label object 136 | trainId2label = { label.trainId : label for label in reversed(labels) } 137 | # label2trainid 138 | label2trainid = { label.id : label.trainId for label in labels } 139 | # trainId to label object 140 | trainId2name = { label.trainId : label.name for label in labels } 141 | trainId2color = { label.trainId : label.color for label in labels } 142 | # category to list of label objects 143 | category2labels = {} 144 | for label in labels: 145 | category = label.category 146 | if category in category2labels: 147 | category2labels[category].append(label) 148 | else: 149 | category2labels[category] = [label] 150 | 151 | #-------------------------------------------------------------------------------- 152 | # Assure single instance name 153 | #-------------------------------------------------------------------------------- 154 | 155 | # returns the label name that describes a single instance (if possible) 156 | # e.g. input | output 157 | # ---------------------- 158 | # car | car 159 | # cargroup | car 160 | # foo | None 161 | # foogroup | None 162 | # skygroup | None 163 | def assureSingleInstanceName( name ): 164 | # if the name is known, it is not a group 165 | if name in name2label: 166 | return name 167 | # test if the name actually denotes a group 168 | if not name.endswith("group"): 169 | return None 170 | # remove group 171 | name = name[:-len("group")] 172 | # test if the new name exists 173 | if not name in name2label: 174 | return None 175 | # test if the new name denotes a label that actually has instances 176 | if not name2label[name].hasInstances: 177 | return None 178 | # all good then 179 | return name 180 | 181 | #-------------------------------------------------------------------------------- 182 | # Main for testing 183 | #-------------------------------------------------------------------------------- 184 | 185 | # just a dummy main 186 | if __name__ == "__main__": 187 | # Print all the labels 188 | print("List of cityscapes labels:") 189 | print("") 190 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))) 191 | print((" " + ('-' * 98))) 192 | for label in labels: 193 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))) 194 | print("") 195 | 196 | print("Example usages:") 197 | 198 | # Map from name to label 199 | name = 'car' 200 | id = name2label[name].id 201 | print(("ID of label '{name}': {id}".format( name=name, id=id ))) 202 | 203 | # Map from ID to label 204 | category = id2label[id].category 205 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category ))) 206 | 207 | # Map from trainID to label 208 | trainId = 0 209 | name = trainId2label[trainId].name 210 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))) 211 | -------------------------------------------------------------------------------- /datasets/mapillary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | Mapillary Dataset Loader 32 | """ 33 | import os 34 | import json 35 | 36 | from config import cfg 37 | from runx.logx import logx 38 | from datasets.base_loader import BaseLoader 39 | from datasets.utils import make_dataset_folder 40 | from datasets import uniform 41 | 42 | 43 | class Loader(BaseLoader): 44 | num_classes = 65 45 | ignore_label = 65 46 | trainid_to_name = {} 47 | color_mapping = [] 48 | 49 | def __init__(self, mode, quality='semantic', joint_transform_list=None, 50 | img_transform=None, label_transform=None, eval_folder=None): 51 | 52 | super(Loader, self).__init__(quality=quality, 53 | mode=mode, 54 | joint_transform_list=joint_transform_list, 55 | img_transform=img_transform, 56 | label_transform=label_transform) 57 | 58 | root = cfg.DATASET.MAPILLARY_DIR 59 | config_fn = os.path.join(root, 'config.json') 60 | self.fill_colormap_and_names(config_fn) 61 | 62 | ###################################################################### 63 | # Assemble image lists 64 | ###################################################################### 65 | if mode == 'folder': 66 | self.all_imgs = make_dataset_folder(eval_folder) 67 | else: 68 | splits = {'train': 'training', 69 | 'val': 'validation', 70 | 'test': 'testing'} 71 | split_name = splits[mode] 72 | img_ext = 'jpg' 73 | mask_ext = 'png' 74 | img_root = os.path.join(root, split_name, 'images') 75 | mask_root = os.path.join(root, split_name, 'labels') 76 | self.all_imgs = self.find_images(img_root, mask_root, img_ext, 77 | mask_ext) 78 | logx.msg('all imgs {}'.format(len(self.all_imgs))) 79 | self.centroids = uniform.build_centroids(self.all_imgs, 80 | self.num_classes, 81 | self.train, 82 | cv=cfg.DATASET.CV) 83 | self.build_epoch() 84 | 85 | def fill_colormap_and_names(self, config_fn): 86 | """ 87 | Mapillary code for color map and class names 88 | 89 | Outputs 90 | ------- 91 | self.trainid_to_name 92 | self.color_mapping 93 | """ 94 | with open(config_fn) as config_file: 95 | config = json.load(config_file) 96 | config_labels = config['labels'] 97 | 98 | # calculate label color mapping 99 | colormap = [] 100 | self.trainid_to_name = {} 101 | for i in range(0, len(config_labels)): 102 | colormap = colormap + config_labels[i]['color'] 103 | name = config_labels[i]['readable'] 104 | name = name.replace(' ', '_') 105 | self.trainid_to_name[i] = name 106 | self.color_mapping = colormap 107 | -------------------------------------------------------------------------------- /datasets/nullloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | Null Loader 31 | """ 32 | from config import cfg 33 | from runx.logx import logx 34 | from datasets.base_loader import BaseLoader 35 | from datasets.utils import make_dataset_folder 36 | from datasets import uniform 37 | import numpy as np 38 | import torch 39 | from torch.utils import data 40 | 41 | class Loader(BaseLoader): 42 | """ 43 | Null Dataset for Performance 44 | """ 45 | num_classes = 19 46 | ignore_label = 255 47 | trainid_to_name = {} 48 | color_mapping = [] 49 | 50 | def __init__(self, mode, quality=None, joint_transform_list=None, 51 | img_transform=None, label_transform=None, eval_folder=None): 52 | super(Loader, self).__init__(quality=quality, 53 | mode=mode, 54 | joint_transform_list=joint_transform_list, 55 | img_transform=img_transform, 56 | label_transform=label_transform) 57 | 58 | def __getitem__(self, index): 59 | # return img, mask, img_name, scale_float 60 | crop_size = cfg.DATASET.CROP_SIZE 61 | if ',' in crop_size: 62 | crop_size = [int(x) for x in crop_size.split(',')] 63 | else: 64 | crop_size = int(crop_size) 65 | crop_size = [crop_size, crop_size] 66 | 67 | img = torch.FloatTensor(np.zeros([3] + crop_size)) 68 | mask = torch.LongTensor(np.zeros(crop_size)) 69 | img_name = f'img{index}' 70 | scale_float = 0.0 71 | return img, mask, img_name, scale_float 72 | 73 | def __len__(self): 74 | return 3000 75 | -------------------------------------------------------------------------------- /datasets/randaugment.py: -------------------------------------------------------------------------------- 1 | # this code from: https://github.com/ildoonet/pytorch-randaugment 2 | # code in this file is adpated from rpmcruz/autoaugment 3 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 4 | import random 5 | import numpy as np 6 | import torch 7 | 8 | from PIL import Image, ImageOps, ImageEnhance, ImageDraw 9 | from config import cfg 10 | 11 | 12 | fillmask = cfg.DATASET.IGNORE_LABEL 13 | fillcolor = (0, 0, 0) 14 | 15 | 16 | def affine_transform(pair, affine_params): 17 | img, mask = pair 18 | img = img.transform(img.size, Image.AFFINE, affine_params, 19 | resample=Image.BILINEAR, fillcolor=fillcolor) 20 | mask = mask.transform(mask.size, Image.AFFINE, affine_params, 21 | resample=Image.NEAREST, fillcolor=fillmask) 22 | return img, mask 23 | 24 | 25 | def ShearX(pair, v): # [-0.3, 0.3] 26 | assert -0.3 <= v <= 0.3 27 | if random.random() > 0.5: 28 | v = -v 29 | return affine_transform(pair, (1, v, 0, 0, 1, 0)) 30 | 31 | 32 | def ShearY(pair, v): # [-0.3, 0.3] 33 | assert -0.3 <= v <= 0.3 34 | if random.random() > 0.5: 35 | v = -v 36 | return affine_transform(pair, (1, 0, 0, v, 1, 0)) 37 | 38 | 39 | def TranslateX(pair, v): # [-150, 150] => percentage: [-0.45, 0.45] 40 | assert -0.45 <= v <= 0.45 41 | if random.random() > 0.5: 42 | v = -v 43 | img, _ = pair 44 | v = v * img.size[0] 45 | return affine_transform(pair, (1, 0, v, 0, 1, 0)) 46 | 47 | 48 | def TranslateY(pair, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert -0.45 <= v <= 0.45 50 | if random.random() > 0.5: 51 | v = -v 52 | img, _ = pair 53 | v = v * img.size[1] 54 | return affine_transform(pair, (1, 0, 0, 0, 1, v)) 55 | 56 | 57 | def TranslateXAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45] 58 | assert 0 <= v <= 10 59 | if random.random() > 0.5: 60 | v = -v 61 | return affine_transform(pair, (1, 0, v, 0, 1, 0)) 62 | 63 | 64 | def TranslateYAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45] 65 | assert 0 <= v <= 10 66 | if random.random() > 0.5: 67 | v = -v 68 | return affine_transform(pair, (1, 0, 0, 0, 1, v)) 69 | 70 | 71 | def Rotate(pair, v): # [-30, 30] 72 | assert -30 <= v <= 30 73 | if random.random() > 0.5: 74 | v = -v 75 | img, mask = pair 76 | img = img.rotate(v, fillcolor=fillcolor) 77 | mask = mask.rotate(v, resample=Image.NEAREST, fillcolor=fillmask) 78 | return img, mask 79 | 80 | 81 | def AutoContrast(pair, _): 82 | img, mask = pair 83 | return ImageOps.autocontrast(img), mask 84 | 85 | 86 | def Invert(pair, _): 87 | img, mask = pair 88 | return ImageOps.invert(img), mask 89 | 90 | 91 | def Equalize(pair, _): 92 | img, mask = pair 93 | return ImageOps.equalize(img), mask 94 | 95 | 96 | def Flip(pair, _): # not from the paper 97 | img, mask = pair 98 | return ImageOps.mirror(img), ImageOps.mirror(mask) 99 | 100 | 101 | def Solarize(pair, v): # [0, 256] 102 | img, mask = pair 103 | assert 0 <= v <= 256 104 | return ImageOps.solarize(img, v), mask 105 | 106 | 107 | def Posterize(pair, v): # [4, 8] 108 | img, mask = pair 109 | assert 4 <= v <= 8 110 | v = int(v) 111 | return ImageOps.posterize(img, v), mask 112 | 113 | 114 | def Posterize2(pair, v): # [0, 4] 115 | img, mask = pair 116 | assert 0 <= v <= 4 117 | v = int(v) 118 | return ImageOps.posterize(img, v), mask 119 | 120 | 121 | def Contrast(pair, v): # [0.1,1.9] 122 | img, mask = pair 123 | assert 0.1 <= v <= 1.9 124 | return ImageEnhance.Contrast(img).enhance(v), mask 125 | 126 | 127 | def Color(pair, v): # [0.1,1.9] 128 | img, mask = pair 129 | assert 0.1 <= v <= 1.9 130 | return ImageEnhance.Color(img).enhance(v), mask 131 | 132 | 133 | def Brightness(pair, v): # [0.1,1.9] 134 | img, mask = pair 135 | assert 0.1 <= v <= 1.9 136 | return ImageEnhance.Brightness(img).enhance(v), mask 137 | 138 | 139 | def Sharpness(pair, v): # [0.1,1.9] 140 | img, mask = pair 141 | assert 0.1 <= v <= 1.9 142 | return ImageEnhance.Sharpness(img).enhance(v), mask 143 | 144 | 145 | def Cutout(pair, v): # [0, 60] => percentage: [0, 0.2] 146 | assert 0.0 <= v <= 0.2 147 | if v <= 0.: 148 | return pair 149 | img, mask = pair 150 | v = v * img.size[0] 151 | return CutoutAbs(img, v), mask 152 | 153 | 154 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 155 | # assert 0 <= v <= 20 156 | if v < 0: 157 | return img 158 | w, h = img.size 159 | x0 = np.random.uniform(w) 160 | y0 = np.random.uniform(h) 161 | 162 | x0 = int(max(0, x0 - v / 2.)) 163 | y0 = int(max(0, y0 - v / 2.)) 164 | x1 = min(w, x0 + v) 165 | y1 = min(h, y0 + v) 166 | 167 | xy = (x0, y0, x1, y1) 168 | color = (125, 123, 114) 169 | # color = (0, 0, 0) 170 | img = img.copy() 171 | ImageDraw.Draw(img).rectangle(xy, color) 172 | return img 173 | 174 | 175 | def Identity(pair, v): 176 | return pair 177 | 178 | 179 | def augment_list(): # 16 oeprations and their ranges 180 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 181 | l = [ 182 | (Identity, 0., 1.0), 183 | (ShearX, 0., 0.3), # 0 184 | (ShearY, 0., 0.3), # 1 185 | (TranslateX, 0., 0.33), # 2 186 | (TranslateY, 0., 0.33), # 3 187 | (Rotate, 0, 30), # 4 188 | (AutoContrast, 0, 1), # 5 189 | (Invert, 0, 1), # 6 190 | (Equalize, 0, 1), # 7 191 | (Solarize, 0, 110), # 8 192 | (Posterize, 4, 8), # 9 193 | # (Contrast, 0.1, 1.9), # 10 194 | (Color, 0.1, 1.9), # 11 195 | (Brightness, 0.1, 1.9), # 12 196 | (Sharpness, 0.1, 1.9), # 13 197 | # (Cutout, 0, 0.2), # 14 198 | # (SamplePairing(imgs), 0, 0.4), # 15 199 | # (Flip, 1, 1), 200 | ] 201 | return l 202 | 203 | 204 | class Lighting(object): 205 | """Lighting noise(AlexNet - style PCA - based noise)""" 206 | 207 | def __init__(self, alphastd, eigval, eigvec): 208 | self.alphastd = alphastd 209 | self.eigval = torch.Tensor(eigval) 210 | self.eigvec = torch.Tensor(eigvec) 211 | 212 | def __call__(self, img): 213 | if self.alphastd == 0: 214 | return img 215 | 216 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 217 | rgb = self.eigvec.type_as(img).clone() \ 218 | .mul(alpha.view(1, 3).expand(3, 3)) \ 219 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 220 | .sum(1).squeeze() 221 | 222 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 223 | 224 | 225 | class CutoutDefault(object): 226 | """ 227 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 228 | """ 229 | def __init__(self, length): 230 | self.length = length 231 | 232 | def __call__(self, img): 233 | h, w = img.size(1), img.size(2) 234 | mask = np.ones((h, w), np.float32) 235 | y = np.random.randint(h) 236 | x = np.random.randint(w) 237 | 238 | y1 = np.clip(y - self.length // 2, 0, h) 239 | y2 = np.clip(y + self.length // 2, 0, h) 240 | x1 = np.clip(x - self.length // 2, 0, w) 241 | x2 = np.clip(x + self.length // 2, 0, w) 242 | 243 | mask[y1: y2, x1: x2] = 0. 244 | mask = torch.from_numpy(mask) 245 | mask = mask.expand_as(img) 246 | img *= mask 247 | return img 248 | 249 | 250 | class RandAugment: 251 | def __init__(self, n, m): 252 | self.n = n 253 | self.m = m # [0, 30] 254 | self.augment_list = augment_list() 255 | 256 | def __call__(self, img, mask): 257 | pair = img, mask 258 | ops = random.choices(self.augment_list, k=self.n) 259 | for op, minval, maxval in ops: 260 | val = (float(self.m) / 30) * float(maxval - minval) + minval 261 | pair = op(pair, val) 262 | 263 | return pair 264 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | 37 | 38 | import math 39 | import torch 40 | from torch.distributed import get_world_size, get_rank 41 | from torch.utils.data import Sampler 42 | 43 | class DistributedSampler(Sampler): 44 | """Sampler that restricts data loading to a subset of the dataset. 45 | 46 | It is especially useful in conjunction with 47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 48 | process can pass a DistributedSampler instance as a DataLoader sampler, 49 | and load a subset of the original dataset that is exclusive to it. 50 | 51 | .. note:: 52 | Dataset is assumed to be of constant size. 53 | 54 | Arguments: 55 | dataset: Dataset used for sampling. 56 | num_replicas (optional): Number of processes participating in 57 | distributed training. 58 | rank (optional): Rank of the current process within num_replicas. 59 | """ 60 | 61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None): 62 | if num_replicas is None: 63 | num_replicas = get_world_size() 64 | if rank is None: 65 | rank = get_rank() 66 | self.dataset = dataset 67 | self.num_replicas = num_replicas 68 | self.rank = rank 69 | self.epoch = 0 70 | self.consecutive_sample = consecutive_sample 71 | self.permutation = permutation 72 | if pad: 73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 74 | else: 75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 76 | self.total_size = self.num_samples * self.num_replicas 77 | 78 | def __iter__(self): 79 | # deterministically shuffle based on epoch 80 | g = torch.Generator() 81 | g.manual_seed(self.epoch) 82 | 83 | if self.permutation: 84 | indices = list(torch.randperm(len(self.dataset), generator=g)) 85 | else: 86 | indices = list([x for x in range(len(self.dataset))]) 87 | 88 | # add extra samples to make it evenly divisible 89 | if self.total_size > len(indices): 90 | indices += indices[:(self.total_size - len(indices))] 91 | 92 | # subsample 93 | if self.consecutive_sample: 94 | offset = self.num_samples * self.rank 95 | indices = indices[offset:offset + self.num_samples] 96 | else: 97 | indices = indices[self.rank:self.total_size:self.num_replicas] 98 | assert len(indices) == self.num_samples 99 | 100 | return iter(indices) 101 | 102 | def __len__(self): 103 | return self.num_samples 104 | 105 | def set_epoch(self, epoch): 106 | self.epoch = epoch 107 | 108 | def set_num_samples(self): 109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 110 | self.total_size = self.num_samples * self.num_replicas -------------------------------------------------------------------------------- /datasets/uniform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | Uniform sampling of classes. 32 | For all images, for all classes, generate centroids around which to sample. 33 | 34 | All images are divided into tiles. 35 | For each tile, a class can be present or not. If it is 36 | present, calculate the centroid of the class and record it. 37 | 38 | We would like to thank Peter Kontschieder for the inspiration of this idea. 39 | """ 40 | 41 | import sys 42 | import os 43 | import json 44 | import numpy as np 45 | 46 | import torch 47 | 48 | from collections import defaultdict 49 | from scipy.ndimage.measurements import center_of_mass 50 | from PIL import Image 51 | from tqdm import tqdm 52 | from config import cfg 53 | from runx.logx import logx 54 | 55 | pbar = None 56 | 57 | 58 | class Point(): 59 | """ 60 | Point Class For X and Y Location 61 | """ 62 | def __init__(self, x, y): 63 | self.x = x 64 | self.y = y 65 | 66 | 67 | def calc_tile_locations(tile_size, image_size): 68 | """ 69 | Divide an image into tiles to help us cover classes that are spread out. 70 | tile_size: size of tile to distribute 71 | image_size: original image size 72 | return: locations of the tiles 73 | """ 74 | image_size_y, image_size_x = image_size 75 | locations = [] 76 | for y in range(image_size_y // tile_size): 77 | for x in range(image_size_x // tile_size): 78 | x_offs = x * tile_size 79 | y_offs = y * tile_size 80 | locations.append((x_offs, y_offs)) 81 | return locations 82 | 83 | 84 | def class_centroids_image(item, tile_size, num_classes, id2trainid): 85 | """ 86 | For one image, calculate centroids for all classes present in image. 87 | item: image, image_name 88 | tile_size: 89 | num_classes: 90 | id2trainid: mapping from original id to training ids 91 | return: Centroids are calculated for each tile. 92 | """ 93 | image_fn, label_fn = item 94 | centroids = defaultdict(list) 95 | mask = np.array(Image.open(label_fn)) 96 | image_size = mask.shape 97 | tile_locations = calc_tile_locations(tile_size, image_size) 98 | 99 | drop_mask = np.zeros((1024,2048)) 100 | drop_mask[15:840, 14:2030] = 1.0 101 | 102 | 103 | ##### 104 | if(cfg.DATASET.CITYSCAPES_CUSTOMCOARSE in label_fn): 105 | gtCoarse_mask_path = label_fn.replace(cfg.DATASET.CITYSCAPES_CUSTOMCOARSE, os.path.join(cfg.DATASET.CITYSCAPES_DIR, 'gtCoarse/gtCoarse') ) 106 | gtCoarse_mask_path = gtCoarse_mask_path.replace('leftImg8bit','gtCoarse_labelIds') 107 | gtCoarse=np.array(Image.open(gtCoarse_mask_path)) 108 | 109 | 110 | #### 111 | 112 | mask_copy = mask.copy() 113 | if id2trainid: 114 | for k, v in id2trainid.items(): 115 | binary_mask = (mask_copy == k) 116 | #This should only apply to auto labelled images 117 | if ('refinement' in label_fn) and cfg.DROPOUT_COARSE_BOOST_CLASSES != None and v in cfg.DROPOUT_COARSE_BOOST_CLASSES and binary_mask.sum() > 0: 118 | binary_mask += (gtCoarse == k) 119 | binary_mask[binary_mask >= 1] = 1 120 | mask[binary_mask] = gtCoarse[binary_mask] 121 | mask[binary_mask] = v 122 | 123 | for x_offs, y_offs in tile_locations: 124 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size] 125 | for class_id in range(num_classes): 126 | if class_id in patch: 127 | patch_class = (patch == class_id).astype(int) 128 | centroid_y, centroid_x = center_of_mass(patch_class) 129 | centroid_y = int(centroid_y) + y_offs 130 | centroid_x = int(centroid_x) + x_offs 131 | centroid = (centroid_x, centroid_y) 132 | centroids[class_id].append((image_fn, label_fn, centroid, 133 | class_id)) 134 | pbar.update(1) 135 | return centroids 136 | 137 | 138 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 139 | """ 140 | Calculate class centroids for all classes for all images for all tiles. 141 | items: list of (image_fn, label_fn) 142 | tile size: size of tile 143 | returns: dict that contains a list of centroids for each class 144 | """ 145 | from multiprocessing.dummy import Pool 146 | from functools import partial 147 | pool = Pool(80) 148 | global pbar 149 | pbar = tqdm(total=len(items), desc='pooled centroid extraction', file=sys.stdout) 150 | class_centroids_item = partial(class_centroids_image, 151 | num_classes=num_classes, 152 | id2trainid=id2trainid, 153 | tile_size=tile_size) 154 | 155 | centroids = defaultdict(list) 156 | new_centroids = pool.map(class_centroids_item, items) 157 | pool.close() 158 | pool.join() 159 | 160 | # combine each image's items into a single global dict 161 | for image_items in new_centroids: 162 | for class_id in image_items: 163 | centroids[class_id].extend(image_items[class_id]) 164 | return centroids 165 | 166 | 167 | def unpooled_class_centroids_all(items, num_classes, id2trainid, 168 | tile_size=1024): 169 | """ 170 | Calculate class centroids for all classes for all images for all tiles. 171 | items: list of (image_fn, label_fn) 172 | tile size: size of tile 173 | returns: dict that contains a list of centroids for each class 174 | """ 175 | centroids = defaultdict(list) 176 | global pbar 177 | pbar = tqdm(total=len(items), desc='centroid extraction', file=sys.stdout) 178 | for image, label in items: 179 | new_centroids = class_centroids_image(item=(image, label), 180 | tile_size=tile_size, 181 | num_classes=num_classes, 182 | id2trainid=id2trainid) 183 | for class_id in new_centroids: 184 | centroids[class_id].extend(new_centroids[class_id]) 185 | 186 | return centroids 187 | 188 | 189 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 190 | """ 191 | intermediate function to call pooled_class_centroid 192 | """ 193 | pooled_centroids = pooled_class_centroids_all(items, num_classes, 194 | id2trainid, tile_size) 195 | # pooled_centroids = unpooled_class_centroids_all(items, num_classes, 196 | # id2trainid, tile_size) 197 | return pooled_centroids 198 | 199 | 200 | def random_sampling(alist, num): 201 | """ 202 | Randomly sample num items from the list 203 | alist: list of centroids to sample from 204 | num: can be larger than the list and if so, then wrap around 205 | return: class uniform samples from the list 206 | """ 207 | sampling = [] 208 | len_list = len(alist) 209 | assert len_list, 'len_list is zero!' 210 | indices = np.arange(len_list) 211 | np.random.shuffle(indices) 212 | 213 | for i in range(num): 214 | item = alist[indices[i % len_list]] 215 | sampling.append(item) 216 | return sampling 217 | 218 | 219 | def build_centroids(imgs, num_classes, train, cv=None, coarse=False, 220 | custom_coarse=False, id2trainid=None): 221 | """ 222 | The first step of uniform sampling is to decide sampling centers. 223 | The idea is to divide each image into tiles and within each tile, 224 | we compute a centroid for each class to indicate roughly where to 225 | sample a crop during training. 226 | 227 | This function computes these centroids and returns a list of them. 228 | """ 229 | if not (cfg.DATASET.CLASS_UNIFORM_PCT and train): 230 | return [] 231 | 232 | centroid_fn = cfg.DATASET.NAME 233 | 234 | if coarse or custom_coarse: 235 | if coarse: 236 | centroid_fn += '_coarse' 237 | if custom_coarse: 238 | centroid_fn += '_customcoarse_final' 239 | else: 240 | centroid_fn += '_cv{}'.format(cv) 241 | centroid_fn += '_tile{}.json'.format(cfg.DATASET.CLASS_UNIFORM_TILE) 242 | json_fn = os.path.join(cfg.DATASET.CENTROID_ROOT, 243 | centroid_fn) 244 | if os.path.isfile(json_fn): 245 | logx.msg('Loading centroid file {}'.format(json_fn)) 246 | with open(json_fn, 'r') as json_data: 247 | centroids = json.load(json_data) 248 | centroids = {int(idx): centroids[idx] for idx in centroids} 249 | logx.msg('Found {} centroids'.format(len(centroids))) 250 | else: 251 | logx.msg('Didn\'t find {}, so building it'.format(json_fn)) 252 | 253 | if cfg.GLOBAL_RANK==0: 254 | 255 | os.makedirs(cfg.DATASET.CENTROID_ROOT, exist_ok=True) 256 | # centroids is a dict (indexed by class) of lists of centroids 257 | centroids = class_centroids_all( 258 | imgs, 259 | num_classes, 260 | id2trainid=id2trainid) 261 | with open(json_fn, 'w') as outfile: 262 | json.dump(centroids, outfile, indent=4) 263 | 264 | # wait for everyone to be at the same point 265 | torch.distributed.barrier() 266 | 267 | # GPUs (except rank0) read in the just-created centroid file 268 | if cfg.GLOBAL_RANK != 0: 269 | msg = f'Expected to find {json_fn}' 270 | assert os.path.isfile(json_fn), msg 271 | with open(json_fn, 'r') as json_data: 272 | centroids = json.load(json_data) 273 | centroids = {int(idx): centroids[idx] for idx in centroids} 274 | 275 | return centroids 276 | 277 | 278 | def build_epoch(imgs, centroids, num_classes, train): 279 | """ 280 | Generate an epoch of crops using uniform sampling. 281 | Needs to be called every epoch. 282 | Will not apply uniform sampling if not train or class uniform is off. 283 | 284 | Inputs: 285 | imgs - list of imgs 286 | centroids - list of class centroids 287 | num_classes - number of classes 288 | class_uniform_pct: % of uniform images in one epoch 289 | Outputs: 290 | imgs - list of images to use this epoch 291 | """ 292 | class_uniform_pct = cfg.DATASET.CLASS_UNIFORM_PCT 293 | if not (train and class_uniform_pct): 294 | return imgs 295 | 296 | logx.msg("Class Uniform Percentage: {}".format(str(class_uniform_pct))) 297 | num_epoch = int(len(imgs)) 298 | 299 | logx.msg('Class Uniform items per Epoch: {}'.format(str(num_epoch))) 300 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes) 301 | class_uniform_count = num_per_class * num_classes 302 | num_rand = num_epoch - class_uniform_count 303 | # create random crops 304 | imgs_uniform = random_sampling(imgs, num_rand) 305 | 306 | # now add uniform sampling 307 | for class_id in range(num_classes): 308 | msg = "cls {} len {}".format(class_id, len(centroids[class_id])) 309 | logx.msg(msg) 310 | for class_id in range(num_classes): 311 | if cfg.DATASET.CLASS_UNIFORM_BIAS is not None: 312 | bias = cfg.DATASET.CLASS_UNIFORM_BIAS[class_id] 313 | num_per_class_biased = int(num_per_class * bias) 314 | else: 315 | num_per_class_biased = num_per_class 316 | centroid_len = len(centroids[class_id]) 317 | if centroid_len == 0: 318 | pass 319 | else: 320 | class_centroids = random_sampling(centroids[class_id], 321 | num_per_class_biased) 322 | imgs_uniform.extend(class_centroids) 323 | 324 | return imgs_uniform 325 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def make_dataset_folder(folder): 5 | """ 6 | Create Filename list for images in the provided path 7 | 8 | input: path to directory with *only* images files 9 | returns: items list with None filled for mask path 10 | """ 11 | items = os.listdir(folder) 12 | items = [(os.path.join(folder, f), '') for f in items] 13 | items = sorted(items) 14 | 15 | print(f'Found {len(items)} folder imgs') 16 | 17 | """ 18 | orig_len = len(items) 19 | rem = orig_len % 8 20 | if rem != 0: 21 | items = items[:-rem] 22 | 23 | msg = 'Found {} folder imgs but altered to {} to be modulo-8' 24 | msg = msg.format(orig_len, len(items)) 25 | print(msg) 26 | """ 27 | 28 | return items 29 | -------------------------------------------------------------------------------- /imgs/composited_sf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/semantic-segmentation/7726b144c2cc0b8e09c67eabb78f027efdf3f0fa/imgs/composited_sf.png -------------------------------------------------------------------------------- /imgs/test_imgs/nyc.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/semantic-segmentation/7726b144c2cc0b8e09c67eabb78f027efdf3f0fa/imgs/test_imgs/nyc.jpg -------------------------------------------------------------------------------- /imgs/test_imgs/sf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/semantic-segmentation/7726b144c2cc0b8e09c67eabb78f027efdf3f0fa/imgs/test_imgs/sf.jpg -------------------------------------------------------------------------------- /loss/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | 31 | # Optimizer and scheduler related tasks 32 | 33 | import math 34 | import torch 35 | 36 | from torch import optim 37 | from runx.logx import logx 38 | 39 | from config import cfg 40 | from loss.radam import RAdam 41 | 42 | 43 | def get_optimizer(args, net): 44 | """ 45 | Decide Optimizer (Adam or SGD) 46 | """ 47 | param_groups = net.parameters() 48 | 49 | if args.optimizer == 'sgd': 50 | optimizer = optim.SGD(param_groups, 51 | lr=args.lr, 52 | weight_decay=args.weight_decay, 53 | momentum=args.momentum, 54 | nesterov=False) 55 | elif args.optimizer == 'adam': 56 | optimizer = optim.Adam(param_groups, 57 | lr=args.lr, 58 | weight_decay=args.weight_decay, 59 | amsgrad=args.amsgrad) 60 | elif args.optimizer == 'radam': 61 | optimizer = RAdam(param_groups, 62 | lr=args.lr, 63 | weight_decay=args.weight_decay) 64 | else: 65 | raise ValueError('Not a valid optimizer') 66 | 67 | def poly_schd(epoch): 68 | return math.pow(1 - epoch / args.max_epoch, args.poly_exp) 69 | 70 | def poly2_schd(epoch): 71 | if epoch < args.poly_step: 72 | poly_exp = args.poly_exp 73 | else: 74 | poly_exp = 2 * args.poly_exp 75 | return math.pow(1 - epoch / args.max_epoch, poly_exp) 76 | 77 | if args.lr_schedule == 'scl-poly': 78 | if cfg.REDUCE_BORDER_EPOCH == -1: 79 | raise ValueError('ERROR Cannot Do Scale Poly') 80 | 81 | rescale_thresh = cfg.REDUCE_BORDER_EPOCH 82 | scale_value = args.rescale 83 | lambda1 = lambda epoch: \ 84 | math.pow(1 - epoch / args.max_epoch, 85 | args.poly_exp) if epoch < rescale_thresh else scale_value * math.pow( 86 | 1 - (epoch - rescale_thresh) / (args.max_epoch - rescale_thresh), 87 | args.repoly) 88 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 89 | elif args.lr_schedule == 'poly2': 90 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, 91 | lr_lambda=poly2_schd) 92 | elif args.lr_schedule == 'poly': 93 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, 94 | lr_lambda=poly_schd) 95 | else: 96 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule)) 97 | 98 | return optimizer, scheduler 99 | 100 | 101 | def load_weights(net, optimizer, snapshot_file, restore_optimizer_bool=False): 102 | """ 103 | Load weights from snapshot file 104 | """ 105 | logx.msg("Loading weights from model {}".format(snapshot_file)) 106 | net, optimizer = restore_snapshot(net, optimizer, snapshot_file, restore_optimizer_bool) 107 | return net, optimizer 108 | 109 | 110 | def restore_snapshot(net, optimizer, snapshot, restore_optimizer_bool): 111 | """ 112 | Restore weights and optimizer (if needed ) for resuming job. 113 | """ 114 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 115 | logx.msg("Checkpoint Load Compelete") 116 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool: 117 | optimizer.load_state_dict(checkpoint['optimizer']) 118 | 119 | if 'state_dict' in checkpoint: 120 | net = forgiving_state_restore(net, checkpoint['state_dict']) 121 | else: 122 | net = forgiving_state_restore(net, checkpoint) 123 | 124 | return net, optimizer 125 | 126 | 127 | def restore_opt(optimizer, checkpoint): 128 | assert 'optimizer' in checkpoint, 'cant find optimizer in checkpoint' 129 | optimizer.load_state_dict(checkpoint['optimizer']) 130 | 131 | 132 | def restore_net(net, checkpoint): 133 | assert 'state_dict' in checkpoint, 'cant find state_dict in checkpoint' 134 | forgiving_state_restore(net, checkpoint['state_dict']) 135 | 136 | 137 | def forgiving_state_restore(net, loaded_dict): 138 | """ 139 | Handle partial loading when some tensors don't match up in size. 140 | Because we want to use models that were trained off a different 141 | number of classes. 142 | """ 143 | 144 | net_state_dict = net.state_dict() 145 | new_loaded_dict = {} 146 | for k in net_state_dict: 147 | new_k = k 148 | if new_k in loaded_dict and net_state_dict[k].size() == loaded_dict[new_k].size(): 149 | new_loaded_dict[k] = loaded_dict[new_k] 150 | else: 151 | logx.msg("Skipped loading parameter {}".format(k)) 152 | net_state_dict.update(new_loaded_dict) 153 | net.load_state_dict(net_state_dict) 154 | return net 155 | -------------------------------------------------------------------------------- /loss/radam.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code adapted from: https://github.com/LiyuanLucasLiu/RAdam 3 | From the paper: https://arxiv.org/abs/1908.03265 4 | """ 5 | import math 6 | import torch 7 | # pylint: disable=no-name-in-module 8 | from torch.optim.optimizer import Optimizer 9 | 10 | 11 | class RAdam(Optimizer): 12 | """RAdam optimizer""" 13 | 14 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 15 | weight_decay=0): 16 | """ 17 | Init 18 | 19 | :param params: parameters to optimize 20 | :param lr: learning rate 21 | :param betas: beta 22 | :param eps: numerical precision 23 | :param weight_decay: weight decay weight 24 | """ 25 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 26 | self.buffer = [[None, None, None] for _ in range(10)] 27 | super().__init__(params, defaults) 28 | 29 | def step(self, closure=None): 30 | 31 | loss = None 32 | if closure is not None: 33 | loss = closure() 34 | 35 | for group in self.param_groups: 36 | 37 | for p in group['params']: 38 | if p.grad is None: 39 | continue 40 | grad = p.grad.data.float() 41 | if grad.is_sparse: 42 | raise RuntimeError( 43 | 'RAdam does not support sparse gradients' 44 | ) 45 | 46 | p_data_fp32 = p.data.float() 47 | 48 | state = self.state[p] 49 | 50 | if len(state) == 0: 51 | state['step'] = 0 52 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 53 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 54 | else: 55 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 56 | state['exp_avg_sq'] = ( 57 | state['exp_avg_sq'].type_as(p_data_fp32) 58 | ) 59 | 60 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 61 | beta1, beta2 = group['betas'] 62 | 63 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 64 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 65 | 66 | state['step'] += 1 67 | buffered = self.buffer[int(state['step'] % 10)] 68 | if state['step'] == buffered[0]: 69 | N_sma, step_size = buffered[1], buffered[2] 70 | else: 71 | buffered[0] = state['step'] 72 | beta2_t = beta2 ** state['step'] 73 | N_sma_max = 2 / (1 - beta2) - 1 74 | N_sma = ( 75 | N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 76 | ) 77 | buffered[1] = N_sma 78 | 79 | # more conservative since it's an approximated value 80 | if N_sma >= 5: 81 | step_size = ( 82 | group['lr'] * 83 | math.sqrt( 84 | (1 - beta2_t) * (N_sma - 4) / 85 | (N_sma_max - 4) * (N_sma - 2) / 86 | N_sma * N_sma_max / (N_sma_max - 2) 87 | ) / (1 - beta1 ** state['step']) 88 | ) 89 | else: 90 | step_size = group['lr'] / (1 - beta1 ** state['step']) 91 | buffered[2] = step_size 92 | 93 | if group['weight_decay'] != 0: 94 | p_data_fp32.add_( 95 | -group['weight_decay'] * group['lr'], p_data_fp32 96 | ) 97 | 98 | # more conservative since it's an approximated value 99 | if N_sma >= 5: 100 | denom = exp_avg_sq.sqrt().add_(group['eps']) 101 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 102 | else: 103 | p_data_fp32.add_(-step_size, exp_avg) 104 | 105 | p.data.copy_(p_data_fp32) 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /loss/rmi.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is adapted from: https://github.com/ZJULearning/RMI 3 | 4 | The implementation of the paper: 5 | Region Mutual Information Loss for Semantic Segmentation. 6 | """ 7 | 8 | # python 2.X, 3.X compatibility 9 | from __future__ import print_function 10 | from __future__ import division 11 | from __future__ import absolute_import 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from loss import rmi_utils 18 | from config import cfg 19 | from apex import amp 20 | 21 | _euler_num = 2.718281828 # euler number 22 | _pi = 3.14159265 # pi 23 | _ln_2_pi = 1.837877 # ln(2 * pi) 24 | _CLIP_MIN = 1e-6 # min clip value after softmax or sigmoid operations 25 | _CLIP_MAX = 1.0 # max clip value after softmax or sigmoid operations 26 | _POS_ALPHA = 5e-4 # add this factor to ensure the AA^T is positive definite 27 | _IS_SUM = 1 # sum the loss per channel 28 | 29 | 30 | __all__ = ['RMILoss'] 31 | 32 | 33 | class RMILoss(nn.Module): 34 | """ 35 | region mutual information 36 | I(A, B) = H(A) + H(B) - H(A, B) 37 | This version need a lot of memory if do not dwonsample. 38 | """ 39 | def __init__(self, 40 | num_classes=21, 41 | rmi_radius=3, 42 | rmi_pool_way=1, 43 | rmi_pool_size=4, 44 | rmi_pool_stride=4, 45 | loss_weight_lambda=0.5, 46 | lambda_way=1, 47 | ignore_index=255): 48 | super(RMILoss, self).__init__() 49 | self.num_classes = num_classes 50 | # radius choices 51 | assert rmi_radius in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 52 | self.rmi_radius = rmi_radius 53 | assert rmi_pool_way in [0, 1, 2, 3] 54 | self.rmi_pool_way = rmi_pool_way 55 | 56 | # set the pool_size = rmi_pool_stride 57 | assert rmi_pool_size == rmi_pool_stride 58 | self.rmi_pool_size = rmi_pool_size 59 | self.rmi_pool_stride = rmi_pool_stride 60 | self.weight_lambda = loss_weight_lambda 61 | self.lambda_way = lambda_way 62 | 63 | # dimension of the distribution 64 | self.half_d = self.rmi_radius * self.rmi_radius 65 | self.d = 2 * self.half_d 66 | self.kernel_padding = self.rmi_pool_size // 2 67 | # ignore class 68 | self.ignore_index = ignore_index 69 | 70 | def forward(self, logits_4D, labels_4D, do_rmi=True): 71 | # explicitly disable fp16 mode because torch.cholesky and 72 | # torch.inverse aren't supported by half 73 | logits_4D.float() 74 | labels_4D.float() 75 | if cfg.TRAIN.FP16: 76 | with amp.disable_casts(): 77 | loss = self.forward_sigmoid(logits_4D, labels_4D, do_rmi=do_rmi) 78 | else: 79 | loss = self.forward_sigmoid(logits_4D, labels_4D, do_rmi=do_rmi) 80 | return loss 81 | 82 | def forward_sigmoid(self, logits_4D, labels_4D, do_rmi=False): 83 | """ 84 | Using the sigmiod operation both. 85 | Args: 86 | logits_4D : [N, C, H, W], dtype=float32 87 | labels_4D : [N, H, W], dtype=long 88 | do_rmi : bool 89 | """ 90 | # label mask -- [N, H, W, 1] 91 | label_mask_3D = labels_4D < self.num_classes 92 | 93 | # valid label 94 | valid_onehot_labels_4D = \ 95 | F.one_hot(labels_4D.long() * label_mask_3D.long(), 96 | num_classes=self.num_classes).float() 97 | label_mask_3D = label_mask_3D.float() 98 | label_mask_flat = label_mask_3D.view([-1, ]) 99 | valid_onehot_labels_4D = valid_onehot_labels_4D * \ 100 | label_mask_3D.unsqueeze(dim=3) 101 | valid_onehot_labels_4D.requires_grad_(False) 102 | 103 | # PART I -- calculate the sigmoid binary cross entropy loss 104 | valid_onehot_label_flat = \ 105 | valid_onehot_labels_4D.view([-1, self.num_classes]).requires_grad_(False) 106 | logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes]) 107 | 108 | # binary loss, multiplied by the not_ignore_mask 109 | valid_pixels = torch.sum(label_mask_flat) 110 | binary_loss = F.binary_cross_entropy_with_logits(logits_flat, 111 | target=valid_onehot_label_flat, 112 | weight=label_mask_flat.unsqueeze(dim=1), 113 | reduction='sum') 114 | bce_loss = torch.div(binary_loss, valid_pixels + 1.0) 115 | if not do_rmi: 116 | return bce_loss 117 | 118 | # PART II -- get rmi loss 119 | # onehot_labels_4D -- [N, C, H, W] 120 | probs_4D = logits_4D.sigmoid() * label_mask_3D.unsqueeze(dim=1) + _CLIP_MIN 121 | valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) 122 | 123 | # get region mutual information 124 | rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) 125 | 126 | # add together 127 | #logx.msg(f'lambda_way {self.lambda_way}') 128 | #logx.msg(f'bce_loss {bce_loss} weight_lambda {self.weight_lambda} rmi_loss {rmi_loss}') 129 | if self.lambda_way: 130 | final_loss = self.weight_lambda * bce_loss + rmi_loss * (1 - self.weight_lambda) 131 | else: 132 | final_loss = bce_loss + rmi_loss * self.weight_lambda 133 | 134 | return final_loss 135 | 136 | def inverse(self, x): 137 | return torch.inverse(x) 138 | 139 | def rmi_lower_bound(self, labels_4D, probs_4D): 140 | """ 141 | calculate the lower bound of the region mutual information. 142 | Args: 143 | labels_4D : [N, C, H, W], dtype=float32 144 | probs_4D : [N, C, H, W], dtype=float32 145 | """ 146 | assert labels_4D.size() == probs_4D.size() 147 | 148 | p, s = self.rmi_pool_size, self.rmi_pool_stride 149 | if self.rmi_pool_stride > 1: 150 | if self.rmi_pool_way == 0: 151 | labels_4D = F.max_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 152 | probs_4D = F.max_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 153 | elif self.rmi_pool_way == 1: 154 | labels_4D = F.avg_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 155 | probs_4D = F.avg_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 156 | elif self.rmi_pool_way == 2: 157 | # interpolation 158 | shape = labels_4D.size() 159 | new_h, new_w = shape[2] // s, shape[3] // s 160 | labels_4D = F.interpolate(labels_4D, size=(new_h, new_w), mode='nearest') 161 | probs_4D = F.interpolate(probs_4D, size=(new_h, new_w), mode='bilinear', align_corners=True) 162 | else: 163 | raise NotImplementedError("Pool way of RMI is not defined!") 164 | # we do not need the gradient of label. 165 | label_shape = labels_4D.size() 166 | n, c = label_shape[0], label_shape[1] 167 | 168 | # combine the high dimension points from label and probability map. new shape [N, C, radius * radius, H, W] 169 | la_vectors, pr_vectors = rmi_utils.map_get_pairs(labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0) 170 | 171 | la_vectors = la_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor).requires_grad_(False) 172 | pr_vectors = pr_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor) 173 | 174 | # small diagonal matrix, shape = [1, 1, radius * radius, radius * radius] 175 | diag_matrix = torch.eye(self.half_d).unsqueeze(dim=0).unsqueeze(dim=0) 176 | 177 | # the mean and covariance of these high dimension points 178 | # Var(X) = E(X^2) - E(X) E(X), N * Var(X) = X^2 - X E(X) 179 | la_vectors = la_vectors - la_vectors.mean(dim=3, keepdim=True) 180 | la_cov = torch.matmul(la_vectors, la_vectors.transpose(2, 3)) 181 | 182 | pr_vectors = pr_vectors - pr_vectors.mean(dim=3, keepdim=True) 183 | pr_cov = torch.matmul(pr_vectors, pr_vectors.transpose(2, 3)) 184 | # https://github.com/pytorch/pytorch/issues/7500 185 | # waiting for batched torch.cholesky_inverse() 186 | # pr_cov_inv = torch.inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) 187 | pr_cov_inv = self.inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) 188 | # if the dimension of the point is less than 9, you can use the below function 189 | # to acceleration computational speed. 190 | #pr_cov_inv = utils.batch_cholesky_inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) 191 | 192 | la_pr_cov = torch.matmul(la_vectors, pr_vectors.transpose(2, 3)) 193 | # the approxiamation of the variance, det(c A) = c^n det(A), A is in n x n shape; 194 | # then log det(c A) = n log(c) + log det(A). 195 | # appro_var = appro_var / n_points, we do not divide the appro_var by number of points here, 196 | # and the purpose is to avoid underflow issue. 197 | # If A = A^T, A^-1 = (A^-1)^T. 198 | appro_var = la_cov - torch.matmul(la_pr_cov.matmul(pr_cov_inv), la_pr_cov.transpose(-2, -1)) 199 | #appro_var = la_cov - torch.chain_matmul(la_pr_cov, pr_cov_inv, la_pr_cov.transpose(-2, -1)) 200 | #appro_var = torch.div(appro_var, n_points.type_as(appro_var)) + diag_matrix.type_as(appro_var) * 1e-6 201 | 202 | # The lower bound. If A is nonsingular, ln( det(A) ) = Tr( ln(A) ). 203 | rmi_now = 0.5 * rmi_utils.log_det_by_cholesky(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) 204 | #rmi_now = 0.5 * torch.logdet(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) 205 | 206 | # mean over N samples. sum over classes. 207 | rmi_per_class = rmi_now.view([-1, self.num_classes]).mean(dim=0).float() 208 | #is_half = False 209 | #if is_half: 210 | # rmi_per_class = torch.div(rmi_per_class, float(self.half_d / 2.0)) 211 | #else: 212 | rmi_per_class = torch.div(rmi_per_class, float(self.half_d)) 213 | 214 | rmi_loss = torch.sum(rmi_per_class) if _IS_SUM else torch.mean(rmi_per_class) 215 | return rmi_loss 216 | -------------------------------------------------------------------------------- /loss/rmi_utils.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from: https://github.com/ZJULearning/RMI 2 | 3 | # python 2.X, 3.X compatibility 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import absolute_import 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | __all__ = ['map_get_pairs', 'log_det_by_cholesky'] 13 | 14 | 15 | def map_get_pairs(labels_4D, probs_4D, radius=3, is_combine=True): 16 | """get map pairs 17 | Args: 18 | labels_4D : labels, shape [N, C, H, W] 19 | probs_4D : probabilities, shape [N, C, H, W] 20 | radius : the square radius 21 | Return: 22 | tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)] 23 | """ 24 | # pad to ensure the following slice operation is valid 25 | #pad_beg = int(radius // 2) 26 | #pad_end = radius - pad_beg 27 | 28 | # the original height and width 29 | label_shape = labels_4D.size() 30 | h, w = label_shape[2], label_shape[3] 31 | new_h, new_w = h - (radius - 1), w - (radius - 1) 32 | # https://pytorch.org/docs/stable/nn.html?highlight=f%20pad#torch.nn.functional.pad 33 | #padding = (pad_beg, pad_end, pad_beg, pad_end) 34 | #labels_4D, probs_4D = F.pad(labels_4D, padding), F.pad(probs_4D, padding) 35 | 36 | # get the neighbors 37 | la_ns = [] 38 | pr_ns = [] 39 | #for x in range(0, radius, 1): 40 | for y in range(0, radius, 1): 41 | for x in range(0, radius, 1): 42 | la_now = labels_4D[:, :, y:y + new_h, x:x + new_w] 43 | pr_now = probs_4D[:, :, y:y + new_h, x:x + new_w] 44 | la_ns.append(la_now) 45 | pr_ns.append(pr_now) 46 | 47 | if is_combine: 48 | # for calculating RMI 49 | pair_ns = la_ns + pr_ns 50 | p_vectors = torch.stack(pair_ns, dim=2) 51 | return p_vectors 52 | else: 53 | # for other purpose 54 | la_vectors = torch.stack(la_ns, dim=2) 55 | pr_vectors = torch.stack(pr_ns, dim=2) 56 | return la_vectors, pr_vectors 57 | 58 | 59 | def map_get_pairs_region(labels_4D, probs_4D, radius=3, is_combine=0, num_classeses=21): 60 | """get map pairs 61 | Args: 62 | labels_4D : labels, shape [N, C, H, W]. 63 | probs_4D : probabilities, shape [N, C, H, W]. 64 | radius : The side length of the square region. 65 | Return: 66 | A tensor with shape [N, C, radiu * radius, H // radius, W // raidius] 67 | """ 68 | kernel = torch.zeros([num_classeses, 1, radius, radius]).type_as(probs_4D) 69 | padding = radius // 2 70 | # get the neighbours 71 | la_ns = [] 72 | pr_ns = [] 73 | for y in range(0, radius, 1): 74 | for x in range(0, radius, 1): 75 | kernel_now = kernel.clone() 76 | kernel_now[:, :, y, x] = 1.0 77 | la_now = F.conv2d(labels_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) 78 | pr_now = F.conv2d(probs_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) 79 | la_ns.append(la_now) 80 | pr_ns.append(pr_now) 81 | 82 | if is_combine: 83 | # for calculating RMI 84 | pair_ns = la_ns + pr_ns 85 | p_vectors = torch.stack(pair_ns, dim=2) 86 | return p_vectors 87 | else: 88 | # for other purpose 89 | la_vectors = torch.stack(la_ns, dim=2) 90 | pr_vectors = torch.stack(pr_ns, dim=2) 91 | return la_vectors, pr_vectors 92 | return 93 | 94 | 95 | def log_det_by_cholesky(matrix): 96 | """ 97 | Args: 98 | matrix: matrix must be a positive define matrix. 99 | shape [N, C, D, D]. 100 | Ref: 101 | https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/linalg/linalg_impl.py 102 | """ 103 | # This uses the property that the log det(A) = 2 * sum(log(real(diag(C)))) 104 | # where C is the cholesky decomposition of A. 105 | chol = torch.cholesky(matrix) 106 | #return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-6), dim=-1) 107 | return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-8), dim=-1) 108 | 109 | 110 | def batch_cholesky_inverse(matrix): 111 | """ 112 | Args: matrix, 4-D tensor, [N, C, M, M]. 113 | matrix must be a symmetric positive define matrix. 114 | """ 115 | chol_low = torch.cholesky(matrix, upper=False) 116 | chol_low_inv = batch_low_tri_inv(chol_low) 117 | return torch.matmul(chol_low_inv.transpose(-2, -1), chol_low_inv) 118 | 119 | 120 | def batch_low_tri_inv(L): 121 | """ 122 | Batched inverse of lower triangular matrices 123 | Args: 124 | L : a lower triangular matrix 125 | Ref: 126 | https://www.pugetsystems.com/labs/hpc/PyTorch-for-Scientific-Computing 127 | """ 128 | n = L.shape[-1] 129 | invL = torch.zeros_like(L) 130 | for j in range(0, n): 131 | invL[..., j, j] = 1.0 / L[..., j, j] 132 | for i in range(j + 1, n): 133 | S = 0.0 134 | for k in range(0, i + 1): 135 | S = S - L[..., i, k] * invL[..., k, j].clone() 136 | invL[..., i, j] = S / L[..., i, i] 137 | return invL 138 | 139 | 140 | def log_det_by_cholesky_test(): 141 | """ 142 | test for function log_det_by_cholesky() 143 | """ 144 | a = torch.randn(1, 4, 4) 145 | a = torch.matmul(a, a.transpose(2, 1)) 146 | print(a) 147 | res_1 = torch.logdet(torch.squeeze(a)) 148 | res_2 = log_det_by_cholesky(a) 149 | print(res_1, res_2) 150 | 151 | 152 | def batch_inv_test(): 153 | """ 154 | test for function batch_cholesky_inverse() 155 | """ 156 | a = torch.randn(1, 1, 4, 4) 157 | a = torch.matmul(a, a.transpose(-2, -1)) 158 | print(a) 159 | res_1 = torch.inverse(a) 160 | res_2 = batch_cholesky_inverse(a) 161 | print(res_1, '\n', res_2) 162 | 163 | 164 | def mean_var_test(): 165 | x = torch.randn(3, 4) 166 | y = torch.randn(3, 4) 167 | 168 | x_mean = x.mean(dim=1, keepdim=True) 169 | x_sum = x.sum(dim=1, keepdim=True) / 2.0 170 | y_mean = y.mean(dim=1, keepdim=True) 171 | y_sum = y.sum(dim=1, keepdim=True) / 2.0 172 | 173 | x_var_1 = torch.matmul(x - x_mean, (x - x_mean).t()) 174 | x_var_2 = torch.matmul(x, x.t()) - torch.matmul(x_sum, x_sum.t()) 175 | xy_cov = torch.matmul(x - x_mean, (y - y_mean).t()) 176 | xy_cov_1 = torch.matmul(x, y.t()) - x_sum.matmul(y_sum.t()) 177 | 178 | print(x_var_1) 179 | print(x_var_2) 180 | 181 | print(xy_cov, '\n', xy_cov_1) 182 | 183 | 184 | if __name__ == '__main__': 185 | batch_inv_test() 186 | -------------------------------------------------------------------------------- /network/Resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code Adapted from: 3 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | import torch.nn as nn 37 | import torch.utils.model_zoo as model_zoo 38 | import network.mynn as mynn 39 | 40 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 41 | 'resnet152'] 42 | 43 | 44 | model_urls = { 45 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 46 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 47 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 48 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 49 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 50 | } 51 | 52 | 53 | def conv3x3(in_planes, out_planes, stride=1): 54 | """3x3 convolution with padding""" 55 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | 58 | 59 | class BasicBlock(nn.Module): 60 | """ 61 | Basic Block for Resnet 62 | """ 63 | expansion = 1 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(BasicBlock, self).__init__() 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = mynn.Norm2d(planes) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.conv2 = conv3x3(planes, planes) 71 | self.bn2 = mynn.Norm2d(planes) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Bottleneck(nn.Module): 95 | """ 96 | Bottleneck Layer for Resnet 97 | """ 98 | expansion = 4 99 | 100 | def __init__(self, inplanes, planes, stride=1, downsample=None): 101 | super(Bottleneck, self).__init__() 102 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 103 | self.bn1 = mynn.Norm2d(planes) 104 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 105 | padding=1, bias=False) 106 | self.bn2 = mynn.Norm2d(planes) 107 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 108 | self.bn3 = mynn.Norm2d(planes * self.expansion) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.downsample = downsample 111 | self.stride = stride 112 | 113 | def forward(self, x): 114 | residual = x 115 | 116 | out = self.conv1(x) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.downsample is not None: 128 | residual = self.downsample(x) 129 | 130 | out += residual 131 | out = self.relu(out) 132 | 133 | return out 134 | 135 | 136 | class ResNet(nn.Module): 137 | """ 138 | Resnet Global Module for Initialization 139 | """ 140 | def __init__(self, block, layers, num_classes=1000): 141 | self.inplanes = 64 142 | super(ResNet, self).__init__() 143 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = mynn.Norm2d(64) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 152 | self.avgpool = nn.AvgPool2d(7, stride=1) 153 | self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, nn.BatchNorm2d): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1): 163 | downsample = None 164 | if stride != 1 or self.inplanes != planes * block.expansion: 165 | downsample = nn.Sequential( 166 | nn.Conv2d(self.inplanes, planes * block.expansion, 167 | kernel_size=1, stride=stride, bias=False), 168 | mynn.Norm2d(planes * block.expansion), 169 | ) 170 | 171 | layers = [] 172 | layers.append(block(self.inplanes, planes, stride, downsample)) 173 | self.inplanes = planes * block.expansion 174 | for index in range(1, blocks): 175 | layers.append(block(self.inplanes, planes)) 176 | 177 | return nn.Sequential(*layers) 178 | 179 | def forward(self, x): 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | x = self.maxpool(x) 184 | 185 | x = self.layer1(x) 186 | x = self.layer2(x) 187 | x = self.layer3(x) 188 | x = self.layer4(x) 189 | 190 | x = self.avgpool(x) 191 | x = x.view(x.size(0), -1) 192 | x = self.fc(x) 193 | 194 | return x 195 | 196 | 197 | def resnet18(pretrained=True, **kwargs): 198 | """Constructs a ResNet-18 model. 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 204 | if pretrained: 205 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 206 | return model 207 | 208 | 209 | def resnet34(pretrained=True, **kwargs): 210 | """Constructs a ResNet-34 model. 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 216 | if pretrained: 217 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 218 | return model 219 | 220 | 221 | def resnet50(pretrained=True, **kwargs): 222 | """Constructs a ResNet-50 model. 223 | 224 | Args: 225 | pretrained (bool): If True, returns a model pre-trained on ImageNet 226 | """ 227 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 228 | if pretrained: 229 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 230 | return model 231 | 232 | 233 | def resnet101(pretrained=True, **kwargs): 234 | """Constructs a ResNet-101 model. 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | """ 239 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 240 | if pretrained: 241 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 242 | return model 243 | 244 | 245 | def resnet152(pretrained=True, **kwargs): 246 | """Constructs a ResNet-152 model. 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | """ 251 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 252 | if pretrained: 253 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 254 | return model 255 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Network Initializations 3 | """ 4 | 5 | import importlib 6 | import torch 7 | 8 | from runx.logx import logx 9 | from config import cfg 10 | 11 | 12 | def get_net(args, criterion): 13 | """ 14 | Get Network Architecture based on arguments provided 15 | """ 16 | net = get_model(network='network.' + args.arch, 17 | num_classes=cfg.DATASET.NUM_CLASSES, 18 | criterion=criterion) 19 | num_params = sum([param.nelement() for param in net.parameters()]) 20 | logx.msg('Model params = {:2.1f}M'.format(num_params / 1000000)) 21 | 22 | net = net.cuda() 23 | return net 24 | 25 | 26 | def is_gscnn_arch(args): 27 | """ 28 | Network is a GSCNN network 29 | """ 30 | return 'gscnn' in args.arch 31 | 32 | 33 | def wrap_network_in_dataparallel(net, use_apex_data_parallel=False): 34 | """ 35 | Wrap the network in Dataparallel 36 | """ 37 | if use_apex_data_parallel: 38 | import apex 39 | net = apex.parallel.DistributedDataParallel(net) 40 | else: 41 | net = torch.nn.DataParallel(net) 42 | return net 43 | 44 | 45 | def get_model(network, num_classes, criterion): 46 | """ 47 | Fetch Network Function Pointer 48 | """ 49 | module = network[:network.rfind('.')] 50 | model = network[network.rfind('.') + 1:] 51 | mod = importlib.import_module(module) 52 | net_func = getattr(mod, model) 53 | net = net_func(num_classes=num_classes, criterion=criterion) 54 | return net 55 | -------------------------------------------------------------------------------- /network/basic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | from torch import nn 31 | 32 | from network.mynn import initialize_weights, Upsample 33 | from network.mynn import scale_as 34 | from network.utils import get_aspp, get_trunk, make_seg_head 35 | from config import cfg 36 | 37 | 38 | class Basic(nn.Module): 39 | """ 40 | Basic segmentation network, no ASPP, no Mscale 41 | """ 42 | def __init__(self, num_classes, trunk='hrnetv2', criterion=None): 43 | super(Basic, self).__init__() 44 | self.criterion = criterion 45 | self.backbone, _, _, high_level_ch = get_trunk( 46 | trunk_name=trunk, output_stride=8) 47 | self.seg_head = make_seg_head(in_ch=high_level_ch, 48 | out_ch=num_classes) 49 | initialize_weights(self.seg_head) 50 | 51 | def forward(self, inputs): 52 | x = inputs['images'] 53 | _, _, final_features = self.backbone(x) 54 | pred = self.seg_head(final_features) 55 | pred = scale_as(pred, x) 56 | 57 | if self.training: 58 | assert 'gts' in inputs 59 | gts = inputs['gts'] 60 | loss = self.criterion(pred, gts) 61 | return loss 62 | else: 63 | output_dict = {'pred': pred} 64 | return output_dict 65 | 66 | 67 | class ASPP(nn.Module): 68 | """ 69 | ASPP-based Segmentation network 70 | """ 71 | def __init__(self, num_classes, trunk='hrnetv2', criterion=None): 72 | super(ASPP, self).__init__() 73 | self.criterion = criterion 74 | self.backbone, _, _, high_level_ch = get_trunk(trunk) 75 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, 76 | bottleneck_ch=cfg.MODEL.ASPP_BOT_CH, 77 | output_stride=8) 78 | self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) 79 | self.final = make_seg_head(in_ch=256, 80 | out_ch=num_classes) 81 | 82 | initialize_weights(self.final, self.bot_aspp, self.aspp) 83 | 84 | def forward(self, inputs): 85 | x = inputs['images'] 86 | x_size = x.size() 87 | 88 | _, _, final_features = self.backbone(x) 89 | aspp = self.aspp(final_features) 90 | aspp = self.bot_aspp(aspp) 91 | pred = self.final(aspp) 92 | pred = Upsample(pred, x_size[2:]) 93 | 94 | if self.training: 95 | assert 'gts' in inputs 96 | gts = inputs['gts'] 97 | loss = self.criterion(pred, gts) 98 | return loss 99 | else: 100 | output_dict = {'pred': pred} 101 | return output_dict 102 | 103 | 104 | def HRNet(num_classes, criterion, s2s4=None): 105 | return Basic(num_classes=num_classes, criterion=criterion, 106 | trunk='hrnetv2') 107 | 108 | 109 | def HRNet_ASP(num_classes, criterion, s2s4=None): 110 | return ASPP(num_classes=num_classes, criterion=criterion, 111 | trunk='hrnetv2') 112 | -------------------------------------------------------------------------------- /network/bn_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | 4 | if torch.__version__.startswith('0'): 5 | from .sync_bn.inplace_abn.bn import InPlaceABNSync 6 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 7 | BatchNorm2d_class = InPlaceABNSync 8 | relu_inplace = False 9 | else: 10 | BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm 11 | relu_inplace = True 12 | -------------------------------------------------------------------------------- /network/deeper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | import torch 31 | from torch import nn 32 | from network.mynn import Upsample2 33 | from network.utils import ConvBnRelu, get_trunk, get_aspp 34 | 35 | 36 | class DeeperS8(nn.Module): 37 | """ 38 | Panoptic DeepLab-style semantic segmentation network 39 | stride8 only 40 | """ 41 | def __init__(self, num_classes, trunk='wrn38', criterion=None): 42 | super(DeeperS8, self).__init__() 43 | 44 | self.criterion = criterion 45 | self.trunk, s2_ch, s4_ch, high_level_ch = get_trunk(trunk_name=trunk, 46 | output_stride=8) 47 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, 48 | output_stride=8) 49 | 50 | self.convs2 = nn.Conv2d(s2_ch, 32, kernel_size=1, bias=False) 51 | self.convs4 = nn.Conv2d(s4_ch, 64, kernel_size=1, bias=False) 52 | self.conv_up1 = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) 53 | self.conv_up2 = ConvBnRelu(256 + 64, 256, kernel_size=5, padding=2) 54 | self.conv_up3 = ConvBnRelu(256 + 32, 256, kernel_size=5, padding=2) 55 | self.conv_up5 = nn.Conv2d(256, num_classes, kernel_size=1, bias=False) 56 | 57 | def forward(self, inputs, gts=None): 58 | assert 'images' in inputs 59 | x = inputs['images'] 60 | 61 | s2_features, s4_features, final_features = self.trunk(x) 62 | s2_features = self.convs2(s2_features) 63 | s4_features = self.convs4(s4_features) 64 | aspp = self.aspp(final_features) 65 | x = self.conv_up1(aspp) 66 | x = Upsample2(x) 67 | x = torch.cat([x, s4_features], 1) 68 | x = self.conv_up2(x) 69 | x = Upsample2(x) 70 | x = torch.cat([x, s2_features], 1) 71 | x = self.conv_up3(x) 72 | x = self.conv_up5(x) 73 | x = Upsample2(x) 74 | 75 | if self.training: 76 | assert 'gts' in inputs 77 | gts = inputs['gts'] 78 | return self.criterion(x, gts) 79 | return {'pred': x} 80 | 81 | 82 | def DeeperW38(num_classes, criterion, s2s4=True): 83 | return DeeperS8(num_classes, criterion=criterion, trunk='wrn38') 84 | 85 | 86 | def DeeperX71(num_classes, criterion, s2s4=True): 87 | return DeeperS8(num_classes, criterion=criterion, trunk='xception71') 88 | 89 | 90 | def DeeperEffB4(num_classes, criterion, s2s4=True): 91 | return DeeperS8(num_classes, criterion=criterion, trunk='efficientnet_b4') 92 | -------------------------------------------------------------------------------- /network/deepv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Adapted from: 3 | https://github.com/sthalles/deeplab_v3 4 | 5 | Copyright 2020 Nvidia Corporation 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its contributors 18 | may be used to endorse or promote products derived from this software 19 | without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | POSSIBILITY OF SUCH DAMAGE. 32 | """ 33 | import torch 34 | from torch import nn 35 | 36 | from network.mynn import initialize_weights, Norm2d, Upsample 37 | from network.utils import get_aspp, get_trunk, make_seg_head 38 | 39 | 40 | class DeepV3Plus(nn.Module): 41 | """ 42 | DeepLabV3+ with various trunks supported 43 | Always stride8 44 | """ 45 | def __init__(self, num_classes, trunk='wrn38', criterion=None, 46 | use_dpc=False, init_all=False): 47 | super(DeepV3Plus, self).__init__() 48 | self.criterion = criterion 49 | self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) 50 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, 51 | bottleneck_ch=256, 52 | output_stride=8, 53 | dpc=use_dpc) 54 | self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) 55 | self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) 56 | self.final = nn.Sequential( 57 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), 58 | Norm2d(256), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 61 | Norm2d(256), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 64 | 65 | if init_all: 66 | initialize_weights(self.aspp) 67 | initialize_weights(self.bot_aspp) 68 | initialize_weights(self.bot_fine) 69 | initialize_weights(self.final) 70 | else: 71 | initialize_weights(self.final) 72 | 73 | def forward(self, inputs): 74 | assert 'images' in inputs 75 | x = inputs['images'] 76 | 77 | x_size = x.size() 78 | s2_features, _, final_features = self.backbone(x) 79 | aspp = self.aspp(final_features) 80 | conv_aspp = self.bot_aspp(aspp) 81 | conv_s2 = self.bot_fine(s2_features) 82 | conv_aspp = Upsample(conv_aspp, s2_features.size()[2:]) 83 | cat_s4 = [conv_s2, conv_aspp] 84 | cat_s4 = torch.cat(cat_s4, 1) 85 | final = self.final(cat_s4) 86 | out = Upsample(final, x_size[2:]) 87 | 88 | if self.training: 89 | assert 'gts' in inputs 90 | gts = inputs['gts'] 91 | return self.criterion(out, gts) 92 | 93 | return {'pred': out} 94 | 95 | 96 | def DeepV3PlusSRNX50(num_classes, criterion): 97 | return DeepV3Plus(num_classes, trunk='seresnext-50', criterion=criterion) 98 | 99 | 100 | def DeepV3PlusR50(num_classes, criterion): 101 | return DeepV3Plus(num_classes, trunk='resnet-50', criterion=criterion) 102 | 103 | 104 | def DeepV3PlusSRNX101(num_classes, criterion): 105 | return DeepV3Plus(num_classes, trunk='seresnext-101', criterion=criterion) 106 | 107 | 108 | def DeepV3PlusW38(num_classes, criterion): 109 | return DeepV3Plus(num_classes, trunk='wrn38', criterion=criterion) 110 | 111 | 112 | def DeepV3PlusW38I(num_classes, criterion): 113 | return DeepV3Plus(num_classes, trunk='wrn38', criterion=criterion, 114 | init_all=True) 115 | 116 | 117 | def DeepV3PlusX71(num_classes, criterion): 118 | return DeepV3Plus(num_classes, trunk='xception71', criterion=criterion) 119 | 120 | 121 | def DeepV3PlusEffB4(num_classes, criterion): 122 | return DeepV3Plus(num_classes, trunk='efficientnet_b4', 123 | criterion=criterion) 124 | 125 | 126 | class DeepV3(nn.Module): 127 | """ 128 | DeepLabV3 with various trunks supported 129 | """ 130 | def __init__(self, num_classes, trunk='resnet-50', criterion=None, 131 | use_dpc=False, init_all=False, output_stride=8): 132 | super(DeepV3, self).__init__() 133 | self.criterion = criterion 134 | 135 | self.backbone, _s2_ch, _s4_ch, high_level_ch = \ 136 | get_trunk(trunk, output_stride=output_stride) 137 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, 138 | bottleneck_ch=256, 139 | output_stride=output_stride, 140 | dpc=use_dpc) 141 | self.final = make_seg_head(in_ch=aspp_out_ch, out_ch=num_classes) 142 | 143 | initialize_weights(self.aspp) 144 | initialize_weights(self.final) 145 | 146 | def forward(self, inputs): 147 | assert 'images' in inputs 148 | x = inputs['images'] 149 | 150 | x_size = x.size() 151 | _, _, final_features = self.backbone(x) 152 | aspp = self.aspp(final_features) 153 | final = self.final(aspp) 154 | out = Upsample(final, x_size[2:]) 155 | 156 | if self.training: 157 | assert 'gts' in inputs 158 | gts = inputs['gts'] 159 | return self.criterion(out, gts) 160 | 161 | return {'pred': out} 162 | 163 | 164 | def DeepV3R50(num_classes, criterion): 165 | return DeepV3(num_classes, trunk='resnet-50', criterion=criterion) 166 | 167 | -------------------------------------------------------------------------------- /network/mscale2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | This is an alternative implementation of mscale, where we feed pairs of 32 | features from both lower and higher resolution images into the attention head. 33 | """ 34 | import torch 35 | from torch import nn 36 | 37 | from network.mynn import initialize_weights, Norm2d, Upsample 38 | from network.mynn import ResizeX, scale_as 39 | from network.utils import get_aspp, get_trunk 40 | from network.utils import make_seg_head, make_attn_head 41 | from config import cfg 42 | 43 | 44 | class MscaleBase(nn.Module): 45 | """ 46 | Multi-scale attention segmentation model base class 47 | """ 48 | def __init__(self): 49 | super(MscaleBase, self).__init__() 50 | self.criterion = None 51 | 52 | def _fwd(self, x): 53 | pass 54 | 55 | def nscale_forward(self, inputs, scales): 56 | """ 57 | Hierarchical attention, primarily used for getting best inference 58 | results. 59 | 60 | We use attention at multiple scales, giving priority to the lower 61 | resolutions. For example, if we have 4 scales {0.5, 1.0, 1.5, 2.0}, 62 | then evaluation is done as follows: 63 | 64 | p_joint = attn_1.5 * p_1.5 + (1 - attn_1.5) * down(p_2.0) 65 | p_joint = attn_1.0 * p_1.0 + (1 - attn_1.0) * down(p_joint) 66 | p_joint = up(attn_0.5 * p_0.5) * (1 - up(attn_0.5)) * p_joint 67 | 68 | The target scale is always 1.0, and 1.0 is expected to be part of the 69 | list of scales. When predictions are done at greater than 1.0 scale, 70 | the predictions are downsampled before combining with the next lower 71 | scale. 72 | 73 | Inputs: 74 | scales - a list of scales to evaluate 75 | inputs - dict containing 'images', the input, and 'gts', the ground 76 | truth mask 77 | 78 | Output: 79 | If training, return loss, else return prediction + attention 80 | """ 81 | x_1x = inputs['images'] 82 | 83 | assert 1.0 in scales, 'expected 1.0 to be the target scale' 84 | # Lower resolution provides attention for higher rez predictions, 85 | # so we evaluate in order: high to low 86 | scales = sorted(scales, reverse=True) 87 | pred = None 88 | last_feats = None 89 | 90 | for idx, s in enumerate(scales): 91 | x = ResizeX(x_1x, s) 92 | p, feats = self._fwd(x) 93 | 94 | # Generate attention prediction 95 | if idx > 0: 96 | assert last_feats is not None 97 | # downscale feats 98 | last_feats = scale_as(last_feats, feats) 99 | cat_feats = torch.cat([feats, last_feats], 1) 100 | attn = self.scale_attn(cat_feats) 101 | attn = scale_as(attn, p) 102 | 103 | if pred is None: 104 | # This is the top scale prediction 105 | pred = p 106 | elif s >= 1.0: 107 | # downscale previous 108 | pred = scale_as(pred, p) 109 | pred = attn * p + (1 - attn) * pred 110 | else: 111 | # upscale current 112 | p = attn * p 113 | p = scale_as(p, pred) 114 | attn = scale_as(attn, pred) 115 | pred = p + (1 - attn) * pred 116 | 117 | last_feats = feats 118 | 119 | if self.training: 120 | assert 'gts' in inputs 121 | gts = inputs['gts'] 122 | loss = self.criterion(pred, gts) 123 | return loss 124 | else: 125 | # FIXME: should add multi-scale values for pred and attn 126 | return {'pred': pred, 127 | 'attn_10x': attn} 128 | 129 | def two_scale_forward(self, inputs): 130 | assert 'images' in inputs 131 | 132 | x_1x = inputs['images'] 133 | x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE) 134 | 135 | p_lo, feats_lo = self._fwd(x_lo) 136 | p_1x, feats_hi = self._fwd(x_1x) 137 | 138 | feats_hi = scale_as(feats_hi, feats_lo) 139 | cat_feats = torch.cat([feats_lo, feats_hi], 1) 140 | logit_attn = self.scale_attn(cat_feats) 141 | logit_attn = scale_as(logit_attn, p_lo) 142 | 143 | p_lo = logit_attn * p_lo 144 | p_lo = scale_as(p_lo, p_1x) 145 | logit_attn = scale_as(logit_attn, p_1x) 146 | joint_pred = p_lo + (1 - logit_attn) * p_1x 147 | 148 | if self.training: 149 | assert 'gts' in inputs 150 | gts = inputs['gts'] 151 | loss = self.criterion(joint_pred, gts) 152 | return loss 153 | else: 154 | # FIXME: should add multi-scale values for pred and attn 155 | return {'pred': joint_pred, 156 | 'attn_10x': logit_attn} 157 | 158 | def forward(self, inputs): 159 | if cfg.MODEL.N_SCALES and not self.training: 160 | return self.nscale_forward(inputs, cfg.MODEL.N_SCALES) 161 | 162 | return self.two_scale_forward(inputs) 163 | 164 | 165 | class MscaleV3Plus(MscaleBase): 166 | """ 167 | DeepLabV3Plus-based mscale segmentation model 168 | """ 169 | def __init__(self, num_classes, trunk='wrn38', criterion=None): 170 | super(MscaleV3Plus, self).__init__() 171 | self.criterion = criterion 172 | self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) 173 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, 174 | bottleneck_ch=256, 175 | output_stride=8) 176 | self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) 177 | self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) 178 | 179 | # Semantic segmentation prediction head 180 | self.final = nn.Sequential( 181 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), 182 | Norm2d(256), 183 | nn.ReLU(inplace=True), 184 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 185 | Norm2d(256), 186 | nn.ReLU(inplace=True), 187 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 188 | 189 | # Scale-attention prediction head 190 | scale_in_ch = 2 * (256 + 48) 191 | 192 | self.scale_attn = nn.Sequential( 193 | nn.Conv2d(scale_in_ch, 256, kernel_size=3, padding=1, bias=False), 194 | Norm2d(256), 195 | nn.ReLU(inplace=True), 196 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 197 | Norm2d(256), 198 | nn.ReLU(inplace=True), 199 | nn.Conv2d(256, 1, kernel_size=1, bias=False), 200 | nn.Sigmoid()) 201 | 202 | if cfg.OPTIONS.INIT_DECODER: 203 | initialize_weights(self.bot_fine) 204 | initialize_weights(self.bot_aspp) 205 | initialize_weights(self.scale_attn) 206 | initialize_weights(self.final) 207 | else: 208 | initialize_weights(self.final) 209 | 210 | def _fwd(self, x): 211 | x_size = x.size() 212 | s2_features, _, final_features = self.backbone(x) 213 | aspp = self.aspp(final_features) 214 | 215 | conv_aspp = self.bot_aspp(aspp) 216 | conv_s2 = self.bot_fine(s2_features) 217 | conv_aspp = Upsample(conv_aspp, s2_features.size()[2:]) 218 | cat_s4 = [conv_s2, conv_aspp] 219 | cat_s4 = torch.cat(cat_s4, 1) 220 | 221 | final = self.final(cat_s4) 222 | out = Upsample(final, x_size[2:]) 223 | 224 | return out, cat_s4 225 | 226 | 227 | def DeepV3R50(num_classes, criterion): 228 | return MscaleV3Plus(num_classes, trunk='resnet-50', criterion=criterion) 229 | 230 | 231 | class Basic(MscaleBase): 232 | """ 233 | """ 234 | def __init__(self, num_classes, trunk='hrnetv2', criterion=None): 235 | super(Basic, self).__init__() 236 | self.criterion = criterion 237 | self.backbone, _, _, high_level_ch = get_trunk( 238 | trunk_name=trunk, output_stride=8) 239 | 240 | self.cls_head = make_seg_head(in_ch=high_level_ch, bot_ch=256, 241 | out_ch=num_classes) 242 | self.scale_attn = make_attn_head(in_ch=high_level_ch * 2, bot_ch=256, 243 | out_ch=1) 244 | 245 | def two_scale_forward(self, inputs): 246 | assert 'images' in inputs 247 | 248 | x_1x = inputs['images'] 249 | x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE) 250 | 251 | p_lo, feats_lo = self._fwd(x_lo) 252 | p_1x, feats_hi = self._fwd(x_1x) 253 | 254 | feats_lo = scale_as(feats_lo, feats_hi) 255 | cat_feats = torch.cat([feats_lo, feats_hi], 1) 256 | logit_attn = self.scale_attn(cat_feats) 257 | logit_attn_lo = scale_as(logit_attn, p_lo) 258 | logit_attn_1x = scale_as(logit_attn, p_1x) 259 | 260 | p_lo = logit_attn_lo * p_lo 261 | p_lo = scale_as(p_lo, p_1x) 262 | joint_pred = p_lo + (1 - logit_attn_1x) * p_1x 263 | 264 | if self.training: 265 | assert 'gts' in inputs 266 | gts = inputs['gts'] 267 | loss = self.criterion(joint_pred, gts) 268 | return loss 269 | else: 270 | return joint_pred, logit_attn_1x 271 | 272 | def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None): 273 | _, _, final_features = self.backbone(x) 274 | pred = self.cls_head(final_features) 275 | pred = scale_as(pred, x) 276 | 277 | return pred, final_features 278 | 279 | 280 | def HRNet(num_classes, criterion, s2s4=None): 281 | return Basic(num_classes=num_classes, criterion=criterion, 282 | trunk='hrnetv2') 283 | -------------------------------------------------------------------------------- /network/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Norm wrappers to enable sync BN, regular BN and for weight 3 | initialization 4 | """ 5 | import re 6 | import torch 7 | import torch.nn as nn 8 | from config import cfg 9 | 10 | from apex import amp 11 | 12 | from runx.logx import logx 13 | 14 | 15 | align_corners = cfg.MODEL.ALIGN_CORNERS 16 | 17 | 18 | def Norm2d(in_channels, **kwargs): 19 | """ 20 | Custom Norm Function to allow flexible switching 21 | """ 22 | layer = getattr(cfg.MODEL, 'BNFUNC') 23 | normalization_layer = layer(in_channels, **kwargs) 24 | return normalization_layer 25 | 26 | 27 | def initialize_weights(*models): 28 | """ 29 | Initialize Model Weights 30 | """ 31 | for model in models: 32 | for module in model.modules(): 33 | if isinstance(module, (nn.Conv2d, nn.Linear)): 34 | nn.init.kaiming_normal_(module.weight) 35 | if module.bias is not None: 36 | module.bias.data.zero_() 37 | elif isinstance(module, cfg.MODEL.BNFUNC): 38 | module.weight.data.fill_(1) 39 | module.bias.data.zero_() 40 | 41 | 42 | @amp.float_function 43 | def Upsample(x, size): 44 | """ 45 | Wrapper Around the Upsample Call 46 | """ 47 | return nn.functional.interpolate(x, size=size, mode='bilinear', 48 | align_corners=align_corners) 49 | 50 | 51 | @amp.float_function 52 | def Upsample2(x): 53 | """ 54 | Wrapper Around the Upsample Call 55 | """ 56 | return nn.functional.interpolate(x, scale_factor=2, mode='bilinear', 57 | align_corners=align_corners) 58 | 59 | 60 | def Down2x(x): 61 | return torch.nn.functional.interpolate( 62 | x, scale_factor=0.5, mode='bilinear', align_corners=align_corners) 63 | 64 | 65 | def Up15x(x): 66 | return torch.nn.functional.interpolate( 67 | x, scale_factor=1.5, mode='bilinear', align_corners=align_corners) 68 | 69 | 70 | def scale_as(x, y): 71 | ''' 72 | scale x to the same size as y 73 | ''' 74 | y_size = y.size(2), y.size(3) 75 | 76 | if cfg.OPTIONS.TORCH_VERSION >= 1.5: 77 | x_scaled = torch.nn.functional.interpolate( 78 | x, size=y_size, mode='bilinear', 79 | align_corners=align_corners) 80 | else: 81 | x_scaled = torch.nn.functional.interpolate( 82 | x, size=y_size, mode='bilinear', 83 | align_corners=align_corners) 84 | return x_scaled 85 | 86 | 87 | def DownX(x, scale_factor): 88 | ''' 89 | scale x to the same size as y 90 | ''' 91 | if cfg.OPTIONS.TORCH_VERSION >= 1.5: 92 | x_scaled = torch.nn.functional.interpolate( 93 | x, scale_factor=scale_factor, mode='bilinear', 94 | align_corners=align_corners, recompute_scale_factor=True) 95 | else: 96 | x_scaled = torch.nn.functional.interpolate( 97 | x, scale_factor=scale_factor, mode='bilinear', 98 | align_corners=align_corners) 99 | return x_scaled 100 | 101 | 102 | def ResizeX(x, scale_factor): 103 | ''' 104 | scale x by some factor 105 | ''' 106 | if cfg.OPTIONS.TORCH_VERSION >= 1.5: 107 | x_scaled = torch.nn.functional.interpolate( 108 | x, scale_factor=scale_factor, mode='bilinear', 109 | align_corners=align_corners, recompute_scale_factor=True) 110 | else: 111 | x_scaled = torch.nn.functional.interpolate( 112 | x, scale_factor=scale_factor, mode='bilinear', 113 | align_corners=align_corners) 114 | return x_scaled 115 | -------------------------------------------------------------------------------- /network/ocr_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn), Jingyi Xie (hsfzxjy@gmail.com) 5 | # 6 | # This code is from: https://github.com/HRNet/HRNet-Semantic-Segmentation 7 | # ------------------------------------------------------------------------------ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from config import cfg 14 | from network.utils import BNReLU 15 | 16 | 17 | class SpatialGather_Module(nn.Module): 18 | """ 19 | Aggregate the context features according to the initial 20 | predicted probability distribution. 21 | Employ the soft-weighted method to aggregate the context. 22 | 23 | Output: 24 | The correlation of every class map with every feature map 25 | shape = [n, num_feats, num_classes, 1] 26 | 27 | 28 | """ 29 | def __init__(self, cls_num=0, scale=1): 30 | super(SpatialGather_Module, self).__init__() 31 | self.cls_num = cls_num 32 | self.scale = scale 33 | 34 | def forward(self, feats, probs): 35 | batch_size, c, _, _ = probs.size(0), probs.size(1), probs.size(2), \ 36 | probs.size(3) 37 | 38 | # each class image now a vector 39 | probs = probs.view(batch_size, c, -1) 40 | feats = feats.view(batch_size, feats.size(1), -1) 41 | 42 | feats = feats.permute(0, 2, 1) # batch x hw x c 43 | probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw 44 | ocr_context = torch.matmul(probs, feats) 45 | ocr_context = ocr_context.permute(0, 2, 1).unsqueeze(3) 46 | return ocr_context 47 | 48 | 49 | class ObjectAttentionBlock(nn.Module): 50 | ''' 51 | The basic implementation for object context block 52 | Input: 53 | N X C X H X W 54 | Parameters: 55 | in_channels : the dimension of the input feature map 56 | key_channels : the dimension after the key/query transform 57 | scale : choose the scale to downsample the input feature 58 | maps (save memory cost) 59 | Return: 60 | N X C X H X W 61 | ''' 62 | def __init__(self, in_channels, key_channels, scale=1): 63 | super(ObjectAttentionBlock, self).__init__() 64 | self.scale = scale 65 | self.in_channels = in_channels 66 | self.key_channels = key_channels 67 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) 68 | self.f_pixel = nn.Sequential( 69 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 70 | kernel_size=1, stride=1, padding=0, bias=False), 71 | BNReLU(self.key_channels), 72 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, 73 | kernel_size=1, stride=1, padding=0, bias=False), 74 | BNReLU(self.key_channels), 75 | ) 76 | self.f_object = nn.Sequential( 77 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 78 | kernel_size=1, stride=1, padding=0, bias=False), 79 | BNReLU(self.key_channels), 80 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, 81 | kernel_size=1, stride=1, padding=0, bias=False), 82 | BNReLU(self.key_channels), 83 | ) 84 | self.f_down = nn.Sequential( 85 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 86 | kernel_size=1, stride=1, padding=0, bias=False), 87 | BNReLU(self.key_channels), 88 | ) 89 | self.f_up = nn.Sequential( 90 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, 91 | kernel_size=1, stride=1, padding=0, bias=False), 92 | BNReLU(self.in_channels), 93 | ) 94 | 95 | def forward(self, x, proxy): 96 | batch_size, h, w = x.size(0), x.size(2), x.size(3) 97 | if self.scale > 1: 98 | x = self.pool(x) 99 | 100 | query = self.f_pixel(x).view(batch_size, self.key_channels, -1) 101 | query = query.permute(0, 2, 1) 102 | key = self.f_object(proxy).view(batch_size, self.key_channels, -1) 103 | value = self.f_down(proxy).view(batch_size, self.key_channels, -1) 104 | value = value.permute(0, 2, 1) 105 | 106 | sim_map = torch.matmul(query, key) 107 | sim_map = (self.key_channels**-.5) * sim_map 108 | sim_map = F.softmax(sim_map, dim=-1) 109 | 110 | # add bg context ... 111 | context = torch.matmul(sim_map, value) 112 | context = context.permute(0, 2, 1).contiguous() 113 | context = context.view(batch_size, self.key_channels, *x.size()[2:]) 114 | context = self.f_up(context) 115 | if self.scale > 1: 116 | context = F.interpolate(input=context, size=(h, w), mode='bilinear', 117 | align_corners=cfg.MODEL.ALIGN_CORNERS) 118 | 119 | return context 120 | 121 | 122 | class SpatialOCR_Module(nn.Module): 123 | """ 124 | Implementation of the OCR module: 125 | We aggregate the global object representation to update the representation 126 | for each pixel. 127 | """ 128 | def __init__(self, in_channels, key_channels, out_channels, scale=1, 129 | dropout=0.1): 130 | super(SpatialOCR_Module, self).__init__() 131 | self.object_context_block = ObjectAttentionBlock(in_channels, 132 | key_channels, 133 | scale) 134 | if cfg.MODEL.OCR_ASPP: 135 | self.aspp, aspp_out_ch = get_aspp( 136 | in_channels, bottleneck_ch=cfg.MODEL.ASPP_BOT_CH, 137 | output_stride=8) 138 | _in_channels = 2 * in_channels + aspp_out_ch 139 | else: 140 | _in_channels = 2 * in_channels 141 | 142 | self.conv_bn_dropout = nn.Sequential( 143 | nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, 144 | bias=False), 145 | BNReLU(out_channels), 146 | nn.Dropout2d(dropout) 147 | ) 148 | 149 | def forward(self, feats, proxy_feats): 150 | context = self.object_context_block(feats, proxy_feats) 151 | 152 | if cfg.MODEL.OCR_ASPP: 153 | aspp = self.aspp(feats) 154 | output = self.conv_bn_dropout(torch.cat([context, aspp, feats], 1)) 155 | else: 156 | output = self.conv_bn_dropout(torch.cat([context, feats], 1)) 157 | 158 | return output 159 | -------------------------------------------------------------------------------- /network/xception.py: -------------------------------------------------------------------------------- 1 | # Xception71 2 | # Code Adapted from: 3 | # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/xception.py 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from config import cfg 10 | from network.mynn import Norm2d 11 | from apex.parallel import SyncBatchNorm 12 | from runx.logx import logx 13 | 14 | 15 | def fixed_padding(inputs, kernel_size, dilation): 16 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 17 | pad_total = kernel_size_effective - 1 18 | pad_beg = pad_total // 2 19 | pad_end = pad_total - pad_beg 20 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 21 | return padded_inputs 22 | 23 | 24 | class SeparableConv2d(nn.Module): 25 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, 26 | bias=False, BatchNorm=None): 27 | super(SeparableConv2d, self).__init__() 28 | 29 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, 30 | dilation, groups=inplanes, bias=bias) 31 | self.bn = BatchNorm(inplanes) 32 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 33 | 34 | def forward(self, x): 35 | x = fixed_padding(x, self.conv1.kernel_size[0], 36 | dilation=self.conv1.dilation[0]) 37 | x = self.conv1(x) 38 | x = self.bn(x) 39 | x = self.pointwise(x) 40 | return x 41 | 42 | 43 | class Block(nn.Module): 44 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, 45 | BatchNorm=None, start_with_relu=True, grow_first=True, 46 | is_last=False): 47 | super(Block, self).__init__() 48 | 49 | if planes != inplanes or stride != 1: 50 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, 51 | bias=False) 52 | self.skipbn = BatchNorm(planes) 53 | else: 54 | self.skip = None 55 | 56 | self.relu = nn.ReLU(inplace=True) 57 | rep = [] 58 | 59 | filters = inplanes 60 | if grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, 63 | BatchNorm=BatchNorm)) 64 | rep.append(BatchNorm(planes)) 65 | filters = planes 66 | 67 | for i in range(reps - 1): 68 | rep.append(self.relu) 69 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, 70 | BatchNorm=BatchNorm)) 71 | rep.append(BatchNorm(filters)) 72 | 73 | if not grow_first: 74 | rep.append(self.relu) 75 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, 76 | BatchNorm=BatchNorm)) 77 | rep.append(BatchNorm(planes)) 78 | 79 | if stride != 1: 80 | rep.append(self.relu) 81 | rep.append(SeparableConv2d(planes, planes, 3, 2, 82 | BatchNorm=BatchNorm)) 83 | rep.append(BatchNorm(planes)) 84 | 85 | if stride == 1 and is_last: 86 | rep.append(self.relu) 87 | rep.append(SeparableConv2d(planes, planes, 3, 1, 88 | BatchNorm=BatchNorm)) 89 | rep.append(BatchNorm(planes)) 90 | 91 | if not start_with_relu: 92 | rep = rep[1:] 93 | 94 | self.rep = nn.Sequential(*rep) 95 | 96 | def forward(self, inp): 97 | x = self.rep(inp) 98 | 99 | if self.skip is not None: 100 | skip = self.skip(inp) 101 | skip = self.skipbn(skip) 102 | else: 103 | skip = inp 104 | 105 | x = x + skip 106 | 107 | return x 108 | 109 | 110 | class xception71(nn.Module): 111 | """ 112 | Modified Alighed Xception 113 | """ 114 | def __init__(self, output_stride, BatchNorm, 115 | pretrained=True): 116 | super(xception71, self).__init__() 117 | 118 | self.output_stride = output_stride 119 | if self.output_stride == 16: 120 | middle_block_dilation = 1 121 | exit_block_dilations = (1, 2) 122 | exit_stride = 2 123 | elif self.output_stride == 8: 124 | middle_block_dilation = 2 125 | exit_block_dilations = (2, 4) 126 | exit_stride = 1 127 | else: 128 | raise NotImplementedError 129 | 130 | # Entry flow 131 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 132 | self.bn1 = BatchNorm(32) 133 | self.relu = nn.ReLU(inplace=True) 134 | 135 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 136 | self.bn2 = BatchNorm(64) 137 | 138 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 139 | # stride4 140 | 141 | self.block2 = Block(128, 256, reps=2, stride=1, BatchNorm=BatchNorm, start_with_relu=False, 142 | grow_first=True) 143 | self.block3 = Block(256, 728, reps=2, stride=2, BatchNorm=BatchNorm, 144 | start_with_relu=True, grow_first=True, is_last=True) 145 | # stride8 146 | 147 | # Middle flow 148 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 149 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 150 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 151 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 152 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 153 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 154 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 155 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 156 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 157 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 158 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 159 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 160 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 161 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 162 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 163 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 164 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 165 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 166 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 167 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 168 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 169 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 170 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 171 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 172 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 173 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 174 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 175 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 176 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 177 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 178 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 179 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 180 | 181 | # Exit flow 182 | self.block20 = Block(728, 1024, reps=2, stride=exit_stride, dilation=exit_block_dilations[0], 183 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 184 | 185 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 186 | self.bn3 = BatchNorm(1536) 187 | 188 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 189 | self.bn4 = BatchNorm(1536) 190 | 191 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 192 | self.bn5 = BatchNorm(2048) 193 | 194 | # Init weights 195 | self._init_weight() 196 | 197 | # Load pretrained model 198 | if pretrained: 199 | self._load_pretrained_model() 200 | 201 | def forward(self, x): 202 | # Entry flow 203 | x = self.conv1(x) 204 | x = self.bn1(x) 205 | x = self.relu(x) 206 | 207 | x = self.conv2(x) 208 | x = self.bn2(x) 209 | str2 = self.relu(x) 210 | # s2 211 | str4 = self.block1(str2) 212 | str4 = self.relu(str4) 213 | # s4 214 | x = self.block2(str4) 215 | str8 = self.block3(x) 216 | # s8 217 | 218 | if self.output_stride == 8: 219 | low_level_feat, high_level_feat = str2, str4 220 | else: 221 | low_level_feat, high_level_feat = str4, str8 222 | 223 | # Middle flow 224 | x = self.block4(str8) 225 | x = self.block5(x) 226 | x = self.block6(x) 227 | x = self.block7(x) 228 | x = self.block8(x) 229 | x = self.block9(x) 230 | x = self.block10(x) 231 | x = self.block11(x) 232 | x = self.block12(x) 233 | x = self.block13(x) 234 | x = self.block14(x) 235 | x = self.block15(x) 236 | x = self.block16(x) 237 | x = self.block17(x) 238 | x = self.block18(x) 239 | x = self.block19(x) 240 | 241 | # Exit flow 242 | x = self.block20(x) 243 | x = self.relu(x) 244 | x = self.conv3(x) 245 | x = self.bn3(x) 246 | x = self.relu(x) 247 | 248 | x = self.conv4(x) 249 | x = self.bn4(x) 250 | x = self.relu(x) 251 | 252 | x = self.conv5(x) 253 | x = self.bn5(x) 254 | x = self.relu(x) 255 | 256 | return low_level_feat, high_level_feat, x 257 | 258 | def _init_weight(self): 259 | for m in self.modules(): 260 | if isinstance(m, nn.Conv2d): 261 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 262 | m.weight.data.normal_(0, math.sqrt(2. / n)) 263 | elif isinstance(m, SyncBatchNorm): 264 | m.weight.data.fill_(1) 265 | m.bias.data.zero_() 266 | elif isinstance(m, nn.BatchNorm2d): 267 | m.weight.data.fill_(1) 268 | m.bias.data.zero_() 269 | 270 | def _load_pretrained_model(self): 271 | pretrained_model = cfg.MODEL.X71_CHECKPOINT 272 | ckpt = torch.load(pretrained_model, map_location='cpu') 273 | model_dict = {k.replace('module.', ''): v for k, v in 274 | ckpt['model_dict'].items()} 275 | state_dict = self.state_dict() 276 | state_dict.update(model_dict) 277 | self.load_state_dict(state_dict, strict=False) 278 | del ckpt 279 | logx.msg('Loaded {} weights'.format(pretrained_model)) 280 | 281 | 282 | if __name__ == "__main__": 283 | model = xception71(BatchNorm=Norm2d, pretrained=True, 284 | output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) 289 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml>=5.1.1 2 | coolname>=1.1.0 3 | tabulate>=0.8.3 4 | tensorboardX>=1.4 5 | runx==0.0.6 6 | -------------------------------------------------------------------------------- /scripts/dump_cityscapes.yml: -------------------------------------------------------------------------------- 1 | # Run Evaluation and Dump Images on Cityscapes with a pretrained model 2 | 3 | CMD: "python -m torch.distributed.launch --nproc_per_node=1 train.py" 4 | 5 | HPARAMS: [ 6 | { 7 | dataset: cityscapes, 8 | cv: 0, 9 | syncbn: true, 10 | apex: true, 11 | fp16: true, 12 | bs_val: 1, 13 | eval: val, 14 | dump_assets: true, 15 | dump_all_images: true, 16 | n_scales: "0.5,1.0,2.0", 17 | snapshot: "ASSETS_PATH/seg_weights/cityscapes_ocrnet.HRNet_Mscale_outstanding-turtle.pth", 18 | arch: ocrnet.HRNet_Mscale, 19 | result_dir: LOGDIR, 20 | }, 21 | ] 22 | -------------------------------------------------------------------------------- /scripts/dump_folder.yml: -------------------------------------------------------------------------------- 1 | # Run Evaluation and Dump Images on Cityscapes with a pretrained model 2 | 3 | CMD: "python -m torch.distributed.launch --nproc_per_node=1 train.py" 4 | 5 | HPARAMS: [ 6 | { 7 | dataset: cityscapes, 8 | cv: 0, 9 | syncbn: true, 10 | apex: true, 11 | fp16: true, 12 | bs_val: 1, 13 | eval: folder, 14 | eval_folder: './imgs/test_imgs', 15 | dump_assets: true, 16 | dump_all_images: true, 17 | n_scales: "0.5,1.0,2.0", 18 | snapshot: "ASSETS_PATH/seg_weights/cityscapes_ocrnet.HRNet_Mscale_outstanding-turtle.pth", 19 | arch: ocrnet.HRNet_Mscale, 20 | result_dir: LOGDIR, 21 | }, 22 | ] 23 | -------------------------------------------------------------------------------- /scripts/eval_cityscapes.yml: -------------------------------------------------------------------------------- 1 | # Run Evaluation on Cityscapes with a pretrained model 2 | 3 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 4 | 5 | HPARAMS: [ 6 | { 7 | dataset: cityscapes, 8 | cv: 0, 9 | syncbn: true, 10 | apex: true, 11 | fp16: true, 12 | bs_val: 2, 13 | eval: val, 14 | n_scales: "0.5,1.0,2.0", 15 | snapshot: "ASSETS_PATH/seg_weights/cityscapes_ocrnet.HRNet_Mscale_outstanding-turtle.pth", 16 | arch: ocrnet.HRNet_Mscale, 17 | result_dir: LOGDIR, 18 | }, 19 | ] 20 | -------------------------------------------------------------------------------- /scripts/eval_mapillary.yml: -------------------------------------------------------------------------------- 1 | # Run Evaluation on Mapillary with a pretrained model 2 | 3 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 4 | 5 | HPARAMS: [ 6 | { 7 | dataset: mapillary, 8 | syncbn: true, 9 | apex: true, 10 | fp16: true, 11 | bs_val: 1, 12 | eval: val, 13 | pre_size: 2177, 14 | amp_opt_level: O3, 15 | n_scales: "0.25,0.5,1.0,2.0", 16 | do_flip: true, 17 | snapshot: "ASSETS_PATH/seg_weights/mapillary_ocrnet.HRNet_Mscale_fast-rattlesnake.pth", 18 | arch: ocrnet.HRNet_Mscale, 19 | result_dir: LOGDIR, 20 | }, 21 | ] 22 | -------------------------------------------------------------------------------- /scripts/train_cityscapes.yml: -------------------------------------------------------------------------------- 1 | # Train cityscapes using Mapillary-pretrained weights 2 | # Requires 32GB GPU 3 | # Adjust nproc_per_node according to how many GPUs you have 4 | 5 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 6 | 7 | HPARAMS: [ 8 | { 9 | dataset: cityscapes, 10 | cv: 0, 11 | syncbn: true, 12 | apex: true, 13 | fp16: true, 14 | crop_size: "1024,2048", 15 | bs_trn: 1, 16 | poly_exp: 2, 17 | lr: 5e-3, 18 | rmi_loss: true, 19 | max_epoch: 175, 20 | n_scales: "0.5,1.0,2.0", 21 | supervised_mscale_loss_wt: 0.05, 22 | snapshot: "ASSETS_PATH/seg_weights/ocrnet.HRNet_industrious-chicken.pth", 23 | arch: ocrnet.HRNet_Mscale, 24 | result_dir: LOGDIR, 25 | RUNX.TAG: '{arch}', 26 | }, 27 | ] 28 | -------------------------------------------------------------------------------- /scripts/train_cityscapes_deepv3.yml: -------------------------------------------------------------------------------- 1 | # Train cityscapes with deeplab v3+ and wide-resnet-38 trunk 2 | # Only requires 16GB gpus 3 | 4 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 5 | 6 | HPARAMS: [ 7 | { 8 | dataset: cityscapes, 9 | cv: 0, 10 | syncbn: true, 11 | apex: true, 12 | fp16: true, 13 | crop_size: "800,800", 14 | bs_trn: 1, 15 | poly_exp: 2, 16 | lr: 5e-3, 17 | max_epoch: 175, 18 | arch: deepv3.DeepV3PlusW38, 19 | result_dir: LOGDIR, 20 | RUNX.TAG: '{arch}', 21 | }, 22 | ] 23 | -------------------------------------------------------------------------------- /scripts/train_cityscapes_sota.yml: -------------------------------------------------------------------------------- 1 | # Train cityscapes using Mapillary-pretrained weights 2 | # Requires 32GB GPU 3 | # Adjust nproc_per_node according to how many GPUs you have 4 | 5 | CMD: "python -m torch.distributed.launch --nproc_per_node=16 train.py" 6 | 7 | HPARAMS: [ 8 | { 9 | dataset: cityscapes, 10 | cv: 0, 11 | syncbn: true, 12 | apex: true, 13 | fp16: true, 14 | crop_size: "1024,2048", 15 | bs_trn: 1, 16 | poly_exp: 2, 17 | lr: 1e-2, 18 | max_epoch: 175, 19 | max_cu_epoch: 150, 20 | rmi_loss: true, 21 | n_scales: ['0.5,1.0,2.0'], 22 | supervised_mscale_loss_wt: 0.05, 23 | 24 | arch: ocrnet.HRNet_Mscale, 25 | snapshot: "ASSETS_PATH/seg_weights/ocrnet.HRNet_industrious-chicken.pth", 26 | result_dir: LOGDIR, 27 | RUNX.TAG: 'sota-cv0-{arch}', 28 | 29 | coarse_boost_classes: "3,4,6,7,9,11,12,13,14,15,16,17,18", 30 | custom_coarse_dropout_classes: "14,15,16", 31 | mask_out_cityscapes: true, 32 | custom_coarse_prob: 0.5, 33 | }, 34 | ] 35 | -------------------------------------------------------------------------------- /scripts/train_mapillary.yml: -------------------------------------------------------------------------------- 1 | # Single node Mapillary training recipe 2 | # Requires 32GB GPU 3 | 4 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 5 | 6 | HPARAMS: [ 7 | { 8 | dataset: mapillary, 9 | cv: 0, 10 | result_dir: LOGDIR, 11 | 12 | pre_size: 2177, 13 | crop_size: "1024,1024", 14 | syncbn: true, 15 | apex: true, 16 | fp16: true, 17 | gblur: true, 18 | 19 | bs_trn: 2, 20 | 21 | lr_schedule: poly, 22 | poly_exp: 1.0, 23 | optimizer: sgd, 24 | lr: 5e-3, 25 | max_epoch: 200, 26 | rmi_loss: true, 27 | 28 | arch: ocrnet.HRNet_Mscale, 29 | n_scales: '0.5,1.0,2.0', 30 | } 31 | ] 32 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/semantic-segmentation/7726b144c2cc0b8e09c67eabb78f027efdf3f0fa/transforms/__init__.py -------------------------------------------------------------------------------- /transforms/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code borrowded from: 3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py 4 | # 5 | # 6 | # MIT License 7 | # 8 | # Copyright (c) 2017 ZijunDeng 9 | # 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy 11 | # of this software and associated documentation files (the "Software"), to deal 12 | # in the Software without restriction, including without limitation the rights 13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | # copies of the Software, and to permit persons to whom the Software is 15 | # furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all 18 | # copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | # SOFTWARE. 27 | 28 | """ 29 | 30 | """ 31 | Standard Transform 32 | """ 33 | 34 | import random 35 | import numpy as np 36 | from skimage.filters import gaussian 37 | from skimage.restoration import denoise_bilateral 38 | import torch 39 | from PIL import Image, ImageEnhance 40 | import torchvision.transforms as torch_tr 41 | from config import cfg 42 | from scipy.ndimage.interpolation import shift 43 | 44 | from skimage.segmentation import find_boundaries 45 | 46 | try: 47 | import accimage 48 | except ImportError: 49 | accimage = None 50 | 51 | 52 | class RandomVerticalFlip(object): 53 | def __call__(self, img): 54 | if random.random() < 0.5: 55 | return img.transpose(Image.FLIP_TOP_BOTTOM) 56 | return img 57 | 58 | 59 | class DeNormalize(object): 60 | def __init__(self, mean, std): 61 | self.mean = mean 62 | self.std = std 63 | 64 | def __call__(self, tensor): 65 | for t, m, s in zip(tensor, self.mean, self.std): 66 | t.mul_(s).add_(m) 67 | return tensor 68 | 69 | 70 | class MaskToTensor(object): 71 | def __call__(self, img, blockout_predefined_area=False): 72 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 73 | 74 | class RelaxedBoundaryLossToTensor(object): 75 | """ 76 | Boundary Relaxation 77 | """ 78 | def __init__(self,ignore_id, num_classes): 79 | self.ignore_id=ignore_id 80 | self.num_classes= num_classes 81 | 82 | 83 | def new_one_hot_converter(self,a): 84 | ncols = self.num_classes+1 85 | out = np.zeros( (a.size,ncols), dtype=np.uint8) 86 | out[np.arange(a.size),a.ravel()] = 1 87 | out.shape = a.shape + (ncols,) 88 | return out 89 | 90 | def __call__(self,img): 91 | 92 | img_arr = np.array(img) 93 | img_arr[img_arr==self.ignore_id]=self.num_classes 94 | 95 | if cfg.STRICTBORDERCLASS != None: 96 | one_hot_orig = self.new_one_hot_converter(img_arr) 97 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1])) 98 | for cls in cfg.STRICTBORDERCLASS: 99 | mask = np.logical_or(mask,(img_arr == cls)) 100 | one_hot = 0 101 | 102 | border = cfg.BORDER_WINDOW 103 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 104 | border = border // 2 105 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8) 106 | 107 | for i in range(-border,border+1): 108 | for j in range(-border, border+1): 109 | shifted= shift(img_arr,(i,j), cval=self.num_classes) 110 | one_hot += self.new_one_hot_converter(shifted) 111 | 112 | one_hot[one_hot>1] = 1 113 | 114 | if cfg.STRICTBORDERCLASS != None: 115 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot) 116 | 117 | one_hot = np.moveaxis(one_hot,-1,0) 118 | 119 | 120 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 121 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot) 122 | # print(one_hot.shape) 123 | return torch.from_numpy(one_hot).byte() 124 | 125 | class ResizeHeight(object): 126 | def __init__(self, size, interpolation=Image.BILINEAR): 127 | self.target_h = size 128 | self.interpolation = interpolation 129 | 130 | def __call__(self, img): 131 | w, h = img.size 132 | target_w = int(w / h * self.target_h) 133 | return img.resize((target_w, self.target_h), self.interpolation) 134 | 135 | 136 | class FreeScale(object): 137 | def __init__(self, size, interpolation=Image.BILINEAR): 138 | self.size = tuple(reversed(size)) # size: (h, w) 139 | self.interpolation = interpolation 140 | 141 | def __call__(self, img): 142 | return img.resize(self.size, self.interpolation) 143 | 144 | 145 | class FlipChannels(object): 146 | """ 147 | Flip around the x-axis 148 | """ 149 | def __call__(self, img): 150 | img = np.array(img)[:, :, ::-1] 151 | return Image.fromarray(img.astype(np.uint8)) 152 | 153 | 154 | class RandomGaussianBlur(object): 155 | """ 156 | Apply Gaussian Blur 157 | """ 158 | def __call__(self, img): 159 | sigma = 0.15 + random.random() * 1.15 160 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 161 | blurred_img *= 255 162 | return Image.fromarray(blurred_img.astype(np.uint8)) 163 | 164 | 165 | class RandomBrightness(object): 166 | def __call__(self, img): 167 | if random.random() < 0.5: 168 | return img 169 | v = random.uniform(0.1, 1.9) 170 | return ImageEnhance.Brightness(img).enhance(v) 171 | 172 | 173 | class RandomBilateralBlur(object): 174 | """ 175 | Apply Bilateral Filtering 176 | 177 | """ 178 | def __call__(self, img): 179 | sigma = random.uniform(0.05, 0.75) 180 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True) 181 | blurred_img *= 255 182 | return Image.fromarray(blurred_img.astype(np.uint8)) 183 | 184 | 185 | def _is_pil_image(img): 186 | if accimage is not None: 187 | return isinstance(img, (Image.Image, accimage.Image)) 188 | else: 189 | return isinstance(img, Image.Image) 190 | 191 | 192 | def adjust_brightness(img, brightness_factor): 193 | """Adjust brightness of an Image. 194 | 195 | Args: 196 | img (PIL Image): PIL Image to be adjusted. 197 | brightness_factor (float): How much to adjust the brightness. Can be 198 | any non negative number. 0 gives a black image, 1 gives the 199 | original image while 2 increases the brightness by a factor of 2. 200 | 201 | Returns: 202 | PIL Image: Brightness adjusted image. 203 | """ 204 | if not _is_pil_image(img): 205 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 206 | 207 | enhancer = ImageEnhance.Brightness(img) 208 | img = enhancer.enhance(brightness_factor) 209 | return img 210 | 211 | 212 | def adjust_contrast(img, contrast_factor): 213 | """Adjust contrast of an Image. 214 | 215 | Args: 216 | img (PIL Image): PIL Image to be adjusted. 217 | contrast_factor (float): How much to adjust the contrast. Can be any 218 | non negative number. 0 gives a solid gray image, 1 gives the 219 | original image while 2 increases the contrast by a factor of 2. 220 | 221 | Returns: 222 | PIL Image: Contrast adjusted image. 223 | """ 224 | if not _is_pil_image(img): 225 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 226 | 227 | enhancer = ImageEnhance.Contrast(img) 228 | img = enhancer.enhance(contrast_factor) 229 | return img 230 | 231 | 232 | def adjust_saturation(img, saturation_factor): 233 | """Adjust color saturation of an image. 234 | 235 | Args: 236 | img (PIL Image): PIL Image to be adjusted. 237 | saturation_factor (float): How much to adjust the saturation. 0 will 238 | give a black and white image, 1 will give the original image while 239 | 2 will enhance the saturation by a factor of 2. 240 | 241 | Returns: 242 | PIL Image: Saturation adjusted image. 243 | """ 244 | if not _is_pil_image(img): 245 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 246 | 247 | enhancer = ImageEnhance.Color(img) 248 | img = enhancer.enhance(saturation_factor) 249 | return img 250 | 251 | 252 | def adjust_hue(img, hue_factor): 253 | """Adjust hue of an image. 254 | 255 | The image hue is adjusted by converting the image to HSV and 256 | cyclically shifting the intensities in the hue channel (H). 257 | The image is then converted back to original image mode. 258 | 259 | `hue_factor` is the amount of shift in H channel and must be in the 260 | interval `[-0.5, 0.5]`. 261 | 262 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 263 | 264 | Args: 265 | img (PIL Image): PIL Image to be adjusted. 266 | hue_factor (float): How much to shift the hue channel. Should be in 267 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 268 | HSV space in positive and negative direction respectively. 269 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 270 | with complementary colors while 0 gives the original image. 271 | 272 | Returns: 273 | PIL Image: Hue adjusted image. 274 | """ 275 | if not(-0.5 <= hue_factor <= 0.5): 276 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 277 | 278 | if not _is_pil_image(img): 279 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 280 | 281 | input_mode = img.mode 282 | if input_mode in {'L', '1', 'I', 'F'}: 283 | return img 284 | 285 | h, s, v = img.convert('HSV').split() 286 | 287 | np_h = np.array(h, dtype=np.uint8) 288 | # uint8 addition take cares of rotation across boundaries 289 | with np.errstate(over='ignore'): 290 | np_h += np.uint8(hue_factor * 255) 291 | h = Image.fromarray(np_h, 'L') 292 | 293 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 294 | return img 295 | 296 | 297 | class ColorJitter(object): 298 | """Randomly change the brightness, contrast and saturation of an image. 299 | 300 | Args: 301 | brightness (float): How much to jitter brightness. brightness_factor 302 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 303 | contrast (float): How much to jitter contrast. contrast_factor 304 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 305 | saturation (float): How much to jitter saturation. saturation_factor 306 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 307 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 308 | [-hue, hue]. Should be >=0 and <= 0.5. 309 | """ 310 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 311 | self.brightness = brightness 312 | self.contrast = contrast 313 | self.saturation = saturation 314 | self.hue = hue 315 | 316 | @staticmethod 317 | def get_params(brightness, contrast, saturation, hue): 318 | """Get a randomized transform to be applied on image. 319 | 320 | Arguments are same as that of __init__. 321 | 322 | Returns: 323 | Transform which randomly adjusts brightness, contrast and 324 | saturation in a random order. 325 | """ 326 | transforms = [] 327 | if brightness > 0: 328 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 329 | transforms.append( 330 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) 331 | 332 | if contrast > 0: 333 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 334 | transforms.append( 335 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) 336 | 337 | if saturation > 0: 338 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 339 | transforms.append( 340 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) 341 | 342 | if hue > 0: 343 | hue_factor = np.random.uniform(-hue, hue) 344 | transforms.append( 345 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) 346 | 347 | np.random.shuffle(transforms) 348 | transform = torch_tr.Compose(transforms) 349 | 350 | return transform 351 | 352 | def __call__(self, img): 353 | """ 354 | Args: 355 | img (PIL Image): Input image. 356 | 357 | Returns: 358 | PIL Image: Color jittered image. 359 | """ 360 | transform = self.get_params(self.brightness, self.contrast, 361 | self.saturation, self.hue) 362 | return transform(img) 363 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/semantic-segmentation/7726b144c2cc0b8e09c67eabb78f027efdf3f0fa/utils/__init__.py -------------------------------------------------------------------------------- /utils/attr_dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | 30 | class AttrDict(dict): 31 | 32 | IMMUTABLE = '__immutable__' 33 | 34 | def __init__(self, *args, **kwargs): 35 | super(AttrDict, self).__init__(*args, **kwargs) 36 | self.__dict__[AttrDict.IMMUTABLE] = False 37 | 38 | def __getattr__(self, name): 39 | if name in self.__dict__: 40 | return self.__dict__[name] 41 | elif name in self: 42 | return self[name] 43 | else: 44 | raise AttributeError(name) 45 | 46 | def __setattr__(self, name, value): 47 | if not self.__dict__[AttrDict.IMMUTABLE]: 48 | if name in self.__dict__: 49 | self.__dict__[name] = value 50 | else: 51 | self[name] = value 52 | else: 53 | raise AttributeError( 54 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'. 55 | format(name, value) 56 | ) 57 | 58 | def immutable(self, is_immutable): 59 | """Set immutability to is_immutable and recursively apply the setting 60 | to all nested AttrDicts. 61 | """ 62 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable 63 | # Recursively set immutable state 64 | for v in self.__dict__.values(): 65 | if isinstance(v, AttrDict): 66 | v.immutable(is_immutable) 67 | for v in self.values(): 68 | if isinstance(v, AttrDict): 69 | v.immutable(is_immutable) 70 | 71 | def is_immutable(self): 72 | return self.__dict__[AttrDict.IMMUTABLE] 73 | -------------------------------------------------------------------------------- /utils/f_boundary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | # Code adapted from: 6 | # https://github.com/fperazzi/davis/blob/master/python/lib/davis/measures/f_boundary.py 7 | # 8 | # Source License 9 | # 10 | # BSD 3-Clause License 11 | # 12 | # Copyright (c) 2017, 13 | # All rights reserved. 14 | # 15 | # Redistribution and use in source and binary forms, with or without 16 | # modification, are permitted provided that the following conditions are met: 17 | # 18 | # * Redistributions of source code must retain the above copyright notice, this 19 | # list of conditions and the following disclaimer. 20 | # 21 | # * Redistributions in binary form must reproduce the above copyright notice, 22 | # this list of conditions and the following disclaimer in the documentation 23 | # and/or other materials provided with the distribution. 24 | # 25 | # * Neither the name of the copyright holder nor the names of its 26 | # contributors may be used to endorse or promote products derived from 27 | # this software without specific prior written permission. 28 | # 29 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 30 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 31 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 32 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 33 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 34 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 35 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 36 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 37 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 38 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s 39 | ############################################################################## 40 | # 41 | # Based on: 42 | # ---------------------------------------------------------------------------- 43 | # A Benchmark Dataset and Evaluation Methodology for Video Object Segmentation 44 | # Copyright (c) 2016 Federico Perazzi 45 | # Licensed under the BSD License [see LICENSE for details] 46 | # Written by Federico Perazzi 47 | # ---------------------------------------------------------------------------- 48 | """ 49 | 50 | 51 | 52 | 53 | import numpy as np 54 | from multiprocessing import Pool 55 | from tqdm import tqdm 56 | from config import cfg 57 | 58 | 59 | """ Utilities for computing, reading and saving benchmark evaluation.""" 60 | 61 | def eval_mask_boundary(seg_mask,gt_mask,num_classes,num_proc=10,bound_th=0.008): 62 | """ 63 | Compute F score for a segmentation mask 64 | 65 | Arguments: 66 | seg_mask (ndarray): segmentation mask prediction 67 | gt_mask (ndarray): segmentation mask ground truth 68 | num_classes (int): number of classes 69 | 70 | Returns: 71 | F (float): mean F score across all classes 72 | Fpc (listof float): F score per class 73 | """ 74 | p = Pool(processes=num_proc) 75 | batch_size = seg_mask.shape[0] 76 | 77 | Fpc = np.zeros(num_classes) 78 | Fc = np.zeros(num_classes) 79 | for class_id in tqdm(range(num_classes)): 80 | args = [((seg_mask[i] == class_id).astype(np.uint8), 81 | (gt_mask[i] == class_id).astype(np.uint8), 82 | gt_mask[i] == cfg.DATASET.IGNORE_LABEL, 83 | bound_th) 84 | for i in range(batch_size)] 85 | temp = p.map(db_eval_boundary_wrapper, args) 86 | temp = np.array(temp) 87 | Fs = temp[:,0] 88 | _valid = ~np.isnan(Fs) 89 | Fc[class_id] = np.sum(_valid) 90 | Fs[np.isnan(Fs)] = 0 91 | Fpc[class_id] = sum(Fs) 92 | return Fpc, Fc 93 | 94 | 95 | #def db_eval_boundary_wrapper_wrapper(args): 96 | # seg_mask, gt_mask, class_id, batch_size, Fpc = args 97 | # print("class_id:" + str(class_id)) 98 | # p = Pool(processes=10) 99 | # args = [((seg_mask[i] == class_id).astype(np.uint8), 100 | # (gt_mask[i] == class_id).astype(np.uint8)) 101 | # for i in range(batch_size)] 102 | # Fs = p.map(db_eval_boundary_wrapper, args) 103 | # Fpc[class_id] = sum(Fs) 104 | # return 105 | 106 | def db_eval_boundary_wrapper(args): 107 | foreground_mask, gt_mask, ignore, bound_th = args 108 | return db_eval_boundary(foreground_mask, gt_mask,ignore, bound_th) 109 | 110 | def db_eval_boundary(foreground_mask,gt_mask, ignore_mask,bound_th=0.008): 111 | """ 112 | Compute mean,recall and decay from per-frame evaluation. 113 | Calculates precision/recall for boundaries between foreground_mask and 114 | gt_mask using morphological operators to speed it up. 115 | 116 | Arguments: 117 | foreground_mask (ndarray): binary segmentation image. 118 | gt_mask (ndarray): binary annotated image. 119 | 120 | Returns: 121 | F (float): boundaries F-measure 122 | P (float): boundaries precision 123 | R (float): boundaries recall 124 | """ 125 | assert np.atleast_3d(foreground_mask).shape[2] == 1 126 | 127 | bound_pix = bound_th if bound_th >= 1 else \ 128 | np.ceil(bound_th*np.linalg.norm(foreground_mask.shape)) 129 | 130 | #print(bound_pix) 131 | #print(gt.shape) 132 | #print(np.unique(gt)) 133 | foreground_mask[ignore_mask] = 0 134 | gt_mask[ignore_mask] = 0 135 | 136 | # Get the pixel boundaries of both masks 137 | fg_boundary = seg2bmap(foreground_mask); 138 | gt_boundary = seg2bmap(gt_mask); 139 | 140 | from skimage.morphology import binary_dilation,disk 141 | 142 | fg_dil = binary_dilation(fg_boundary,disk(bound_pix)) 143 | gt_dil = binary_dilation(gt_boundary,disk(bound_pix)) 144 | 145 | # Get the intersection 146 | gt_match = gt_boundary * fg_dil 147 | fg_match = fg_boundary * gt_dil 148 | 149 | # Area of the intersection 150 | n_fg = np.sum(fg_boundary) 151 | n_gt = np.sum(gt_boundary) 152 | 153 | #% Compute precision and recall 154 | if n_fg == 0 and n_gt > 0: 155 | precision = 1 156 | recall = 0 157 | elif n_fg > 0 and n_gt == 0: 158 | precision = 0 159 | recall = 1 160 | elif n_fg == 0 and n_gt == 0: 161 | precision = 1 162 | recall = 1 163 | else: 164 | precision = np.sum(fg_match)/float(n_fg) 165 | recall = np.sum(gt_match)/float(n_gt) 166 | 167 | # Compute F measure 168 | if precision + recall == 0: 169 | F = 0 170 | else: 171 | F = 2*precision*recall/(precision+recall); 172 | 173 | return F, precision 174 | 175 | def seg2bmap(seg,width=None,height=None): 176 | """ 177 | From a segmentation, compute a binary boundary map with 1 pixel wide 178 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 179 | origin from the actual segment boundary. 180 | 181 | Arguments: 182 | seg : Segments labeled from 1..k. 183 | width : Width of desired bmap <= seg.shape[1] 184 | height : Height of desired bmap <= seg.shape[0] 185 | 186 | Returns: 187 | bmap (ndarray): Binary boundary map. 188 | 189 | David Martin 190 | January 2003 191 | """ 192 | 193 | seg = seg.astype(np.bool) 194 | seg[seg>0] = 1 195 | 196 | assert np.atleast_3d(seg).shape[2] == 1 197 | 198 | width = seg.shape[1] if width is None else width 199 | height = seg.shape[0] if height is None else height 200 | 201 | h,w = seg.shape[:2] 202 | 203 | ar1 = float(width) / float(height) 204 | ar2 = float(w) / float(h) 205 | 206 | assert not (width>w | height>h | abs(ar1-ar2)>0.01),\ 207 | 'Can''t convert %dx%d seg to %dx%d bmap.'%(w,h,width,height) 208 | 209 | e = np.zeros_like(seg) 210 | s = np.zeros_like(seg) 211 | se = np.zeros_like(seg) 212 | 213 | e[:,:-1] = seg[:,1:] 214 | s[:-1,:] = seg[1:,:] 215 | se[:-1,:-1] = seg[1:,1:] 216 | 217 | b = seg^e | seg^s | seg^se 218 | b[-1,:] = seg[-1,:]^e[-1,:] 219 | b[:,-1] = seg[:,-1]^s[:,-1] 220 | b[-1,-1] = 0 221 | 222 | if w == width and h == height: 223 | bmap = b 224 | else: 225 | bmap = np.zeros((height,width)) 226 | for x in range(w): 227 | for y in range(h): 228 | if b[y,x]: 229 | j = 1+floor((y-1)+height / h) 230 | i = 1+floor((x-1)+width / h) 231 | bmap[j,i] = 1; 232 | 233 | return bmap 234 | -------------------------------------------------------------------------------- /utils/my_data_parallel.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s 34 | """ 35 | 36 | 37 | import operator 38 | import torch 39 | import warnings 40 | from torch.nn.modules import Module 41 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 42 | from torch.nn.parallel.replicate import replicate 43 | from torch.nn.parallel.parallel_apply import parallel_apply 44 | 45 | 46 | def _check_balance(device_ids): 47 | imbalance_warn = """ 48 | There is an imbalance between your GPUs. You may want to exclude GPU {} which 49 | has less than 75% of the memory or cores of GPU {}. You can do so by setting 50 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES 51 | environment variable.""" 52 | 53 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] 54 | 55 | def warn_imbalance(get_prop): 56 | values = [get_prop(props) for props in dev_props] 57 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) 58 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) 59 | if min_val / max_val < 0.75: 60 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) 61 | return True 62 | return False 63 | 64 | if warn_imbalance(lambda props: props.total_memory): 65 | return 66 | if warn_imbalance(lambda props: props.multi_processor_count): 67 | return 68 | 69 | 70 | 71 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, gather=True): 72 | """ 73 | Evaluates module(input) in parallel across the GPUs given in device_ids. 74 | This is the functional version of the DataParallel module. 75 | Args: 76 | module: the module to evaluate in parallel 77 | inputs: inputs to the module 78 | device_ids: GPU ids on which to replicate module 79 | output_device: GPU location of the output Use -1 to indicate the CPU. 80 | (default: device_ids[0]) 81 | Returns: 82 | a Tensor containing the result of module(input) located on 83 | output_device 84 | """ 85 | if not isinstance(inputs, tuple): 86 | inputs = (inputs,) 87 | 88 | if device_ids is None: 89 | device_ids = list(range(torch.cuda.device_count())) 90 | 91 | if output_device is None: 92 | output_device = device_ids[0] 93 | 94 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) 95 | if len(device_ids) == 1: 96 | return module(*inputs[0], **module_kwargs[0]) 97 | used_device_ids = device_ids[:len(inputs)] 98 | replicas = replicate(module, used_device_ids) 99 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) 100 | if gather: 101 | return gather(outputs, output_device, dim) 102 | else: 103 | return outputs 104 | 105 | 106 | 107 | class MyDataParallel(Module): 108 | """ 109 | Implements data parallelism at the module level. 110 | This container parallelizes the application of the given module by 111 | splitting the input across the specified devices by chunking in the batch 112 | dimension. In the forward pass, the module is replicated on each device, 113 | and each replica handles a portion of the input. During the backwards 114 | pass, gradients from each replica are summed into the original module. 115 | The batch size should be larger than the number of GPUs used. 116 | See also: :ref:`cuda-nn-dataparallel-instead` 117 | Arbitrary positional and keyword inputs are allowed to be passed into 118 | DataParallel EXCEPT Tensors. All tensors will be scattered on dim 119 | specified (default 0). Primitive types will be broadcasted, but all 120 | other types will be a shallow copy and can be corrupted if written to in 121 | the model's forward pass. 122 | .. warning:: 123 | Forward and backward hooks defined on :attr:`module` and its submodules 124 | will be invoked ``len(device_ids)`` times, each with inputs located on 125 | a particular device. Particularly, the hooks are only guaranteed to be 126 | executed in correct order with respect to operations on corresponding 127 | devices. For example, it is not guaranteed that hooks set via 128 | :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before 129 | `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but 130 | that each such hook be executed before the corresponding 131 | :meth:`~torch.nn.Module.forward` call of that device. 132 | .. warning:: 133 | When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in 134 | :func:`forward`, this wrapper will return a vector of length equal to 135 | number of devices used in data parallelism, containing the result from 136 | each device. 137 | .. note:: 138 | There is a subtlety in using the 139 | ``pack sequence -> recurrent network -> unpack sequence`` pattern in a 140 | :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. 141 | See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for 142 | details. 143 | Args: 144 | module: module to be parallelized 145 | device_ids: CUDA devices (default: all devices) 146 | output_device: device location of output (default: device_ids[0]) 147 | Attributes: 148 | module (Module): the module to be parallelized 149 | Example:: 150 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 151 | >>> output = net(input_var) 152 | """ 153 | 154 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well 155 | 156 | def __init__(self, module, device_ids=None, output_device=None, dim=0, gather=True): 157 | super(MyDataParallel, self).__init__() 158 | 159 | if not torch.cuda.is_available(): 160 | self.module = module 161 | self.device_ids = [] 162 | return 163 | 164 | if device_ids is None: 165 | device_ids = list(range(torch.cuda.device_count())) 166 | if output_device is None: 167 | output_device = device_ids[0] 168 | self.dim = dim 169 | self.module = module 170 | self.device_ids = device_ids 171 | self.output_device = output_device 172 | self.gather_bool = gather 173 | 174 | _check_balance(self.device_ids) 175 | 176 | if len(self.device_ids) == 1: 177 | self.module.cuda(device_ids[0]) 178 | 179 | def forward(self, *inputs, **kwargs): 180 | if not self.device_ids: 181 | return self.module(*inputs, **kwargs) 182 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 183 | if len(self.device_ids) == 1: 184 | return [self.module(*inputs[0], **kwargs[0])] 185 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 186 | outputs = self.parallel_apply(replicas, inputs, kwargs) 187 | if self.gather_bool: 188 | return self.gather(outputs, self.output_device) 189 | else: 190 | return outputs 191 | 192 | def replicate(self, module, device_ids): 193 | return replicate(module, device_ids) 194 | 195 | def scatter(self, inputs, kwargs, device_ids): 196 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 197 | 198 | def parallel_apply(self, replicas, inputs, kwargs): 199 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 200 | 201 | def gather(self, outputs, output_device): 202 | return gather(outputs, output_device, dim=self.dim) 203 | 204 | -------------------------------------------------------------------------------- /utils/results_page.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | 31 | import glob 32 | import os 33 | import numpy as np 34 | 35 | id2cat = { 36 | 0: 'road', 37 | 1: 'sidewalk', 38 | 2: 'building', 39 | 3: 'wall', 40 | 4: 'fence', 41 | 5: 'pole', 42 | 6: 'traffic_light', 43 | 7: 'traffic_sign', 44 | 8: 'vegetation', 45 | 9: 'terrain', 46 | 10: 'sky', 47 | 11: 'person', 48 | 12: 'rider', 49 | 13: 'car', 50 | 14: 'truck', 51 | 15: 'bus', 52 | 16: 'train', 53 | 17: 'motorcycle', 54 | 18: 'bicycle'} 55 | 56 | # Leaderboard mapillary 57 | sota_iu_results = { 58 | 0: 98.4046, 59 | 1: 85.0224, 60 | 2: 93.6462, 61 | 3: 61.7487, 62 | 4: 63.8885, 63 | 5: 67.6745, 64 | 6: 77.43, 65 | 7: 80.8351, 66 | 8: 93.7341, 67 | 9: 71.8774, 68 | 10: 95.6122, 69 | 11: 86.7228, 70 | 12: 72.7778, 71 | 13: 95.7033, 72 | 14: 79.9019, 73 | 15: 93.0954, 74 | 16: 89.7196, 75 | 17: 72.5731, 76 | 18: 78.2172, 77 | 255: 0} 78 | 79 | 80 | class ResultsPage(object): 81 | ''' 82 | This creates an HTML page of embedded images, useful for showing evaluation results. 83 | 84 | Usage: 85 | ip = ImagePage(html_fn) 86 | 87 | # Add a table with N images ... 88 | ip.add_table((img, descr), (img, descr), ...) 89 | 90 | # Generate html page 91 | ip.write_page() 92 | ''' 93 | 94 | def __init__(self, experiment_name, html_filename): 95 | self.experiment_name = experiment_name 96 | self.html_filename = html_filename 97 | self.outfile = open(self.html_filename, 'w') 98 | self.items = [] 99 | 100 | def _print_header(self): 101 | header = ''' 102 | 103 | 104 | Experiment = {} 105 | 106 | '''.format(self.experiment_name) 107 | self.outfile.write(header) 108 | 109 | def _print_footer(self): 110 | self.outfile.write(''' 111 | ''') 112 | 113 | def _print_table_header(self, table_name): 114 | table_hdr = '''

{}

115 | 116 | '''.format(table_name) 117 | self.outfile.write(table_hdr) 118 | 119 | def _print_table_footer(self): 120 | table_ftr = ''' 121 |
''' 122 | self.outfile.write(table_ftr) 123 | 124 | def _print_table_guts(self, img_fn, descr): 125 | table = ''' 126 |

127 | 128 | 129 |
130 |

{descr}

131 |

132 | '''.format(img_fn=img_fn, descr=descr) 133 | self.outfile.write(table) 134 | 135 | def add_table(self, img_label_pairs, table_heading=''): 136 | """ 137 | :img_label_pairs: A list of pairs of [img,label] 138 | """ 139 | self.items.append([img_label_pairs, table_heading]) 140 | 141 | def _write_table(self, table, heading): 142 | img, _descr = table[0] 143 | self._print_table_header(heading) 144 | for img, descr in table: 145 | self._print_table_guts(img, descr) 146 | self._print_table_footer() 147 | 148 | def write_page(self): 149 | self._print_header() 150 | 151 | for table, heading in self.items: 152 | self._write_table(table, heading) 153 | 154 | self._print_footer() 155 | 156 | def _print_page_start(self): 157 | page_start = ''' 158 | 159 | 160 | Experiment = EXP_NAME 161 | 171 | 172 | ''' 173 | self.outfile.write(page_start) 174 | 175 | def _print_table_start(self, caption, hdr): 176 | self.outfile.write(''' 177 | 178 | '''.format(caption)) 179 | for hdr_col in hdr: 180 | self.outfile.write(' '.format(hdr_col)) 181 | self.outfile.write(' ') 182 | 183 | def _print_table_row(self, row): 184 | self.outfile.write(' ') 185 | for i in row: 186 | self.outfile.write(' '.format(i)) 187 | # Create Links 188 | fp_link = 'false positive Top N'.format(row[ 189 | 1]) 190 | fn_link = 'false_negative Top N'.format(row[ 191 | 1]) 192 | self.outfile.write(' '.format(fp_link)) 193 | self.outfile.write(' '.format(fn_link)) 194 | self.outfile.write(' ') 195 | 196 | def _print_table_end(self): 197 | self.outfile.write('
{}
{}
{}{}{}
') 198 | 199 | def _print_page_end(self): 200 | self.outfile.write(''' 201 | 202 | ''') 203 | 204 | def create_main(self, iu, hist): 205 | self._print_page_start() 206 | #_print_table_style() 207 | # Calculate all of the terms: 208 | iu_false_positive = hist.sum(axis=1) - np.diag(hist) 209 | iu_false_negative = hist.sum(axis=0) - np.diag(hist) 210 | iu_true_positive = np.diag(hist) 211 | 212 | hdr = ("Class ID", "Class", "IoU", "Sota-IU", "TP", 213 | "FP", "FN", "precision", "recall", "", "") 214 | self._print_table_start("Mean IoU Results", hdr) 215 | for iu_score, index in iu: 216 | class_name = id2cat[index] 217 | iu_string = '{:5.2f}'.format(iu_score * 100) 218 | total_pixels = hist.sum() 219 | tp = '{:5.2f}'.format(100 * iu_true_positive[index] / total_pixels) 220 | fp = '{:5.2f}'.format( 221 | iu_false_positive[index] / iu_true_positive[index]) 222 | fn = '{:5.2f}'.format( 223 | iu_false_negative[index] / iu_true_positive[index]) 224 | precision = '{:5.2f}'.format( 225 | iu_true_positive[index] / (iu_true_positive[index] + iu_false_positive[index])) 226 | recall = '{:5.2f}'.format( 227 | iu_true_positive[index] / (iu_true_positive[index] + iu_false_negative[index])) 228 | sota = '{:5.2f}'.format(sota_iu_results[index]) 229 | row = (index, class_name, iu_string, sota, 230 | tp, fp, fn, precision, recall) 231 | self._print_table_row(row) 232 | self._print_table_end() 233 | self._print_page_end() 234 | 235 | 236 | def main(): 237 | images = glob.glob('dump_imgs_train/*.png') 238 | images = [i for i in images if 'mask' not in i] 239 | 240 | ip = ResultsPage('test page', 'dd.html') 241 | for img in images: 242 | basename = os.path.splitext(img)[0] 243 | mask_img = basename + '_mask.png' 244 | ip.add_table(((img, 'image'), (mask_img, 'mask'))) 245 | ip.write_page() 246 | --------------------------------------------------------------------------------