├── semantic-segmentation ├── imgs │ ├── munster_000095_000019_leftImg8bit_input.png │ └── composited_munster_000095_000019_leftImg8bit.png ├── network │ ├── bn_helper.py │ ├── __init__.py │ ├── mynn.py │ ├── deeper.py │ ├── basic.py │ ├── deepv3.py │ ├── ocr_utils.py │ ├── Resnet.py │ ├── PSA.py │ ├── mscale2.py │ └── xception.py ├── scripts │ ├── eval_cityscapes.yml │ ├── train_cityscapes_deepv3.yml │ ├── eval_mapillary.yml │ ├── dump_cityscapes.yml │ ├── train_mapillary.yml │ ├── dump_folder.yml │ ├── train_cityscapes.yml │ └── train_cityscapes_sota.yml ├── datasets │ ├── utils.py │ ├── nullloader.py │ ├── mapillary.py │ ├── sampler.py │ ├── randaugment.py │ ├── __init__.py │ ├── base_loader.py │ ├── cityscapes.py │ ├── cityscapes_labels.py │ └── uniform.py ├── PREPARE_DATASETS.md ├── loss │ ├── radam.py │ ├── rmi_utils.py │ ├── optimizer.py │ └── rmi.py ├── transforms │ └── transforms.py └── config.py ├── README.md └── LICENSE /semantic-segmentation/imgs/munster_000095_000019_leftImg8bit_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeLightCMU/PSA/HEAD/semantic-segmentation/imgs/munster_000095_000019_leftImg8bit_input.png -------------------------------------------------------------------------------- /semantic-segmentation/imgs/composited_munster_000095_000019_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeLightCMU/PSA/HEAD/semantic-segmentation/imgs/composited_munster_000095_000019_leftImg8bit.png -------------------------------------------------------------------------------- /semantic-segmentation/network/bn_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | 4 | if torch.__version__.startswith('0'): 5 | from .sync_bn.inplace_abn.bn import InPlaceABNSync 6 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 7 | BatchNorm2d_class = InPlaceABNSync 8 | relu_inplace = False 9 | else: 10 | BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm 11 | relu_inplace = True 12 | -------------------------------------------------------------------------------- /semantic-segmentation/scripts/eval_cityscapes.yml: -------------------------------------------------------------------------------- 1 | # Run Evaluation on Cityscapes with a pretrained model 2 | 3 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 4 | 5 | HPARAMS: [ 6 | { 7 | dataset: cityscapes, 8 | cv: 0, 9 | syncbn: true, 10 | apex: true, 11 | fp16: true, 12 | bs_val: 2, 13 | eval: val, 14 | n_scales: "0.5,1.0,2.0", 15 | snapshot: "ASSETS_PATH/seg_weights/cityscapes_ocrnet.HRNet_Mscale_outstanding-turtle.pth", 16 | arch: ocrnet.HRNet_Mscale, 17 | result_dir: LOGDIR, 18 | }, 19 | ] 20 | -------------------------------------------------------------------------------- /semantic-segmentation/scripts/train_cityscapes_deepv3.yml: -------------------------------------------------------------------------------- 1 | # Train cityscapes with deeplab v3+ and wide-resnet-38 trunk 2 | # Only requires 16GB gpus 3 | 4 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 5 | 6 | HPARAMS: [ 7 | { 8 | dataset: cityscapes, 9 | cv: 0, 10 | syncbn: true, 11 | apex: true, 12 | fp16: true, 13 | crop_size: "800,800", 14 | bs_trn: 1, 15 | poly_exp: 2, 16 | lr: 5e-3, 17 | max_epoch: 175, 18 | arch: deepv3.DeepV3PlusW38, 19 | result_dir: LOGDIR, 20 | RUNX.TAG: '{arch}', 21 | }, 22 | ] 23 | -------------------------------------------------------------------------------- /semantic-segmentation/scripts/eval_mapillary.yml: -------------------------------------------------------------------------------- 1 | # Run Evaluation on Mapillary with a pretrained model 2 | 3 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 4 | 5 | HPARAMS: [ 6 | { 7 | dataset: mapillary, 8 | syncbn: true, 9 | apex: true, 10 | fp16: true, 11 | bs_val: 1, 12 | eval: val, 13 | pre_size: 2177, 14 | amp_opt_level: O3, 15 | n_scales: "0.25,0.5,1.0,2.0", 16 | do_flip: true, 17 | snapshot: "ASSETS_PATH/seg_weights/mapillary_ocrnet.HRNet_Mscale_fast-rattlesnake.pth", 18 | arch: ocrnet.HRNet_Mscale, 19 | result_dir: LOGDIR, 20 | }, 21 | ] 22 | -------------------------------------------------------------------------------- /semantic-segmentation/scripts/dump_cityscapes.yml: -------------------------------------------------------------------------------- 1 | # Run Evaluation and Dump Images on Cityscapes with a pretrained model 2 | 3 | CMD: "python -m torch.distributed.launch --nproc_per_node=1 train.py" 4 | 5 | HPARAMS: [ 6 | { 7 | dataset: cityscapes, 8 | cv: 0, 9 | syncbn: true, 10 | apex: true, 11 | fp16: true, 12 | bs_val: 1, 13 | eval: val, 14 | dump_assets: true, 15 | dump_all_images: true, 16 | n_scales: "0.5,1.0,2.0", 17 | snapshot: "ASSETS_PATH/seg_weights/cityscapes_ocrnet.HRNet_Mscale_outstanding-turtle.pth", 18 | arch: ocrnet.HRNet_Mscale, 19 | result_dir: LOGDIR, 20 | }, 21 | ] 22 | -------------------------------------------------------------------------------- /semantic-segmentation/scripts/train_mapillary.yml: -------------------------------------------------------------------------------- 1 | # Single node Mapillary training recipe 2 | # Requires 32GB GPU 3 | 4 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 5 | 6 | hparams: [ 7 | { 8 | dataset: mapillary, 9 | cv: 0, 10 | result_dir: LOGDIR, 11 | 12 | pre_size: 2177, 13 | crop_size: "1024,1024", 14 | syncbn: true, 15 | apex: true, 16 | fp16: true, 17 | gblur: true, 18 | 19 | bs_trn: 2, 20 | 21 | lr_schedule: poly, 22 | poly_exp: 1.0, 23 | optimizer: sgd, 24 | lr: 5e-3, 25 | max_epoch: 200, 26 | rmi_loss: true, 27 | 28 | arch: ocrnet.HRNet_Mscale, 29 | n_scales: '0.5,1.0,2.0', 30 | } 31 | ] 32 | -------------------------------------------------------------------------------- /semantic-segmentation/scripts/dump_folder.yml: -------------------------------------------------------------------------------- 1 | # Run Evaluation and Dump Images on Cityscapes with a pretrained model 2 | 3 | CMD: "python -m torch.distributed.launch --nproc_per_node=1 train.py" 4 | 5 | HPARAMS: [ 6 | { 7 | dataset: cityscapes, 8 | cv: 0, 9 | syncbn: true, 10 | apex: true, 11 | fp16: true, 12 | bs_val: 1, 13 | eval: folder, 14 | eval_folder: './large_assets/data/cityscapes/leftImg8bit_trainvaltest/leftImg8bit/test/', 15 | dump_assets: true, 16 | dump_all_images: true, 17 | n_scales: "0.5,1.0,2.0", 18 | snapshot: "ASSETS_PATH/seg_weights/cityscapes_ocrnet.HRNet_Mscale_outstanding-turtle.pth", 19 | arch: ocrnet.HRNet_Mscale, 20 | result_dir: LOGDIR, 21 | }, 22 | ] 23 | -------------------------------------------------------------------------------- /semantic-segmentation/scripts/train_cityscapes.yml: -------------------------------------------------------------------------------- 1 | # Train cityscapes using Mapillary-pretrained weights 2 | # Requires 32GB GPU 3 | # Adjust nproc_per_node according to how many GPUs you have 4 | 5 | CMD: "python -m torch.distributed.launch --nproc_per_node=8 train.py" 6 | 7 | HPARAMS: [ 8 | { 9 | dataset: cityscapes, 10 | cv: 0, 11 | syncbn: true, 12 | apex: true, 13 | fp16: true, 14 | crop_size: "1024,2048", 15 | bs_trn: 1, 16 | poly_exp: 2, 17 | lr: 5e-3, 18 | rmi_loss: true, 19 | max_epoch: 175, 20 | n_scales: "0.5,1.0,2.0", 21 | supervised_mscale_loss_wt: 0.05, 22 | snapshot: "ASSETS_PATH/seg_weights/ocrnet.HRNet_industrious-chicken.pth", 23 | arch: ocrnet.HRNet_Mscale, 24 | result_dir: LOGDIR, 25 | RUNX.TAG: '{arch}', 26 | }, 27 | ] 28 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def make_dataset_folder(folder): 5 | """ 6 | Create Filename list for images in the provided path 7 | 8 | input: path to directory with *only* images files 9 | returns: items list with None filled for mask path 10 | """ 11 | items = os.listdir(folder) 12 | items = [(os.path.join(folder, f), '') for f in items] 13 | items = sorted(items) 14 | 15 | print(f'Found {len(items)} folder imgs') 16 | 17 | """ 18 | orig_len = len(items) 19 | rem = orig_len % 8 20 | if rem != 0: 21 | items = items[:-rem] 22 | 23 | msg = 'Found {} folder imgs but altered to {} to be modulo-8' 24 | msg = msg.format(orig_len, len(items)) 25 | print(msg) 26 | """ 27 | 28 | return items 29 | -------------------------------------------------------------------------------- /semantic-segmentation/scripts/train_cityscapes_sota.yml: -------------------------------------------------------------------------------- 1 | # Train cityscapes using Mapillary-pretrained weights 2 | # Requires 32GB GPU 3 | # Adjust nproc_per_node according to how many GPUs you have 4 | 5 | CMD: "python -m torch.distributed.launch --nproc_per_node=16 train.py" 6 | 7 | HPARAMS: [ 8 | { 9 | dataset: cityscapes, 10 | cv: 0, 11 | syncbn: true, 12 | apex: true, 13 | fp16: true, 14 | crop_size: "1024,2048", 15 | bs_trn: 1, 16 | poly_exp: 2, 17 | lr: 1e-2, 18 | max_epoch: 175, 19 | max_cu_epoch: 150, 20 | rmi_loss: true, 21 | n_scales: ['0.5,1.0,2.0'], 22 | supervised_mscale_loss_wt: 0.05, 23 | 24 | arch: ocrnet.HRNet_Mscale, 25 | snapshot: "ASSETS_PATH/seg_weights/ocrnet.HRNet_industrious-chicken.pth", 26 | result_dir: LOGDIR, 27 | RUNX.TAG: 'sota-cv0-{arch}', 28 | 29 | coarse_boost_classes: "3,4,6,7,9,11,12,13,14,15,16,17,18", 30 | custom_coarse_dropout_classes: "14,15,16", 31 | mask_out_cityscapes: true, 32 | custom_coarse_prob: 0.5, 33 | }, 34 | ] 35 | -------------------------------------------------------------------------------- /semantic-segmentation/network/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Network Initializations 3 | """ 4 | 5 | import importlib 6 | import torch 7 | 8 | from runx.logx import logx 9 | from config import cfg 10 | 11 | 12 | def get_net(args, criterion): 13 | """ 14 | Get Network Architecture based on arguments provided 15 | """ 16 | net = get_model(network='network.' + args.arch, 17 | num_classes=cfg.DATASET.NUM_CLASSES, 18 | criterion=criterion) 19 | num_params = sum([param.nelement() for param in net.parameters()]) 20 | logx.msg('Model params = {:2.1f}M'.format(num_params / 1000000)) 21 | 22 | net = net.cuda() 23 | return net 24 | 25 | 26 | def is_gscnn_arch(args): 27 | """ 28 | Network is a GSCNN network 29 | """ 30 | return 'gscnn' in args.arch 31 | 32 | 33 | def wrap_network_in_dataparallel(net, use_apex_data_parallel=False): 34 | """ 35 | Wrap the network in Dataparallel 36 | """ 37 | if use_apex_data_parallel: 38 | import apex 39 | net = apex.parallel.DistributedDataParallel(net) 40 | else: 41 | net = torch.nn.DataParallel(net) 42 | return net 43 | 44 | 45 | def get_model(network, num_classes, criterion): 46 | """ 47 | Fetch Network Function Pointer 48 | """ 49 | module = network[:network.rfind('.')] 50 | model = network[network.rfind('.') + 1:] 51 | mod = importlib.import_module(module) 52 | net_func = getattr(mod, model) 53 | net = net_func(num_classes=num_classes, criterion=criterion) 54 | return net 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Polarized Self-Attention: Towards High-quality Pixel-wise Regression 2 | This is an official implementation of: 3 | 4 | **Huajun Liu, Fuqiang Liu, Xinyi Fan and Dong Huang**. ***Polarized Self-Attention: Towards High-quality Pixel-wise Regression*** [Arxiv Version](https://arxiv.org/abs/2107.00782) 5 | 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/polarized-self-attention-towards-high-quality-1/pose-estimation-on-coco-test-dev)](https://paperswithcode.com/sota/pose-estimation-on-coco-test-dev?p=polarized-self-attention-towards-high-quality-1) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/polarized-self-attention-towards-high-quality-1/keypoint-detection-on-coco)](https://paperswithcode.com/sota/keypoint-detection-on-coco?p=polarized-self-attention-towards-high-quality-1) 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/polarized-self-attention-towards-high-quality-1/semantic-segmentation-on-cityscapes-val)](https://paperswithcode.com/sota/semantic-segmentation-on-cityscapes-val?p=polarized-self-attention-towards-high-quality-1) 9 | 10 | ### Citation: 11 | 12 | ```bash 13 | @article{Liu2021PSA, 14 | title={Polarized Self-Attention: Towards High-quality Pixel-wise Regression}, 15 | author={Huajun Liu and Fuqiang Liu and Xinyi Fan and Dong Huang}, 16 | journal={Arxiv Pre-Print arXiv:2107.00782 }, 17 | year={2021} 18 | } 19 | ``` 20 | 21 | ## Codes and Pre-trained models will be uploaded soon~ 22 | 23 | ### Top-down 2D pose estimation models pre-trained on the MS-COCO keypoint task(Table4 in the Arxiv version). 24 | 25 | | Model Name | Backbone |Input Size | AP | pth file | 26 | | :----------------------: | :---------------------:| :--------------: | :--------------: | :------------: | 27 | | UDP-Pose-PSA(p) | HRNet-W48 |256x192 |78.9 | to be uploaded | 28 | | UDP-Pose-PSA(p) | HRNet-W48 |384x288 |79.5 | to be uploaded | 29 | | UDP-Pose-PSA(s) | HRNet-W48 |384x288 |79.4 |to be uploaded | 30 | 31 | #### Setup and inference: 32 | 33 | 34 | ### Semantic segmentation models pre-trained on Cityscapes (Table5 in the Arxiv version). 35 | 36 | | Model Name | Backbone | val mIoU | pth file | 37 | | :----------------------: | :---------------------:| :--------------: | :------------: | 38 | | HRNetV2-OCR+PSA(p) | HRNetV2-W48 |86.95 | [download](https://cmu.box.com/s/if90kw6r66q2y6c5xparflhnbwi6c2yi) | 39 | | HRNetV2-OCR+PSA(s) | HRNetV2-W48 |86.72 | [download](https://cmu.box.com/s/uyzzfmkx8p2ipcznpzdtf14ng63s65sq) | 40 | 41 | #### Setup and inference: 42 | 43 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/nullloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | Null Loader 31 | """ 32 | from config import cfg 33 | from runx.logx import logx 34 | from datasets.base_loader import BaseLoader 35 | from datasets.utils import make_dataset_folder 36 | from datasets import uniform 37 | import numpy as np 38 | import torch 39 | from torch.utils import data 40 | 41 | class Loader(BaseLoader): 42 | """ 43 | Null Dataset for Performance 44 | """ 45 | num_classes = 19 46 | ignore_label = 255 47 | trainid_to_name = {} 48 | color_mapping = [] 49 | 50 | def __init__(self, mode, quality=None, joint_transform_list=None, 51 | img_transform=None, label_transform=None, eval_folder=None): 52 | super(Loader, self).__init__(quality=quality, 53 | mode=mode, 54 | joint_transform_list=joint_transform_list, 55 | img_transform=img_transform, 56 | label_transform=label_transform) 57 | 58 | def __getitem__(self, index): 59 | # return img, mask, img_name, scale_float 60 | crop_size = cfg.DATASET.CROP_SIZE 61 | if ',' in crop_size: 62 | crop_size = [int(x) for x in crop_size.split(',')] 63 | else: 64 | crop_size = int(crop_size) 65 | crop_size = [crop_size, crop_size] 66 | 67 | img = torch.FloatTensor(np.zeros([3] + crop_size)) 68 | mask = torch.LongTensor(np.zeros(crop_size)) 69 | img_name = f'img{index}' 70 | scale_float = 0.0 71 | return img, mask, img_name, scale_float 72 | 73 | def __len__(self): 74 | return 3000 75 | -------------------------------------------------------------------------------- /semantic-segmentation/network/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Norm wrappers to enable sync BN, regular BN and for weight 3 | initialization 4 | """ 5 | import re 6 | import torch 7 | import torch.nn as nn 8 | from config import cfg 9 | 10 | from apex import amp 11 | 12 | from runx.logx import logx 13 | 14 | 15 | align_corners = cfg.MODEL.ALIGN_CORNERS 16 | 17 | 18 | def Norm2d(in_channels, **kwargs): 19 | """ 20 | Custom Norm Function to allow flexible switching 21 | """ 22 | layer = getattr(cfg.MODEL, 'BNFUNC') 23 | normalization_layer = layer(in_channels, **kwargs) 24 | return normalization_layer 25 | 26 | 27 | def initialize_weights(*models): 28 | """ 29 | Initialize Model Weights 30 | """ 31 | for model in models: 32 | for module in model.modules(): 33 | if isinstance(module, (nn.Conv2d, nn.Linear)): 34 | nn.init.kaiming_normal_(module.weight) 35 | if module.bias is not None: 36 | module.bias.data.zero_() 37 | elif isinstance(module, cfg.MODEL.BNFUNC): 38 | module.weight.data.fill_(1) 39 | module.bias.data.zero_() 40 | 41 | 42 | @amp.float_function 43 | def Upsample(x, size): 44 | """ 45 | Wrapper Around the Upsample Call 46 | """ 47 | return nn.functional.interpolate(x, size=size, mode='bilinear', 48 | align_corners=align_corners) 49 | 50 | 51 | @amp.float_function 52 | def Upsample2(x): 53 | """ 54 | Wrapper Around the Upsample Call 55 | """ 56 | return nn.functional.interpolate(x, scale_factor=2, mode='bilinear', 57 | align_corners=align_corners) 58 | 59 | 60 | def Down2x(x): 61 | return torch.nn.functional.interpolate( 62 | x, scale_factor=0.5, mode='bilinear', align_corners=align_corners) 63 | 64 | 65 | def Up15x(x): 66 | return torch.nn.functional.interpolate( 67 | x, scale_factor=1.5, mode='bilinear', align_corners=align_corners) 68 | 69 | 70 | def scale_as(x, y): 71 | ''' 72 | scale x to the same size as y 73 | ''' 74 | y_size = y.size(2), y.size(3) 75 | 76 | if cfg.OPTIONS.TORCH_VERSION >= 1.5: 77 | x_scaled = torch.nn.functional.interpolate( 78 | x, size=y_size, mode='bilinear', 79 | align_corners=align_corners) 80 | else: 81 | x_scaled = torch.nn.functional.interpolate( 82 | x, size=y_size, mode='bilinear', 83 | align_corners=align_corners) 84 | return x_scaled 85 | 86 | 87 | def DownX(x, scale_factor): 88 | ''' 89 | scale x to the same size as y 90 | ''' 91 | if cfg.OPTIONS.TORCH_VERSION >= 1.5: 92 | x_scaled = torch.nn.functional.interpolate( 93 | x, scale_factor=scale_factor, mode='bilinear', 94 | align_corners=align_corners, recompute_scale_factor=True) 95 | else: 96 | x_scaled = torch.nn.functional.interpolate( 97 | x, scale_factor=scale_factor, mode='bilinear', 98 | align_corners=align_corners) 99 | return x_scaled 100 | 101 | 102 | def ResizeX(x, scale_factor): 103 | ''' 104 | scale x by some factor 105 | ''' 106 | if cfg.OPTIONS.TORCH_VERSION >= 1.5: 107 | x_scaled = torch.nn.functional.interpolate( 108 | x, scale_factor=scale_factor, mode='bilinear', 109 | align_corners=align_corners, recompute_scale_factor=True) 110 | else: 111 | x_scaled = torch.nn.functional.interpolate( 112 | x, scale_factor=scale_factor, mode='bilinear', 113 | align_corners=align_corners) 114 | return x_scaled 115 | -------------------------------------------------------------------------------- /semantic-segmentation/network/deeper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | import torch 31 | from torch import nn 32 | from network.mynn import Upsample2 33 | from network.utils import ConvBnRelu, get_trunk, get_aspp 34 | 35 | 36 | class DeeperS8(nn.Module): 37 | """ 38 | Panoptic DeepLab-style semantic segmentation network 39 | stride8 only 40 | """ 41 | def __init__(self, num_classes, trunk='wrn38', criterion=None): 42 | super(DeeperS8, self).__init__() 43 | 44 | self.criterion = criterion 45 | self.trunk, s2_ch, s4_ch, high_level_ch = get_trunk(trunk_name=trunk, 46 | output_stride=8) 47 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, bottleneck_ch=256, 48 | output_stride=8) 49 | 50 | self.convs2 = nn.Conv2d(s2_ch, 32, kernel_size=1, bias=False) 51 | self.convs4 = nn.Conv2d(s4_ch, 64, kernel_size=1, bias=False) 52 | self.conv_up1 = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) 53 | self.conv_up2 = ConvBnRelu(256 + 64, 256, kernel_size=5, padding=2) 54 | self.conv_up3 = ConvBnRelu(256 + 32, 256, kernel_size=5, padding=2) 55 | self.conv_up5 = nn.Conv2d(256, num_classes, kernel_size=1, bias=False) 56 | 57 | def forward(self, inputs, gts=None): 58 | assert 'images' in inputs 59 | x = inputs['images'] 60 | 61 | s2_features, s4_features, final_features = self.trunk(x) 62 | s2_features = self.convs2(s2_features) 63 | s4_features = self.convs4(s4_features) 64 | aspp = self.aspp(final_features) 65 | x = self.conv_up1(aspp) 66 | x = Upsample2(x) 67 | x = torch.cat([x, s4_features], 1) 68 | x = self.conv_up2(x) 69 | x = Upsample2(x) 70 | x = torch.cat([x, s2_features], 1) 71 | x = self.conv_up3(x) 72 | x = self.conv_up5(x) 73 | x = Upsample2(x) 74 | 75 | if self.training: 76 | assert 'gts' in inputs 77 | gts = inputs['gts'] 78 | return self.criterion(x, gts) 79 | return {'pred': x} 80 | 81 | 82 | def DeeperW38(num_classes, criterion, s2s4=True): 83 | return DeeperS8(num_classes, criterion=criterion, trunk='wrn38') 84 | 85 | 86 | def DeeperX71(num_classes, criterion, s2s4=True): 87 | return DeeperS8(num_classes, criterion=criterion, trunk='xception71') 88 | 89 | 90 | def DeeperEffB4(num_classes, criterion, s2s4=True): 91 | return DeeperS8(num_classes, criterion=criterion, trunk='efficientnet_b4') 92 | -------------------------------------------------------------------------------- /semantic-segmentation/PREPARE_DATASETS.md: -------------------------------------------------------------------------------- 1 | ## Mapillary Vistas Dataset 2 | 3 | First of all, please request the research edition dataset from [here](https://www.mapillary.com/dataset/vistas/). The downloaded file is named as `mapillary-vistas-dataset_public_v1.1.zip`. 4 | 5 | Then simply unzip the file by 6 | ```shell 7 | unzip mapillary-vistas-dataset_public_v1.1.zip 8 | ``` 9 | 10 | The folder structure will look like: 11 | ``` 12 | Mapillary 13 | ├── config.json 14 | ├── demo.py 15 | ├── Mapillary Vistas Research Edition License.pdf 16 | ├── README 17 | ├── requirements.txt 18 | ├── training 19 | │ ├── images 20 | │ ├── instances 21 | │ ├── labels 22 | │ ├── panoptic 23 | ├── validation 24 | │ ├── images 25 | │ ├── instances 26 | │ ├── labels 27 | │ ├── panoptic 28 | ├── testing 29 | │ ├── images 30 | │ ├── instances 31 | │ ├── labels 32 | │ ├── panoptic 33 | ``` 34 | Note that, the `instances`, `labels` and `panoptic` folders inside `testing` are empty. 35 | 36 | Suppose you store your dataset at `~/username/data/Mapillary`, please update the dataset path in `config.py`, 37 | ``` 38 | __C.DATASET.MAPILLARY_DIR = '~/username/data/Mapillary' 39 | ``` 40 | 41 | ## Cityscapes Dataset 42 | 43 | ### Download Dataset 44 | First of all, please request the dataset from [here](https://www.cityscapes-dataset.com/). You need multiple files. 45 | ``` 46 | - leftImg8bit_trainvaltest.zip 47 | - gtFine_trainvaltest.zip 48 | - leftImg8bit_trainextra.zip 49 | - gtCoarse.zip 50 | - refinement_final_v0.zip [link] (https://drive.google.com/file/d/1DtPo-WP-hjaOwsbj6ZxTtOo_7R_4TKRG/) # This file is only needed for autolabelled training for recreating SOTA 51 | ``` 52 | 53 | If you prefer to use command lines (e.g., `wget`) to download the dataset, 54 | ``` 55 | # First step, obtain your login credentials. 56 | Please register an account at https://www.cityscapes-dataset.com/login/. 57 | 58 | # Second step, log into cityscapes system, suppose you already have a USERNAME and a PASSWORD. 59 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=USERNAME&password=PASSWORD&submit=Login' https://www.cityscapes-dataset.com/login/ 60 | 61 | # Third step, download the zip files you need. 62 | wget -c -t 0 --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 63 | 64 | # The corresponding packageID is listed below, 65 | 1 -> gtFine_trainvaltest.zip (241MB) md5sum: 4237c19de34c8a376e9ba46b495d6f66 66 | 2 -> gtCoarse.zip (1.3GB) md5sum: 1c7b95c84b1d36cc59a9194d8e5b989f 67 | 3 -> leftImg8bit_trainvaltest.zip (11GB) md5sum: 0a6e97e94b616a514066c9e2adb0c97f 68 | 4 -> leftImg8bit_trainextra.zip (44GB) md5sum: 9167a331a158ce3e8989e166c95d56d4 69 | 5 -> refinement_final_v0.zip (5GB) md5sum: 82aa6698ef7358457894c7cc924534fb 70 | ``` 71 | 72 | ### Prepare Folder Structure 73 | 74 | Now unzip those files, the desired folder structure will look like, 75 | ``` 76 | Cityscapes 77 | ├── leftImg8bit_trainvaltest 78 | │ ├── leftImg8bit 79 | │ │ ├── train 80 | │ │ │ ├── aachen 81 | │ │ │ │ ├── aachen_000000_000019_leftImg8bit.png 82 | │ │ │ │ ├── aachen_000001_000019_leftImg8bit.png 83 | │ │ │ │ ├── ... 84 | │ │ │ ├── bochum 85 | │ │ │ ├── ... 86 | │ │ ├── val 87 | │ │ ├── test 88 | ├── gtFine_trainvaltest 89 | │ ├── gtFine 90 | │ │ ├── train 91 | │ │ │ ├── aachen 92 | │ │ │ │ ├── aachen_000000_000019_gtFine_color.png 93 | │ │ │ │ ├── aachen_000000_000019_gtFine_instanceIds.png 94 | │ │ │ │ ├── aachen_000000_000019_gtFine_labelIds.png 95 | │ │ │ │ ├── aachen_000000_000019_gtFine_polygons.json 96 | │ │ │ │ ├── ... 97 | │ │ │ ├── bochum 98 | │ │ │ ├── ... 99 | │ │ ├── val 100 | │ │ ├── test 101 | ├── leftImg8bit_trainextra 102 | │ ├── leftImg8bit 103 | │ │ ├── train_extra 104 | │ │ │ ├── augsburg 105 | │ │ │ ├── bad-honnef 106 | │ │ │ ├── ... 107 | ├── gtCoarse 108 | │ ├── gtCoarse 109 | │ │ ├── train 110 | │ │ ├── train_extra 111 | │ │ ├── val 112 | ├── autolabelled 113 | │ ├── train_extra 114 | │ │ ├── augsburg 115 | │ │ ├── bad-honnef 116 | │ │ ├── ... 117 | ``` 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /semantic-segmentation/loss/radam.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code adapted from: https://github.com/LiyuanLucasLiu/RAdam 3 | From the paper: https://arxiv.org/abs/1908.03265 4 | """ 5 | import math 6 | import torch 7 | # pylint: disable=no-name-in-module 8 | from torch.optim.optimizer import Optimizer 9 | 10 | 11 | class RAdam(Optimizer): 12 | """RAdam optimizer""" 13 | 14 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 15 | weight_decay=0): 16 | """ 17 | Init 18 | 19 | :param params: parameters to optimize 20 | :param lr: learning rate 21 | :param betas: beta 22 | :param eps: numerical precision 23 | :param weight_decay: weight decay weight 24 | """ 25 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 26 | self.buffer = [[None, None, None] for _ in range(10)] 27 | super().__init__(params, defaults) 28 | 29 | def step(self, closure=None): 30 | 31 | loss = None 32 | if closure is not None: 33 | loss = closure() 34 | 35 | for group in self.param_groups: 36 | 37 | for p in group['params']: 38 | if p.grad is None: 39 | continue 40 | grad = p.grad.data.float() 41 | if grad.is_sparse: 42 | raise RuntimeError( 43 | 'RAdam does not support sparse gradients' 44 | ) 45 | 46 | p_data_fp32 = p.data.float() 47 | 48 | state = self.state[p] 49 | 50 | if len(state) == 0: 51 | state['step'] = 0 52 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 53 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 54 | else: 55 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 56 | state['exp_avg_sq'] = ( 57 | state['exp_avg_sq'].type_as(p_data_fp32) 58 | ) 59 | 60 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 61 | beta1, beta2 = group['betas'] 62 | 63 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 64 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 65 | 66 | state['step'] += 1 67 | buffered = self.buffer[int(state['step'] % 10)] 68 | if state['step'] == buffered[0]: 69 | N_sma, step_size = buffered[1], buffered[2] 70 | else: 71 | buffered[0] = state['step'] 72 | beta2_t = beta2 ** state['step'] 73 | N_sma_max = 2 / (1 - beta2) - 1 74 | N_sma = ( 75 | N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 76 | ) 77 | buffered[1] = N_sma 78 | 79 | # more conservative since it's an approximated value 80 | if N_sma >= 5: 81 | step_size = ( 82 | group['lr'] * 83 | math.sqrt( 84 | (1 - beta2_t) * (N_sma - 4) / 85 | (N_sma_max - 4) * (N_sma - 2) / 86 | N_sma * N_sma_max / (N_sma_max - 2) 87 | ) / (1 - beta1 ** state['step']) 88 | ) 89 | else: 90 | step_size = group['lr'] / (1 - beta1 ** state['step']) 91 | buffered[2] = step_size 92 | 93 | if group['weight_decay'] != 0: 94 | p_data_fp32.add_( 95 | -group['weight_decay'] * group['lr'], p_data_fp32 96 | ) 97 | 98 | # more conservative since it's an approximated value 99 | if N_sma >= 5: 100 | denom = exp_avg_sq.sqrt().add_(group['eps']) 101 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 102 | else: 103 | p_data_fp32.add_(-step_size, exp_avg) 104 | 105 | p.data.copy_(p_data_fp32) 106 | 107 | return loss 108 | -------------------------------------------------------------------------------- /semantic-segmentation/network/basic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | from torch import nn 31 | 32 | from network.mynn import initialize_weights, Upsample 33 | from network.mynn import scale_as 34 | from network.utils import get_aspp, get_trunk, make_seg_head 35 | from config import cfg 36 | 37 | 38 | class Basic(nn.Module): 39 | """ 40 | Basic segmentation network, no ASPP, no Mscale 41 | """ 42 | def __init__(self, num_classes, trunk='hrnetv2', criterion=None): 43 | super(Basic, self).__init__() 44 | self.criterion = criterion 45 | self.backbone, _, _, high_level_ch = get_trunk( 46 | trunk_name=trunk, output_stride=8) 47 | self.seg_head = make_seg_head(in_ch=high_level_ch, 48 | out_ch=num_classes) 49 | initialize_weights(self.seg_head) 50 | 51 | def forward(self, inputs): 52 | x = inputs['images'] 53 | _, _, final_features = self.backbone(x) 54 | pred = self.seg_head(final_features) 55 | pred = scale_as(pred, x) 56 | 57 | if self.training: 58 | assert 'gts' in inputs 59 | gts = inputs['gts'] 60 | loss = self.criterion(pred, gts) 61 | return loss 62 | else: 63 | output_dict = {'pred': pred} 64 | return output_dict 65 | 66 | 67 | class ASPP(nn.Module): 68 | """ 69 | ASPP-based Segmentation network 70 | """ 71 | def __init__(self, num_classes, trunk='hrnetv2', criterion=None): 72 | super(ASPP, self).__init__() 73 | self.criterion = criterion 74 | self.backbone, _, _, high_level_ch = get_trunk(trunk) 75 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, 76 | bottleneck_ch=cfg.MODEL.ASPP_BOT_CH, 77 | output_stride=8) 78 | self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) 79 | self.final = make_seg_head(in_ch=256, 80 | out_ch=num_classes) 81 | 82 | initialize_weights(self.final, self.bot_aspp, self.aspp) 83 | 84 | def forward(self, inputs): 85 | x = inputs['images'] 86 | x_size = x.size() 87 | 88 | _, _, final_features = self.backbone(x) 89 | aspp = self.aspp(final_features) 90 | aspp = self.bot_aspp(aspp) 91 | pred = self.final(aspp) 92 | pred = Upsample(pred, x_size[2:]) 93 | 94 | if self.training: 95 | assert 'gts' in inputs 96 | gts = inputs['gts'] 97 | loss = self.criterion(pred, gts) 98 | return loss 99 | else: 100 | output_dict = {'pred': pred} 101 | return output_dict 102 | 103 | 104 | def HRNet(num_classes, criterion, s2s4=None): 105 | return Basic(num_classes=num_classes, criterion=criterion, 106 | trunk='hrnetv2') 107 | 108 | 109 | def HRNet_ASP(num_classes, criterion, s2s4=None): 110 | return ASPP(num_classes=num_classes, criterion=criterion, 111 | trunk='hrnetv2') 112 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/mapillary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | Mapillary Dataset Loader 32 | """ 33 | import os 34 | import json 35 | 36 | from config import cfg 37 | from runx.logx import logx 38 | from datasets.base_loader import BaseLoader 39 | from datasets.utils import make_dataset_folder 40 | from datasets import uniform 41 | 42 | 43 | class Loader(BaseLoader): 44 | num_classes = 65 45 | ignore_label = 65 46 | trainid_to_name = {} 47 | color_mapping = [] 48 | 49 | def __init__(self, mode, quality='semantic', joint_transform_list=None, 50 | img_transform=None, label_transform=None, eval_folder=None): 51 | 52 | super(Loader, self).__init__(quality=quality, 53 | mode=mode, 54 | joint_transform_list=joint_transform_list, 55 | img_transform=img_transform, 56 | label_transform=label_transform) 57 | 58 | root = cfg.DATASET.MAPILLARY_DIR 59 | config_fn = os.path.join(root, 'config.json') 60 | self.fill_colormap_and_names(config_fn) 61 | 62 | ###################################################################### 63 | # Assemble image lists 64 | ###################################################################### 65 | if mode == 'folder': 66 | self.all_imgs = make_dataset_folder(eval_folder) 67 | else: 68 | splits = {'train': 'training', 69 | 'val': 'validation', 70 | 'test': 'testing'} 71 | split_name = splits[mode] 72 | img_ext = 'jpg' 73 | mask_ext = 'png' 74 | img_root = os.path.join(root, split_name, 'images') 75 | mask_root = os.path.join(root, split_name, 'labels') 76 | self.all_imgs = self.find_images(img_root, mask_root, img_ext, 77 | mask_ext) 78 | logx.msg('all imgs {}'.format(len(self.all_imgs))) 79 | self.centroids = uniform.build_centroids(self.all_imgs, 80 | self.num_classes, 81 | self.train, 82 | cv=cfg.DATASET.CV) 83 | self.build_epoch() 84 | 85 | def fill_colormap_and_names(self, config_fn): 86 | """ 87 | Mapillary code for color map and class names 88 | 89 | Outputs 90 | ------- 91 | self.trainid_to_name 92 | self.color_mapping 93 | """ 94 | with open(config_fn) as config_file: 95 | config = json.load(config_file) 96 | config_labels = config['labels'] 97 | 98 | # calculate label color mapping 99 | colormap = [] 100 | self.trainid_to_name = {} 101 | for i in range(0, len(config_labels)): 102 | colormap = colormap + config_labels[i]['color'] 103 | name = config_labels[i]['readable'] 104 | name = name.replace(' ', '_') 105 | self.trainid_to_name[i] = name 106 | self.color_mapping = colormap 107 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | 37 | 38 | import math 39 | import torch 40 | from torch.distributed import get_world_size, get_rank 41 | from torch.utils.data import Sampler 42 | 43 | class DistributedSampler(Sampler): 44 | """Sampler that restricts data loading to a subset of the dataset. 45 | 46 | It is especially useful in conjunction with 47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 48 | process can pass a DistributedSampler instance as a DataLoader sampler, 49 | and load a subset of the original dataset that is exclusive to it. 50 | 51 | .. note:: 52 | Dataset is assumed to be of constant size. 53 | 54 | Arguments: 55 | dataset: Dataset used for sampling. 56 | num_replicas (optional): Number of processes participating in 57 | distributed training. 58 | rank (optional): Rank of the current process within num_replicas. 59 | """ 60 | 61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None): 62 | if num_replicas is None: 63 | num_replicas = get_world_size() 64 | if rank is None: 65 | rank = get_rank() 66 | self.dataset = dataset 67 | self.num_replicas = num_replicas 68 | self.rank = rank 69 | self.epoch = 0 70 | self.consecutive_sample = consecutive_sample 71 | self.permutation = permutation 72 | if pad: 73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 74 | else: 75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 76 | self.total_size = self.num_samples * self.num_replicas 77 | 78 | def __iter__(self): 79 | # deterministically shuffle based on epoch 80 | g = torch.Generator() 81 | g.manual_seed(self.epoch) 82 | 83 | if self.permutation: 84 | indices = list(torch.randperm(len(self.dataset), generator=g)) 85 | else: 86 | indices = list([x for x in range(len(self.dataset))]) 87 | 88 | # add extra samples to make it evenly divisible 89 | if self.total_size > len(indices): 90 | indices += indices[:(self.total_size - len(indices))] 91 | 92 | # subsample 93 | if self.consecutive_sample: 94 | offset = self.num_samples * self.rank 95 | indices = indices[offset:offset + self.num_samples] 96 | else: 97 | indices = indices[self.rank:self.total_size:self.num_replicas] 98 | assert len(indices) == self.num_samples 99 | 100 | return iter(indices) 101 | 102 | def __len__(self): 103 | return self.num_samples 104 | 105 | def set_epoch(self, epoch): 106 | self.epoch = epoch 107 | 108 | def set_num_samples(self): 109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 110 | self.total_size = self.num_samples * self.num_replicas -------------------------------------------------------------------------------- /semantic-segmentation/loss/rmi_utils.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from: https://github.com/ZJULearning/RMI 2 | 3 | # python 2.X, 3.X compatibility 4 | from __future__ import print_function 5 | from __future__ import division 6 | from __future__ import absolute_import 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | 12 | __all__ = ['map_get_pairs', 'log_det_by_cholesky'] 13 | 14 | 15 | def map_get_pairs(labels_4D, probs_4D, radius=3, is_combine=True): 16 | """get map pairs 17 | Args: 18 | labels_4D : labels, shape [N, C, H, W] 19 | probs_4D : probabilities, shape [N, C, H, W] 20 | radius : the square radius 21 | Return: 22 | tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)] 23 | """ 24 | # pad to ensure the following slice operation is valid 25 | #pad_beg = int(radius // 2) 26 | #pad_end = radius - pad_beg 27 | 28 | # the original height and width 29 | label_shape = labels_4D.size() 30 | h, w = label_shape[2], label_shape[3] 31 | new_h, new_w = h - (radius - 1), w - (radius - 1) 32 | # https://pytorch.org/docs/stable/nn.html?highlight=f%20pad#torch.nn.functional.pad 33 | #padding = (pad_beg, pad_end, pad_beg, pad_end) 34 | #labels_4D, probs_4D = F.pad(labels_4D, padding), F.pad(probs_4D, padding) 35 | 36 | # get the neighbors 37 | la_ns = [] 38 | pr_ns = [] 39 | #for x in range(0, radius, 1): 40 | for y in range(0, radius, 1): 41 | for x in range(0, radius, 1): 42 | la_now = labels_4D[:, :, y:y + new_h, x:x + new_w] 43 | pr_now = probs_4D[:, :, y:y + new_h, x:x + new_w] 44 | la_ns.append(la_now) 45 | pr_ns.append(pr_now) 46 | 47 | if is_combine: 48 | # for calculating RMI 49 | pair_ns = la_ns + pr_ns 50 | p_vectors = torch.stack(pair_ns, dim=2) 51 | return p_vectors 52 | else: 53 | # for other purpose 54 | la_vectors = torch.stack(la_ns, dim=2) 55 | pr_vectors = torch.stack(pr_ns, dim=2) 56 | return la_vectors, pr_vectors 57 | 58 | 59 | def map_get_pairs_region(labels_4D, probs_4D, radius=3, is_combine=0, num_classeses=21): 60 | """get map pairs 61 | Args: 62 | labels_4D : labels, shape [N, C, H, W]. 63 | probs_4D : probabilities, shape [N, C, H, W]. 64 | radius : The side length of the square region. 65 | Return: 66 | A tensor with shape [N, C, radiu * radius, H // radius, W // raidius] 67 | """ 68 | kernel = torch.zeros([num_classeses, 1, radius, radius]).type_as(probs_4D) 69 | padding = radius // 2 70 | # get the neighbours 71 | la_ns = [] 72 | pr_ns = [] 73 | for y in range(0, radius, 1): 74 | for x in range(0, radius, 1): 75 | kernel_now = kernel.clone() 76 | kernel_now[:, :, y, x] = 1.0 77 | la_now = F.conv2d(labels_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) 78 | pr_now = F.conv2d(probs_4D, kernel_now, stride=radius, padding=padding, groups=num_classeses) 79 | la_ns.append(la_now) 80 | pr_ns.append(pr_now) 81 | 82 | if is_combine: 83 | # for calculating RMI 84 | pair_ns = la_ns + pr_ns 85 | p_vectors = torch.stack(pair_ns, dim=2) 86 | return p_vectors 87 | else: 88 | # for other purpose 89 | la_vectors = torch.stack(la_ns, dim=2) 90 | pr_vectors = torch.stack(pr_ns, dim=2) 91 | return la_vectors, pr_vectors 92 | return 93 | 94 | 95 | def log_det_by_cholesky(matrix): 96 | """ 97 | Args: 98 | matrix: matrix must be a positive define matrix. 99 | shape [N, C, D, D]. 100 | Ref: 101 | https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/linalg/linalg_impl.py 102 | """ 103 | # This uses the property that the log det(A) = 2 * sum(log(real(diag(C)))) 104 | # where C is the cholesky decomposition of A. 105 | chol = torch.cholesky(matrix) 106 | #return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-6), dim=-1) 107 | return 2.0 * torch.sum(torch.log(torch.diagonal(chol, dim1=-2, dim2=-1) + 1e-8), dim=-1) 108 | 109 | 110 | def batch_cholesky_inverse(matrix): 111 | """ 112 | Args: matrix, 4-D tensor, [N, C, M, M]. 113 | matrix must be a symmetric positive define matrix. 114 | """ 115 | chol_low = torch.cholesky(matrix, upper=False) 116 | chol_low_inv = batch_low_tri_inv(chol_low) 117 | return torch.matmul(chol_low_inv.transpose(-2, -1), chol_low_inv) 118 | 119 | 120 | def batch_low_tri_inv(L): 121 | """ 122 | Batched inverse of lower triangular matrices 123 | Args: 124 | L : a lower triangular matrix 125 | Ref: 126 | https://www.pugetsystems.com/labs/hpc/PyTorch-for-Scientific-Computing 127 | """ 128 | n = L.shape[-1] 129 | invL = torch.zeros_like(L) 130 | for j in range(0, n): 131 | invL[..., j, j] = 1.0 / L[..., j, j] 132 | for i in range(j + 1, n): 133 | S = 0.0 134 | for k in range(0, i + 1): 135 | S = S - L[..., i, k] * invL[..., k, j].clone() 136 | invL[..., i, j] = S / L[..., i, i] 137 | return invL 138 | 139 | 140 | def log_det_by_cholesky_test(): 141 | """ 142 | test for function log_det_by_cholesky() 143 | """ 144 | a = torch.randn(1, 4, 4) 145 | a = torch.matmul(a, a.transpose(2, 1)) 146 | print(a) 147 | res_1 = torch.logdet(torch.squeeze(a)) 148 | res_2 = log_det_by_cholesky(a) 149 | print(res_1, res_2) 150 | 151 | 152 | def batch_inv_test(): 153 | """ 154 | test for function batch_cholesky_inverse() 155 | """ 156 | a = torch.randn(1, 1, 4, 4) 157 | a = torch.matmul(a, a.transpose(-2, -1)) 158 | print(a) 159 | res_1 = torch.inverse(a) 160 | res_2 = batch_cholesky_inverse(a) 161 | print(res_1, '\n', res_2) 162 | 163 | 164 | def mean_var_test(): 165 | x = torch.randn(3, 4) 166 | y = torch.randn(3, 4) 167 | 168 | x_mean = x.mean(dim=1, keepdim=True) 169 | x_sum = x.sum(dim=1, keepdim=True) / 2.0 170 | y_mean = y.mean(dim=1, keepdim=True) 171 | y_sum = y.sum(dim=1, keepdim=True) / 2.0 172 | 173 | x_var_1 = torch.matmul(x - x_mean, (x - x_mean).t()) 174 | x_var_2 = torch.matmul(x, x.t()) - torch.matmul(x_sum, x_sum.t()) 175 | xy_cov = torch.matmul(x - x_mean, (y - y_mean).t()) 176 | xy_cov_1 = torch.matmul(x, y.t()) - x_sum.matmul(y_sum.t()) 177 | 178 | print(x_var_1) 179 | print(x_var_2) 180 | 181 | print(xy_cov, '\n', xy_cov_1) 182 | 183 | 184 | if __name__ == '__main__': 185 | batch_inv_test() 186 | -------------------------------------------------------------------------------- /semantic-segmentation/loss/optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | 31 | # Optimizer and scheduler related tasks 32 | 33 | import math 34 | import torch 35 | 36 | from torch import optim 37 | from runx.logx import logx 38 | 39 | from config import cfg 40 | from loss.radam import RAdam 41 | 42 | 43 | def get_optimizer(args, net): 44 | """ 45 | Decide Optimizer (Adam or SGD) 46 | """ 47 | param_groups = net.parameters() 48 | 49 | if args.optimizer == 'sgd': 50 | optimizer = optim.SGD(param_groups, 51 | lr=args.lr, 52 | weight_decay=args.weight_decay, 53 | momentum=args.momentum, 54 | nesterov=False) 55 | elif args.optimizer == 'adam': 56 | optimizer = optim.Adam(param_groups, 57 | lr=args.lr, 58 | weight_decay=args.weight_decay, 59 | amsgrad=args.amsgrad) 60 | elif args.optimizer == 'radam': 61 | optimizer = RAdam(param_groups, 62 | lr=args.lr, 63 | weight_decay=args.weight_decay) 64 | else: 65 | raise ValueError('Not a valid optimizer') 66 | 67 | def poly_schd(epoch): 68 | return math.pow(1 - epoch / args.max_epoch, args.poly_exp) 69 | 70 | def poly2_schd(epoch): 71 | if epoch < args.poly_step: 72 | poly_exp = args.poly_exp 73 | else: 74 | poly_exp = 2 * args.poly_exp 75 | return math.pow(1 - epoch / args.max_epoch, poly_exp) 76 | 77 | if args.lr_schedule == 'scl-poly': 78 | if cfg.REDUCE_BORDER_EPOCH == -1: 79 | raise ValueError('ERROR Cannot Do Scale Poly') 80 | 81 | rescale_thresh = cfg.REDUCE_BORDER_EPOCH 82 | scale_value = args.rescale 83 | lambda1 = lambda epoch: \ 84 | math.pow(1 - epoch / args.max_epoch, 85 | args.poly_exp) if epoch < rescale_thresh else scale_value * math.pow( 86 | 1 - (epoch - rescale_thresh) / (args.max_epoch - rescale_thresh), 87 | args.repoly) 88 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 89 | elif args.lr_schedule == 'poly2': 90 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, 91 | lr_lambda=poly2_schd) 92 | elif args.lr_schedule == 'poly': 93 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, 94 | lr_lambda=poly_schd) 95 | else: 96 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule)) 97 | 98 | return optimizer, scheduler 99 | 100 | 101 | def load_weights(net, optimizer, snapshot_file, restore_optimizer_bool=False): 102 | """ 103 | Load weights from snapshot file 104 | """ 105 | logx.msg("Loading weights from model {}".format(snapshot_file)) 106 | net, optimizer = restore_snapshot(net, optimizer, snapshot_file, restore_optimizer_bool) 107 | return net, optimizer 108 | 109 | 110 | def restore_snapshot(net, optimizer, snapshot, restore_optimizer_bool): 111 | """ 112 | Restore weights and optimizer (if needed ) for resuming job. 113 | """ 114 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 115 | logx.msg("Checkpoint Load Compelete") 116 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool: 117 | optimizer.load_state_dict(checkpoint['optimizer']) 118 | 119 | if 'state_dict' in checkpoint: 120 | net = forgiving_state_restore(net, checkpoint['state_dict']) 121 | else: 122 | net = forgiving_state_restore(net, checkpoint) 123 | 124 | return net, optimizer 125 | 126 | 127 | def restore_opt(optimizer, checkpoint): 128 | assert 'optimizer' in checkpoint, 'cant find optimizer in checkpoint' 129 | optimizer.load_state_dict(checkpoint['optimizer']) 130 | 131 | 132 | def restore_net(net, checkpoint): 133 | assert 'state_dict' in checkpoint, 'cant find state_dict in checkpoint' 134 | forgiving_state_restore(net, checkpoint['state_dict']) 135 | 136 | 137 | def forgiving_state_restore(net, loaded_dict): 138 | """ 139 | Handle partial loading when some tensors don't match up in size. 140 | Because we want to use models that were trained off a different 141 | number of classes. 142 | """ 143 | 144 | net_state_dict = net.state_dict() 145 | new_loaded_dict = {} 146 | for k in net_state_dict: 147 | new_k = k 148 | if new_k in loaded_dict and net_state_dict[k].size() == loaded_dict[new_k].size(): 149 | new_loaded_dict[k] = loaded_dict[new_k] 150 | else: 151 | logx.msg("Skipped loading parameter {}".format(k)) 152 | net_state_dict.update(new_loaded_dict) 153 | net.load_state_dict(net_state_dict) 154 | return net 155 | -------------------------------------------------------------------------------- /semantic-segmentation/network/deepv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Adapted from: 3 | https://github.com/sthalles/deeplab_v3 4 | 5 | Copyright 2020 Nvidia Corporation 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | 1. Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | 2. Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its contributors 18 | may be used to endorse or promote products derived from this software 19 | without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 24 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 25 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 26 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 27 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 28 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 29 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 30 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 31 | POSSIBILITY OF SUCH DAMAGE. 32 | """ 33 | import torch 34 | from torch import nn 35 | 36 | from network.mynn import initialize_weights, Norm2d, Upsample 37 | from network.utils import get_aspp, get_trunk, make_seg_head 38 | 39 | 40 | class DeepV3Plus(nn.Module): 41 | """ 42 | DeepLabV3+ with various trunks supported 43 | Always stride8 44 | """ 45 | def __init__(self, num_classes, trunk='wrn38', criterion=None, 46 | use_dpc=False, init_all=False): 47 | super(DeepV3Plus, self).__init__() 48 | self.criterion = criterion 49 | self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) 50 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, 51 | bottleneck_ch=256, 52 | output_stride=8, 53 | dpc=use_dpc) 54 | self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) 55 | self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) 56 | self.final = nn.Sequential( 57 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), 58 | Norm2d(256), 59 | nn.ReLU(inplace=True), 60 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 61 | Norm2d(256), 62 | nn.ReLU(inplace=True), 63 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 64 | 65 | if init_all: 66 | initialize_weights(self.aspp) 67 | initialize_weights(self.bot_aspp) 68 | initialize_weights(self.bot_fine) 69 | initialize_weights(self.final) 70 | else: 71 | initialize_weights(self.final) 72 | 73 | def forward(self, inputs): 74 | assert 'images' in inputs 75 | x = inputs['images'] 76 | 77 | x_size = x.size() 78 | s2_features, _, final_features = self.backbone(x) 79 | aspp = self.aspp(final_features) 80 | conv_aspp = self.bot_aspp(aspp) 81 | conv_s2 = self.bot_fine(s2_features) 82 | conv_aspp = Upsample(conv_aspp, s2_features.size()[2:]) 83 | cat_s4 = [conv_s2, conv_aspp] 84 | cat_s4 = torch.cat(cat_s4, 1) 85 | final = self.final(cat_s4) 86 | out = Upsample(final, x_size[2:]) 87 | 88 | if self.training: 89 | assert 'gts' in inputs 90 | gts = inputs['gts'] 91 | return self.criterion(out, gts) 92 | 93 | return {'pred': out} 94 | 95 | 96 | def DeepV3PlusSRNX50(num_classes, criterion): 97 | return DeepV3Plus(num_classes, trunk='seresnext-50', criterion=criterion) 98 | 99 | 100 | def DeepV3PlusR50(num_classes, criterion): 101 | return DeepV3Plus(num_classes, trunk='resnet-50', criterion=criterion) 102 | 103 | 104 | def DeepV3PlusSRNX101(num_classes, criterion): 105 | return DeepV3Plus(num_classes, trunk='seresnext-101', criterion=criterion) 106 | 107 | 108 | def DeepV3PlusW38(num_classes, criterion): 109 | return DeepV3Plus(num_classes, trunk='wrn38', criterion=criterion) 110 | 111 | 112 | def DeepV3PlusW38I(num_classes, criterion): 113 | return DeepV3Plus(num_classes, trunk='wrn38', criterion=criterion, 114 | init_all=True) 115 | 116 | 117 | def DeepV3PlusX71(num_classes, criterion): 118 | return DeepV3Plus(num_classes, trunk='xception71', criterion=criterion) 119 | 120 | 121 | def DeepV3PlusEffB4(num_classes, criterion): 122 | return DeepV3Plus(num_classes, trunk='efficientnet_b4', 123 | criterion=criterion) 124 | 125 | 126 | class DeepV3(nn.Module): 127 | """ 128 | DeepLabV3 with various trunks supported 129 | """ 130 | def __init__(self, num_classes, trunk='resnet-50', criterion=None, 131 | use_dpc=False, init_all=False, output_stride=8): 132 | super(DeepV3, self).__init__() 133 | self.criterion = criterion 134 | 135 | self.backbone, _s2_ch, _s4_ch, high_level_ch = \ 136 | get_trunk(trunk, output_stride=output_stride) 137 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, 138 | bottleneck_ch=256, 139 | output_stride=output_stride, 140 | dpc=use_dpc) 141 | self.final = make_seg_head(in_ch=aspp_out_ch, out_ch=num_classes) 142 | 143 | initialize_weights(self.aspp) 144 | initialize_weights(self.final) 145 | 146 | def forward(self, inputs): 147 | assert 'images' in inputs 148 | x = inputs['images'] 149 | 150 | x_size = x.size() 151 | _, _, final_features = self.backbone(x) 152 | aspp = self.aspp(final_features) 153 | final = self.final(aspp) 154 | out = Upsample(final, x_size[2:]) 155 | 156 | if self.training: 157 | assert 'gts' in inputs 158 | gts = inputs['gts'] 159 | return self.criterion(out, gts) 160 | 161 | return {'pred': out} 162 | 163 | 164 | def DeepV3R50(num_classes, criterion): 165 | return DeepV3(num_classes, trunk='resnet-50', criterion=criterion) 166 | 167 | -------------------------------------------------------------------------------- /semantic-segmentation/network/ocr_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn), Jingyi Xie (hsfzxjy@gmail.com) 5 | # 6 | # This code is from: https://github.com/HRNet/HRNet-Semantic-Segmentation 7 | # ------------------------------------------------------------------------------ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from config import cfg 14 | from network.utils import BNReLU 15 | 16 | 17 | class SpatialGather_Module(nn.Module): 18 | """ 19 | Aggregate the context features according to the initial 20 | predicted probability distribution. 21 | Employ the soft-weighted method to aggregate the context. 22 | 23 | Output: 24 | The correlation of every class map with every feature map 25 | shape = [n, num_feats, num_classes, 1] 26 | 27 | 28 | """ 29 | def __init__(self, cls_num=0, scale=1): 30 | super(SpatialGather_Module, self).__init__() 31 | self.cls_num = cls_num 32 | self.scale = scale 33 | 34 | def forward(self, feats, probs): 35 | batch_size, c, _, _ = probs.size(0), probs.size(1), probs.size(2), \ 36 | probs.size(3) 37 | 38 | # each class image now a vector 39 | probs = probs.view(batch_size, c, -1) 40 | feats = feats.view(batch_size, feats.size(1), -1) 41 | 42 | feats = feats.permute(0, 2, 1) # batch x hw x c 43 | probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw 44 | ocr_context = torch.matmul(probs, feats) 45 | ocr_context = ocr_context.permute(0, 2, 1).unsqueeze(3) 46 | return ocr_context 47 | 48 | 49 | class ObjectAttentionBlock(nn.Module): 50 | ''' 51 | The basic implementation for object context block 52 | Input: 53 | N X C X H X W 54 | Parameters: 55 | in_channels : the dimension of the input feature map 56 | key_channels : the dimension after the key/query transform 57 | scale : choose the scale to downsample the input feature 58 | maps (save memory cost) 59 | Return: 60 | N X C X H X W 61 | ''' 62 | def __init__(self, in_channels, key_channels, scale=1): 63 | super(ObjectAttentionBlock, self).__init__() 64 | self.scale = scale 65 | self.in_channels = in_channels 66 | self.key_channels = key_channels 67 | self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) 68 | self.f_pixel = nn.Sequential( 69 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 70 | kernel_size=1, stride=1, padding=0, bias=False), 71 | BNReLU(self.key_channels), 72 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, 73 | kernel_size=1, stride=1, padding=0, bias=False), 74 | BNReLU(self.key_channels), 75 | ) 76 | self.f_object = nn.Sequential( 77 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 78 | kernel_size=1, stride=1, padding=0, bias=False), 79 | BNReLU(self.key_channels), 80 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, 81 | kernel_size=1, stride=1, padding=0, bias=False), 82 | BNReLU(self.key_channels), 83 | ) 84 | self.f_down = nn.Sequential( 85 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, 86 | kernel_size=1, stride=1, padding=0, bias=False), 87 | BNReLU(self.key_channels), 88 | ) 89 | self.f_up = nn.Sequential( 90 | nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, 91 | kernel_size=1, stride=1, padding=0, bias=False), 92 | BNReLU(self.in_channels), 93 | ) 94 | 95 | def forward(self, x, proxy): 96 | batch_size, h, w = x.size(0), x.size(2), x.size(3) 97 | if self.scale > 1: 98 | x = self.pool(x) 99 | 100 | query = self.f_pixel(x).view(batch_size, self.key_channels, -1) 101 | query = query.permute(0, 2, 1) 102 | key = self.f_object(proxy).view(batch_size, self.key_channels, -1) 103 | value = self.f_down(proxy).view(batch_size, self.key_channels, -1) 104 | value = value.permute(0, 2, 1) 105 | 106 | sim_map = torch.matmul(query, key) 107 | sim_map = (self.key_channels**-.5) * sim_map 108 | sim_map = F.softmax(sim_map, dim=-1) 109 | 110 | # add bg context ... 111 | context = torch.matmul(sim_map, value) 112 | context = context.permute(0, 2, 1).contiguous() 113 | context = context.view(batch_size, self.key_channels, *x.size()[2:]) 114 | context = self.f_up(context) 115 | if self.scale > 1: 116 | context = F.interpolate(input=context, size=(h, w), mode='bilinear', 117 | align_corners=cfg.MODEL.ALIGN_CORNERS) 118 | 119 | return context 120 | 121 | 122 | class SpatialOCR_Module(nn.Module): 123 | """ 124 | Implementation of the OCR module: 125 | We aggregate the global object representation to update the representation 126 | for each pixel. 127 | """ 128 | def __init__(self, in_channels, key_channels, out_channels, scale=1, 129 | dropout=0.1): 130 | super(SpatialOCR_Module, self).__init__() 131 | self.object_context_block = ObjectAttentionBlock(in_channels, 132 | key_channels, 133 | scale) 134 | if cfg.MODEL.OCR_ASPP: 135 | self.aspp, aspp_out_ch = get_aspp( 136 | in_channels, bottleneck_ch=cfg.MODEL.ASPP_BOT_CH, 137 | output_stride=8) 138 | _in_channels = 2 * in_channels + aspp_out_ch 139 | else: 140 | _in_channels = 2 * in_channels 141 | 142 | self.conv_bn_dropout = nn.Sequential( 143 | nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, 144 | bias=False), 145 | BNReLU(out_channels), 146 | nn.Dropout2d(dropout) 147 | ) 148 | 149 | def forward(self, feats, proxy_feats): 150 | context = self.object_context_block(feats, proxy_feats) 151 | 152 | if cfg.MODEL.OCR_ASPP: 153 | aspp = self.aspp(feats) 154 | output = self.conv_bn_dropout(torch.cat([context, aspp, feats], 1)) 155 | else: 156 | output = self.conv_bn_dropout(torch.cat([context, feats], 1)) 157 | 158 | return output 159 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/randaugment.py: -------------------------------------------------------------------------------- 1 | # this code from: https://github.com/ildoonet/pytorch-randaugment 2 | # code in this file is adpated from rpmcruz/autoaugment 3 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 4 | import random 5 | import numpy as np 6 | import torch 7 | 8 | from PIL import Image, ImageOps, ImageEnhance, ImageDraw 9 | from config import cfg 10 | 11 | 12 | fillmask = cfg.DATASET.IGNORE_LABEL 13 | fillcolor = (0, 0, 0) 14 | 15 | 16 | def affine_transform(pair, affine_params): 17 | img, mask = pair 18 | img = img.transform(img.size, Image.AFFINE, affine_params, 19 | resample=Image.BILINEAR, fillcolor=fillcolor) 20 | mask = mask.transform(mask.size, Image.AFFINE, affine_params, 21 | resample=Image.NEAREST, fillcolor=fillmask) 22 | return img, mask 23 | 24 | 25 | def ShearX(pair, v): # [-0.3, 0.3] 26 | assert -0.3 <= v <= 0.3 27 | if random.random() > 0.5: 28 | v = -v 29 | return affine_transform(pair, (1, v, 0, 0, 1, 0)) 30 | 31 | 32 | def ShearY(pair, v): # [-0.3, 0.3] 33 | assert -0.3 <= v <= 0.3 34 | if random.random() > 0.5: 35 | v = -v 36 | return affine_transform(pair, (1, 0, 0, v, 1, 0)) 37 | 38 | 39 | def TranslateX(pair, v): # [-150, 150] => percentage: [-0.45, 0.45] 40 | assert -0.45 <= v <= 0.45 41 | if random.random() > 0.5: 42 | v = -v 43 | img, _ = pair 44 | v = v * img.size[0] 45 | return affine_transform(pair, (1, 0, v, 0, 1, 0)) 46 | 47 | 48 | def TranslateY(pair, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert -0.45 <= v <= 0.45 50 | if random.random() > 0.5: 51 | v = -v 52 | img, _ = pair 53 | v = v * img.size[1] 54 | return affine_transform(pair, (1, 0, 0, 0, 1, v)) 55 | 56 | 57 | def TranslateXAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45] 58 | assert 0 <= v <= 10 59 | if random.random() > 0.5: 60 | v = -v 61 | return affine_transform(pair, (1, 0, v, 0, 1, 0)) 62 | 63 | 64 | def TranslateYAbs(pair, v): # [-150, 150] => percentage: [-0.45, 0.45] 65 | assert 0 <= v <= 10 66 | if random.random() > 0.5: 67 | v = -v 68 | return affine_transform(pair, (1, 0, 0, 0, 1, v)) 69 | 70 | 71 | def Rotate(pair, v): # [-30, 30] 72 | assert -30 <= v <= 30 73 | if random.random() > 0.5: 74 | v = -v 75 | img, mask = pair 76 | img = img.rotate(v, fillcolor=fillcolor) 77 | mask = mask.rotate(v, resample=Image.NEAREST, fillcolor=fillmask) 78 | return img, mask 79 | 80 | 81 | def AutoContrast(pair, _): 82 | img, mask = pair 83 | return ImageOps.autocontrast(img), mask 84 | 85 | 86 | def Invert(pair, _): 87 | img, mask = pair 88 | return ImageOps.invert(img), mask 89 | 90 | 91 | def Equalize(pair, _): 92 | img, mask = pair 93 | return ImageOps.equalize(img), mask 94 | 95 | 96 | def Flip(pair, _): # not from the paper 97 | img, mask = pair 98 | return ImageOps.mirror(img), ImageOps.mirror(mask) 99 | 100 | 101 | def Solarize(pair, v): # [0, 256] 102 | img, mask = pair 103 | assert 0 <= v <= 256 104 | return ImageOps.solarize(img, v), mask 105 | 106 | 107 | def Posterize(pair, v): # [4, 8] 108 | img, mask = pair 109 | assert 4 <= v <= 8 110 | v = int(v) 111 | return ImageOps.posterize(img, v), mask 112 | 113 | 114 | def Posterize2(pair, v): # [0, 4] 115 | img, mask = pair 116 | assert 0 <= v <= 4 117 | v = int(v) 118 | return ImageOps.posterize(img, v), mask 119 | 120 | 121 | def Contrast(pair, v): # [0.1,1.9] 122 | img, mask = pair 123 | assert 0.1 <= v <= 1.9 124 | return ImageEnhance.Contrast(img).enhance(v), mask 125 | 126 | 127 | def Color(pair, v): # [0.1,1.9] 128 | img, mask = pair 129 | assert 0.1 <= v <= 1.9 130 | return ImageEnhance.Color(img).enhance(v), mask 131 | 132 | 133 | def Brightness(pair, v): # [0.1,1.9] 134 | img, mask = pair 135 | assert 0.1 <= v <= 1.9 136 | return ImageEnhance.Brightness(img).enhance(v), mask 137 | 138 | 139 | def Sharpness(pair, v): # [0.1,1.9] 140 | img, mask = pair 141 | assert 0.1 <= v <= 1.9 142 | return ImageEnhance.Sharpness(img).enhance(v), mask 143 | 144 | 145 | def Cutout(pair, v): # [0, 60] => percentage: [0, 0.2] 146 | assert 0.0 <= v <= 0.2 147 | if v <= 0.: 148 | return pair 149 | img, mask = pair 150 | v = v * img.size[0] 151 | return CutoutAbs(img, v), mask 152 | 153 | 154 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 155 | # assert 0 <= v <= 20 156 | if v < 0: 157 | return img 158 | w, h = img.size 159 | x0 = np.random.uniform(w) 160 | y0 = np.random.uniform(h) 161 | 162 | x0 = int(max(0, x0 - v / 2.)) 163 | y0 = int(max(0, y0 - v / 2.)) 164 | x1 = min(w, x0 + v) 165 | y1 = min(h, y0 + v) 166 | 167 | xy = (x0, y0, x1, y1) 168 | color = (125, 123, 114) 169 | # color = (0, 0, 0) 170 | img = img.copy() 171 | ImageDraw.Draw(img).rectangle(xy, color) 172 | return img 173 | 174 | 175 | def Identity(pair, v): 176 | return pair 177 | 178 | 179 | def augment_list(): # 16 oeprations and their ranges 180 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 181 | l = [ 182 | (Identity, 0., 1.0), 183 | (ShearX, 0., 0.3), # 0 184 | (ShearY, 0., 0.3), # 1 185 | (TranslateX, 0., 0.33), # 2 186 | (TranslateY, 0., 0.33), # 3 187 | (Rotate, 0, 30), # 4 188 | (AutoContrast, 0, 1), # 5 189 | (Invert, 0, 1), # 6 190 | (Equalize, 0, 1), # 7 191 | (Solarize, 0, 110), # 8 192 | (Posterize, 4, 8), # 9 193 | # (Contrast, 0.1, 1.9), # 10 194 | (Color, 0.1, 1.9), # 11 195 | (Brightness, 0.1, 1.9), # 12 196 | (Sharpness, 0.1, 1.9), # 13 197 | # (Cutout, 0, 0.2), # 14 198 | # (SamplePairing(imgs), 0, 0.4), # 15 199 | # (Flip, 1, 1), 200 | ] 201 | return l 202 | 203 | 204 | class Lighting(object): 205 | """Lighting noise(AlexNet - style PCA - based noise)""" 206 | 207 | def __init__(self, alphastd, eigval, eigvec): 208 | self.alphastd = alphastd 209 | self.eigval = torch.Tensor(eigval) 210 | self.eigvec = torch.Tensor(eigvec) 211 | 212 | def __call__(self, img): 213 | if self.alphastd == 0: 214 | return img 215 | 216 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 217 | rgb = self.eigvec.type_as(img).clone() \ 218 | .mul(alpha.view(1, 3).expand(3, 3)) \ 219 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 220 | .sum(1).squeeze() 221 | 222 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 223 | 224 | 225 | class CutoutDefault(object): 226 | """ 227 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 228 | """ 229 | def __init__(self, length): 230 | self.length = length 231 | 232 | def __call__(self, img): 233 | h, w = img.size(1), img.size(2) 234 | mask = np.ones((h, w), np.float32) 235 | y = np.random.randint(h) 236 | x = np.random.randint(w) 237 | 238 | y1 = np.clip(y - self.length // 2, 0, h) 239 | y2 = np.clip(y + self.length // 2, 0, h) 240 | x1 = np.clip(x - self.length // 2, 0, w) 241 | x2 = np.clip(x + self.length // 2, 0, w) 242 | 243 | mask[y1: y2, x1: x2] = 0. 244 | mask = torch.from_numpy(mask) 245 | mask = mask.expand_as(img) 246 | img *= mask 247 | return img 248 | 249 | 250 | class RandAugment: 251 | def __init__(self, n, m): 252 | self.n = n 253 | self.m = m # [0, 30] 254 | self.augment_list = augment_list() 255 | 256 | def __call__(self, img, mask): 257 | pair = img, mask 258 | ops = random.choices(self.augment_list, k=self.n) 259 | for op, minval, maxval in ops: 260 | val = (float(self.m) / 30) * float(maxval - minval) + minval 261 | pair = op(pair, val) 262 | 263 | return pair 264 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | Dataset setup and loaders 31 | """ 32 | 33 | import importlib 34 | import torchvision.transforms as standard_transforms 35 | 36 | import transforms.joint_transforms as joint_transforms 37 | import transforms.transforms as extended_transforms 38 | from torch.utils.data import DataLoader 39 | 40 | from config import cfg, update_dataset_cfg, update_dataset_inst 41 | from runx.logx import logx 42 | from datasets.randaugment import RandAugment 43 | 44 | 45 | def setup_loaders(args): 46 | """ 47 | Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin] 48 | input: argument passed by the user 49 | return: training data loader, validation data loader loader, train_set 50 | """ 51 | 52 | # TODO add error checking to make sure class exists 53 | logx.msg(f'dataset = {args.dataset}') 54 | 55 | mod = importlib.import_module('datasets.{}'.format(args.dataset)) 56 | dataset_cls = getattr(mod, 'Loader') 57 | 58 | logx.msg(f'ignore_label = {dataset_cls.ignore_label}') 59 | 60 | update_dataset_cfg(num_classes=dataset_cls.num_classes, 61 | ignore_label=dataset_cls.ignore_label) 62 | 63 | ###################################################################### 64 | # Define transformations, augmentations 65 | ###################################################################### 66 | 67 | # Joint transformations that must happen on both image and mask 68 | if ',' in args.crop_size: 69 | args.crop_size = [int(x) for x in args.crop_size.split(',')] 70 | else: 71 | args.crop_size = int(args.crop_size) 72 | train_joint_transform_list = [ 73 | # TODO FIXME: move these hparams into cfg 74 | joint_transforms.RandomSizeAndCrop(args.crop_size, 75 | False, 76 | scale_min=args.scale_min, 77 | scale_max=args.scale_max, 78 | full_size=args.full_crop_training, 79 | pre_size=args.pre_size)] 80 | train_joint_transform_list.append( 81 | joint_transforms.RandomHorizontallyFlip()) 82 | 83 | if args.rand_augment is not None: 84 | N, M = [int(i) for i in args.rand_augment.split(',')] 85 | assert isinstance(N, int) and isinstance(M, int), \ 86 | f'Either N {N} or M {M} not integer' 87 | train_joint_transform_list.append(RandAugment(N, M)) 88 | 89 | ###################################################################### 90 | # Image only augmentations 91 | ###################################################################### 92 | train_input_transform = [] 93 | 94 | if args.color_aug: 95 | train_input_transform += [extended_transforms.ColorJitter( 96 | brightness=args.color_aug, 97 | contrast=args.color_aug, 98 | saturation=args.color_aug, 99 | hue=args.color_aug)] 100 | if args.bblur: 101 | train_input_transform += [extended_transforms.RandomBilateralBlur()] 102 | elif args.gblur: 103 | train_input_transform += [extended_transforms.RandomGaussianBlur()] 104 | 105 | mean_std = (cfg.DATASET.MEAN, cfg.DATASET.STD) 106 | train_input_transform += [standard_transforms.ToTensor(), 107 | standard_transforms.Normalize(*mean_std)] 108 | train_input_transform = standard_transforms.Compose(train_input_transform) 109 | 110 | val_input_transform = standard_transforms.Compose([ 111 | standard_transforms.ToTensor(), 112 | standard_transforms.Normalize(*mean_std) 113 | ]) 114 | 115 | target_transform = extended_transforms.MaskToTensor() 116 | 117 | if args.jointwtborder: 118 | target_train_transform = \ 119 | extended_transforms.RelaxedBoundaryLossToTensor() 120 | else: 121 | target_train_transform = extended_transforms.MaskToTensor() 122 | 123 | if args.eval == 'folder': 124 | val_joint_transform_list = None 125 | elif 'mapillary' in args.dataset: 126 | if args.pre_size is None: 127 | eval_size = 2177 128 | else: 129 | eval_size = args.pre_size 130 | if cfg.DATASET.MAPILLARY_CROP_VAL: 131 | val_joint_transform_list = [ 132 | joint_transforms.ResizeHeight(eval_size), 133 | joint_transforms.CenterCropPad(eval_size)] 134 | else: 135 | val_joint_transform_list = [ 136 | joint_transforms.Scale(eval_size)] 137 | else: 138 | val_joint_transform_list = None 139 | 140 | if args.eval is None or args.eval == 'val': 141 | val_name = 'val' 142 | elif args.eval == 'trn': 143 | val_name = 'train' 144 | elif args.eval == 'folder': 145 | val_name = 'folder' 146 | else: 147 | raise 'unknown eval mode {}'.format(args.eval) 148 | 149 | ###################################################################### 150 | # Create loaders 151 | ###################################################################### 152 | val_set = dataset_cls( 153 | mode=val_name, 154 | joint_transform_list=val_joint_transform_list, 155 | img_transform=val_input_transform, 156 | label_transform=target_transform, 157 | eval_folder=args.eval_folder) 158 | 159 | update_dataset_inst(dataset_inst=val_set) 160 | 161 | if args.apex: 162 | from datasets.sampler import DistributedSampler 163 | val_sampler = DistributedSampler(val_set, pad=False, permutation=False, 164 | consecutive_sample=False) 165 | else: 166 | val_sampler = None 167 | 168 | val_loader = DataLoader(val_set, batch_size=args.bs_val, 169 | num_workers=args.num_workers // 2, 170 | shuffle=False, drop_last=False, 171 | sampler=val_sampler) 172 | 173 | if args.eval is not None: 174 | # Don't create train dataloader if eval 175 | train_set = None 176 | train_loader = None 177 | else: 178 | train_set = dataset_cls( 179 | mode='train', 180 | joint_transform_list=train_joint_transform_list, 181 | img_transform=train_input_transform, 182 | label_transform=target_train_transform) 183 | 184 | if args.apex: 185 | from datasets.sampler import DistributedSampler 186 | train_sampler = DistributedSampler(train_set, pad=True, 187 | permutation=True, 188 | consecutive_sample=False) 189 | train_batch_size = args.bs_trn 190 | else: 191 | train_sampler = None 192 | train_batch_size = args.bs_trn * args.ngpu 193 | 194 | train_loader = DataLoader(train_set, batch_size=train_batch_size, 195 | num_workers=args.num_workers, 196 | shuffle=(train_sampler is None), 197 | drop_last=True, sampler=train_sampler) 198 | 199 | return train_loader, val_loader, train_set 200 | -------------------------------------------------------------------------------- /semantic-segmentation/network/Resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code Adapted from: 3 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | import torch.nn as nn 37 | import torch.utils.model_zoo as model_zoo 38 | import network.mynn as mynn 39 | 40 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 41 | 'resnet152'] 42 | 43 | 44 | model_urls = { 45 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 46 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 47 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 48 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 49 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 50 | } 51 | 52 | 53 | def conv3x3(in_planes, out_planes, stride=1): 54 | """3x3 convolution with padding""" 55 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | 58 | 59 | class BasicBlock(nn.Module): 60 | """ 61 | Basic Block for Resnet 62 | """ 63 | expansion = 1 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(BasicBlock, self).__init__() 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = mynn.Norm2d(planes) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.conv2 = conv3x3(planes, planes) 71 | self.bn2 = mynn.Norm2d(planes) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Bottleneck(nn.Module): 95 | """ 96 | Bottleneck Layer for Resnet 97 | """ 98 | expansion = 4 99 | 100 | def __init__(self, inplanes, planes, stride=1, downsample=None): 101 | super(Bottleneck, self).__init__() 102 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 103 | self.bn1 = mynn.Norm2d(planes) 104 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 105 | padding=1, bias=False) 106 | self.bn2 = mynn.Norm2d(planes) 107 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 108 | self.bn3 = mynn.Norm2d(planes * self.expansion) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.downsample = downsample 111 | self.stride = stride 112 | 113 | def forward(self, x): 114 | residual = x 115 | 116 | out = self.conv1(x) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.downsample is not None: 128 | residual = self.downsample(x) 129 | 130 | out += residual 131 | out = self.relu(out) 132 | 133 | return out 134 | 135 | 136 | class ResNet(nn.Module): 137 | """ 138 | Resnet Global Module for Initialization 139 | """ 140 | def __init__(self, block, layers, num_classes=1000): 141 | self.inplanes = 64 142 | super(ResNet, self).__init__() 143 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = mynn.Norm2d(64) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 152 | self.avgpool = nn.AvgPool2d(7, stride=1) 153 | self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, nn.BatchNorm2d): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1): 163 | downsample = None 164 | if stride != 1 or self.inplanes != planes * block.expansion: 165 | downsample = nn.Sequential( 166 | nn.Conv2d(self.inplanes, planes * block.expansion, 167 | kernel_size=1, stride=stride, bias=False), 168 | mynn.Norm2d(planes * block.expansion), 169 | ) 170 | 171 | layers = [] 172 | layers.append(block(self.inplanes, planes, stride, downsample)) 173 | self.inplanes = planes * block.expansion 174 | for index in range(1, blocks): 175 | layers.append(block(self.inplanes, planes)) 176 | 177 | return nn.Sequential(*layers) 178 | 179 | def forward(self, x): 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | x = self.maxpool(x) 184 | 185 | x = self.layer1(x) 186 | x = self.layer2(x) 187 | x = self.layer3(x) 188 | x = self.layer4(x) 189 | 190 | x = self.avgpool(x) 191 | x = x.view(x.size(0), -1) 192 | x = self.fc(x) 193 | 194 | return x 195 | 196 | 197 | def resnet18(pretrained=True, **kwargs): 198 | """Constructs a ResNet-18 model. 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 204 | if pretrained: 205 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 206 | return model 207 | 208 | 209 | def resnet34(pretrained=True, **kwargs): 210 | """Constructs a ResNet-34 model. 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 216 | if pretrained: 217 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 218 | return model 219 | 220 | 221 | def resnet50(pretrained=True, **kwargs): 222 | """Constructs a ResNet-50 model. 223 | 224 | Args: 225 | pretrained (bool): If True, returns a model pre-trained on ImageNet 226 | """ 227 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 228 | if pretrained: 229 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 230 | return model 231 | 232 | 233 | def resnet101(pretrained=True, **kwargs): 234 | """Constructs a ResNet-101 model. 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | """ 239 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 240 | if pretrained: 241 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 242 | return model 243 | 244 | 245 | def resnet152(pretrained=True, **kwargs): 246 | """Constructs a ResNet-152 model. 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | """ 251 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 252 | if pretrained: 253 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 254 | return model 255 | -------------------------------------------------------------------------------- /semantic-segmentation/network/PSA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch._utils 4 | import torch.nn.functional as F 5 | 6 | def constant_init(module, val, bias=0): 7 | if hasattr(module, 'weight') and module.weight is not None: 8 | nn.init.constant_(module.weight, val) 9 | if hasattr(module, 'bias') and module.bias is not None: 10 | nn.init.constant_(module.bias, bias) 11 | 12 | 13 | def kaiming_init(module, 14 | a=0, 15 | mode='fan_out', 16 | nonlinearity='relu', 17 | bias=0, 18 | distribution='normal'): 19 | assert distribution in ['uniform', 'normal'] 20 | if distribution == 'uniform': 21 | nn.init.kaiming_uniform_( 22 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 23 | else: 24 | nn.init.kaiming_normal_( 25 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 26 | if hasattr(module, 'bias') and module.bias is not None: 27 | nn.init.constant_(module.bias, bias) 28 | 29 | class PSA_p(nn.Module): 30 | def __init__(self, inplanes, planes, kernel_size=1, stride=1): 31 | super(PSA_p, self).__init__() 32 | 33 | self.inplanes = inplanes 34 | self.inter_planes = planes // 2 35 | self.planes = planes 36 | self.kernel_size = kernel_size 37 | self.stride = stride 38 | self.padding = (kernel_size-1)//2 39 | 40 | self.conv_q_right = nn.Conv2d(self.inplanes, 1, kernel_size=1, stride=stride, padding=0, bias=False) 41 | self.conv_v_right = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) 42 | self.conv_up = nn.Conv2d(self.inter_planes, self.planes, kernel_size=1, stride=1, padding=0, bias=False) 43 | self.softmax_right = nn.Softmax(dim=2) 44 | self.sigmoid = nn.Sigmoid() 45 | 46 | self.conv_q_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #g 47 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 48 | self.conv_v_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, bias=False) #theta 49 | self.softmax_left = nn.Softmax(dim=2) 50 | 51 | self.reset_parameters() 52 | 53 | def reset_parameters(self): 54 | kaiming_init(self.conv_q_right, mode='fan_in') 55 | kaiming_init(self.conv_v_right, mode='fan_in') 56 | kaiming_init(self.conv_q_left, mode='fan_in') 57 | kaiming_init(self.conv_v_left, mode='fan_in') 58 | 59 | self.conv_q_right.inited = True 60 | self.conv_v_right.inited = True 61 | self.conv_q_left.inited = True 62 | self.conv_v_left.inited = True 63 | 64 | def spatial_pool(self, x): 65 | input_x = self.conv_v_right(x) 66 | 67 | batch, channel, height, width = input_x.size() 68 | 69 | # [N, IC, H*W] 70 | input_x = input_x.view(batch, channel, height * width) 71 | 72 | # [N, 1, H, W] 73 | context_mask = self.conv_q_right(x) 74 | 75 | # [N, 1, H*W] 76 | context_mask = context_mask.view(batch, 1, height * width) 77 | 78 | # [N, 1, H*W] 79 | context_mask = self.softmax_right(context_mask) 80 | 81 | # [N, IC, 1] 82 | # context = torch.einsum('ndw,new->nde', input_x, context_mask) 83 | context = torch.matmul(input_x, context_mask.transpose(1,2)) 84 | # [N, IC, 1, 1] 85 | context = context.unsqueeze(-1) 86 | 87 | # [N, OC, 1, 1] 88 | context = self.conv_up(context) 89 | 90 | # [N, OC, 1, 1] 91 | mask_ch = self.sigmoid(context) 92 | 93 | out = x * mask_ch 94 | 95 | return out 96 | 97 | def channel_pool(self, x): 98 | # [N, IC, H, W] 99 | g_x = self.conv_q_left(x) 100 | 101 | batch, channel, height, width = g_x.size() 102 | 103 | # [N, IC, 1, 1] 104 | avg_x = self.avg_pool(g_x) 105 | 106 | batch, channel, avg_x_h, avg_x_w = avg_x.size() 107 | 108 | # [N, 1, IC] 109 | avg_x = avg_x.view(batch, channel, avg_x_h * avg_x_w).permute(0, 2, 1) 110 | 111 | # [N, IC, H*W] 112 | theta_x = self.conv_v_left(x).view(batch, self.inter_planes, height * width) 113 | 114 | # [N, 1, H*W] 115 | # context = torch.einsum('nde,new->ndw', avg_x, theta_x) 116 | context = torch.matmul(avg_x, theta_x) 117 | # [N, 1, H*W] 118 | context = self.softmax_left(context) 119 | 120 | # [N, 1, H, W] 121 | context = context.view(batch, 1, height, width) 122 | 123 | # [N, 1, H, W] 124 | mask_sp = self.sigmoid(context) 125 | 126 | out = x * mask_sp 127 | 128 | return out 129 | 130 | def forward(self, x): 131 | # [N, C, H, W] 132 | context_channel = self.spatial_pool(x) 133 | # [N, C, H, W] 134 | context_spatial = self.channel_pool(x) 135 | # [N, C, H, W] 136 | out = context_spatial + context_channel 137 | return out 138 | 139 | class PSA_s(nn.Module): 140 | def __init__(self, inplanes, planes, kernel_size=1, stride=1): 141 | super(PSA_s, self).__init__() 142 | 143 | self.inplanes = inplanes 144 | self.inter_planes = planes // 2 145 | self.planes = planes 146 | self.kernel_size = kernel_size 147 | self.stride = stride 148 | self.padding = (kernel_size - 1) // 2 149 | ratio = 4 150 | 151 | self.conv_q_right = nn.Conv2d(self.inplanes, 1, kernel_size=1, stride=stride, padding=0, bias=False) 152 | self.conv_v_right = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, 153 | bias=False) 154 | # self.conv_up = nn.Conv2d(self.inter_planes, self.planes, kernel_size=1, stride=1, padding=0, bias=False) 155 | self.conv_up = nn.Sequential( 156 | nn.Conv2d(self.inter_planes, self.inter_planes // ratio, kernel_size=1), 157 | nn.LayerNorm([self.inter_planes // ratio, 1, 1]), 158 | nn.ReLU(inplace=True), 159 | nn.Conv2d(self.inter_planes // ratio, self.planes, kernel_size=1) 160 | ) 161 | self.softmax_right = nn.Softmax(dim=2) 162 | self.sigmoid = nn.Sigmoid() 163 | 164 | self.conv_q_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, 165 | bias=False) # g 166 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 167 | self.conv_v_left = nn.Conv2d(self.inplanes, self.inter_planes, kernel_size=1, stride=stride, padding=0, 168 | bias=False) # theta 169 | self.softmax_left = nn.Softmax(dim=2) 170 | 171 | self.reset_parameters() 172 | 173 | def reset_parameters(self): 174 | kaiming_init(self.conv_q_right, mode='fan_in') 175 | kaiming_init(self.conv_v_right, mode='fan_in') 176 | kaiming_init(self.conv_q_left, mode='fan_in') 177 | kaiming_init(self.conv_v_left, mode='fan_in') 178 | 179 | self.conv_q_right.inited = True 180 | self.conv_v_right.inited = True 181 | self.conv_q_left.inited = True 182 | self.conv_v_left.inited = True 183 | 184 | def spatial_pool(self, x): 185 | input_x = self.conv_v_right(x) 186 | 187 | batch, channel, height, width = input_x.size() 188 | 189 | # [N, IC, H*W] 190 | input_x = input_x.view(batch, channel, height * width) 191 | 192 | # [N, 1, H, W] 193 | context_mask = self.conv_q_right(x) 194 | 195 | # [N, 1, H*W] 196 | context_mask = context_mask.view(batch, 1, height * width) 197 | 198 | # [N, 1, H*W] 199 | context_mask = self.softmax_right(context_mask) 200 | 201 | # [N, IC, 1] 202 | # context = torch.einsum('ndw,new->nde', input_x, context_mask) 203 | context = torch.matmul(input_x, context_mask.transpose(1, 2)) 204 | 205 | # [N, IC, 1, 1] 206 | context = context.unsqueeze(-1) 207 | 208 | # [N, OC, 1, 1] 209 | context = self.conv_up(context) 210 | 211 | # [N, OC, 1, 1] 212 | mask_ch = self.sigmoid(context) 213 | 214 | out = x * mask_ch 215 | 216 | return out 217 | 218 | def channel_pool(self, x): 219 | # [N, IC, H, W] 220 | g_x = self.conv_q_left(x) 221 | 222 | batch, channel, height, width = g_x.size() 223 | 224 | # [N, IC, 1, 1] 225 | avg_x = self.avg_pool(g_x) 226 | 227 | batch, channel, avg_x_h, avg_x_w = avg_x.size() 228 | 229 | # [N, 1, IC] 230 | avg_x = avg_x.view(batch, channel, avg_x_h * avg_x_w).permute(0, 2, 1) 231 | 232 | # [N, IC, H*W] 233 | theta_x = self.conv_v_left(x).view(batch, self.inter_planes, height * width) 234 | 235 | # [N, IC, H*W] 236 | theta_x = self.softmax_left(theta_x) 237 | 238 | # [N, 1, H*W] 239 | # context = torch.einsum('nde,new->ndw', avg_x, theta_x) 240 | context = torch.matmul(avg_x, theta_x) 241 | 242 | # [N, 1, H, W] 243 | context = context.view(batch, 1, height, width) 244 | 245 | # [N, 1, H, W] 246 | mask_sp = self.sigmoid(context) 247 | 248 | out = x * mask_sp 249 | 250 | return out 251 | 252 | def forward(self, x): 253 | # [N, C, H, W] 254 | out = self.spatial_pool(x) 255 | 256 | # [N, C, H, W] 257 | out = self.channel_pool(out) 258 | 259 | # [N, C, H, W] 260 | # out = context_spatial + context_channel 261 | 262 | return out 263 | 264 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/base_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | Generic dataloader base class 31 | """ 32 | import os 33 | import glob 34 | import numpy as np 35 | import torch 36 | 37 | from PIL import Image 38 | from torch.utils import data 39 | from config import cfg 40 | from datasets import uniform 41 | from runx.logx import logx 42 | from utils.misc import tensor_to_pil 43 | 44 | 45 | class BaseLoader(data.Dataset): 46 | def __init__(self, quality, mode, joint_transform_list, img_transform, 47 | label_transform): 48 | 49 | super(BaseLoader, self).__init__() 50 | self.quality = quality 51 | self.mode = mode 52 | self.joint_transform_list = joint_transform_list 53 | self.img_transform = img_transform 54 | self.label_transform = label_transform 55 | self.train = mode == 'train' 56 | self.id_to_trainid = {} 57 | self.centroids = None 58 | self.all_imgs = None 59 | self.drop_mask = np.zeros((1024, 2048)) 60 | self.drop_mask[15:840, 14:2030] = 1.0 61 | 62 | def build_epoch(self): 63 | """ 64 | For class uniform sampling ... every epoch, we want to recompute 65 | which tiles from which images we want to sample from, so that the 66 | sampling is uniformly random. 67 | """ 68 | self.imgs = uniform.build_epoch(self.all_imgs, 69 | self.centroids, 70 | self.num_classes, 71 | self.train) 72 | 73 | @staticmethod 74 | def find_images(img_root, mask_root, img_ext, mask_ext): 75 | """ 76 | Find image and segmentation mask files and return a list of 77 | tuples of them. 78 | """ 79 | img_path = '{}/*.{}'.format(img_root, img_ext) 80 | imgs = glob.glob(img_path) 81 | items = [] 82 | for full_img_fn in imgs: 83 | img_dir, img_fn = os.path.split(full_img_fn) 84 | img_name, _ = os.path.splitext(img_fn) 85 | full_mask_fn = '{}.{}'.format(img_name, mask_ext) 86 | full_mask_fn = os.path.join(mask_root, full_mask_fn) 87 | assert os.path.exists(full_mask_fn) 88 | items.append((full_img_fn, full_mask_fn)) 89 | return items 90 | 91 | def disable_coarse(self): 92 | pass 93 | 94 | def colorize_mask(self, image_array): 95 | """ 96 | Colorize the segmentation mask 97 | """ 98 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P') 99 | new_mask.putpalette(self.color_mapping) 100 | return new_mask 101 | 102 | def dump_images(self, img_name, mask, centroid, class_id, img): 103 | img = tensor_to_pil(img) 104 | outdir = 'new_dump_imgs_{}'.format(self.mode) 105 | os.makedirs(outdir, exist_ok=True) 106 | if centroid is not None: 107 | dump_img_name = '{}_{}'.format(self.trainid_to_name[class_id], 108 | img_name) 109 | else: 110 | dump_img_name = img_name 111 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 112 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 113 | out_raw_fn = os.path.join(outdir, dump_img_name + '_mask_raw.png') 114 | mask_img = self.colorize_mask(np.array(mask)) 115 | raw_img = Image.fromarray(np.array(mask)) 116 | img.save(out_img_fn) 117 | mask_img.save(out_msk_fn) 118 | raw_img.save(out_raw_fn) 119 | 120 | def do_transforms(self, img, mask, centroid, img_name, class_id): 121 | """ 122 | Do transformations to image and mask 123 | 124 | :returns: image, mask 125 | """ 126 | scale_float = 1.0 127 | 128 | if self.joint_transform_list is not None: 129 | for idx, xform in enumerate(self.joint_transform_list): 130 | if idx == 0 and centroid is not None: 131 | # HACK! Assume the first transform accepts a centroid 132 | outputs = xform(img, mask, centroid) 133 | else: 134 | outputs = xform(img, mask) 135 | 136 | if len(outputs) == 3: 137 | img, mask, scale_float = outputs 138 | else: 139 | img, mask = outputs 140 | 141 | if self.img_transform is not None: 142 | img = self.img_transform(img) 143 | 144 | if cfg.DATASET.DUMP_IMAGES: 145 | self.dump_images(img_name, mask, centroid, class_id, img) 146 | 147 | if self.label_transform is not None: 148 | mask = self.label_transform(mask) 149 | 150 | return img, mask, scale_float 151 | 152 | def read_images(self, img_path, mask_path, mask_out=False): 153 | img = Image.open(img_path).convert('RGB') 154 | if mask_path is None or mask_path == '': 155 | w, h = img.size 156 | mask = np.zeros((h, w)) 157 | else: 158 | mask = Image.open(mask_path) 159 | 160 | drop_out_mask = None 161 | # This code is specific to cityscapes 162 | if(cfg.DATASET.CITYSCAPES_CUSTOMCOARSE in mask_path): 163 | 164 | gtCoarse_mask_path = mask_path.replace(cfg.DATASET.CITYSCAPES_CUSTOMCOARSE, os.path.join(cfg.DATASET.CITYSCAPES_DIR, 'gtCoarse/gtCoarse') ) 165 | gtCoarse_mask_path = gtCoarse_mask_path.replace('leftImg8bit','gtCoarse_labelIds') 166 | gtCoarse=np.array(Image.open(gtCoarse_mask_path)) 167 | 168 | 169 | 170 | img_name = os.path.splitext(os.path.basename(img_path))[0] 171 | 172 | mask = np.array(mask) 173 | if (mask_out): 174 | mask = self.drop_mask * mask 175 | 176 | mask = mask.copy() 177 | for k, v in self.id_to_trainid.items(): 178 | binary_mask = (mask == k) #+ (gtCoarse == k) 179 | if ('refinement' in mask_path) and cfg.DROPOUT_COARSE_BOOST_CLASSES != None and v in cfg.DROPOUT_COARSE_BOOST_CLASSES and binary_mask.sum() > 0 and 'vidseq' not in mask_path: 180 | binary_mask += (gtCoarse == k) 181 | binary_mask[binary_mask >= 1] = 1 182 | mask[binary_mask] = gtCoarse[binary_mask] 183 | mask[binary_mask] = v 184 | 185 | 186 | mask = Image.fromarray(mask.astype(np.uint8)) 187 | return img, mask, img_name 188 | 189 | def __getitem__(self, index): 190 | """ 191 | Generate data: 192 | 193 | :return: 194 | - image: image, tensor 195 | - mask: mask, tensor 196 | - image_name: basename of file, string 197 | """ 198 | # Pick an image, fill in defaults if not using class uniform 199 | if len(self.imgs[index]) == 2: 200 | img_path, mask_path = self.imgs[index] 201 | centroid = None 202 | class_id = None 203 | else: 204 | img_path, mask_path, centroid, class_id = self.imgs[index] 205 | 206 | mask_out = cfg.DATASET.MASK_OUT_CITYSCAPES and \ 207 | cfg.DATASET.CUSTOM_COARSE_PROB is not None and \ 208 | 'refinement' in mask_path 209 | 210 | img, mask, img_name = self.read_images(img_path, mask_path, 211 | mask_out=mask_out) 212 | 213 | ###################################################################### 214 | # Thresholding is done when using coarse-labelled Cityscapes images 215 | ###################################################################### 216 | if 'refinement' in mask_path: 217 | 218 | mask = np.array(mask) 219 | prob_mask_path = mask_path.replace('.png', '_prob.png') 220 | # put it in 0 to 1 221 | prob_map = np.array(Image.open(prob_mask_path)) / 255.0 222 | prob_map_threshold = (prob_map < cfg.DATASET.CUSTOM_COARSE_PROB) 223 | mask[prob_map_threshold] = cfg.DATASET.IGNORE_LABEL 224 | mask = Image.fromarray(mask.astype(np.uint8)) 225 | 226 | img, mask, scale_float = self.do_transforms(img, mask, centroid, 227 | img_name, class_id) 228 | 229 | return img, mask, img_name, scale_float 230 | 231 | def __len__(self): 232 | return len(self.imgs) 233 | 234 | def calculate_weights(self): 235 | raise BaseException("not supported yet") 236 | -------------------------------------------------------------------------------- /semantic-segmentation/loss/rmi.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is adapted from: https://github.com/ZJULearning/RMI 3 | 4 | The implementation of the paper: 5 | Region Mutual Information Loss for Semantic Segmentation. 6 | """ 7 | 8 | # python 2.X, 3.X compatibility 9 | from __future__ import print_function 10 | from __future__ import division 11 | from __future__ import absolute_import 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | 17 | from loss import rmi_utils 18 | from config import cfg 19 | from apex import amp 20 | 21 | _euler_num = 2.718281828 # euler number 22 | _pi = 3.14159265 # pi 23 | _ln_2_pi = 1.837877 # ln(2 * pi) 24 | _CLIP_MIN = 1e-6 # min clip value after softmax or sigmoid operations 25 | _CLIP_MAX = 1.0 # max clip value after softmax or sigmoid operations 26 | _POS_ALPHA = 5e-4 # add this factor to ensure the AA^T is positive definite 27 | _IS_SUM = 1 # sum the loss per channel 28 | 29 | 30 | __all__ = ['RMILoss'] 31 | 32 | 33 | class RMILoss(nn.Module): 34 | """ 35 | region mutual information 36 | I(A, B) = H(A) + H(B) - H(A, B) 37 | This version need a lot of memory if do not dwonsample. 38 | """ 39 | def __init__(self, 40 | num_classes=21, 41 | rmi_radius=3, 42 | rmi_pool_way=1, 43 | rmi_pool_size=4, 44 | rmi_pool_stride=4, 45 | loss_weight_lambda=0.5, 46 | lambda_way=1, 47 | ignore_index=255): 48 | super(RMILoss, self).__init__() 49 | self.num_classes = num_classes 50 | # radius choices 51 | assert rmi_radius in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 52 | self.rmi_radius = rmi_radius 53 | assert rmi_pool_way in [0, 1, 2, 3] 54 | self.rmi_pool_way = rmi_pool_way 55 | 56 | # set the pool_size = rmi_pool_stride 57 | assert rmi_pool_size == rmi_pool_stride 58 | self.rmi_pool_size = rmi_pool_size 59 | self.rmi_pool_stride = rmi_pool_stride 60 | self.weight_lambda = loss_weight_lambda 61 | self.lambda_way = lambda_way 62 | 63 | # dimension of the distribution 64 | self.half_d = self.rmi_radius * self.rmi_radius 65 | self.d = 2 * self.half_d 66 | self.kernel_padding = self.rmi_pool_size // 2 67 | # ignore class 68 | self.ignore_index = ignore_index 69 | 70 | def forward(self, logits_4D, labels_4D, do_rmi=True): 71 | # explicitly disable fp16 mode because torch.cholesky and 72 | # torch.inverse aren't supported by half 73 | logits_4D.float() 74 | labels_4D.float() 75 | if cfg.TRAIN.FP16: 76 | with amp.disable_casts(): 77 | loss = self.forward_sigmoid(logits_4D, labels_4D, do_rmi=do_rmi) 78 | else: 79 | loss = self.forward_sigmoid(logits_4D, labels_4D, do_rmi=do_rmi) 80 | return loss 81 | 82 | def forward_sigmoid(self, logits_4D, labels_4D, do_rmi=False): 83 | """ 84 | Using the sigmiod operation both. 85 | Args: 86 | logits_4D : [N, C, H, W], dtype=float32 87 | labels_4D : [N, H, W], dtype=long 88 | do_rmi : bool 89 | """ 90 | # label mask -- [N, H, W, 1] 91 | label_mask_3D = labels_4D < self.num_classes 92 | 93 | # valid label 94 | valid_onehot_labels_4D = \ 95 | F.one_hot(labels_4D.long() * label_mask_3D.long(), 96 | num_classes=self.num_classes).float() 97 | label_mask_3D = label_mask_3D.float() 98 | label_mask_flat = label_mask_3D.view([-1, ]) 99 | valid_onehot_labels_4D = valid_onehot_labels_4D * \ 100 | label_mask_3D.unsqueeze(dim=3) 101 | valid_onehot_labels_4D.requires_grad_(False) 102 | 103 | # PART I -- calculate the sigmoid binary cross entropy loss 104 | valid_onehot_label_flat = \ 105 | valid_onehot_labels_4D.view([-1, self.num_classes]).requires_grad_(False) 106 | logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes]) 107 | 108 | # binary loss, multiplied by the not_ignore_mask 109 | valid_pixels = torch.sum(label_mask_flat) 110 | binary_loss = F.binary_cross_entropy_with_logits(logits_flat, 111 | target=valid_onehot_label_flat, 112 | weight=label_mask_flat.unsqueeze(dim=1), 113 | reduction='sum') 114 | bce_loss = torch.div(binary_loss, valid_pixels + 1.0) 115 | if not do_rmi: 116 | return bce_loss 117 | 118 | # PART II -- get rmi loss 119 | # onehot_labels_4D -- [N, C, H, W] 120 | probs_4D = logits_4D.sigmoid() * label_mask_3D.unsqueeze(dim=1) + _CLIP_MIN 121 | valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) 122 | 123 | # get region mutual information 124 | rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) 125 | 126 | # add together 127 | #logx.msg(f'lambda_way {self.lambda_way}') 128 | #logx.msg(f'bce_loss {bce_loss} weight_lambda {self.weight_lambda} rmi_loss {rmi_loss}') 129 | if self.lambda_way: 130 | final_loss = self.weight_lambda * bce_loss + rmi_loss * (1 - self.weight_lambda) 131 | else: 132 | final_loss = bce_loss + rmi_loss * self.weight_lambda 133 | 134 | return final_loss 135 | 136 | def inverse(self, x): 137 | return torch.inverse(x) 138 | 139 | def rmi_lower_bound(self, labels_4D, probs_4D): 140 | """ 141 | calculate the lower bound of the region mutual information. 142 | Args: 143 | labels_4D : [N, C, H, W], dtype=float32 144 | probs_4D : [N, C, H, W], dtype=float32 145 | """ 146 | assert labels_4D.size() == probs_4D.size() 147 | 148 | p, s = self.rmi_pool_size, self.rmi_pool_stride 149 | if self.rmi_pool_stride > 1: 150 | if self.rmi_pool_way == 0: 151 | labels_4D = F.max_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 152 | probs_4D = F.max_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 153 | elif self.rmi_pool_way == 1: 154 | labels_4D = F.avg_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 155 | probs_4D = F.avg_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) 156 | elif self.rmi_pool_way == 2: 157 | # interpolation 158 | shape = labels_4D.size() 159 | new_h, new_w = shape[2] // s, shape[3] // s 160 | labels_4D = F.interpolate(labels_4D, size=(new_h, new_w), mode='nearest') 161 | probs_4D = F.interpolate(probs_4D, size=(new_h, new_w), mode='bilinear', align_corners=True) 162 | else: 163 | raise NotImplementedError("Pool way of RMI is not defined!") 164 | # we do not need the gradient of label. 165 | label_shape = labels_4D.size() 166 | n, c = label_shape[0], label_shape[1] 167 | 168 | # combine the high dimension points from label and probability map. new shape [N, C, radius * radius, H, W] 169 | la_vectors, pr_vectors = rmi_utils.map_get_pairs(labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0) 170 | 171 | la_vectors = la_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor).requires_grad_(False) 172 | pr_vectors = pr_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor) 173 | 174 | # small diagonal matrix, shape = [1, 1, radius * radius, radius * radius] 175 | diag_matrix = torch.eye(self.half_d).unsqueeze(dim=0).unsqueeze(dim=0) 176 | 177 | # the mean and covariance of these high dimension points 178 | # Var(X) = E(X^2) - E(X) E(X), N * Var(X) = X^2 - X E(X) 179 | la_vectors = la_vectors - la_vectors.mean(dim=3, keepdim=True) 180 | la_cov = torch.matmul(la_vectors, la_vectors.transpose(2, 3)) 181 | 182 | pr_vectors = pr_vectors - pr_vectors.mean(dim=3, keepdim=True) 183 | pr_cov = torch.matmul(pr_vectors, pr_vectors.transpose(2, 3)) 184 | # https://github.com/pytorch/pytorch/issues/7500 185 | # waiting for batched torch.cholesky_inverse() 186 | # pr_cov_inv = torch.inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) 187 | pr_cov_inv = self.inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) 188 | # if the dimension of the point is less than 9, you can use the below function 189 | # to acceleration computational speed. 190 | #pr_cov_inv = utils.batch_cholesky_inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) 191 | 192 | la_pr_cov = torch.matmul(la_vectors, pr_vectors.transpose(2, 3)) 193 | # the approxiamation of the variance, det(c A) = c^n det(A), A is in n x n shape; 194 | # then log det(c A) = n log(c) + log det(A). 195 | # appro_var = appro_var / n_points, we do not divide the appro_var by number of points here, 196 | # and the purpose is to avoid underflow issue. 197 | # If A = A^T, A^-1 = (A^-1)^T. 198 | appro_var = la_cov - torch.matmul(la_pr_cov.matmul(pr_cov_inv), la_pr_cov.transpose(-2, -1)) 199 | #appro_var = la_cov - torch.chain_matmul(la_pr_cov, pr_cov_inv, la_pr_cov.transpose(-2, -1)) 200 | #appro_var = torch.div(appro_var, n_points.type_as(appro_var)) + diag_matrix.type_as(appro_var) * 1e-6 201 | 202 | # The lower bound. If A is nonsingular, ln( det(A) ) = Tr( ln(A) ). 203 | rmi_now = 0.5 * rmi_utils.log_det_by_cholesky(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) 204 | #rmi_now = 0.5 * torch.logdet(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) 205 | 206 | # mean over N samples. sum over classes. 207 | rmi_per_class = rmi_now.view([-1, self.num_classes]).mean(dim=0).float() 208 | #is_half = False 209 | #if is_half: 210 | # rmi_per_class = torch.div(rmi_per_class, float(self.half_d / 2.0)) 211 | #else: 212 | rmi_per_class = torch.div(rmi_per_class, float(self.half_d)) 213 | 214 | rmi_loss = torch.sum(rmi_per_class) if _IS_SUM else torch.mean(rmi_per_class) 215 | return rmi_loss 216 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | import os 31 | import os.path as path 32 | 33 | from config import cfg 34 | from runx.logx import logx 35 | from datasets.base_loader import BaseLoader 36 | import datasets.cityscapes_labels as cityscapes_labels 37 | import datasets.uniform as uniform 38 | from datasets.utils import make_dataset_folder 39 | 40 | 41 | def cities_cv_split(root, split, cv_split): 42 | """ 43 | Find cities that correspond to a given split of the data. We split the data 44 | such that a given city belongs to either train or val, but never both. cv0 45 | is defined to be the default split. 46 | 47 | all_cities = [x x x x x x x x x x x x] 48 | val: 49 | split0 [x x x ] 50 | split1 [ x x x ] 51 | split2 [ x x x ] 52 | trn: 53 | split0 [ x x x x x x x x x] 54 | split1 [x x x x x x x x x] 55 | split2 [x x x x x x x x ] 56 | 57 | split - train/val/test 58 | cv_split - 0,1,2,3 59 | 60 | cv_split == 3 means use train + val 61 | """ 62 | trn_path = path.join(root, 'leftImg8bit_trainvaltest/leftImg8bit', 'train') 63 | val_path = path.join(root, 'leftImg8bit_trainvaltest/leftImg8bit', 'val') 64 | 65 | trn_cities = ['train/' + c for c in os.listdir(trn_path)] 66 | trn_cities = sorted(trn_cities) # sort to insure reproducibility 67 | val_cities = ['val/' + c for c in os.listdir(val_path)] 68 | 69 | all_cities = val_cities + trn_cities 70 | 71 | if cv_split == 3: 72 | logx.msg('cv split {} {} {}'.format(split, cv_split, all_cities)) 73 | return all_cities 74 | 75 | num_val_cities = len(val_cities) 76 | num_cities = len(all_cities) 77 | 78 | offset = cv_split * num_cities // cfg.DATASET.CV_SPLITS 79 | cities = [] 80 | for j in range(num_cities): 81 | if j >= offset and j < (offset + num_val_cities): 82 | if split == 'val': 83 | cities.append(all_cities[j]) 84 | else: 85 | if split == 'train': 86 | cities.append(all_cities[j]) 87 | 88 | logx.msg('cv split {} {} {}'.format(split, cv_split, cities)) 89 | return cities 90 | 91 | 92 | def coarse_cities(root): 93 | """ 94 | Find coarse cities 95 | """ 96 | split = 'train_extra' 97 | coarse_path = path.join(root, 'leftImg8bit_trainextra/leftImg8bit', 98 | split) 99 | coarse_cities = [f'{split}/' + c for c in os.listdir(coarse_path)] 100 | 101 | logx.msg(f'found {len(coarse_cities)} coarse cities') 102 | return coarse_cities 103 | 104 | 105 | class Loader(BaseLoader): 106 | num_classes = 19 107 | ignore_label = 255 108 | trainid_to_name = {} 109 | color_mapping = [] 110 | 111 | def __init__(self, mode, quality='fine', joint_transform_list=None, 112 | img_transform=None, label_transform=None, eval_folder=None): 113 | 114 | super(Loader, self).__init__(quality=quality, mode=mode, 115 | joint_transform_list=joint_transform_list, 116 | img_transform=img_transform, 117 | label_transform=label_transform) 118 | 119 | ###################################################################### 120 | # Cityscapes-specific stuff: 121 | ###################################################################### 122 | self.root = cfg.DATASET.CITYSCAPES_DIR 123 | self.id_to_trainid = cityscapes_labels.label2trainid 124 | self.trainid_to_name = cityscapes_labels.trainId2name 125 | self.fill_colormap() 126 | img_ext = 'png' 127 | mask_ext = 'png' 128 | img_root = path.join(self.root, 'leftImg8bit_trainvaltest/leftImg8bit') 129 | mask_root = path.join(self.root, 'gtFine_trainvaltest/gtFine') 130 | if mode == 'folder': 131 | self.all_imgs = make_dataset_folder(eval_folder) 132 | else: 133 | self.fine_cities = cities_cv_split(self.root, mode, cfg.DATASET.CV) 134 | self.all_imgs = self.find_cityscapes_images( 135 | self.fine_cities, img_root, mask_root, img_ext, mask_ext) 136 | 137 | logx.msg(f'cn num_classes {self.num_classes}') 138 | self.fine_centroids = uniform.build_centroids(self.all_imgs, 139 | self.num_classes, 140 | self.train, 141 | cv=cfg.DATASET.CV, 142 | id2trainid=self.id_to_trainid) 143 | self.centroids = self.fine_centroids 144 | 145 | if cfg.DATASET.COARSE_BOOST_CLASSES and mode == 'train': 146 | self.coarse_cities = coarse_cities(self.root) 147 | img_root = path.join(self.root, 148 | 'leftImg8bit_trainextra/leftImg8bit') 149 | mask_root = path.join(self.root, 'gtCoarse', 'gtCoarse') 150 | self.coarse_imgs = self.find_cityscapes_images( 151 | self.coarse_cities, img_root, mask_root, img_ext, mask_ext, 152 | fine_coarse='gtCoarse') 153 | 154 | if cfg.DATASET.CLASS_UNIFORM_PCT: 155 | 156 | custom_coarse = (cfg.DATASET.CUSTOM_COARSE_PROB is not None) 157 | self.coarse_centroids = uniform.build_centroids( 158 | self.coarse_imgs, self.num_classes, self.train, 159 | coarse=(not custom_coarse), custom_coarse=custom_coarse, 160 | id2trainid=self.id_to_trainid) 161 | 162 | for cid in cfg.DATASET.COARSE_BOOST_CLASSES: 163 | self.centroids[cid].extend(self.coarse_centroids[cid]) 164 | else: 165 | self.all_imgs.extend(self.coarse_imgs) 166 | 167 | self.build_epoch() 168 | 169 | def disable_coarse(self): 170 | """ 171 | Turn off using coarse images in training 172 | """ 173 | self.centroids = self.fine_centroids 174 | 175 | def only_coarse(self): 176 | """ 177 | Turn on using coarse images in training 178 | """ 179 | print('==============+Running Only Coarse+===============') 180 | self.centroids = self.coarse_centroids 181 | 182 | def find_cityscapes_images(self, cities, img_root, mask_root, img_ext, 183 | mask_ext, fine_coarse='gtFine'): 184 | """ 185 | Find image and segmentation mask files and return a list of 186 | tuples of them. 187 | 188 | Inputs: 189 | img_root: path to parent directory of train/val/test dirs 190 | mask_root: path to parent directory of train/val/test dirs 191 | img_ext: image file extension 192 | mask_ext: mask file extension 193 | cities: a list of cities, each element in the form of 'train/a_city' 194 | or 'val/a_city', for example. 195 | """ 196 | items = [] 197 | for city in cities: 198 | img_dir = '{root}/{city}'.format(root=img_root, city=city) 199 | for file_name in os.listdir(img_dir): 200 | basename, ext = os.path.splitext(file_name) 201 | assert ext == '.' + img_ext, '{} {}'.format(ext, img_ext) 202 | full_img_fn = os.path.join(img_dir, file_name) 203 | basename, ext = file_name.split('_leftImg8bit') 204 | if cfg.DATASET.CUSTOM_COARSE_PROB and fine_coarse != 'gtFine': 205 | mask_fn = f'{basename}_leftImg8bit.png' 206 | cc_path = cfg.DATASET.CITYSCAPES_CUSTOMCOARSE 207 | full_mask_fn = os.path.join(cc_path, city, mask_fn) 208 | os.path.isfile(full_mask_fn) 209 | else: 210 | mask_fn = f'{basename}_{fine_coarse}_labelIds{ext}' 211 | full_mask_fn = os.path.join(mask_root, city, mask_fn) 212 | items.append((full_img_fn, full_mask_fn)) 213 | 214 | logx.msg('mode {} found {} images'.format(self.mode, len(items))) 215 | 216 | return items 217 | 218 | def fill_colormap(self): 219 | palette = [128, 64, 128, 220 | 244, 35, 232, 221 | 70, 70, 70, 222 | 102, 102, 156, 223 | 190, 153, 153, 224 | 153, 153, 153, 225 | 250, 170, 30, 226 | 220, 220, 0, 227 | 107, 142, 35, 228 | 152, 251, 152, 229 | 70, 130, 180, 230 | 220, 20, 60, 231 | 255, 0, 0, 232 | 0, 0, 142, 233 | 0, 0, 70, 234 | 0, 60, 100, 235 | 0, 80, 100, 236 | 0, 0, 230, 237 | 119, 11, 32] 238 | zero_pad = 256 * 3 - len(palette) 239 | for i in range(zero_pad): 240 | palette.append(0) 241 | self.color_mapping = palette 242 | -------------------------------------------------------------------------------- /semantic-segmentation/network/mscale2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | This is an alternative implementation of mscale, where we feed pairs of 32 | features from both lower and higher resolution images into the attention head. 33 | """ 34 | import torch 35 | from torch import nn 36 | 37 | from network.mynn import initialize_weights, Norm2d, Upsample 38 | from network.mynn import ResizeX, scale_as 39 | from network.utils import get_aspp, get_trunk 40 | from network.utils import make_seg_head, make_attn_head 41 | from config import cfg 42 | 43 | 44 | class MscaleBase(nn.Module): 45 | """ 46 | Multi-scale attention segmentation model base class 47 | """ 48 | def __init__(self): 49 | super(MscaleBase, self).__init__() 50 | self.criterion = None 51 | 52 | def _fwd(self, x): 53 | pass 54 | 55 | def nscale_forward(self, inputs, scales): 56 | """ 57 | Hierarchical attention, primarily used for getting best inference 58 | results. 59 | 60 | We use attention at multiple scales, giving priority to the lower 61 | resolutions. For example, if we have 4 scales {0.5, 1.0, 1.5, 2.0}, 62 | then evaluation is done as follows: 63 | 64 | p_joint = attn_1.5 * p_1.5 + (1 - attn_1.5) * down(p_2.0) 65 | p_joint = attn_1.0 * p_1.0 + (1 - attn_1.0) * down(p_joint) 66 | p_joint = up(attn_0.5 * p_0.5) * (1 - up(attn_0.5)) * p_joint 67 | 68 | The target scale is always 1.0, and 1.0 is expected to be part of the 69 | list of scales. When predictions are done at greater than 1.0 scale, 70 | the predictions are downsampled before combining with the next lower 71 | scale. 72 | 73 | Inputs: 74 | scales - a list of scales to evaluate 75 | inputs - dict containing 'images', the input, and 'gts', the ground 76 | truth mask 77 | 78 | Output: 79 | If training, return loss, else return prediction + attention 80 | """ 81 | x_1x = inputs['images'] 82 | 83 | assert 1.0 in scales, 'expected 1.0 to be the target scale' 84 | # Lower resolution provides attention for higher rez predictions, 85 | # so we evaluate in order: high to low 86 | scales = sorted(scales, reverse=True) 87 | pred = None 88 | last_feats = None 89 | 90 | for idx, s in enumerate(scales): 91 | x = ResizeX(x_1x, s) 92 | p, feats = self._fwd(x) 93 | 94 | # Generate attention prediction 95 | if idx > 0: 96 | assert last_feats is not None 97 | # downscale feats 98 | last_feats = scale_as(last_feats, feats) 99 | cat_feats = torch.cat([feats, last_feats], 1) 100 | attn = self.scale_attn(cat_feats) 101 | attn = scale_as(attn, p) 102 | 103 | if pred is None: 104 | # This is the top scale prediction 105 | pred = p 106 | elif s >= 1.0: 107 | # downscale previous 108 | pred = scale_as(pred, p) 109 | pred = attn * p + (1 - attn) * pred 110 | else: 111 | # upscale current 112 | p = attn * p 113 | p = scale_as(p, pred) 114 | attn = scale_as(attn, pred) 115 | pred = p + (1 - attn) * pred 116 | 117 | last_feats = feats 118 | 119 | if self.training: 120 | assert 'gts' in inputs 121 | gts = inputs['gts'] 122 | loss = self.criterion(pred, gts) 123 | return loss 124 | else: 125 | # FIXME: should add multi-scale values for pred and attn 126 | return {'pred': pred, 127 | 'attn_10x': attn} 128 | 129 | def two_scale_forward(self, inputs): 130 | assert 'images' in inputs 131 | 132 | x_1x = inputs['images'] 133 | x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE) 134 | 135 | p_lo, feats_lo = self._fwd(x_lo) 136 | p_1x, feats_hi = self._fwd(x_1x) 137 | 138 | feats_hi = scale_as(feats_hi, feats_lo) 139 | cat_feats = torch.cat([feats_lo, feats_hi], 1) 140 | logit_attn = self.scale_attn(cat_feats) 141 | logit_attn = scale_as(logit_attn, p_lo) 142 | 143 | p_lo = logit_attn * p_lo 144 | p_lo = scale_as(p_lo, p_1x) 145 | logit_attn = scale_as(logit_attn, p_1x) 146 | joint_pred = p_lo + (1 - logit_attn) * p_1x 147 | 148 | if self.training: 149 | assert 'gts' in inputs 150 | gts = inputs['gts'] 151 | loss = self.criterion(joint_pred, gts) 152 | return loss 153 | else: 154 | # FIXME: should add multi-scale values for pred and attn 155 | return {'pred': joint_pred, 156 | 'attn_10x': logit_attn} 157 | 158 | def forward(self, inputs): 159 | if cfg.MODEL.N_SCALES and not self.training: 160 | return self.nscale_forward(inputs, cfg.MODEL.N_SCALES) 161 | 162 | return self.two_scale_forward(inputs) 163 | 164 | 165 | class MscaleV3Plus(MscaleBase): 166 | """ 167 | DeepLabV3Plus-based mscale segmentation model 168 | """ 169 | def __init__(self, num_classes, trunk='wrn38', criterion=None): 170 | super(MscaleV3Plus, self).__init__() 171 | self.criterion = criterion 172 | self.backbone, s2_ch, _s4_ch, high_level_ch = get_trunk(trunk) 173 | self.aspp, aspp_out_ch = get_aspp(high_level_ch, 174 | bottleneck_ch=256, 175 | output_stride=8) 176 | self.bot_fine = nn.Conv2d(s2_ch, 48, kernel_size=1, bias=False) 177 | self.bot_aspp = nn.Conv2d(aspp_out_ch, 256, kernel_size=1, bias=False) 178 | 179 | # Semantic segmentation prediction head 180 | self.final = nn.Sequential( 181 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), 182 | Norm2d(256), 183 | nn.ReLU(inplace=True), 184 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 185 | Norm2d(256), 186 | nn.ReLU(inplace=True), 187 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 188 | 189 | # Scale-attention prediction head 190 | scale_in_ch = 2 * (256 + 48) 191 | 192 | self.scale_attn = nn.Sequential( 193 | nn.Conv2d(scale_in_ch, 256, kernel_size=3, padding=1, bias=False), 194 | Norm2d(256), 195 | nn.ReLU(inplace=True), 196 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 197 | Norm2d(256), 198 | nn.ReLU(inplace=True), 199 | nn.Conv2d(256, 1, kernel_size=1, bias=False), 200 | nn.Sigmoid()) 201 | 202 | if cfg.OPTIONS.INIT_DECODER: 203 | initialize_weights(self.bot_fine) 204 | initialize_weights(self.bot_aspp) 205 | initialize_weights(self.scale_attn) 206 | initialize_weights(self.final) 207 | else: 208 | initialize_weights(self.final) 209 | 210 | def _fwd(self, x): 211 | x_size = x.size() 212 | s2_features, _, final_features = self.backbone(x) 213 | aspp = self.aspp(final_features) 214 | 215 | conv_aspp = self.bot_aspp(aspp) 216 | conv_s2 = self.bot_fine(s2_features) 217 | conv_aspp = Upsample(conv_aspp, s2_features.size()[2:]) 218 | cat_s4 = [conv_s2, conv_aspp] 219 | cat_s4 = torch.cat(cat_s4, 1) 220 | 221 | final = self.final(cat_s4) 222 | out = Upsample(final, x_size[2:]) 223 | 224 | return out, cat_s4 225 | 226 | 227 | def DeepV3R50(num_classes, criterion): 228 | return MscaleV3Plus(num_classes, trunk='resnet-50', criterion=criterion) 229 | 230 | 231 | class Basic(MscaleBase): 232 | """ 233 | """ 234 | def __init__(self, num_classes, trunk='hrnetv2', criterion=None): 235 | super(Basic, self).__init__() 236 | self.criterion = criterion 237 | self.backbone, _, _, high_level_ch = get_trunk( 238 | trunk_name=trunk, output_stride=8) 239 | 240 | self.cls_head = make_seg_head(in_ch=high_level_ch, bot_ch=256, 241 | out_ch=num_classes) 242 | self.scale_attn = make_attn_head(in_ch=high_level_ch * 2, bot_ch=256, 243 | out_ch=1) 244 | 245 | def two_scale_forward(self, inputs): 246 | assert 'images' in inputs 247 | 248 | x_1x = inputs['images'] 249 | x_lo = ResizeX(x_1x, cfg.MODEL.MSCALE_LO_SCALE) 250 | 251 | p_lo, feats_lo = self._fwd(x_lo) 252 | p_1x, feats_hi = self._fwd(x_1x) 253 | 254 | feats_lo = scale_as(feats_lo, feats_hi) 255 | cat_feats = torch.cat([feats_lo, feats_hi], 1) 256 | logit_attn = self.scale_attn(cat_feats) 257 | logit_attn_lo = scale_as(logit_attn, p_lo) 258 | logit_attn_1x = scale_as(logit_attn, p_1x) 259 | 260 | p_lo = logit_attn_lo * p_lo 261 | p_lo = scale_as(p_lo, p_1x) 262 | joint_pred = p_lo + (1 - logit_attn_1x) * p_1x 263 | 264 | if self.training: 265 | assert 'gts' in inputs 266 | gts = inputs['gts'] 267 | loss = self.criterion(joint_pred, gts) 268 | return loss 269 | else: 270 | return joint_pred, logit_attn_1x 271 | 272 | def _fwd(self, x, aspp_lo=None, aspp_attn=None, scale_float=None): 273 | _, _, final_features = self.backbone(x) 274 | pred = self.cls_head(final_features) 275 | pred = scale_as(pred, x) 276 | 277 | return pred, final_features 278 | 279 | 280 | def HRNet(num_classes, criterion, s2s4=None): 281 | return Basic(num_classes=num_classes, criterion=criterion, 282 | trunk='hrnetv2') 283 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /semantic-segmentation/network/xception.py: -------------------------------------------------------------------------------- 1 | # Xception71 2 | # Code Adapted from: 3 | # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/xception.py 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from config import cfg 10 | from network.mynn import Norm2d 11 | from apex.parallel import SyncBatchNorm 12 | from runx.logx import logx 13 | 14 | 15 | def fixed_padding(inputs, kernel_size, dilation): 16 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 17 | pad_total = kernel_size_effective - 1 18 | pad_beg = pad_total // 2 19 | pad_end = pad_total - pad_beg 20 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 21 | return padded_inputs 22 | 23 | 24 | class SeparableConv2d(nn.Module): 25 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, 26 | bias=False, BatchNorm=None): 27 | super(SeparableConv2d, self).__init__() 28 | 29 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, 30 | dilation, groups=inplanes, bias=bias) 31 | self.bn = BatchNorm(inplanes) 32 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 33 | 34 | def forward(self, x): 35 | x = fixed_padding(x, self.conv1.kernel_size[0], 36 | dilation=self.conv1.dilation[0]) 37 | x = self.conv1(x) 38 | x = self.bn(x) 39 | x = self.pointwise(x) 40 | return x 41 | 42 | 43 | class Block(nn.Module): 44 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, 45 | BatchNorm=None, start_with_relu=True, grow_first=True, 46 | is_last=False): 47 | super(Block, self).__init__() 48 | 49 | if planes != inplanes or stride != 1: 50 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, 51 | bias=False) 52 | self.skipbn = BatchNorm(planes) 53 | else: 54 | self.skip = None 55 | 56 | self.relu = nn.ReLU(inplace=True) 57 | rep = [] 58 | 59 | filters = inplanes 60 | if grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, 63 | BatchNorm=BatchNorm)) 64 | rep.append(BatchNorm(planes)) 65 | filters = planes 66 | 67 | for i in range(reps - 1): 68 | rep.append(self.relu) 69 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, 70 | BatchNorm=BatchNorm)) 71 | rep.append(BatchNorm(filters)) 72 | 73 | if not grow_first: 74 | rep.append(self.relu) 75 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, 76 | BatchNorm=BatchNorm)) 77 | rep.append(BatchNorm(planes)) 78 | 79 | if stride != 1: 80 | rep.append(self.relu) 81 | rep.append(SeparableConv2d(planes, planes, 3, 2, 82 | BatchNorm=BatchNorm)) 83 | rep.append(BatchNorm(planes)) 84 | 85 | if stride == 1 and is_last: 86 | rep.append(self.relu) 87 | rep.append(SeparableConv2d(planes, planes, 3, 1, 88 | BatchNorm=BatchNorm)) 89 | rep.append(BatchNorm(planes)) 90 | 91 | if not start_with_relu: 92 | rep = rep[1:] 93 | 94 | self.rep = nn.Sequential(*rep) 95 | 96 | def forward(self, inp): 97 | x = self.rep(inp) 98 | 99 | if self.skip is not None: 100 | skip = self.skip(inp) 101 | skip = self.skipbn(skip) 102 | else: 103 | skip = inp 104 | 105 | x = x + skip 106 | 107 | return x 108 | 109 | 110 | class xception71(nn.Module): 111 | """ 112 | Modified Alighed Xception 113 | """ 114 | def __init__(self, output_stride, BatchNorm, 115 | pretrained=True): 116 | super(xception71, self).__init__() 117 | 118 | self.output_stride = output_stride 119 | if self.output_stride == 16: 120 | middle_block_dilation = 1 121 | exit_block_dilations = (1, 2) 122 | exit_stride = 2 123 | elif self.output_stride == 8: 124 | middle_block_dilation = 2 125 | exit_block_dilations = (2, 4) 126 | exit_stride = 1 127 | else: 128 | raise NotImplementedError 129 | 130 | # Entry flow 131 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 132 | self.bn1 = BatchNorm(32) 133 | self.relu = nn.ReLU(inplace=True) 134 | 135 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 136 | self.bn2 = BatchNorm(64) 137 | 138 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 139 | # stride4 140 | 141 | self.block2 = Block(128, 256, reps=2, stride=1, BatchNorm=BatchNorm, start_with_relu=False, 142 | grow_first=True) 143 | self.block3 = Block(256, 728, reps=2, stride=2, BatchNorm=BatchNorm, 144 | start_with_relu=True, grow_first=True, is_last=True) 145 | # stride8 146 | 147 | # Middle flow 148 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 149 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 150 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 151 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 152 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 153 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 154 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 155 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 156 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 157 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 158 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 159 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 160 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 161 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 162 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 163 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 164 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 165 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 166 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 167 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 168 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 169 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 170 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 171 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 172 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 173 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 174 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 175 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 176 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 177 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 178 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 179 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 180 | 181 | # Exit flow 182 | self.block20 = Block(728, 1024, reps=2, stride=exit_stride, dilation=exit_block_dilations[0], 183 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 184 | 185 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 186 | self.bn3 = BatchNorm(1536) 187 | 188 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 189 | self.bn4 = BatchNorm(1536) 190 | 191 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 192 | self.bn5 = BatchNorm(2048) 193 | 194 | # Init weights 195 | self._init_weight() 196 | 197 | # Load pretrained model 198 | if pretrained: 199 | self._load_pretrained_model() 200 | 201 | def forward(self, x): 202 | # Entry flow 203 | x = self.conv1(x) 204 | x = self.bn1(x) 205 | x = self.relu(x) 206 | 207 | x = self.conv2(x) 208 | x = self.bn2(x) 209 | str2 = self.relu(x) 210 | # s2 211 | str4 = self.block1(str2) 212 | str4 = self.relu(str4) 213 | # s4 214 | x = self.block2(str4) 215 | str8 = self.block3(x) 216 | # s8 217 | 218 | if self.output_stride == 8: 219 | low_level_feat, high_level_feat = str2, str4 220 | else: 221 | low_level_feat, high_level_feat = str4, str8 222 | 223 | # Middle flow 224 | x = self.block4(str8) 225 | x = self.block5(x) 226 | x = self.block6(x) 227 | x = self.block7(x) 228 | x = self.block8(x) 229 | x = self.block9(x) 230 | x = self.block10(x) 231 | x = self.block11(x) 232 | x = self.block12(x) 233 | x = self.block13(x) 234 | x = self.block14(x) 235 | x = self.block15(x) 236 | x = self.block16(x) 237 | x = self.block17(x) 238 | x = self.block18(x) 239 | x = self.block19(x) 240 | 241 | # Exit flow 242 | x = self.block20(x) 243 | x = self.relu(x) 244 | x = self.conv3(x) 245 | x = self.bn3(x) 246 | x = self.relu(x) 247 | 248 | x = self.conv4(x) 249 | x = self.bn4(x) 250 | x = self.relu(x) 251 | 252 | x = self.conv5(x) 253 | x = self.bn5(x) 254 | x = self.relu(x) 255 | 256 | return low_level_feat, high_level_feat, x 257 | 258 | def _init_weight(self): 259 | for m in self.modules(): 260 | if isinstance(m, nn.Conv2d): 261 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 262 | m.weight.data.normal_(0, math.sqrt(2. / n)) 263 | elif isinstance(m, SyncBatchNorm): 264 | m.weight.data.fill_(1) 265 | m.bias.data.zero_() 266 | elif isinstance(m, nn.BatchNorm2d): 267 | m.weight.data.fill_(1) 268 | m.bias.data.zero_() 269 | 270 | def _load_pretrained_model(self): 271 | pretrained_model = cfg.MODEL.X71_CHECKPOINT 272 | ckpt = torch.load(pretrained_model, map_location='cpu') 273 | model_dict = {k.replace('module.', ''): v for k, v in 274 | ckpt['model_dict'].items()} 275 | state_dict = self.state_dict() 276 | state_dict.update(model_dict) 277 | self.load_state_dict(state_dict, strict=False) 278 | del ckpt 279 | logx.msg('Loaded {} weights'.format(pretrained_model)) 280 | 281 | 282 | if __name__ == "__main__": 283 | model = xception71(BatchNorm=Norm2d, pretrained=True, 284 | output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) 289 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/cityscapes_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | # File taken from https://github.com/mcordts/cityscapesScripts/ 3 | # License File Available at: 4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt 5 | 6 | # ---------------------- 7 | # The Cityscapes Dataset 8 | # ---------------------- 9 | # 10 | # 11 | # License agreement 12 | # ----------------- 13 | # 14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree: 15 | # 16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions. 17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website. 18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character. 19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain. 20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt). 21 | # 22 | # 23 | # Contact 24 | # ------- 25 | # 26 | # Marius Cordts, Mohamed Omran 27 | # www.cityscapes-dataset.net 28 | 29 | """ 30 | from collections import namedtuple 31 | 32 | 33 | #-------------------------------------------------------------------------------- 34 | # Definitions 35 | #-------------------------------------------------------------------------------- 36 | 37 | # a label and all meta information 38 | Label = namedtuple( 'Label' , [ 39 | 40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 41 | # We use them to uniquely name a class 42 | 43 | 'id' , # An integer ID that is associated with this label. 44 | # The IDs are used to represent the label in ground truth images 45 | # An ID of -1 means that this label does not have an ID and thus 46 | # is ignored when creating ground truth images (e.g. license plate). 47 | # Do not modify these IDs, since exactly these IDs are expected by the 48 | # evaluation server. 49 | 50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 51 | # ground truth images with train IDs, using the tools provided in the 52 | # 'preparation' folder. However, make sure to validate or submit results 53 | # to our evaluation server using the regular IDs above! 54 | # For trainIds, multiple labels might have the same ID. Then, these labels 55 | # are mapped to the same class in the ground truth images. For the inverse 56 | # mapping, we use the label that is defined first in the list below. 57 | # For example, mapping all void-type classes to the same ID in training, 58 | # might make sense for some approaches. 59 | # Max value is 255! 60 | 61 | 'category' , # The name of the category that this label belongs to 62 | 63 | 'categoryId' , # The ID of this category. Used to create ground truth images 64 | # on category level. 65 | 66 | 'hasInstances', # Whether this label distinguishes between single instances or not 67 | 68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 69 | # during evaluations or not 70 | 71 | 'color' , # The color of this label 72 | ] ) 73 | 74 | 75 | #-------------------------------------------------------------------------------- 76 | # A list of all labels 77 | #-------------------------------------------------------------------------------- 78 | 79 | # Please adapt the train IDs as appropriate for you approach. 80 | # Note that you might want to ignore labels with ID 255 during training. 81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 82 | # Make sure to provide your results using the original IDs and not the training IDs. 83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 84 | 85 | labels = [ 86 | # name id trainId category catId hasInstances ignoreInEval color 87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 122 | ] 123 | 124 | 125 | #-------------------------------------------------------------------------------- 126 | # Create dictionaries for a fast lookup 127 | #-------------------------------------------------------------------------------- 128 | 129 | # Please refer to the main method below for example usages! 130 | 131 | # name to label object 132 | name2label = { label.name : label for label in labels } 133 | # id to label object 134 | id2label = { label.id : label for label in labels } 135 | # trainId to label object 136 | trainId2label = { label.trainId : label for label in reversed(labels) } 137 | # label2trainid 138 | label2trainid = { label.id : label.trainId for label in labels } 139 | # trainId to label object 140 | trainId2name = { label.trainId : label.name for label in labels } 141 | trainId2color = { label.trainId : label.color for label in labels } 142 | # category to list of label objects 143 | category2labels = {} 144 | for label in labels: 145 | category = label.category 146 | if category in category2labels: 147 | category2labels[category].append(label) 148 | else: 149 | category2labels[category] = [label] 150 | 151 | #-------------------------------------------------------------------------------- 152 | # Assure single instance name 153 | #-------------------------------------------------------------------------------- 154 | 155 | # returns the label name that describes a single instance (if possible) 156 | # e.g. input | output 157 | # ---------------------- 158 | # car | car 159 | # cargroup | car 160 | # foo | None 161 | # foogroup | None 162 | # skygroup | None 163 | def assureSingleInstanceName( name ): 164 | # if the name is known, it is not a group 165 | if name in name2label: 166 | return name 167 | # test if the name actually denotes a group 168 | if not name.endswith("group"): 169 | return None 170 | # remove group 171 | name = name[:-len("group")] 172 | # test if the new name exists 173 | if not name in name2label: 174 | return None 175 | # test if the new name denotes a label that actually has instances 176 | if not name2label[name].hasInstances: 177 | return None 178 | # all good then 179 | return name 180 | 181 | #-------------------------------------------------------------------------------- 182 | # Main for testing 183 | #-------------------------------------------------------------------------------- 184 | 185 | # just a dummy main 186 | if __name__ == "__main__": 187 | # Print all the labels 188 | print("List of cityscapes labels:") 189 | print("") 190 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))) 191 | print((" " + ('-' * 98))) 192 | for label in labels: 193 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))) 194 | print("") 195 | 196 | print("Example usages:") 197 | 198 | # Map from name to label 199 | name = 'car' 200 | id = name2label[name].id 201 | print(("ID of label '{name}': {id}".format( name=name, id=id ))) 202 | 203 | # Map from ID to label 204 | category = id2label[id].category 205 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category ))) 206 | 207 | # Map from trainID to label 208 | trainId = 0 209 | name = trainId2label[trainId].name 210 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))) 211 | -------------------------------------------------------------------------------- /semantic-segmentation/datasets/uniform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2020 Nvidia Corporation 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | 2. Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | 3. Neither the name of the copyright holder nor the names of its contributors 15 | may be used to endorse or promote products derived from this software 16 | without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE 22 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | Uniform sampling of classes. 32 | For all images, for all classes, generate centroids around which to sample. 33 | 34 | All images are divided into tiles. 35 | For each tile, a class can be present or not. If it is 36 | present, calculate the centroid of the class and record it. 37 | 38 | We would like to thank Peter Kontschieder for the inspiration of this idea. 39 | """ 40 | 41 | import sys 42 | import os 43 | import json 44 | import numpy as np 45 | 46 | import torch 47 | 48 | from collections import defaultdict 49 | from scipy.ndimage.measurements import center_of_mass 50 | from PIL import Image 51 | from tqdm import tqdm 52 | from config import cfg 53 | from runx.logx import logx 54 | 55 | pbar = None 56 | 57 | 58 | class Point(): 59 | """ 60 | Point Class For X and Y Location 61 | """ 62 | def __init__(self, x, y): 63 | self.x = x 64 | self.y = y 65 | 66 | 67 | def calc_tile_locations(tile_size, image_size): 68 | """ 69 | Divide an image into tiles to help us cover classes that are spread out. 70 | tile_size: size of tile to distribute 71 | image_size: original image size 72 | return: locations of the tiles 73 | """ 74 | image_size_y, image_size_x = image_size 75 | locations = [] 76 | for y in range(image_size_y // tile_size): 77 | for x in range(image_size_x // tile_size): 78 | x_offs = x * tile_size 79 | y_offs = y * tile_size 80 | locations.append((x_offs, y_offs)) 81 | return locations 82 | 83 | 84 | def class_centroids_image(item, tile_size, num_classes, id2trainid): 85 | """ 86 | For one image, calculate centroids for all classes present in image. 87 | item: image, image_name 88 | tile_size: 89 | num_classes: 90 | id2trainid: mapping from original id to training ids 91 | return: Centroids are calculated for each tile. 92 | """ 93 | image_fn, label_fn = item 94 | centroids = defaultdict(list) 95 | mask = np.array(Image.open(label_fn)) 96 | image_size = mask.shape 97 | tile_locations = calc_tile_locations(tile_size, image_size) 98 | 99 | drop_mask = np.zeros((1024,2048)) 100 | drop_mask[15:840, 14:2030] = 1.0 101 | 102 | 103 | ##### 104 | if(cfg.DATASET.CITYSCAPES_CUSTOMCOARSE in label_fn): 105 | gtCoarse_mask_path = label_fn.replace(cfg.DATASET.CITYSCAPES_CUSTOMCOARSE, os.path.join(cfg.DATASET.CITYSCAPES_DIR, 'gtCoarse/gtCoarse') ) 106 | gtCoarse_mask_path = gtCoarse_mask_path.replace('leftImg8bit','gtCoarse_labelIds') 107 | gtCoarse=np.array(Image.open(gtCoarse_mask_path)) 108 | 109 | 110 | #### 111 | 112 | mask_copy = mask.copy() 113 | if id2trainid: 114 | for k, v in id2trainid.items(): 115 | binary_mask = (mask_copy == k) 116 | #This should only apply to auto labelled images 117 | if ('refinement' in label_fn) and cfg.DROPOUT_COARSE_BOOST_CLASSES != None and v in cfg.DROPOUT_COARSE_BOOST_CLASSES and binary_mask.sum() > 0: 118 | binary_mask += (gtCoarse == k) 119 | binary_mask[binary_mask >= 1] = 1 120 | mask[binary_mask] = gtCoarse[binary_mask] 121 | mask[binary_mask] = v 122 | 123 | for x_offs, y_offs in tile_locations: 124 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size] 125 | for class_id in range(num_classes): 126 | if class_id in patch: 127 | patch_class = (patch == class_id).astype(int) 128 | centroid_y, centroid_x = center_of_mass(patch_class) 129 | centroid_y = int(centroid_y) + y_offs 130 | centroid_x = int(centroid_x) + x_offs 131 | centroid = (centroid_x, centroid_y) 132 | centroids[class_id].append((image_fn, label_fn, centroid, 133 | class_id)) 134 | pbar.update(1) 135 | return centroids 136 | 137 | 138 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 139 | """ 140 | Calculate class centroids for all classes for all images for all tiles. 141 | items: list of (image_fn, label_fn) 142 | tile size: size of tile 143 | returns: dict that contains a list of centroids for each class 144 | """ 145 | from multiprocessing.dummy import Pool 146 | from functools import partial 147 | pool = Pool(80) 148 | global pbar 149 | pbar = tqdm(total=len(items), desc='pooled centroid extraction', file=sys.stdout) 150 | class_centroids_item = partial(class_centroids_image, 151 | num_classes=num_classes, 152 | id2trainid=id2trainid, 153 | tile_size=tile_size) 154 | 155 | centroids = defaultdict(list) 156 | new_centroids = pool.map(class_centroids_item, items) 157 | pool.close() 158 | pool.join() 159 | 160 | # combine each image's items into a single global dict 161 | for image_items in new_centroids: 162 | for class_id in image_items: 163 | centroids[class_id].extend(image_items[class_id]) 164 | return centroids 165 | 166 | 167 | def unpooled_class_centroids_all(items, num_classes, id2trainid, 168 | tile_size=1024): 169 | """ 170 | Calculate class centroids for all classes for all images for all tiles. 171 | items: list of (image_fn, label_fn) 172 | tile size: size of tile 173 | returns: dict that contains a list of centroids for each class 174 | """ 175 | centroids = defaultdict(list) 176 | global pbar 177 | pbar = tqdm(total=len(items), desc='centroid extraction', file=sys.stdout) 178 | for image, label in items: 179 | new_centroids = class_centroids_image(item=(image, label), 180 | tile_size=tile_size, 181 | num_classes=num_classes, 182 | id2trainid=id2trainid) 183 | for class_id in new_centroids: 184 | centroids[class_id].extend(new_centroids[class_id]) 185 | 186 | return centroids 187 | 188 | 189 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 190 | """ 191 | intermediate function to call pooled_class_centroid 192 | """ 193 | pooled_centroids = pooled_class_centroids_all(items, num_classes, 194 | id2trainid, tile_size) 195 | # pooled_centroids = unpooled_class_centroids_all(items, num_classes, 196 | # id2trainid, tile_size) 197 | return pooled_centroids 198 | 199 | 200 | def random_sampling(alist, num): 201 | """ 202 | Randomly sample num items from the list 203 | alist: list of centroids to sample from 204 | num: can be larger than the list and if so, then wrap around 205 | return: class uniform samples from the list 206 | """ 207 | sampling = [] 208 | len_list = len(alist) 209 | assert len_list, 'len_list is zero!' 210 | indices = np.arange(len_list) 211 | np.random.shuffle(indices) 212 | 213 | for i in range(num): 214 | item = alist[indices[i % len_list]] 215 | sampling.append(item) 216 | return sampling 217 | 218 | 219 | def build_centroids(imgs, num_classes, train, cv=None, coarse=False, 220 | custom_coarse=False, id2trainid=None): 221 | """ 222 | The first step of uniform sampling is to decide sampling centers. 223 | The idea is to divide each image into tiles and within each tile, 224 | we compute a centroid for each class to indicate roughly where to 225 | sample a crop during training. 226 | 227 | This function computes these centroids and returns a list of them. 228 | """ 229 | if not (cfg.DATASET.CLASS_UNIFORM_PCT and train): 230 | return [] 231 | 232 | centroid_fn = cfg.DATASET.NAME 233 | 234 | if coarse or custom_coarse: 235 | if coarse: 236 | centroid_fn += '_coarse' 237 | if custom_coarse: 238 | centroid_fn += '_customcoarse_final' 239 | else: 240 | centroid_fn += '_cv{}'.format(cv) 241 | centroid_fn += '_tile{}.json'.format(cfg.DATASET.CLASS_UNIFORM_TILE) 242 | json_fn = os.path.join(cfg.DATASET.CENTROID_ROOT, 243 | centroid_fn) 244 | if os.path.isfile(json_fn): 245 | logx.msg('Loading centroid file {}'.format(json_fn)) 246 | with open(json_fn, 'r') as json_data: 247 | centroids = json.load(json_data) 248 | centroids = {int(idx): centroids[idx] for idx in centroids} 249 | logx.msg('Found {} centroids'.format(len(centroids))) 250 | else: 251 | logx.msg('Didn\'t find {}, so building it'.format(json_fn)) 252 | 253 | if cfg.GLOBAL_RANK==0: 254 | 255 | os.makedirs(cfg.DATASET.CENTROID_ROOT, exist_ok=True) 256 | # centroids is a dict (indexed by class) of lists of centroids 257 | centroids = class_centroids_all( 258 | imgs, 259 | num_classes, 260 | id2trainid=id2trainid) 261 | with open(json_fn, 'w') as outfile: 262 | json.dump(centroids, outfile, indent=4) 263 | 264 | # wait for everyone to be at the same point 265 | torch.distributed.barrier() 266 | 267 | # GPUs (except rank0) read in the just-created centroid file 268 | if cfg.GLOBAL_RANK != 0: 269 | msg = f'Expected to find {json_fn}' 270 | assert os.path.isfile(json_fn), msg 271 | with open(json_fn, 'r') as json_data: 272 | centroids = json.load(json_data) 273 | centroids = {int(idx): centroids[idx] for idx in centroids} 274 | 275 | return centroids 276 | 277 | 278 | def build_epoch(imgs, centroids, num_classes, train): 279 | """ 280 | Generate an epoch of crops using uniform sampling. 281 | Needs to be called every epoch. 282 | Will not apply uniform sampling if not train or class uniform is off. 283 | 284 | Inputs: 285 | imgs - list of imgs 286 | centroids - list of class centroids 287 | num_classes - number of classes 288 | class_uniform_pct: % of uniform images in one epoch 289 | Outputs: 290 | imgs - list of images to use this epoch 291 | """ 292 | class_uniform_pct = cfg.DATASET.CLASS_UNIFORM_PCT 293 | if not (train and class_uniform_pct): 294 | return imgs 295 | 296 | logx.msg("Class Uniform Percentage: {}".format(str(class_uniform_pct))) 297 | num_epoch = int(len(imgs)) 298 | 299 | logx.msg('Class Uniform items per Epoch: {}'.format(str(num_epoch))) 300 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes) 301 | class_uniform_count = num_per_class * num_classes 302 | num_rand = num_epoch - class_uniform_count 303 | # create random crops 304 | imgs_uniform = random_sampling(imgs, num_rand) 305 | 306 | # now add uniform sampling 307 | for class_id in range(num_classes): 308 | msg = "cls {} len {}".format(class_id, len(centroids[class_id])) 309 | logx.msg(msg) 310 | for class_id in range(num_classes): 311 | if cfg.DATASET.CLASS_UNIFORM_BIAS is not None: 312 | bias = cfg.DATASET.CLASS_UNIFORM_BIAS[class_id] 313 | num_per_class_biased = int(num_per_class * bias) 314 | else: 315 | num_per_class_biased = num_per_class 316 | centroid_len = len(centroids[class_id]) 317 | if centroid_len == 0: 318 | pass 319 | else: 320 | class_centroids = random_sampling(centroids[class_id], 321 | num_per_class_biased) 322 | imgs_uniform.extend(class_centroids) 323 | 324 | return imgs_uniform 325 | -------------------------------------------------------------------------------- /semantic-segmentation/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code borrowded from: 3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py 4 | # 5 | # 6 | # MIT License 7 | # 8 | # Copyright (c) 2017 ZijunDeng 9 | # 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy 11 | # of this software and associated documentation files (the "Software"), to deal 12 | # in the Software without restriction, including without limitation the rights 13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | # copies of the Software, and to permit persons to whom the Software is 15 | # furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all 18 | # copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | # SOFTWARE. 27 | 28 | """ 29 | 30 | """ 31 | Standard Transform 32 | """ 33 | 34 | import random 35 | import numpy as np 36 | from skimage.filters import gaussian 37 | from skimage.restoration import denoise_bilateral 38 | import torch 39 | from PIL import Image, ImageEnhance 40 | import torchvision.transforms as torch_tr 41 | from config import cfg 42 | from scipy.ndimage.interpolation import shift 43 | 44 | from skimage.segmentation import find_boundaries 45 | 46 | try: 47 | import accimage 48 | except ImportError: 49 | accimage = None 50 | 51 | 52 | class RandomVerticalFlip(object): 53 | def __call__(self, img): 54 | if random.random() < 0.5: 55 | return img.transpose(Image.FLIP_TOP_BOTTOM) 56 | return img 57 | 58 | 59 | class DeNormalize(object): 60 | def __init__(self, mean, std): 61 | self.mean = mean 62 | self.std = std 63 | 64 | def __call__(self, tensor): 65 | for t, m, s in zip(tensor, self.mean, self.std): 66 | t.mul_(s).add_(m) 67 | return tensor 68 | 69 | 70 | class MaskToTensor(object): 71 | def __call__(self, img, blockout_predefined_area=False): 72 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 73 | 74 | class RelaxedBoundaryLossToTensor(object): 75 | """ 76 | Boundary Relaxation 77 | """ 78 | def __init__(self,ignore_id, num_classes): 79 | self.ignore_id=ignore_id 80 | self.num_classes= num_classes 81 | 82 | 83 | def new_one_hot_converter(self,a): 84 | ncols = self.num_classes+1 85 | out = np.zeros( (a.size,ncols), dtype=np.uint8) 86 | out[np.arange(a.size),a.ravel()] = 1 87 | out.shape = a.shape + (ncols,) 88 | return out 89 | 90 | def __call__(self,img): 91 | 92 | img_arr = np.array(img) 93 | img_arr[img_arr==self.ignore_id]=self.num_classes 94 | 95 | if cfg.STRICTBORDERCLASS != None: 96 | one_hot_orig = self.new_one_hot_converter(img_arr) 97 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1])) 98 | for cls in cfg.STRICTBORDERCLASS: 99 | mask = np.logical_or(mask,(img_arr == cls)) 100 | one_hot = 0 101 | 102 | border = cfg.BORDER_WINDOW 103 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 104 | border = border // 2 105 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8) 106 | 107 | for i in range(-border,border+1): 108 | for j in range(-border, border+1): 109 | shifted= shift(img_arr,(i,j), cval=self.num_classes) 110 | one_hot += self.new_one_hot_converter(shifted) 111 | 112 | one_hot[one_hot>1] = 1 113 | 114 | if cfg.STRICTBORDERCLASS != None: 115 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot) 116 | 117 | one_hot = np.moveaxis(one_hot,-1,0) 118 | 119 | 120 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 121 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot) 122 | # print(one_hot.shape) 123 | return torch.from_numpy(one_hot).byte() 124 | 125 | class ResizeHeight(object): 126 | def __init__(self, size, interpolation=Image.BILINEAR): 127 | self.target_h = size 128 | self.interpolation = interpolation 129 | 130 | def __call__(self, img): 131 | w, h = img.size 132 | target_w = int(w / h * self.target_h) 133 | return img.resize((target_w, self.target_h), self.interpolation) 134 | 135 | 136 | class FreeScale(object): 137 | def __init__(self, size, interpolation=Image.BILINEAR): 138 | self.size = tuple(reversed(size)) # size: (h, w) 139 | self.interpolation = interpolation 140 | 141 | def __call__(self, img): 142 | return img.resize(self.size, self.interpolation) 143 | 144 | 145 | class FlipChannels(object): 146 | """ 147 | Flip around the x-axis 148 | """ 149 | def __call__(self, img): 150 | img = np.array(img)[:, :, ::-1] 151 | return Image.fromarray(img.astype(np.uint8)) 152 | 153 | 154 | class RandomGaussianBlur(object): 155 | """ 156 | Apply Gaussian Blur 157 | """ 158 | def __call__(self, img): 159 | sigma = 0.15 + random.random() * 1.15 160 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 161 | blurred_img *= 255 162 | return Image.fromarray(blurred_img.astype(np.uint8)) 163 | 164 | 165 | class RandomBrightness(object): 166 | def __call__(self, img): 167 | if random.random() < 0.5: 168 | return img 169 | v = random.uniform(0.1, 1.9) 170 | return ImageEnhance.Brightness(img).enhance(v) 171 | 172 | 173 | class RandomBilateralBlur(object): 174 | """ 175 | Apply Bilateral Filtering 176 | 177 | """ 178 | def __call__(self, img): 179 | sigma = random.uniform(0.05, 0.75) 180 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True) 181 | blurred_img *= 255 182 | return Image.fromarray(blurred_img.astype(np.uint8)) 183 | 184 | 185 | def _is_pil_image(img): 186 | if accimage is not None: 187 | return isinstance(img, (Image.Image, accimage.Image)) 188 | else: 189 | return isinstance(img, Image.Image) 190 | 191 | 192 | def adjust_brightness(img, brightness_factor): 193 | """Adjust brightness of an Image. 194 | 195 | Args: 196 | img (PIL Image): PIL Image to be adjusted. 197 | brightness_factor (float): How much to adjust the brightness. Can be 198 | any non negative number. 0 gives a black image, 1 gives the 199 | original image while 2 increases the brightness by a factor of 2. 200 | 201 | Returns: 202 | PIL Image: Brightness adjusted image. 203 | """ 204 | if not _is_pil_image(img): 205 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 206 | 207 | enhancer = ImageEnhance.Brightness(img) 208 | img = enhancer.enhance(brightness_factor) 209 | return img 210 | 211 | 212 | def adjust_contrast(img, contrast_factor): 213 | """Adjust contrast of an Image. 214 | 215 | Args: 216 | img (PIL Image): PIL Image to be adjusted. 217 | contrast_factor (float): How much to adjust the contrast. Can be any 218 | non negative number. 0 gives a solid gray image, 1 gives the 219 | original image while 2 increases the contrast by a factor of 2. 220 | 221 | Returns: 222 | PIL Image: Contrast adjusted image. 223 | """ 224 | if not _is_pil_image(img): 225 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 226 | 227 | enhancer = ImageEnhance.Contrast(img) 228 | img = enhancer.enhance(contrast_factor) 229 | return img 230 | 231 | 232 | def adjust_saturation(img, saturation_factor): 233 | """Adjust color saturation of an image. 234 | 235 | Args: 236 | img (PIL Image): PIL Image to be adjusted. 237 | saturation_factor (float): How much to adjust the saturation. 0 will 238 | give a black and white image, 1 will give the original image while 239 | 2 will enhance the saturation by a factor of 2. 240 | 241 | Returns: 242 | PIL Image: Saturation adjusted image. 243 | """ 244 | if not _is_pil_image(img): 245 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 246 | 247 | enhancer = ImageEnhance.Color(img) 248 | img = enhancer.enhance(saturation_factor) 249 | return img 250 | 251 | 252 | def adjust_hue(img, hue_factor): 253 | """Adjust hue of an image. 254 | 255 | The image hue is adjusted by converting the image to HSV and 256 | cyclically shifting the intensities in the hue channel (H). 257 | The image is then converted back to original image mode. 258 | 259 | `hue_factor` is the amount of shift in H channel and must be in the 260 | interval `[-0.5, 0.5]`. 261 | 262 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 263 | 264 | Args: 265 | img (PIL Image): PIL Image to be adjusted. 266 | hue_factor (float): How much to shift the hue channel. Should be in 267 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 268 | HSV space in positive and negative direction respectively. 269 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 270 | with complementary colors while 0 gives the original image. 271 | 272 | Returns: 273 | PIL Image: Hue adjusted image. 274 | """ 275 | if not(-0.5 <= hue_factor <= 0.5): 276 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 277 | 278 | if not _is_pil_image(img): 279 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 280 | 281 | input_mode = img.mode 282 | if input_mode in {'L', '1', 'I', 'F'}: 283 | return img 284 | 285 | h, s, v = img.convert('HSV').split() 286 | 287 | np_h = np.array(h, dtype=np.uint8) 288 | # uint8 addition take cares of rotation across boundaries 289 | with np.errstate(over='ignore'): 290 | np_h += np.uint8(hue_factor * 255) 291 | h = Image.fromarray(np_h, 'L') 292 | 293 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 294 | return img 295 | 296 | 297 | class ColorJitter(object): 298 | """Randomly change the brightness, contrast and saturation of an image. 299 | 300 | Args: 301 | brightness (float): How much to jitter brightness. brightness_factor 302 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 303 | contrast (float): How much to jitter contrast. contrast_factor 304 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 305 | saturation (float): How much to jitter saturation. saturation_factor 306 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 307 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 308 | [-hue, hue]. Should be >=0 and <= 0.5. 309 | """ 310 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 311 | self.brightness = brightness 312 | self.contrast = contrast 313 | self.saturation = saturation 314 | self.hue = hue 315 | 316 | @staticmethod 317 | def get_params(brightness, contrast, saturation, hue): 318 | """Get a randomized transform to be applied on image. 319 | 320 | Arguments are same as that of __init__. 321 | 322 | Returns: 323 | Transform which randomly adjusts brightness, contrast and 324 | saturation in a random order. 325 | """ 326 | transforms = [] 327 | if brightness > 0: 328 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 329 | transforms.append( 330 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) 331 | 332 | if contrast > 0: 333 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 334 | transforms.append( 335 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) 336 | 337 | if saturation > 0: 338 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 339 | transforms.append( 340 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) 341 | 342 | if hue > 0: 343 | hue_factor = np.random.uniform(-hue, hue) 344 | transforms.append( 345 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) 346 | 347 | np.random.shuffle(transforms) 348 | transform = torch_tr.Compose(transforms) 349 | 350 | return transform 351 | 352 | def __call__(self, img): 353 | """ 354 | Args: 355 | img (PIL Image): Input image. 356 | 357 | Returns: 358 | PIL Image: Color jittered image. 359 | """ 360 | transform = self.get_params(self.brightness, self.contrast, 361 | self.saturation, self.hue) 362 | return transform(img) 363 | -------------------------------------------------------------------------------- /semantic-segmentation/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | ############################################################################## 30 | # Config 31 | ############################################################################## 32 | 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | from __future__ import unicode_literals 38 | 39 | import os 40 | import re 41 | import torch 42 | 43 | from utils.attr_dict import AttrDict 44 | from runx.logx import logx 45 | 46 | 47 | __C = AttrDict() 48 | cfg = __C 49 | __C.GLOBAL_RANK = 0 50 | __C.EPOCH = 0 51 | # Absolute path to a location to keep some large files, not in this dir. 52 | __C.ASSETS_PATH = '/home/louis/SemSeg/semantic-segmentation/large_assets' 53 | 54 | # Use class weighted loss per batch to increase loss for low pixel count classes per batch 55 | __C.BATCH_WEIGHTING = False 56 | 57 | # Border Relaxation Count 58 | __C.BORDER_WINDOW = 1 59 | # Number of epoch to use before turn off border restriction 60 | __C.REDUCE_BORDER_EPOCH = -1 61 | # Comma Seperated List of class id to relax 62 | __C.STRICTBORDERCLASS = None 63 | # Where output results get written 64 | __C.RESULT_DIR = None 65 | 66 | __C.OPTIONS = AttrDict() 67 | __C.OPTIONS.TEST_MODE = False 68 | __C.OPTIONS.INIT_DECODER = False 69 | __C.OPTIONS.TORCH_VERSION = None 70 | 71 | __C.TRAIN = AttrDict() 72 | __C.TRAIN.RANDOM_BRIGHTNESS_SHIFT_VALUE = 10 73 | __C.TRAIN.FP16 = False 74 | 75 | #Attribute Dictionary for Dataset 76 | __C.DATASET = AttrDict() 77 | #Cityscapes Dir Location 78 | __C.DATASET.CITYSCAPES_DIR = \ 79 | os.path.join(__C.ASSETS_PATH, 'data/cityscapes') 80 | __C.DATASET.CITYSCAPES_CUSTOMCOARSE = \ 81 | os.path.join(__C.ASSETS_PATH, 'data/cityscapes/autolabelled') 82 | __C.DATASET.CENTROID_ROOT = \ 83 | os.path.join(__C.ASSETS_PATH, 'uniform_centroids') 84 | #SDC Augmented Cityscapes Dir Location 85 | __C.DATASET.CITYSCAPES_AUG_DIR = '' 86 | #Mapillary Dataset Dir Location 87 | __C.DATASET.MAPILLARY_DIR = os.path.join(__C.ASSETS_PATH, 'data/Mapillary/data') 88 | #Kitti Dataset Dir Location 89 | __C.DATASET.KITTI_DIR = '' 90 | #SDC Augmented Kitti Dataset Dir Location 91 | __C.DATASET.KITTI_AUG_DIR = '' 92 | #Camvid Dataset Dir Location 93 | __C.DATASET.CAMVID_DIR = '' 94 | #Number of splits to support 95 | __C.DATASET.CITYSCAPES_SPLITS = 3 96 | __C.DATASET.MEAN = [0.485, 0.456, 0.406] 97 | __C.DATASET.STD = [0.229, 0.224, 0.225] 98 | __C.DATASET.NAME = '' 99 | __C.DATASET.NUM_CLASSES = 0 100 | __C.DATASET.IGNORE_LABEL = 255 101 | __C.DATASET.DUMP_IMAGES = False 102 | __C.DATASET.CLASS_UNIFORM_PCT = 0.5 103 | __C.DATASET.CLASS_UNIFORM_TILE = 1024 104 | __C.DATASET.COARSE_BOOST_CLASSES = None 105 | __C.DATASET.CV = 0 106 | __C.DATASET.COLORIZE_MASK_FN = None 107 | __C.DATASET.CUSTOM_COARSE_PROB = None 108 | __C.DATASET.MASK_OUT_CITYSCAPES = False 109 | 110 | # This enables there to always be translation augmentation during random crop 111 | # process, even if image is smaller than crop size. 112 | __C.DATASET.TRANSLATE_AUG_FIX = False 113 | __C.DATASET.LANCZOS_SCALES = False 114 | # Use a center crop of size args.pre_size for mapillary validation 115 | # Need to use this if you want to dump images 116 | __C.DATASET.MAPILLARY_CROP_VAL = False 117 | __C.DATASET.CROP_SIZE = '896' 118 | 119 | __C.MODEL = AttrDict() 120 | __C.MODEL.BN = 'regularnorm' 121 | __C.MODEL.BNFUNC = None 122 | __C.MODEL.MSCALE = False 123 | __C.MODEL.THREE_SCALE = False 124 | __C.MODEL.ALT_TWO_SCALE = False 125 | __C.MODEL.EXTRA_SCALES = '0.5,1.5' 126 | __C.MODEL.N_SCALES = None 127 | __C.MODEL.ALIGN_CORNERS = False 128 | __C.MODEL.MSCALE_LO_SCALE = 0.5 129 | __C.MODEL.OCR_ASPP = False 130 | __C.MODEL.SEGATTN_BOT_CH = 256 131 | __C.MODEL.ASPP_BOT_CH = 256 132 | __C.MODEL.MSCALE_CAT_SCALE_FLT = False 133 | __C.MODEL.MSCALE_INNER_3x3 = True 134 | __C.MODEL.MSCALE_DROPOUT = False 135 | __C.MODEL.MSCALE_OLDARCH = False 136 | __C.MODEL.MSCALE_INIT = 0.5 137 | __C.MODEL.ATTNSCALE_BN_HEAD = False 138 | __C.MODEL.GRAD_CKPT = False 139 | 140 | WEIGHTS_PATH = os.path.join(__C.ASSETS_PATH, 'seg_weights') 141 | __C.MODEL.WRN38_CHECKPOINT = \ 142 | os.path.join(WEIGHTS_PATH, 'wider_resnet38.pth.tar') 143 | __C.MODEL.WRN41_CHECKPOINT = \ 144 | os.path.join(WEIGHTS_PATH, 'wider_resnet41_cornflower_sunfish.pth') 145 | __C.MODEL.X71_CHECKPOINT = \ 146 | os.path.join(WEIGHTS_PATH, 'aligned_xception71.pth') 147 | __C.MODEL.HRNET_CHECKPOINT = \ 148 | os.path.join(WEIGHTS_PATH, 'hrnetv2_w48_imagenet_pretrained.pth') 149 | 150 | __C.LOSS = AttrDict() 151 | # Weight for OCR aux loss 152 | __C.LOSS.OCR_ALPHA = 0.4 153 | # Use RMI for the OCR aux loss 154 | __C.LOSS.OCR_AUX_RMI = False 155 | # Supervise the multi-scale predictions directly 156 | __C.LOSS.SUPERVISED_MSCALE_WT = 0 157 | 158 | __C.MODEL.OCR = AttrDict() 159 | __C.MODEL.OCR.MID_CHANNELS = 512 160 | __C.MODEL.OCR.KEY_CHANNELS = 256 161 | __C.MODEL.OCR_EXTRA = AttrDict() 162 | __C.MODEL.OCR_EXTRA.FINAL_CONV_KERNEL = 1 163 | __C.MODEL.OCR_EXTRA.STAGE1 = AttrDict() 164 | __C.MODEL.OCR_EXTRA.STAGE1.NUM_MODULES = 1 165 | __C.MODEL.OCR_EXTRA.STAGE1.NUM_RANCHES = 1 166 | __C.MODEL.OCR_EXTRA.STAGE1.BLOCK = 'BOTTLENECK' 167 | __C.MODEL.OCR_EXTRA.STAGE1.NUM_BLOCKS = [4] 168 | __C.MODEL.OCR_EXTRA.STAGE1.NUM_CHANNELS = [64] 169 | __C.MODEL.OCR_EXTRA.STAGE1.FUSE_METHOD = 'SUM' 170 | __C.MODEL.OCR_EXTRA.STAGE2 = AttrDict() 171 | __C.MODEL.OCR_EXTRA.STAGE2.NUM_MODULES = 1 172 | __C.MODEL.OCR_EXTRA.STAGE2.NUM_BRANCHES = 2 173 | __C.MODEL.OCR_EXTRA.STAGE2.BLOCK = 'BASIC' 174 | __C.MODEL.OCR_EXTRA.STAGE2.NUM_BLOCKS = [4, 4] 175 | __C.MODEL.OCR_EXTRA.STAGE2.NUM_CHANNELS = [48, 96] 176 | __C.MODEL.OCR_EXTRA.STAGE2.FUSE_METHOD = 'SUM' 177 | __C.MODEL.OCR_EXTRA.STAGE3 = AttrDict() 178 | __C.MODEL.OCR_EXTRA.STAGE3.NUM_MODULES = 4 179 | __C.MODEL.OCR_EXTRA.STAGE3.NUM_BRANCHES = 3 180 | __C.MODEL.OCR_EXTRA.STAGE3.BLOCK = 'BASIC' 181 | __C.MODEL.OCR_EXTRA.STAGE3.NUM_BLOCKS = [4, 4, 4] 182 | __C.MODEL.OCR_EXTRA.STAGE3.NUM_CHANNELS = [48, 96, 192] 183 | __C.MODEL.OCR_EXTRA.STAGE3.FUSE_METHOD = 'SUM' 184 | __C.MODEL.OCR_EXTRA.STAGE4 = AttrDict() 185 | __C.MODEL.OCR_EXTRA.STAGE4.NUM_MODULES = 3 186 | __C.MODEL.OCR_EXTRA.STAGE4.NUM_BRANCHES = 4 187 | __C.MODEL.OCR_EXTRA.STAGE4.BLOCK = 'BASIC' 188 | __C.MODEL.OCR_EXTRA.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 189 | __C.MODEL.OCR_EXTRA.STAGE4.NUM_CHANNELS = [48, 96, 192, 384] 190 | __C.MODEL.OCR_EXTRA.STAGE4.FUSE_METHOD = 'SUM' 191 | 192 | 193 | def torch_version_float(): 194 | version_str = torch.__version__ 195 | version_re = re.search(r'^([0-9]+\.[0-9]+)', version_str) 196 | if version_re: 197 | version = float(version_re.group(1)) 198 | logx.msg(f'Torch version: {version}, {version_str}') 199 | else: 200 | version = 1.0 201 | logx.msg(f'Can\'t parse torch version ({version}), assuming {version}') 202 | return version 203 | 204 | 205 | def assert_and_infer_cfg(args, make_immutable=True, train_mode=True): 206 | """Call this function in your script after you have finished setting all cfg 207 | values that are necessary (e.g., merging a config from a file, merging 208 | command line config options, etc.). By default, this function will also 209 | mark the global cfg as immutable to prevent changing the global cfg 210 | settings during script execution (which can lead to hard to debug errors 211 | or code that's harder to understand than is necessary). 212 | """ 213 | 214 | __C.OPTIONS.TORCH_VERSION = torch_version_float() 215 | 216 | if hasattr(args, 'syncbn') and args.syncbn: 217 | if args.apex: 218 | import apex 219 | __C.MODEL.BN = 'apex-syncnorm' 220 | __C.MODEL.BNFUNC = apex.parallel.SyncBatchNorm 221 | else: 222 | raise Exception('No Support for SyncBN without Apex') 223 | else: 224 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d 225 | print('Using regular batch norm') 226 | 227 | if not train_mode: 228 | cfg.immutable(True) 229 | return 230 | 231 | if args.batch_weighting: 232 | __C.BATCH_WEIGHTING = True 233 | 234 | if args.custom_coarse_prob: 235 | __C.DATASET.CUSTOM_COARSE_PROB = args.custom_coarse_prob 236 | 237 | if args.jointwtborder: 238 | if args.strict_bdr_cls != '': 239 | strict_classes = [int(i) for i in args.strict_bdr_cls.split(",")] 240 | __C.STRICTBORDERCLASS = strict_classes 241 | if args.rlx_off_epoch > -1: 242 | __C.REDUCE_BORDER_EPOCH = args.rlx_off_epoch 243 | 244 | cfg.DATASET.NAME = args.dataset 245 | cfg.DATASET.DUMP_IMAGES = args.dump_augmentation_images 246 | 247 | cfg.DATASET.CLASS_UNIFORM_PCT = args.class_uniform_pct 248 | cfg.DATASET.CLASS_UNIFORM_TILE = args.class_uniform_tile 249 | if args.coarse_boost_classes: 250 | cfg.DATASET.COARSE_BOOST_CLASSES = \ 251 | [int(i) for i in args.coarse_boost_classes.split(',')] 252 | 253 | cfg.DATASET.CLASS_UNIFORM_BIAS = None 254 | 255 | if args.dump_assets and args.dataset == 'cityscapes': 256 | # A hacky way to force that when we dump cityscapes 257 | logx.msg('*' * 70) 258 | logx.msg(f'ALERT: forcing cv=3 to allow all images to be evaluated') 259 | logx.msg('*' * 70) 260 | cfg.DATASET.CV = 3 261 | else: 262 | cfg.DATASET.CV = args.cv 263 | # Total number of splits 264 | cfg.DATASET.CV_SPLITS = 3 265 | 266 | if args.translate_aug_fix: 267 | cfg.DATASET.TRANSLATE_AUG_FIX = True 268 | 269 | cfg.MODEL.MSCALE = ('mscale' in args.arch.lower() or 'attnscale' in 270 | args.arch.lower()) 271 | 272 | if args.three_scale: 273 | cfg.MODEL.THREE_SCALE = True 274 | 275 | if args.alt_two_scale: 276 | cfg.MODEL.ALT_TWO_SCALE = True 277 | 278 | cfg.MODEL.MSCALE_LO_SCALE = args.mscale_lo_scale 279 | 280 | def str2list(s): 281 | alist = s.split(',') 282 | alist = [float(x) for x in alist] 283 | return alist 284 | 285 | if args.n_scales: 286 | cfg.MODEL.N_SCALES = str2list(args.n_scales) 287 | logx.msg('n scales {}'.format(cfg.MODEL.N_SCALES)) 288 | 289 | if args.extra_scales: 290 | cfg.MODEL.EXTRA_SCALES = str2list(args.extra_scales) 291 | 292 | if args.align_corners: 293 | cfg.MODEL.ALIGN_CORNERS = True 294 | 295 | if args.init_decoder: 296 | cfg.OPTIONS.INIT_DECODER = True 297 | 298 | cfg.RESULT_DIR = args.result_dir 299 | 300 | if args.mask_out_cityscapes: 301 | cfg.DATASET.MASK_OUT_CITYSCAPES = True 302 | 303 | if args.fp16: 304 | cfg.TRAIN.FP16 = True 305 | 306 | if args.map_crop_val: 307 | __C.DATASET.MAPILLARY_CROP_VAL = True 308 | 309 | __C.DATASET.CROP_SIZE = args.crop_size 310 | 311 | if args.aspp_bot_ch is not None: 312 | # todo fixme: make all code use this cfg 313 | __C.MODEL.ASPP_BOT_CH = int(args.aspp_bot_ch) 314 | 315 | if args.mscale_cat_scale_flt: 316 | __C.MODEL.MSCALE_CAT_SCALE_FLT = True 317 | 318 | if args.mscale_no3x3: 319 | __C.MODEL.MSCALE_INNER_3x3 = False 320 | 321 | if args.mscale_dropout: 322 | __C.MODEL.MSCALE_DROPOUT = True 323 | 324 | if args.mscale_old_arch: 325 | __C.MODEL.MSCALE_OLDARCH = True 326 | 327 | if args.mscale_init is not None: 328 | __C.MODEL.MSCALE_INIT = args.mscale_init 329 | 330 | if args.attnscale_bn_head: 331 | __C.MODEL.ATTNSCALE_BN_HEAD = True 332 | 333 | if args.segattn_bot_ch is not None: 334 | __C.MODEL.SEGATTN_BOT_CH = args.segattn_bot_ch 335 | 336 | if args.set_cityscapes_root is not None: 337 | # '/data/cs_imgs_cv0' 338 | # '/data/cs_imgs_cv2' 339 | __C.DATASET.CITYSCAPES_DIR = args.set_cityscapes_root 340 | 341 | if args.ocr_alpha is not None: 342 | __C.LOSS.OCR_ALPHA = args.ocr_alpha 343 | 344 | if args.ocr_aux_loss_rmi: 345 | __C.LOSS.OCR_AUX_RMI = True 346 | 347 | if args.supervised_mscale_loss_wt is not None: 348 | __C.LOSS.SUPERVISED_MSCALE_WT = args.supervised_mscale_loss_wt 349 | 350 | cfg.DROPOUT_COARSE_BOOST_CLASSES = None 351 | if args.custom_coarse_dropout_classes: 352 | cfg.DROPOUT_COARSE_BOOST_CLASSES = \ 353 | [int(i) for i in args.custom_coarse_dropout_classes.split(',')] 354 | 355 | if args.grad_ckpt: 356 | __C.MODEL.GRAD_CKPT = True 357 | 358 | __C.GLOBAL_RANK = args.global_rank 359 | 360 | if make_immutable: 361 | cfg.immutable(True) 362 | 363 | 364 | def update_epoch(epoch): 365 | # Update EPOCH CTR 366 | cfg.immutable(False) 367 | cfg.EPOCH = epoch 368 | cfg.immutable(True) 369 | 370 | 371 | def update_dataset_cfg(num_classes, ignore_label): 372 | cfg.immutable(False) 373 | cfg.DATASET.NUM_CLASSES = num_classes 374 | cfg.DATASET.IGNORE_LABEL = ignore_label 375 | logx.msg('num_classes = {}'.format(num_classes)) 376 | cfg.immutable(True) 377 | 378 | 379 | def update_dataset_inst(dataset_inst): 380 | cfg.immutable(False) 381 | cfg.DATASET_INST = dataset_inst 382 | cfg.immutable(True) 383 | --------------------------------------------------------------------------------