├── utils ├── __init__.py ├── attr_dict.py ├── my_data_parallel.py └── misc.py ├── transforms ├── __init__.py └── transforms.py ├── assets ├── 6529-teaser.gif ├── 6529-architecture.png └── youtube_capture_p.png ├── scripts ├── eval_r101_os8.sh ├── eval_resnext_os4.sh ├── eval_resnext_os8.sh ├── submit_r101_os8.sh ├── submit_cityscapes_resnext.sh ├── train_m_os16_hanet.sh ├── train_r101_os16_hanet.sh ├── train_r101_os8_hanet.sh ├── train_resnext_pretrain.sh ├── train_r101_os8_hanet_best.sh └── train_resnext_fr_pretrain.sh ├── datasets ├── nullloader.py ├── sampler.py ├── mapillary.py ├── kitti.py ├── uniform.py ├── cityscapes_labels.py └── camvid.py ├── network ├── __init__.py ├── mynn.py ├── HANet.py ├── PosEmbedding.py ├── Resnet.py ├── wider_resnet.py └── SEresnext.py ├── config.py ├── README.md ├── loss.py └── optimizer.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/6529-teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shachoi/HANet/HEAD/assets/6529-teaser.gif -------------------------------------------------------------------------------- /assets/6529-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shachoi/HANet/HEAD/assets/6529-architecture.png -------------------------------------------------------------------------------- /assets/youtube_capture_p.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shachoi/HANet/HEAD/assets/youtube_capture_p.png -------------------------------------------------------------------------------- /scripts/eval_r101_os8.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \ 5 | --dataset cityscapes \ 6 | --arch network.deepv3.DeepR101V3PlusD_HANet_OS8 \ 7 | --inference_mode sliding \ 8 | --scales 0.5,1.0,2.0 \ 9 | --split val \ 10 | --cv_split 0 \ 11 | --ckpt_path ${2} \ 12 | --snapshot ${1} \ 13 | --pos_rfactor 8 \ 14 | --hanet 1 1 1 1 0 \ 15 | --hanet_set 3 64 3 \ 16 | --hanet_pos 2 1 \ 17 | --dropout 0.1 \ 18 | --aux_loss \ 19 | -------------------------------------------------------------------------------- /scripts/eval_resnext_os4.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \ 5 | --dataset cityscapes \ 6 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS4 \ 7 | --inference_mode sliding \ 8 | --scales 0.5,1.0,2.0 \ 9 | --split val \ 10 | --cv_split 0 \ 11 | --ckpt_path ${2} \ 12 | --snapshot ${1} \ 13 | --pos_rfactor 8 \ 14 | --hanet 1 1 1 1 0 \ 15 | --hanet_set 3 64 3 \ 16 | --hanet_pos 2 1 \ 17 | --dropout 0.1 \ 18 | --aux_loss \ 19 | -------------------------------------------------------------------------------- /scripts/eval_resnext_os8.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \ 5 | --dataset cityscapes \ 6 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS8 \ 7 | --inference_mode sliding \ 8 | --scales 0.5,1.0,2.0 \ 9 | --split val \ 10 | --cv_split 0 \ 11 | --ckpt_path ${2} \ 12 | --snapshot ${1} \ 13 | --pos_rfactor 8 \ 14 | --hanet 1 1 1 1 0 \ 15 | --hanet_set 3 64 3 \ 16 | --hanet_pos 2 1 \ 17 | --dropout 0.1 \ 18 | --aux_loss \ 19 | -------------------------------------------------------------------------------- /scripts/submit_r101_os8.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \ 5 | --dataset cityscapes \ 6 | --arch network.deepv3.DeepR101V3PlusD_HANet_OS8 \ 7 | --inference_mode sliding \ 8 | --scales 0.5,1.0,2.0 \ 9 | --split test \ 10 | --cv_split 0 \ 11 | --ckpt_path ${2} \ 12 | --snapshot ${1} \ 13 | --pos_rfactor 8 \ 14 | --hanet 1 1 1 1 0 \ 15 | --hanet_set 3 64 3 \ 16 | --hanet_pos 2 1 \ 17 | --dropout 0.3 \ 18 | --aux_loss \ 19 | --dump_images \ 20 | -------------------------------------------------------------------------------- /scripts/submit_cityscapes_resnext.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \ 5 | --dataset cityscapes \ 6 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS8 \ 7 | --inference_mode sliding \ 8 | --scales 0.5,1.0,2.0 \ 9 | --split test \ 10 | --cv_split 0 \ 11 | --ckpt_path ${2} \ 12 | --snapshot ${1} \ 13 | --pos_rfactor 8 \ 14 | --hanet 1 1 1 1 0 \ 15 | --hanet_set 3 64 3 \ 16 | --hanet_pos 2 1 \ 17 | --pos_noise 0.0 \ 18 | --dropout 0.1 \ 19 | --aux_loss \ 20 | --dump_images \ 21 | -------------------------------------------------------------------------------- /datasets/nullloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Null Loader 3 | """ 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | 8 | num_classes = 19 9 | ignore_label = 255 10 | 11 | class NullLoader(data.Dataset): 12 | """ 13 | Null Dataset for Performance 14 | """ 15 | def __init__(self,crop_size): 16 | self.imgs = range(200) 17 | self.crop_size = crop_size 18 | 19 | def __getitem__(self, index): 20 | #Return img, mask, name 21 | return torch.FloatTensor(np.zeros((3,self.crop_size,self.crop_size))), torch.LongTensor(np.zeros((self.crop_size,self.crop_size))), 'img' + str(index) 22 | 23 | def __len__(self): 24 | return len(self.imgs) -------------------------------------------------------------------------------- /scripts/train_m_os16_hanet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Example on Cityscapes 3 | python -m torch.distributed.launch --nproc_per_node=1 train.py \ 4 | --dataset cityscapes \ 5 | --arch network.deepv3.DeepMobileNetV3PlusD_HANet \ 6 | --city_mode 'train' \ 7 | --lr_schedule poly \ 8 | --lr 0.01 \ 9 | --poly_exp 0.9 \ 10 | --hanet_lr 0.01 \ 11 | --hanet_poly_exp 0.9 \ 12 | --max_cu_epoch 10000 \ 13 | --class_uniform_pct 0.5 \ 14 | --class_uniform_tile 1024 \ 15 | --syncbn \ 16 | --sgd \ 17 | --crop_size 768 \ 18 | --scale_min 0.5 \ 19 | --scale_max 2.0 \ 20 | --rrotate 0 \ 21 | --color_aug 0.25 \ 22 | --gblur \ 23 | --max_iter 40000 \ 24 | --bs_mult 8 \ 25 | --hanet 1 1 1 1 0 \ 26 | --hanet_set 3 16 3 \ 27 | --hanet_pos 2 1 \ 28 | --pos_rfactor 8 \ 29 | --dropout 0.1 \ 30 | --pos_noise 0.3 \ 31 | --aux_loss \ 32 | --date 0101 \ 33 | --exp m_os16_hanet \ 34 | --ckpt ./logs/ \ 35 | --tb_path ./logs/ 36 | -------------------------------------------------------------------------------- /scripts/train_r101_os16_hanet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Example on Cityscapes 3 | python -m torch.distributed.launch --nproc_per_node=2 train.py \ 4 | --dataset cityscapes \ 5 | --arch network.deepv3.DeepR101V3PlusD_HANet \ 6 | --city_mode 'train' \ 7 | --lr_schedule poly \ 8 | --lr 0.01 \ 9 | --poly_exp 0.9 \ 10 | --hanet_lr 0.01 \ 11 | --hanet_poly_exp 0.9 \ 12 | --max_cu_epoch 10000 \ 13 | --class_uniform_pct 0.5 \ 14 | --class_uniform_tile 1024 \ 15 | --syncbn \ 16 | --sgd \ 17 | --crop_size 768 \ 18 | --scale_min 0.5 \ 19 | --scale_max 2.0 \ 20 | --rrotate 0 \ 21 | --color_aug 0.25 \ 22 | --gblur \ 23 | --max_iter 40000 \ 24 | --bs_mult 4 \ 25 | --hanet 1 1 1 1 0 \ 26 | --hanet_set 3 32 3 \ 27 | --hanet_pos 2 1 \ 28 | --pos_rfactor 8 \ 29 | --dropout 0.1 \ 30 | --pos_noise 0.5 \ 31 | --aux_loss \ 32 | --date 0101 \ 33 | --exp r101_os16_hanet \ 34 | --ckpt ./logs/ \ 35 | --tb_path ./logs/ 36 | -------------------------------------------------------------------------------- /scripts/train_r101_os8_hanet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Example on Cityscapes 3 | python -m torch.distributed.launch --nproc_per_node=2 train.py \ 4 | --dataset cityscapes \ 5 | --arch network.deepv3.DeepR101V3PlusD_HANet_OS8 \ 6 | --city_mode 'train' \ 7 | --lr_schedule poly \ 8 | --lr 0.01 \ 9 | --poly_exp 0.9 \ 10 | --hanet_lr 0.01 \ 11 | --hanet_poly_exp 0.9 \ 12 | --max_cu_epoch 10000 \ 13 | --class_uniform_pct 0.5 \ 14 | --class_uniform_tile 1024 \ 15 | --syncbn \ 16 | --sgd \ 17 | --crop_size 768 \ 18 | --scale_min 0.5 \ 19 | --scale_max 2.0 \ 20 | --rrotate 0 \ 21 | --color_aug 0.25 \ 22 | --gblur \ 23 | --max_iter 40000 \ 24 | --bs_mult 4 \ 25 | --hanet 1 1 1 1 0 \ 26 | --hanet_set 3 64 3 \ 27 | --hanet_pos 2 1 \ 28 | --pos_rfactor 8 \ 29 | --dropout 0.1 \ 30 | --pos_noise 0.5 \ 31 | --aux_loss \ 32 | --date 0104 \ 33 | --exp r101_os8_hanet_64_01 \ 34 | --ckpt ./logs/ \ 35 | --tb_path ./logs/ 36 | -------------------------------------------------------------------------------- /scripts/train_resnext_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Example on Cityscapes 3 | python -m torch.distributed.launch --nproc_per_node=6 train.py \ 4 | --dataset mapillary \ 5 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS8 \ 6 | --city_mode 'train' \ 7 | --lr_schedule poly \ 8 | --lr 0.01 \ 9 | --poly_exp 0.9 \ 10 | --hanet_lr 0.01 \ 11 | --hanet_poly_exp 0.9 \ 12 | --max_cu_epoch 10000 \ 13 | --class_uniform_pct 0.5 \ 14 | --class_uniform_tile 1024 \ 15 | --syncbn \ 16 | --sgd \ 17 | --crop_size 864 \ 18 | --scale_min 0.5 \ 19 | --scale_max 2.0 \ 20 | --rrotate 0 \ 21 | --color_aug 0.25 \ 22 | --gblur \ 23 | --max_iter 200000 \ 24 | --bs_mult 2 \ 25 | --hanet 0 0 0 0 0 \ 26 | --hanet_set 3 64 3 \ 27 | --hanet_pos 2 1 \ 28 | --no_pos_dataset \ 29 | --dropout 0.1 \ 30 | --pos_noise 0.5 \ 31 | --img_wt_loss \ 32 | --date 0101 \ 33 | --exp resnext_pretrain \ 34 | --ckpt ./logs/ \ 35 | --tb_path ./logs/ 36 | -------------------------------------------------------------------------------- /scripts/train_r101_os8_hanet_best.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Example on Cityscapes 3 | python -m torch.distributed.launch --nproc_per_node=4 train.py \ 4 | --dataset cityscapes \ 5 | --arch network.deepv3.DeepR101V3PlusD_HANet_OS8 \ 6 | --city_mode 'trainval' \ 7 | --lr_schedule poly \ 8 | --lr 0.01 \ 9 | --poly_exp 0.9 \ 10 | --hanet_lr 0.01 \ 11 | --hanet_poly_exp 0.9 \ 12 | --max_cu_epoch 10000 \ 13 | --class_uniform_pct 0.5 \ 14 | --class_uniform_tile 1024 \ 15 | --syncbn \ 16 | --sgd \ 17 | --crop_size 864 \ 18 | --scale_min 0.5 \ 19 | --scale_max 2.0 \ 20 | --rrotate 0 \ 21 | --color_aug 0.25 \ 22 | --gblur \ 23 | --max_iter 90000 \ 24 | --bs_mult 3 \ 25 | --hanet 1 1 1 1 0 \ 26 | --hanet_set 3 64 3 \ 27 | --hanet_pos 2 1 \ 28 | --pos_rfactor 8 \ 29 | --dropout 0.1 \ 30 | --pos_noise 0.5 \ 31 | --aux_loss \ 32 | --cls_wt_loss \ 33 | --date 0101 \ 34 | --exp r101_os8_hanet_best \ 35 | --ckpt ./logs/ \ 36 | --tb_path ./logs/ 37 | -------------------------------------------------------------------------------- /scripts/train_resnext_fr_pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Example on Cityscapes 3 | python -m torch.distributed.launch --nproc_per_node=6 train.py \ 4 | --dataset cityscapes \ 5 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS8 \ 6 | --snapshot pretrained/resnext_mapillary_0.47475.pth \ 7 | --city_mode 'trainval' \ 8 | --lr_schedule poly \ 9 | --lr 0.01 \ 10 | --poly_exp 0.9 \ 11 | --hanet_lr 0.01 \ 12 | --hanet_poly_exp 0.9 \ 13 | --max_cu_epoch 10000 \ 14 | --class_uniform_pct 0.5 \ 15 | --class_uniform_tile 1024 \ 16 | --coarse_boost_classes 14,15,16,3,12,17,4 \ 17 | --syncbn \ 18 | --sgd \ 19 | --crop_size 864 \ 20 | --scale_min 0.5 \ 21 | --scale_max 2.0 \ 22 | --rrotate 0 \ 23 | --color_aug 0.25 \ 24 | --gblur \ 25 | --max_iter 90000 \ 26 | --bs_mult 2 \ 27 | --hanet 1 1 1 1 0 \ 28 | --hanet_set 3 64 3 \ 29 | --hanet_pos 2 1 \ 30 | --pos_rfactor 8 \ 31 | --dropout 0.1 \ 32 | --pos_noise 0.5 \ 33 | --aux_loss \ 34 | --cls_wt_loss \ 35 | --date 0101 \ 36 | --exp resnext_from_pretrain \ 37 | --ckpt ./logs/ \ 38 | --tb_path ./logs/ 39 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Network Initializations 3 | """ 4 | 5 | import logging 6 | import importlib 7 | import torch 8 | 9 | 10 | 11 | def get_net(args, criterion, criterion_aux=None): 12 | """ 13 | Get Network Architecture based on arguments provided 14 | """ 15 | net = get_model(args=args, num_classes=args.dataset_cls.num_classes, 16 | criterion=criterion, criterion_aux=criterion_aux) 17 | num_params = sum([param.nelement() for param in net.parameters()]) 18 | logging.info('Model params = {:2.3f}M'.format(num_params / 1000000)) 19 | 20 | net = net.cuda() 21 | return net 22 | 23 | 24 | def warp_network_in_dataparallel(net, gpuid): 25 | """ 26 | Wrap the network in Dataparallel 27 | """ 28 | # torch.cuda.set_device(gpuid) 29 | # net.cuda(gpuid) 30 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid], find_unused_parameters=True) 31 | # net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid])#, find_unused_parameters=True) 32 | return net 33 | 34 | 35 | def get_model(args, num_classes, criterion, criterion_aux=None): 36 | """ 37 | Fetch Network Function Pointer 38 | """ 39 | network = args.arch 40 | module = network[:network.rfind('.')] 41 | model = network[network.rfind('.') + 1:] 42 | mod = importlib.import_module(module) 43 | net_func = getattr(mod, model) 44 | net = net_func(args=args, num_classes=num_classes, criterion=criterion, criterion_aux=criterion_aux) 45 | return net 46 | -------------------------------------------------------------------------------- /utils/attr_dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | 30 | class AttrDict(dict): 31 | 32 | IMMUTABLE = '__immutable__' 33 | 34 | def __init__(self, *args, **kwargs): 35 | super(AttrDict, self).__init__(*args, **kwargs) 36 | self.__dict__[AttrDict.IMMUTABLE] = False 37 | 38 | def __getattr__(self, name): 39 | if name in self.__dict__: 40 | return self.__dict__[name] 41 | elif name in self: 42 | return self[name] 43 | else: 44 | raise AttributeError(name) 45 | 46 | def __setattr__(self, name, value): 47 | if not self.__dict__[AttrDict.IMMUTABLE]: 48 | if name in self.__dict__: 49 | self.__dict__[name] = value 50 | else: 51 | self[name] = value 52 | else: 53 | raise AttributeError( 54 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'. 55 | format(name, value) 56 | ) 57 | 58 | def immutable(self, is_immutable): 59 | """Set immutability to is_immutable and recursively apply the setting 60 | to all nested AttrDicts. 61 | """ 62 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable 63 | # Recursively set immutable state 64 | for v in self.__dict__.values(): 65 | if isinstance(v, AttrDict): 66 | v.immutable(is_immutable) 67 | for v in self.values(): 68 | if isinstance(v, AttrDict): 69 | v.immutable(is_immutable) 70 | 71 | def is_immutable(self): 72 | return self.__dict__[AttrDict.IMMUTABLE] 73 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | ############################################################################## 30 | #Config 31 | ############################################################################## 32 | 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | from __future__ import unicode_literals 38 | 39 | 40 | import torch 41 | 42 | 43 | from utils.attr_dict import AttrDict 44 | 45 | 46 | __C = AttrDict() 47 | cfg = __C 48 | __C.ITER = 0 49 | __C.EPOCH = 0 50 | # Use Class Uniform Sampling to give each class proper sampling 51 | __C.CLASS_UNIFORM_PCT = 0.0 52 | 53 | # Use class weighted loss per batch to increase loss for low pixel count classes per batch 54 | __C.BATCH_WEIGHTING = False 55 | 56 | # Border Relaxation Count 57 | __C.BORDER_WINDOW = 1 58 | # Number of epoch to use before turn off border restriction 59 | __C.REDUCE_BORDER_ITER = -1 60 | __C.REDUCE_BORDER_EPOCH = -1 61 | # Comma Seperated List of class id to relax 62 | __C.STRICTBORDERCLASS = None 63 | 64 | 65 | 66 | #Attribute Dictionary for Dataset 67 | __C.DATASET = AttrDict() 68 | #Cityscapes Dir Location 69 | __C.DATASET.CITYSCAPES_DIR = '/home/nas_datasets/segmentation/cityscapes' 70 | #SDC Augmented Cityscapes Dir Location 71 | __C.DATASET.CITYSCAPES_AUG_DIR = '' 72 | #Mapillary Dataset Dir Location 73 | __C.DATASET.MAPILLARY_DIR = '/home/nas_datasets/segmentation/mapillary' 74 | #GTAV, BDD100K Dataset Dir Location 75 | __C.DATASET.GTAV_DIR = '/home/nas_datasets/segmentation/gtav' 76 | __C.DATASET.BDD_DIR = '/home/nas_datasets/segmentation/bdd100k/seg' 77 | #Kitti Dataset Dir Location 78 | __C.DATASET.KITTI_DIR = '' 79 | #SDC Augmented Kitti Dataset Dir Location 80 | __C.DATASET.KITTI_AUG_DIR = '' 81 | #Camvid Dataset Dir Location 82 | __C.DATASET.CAMVID_DIR = '/home/nas_datasets/segmentation/SegNet-Tutorial/CamVid' 83 | #Number of splits to support 84 | __C.DATASET.CV_SPLITS = 3 85 | 86 | 87 | __C.MODEL = AttrDict() 88 | __C.MODEL.BN = 'pytorch-syncnorm' 89 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm 90 | 91 | def assert_and_infer_cfg(args, make_immutable=True, train_mode=True): 92 | """Call this function in your script after you have finished setting all cfg 93 | values that are necessary (e.g., merging a config from a file, merging 94 | command line config options, etc.). By default, this function will also 95 | mark the global cfg as immutable to prevent changing the global cfg settings 96 | during script execution (which can lead to hard to debug errors or code 97 | that's harder to understand than is necessary). 98 | """ 99 | 100 | if hasattr(args, 'syncbn') and args.syncbn: 101 | __C.MODEL.BN = 'pytorch-syncnorm' 102 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm 103 | print('Using pytorch sync batch norm') 104 | else: 105 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d 106 | print('Using regular batch norm') 107 | 108 | if not train_mode: 109 | cfg.immutable(True) 110 | return 111 | if args.class_uniform_pct: 112 | cfg.CLASS_UNIFORM_PCT = args.class_uniform_pct 113 | 114 | if args.batch_weighting: 115 | __C.BATCH_WEIGHTING = True 116 | 117 | if args.jointwtborder: 118 | if args.strict_bdr_cls != '': 119 | __C.STRICTBORDERCLASS = [int(i) for i in args.strict_bdr_cls.split(",")] 120 | if args.rlx_off_iter > -1: 121 | __C.REDUCE_BORDER_ITER = args.rlx_off_iter 122 | 123 | if make_immutable: 124 | cfg.immutable(True) 125 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | 37 | 38 | import math 39 | import torch 40 | from torch.distributed import get_world_size, get_rank 41 | from torch.utils.data import Sampler 42 | 43 | class DistributedSampler(Sampler): 44 | """Sampler that restricts data loading to a subset of the dataset. 45 | 46 | It is especially useful in conjunction with 47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 48 | process can pass a DistributedSampler instance as a DataLoader sampler, 49 | and load a subset of the original dataset that is exclusive to it. 50 | 51 | .. note:: 52 | Dataset is assumed to be of constant size. 53 | 54 | Arguments: 55 | dataset: Dataset used for sampling. 56 | num_replicas (optional): Number of processes participating in 57 | distributed training. 58 | rank (optional): Rank of the current process within num_replicas. 59 | """ 60 | 61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None): 62 | if num_replicas is None: 63 | num_replicas = get_world_size() 64 | if rank is None: 65 | rank = get_rank() 66 | self.dataset = dataset 67 | self.num_replicas = num_replicas 68 | self.rank = rank 69 | self.epoch = 0 70 | self.consecutive_sample = consecutive_sample 71 | self.permutation = permutation 72 | if pad: 73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 74 | else: 75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 76 | self.total_size = self.num_samples * self.num_replicas 77 | 78 | def __iter__(self): 79 | # deterministically shuffle based on epoch 80 | g = torch.Generator() 81 | g.manual_seed(self.epoch) 82 | 83 | if self.permutation: 84 | indices = list(torch.randperm(len(self.dataset), generator=g)) 85 | else: 86 | indices = list([x for x in range(len(self.dataset))]) 87 | 88 | # add extra samples to make it evenly divisible 89 | if self.total_size > len(indices): 90 | indices += indices[:(self.total_size - len(indices))] 91 | 92 | # subsample 93 | if self.consecutive_sample: 94 | offset = self.num_samples * self.rank 95 | indices = indices[offset:offset + self.num_samples] 96 | else: 97 | indices = indices[self.rank:self.total_size:self.num_replicas] 98 | assert len(indices) == self.num_samples 99 | 100 | return iter(indices) 101 | 102 | def __len__(self): 103 | return self.num_samples 104 | 105 | def set_epoch(self, epoch): 106 | self.epoch = epoch 107 | 108 | def set_num_samples(self): 109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 110 | self.total_size = self.num_samples * self.num_replicas -------------------------------------------------------------------------------- /network/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization 3 | """ 4 | import torch.nn as nn 5 | import torch 6 | from config import cfg 7 | 8 | def Norm2d(in_channels): 9 | """ 10 | Custom Norm Function to allow flexible switching 11 | """ 12 | layer = getattr(cfg.MODEL, 'BNFUNC') 13 | normalization_layer = layer(in_channels) 14 | return normalization_layer 15 | 16 | 17 | def freeze_weights(*models): 18 | for model in models: 19 | for k in model.parameters(): 20 | k.requires_grad = False 21 | 22 | def unfreeze_weights(*models): 23 | for model in models: 24 | for k in model.parameters(): 25 | k.requires_grad = True 26 | 27 | def initialize_weights(*models): 28 | """ 29 | Initialize Model Weights 30 | """ 31 | for model in models: 32 | for module in model.modules(): 33 | if isinstance(module, (nn.Conv2d, nn.Linear)): 34 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu') 35 | if module.bias is not None: 36 | module.bias.data.zero_() 37 | elif isinstance(module, nn.Conv1d): 38 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu') 39 | if module.bias is not None: 40 | module.bias.data.zero_() 41 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or \ 42 | isinstance(module, nn.GroupNorm) or isinstance(module, nn.SyncBatchNorm): 43 | module.weight.data.fill_(1) 44 | module.bias.data.zero_() 45 | 46 | def initialize_embedding(*models): 47 | """ 48 | Initialize Model Weights 49 | """ 50 | for model in models: 51 | for module in model.modules(): 52 | if isinstance(module, nn.Embedding): 53 | module.weight.data.zero_() #original 54 | 55 | 56 | 57 | def Upsample(x, size): 58 | """ 59 | Wrapper Around the Upsample Call 60 | """ 61 | return nn.functional.interpolate(x, size=size, mode='bilinear', 62 | align_corners=True) 63 | 64 | def Zero_Masking(input_tensor, mask_org): 65 | output = input_tensor.clone() 66 | output.mul_(mask_org) 67 | return output 68 | 69 | def RandomPosZero_Masking(input_tensor, p=0.5): 70 | output = input_tensor.clone() 71 | noise_b = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3)) 72 | noise_u = input_tensor.new().resize_(input_tensor.size(0), input_tensor.size(1), input_tensor.size(2), input_tensor.size(3)) 73 | noise_b.bernoulli_(1 - p) 74 | noise_b = noise_b.expand_as(input_tensor) 75 | output.mul_(noise_b) 76 | return output 77 | 78 | def RandomVal_Masking(input_tensor, mask_org): 79 | output = input_tensor.clone() 80 | noise_u = input_tensor.new().resize_(input_tensor.size(0), input_tensor.size(1), input_tensor.size(2), input_tensor.size(3)) 81 | mask = (mask_org==0).type(input_tensor.type()) 82 | mask = mask.expand_as(input_tensor) 83 | mask = torch.mul(mask, noise_u.uniform_(torch.min(input_tensor).item(), torch.max(input_tensor).item())) 84 | mask_org = mask_org.expand_as(input_tensor) 85 | output.mul_(mask_org) 86 | output.add_(mask) 87 | return output 88 | 89 | def RandomPosVal_Masking(input_tensor, p=0.5): 90 | output = input_tensor.clone() 91 | noise_b = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3)) 92 | noise_u = input_tensor.new().resize_(input_tensor.size(0), input_tensor.size(1), input_tensor.size(2), input_tensor.size(3)) 93 | mask = noise_b.bernoulli_(1 - p) 94 | mask = (mask==0).type(input_tensor.type()) 95 | mask = mask.expand_as(input_tensor) 96 | mask = torch.mul(mask, noise_u.uniform_(torch.min(input_tensor).item(), torch.max(input_tensor).item())) 97 | noise_b = noise_b.expand_as(input_tensor) 98 | output.mul_(noise_b) 99 | output.add_(mask) 100 | return output 101 | 102 | def masking(input_tensor, p=0.5): 103 | output = input_tensor.clone() 104 | noise_b = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3)) 105 | noise_u = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3)) 106 | mask = noise_b.bernoulli_(1 - p) 107 | mask = (mask==0).type(input_tensor.type()) 108 | mask.mul_(noise_u.uniform_(torch.min(input_tensor).item(), torch.max(input_tensor).item())) 109 | # mask.mul_(noise_u.uniform_(5, 10)) 110 | noise_b = noise_b.expand_as(input_tensor) 111 | mask = mask.expand_as(input_tensor) 112 | output.mul_(noise_b) 113 | output.add_(mask) 114 | return output 115 | -------------------------------------------------------------------------------- /network/HANet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from network.mynn import Norm2d, Upsample 6 | from network.PosEmbedding import PosEmbedding1D, PosEncoding1D 7 | 8 | 9 | class HANet_Conv(nn.Module): 10 | 11 | def __init__(self, in_channel, out_channel, kernel_size=3, r_factor=64, layer=3, pos_injection=2, is_encoding=1, 12 | pos_rfactor=8, pooling='mean', dropout_prob=0.0, pos_noise=0.0): 13 | super(HANet_Conv, self).__init__() 14 | 15 | self.pooling = pooling 16 | self.pos_injection = pos_injection 17 | self.layer = layer 18 | self.dropout_prob = dropout_prob 19 | self.sigmoid = nn.Sigmoid() 20 | 21 | if r_factor > 0: 22 | mid_1_channel = math.ceil(in_channel / r_factor) 23 | elif r_factor < 0: 24 | r_factor = r_factor * -1 25 | mid_1_channel = in_channel * r_factor 26 | 27 | if self.dropout_prob > 0: 28 | self.dropout = nn.Dropout2d(self.dropout_prob) 29 | 30 | self.attention_first = nn.Sequential( 31 | nn.Conv1d(in_channels=in_channel, out_channels=mid_1_channel, 32 | kernel_size=1, stride=1, padding=0, bias=False), 33 | Norm2d(mid_1_channel), 34 | nn.ReLU(inplace=True)) 35 | 36 | if layer == 2: 37 | self.attention_second = nn.Sequential( 38 | nn.Conv1d(in_channels=mid_1_channel, out_channels=out_channel, 39 | kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True)) 40 | elif layer == 3: 41 | mid_2_channel = (mid_1_channel * 2) 42 | self.attention_second = nn.Sequential( 43 | nn.Conv1d(in_channels=mid_1_channel, out_channels=mid_2_channel, 44 | kernel_size=3, stride=1, padding=1, bias=True), 45 | Norm2d(mid_2_channel), 46 | nn.ReLU(inplace=True)) 47 | self.attention_third = nn.Sequential( 48 | nn.Conv1d(in_channels=mid_2_channel, out_channels=out_channel, 49 | kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True)) 50 | 51 | if self.pooling == 'mean': 52 | #print("##### average pooling") 53 | self.rowpool = nn.AdaptiveAvgPool2d((128//pos_rfactor,1)) 54 | else: 55 | #print("##### max pooling") 56 | self.rowpool = nn.AdaptiveMaxPool2d((128//pos_rfactor,1)) 57 | 58 | if pos_rfactor > 0: 59 | if is_encoding == 0: 60 | if self.pos_injection == 1: 61 | self.pos_emb1d_1st = PosEmbedding1D(pos_rfactor, dim=in_channel, pos_noise=pos_noise) 62 | elif self.pos_injection == 2: 63 | self.pos_emb1d_2nd = PosEmbedding1D(pos_rfactor, dim=mid_1_channel, pos_noise=pos_noise) 64 | elif is_encoding == 1: 65 | if self.pos_injection == 1: 66 | self.pos_emb1d_1st = PosEncoding1D(pos_rfactor, dim=in_channel, pos_noise=pos_noise) 67 | elif self.pos_injection == 2: 68 | self.pos_emb1d_2nd = PosEncoding1D(pos_rfactor, dim=mid_1_channel, pos_noise=pos_noise) 69 | else: 70 | print("Not supported position encoding") 71 | exit() 72 | 73 | 74 | def forward(self, x, out, pos=None, return_attention=False, return_posmap=False, attention_loss=False): 75 | """ 76 | inputs : 77 | x : input feature maps( B X C X W X H) 78 | returns : 79 | out : self attention value + input feature 80 | attention: B X N X N (N is Width*Height) 81 | """ 82 | H = out.size(2) 83 | x1d = self.rowpool(x).squeeze(3) 84 | 85 | if pos is not None and self.pos_injection == 1: 86 | if return_posmap: 87 | x1d, pos_map1 = self.pos_emb1d_1st(x1d, pos, True) 88 | else: 89 | x1d = self.pos_emb1d_1st(x1d, pos) 90 | 91 | if self.dropout_prob > 0: 92 | x1d = self.dropout(x1d) 93 | x1d = self.attention_first(x1d) 94 | 95 | if pos is not None and self.pos_injection == 2: 96 | if return_posmap: 97 | x1d, pos_map2 = self.pos_emb1d_2nd(x1d, pos, True) 98 | else: 99 | x1d = self.pos_emb1d_2nd(x1d, pos) 100 | 101 | x1d = self.attention_second(x1d) 102 | 103 | if self.layer == 3: 104 | x1d = self.attention_third(x1d) 105 | if attention_loss: 106 | last_attention = x1d 107 | x1d = self.sigmoid(x1d) 108 | else: 109 | if attention_loss: 110 | last_attention = x1d 111 | x1d = self.sigmoid(x1d) 112 | 113 | x1d = F.interpolate(x1d, size=H, mode='linear') 114 | out = torch.mul(out, x1d.unsqueeze(3)) 115 | 116 | if return_attention: 117 | if return_posmap: 118 | if self.pos_injection == 1: 119 | pos_map = (pos_map1) 120 | elif self.pos_injection == 2: 121 | pos_map = (pos_map2) 122 | return out, x1d, pos_map 123 | else: 124 | return out, x1d 125 | else: 126 | if attention_loss: 127 | return out, last_attention 128 | else: 129 | return out 130 | -------------------------------------------------------------------------------- /network/PosEmbedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from network.mynn import Norm2d, Upsample, initialize_embedding 5 | import numpy as np 6 | 7 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 8 | ''' Sinusoid position encoding table ''' 9 | def cal_angle(position, hid_idx): 10 | if d_hid > 50: 11 | cycle = 10 12 | elif d_hid > 5: 13 | cycle = 100 14 | else: 15 | cycle = 10000 16 | cycle = 10 if d_hid > 50 else 100 17 | return position / np.power(cycle, 2 * (hid_idx // 2) / d_hid) 18 | def get_posi_angle_vec(position): 19 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 20 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 21 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 22 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 23 | if padding_idx is not None: 24 | # zero vector for padding dimension 25 | sinusoid_table[padding_idx] = 0. 26 | return torch.FloatTensor(sinusoid_table) 27 | 28 | class PosEmbedding2D(nn.Module): 29 | 30 | def __init__(self, pos_rfactor, dim): 31 | super(PosEmbedding2D, self).__init__() 32 | 33 | self.pos_layer_h = nn.Embedding((128//pos_rfactor)+1, dim) 34 | self.pos_layer_w = nn.Embedding((128//pos_rfactor)+1, dim) 35 | initialize_embedding(self.pos_layer_h) 36 | initialize_embedding(self.pos_layer_w) 37 | 38 | def forward(self, x, pos): 39 | pos_h, pos_w = pos 40 | pos_h = pos_h.unsqueeze(1) 41 | pos_w = pos_w.unsqueeze(1) 42 | pos_h = nn.functional.interpolate(pos_h.float(), size=x.shape[2:], mode='nearest').long() # B X 1 X H X W 43 | pos_w = nn.functional.interpolate(pos_w.float(), size=x.shape[2:], mode='nearest').long() # B X 1 X H X W 44 | pos_h = self.pos_layer_h(pos_h).transpose(1,4).squeeze(4) # B X 1 X H X W X C 45 | pos_w = self.pos_layer_w(pos_w).transpose(1,4).squeeze(4) # B X 1 X H X W X C 46 | x = x + pos_h + pos_w 47 | return x 48 | 49 | class PosEncoding1D(nn.Module): 50 | 51 | def __init__(self, pos_rfactor, dim, pos_noise=0.0): 52 | super(PosEncoding1D, self).__init__() 53 | print("use PosEncoding1D") 54 | self.sel_index = torch.tensor([0]).cuda() 55 | pos_enc = (get_sinusoid_encoding_table((128//pos_rfactor)+1, dim) + 1) 56 | self.pos_layer = nn.Embedding.from_pretrained(embeddings=pos_enc, freeze=True) 57 | self.pos_noise = pos_noise 58 | self.noise_clamp = 16 // pos_rfactor # 4: 4, 8: 2, 16: 1 59 | 60 | self.pos_rfactor = pos_rfactor 61 | if pos_noise > 0.0: 62 | self.min = 0.0 #torch.tensor([0]).cuda() 63 | self.max = 128//pos_rfactor #torch.tensor([128//pos_rfactor]).cuda() 64 | self.noise = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([pos_noise])) 65 | 66 | def forward(self, x, pos, return_posmap=False): 67 | pos_h, _ = pos # B X H X W 68 | pos_h = pos_h//self.pos_rfactor 69 | pos_h = pos_h.index_select(2, self.sel_index).unsqueeze(1).squeeze(3) # B X 1 X H 70 | pos_h = nn.functional.interpolate(pos_h.float(), size=x.shape[2], mode='nearest').long() # B X 1 X 48 71 | 72 | if self.training is True and self.pos_noise > 0.0: 73 | #pos_h = pos_h + (self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long() 74 | pos_h = pos_h + torch.clamp((self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long(), 75 | min=-self.noise_clamp, max=self.noise_clamp) 76 | pos_h = torch.clamp(pos_h, min=self.min, max=self.max) 77 | #pos_h = torch.where(pos_h < self.min_tensor, self.min_tensor, pos_h) 78 | #pos_h = torch.where(pos_h > self.max_tensor, self.max_tensor, pos_h) 79 | 80 | pos_h = self.pos_layer(pos_h).transpose(1,3).squeeze(3) # B X 1 X 48 X 80 > B X 80 X 48 X 1 81 | x = x + pos_h 82 | if return_posmap: 83 | return x, self.pos_layer.weight # 33 X 80 84 | return x 85 | 86 | class PosEmbedding1D(nn.Module): 87 | 88 | def __init__(self, pos_rfactor, dim, pos_noise=0.0): 89 | super(PosEmbedding1D, self).__init__() 90 | print("use PosEmbedding1D") 91 | self.sel_index = torch.tensor([0]).cuda() 92 | self.pos_layer = nn.Embedding((128//pos_rfactor)+1, dim) 93 | initialize_embedding(self.pos_layer) 94 | self.pos_noise = pos_noise 95 | self.pos_rfactor = pos_rfactor 96 | self.noise_clamp = 16 // pos_rfactor # 4: 4, 8: 2, 16: 1 97 | 98 | if pos_noise > 0.0: 99 | self.min = 0.0 #torch.tensor([0]).cuda() 100 | self.max = 128//pos_rfactor #torch.tensor([128//pos_rfactor]).cuda() 101 | self.noise = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([pos_noise])) 102 | 103 | def forward(self, x, pos, return_posmap=False): 104 | pos_h, _ = pos # B X H X W 105 | pos_h = pos_h//self.pos_rfactor 106 | pos_h = pos_h.index_select(2, self.sel_index).unsqueeze(1).squeeze(3) # B X 1 X H 107 | pos_h = nn.functional.interpolate(pos_h.float(), size=x.shape[2], mode='nearest').long() # B X 1 X 48 108 | 109 | if self.training is True and self.pos_noise > 0.0: 110 | #pos_h = pos_h + (self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long() 111 | pos_h = pos_h + torch.clamp((self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long(), 112 | min=-self.noise_clamp, max=self.noise_clamp) 113 | pos_h = torch.clamp(pos_h, min=self.min, max=self.max) 114 | 115 | pos_h = self.pos_layer(pos_h).transpose(1,3).squeeze(3) # B X 1 X 48 X 80 > B X 80 X 48 X 1 116 | x = x + pos_h 117 | if return_posmap: 118 | return x, self.pos_layer.weight # 33 X 80 119 | return x 120 | -------------------------------------------------------------------------------- /datasets/mapillary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mapillary Dataset Loader 3 | """ 4 | from PIL import Image 5 | from torch.utils import data 6 | import os 7 | import numpy as np 8 | import json 9 | import datasets.uniform as uniform 10 | from config import cfg 11 | 12 | num_classes = 65 13 | ignore_label = 65 14 | root = cfg.DATASET.MAPILLARY_DIR 15 | config_fn = os.path.join(root, 'config.json') 16 | id_to_ignore_or_group = {} 17 | color_mapping = [] 18 | id_to_trainid = {} 19 | 20 | 21 | def colorize_mask(image_array): 22 | """ 23 | Colorize a segmentation mask 24 | """ 25 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P') 26 | new_mask.putpalette(color_mapping) 27 | return new_mask 28 | 29 | 30 | def make_dataset(quality, mode): 31 | """ 32 | Create File List 33 | """ 34 | assert (quality == 'semantic' and mode in ['train', 'val']) 35 | img_dir_name = None 36 | if quality == 'semantic': 37 | if mode == 'train': 38 | img_dir_name = 'training' 39 | if mode == 'val': 40 | img_dir_name = 'validation' 41 | mask_path = os.path.join(root, img_dir_name, 'labels') 42 | else: 43 | raise BaseException("Instance Segmentation Not support") 44 | 45 | img_path = os.path.join(root, img_dir_name, 'images') 46 | print(img_path) 47 | if quality != 'video': 48 | imgs = sorted([os.path.splitext(f)[0] for f in os.listdir(img_path)]) 49 | msks = sorted([os.path.splitext(f)[0] for f in os.listdir(mask_path)]) 50 | assert imgs == msks 51 | 52 | items = [] 53 | c_items = os.listdir(img_path) 54 | if '.DS_Store' in c_items: 55 | c_items.remove('.DS_Store') 56 | 57 | for it in c_items: 58 | if quality == 'video': 59 | item = (os.path.join(img_path, it), os.path.join(img_path, it)) 60 | else: 61 | item = (os.path.join(img_path, it), 62 | os.path.join(mask_path, it.replace(".jpg", ".png"))) 63 | items.append(item) 64 | return items 65 | 66 | 67 | def gen_colormap(): 68 | """ 69 | Get Color Map from file 70 | """ 71 | global color_mapping 72 | 73 | # load mapillary config 74 | with open(config_fn) as config_file: 75 | config = json.load(config_file) 76 | config_labels = config['labels'] 77 | 78 | # calculate label color mapping 79 | colormap = [] 80 | id2name = {} 81 | for i in range(0, len(config_labels)): 82 | colormap = colormap + config_labels[i]['color'] 83 | id2name[i] = config_labels[i]['readable'] 84 | color_mapping = colormap 85 | return id2name 86 | 87 | 88 | class Mapillary(data.Dataset): 89 | def __init__(self, quality, mode, joint_transform_list=None, 90 | transform=None, target_transform=None, dump_images=False, 91 | class_uniform_pct=0, class_uniform_tile=768, test=False): 92 | """ 93 | class_uniform_pct = Percent of class uniform samples. 1.0 means fully uniform. 94 | 0.0 means fully random. 95 | class_uniform_tile_size = Class uniform tile size 96 | """ 97 | self.quality = quality 98 | self.mode = mode 99 | self.joint_transform_list = joint_transform_list 100 | self.transform = transform 101 | self.target_transform = target_transform 102 | self.dump_images = dump_images 103 | self.class_uniform_pct = class_uniform_pct 104 | self.class_uniform_tile = class_uniform_tile 105 | self.id2name = gen_colormap() 106 | self.imgs_uniform = None 107 | for i in range(num_classes): 108 | id_to_trainid[i] = i 109 | 110 | # find all images 111 | self.imgs = make_dataset(quality, mode) 112 | if len(self.imgs) == 0: 113 | raise RuntimeError('Found 0 images, please check the data set') 114 | if test: 115 | np.random.shuffle(self.imgs) 116 | self.imgs = self.imgs[:200] 117 | 118 | if self.class_uniform_pct: 119 | json_fn = 'mapillary_tile{}.json'.format(self.class_uniform_tile) 120 | if os.path.isfile(json_fn): 121 | with open(json_fn, 'r') as json_data: 122 | centroids = json.load(json_data) 123 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 124 | else: 125 | # centroids is a dict (indexed by class) of lists of centroids 126 | self.centroids = uniform.class_centroids_all( 127 | self.imgs, 128 | num_classes, 129 | id2trainid=None, 130 | tile_size=self.class_uniform_tile) 131 | with open(json_fn, 'w') as outfile: 132 | json.dump(self.centroids, outfile, indent=4) 133 | else: 134 | self.centroids = [] 135 | self.build_epoch() 136 | 137 | def build_epoch(self): 138 | if self.class_uniform_pct != 0: 139 | self.imgs_uniform = uniform.build_epoch(self.imgs, 140 | self.centroids, 141 | num_classes, 142 | self.class_uniform_pct) 143 | else: 144 | self.imgs_uniform = self.imgs 145 | 146 | def __getitem__(self, index): 147 | if len(self.imgs_uniform[index]) == 2: 148 | img_path, mask_path = self.imgs_uniform[index] 149 | centroid = None 150 | class_id = None 151 | else: 152 | img_path, mask_path, centroid, class_id = self.imgs_uniform[index] 153 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 154 | img_name = os.path.splitext(os.path.basename(img_path))[0] 155 | 156 | mask = np.array(mask) 157 | mask_copy = mask.copy() 158 | for k, v in id_to_ignore_or_group.items(): 159 | mask_copy[mask == k] = v 160 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 161 | 162 | # Image Transformations 163 | if self.joint_transform_list is not None: 164 | for idx, xform in enumerate(self.joint_transform_list): 165 | if idx == 0 and centroid is not None: 166 | # HACK! Assume the first transform accepts a centroid 167 | img, mask = xform(img, mask, centroid) 168 | else: 169 | img, mask = xform(img, mask) 170 | 171 | if self.dump_images: 172 | outdir = 'dump_imgs_{}'.format(self.mode) 173 | os.makedirs(outdir, exist_ok=True) 174 | if centroid is not None: 175 | dump_img_name = self.id2name[class_id] + '_' + img_name 176 | else: 177 | dump_img_name = img_name 178 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 179 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 180 | mask_img = colorize_mask(np.array(mask)) 181 | img.save(out_img_fn) 182 | mask_img.save(out_msk_fn) 183 | 184 | if self.transform is not None: 185 | img = self.transform(img) 186 | if self.target_transform is not None: 187 | mask = self.target_transform(mask) 188 | return img, mask, img_name 189 | 190 | def __len__(self): 191 | return len(self.imgs_uniform) 192 | 193 | def calculate_weights(self): 194 | raise BaseException("not supported yet") 195 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## HANet: Official Project Webpage 2 | HANet is an add-on module for urban-scene segmentation to exploit the structural priors existing in urban-scene. It is effective and wide applicable! 3 | 4 |

5 | 6 |

7 | 8 | This repository provides the official PyTorch implementation of the following paper: 9 | > **Cars Can’t Fly up in the Sky:** Improving Urban-Scene Segmentation via Height-driven Attention Networks
10 | > Sungha Choi (LGE, Korea Univ.), Joanne T. Kim (LLNL, Korea Univ.), Jaegul Choo (KAIST)
11 | > In CVPR 2020
12 | 13 | > Paper : [[pdf]](http://openaccess.thecvf.com/content_CVPR_2020/papers/Choi_Cars_Cant_Fly_Up_in_the_Sky_Improving_Urban-Scene_Segmentation_CVPR_2020_paper.pdf) [[supp]](http://openaccess.thecvf.com/content_CVPR_2020/supplemental/Choi_Cars_Cant_Fly_CVPR_2020_supplemental.pdf)
14 | 15 |

16 | 17 |

18 | 19 | > **Abstract:** *This paper exploits the intrinsic features of urban-scene images and proposes a general add-on module, called height driven attention networks (HANet), for improving semantic segmentation for urban-scene images. It emphasizes informative features or classes selectively according to the vertical position of a pixel. The pixel-wise class distributions are significantly different from each other among horizontally segmented sections in the urban-scene images. Likewise, urban-scene images have their own distinct characteristics, but most semantic segmentation networks do not reflect such unique attributes in the architecture. The proposed network architecture incorporates the capability exploiting the attributes to handle the urban scene dataset effectively. We validate the consistent performance (mIoU) increase of various semantic segmentation models on two datasets when HANet is adopted. This extensive quantitative analysis demonstrates that adding our module to existing models is easy and cost-effective. Our method achieves a new state-of-the-art performance on the Cityscapes benchmark with a large margin among ResNet-101 based segmentation models. Also, we show that the proposed model is coherent with the facts observed in the urban scene by visualizing and interpreting the attention map*
20 | 21 | ## Concept Video 22 | Click the figure to watch the youtube video of our paper! 23 | 24 |

25 | Youtube Video
26 |

27 | 28 | ## Cityscapes Benchmark 29 | | Models | Data | Crop Size | Batch Size | Output Stride | mIoU | External Link | 30 | |:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:| 31 | | HANet (ResNext-101) | Fine train/val + Coarse | 864X864 | 12 | 8 | 83.2% | [Benchmark](https://www.cityscapes-dataset.com/anonymous-results/?id=9a8b7333dcb66360b4f38ba00db7c84e7997f7203084bf6e92ca9bbabbc34640) | 32 | | HANet (ResNet-101) | Fine train/val | 864X864 | 12 | 8 | 82.1% | [Benchmark](https://www.cityscapes-dataset.com/anonymous-results/?id=f96818d678c67c82449323203d144e530fb66102a5b5a101f599a96cc62458e7) | 33 | | HANet (ResNet-101) | Fine train | 768X768 | 8 | 8 | 80.9% | [Benchmark](https://www.cityscapes-dataset.com/anonymous-results/?id=1e5e85818e439332fdae01037259706d9091be2b9fca850eb4a851805f5ed44d) | 34 | 35 | 36 | ## Pytorch Implementation 37 | ### Installation 38 | Clone this repository. 39 | ``` 40 | git clone https://github.com/shachoi/HANet.git 41 | cd HANet 42 | ``` 43 | Install following packages. 44 | ``` 45 | conda create --name hanet python=3.6 46 | conda activate hanet 47 | conda install -y pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.1 -c pytorch 48 | conda install scipy==1.4.1 49 | conda install tqdm==4.46.0 50 | conda install scikit-image==0.16.2 51 | pip install tensorboardX==2.0 52 | pip install thop 53 | ``` 54 | 55 | ### Datasets 56 | We evaludated HANet on [Cityscapes](https://www.cityscapes-dataset.com/) and [BDD-100K](https://bair.berkeley.edu/blog/2018/05/30/bdd/). 57 | 58 | For Cityscapes dataset, download "leftImg8bit_trainvaltest.zip" and "gtFine_trainvaltest.zip" from https://www.cityscapes-dataset.com/downloads/
59 | Unzip the files and make the directory structures as follows. 60 | ``` 61 | cityscapes 62 | └ leftImg8bit_trainvaltest 63 | └ leftImg8bit 64 | └ train 65 | └ val 66 | └ test 67 | └ gtFine_trainvaltest 68 | └ gtFine 69 | └ train 70 | └ val 71 | └ test 72 | ``` 73 | You should modify the path in **"/config.py"** according to your Cityscapes dataset path. 74 | 75 | ``` 76 | #Cityscapes Dir Location 77 | __C.DATASET.CITYSCAPES_DIR = 78 | ``` 79 | 80 | Additionally, you can use Cityscapes coarse dataset to get best mIoU score. 81 | 82 | Please refer the training script **"/scripts/train_resnext_fr_pretrain.sh"**. 83 | 84 | The other training scripts don't use Cityscapes coarse dataset. 85 | 86 | ### Pretrained Models 87 | #### All models trained for our paper 88 | You can download all models evaluated in our paper at [Google Drive](https://drive.google.com/drive/folders/1PfrG1d3fq4T9c96FmfOIkfKPShTTRz2G?usp=sharing) 89 | 90 | #### ImageNet pretrained ResNet-101 which has three 3×3 convolutions in the first layer 91 | To train ResNet-101 based HANet, you should download ImageNet pretrained ResNet-101 from [this link](https://drive.google.com/file/d/1Sx1Clf9Q9BsXKklZuUIqSJJhjMNF3jAa/view?usp=sharing). Put it into following directory. 92 | ``` 93 | /pretrained/resnet101-imagenet.pth 94 | ``` 95 | This pretrained model is from [MIT CSAIL Computer Vision Group](http://sceneparsing.csail.mit.edu/) 96 | 97 | #### Mapillary pretrained ResNext-101 98 | You can finetune HANet from Mapillary pretrained ResNext-101 using the training script **"/scripts/train_resnext_fr_pretrain.sh"**. 99 | Download it from [this link](https://drive.google.com/file/d/1GJ4VOSiLwNuyqOgRqQoe9FbvnklI2TYe/view?usp=sharing) and put it into following directory. 100 | ``` 101 | /pretrained/resnext_mapillary_0.47475.pth 102 | ``` 103 | ### Training Networks 104 | According to the specification of your gpu system, you may modify the training script. 105 | ``` 106 | python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE train.py \ 107 | ... 108 | --bs_mult NUM_BATCH_PER_SINGLE_GPU \ 109 | ... 110 | ``` 111 | You can train HANet (based on ResNet-101) using **finely annotated training and validation set** with following command. 112 | ``` 113 | $ CUDA_VISIBLE_DEVICES=0,1,2,3 ./scripts/train_r101_os8_hanet_best.sh 114 | ``` 115 | Otherwise, you can train HANet (based on ResNet-101) using only **finely annotated training set** with following command. 116 | ``` 117 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/train_r101_os8_hanet.sh 118 | ``` 119 | To run the script "train_r101_os8_hanet.sh", two Titan RTX GPUs (2 X 24GB GPU Memory) are required. 120 | 121 | Additioanlly, we provide various training scripts like MobileNet based HANet. 122 | 123 | The results will be stored in **"/logs/"** 124 | ### Inference 125 | ``` 126 | $ CUDA_VISIBLE_DEVICES=0 ./scripts/eval_r101_os8.sh 127 | ``` 128 | ### Submit to Cityscapes benchmark server 129 | ``` 130 | $ CUDA_VISIBLE_DEVICES=0 ./scripts/submit_r101_os8.sh 131 | ``` 132 | 133 | ## Citation 134 | If you find this work useful for your research, please cite our paper: 135 | ``` 136 | @InProceedings{Choi_2020_CVPR, 137 | author = {Choi, Sungha and Kim, Joanne T. and Choo, Jaegul}, 138 | title = {Cars Can't Fly Up in the Sky: Improving Urban-Scene Segmentation via Height-Driven Attention Networks}, 139 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 140 | month = {June}, 141 | year = {2020} 142 | } 143 | ``` 144 | 145 | ## Acknowledgments 146 | Our pytorch implementation is heavily derived from [NVIDIA segmentation](https://github.com/NVIDIA/semantic-segmentation). 147 | Thanks to the NVIDIA implementations. 148 | -------------------------------------------------------------------------------- /utils/my_data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | # Code adapted from: 4 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py 5 | # 6 | # BSD 3-Clause License 7 | # 8 | # Copyright (c) 2017, 9 | # All rights reserved. 10 | # 11 | # Redistribution and use in source and binary forms, with or without 12 | # modification, are permitted provided that the following conditions are met: 13 | # 14 | # * Redistributions of source code must retain the above copyright notice, this 15 | # list of conditions and the following disclaimer. 16 | # 17 | # * Redistributions in binary form must reproduce the above copyright notice, 18 | # this list of conditions and the following disclaimer in the documentation 19 | # and/or other materials provided with the distribution. 20 | # 21 | # * Neither the name of the copyright holder nor the names of its 22 | # contributors may be used to endorse or promote products derived from 23 | # this software without specific prior written permission. 24 | # 25 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 26 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 27 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 28 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 29 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 30 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 31 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 32 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 33 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 34 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s 35 | """ 36 | 37 | 38 | import operator 39 | import torch 40 | import warnings 41 | from torch.nn.modules import Module 42 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 43 | from torch.nn.parallel.replicate import replicate 44 | from torch.nn.parallel.parallel_apply import parallel_apply 45 | 46 | 47 | def _check_balance(device_ids): 48 | imbalance_warn = """ 49 | There is an imbalance between your GPUs. You may want to exclude GPU {} which 50 | has less than 75% of the memory or cores of GPU {}. You can do so by setting 51 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES 52 | environment variable.""" 53 | 54 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] 55 | 56 | def warn_imbalance(get_prop): 57 | values = [get_prop(props) for props in dev_props] 58 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) 59 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) 60 | if min_val / max_val < 0.75: 61 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) 62 | return True 63 | return False 64 | 65 | if warn_imbalance(lambda props: props.total_memory): 66 | return 67 | if warn_imbalance(lambda props: props.multi_processor_count): 68 | return 69 | 70 | 71 | 72 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, gather=True): 73 | """ 74 | Evaluates module(input) in parallel across the GPUs given in device_ids. 75 | This is the functional version of the DataParallel module. 76 | Args: 77 | module: the module to evaluate in parallel 78 | inputs: inputs to the module 79 | device_ids: GPU ids on which to replicate module 80 | output_device: GPU location of the output Use -1 to indicate the CPU. 81 | (default: device_ids[0]) 82 | Returns: 83 | a Tensor containing the result of module(input) located on 84 | output_device 85 | """ 86 | if not isinstance(inputs, tuple): 87 | inputs = (inputs,) 88 | 89 | if device_ids is None: 90 | device_ids = list(range(torch.cuda.device_count())) 91 | 92 | if output_device is None: 93 | output_device = device_ids[0] 94 | 95 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) 96 | if len(device_ids) == 1: 97 | return module(*inputs[0], **module_kwargs[0]) 98 | used_device_ids = device_ids[:len(inputs)] 99 | replicas = replicate(module, used_device_ids) 100 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) 101 | if gather: 102 | return gather(outputs, output_device, dim) 103 | else: 104 | return outputs 105 | 106 | 107 | 108 | class MyDataParallel(Module): 109 | """ 110 | Implements data parallelism at the module level. 111 | This container parallelizes the application of the given module by 112 | splitting the input across the specified devices by chunking in the batch 113 | dimension. In the forward pass, the module is replicated on each device, 114 | and each replica handles a portion of the input. During the backwards 115 | pass, gradients from each replica are summed into the original module. 116 | The batch size should be larger than the number of GPUs used. 117 | See also: :ref:`cuda-nn-dataparallel-instead` 118 | Arbitrary positional and keyword inputs are allowed to be passed into 119 | DataParallel EXCEPT Tensors. All tensors will be scattered on dim 120 | specified (default 0). Primitive types will be broadcasted, but all 121 | other types will be a shallow copy and can be corrupted if written to in 122 | the model's forward pass. 123 | .. warning:: 124 | Forward and backward hooks defined on :attr:`module` and its submodules 125 | will be invoked ``len(device_ids)`` times, each with inputs located on 126 | a particular device. Particularly, the hooks are only guaranteed to be 127 | executed in correct order with respect to operations on corresponding 128 | devices. For example, it is not guaranteed that hooks set via 129 | :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before 130 | `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but 131 | that each such hook be executed before the corresponding 132 | :meth:`~torch.nn.Module.forward` call of that device. 133 | .. warning:: 134 | When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in 135 | :func:`forward`, this wrapper will return a vector of length equal to 136 | number of devices used in data parallelism, containing the result from 137 | each device. 138 | .. note:: 139 | There is a subtlety in using the 140 | ``pack sequence -> recurrent network -> unpack sequence`` pattern in a 141 | :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. 142 | See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for 143 | details. 144 | Args: 145 | module: module to be parallelized 146 | device_ids: CUDA devices (default: all devices) 147 | output_device: device location of output (default: device_ids[0]) 148 | Attributes: 149 | module (Module): the module to be parallelized 150 | Example:: 151 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 152 | >>> output = net(input_var) 153 | """ 154 | 155 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well 156 | 157 | def __init__(self, module, device_ids=None, output_device=None, dim=0, gather=True): 158 | super(MyDataParallel, self).__init__() 159 | 160 | if not torch.cuda.is_available(): 161 | self.module = module 162 | self.device_ids = [] 163 | return 164 | 165 | if device_ids is None: 166 | device_ids = list(range(torch.cuda.device_count())) 167 | if output_device is None: 168 | output_device = device_ids[0] 169 | self.dim = dim 170 | self.module = module 171 | self.device_ids = device_ids 172 | self.output_device = output_device 173 | self.gather_bool = gather 174 | 175 | _check_balance(self.device_ids) 176 | 177 | if len(self.device_ids) == 1: 178 | self.module.cuda(device_ids[0]) 179 | 180 | def forward(self, *inputs, **kwargs): 181 | if not self.device_ids: 182 | return self.module(*inputs, **kwargs) 183 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 184 | if len(self.device_ids) == 1: 185 | return [self.module(*inputs[0], **kwargs[0])] 186 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 187 | outputs = self.parallel_apply(replicas, inputs, kwargs) 188 | if self.gather_bool: 189 | return self.gather(outputs, self.output_device) 190 | else: 191 | return outputs 192 | 193 | def replicate(self, module, device_ids): 194 | return replicate(module, device_ids) 195 | 196 | def scatter(self, inputs, kwargs, device_ids): 197 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 198 | 199 | def parallel_apply(self, replicas, inputs, kwargs): 200 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 201 | 202 | def gather(self, outputs, output_device): 203 | return gather(outputs, output_device, dim=self.dim) 204 | 205 | -------------------------------------------------------------------------------- /datasets/kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | KITTI Dataset Loader 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | from PIL import Image 9 | from torch.utils import data 10 | import logging 11 | import datasets.uniform as uniform 12 | import datasets.cityscapes_labels as cityscapes_labels 13 | import json 14 | from config import cfg 15 | 16 | 17 | trainid_to_name = cityscapes_labels.trainId2name 18 | id_to_trainid = cityscapes_labels.label2trainid 19 | num_classes = 19 20 | ignore_label = 255 21 | root = cfg.DATASET.KITTI_DIR 22 | aug_root = cfg.DATASET.KITTI_AUG_DIR 23 | 24 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 25 | 153, 153, 153, 250, 170, 30, 26 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 27 | 255, 0, 0, 0, 0, 142, 0, 0, 70, 28 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 29 | zero_pad = 256 * 3 - len(palette) 30 | for i in range(zero_pad): 31 | palette.append(0) 32 | 33 | def colorize_mask(mask): 34 | # mask: numpy array of the mask 35 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 36 | new_mask.putpalette(palette) 37 | return new_mask 38 | 39 | def get_train_val(cv_split, all_items): 40 | 41 | # 90/10 train/val split, three random splits 42 | val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198] 43 | val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197] 44 | val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199] 45 | 46 | train_set = [] 47 | val_set = [] 48 | 49 | if cv_split == 0: 50 | for i in range(200): 51 | if i in val_0: 52 | val_set.append(all_items[i]) 53 | else: 54 | train_set.append(all_items[i]) 55 | elif cv_split == 1: 56 | for i in range(200): 57 | if i in val_1: 58 | val_set.append(all_items[i]) 59 | else: 60 | train_set.append(all_items[i]) 61 | elif cv_split == 2: 62 | for i in range(200): 63 | if i in val_2: 64 | val_set.append(all_items[i]) 65 | else: 66 | train_set.append(all_items[i]) 67 | else: 68 | logging.info('Unknown cv_split {}'.format(cv_split)) 69 | sys.exit() 70 | 71 | return train_set, val_set 72 | 73 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0): 74 | 75 | items = [] 76 | all_items = [] 77 | aug_items = [] 78 | 79 | assert quality == 'semantic' 80 | assert mode in ['train', 'val', 'trainval'] 81 | # note that train and val are randomly determined, no official split 82 | 83 | img_dir_name = "training" 84 | img_path = os.path.join(root, img_dir_name, 'image_2') 85 | mask_path = os.path.join(root, img_dir_name, 'semantic') 86 | 87 | c_items = os.listdir(img_path) 88 | c_items.sort() 89 | 90 | for it in c_items: 91 | item = (os.path.join(img_path, it), os.path.join(mask_path, it)) 92 | all_items.append(item) 93 | logging.info('KITTI has a total of {} images'.format(len(all_items))) 94 | 95 | # split into train/val 96 | train_set, val_set = get_train_val(cv_split, all_items) 97 | 98 | if mode == 'train': 99 | items = train_set 100 | elif mode == 'val': 101 | items = val_set 102 | elif mode == 'trainval': 103 | items = train_set + val_set 104 | else: 105 | logging.info('Unknown mode {}'.format(mode)) 106 | sys.exit() 107 | 108 | logging.info('KITTI-{}: {} images'.format(mode, len(items))) 109 | 110 | return items, aug_items 111 | 112 | class KITTI(data.Dataset): 113 | 114 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None, 115 | transform=None, target_transform=None, dump_images=False, 116 | class_uniform_pct=0, class_uniform_tile=0, test=False, 117 | cv_split=None, scf=None, hardnm=0): 118 | 119 | self.quality = quality 120 | self.mode = mode 121 | self.maxSkip = maxSkip 122 | self.joint_transform_list = joint_transform_list 123 | self.transform = transform 124 | self.target_transform = target_transform 125 | self.dump_images = dump_images 126 | self.class_uniform_pct = class_uniform_pct 127 | self.class_uniform_tile = class_uniform_tile 128 | self.scf = scf 129 | self.hardnm = hardnm 130 | 131 | if cv_split: 132 | self.cv_split = cv_split 133 | assert cv_split < cfg.DATASET.CV_SPLITS, \ 134 | 'expected cv_split {} to be < CV_SPLITS {}'.format( 135 | cv_split, cfg.DATASET.CV_SPLITS) 136 | else: 137 | self.cv_split = 0 138 | 139 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm) 140 | assert len(self.imgs), 'Found 0 images, please check the data set' 141 | # self.cal_shape(self.imgs) 142 | 143 | # Centroids for GT data 144 | if self.class_uniform_pct > 0: 145 | if self.scf: 146 | json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split) 147 | else: 148 | json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm) 149 | if os.path.isfile(json_fn): 150 | with open(json_fn, 'r') as json_data: 151 | centroids = json.load(json_data) 152 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 153 | else: 154 | if self.scf: 155 | self.centroids = kitti_uniform.class_centroids_all( 156 | self.imgs, 157 | num_classes, 158 | id2trainid=id_to_trainid, 159 | tile_size=class_uniform_tile) 160 | else: 161 | self.centroids = uniform.class_centroids_all( 162 | self.imgs, 163 | num_classes, 164 | id2trainid=id_to_trainid, 165 | tile_size=class_uniform_tile) 166 | with open(json_fn, 'w') as outfile: 167 | json.dump(self.centroids, outfile, indent=4) 168 | 169 | self.build_epoch() 170 | 171 | 172 | def cal_shape(self, imgs): 173 | 174 | for i in imgs: 175 | img_path, mask_path = i 176 | img = Image.open(img_path).convert('RGB') 177 | print(img.size) 178 | 179 | def build_epoch(self, cut=False): 180 | if self.class_uniform_pct > 0: 181 | self.imgs_uniform = uniform.build_epoch(self.imgs, 182 | self.centroids, 183 | num_classes, 184 | cfg.CLASS_UNIFORM_PCT) 185 | else: 186 | self.imgs_uniform = self.imgs 187 | 188 | def __getitem__(self, index): 189 | elem = self.imgs_uniform[index] 190 | centroid = None 191 | if len(elem) == 4: 192 | img_path, mask_path, centroid, class_id = elem 193 | else: 194 | img_path, mask_path = elem 195 | 196 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 197 | img_name = os.path.splitext(os.path.basename(img_path))[0] 198 | 199 | # kitti scale correction factor 200 | if self.mode == 'train' or self.mode == 'trainval': 201 | if self.scf: 202 | width, height = img.size 203 | img = img.resize((width*2, height*2), Image.BICUBIC) 204 | mask = mask.resize((width*2, height*2), Image.NEAREST) 205 | elif self.mode == 'val': 206 | width, height = 1242, 376 207 | img = img.resize((width, height), Image.BICUBIC) 208 | mask = mask.resize((width, height), Image.NEAREST) 209 | else: 210 | logging.info('Unknown mode {}'.format(mode)) 211 | sys.exit() 212 | 213 | mask = np.array(mask) 214 | mask_copy = mask.copy() 215 | 216 | for k, v in id_to_trainid.items(): 217 | mask_copy[mask == k] = v 218 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 219 | 220 | # Image Transformations 221 | if self.joint_transform_list is not None: 222 | for idx, xform in enumerate(self.joint_transform_list): 223 | if idx == 0 and centroid is not None: 224 | # HACK 225 | # We assume that the first transform is capable of taking 226 | # in a centroid 227 | img, mask = xform(img, mask, centroid) 228 | else: 229 | img, mask = xform(img, mask) 230 | 231 | # Debug 232 | if self.dump_images and centroid is not None: 233 | outdir = './dump_imgs_{}'.format(self.mode) 234 | os.makedirs(outdir, exist_ok=True) 235 | dump_img_name = trainid_to_name[class_id] + '_' + img_name 236 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 237 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 238 | mask_img = colorize_mask(np.array(mask)) 239 | img.save(out_img_fn) 240 | mask_img.save(out_msk_fn) 241 | 242 | if self.transform is not None: 243 | img = self.transform(img) 244 | if self.target_transform is not None: 245 | mask = self.target_transform(mask) 246 | 247 | return img, mask, img_name 248 | 249 | def __len__(self): 250 | return len(self.imgs_uniform) 251 | 252 | -------------------------------------------------------------------------------- /datasets/uniform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Uniform sampling of classes. 3 | For all images, for all classes, generate centroids around which to sample. 4 | 5 | All images are divided into tiles. 6 | For each tile, a class can be present or not. If it is 7 | present, calculate the centroid of the class and record it. 8 | 9 | We would like to thank Peter Kontschieder for the inspiration of this idea. 10 | """ 11 | 12 | import logging 13 | from collections import defaultdict 14 | from PIL import Image 15 | import numpy as np 16 | from scipy import ndimage 17 | from tqdm import tqdm 18 | 19 | pbar = None 20 | 21 | class Point(): 22 | """ 23 | Point Class For X and Y Location 24 | """ 25 | def __init__(self, x, y): 26 | self.x = x 27 | self.y = y 28 | 29 | 30 | def calc_tile_locations(tile_size, image_size): 31 | """ 32 | Divide an image into tiles to help us cover classes that are spread out. 33 | tile_size: size of tile to distribute 34 | image_size: original image size 35 | return: locations of the tiles 36 | """ 37 | image_size_y, image_size_x = image_size 38 | locations = [] 39 | for y in range(image_size_y // tile_size): 40 | for x in range(image_size_x // tile_size): 41 | x_offs = x * tile_size 42 | y_offs = y * tile_size 43 | locations.append((x_offs, y_offs)) 44 | return locations 45 | 46 | 47 | def class_centroids_image(item, tile_size, num_classes, id2trainid): 48 | """ 49 | For one image, calculate centroids for all classes present in image. 50 | item: image, image_name 51 | tile_size: 52 | num_classes: 53 | id2trainid: mapping from original id to training ids 54 | return: Centroids are calculated for each tile. 55 | """ 56 | image_fn, label_fn = item 57 | centroids = defaultdict(list) 58 | mask = np.array(Image.open(label_fn)) 59 | image_size = mask.shape 60 | tile_locations = calc_tile_locations(tile_size, image_size) 61 | 62 | mask_copy = mask.copy() 63 | if id2trainid: 64 | for k, v in id2trainid.items(): 65 | mask[mask_copy == k] = v 66 | 67 | for x_offs, y_offs in tile_locations: 68 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size] 69 | for class_id in range(num_classes): 70 | if class_id in patch: 71 | patch_class = (patch == class_id).astype(int) 72 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class) 73 | centroid_y = int(centroid_y) + y_offs 74 | centroid_x = int(centroid_x) + x_offs 75 | centroid = (centroid_x, centroid_y) 76 | centroids[class_id].append((image_fn, label_fn, centroid, class_id)) 77 | pbar.update(1) 78 | return centroids 79 | 80 | import scipy.misc as m 81 | 82 | def class_centroids_image_from_color(item, tile_size, num_classes, id2trainid): 83 | """ 84 | For one image, calculate centroids for all classes present in image. 85 | item: image, image_name 86 | tile_size: 87 | num_classes: 88 | id2trainid: mapping from original id to training ids 89 | return: Centroids are calculated for each tile. 90 | """ 91 | image_fn, label_fn = item 92 | centroids = defaultdict(list) 93 | mask = m.imread(label_fn) 94 | image_size = mask[:,:,0].shape 95 | tile_locations = calc_tile_locations(tile_size, image_size) 96 | 97 | # mask = m.imread(label_fn) 98 | # mask_copy = np.full((img.size[1], img.size[0]), 255, dtype=np.uint8) 99 | # for k, v in id2trainid.items(): 100 | # mask_copy[(mask == k)[:,:,0]] = v 101 | # mask = Image.fromarray(mask_copy.astype(np.uint8)) 102 | 103 | # mask_copy = mask.copy() 104 | # mask_copy = mask.copy() 105 | # if id2trainid: 106 | # for k, v in id2trainid.items(): 107 | # mask[mask_copy == k] = v 108 | 109 | mask_copy = np.full(image_size, 255, dtype=np.uint8) 110 | 111 | if id2trainid: 112 | for k, v in id2trainid.items(): 113 | # print("0", mask.shape) 114 | # print("1", ((mask == np.array(k))[:,:,0]).shape) # 1052, 1914 115 | # # print("2", mask == np.array(k)[:,:,0]) 116 | # break 117 | # if v != 255: 118 | # print(v) 119 | # if v == 2: 120 | # print(k, v, "num", np.count_nonzero(mask == np.array(k))) 121 | # break 122 | if v != 255 and v != -1: 123 | mask_copy[(mask == np.array(k))[:,:,0] & (mask == np.array(k))[:,:,1] & (mask == np.array(k))[:,:,2]] = v 124 | mask = mask_copy 125 | 126 | # mask_copy = mask.copy() 127 | # if id2trainid: 128 | # for k, v in id2trainid.items(): 129 | # mask[mask_copy == k] = v 130 | 131 | for x_offs, y_offs in tile_locations: 132 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size] 133 | for class_id in range(num_classes): 134 | if class_id in patch: 135 | patch_class = (patch == class_id).astype(int) 136 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class) 137 | centroid_y = int(centroid_y) + y_offs 138 | centroid_x = int(centroid_x) + x_offs 139 | centroid = (centroid_x, centroid_y) 140 | centroids[class_id].append((image_fn, label_fn, centroid, class_id)) 141 | pbar.update(1) 142 | return centroids 143 | 144 | def pooled_class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024): 145 | """ 146 | Calculate class centroids for all classes for all images for all tiles. 147 | items: list of (image_fn, label_fn) 148 | tile size: size of tile 149 | returns: dict that contains a list of centroids for each class 150 | """ 151 | from multiprocessing.dummy import Pool 152 | from functools import partial 153 | pool = Pool(32) 154 | global pbar 155 | pbar = tqdm(total=len(items), desc='pooled centroid extraction') 156 | class_centroids_item = partial(class_centroids_image_from_color, 157 | num_classes=num_classes, 158 | id2trainid=id2trainid, 159 | tile_size=tile_size) 160 | 161 | centroids = defaultdict(list) 162 | new_centroids = pool.map(class_centroids_item, items) 163 | pool.close() 164 | pool.join() 165 | 166 | # combine each image's items into a single global dict 167 | for image_items in new_centroids: 168 | for class_id in image_items: 169 | centroids[class_id].extend(image_items[class_id]) 170 | return centroids 171 | 172 | 173 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 174 | """ 175 | Calculate class centroids for all classes for all images for all tiles. 176 | items: list of (image_fn, label_fn) 177 | tile size: size of tile 178 | returns: dict that contains a list of centroids for each class 179 | """ 180 | from multiprocessing.dummy import Pool 181 | from functools import partial 182 | pool = Pool(32) 183 | global pbar 184 | pbar = tqdm(total=len(items), desc='pooled centroid extraction') 185 | class_centroids_item = partial(class_centroids_image, 186 | num_classes=num_classes, 187 | id2trainid=id2trainid, 188 | tile_size=tile_size) 189 | 190 | centroids = defaultdict(list) 191 | new_centroids = pool.map(class_centroids_item, items) 192 | pool.close() 193 | pool.join() 194 | 195 | # combine each image's items into a single global dict 196 | for image_items in new_centroids: 197 | for class_id in image_items: 198 | centroids[class_id].extend(image_items[class_id]) 199 | return centroids 200 | 201 | 202 | def unpooled_class_centroids_all(items, num_classes, tile_size=1024): 203 | """ 204 | Calculate class centroids for all classes for all images for all tiles. 205 | items: list of (image_fn, label_fn) 206 | tile size: size of tile 207 | returns: dict that contains a list of centroids for each class 208 | """ 209 | centroids = defaultdict(list) 210 | global pbar 211 | pbar = tqdm(total=len(items), desc='centroid extraction') 212 | for image, label in items: 213 | new_centroids = class_centroids_image((image, label), 214 | tile_size, 215 | num_classes) 216 | for class_id in new_centroids: 217 | centroids[class_id].extend(new_centroids[class_id]) 218 | 219 | return centroids 220 | 221 | 222 | def class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024): 223 | """ 224 | intermediate function to call pooled_class_centroid 225 | """ 226 | 227 | pooled_centroids = pooled_class_centroids_all_from_color(items, num_classes, 228 | id2trainid, tile_size) 229 | return pooled_centroids 230 | 231 | 232 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 233 | """ 234 | intermediate function to call pooled_class_centroid 235 | """ 236 | 237 | pooled_centroids = pooled_class_centroids_all(items, num_classes, 238 | id2trainid, tile_size) 239 | return pooled_centroids 240 | 241 | 242 | def random_sampling(alist, num): 243 | """ 244 | Randomly sample num items from the list 245 | alist: list of centroids to sample from 246 | num: can be larger than the list and if so, then wrap around 247 | return: class uniform samples from the list 248 | """ 249 | sampling = [] 250 | len_list = len(alist) 251 | assert len_list, 'len_list is zero!' 252 | indices = np.arange(len_list) 253 | np.random.shuffle(indices) 254 | 255 | for i in range(num): 256 | item = alist[indices[i % len_list]] 257 | sampling.append(item) 258 | return sampling 259 | 260 | 261 | def build_epoch(imgs, centroids, num_classes, class_uniform_pct): 262 | """ 263 | Generate an epochs-worth of crops using uniform sampling. Needs to be called every 264 | imgs: list of imgs 265 | centroids: 266 | num_classes: 267 | class_uniform_pct: class uniform sampling percent ( % of uniform images in one epoch ) 268 | """ 269 | logging.info("Class Uniform Percentage: %s", str(class_uniform_pct)) 270 | num_epoch = int(len(imgs)) 271 | 272 | logging.info('Class Uniform items per Epoch:%s', str(num_epoch)) 273 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes) 274 | num_rand = num_epoch - num_per_class * num_classes 275 | # create random crops 276 | imgs_uniform = random_sampling(imgs, num_rand) 277 | 278 | # now add uniform sampling 279 | for class_id in range(num_classes): 280 | string_format = "cls %d len %d"% (class_id, len(centroids[class_id])) 281 | logging.info(string_format) 282 | for class_id in range(num_classes): 283 | centroid_len = len(centroids[class_id]) 284 | if centroid_len == 0: 285 | pass 286 | else: 287 | class_centroids = random_sampling(centroids[class_id], num_per_class) 288 | imgs_uniform.extend(class_centroids) 289 | 290 | return imgs_uniform 291 | -------------------------------------------------------------------------------- /transforms/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code borrowded from: 3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py 4 | # 5 | # 6 | # MIT License 7 | # 8 | # Copyright (c) 2017 ZijunDeng 9 | # 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy 11 | # of this software and associated documentation files (the "Software"), to deal 12 | # in the Software without restriction, including without limitation the rights 13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | # copies of the Software, and to permit persons to whom the Software is 15 | # furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all 18 | # copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | # SOFTWARE. 27 | 28 | """ 29 | 30 | """ 31 | Standard Transform 32 | """ 33 | 34 | import random 35 | import numpy as np 36 | from skimage.filters import gaussian 37 | from skimage.restoration import denoise_bilateral 38 | import torch 39 | from PIL import Image, ImageEnhance 40 | import torchvision.transforms as torch_tr 41 | from config import cfg 42 | from scipy.ndimage.interpolation import shift 43 | 44 | from skimage.segmentation import find_boundaries 45 | 46 | try: 47 | import accimage 48 | except ImportError: 49 | accimage = None 50 | 51 | 52 | class RandomVerticalFlip(object): 53 | def __call__(self, img): 54 | if random.random() < 0.5: 55 | return img.transpose(Image.FLIP_TOP_BOTTOM) 56 | return img 57 | 58 | 59 | class DeNormalize(object): 60 | def __init__(self, mean, std): 61 | self.mean = mean 62 | self.std = std 63 | 64 | def __call__(self, tensor): 65 | for t, m, s in zip(tensor, self.mean, self.std): 66 | t.mul_(s).add_(m) 67 | return tensor 68 | 69 | 70 | class MaskToTensor(object): 71 | def __call__(self, img): 72 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 73 | 74 | class RelaxedBoundaryLossToTensor(object): 75 | """ 76 | Boundary Relaxation 77 | """ 78 | def __init__(self,ignore_id, num_classes): 79 | self.ignore_id=ignore_id 80 | self.num_classes= num_classes 81 | 82 | 83 | def new_one_hot_converter(self,a): 84 | ncols = self.num_classes+1 85 | out = np.zeros( (a.size,ncols), dtype=np.uint8) 86 | out[np.arange(a.size),a.ravel()] = 1 87 | out.shape = a.shape + (ncols,) 88 | return out 89 | 90 | def __call__(self,img): 91 | 92 | img_arr = np.array(img) 93 | img_arr[img_arr==self.ignore_id]=self.num_classes 94 | 95 | if cfg.STRICTBORDERCLASS != None: 96 | one_hot_orig = self.new_one_hot_converter(img_arr) 97 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1])) 98 | for cls in cfg.STRICTBORDERCLASS: 99 | mask = np.logical_or(mask,(img_arr == cls)) 100 | one_hot = 0 101 | 102 | border = cfg.BORDER_WINDOW 103 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER): 104 | border = border // 2 105 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8) 106 | 107 | for i in range(-border,border+1): 108 | for j in range(-border, border+1): 109 | shifted= shift(img_arr,(i,j), cval=self.num_classes) 110 | one_hot += self.new_one_hot_converter(shifted) 111 | 112 | one_hot[one_hot>1] = 1 113 | 114 | if cfg.STRICTBORDERCLASS != None: 115 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot) 116 | 117 | one_hot = np.moveaxis(one_hot,-1,0) 118 | 119 | 120 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER): 121 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot) 122 | # print(one_hot.shape) 123 | return torch.from_numpy(one_hot).byte() 124 | 125 | class ResizeHeight(object): 126 | def __init__(self, size, interpolation=Image.BILINEAR): 127 | self.target_h = size 128 | self.interpolation = interpolation 129 | 130 | def __call__(self, img): 131 | w, h = img.size 132 | target_w = int(w / h * self.target_h) 133 | return img.resize((target_w, self.target_h), self.interpolation) 134 | 135 | 136 | class FreeScale(object): 137 | def __init__(self, size, interpolation=Image.BILINEAR): 138 | self.size = tuple(reversed(size)) # size: (h, w) 139 | self.interpolation = interpolation 140 | 141 | def __call__(self, img): 142 | return img.resize(self.size, self.interpolation) 143 | 144 | 145 | class FlipChannels(object): 146 | """ 147 | Flip around the x-axis 148 | """ 149 | def __call__(self, img): 150 | img = np.array(img)[:, :, ::-1] 151 | return Image.fromarray(img.astype(np.uint8)) 152 | 153 | 154 | class RandomGaussianBlur(object): 155 | """ 156 | Apply Gaussian Blur 157 | """ 158 | def __call__(self, img): 159 | sigma = 0.15 + random.random() * 1.15 160 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 161 | blurred_img *= 255 162 | return Image.fromarray(blurred_img.astype(np.uint8)) 163 | 164 | 165 | class RandomBilateralBlur(object): 166 | """ 167 | Apply Bilateral Filtering 168 | 169 | """ 170 | def __call__(self, img): 171 | sigma = random.uniform(0.05,0.75) 172 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True) 173 | blurred_img *= 255 174 | return Image.fromarray(blurred_img.astype(np.uint8)) 175 | 176 | def _is_pil_image(img): 177 | if accimage is not None: 178 | return isinstance(img, (Image.Image, accimage.Image)) 179 | else: 180 | return isinstance(img, Image.Image) 181 | 182 | 183 | def adjust_brightness(img, brightness_factor): 184 | """Adjust brightness of an Image. 185 | 186 | Args: 187 | img (PIL Image): PIL Image to be adjusted. 188 | brightness_factor (float): How much to adjust the brightness. Can be 189 | any non negative number. 0 gives a black image, 1 gives the 190 | original image while 2 increases the brightness by a factor of 2. 191 | 192 | Returns: 193 | PIL Image: Brightness adjusted image. 194 | """ 195 | if not _is_pil_image(img): 196 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 197 | 198 | enhancer = ImageEnhance.Brightness(img) 199 | img = enhancer.enhance(brightness_factor) 200 | return img 201 | 202 | 203 | def adjust_contrast(img, contrast_factor): 204 | """Adjust contrast of an Image. 205 | 206 | Args: 207 | img (PIL Image): PIL Image to be adjusted. 208 | contrast_factor (float): How much to adjust the contrast. Can be any 209 | non negative number. 0 gives a solid gray image, 1 gives the 210 | original image while 2 increases the contrast by a factor of 2. 211 | 212 | Returns: 213 | PIL Image: Contrast adjusted image. 214 | """ 215 | if not _is_pil_image(img): 216 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 217 | 218 | enhancer = ImageEnhance.Contrast(img) 219 | img = enhancer.enhance(contrast_factor) 220 | return img 221 | 222 | 223 | def adjust_saturation(img, saturation_factor): 224 | """Adjust color saturation of an image. 225 | 226 | Args: 227 | img (PIL Image): PIL Image to be adjusted. 228 | saturation_factor (float): How much to adjust the saturation. 0 will 229 | give a black and white image, 1 will give the original image while 230 | 2 will enhance the saturation by a factor of 2. 231 | 232 | Returns: 233 | PIL Image: Saturation adjusted image. 234 | """ 235 | if not _is_pil_image(img): 236 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 237 | 238 | enhancer = ImageEnhance.Color(img) 239 | img = enhancer.enhance(saturation_factor) 240 | return img 241 | 242 | 243 | def adjust_hue(img, hue_factor): 244 | """Adjust hue of an image. 245 | 246 | The image hue is adjusted by converting the image to HSV and 247 | cyclically shifting the intensities in the hue channel (H). 248 | The image is then converted back to original image mode. 249 | 250 | `hue_factor` is the amount of shift in H channel and must be in the 251 | interval `[-0.5, 0.5]`. 252 | 253 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 254 | 255 | Args: 256 | img (PIL Image): PIL Image to be adjusted. 257 | hue_factor (float): How much to shift the hue channel. Should be in 258 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 259 | HSV space in positive and negative direction respectively. 260 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 261 | with complementary colors while 0 gives the original image. 262 | 263 | Returns: 264 | PIL Image: Hue adjusted image. 265 | """ 266 | if not(-0.5 <= hue_factor <= 0.5): 267 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 268 | 269 | if not _is_pil_image(img): 270 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 271 | 272 | input_mode = img.mode 273 | if input_mode in {'L', '1', 'I', 'F'}: 274 | return img 275 | 276 | h, s, v = img.convert('HSV').split() 277 | 278 | np_h = np.array(h, dtype=np.uint8) 279 | # uint8 addition take cares of rotation across boundaries 280 | with np.errstate(over='ignore'): 281 | np_h += np.uint8(hue_factor * 255) 282 | h = Image.fromarray(np_h, 'L') 283 | 284 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 285 | return img 286 | 287 | 288 | class ColorJitter(object): 289 | """Randomly change the brightness, contrast and saturation of an image. 290 | 291 | Args: 292 | brightness (float): How much to jitter brightness. brightness_factor 293 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 294 | contrast (float): How much to jitter contrast. contrast_factor 295 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 296 | saturation (float): How much to jitter saturation. saturation_factor 297 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 298 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 299 | [-hue, hue]. Should be >=0 and <= 0.5. 300 | """ 301 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 302 | self.brightness = brightness 303 | self.contrast = contrast 304 | self.saturation = saturation 305 | self.hue = hue 306 | 307 | @staticmethod 308 | def get_params(brightness, contrast, saturation, hue): 309 | """Get a randomized transform to be applied on image. 310 | 311 | Arguments are same as that of __init__. 312 | 313 | Returns: 314 | Transform which randomly adjusts brightness, contrast and 315 | saturation in a random order. 316 | """ 317 | transforms = [] 318 | if brightness > 0: 319 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 320 | transforms.append( 321 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) 322 | 323 | if contrast > 0: 324 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 325 | transforms.append( 326 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) 327 | 328 | if saturation > 0: 329 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 330 | transforms.append( 331 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) 332 | 333 | if hue > 0: 334 | hue_factor = np.random.uniform(-hue, hue) 335 | transforms.append( 336 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) 337 | 338 | np.random.shuffle(transforms) 339 | transform = torch_tr.Compose(transforms) 340 | 341 | return transform 342 | 343 | def __call__(self, img): 344 | """ 345 | Args: 346 | img (PIL Image): Input image. 347 | 348 | Returns: 349 | PIL Image: Color jittered image. 350 | """ 351 | transform = self.get_params(self.brightness, self.contrast, 352 | self.saturation, self.hue) 353 | return transform(img) 354 | -------------------------------------------------------------------------------- /datasets/cityscapes_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | # File taken from https://github.com/mcordts/cityscapesScripts/ 3 | # License File Available at: 4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt 5 | 6 | # ---------------------- 7 | # The Cityscapes Dataset 8 | # ---------------------- 9 | # 10 | # 11 | # License agreement 12 | # ----------------- 13 | # 14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree: 15 | # 16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions. 17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website. 18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character. 19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain. 20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt). 21 | # 22 | # 23 | # Contact 24 | # ------- 25 | # 26 | # Marius Cordts, Mohamed Omran 27 | # www.cityscapes-dataset.net 28 | 29 | """ 30 | from collections import namedtuple 31 | 32 | 33 | #-------------------------------------------------------------------------------- 34 | # Definitions 35 | #-------------------------------------------------------------------------------- 36 | 37 | # a label and all meta information 38 | Label = namedtuple( 'Label' , [ 39 | 40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 41 | # We use them to uniquely name a class 42 | 43 | 'id' , # An integer ID that is associated with this label. 44 | # The IDs are used to represent the label in ground truth images 45 | # An ID of -1 means that this label does not have an ID and thus 46 | # is ignored when creating ground truth images (e.g. license plate). 47 | # Do not modify these IDs, since exactly these IDs are expected by the 48 | # evaluation server. 49 | 50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 51 | # ground truth images with train IDs, using the tools provided in the 52 | # 'preparation' folder. However, make sure to validate or submit results 53 | # to our evaluation server using the regular IDs above! 54 | # For trainIds, multiple labels might have the same ID. Then, these labels 55 | # are mapped to the same class in the ground truth images. For the inverse 56 | # mapping, we use the label that is defined first in the list below. 57 | # For example, mapping all void-type classes to the same ID in training, 58 | # might make sense for some approaches. 59 | # Max value is 255! 60 | 61 | 'category' , # The name of the category that this label belongs to 62 | 63 | 'categoryId' , # The ID of this category. Used to create ground truth images 64 | # on category level. 65 | 66 | 'hasInstances', # Whether this label distinguishes between single instances or not 67 | 68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 69 | # during evaluations or not 70 | 71 | 'color' , # The color of this label 72 | ] ) 73 | 74 | 75 | #-------------------------------------------------------------------------------- 76 | # A list of all labels 77 | #-------------------------------------------------------------------------------- 78 | 79 | # Please adapt the train IDs as appropriate for you approach. 80 | # Note that you might want to ignore labels with ID 255 during training. 81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 82 | # Make sure to provide your results using the original IDs and not the training IDs. 83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 84 | 85 | labels = [ 86 | # name id trainId category catId hasInstances ignoreInEval color 87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,154) ), # (153,153,153) 106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,143) ), # ( 0, 0,142) 122 | ] 123 | 124 | 125 | #-------------------------------------------------------------------------------- 126 | # Create dictionaries for a fast lookup 127 | #-------------------------------------------------------------------------------- 128 | 129 | # Please refer to the main method below for example usages! 130 | 131 | # name to label object 132 | name2label = { label.name : label for label in labels } 133 | # id to label object 134 | id2label = { label.id : label for label in labels } 135 | # trainId to label object 136 | trainId2label = { label.trainId : label for label in reversed(labels) } 137 | # label2trainid 138 | label2trainid = { label.id : label.trainId for label in labels } 139 | # trainId to label object 140 | trainId2name = { label.trainId : label.name for label in labels } 141 | trainId2color = { label.trainId : label.color for label in labels } 142 | 143 | color2trainId = { label.color : label.trainId for label in labels } 144 | 145 | trainId2trainId = { label.trainId : label.trainId for label in labels } 146 | 147 | # category to list of label objects 148 | category2labels = {} 149 | for label in labels: 150 | category = label.category 151 | if category in category2labels: 152 | category2labels[category].append(label) 153 | else: 154 | category2labels[category] = [label] 155 | 156 | #-------------------------------------------------------------------------------- 157 | # Assure single instance name 158 | #-------------------------------------------------------------------------------- 159 | 160 | # returns the label name that describes a single instance (if possible) 161 | # e.g. input | output 162 | # ---------------------- 163 | # car | car 164 | # cargroup | car 165 | # foo | None 166 | # foogroup | None 167 | # skygroup | None 168 | def assureSingleInstanceName( name ): 169 | # if the name is known, it is not a group 170 | if name in name2label: 171 | return name 172 | # test if the name actually denotes a group 173 | if not name.endswith("group"): 174 | return None 175 | # remove group 176 | name = name[:-len("group")] 177 | # test if the new name exists 178 | if not name in name2label: 179 | return None 180 | # test if the new name denotes a label that actually has instances 181 | if not name2label[name].hasInstances: 182 | return None 183 | # all good then 184 | return name 185 | 186 | #-------------------------------------------------------------------------------- 187 | # Main for testing 188 | #-------------------------------------------------------------------------------- 189 | 190 | # just a dummy main 191 | if __name__ == "__main__": 192 | # Print all the labels 193 | print("List of cityscapes labels:") 194 | print("") 195 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))) 196 | print((" " + ('-' * 98))) 197 | for label in labels: 198 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))) 199 | print("") 200 | 201 | print("Example usages:") 202 | 203 | # Map from name to label 204 | name = 'car' 205 | id = name2label[name].id 206 | print(("ID of label '{name}': {id}".format( name=name, id=id ))) 207 | 208 | # Map from ID to label 209 | category = id2label[id].category 210 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category ))) 211 | 212 | # Map from trainID to label 213 | trainId = 0 214 | name = trainId2label[trainId].name 215 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))) 216 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellanous Functions 3 | """ 4 | 5 | import sys 6 | import re 7 | import os 8 | import shutil 9 | import torch 10 | from datetime import datetime 11 | import logging 12 | from subprocess import call 13 | import shlex 14 | from tensorboardX import SummaryWriter 15 | import numpy as np 16 | import torchvision.transforms as standard_transforms 17 | import torchvision.utils as vutils 18 | 19 | 20 | # Create unique output dir name based on non-default command line args 21 | def make_exp_name(args, parser): 22 | exp_name = '{}-{}'.format(args.dataset[:4], args.arch[:]) 23 | dict_args = vars(args) 24 | 25 | # sort so that we get a consistent directory name 26 | argnames = sorted(dict_args) 27 | ignorelist = ['date', 'exp', 'arch','prev_best_filepath', 'lr_schedule', 'max_cu_epoch', 'max_epoch', 28 | 'strict_bdr_cls', 'world_size', 'tb_path','best_record', 'test_mode', 'ckpt', 'coarse_boost_classes', 29 | 'crop_size', 'dist_url', 'syncbn', 'max_iter', 'color_aug', 'scale_max', 'scale_min', 'bs_mult', 30 | 'hanet_lr', 'class_uniform_pct', 'class_uniform_tile', 'hanet', 'hanet_set', 'hanet_pos'] 31 | # build experiment name with non-default args 32 | for argname in argnames: 33 | if dict_args[argname] != parser.get_default(argname): 34 | if argname in ignorelist: 35 | continue 36 | if argname == 'snapshot': 37 | arg_str = 'PT' 38 | argname = '' 39 | elif argname == 'nosave': 40 | arg_str = '' 41 | argname='' 42 | elif argname == 'freeze_trunk': 43 | argname = '' 44 | arg_str = 'ft' 45 | elif argname == 'syncbn': 46 | argname = '' 47 | arg_str = 'sbn' 48 | elif argname == 'jointwtborder': 49 | argname = '' 50 | arg_str = 'rlx_loss' 51 | elif isinstance(dict_args[argname], bool): 52 | arg_str = 'T' if dict_args[argname] else 'F' 53 | else: 54 | arg_str = str(dict_args[argname])[:7] 55 | if argname is not '': 56 | exp_name += '_{}_{}'.format(str(argname), arg_str) 57 | else: 58 | exp_name += '_{}'.format(arg_str) 59 | # clean special chars out exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name) 60 | return exp_name 61 | 62 | def fast_hist(label_pred, label_true, num_classes): 63 | mask = (label_true >= 0) & (label_true < num_classes) 64 | hist = np.bincount( 65 | num_classes * label_true[mask].astype(int) + 66 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 67 | return hist 68 | 69 | def per_class_iu(hist): 70 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 71 | 72 | def save_log(prefix, output_dir, date_str, rank=0): 73 | fmt = '%(asctime)s.%(msecs)03d %(message)s' 74 | date_fmt = '%m-%d %H:%M:%S' 75 | filename = os.path.join(output_dir, prefix + '_' + date_str +'_rank_' + str(rank) +'.log') 76 | print("Logging :", filename) 77 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt, 78 | filename=filename, filemode='w') 79 | console = logging.StreamHandler() 80 | console.setLevel(logging.INFO) 81 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt) 82 | console.setFormatter(formatter) 83 | if rank == 0: 84 | logging.getLogger('').addHandler(console) 85 | else: 86 | fh = logging.FileHandler(filename) 87 | logging.getLogger('').addHandler(fh) 88 | 89 | 90 | 91 | def prep_experiment(args, parser): 92 | """ 93 | Make output directories, setup logging, Tensorboard, snapshot code. 94 | """ 95 | ckpt_path = args.ckpt 96 | tb_path = args.tb_path 97 | exp_name = make_exp_name(args, parser) 98 | args.exp_path = os.path.join(ckpt_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H'))) 99 | args.tb_exp_path = os.path.join(tb_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H'))) 100 | args.ngpu = torch.cuda.device_count() 101 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S')) 102 | args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 103 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 104 | args.last_record = {} 105 | if args.local_rank == 0: 106 | os.makedirs(args.exp_path, exist_ok=True) 107 | os.makedirs(args.tb_exp_path, exist_ok=True) 108 | save_log('log', args.exp_path, args.date_str, rank=args.local_rank) 109 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write( 110 | str(args) + '\n\n') 111 | writer = SummaryWriter(log_dir=args.tb_exp_path, comment=args.tb_tag) 112 | return writer 113 | return None 114 | 115 | def evaluate_eval_for_inference(hist, dataset=None): 116 | """ 117 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for 118 | large dataset) Only applies to eval/eval.py 119 | """ 120 | # axis 0: gt, axis 1: prediction 121 | acc = np.diag(hist).sum() / hist.sum() 122 | acc_cls = np.diag(hist) / hist.sum(axis=1) 123 | acc_cls = np.nanmean(acc_cls) 124 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 125 | 126 | print_evaluate_results(hist, iu, dataset=dataset) 127 | freq = hist.sum(axis=1) / hist.sum() 128 | mean_iu = np.nanmean(iu) 129 | logging.info('mean {}'.format(mean_iu)) 130 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 131 | return acc, acc_cls, mean_iu, fwavacc 132 | 133 | 134 | 135 | def evaluate_eval(args, net, optimizer, scheduler, val_loss, hist, dump_images, writer, epoch=0, dataset=None, curr_iter=0, optimizer_at=None, scheduler_at=None): 136 | """ 137 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for 138 | large dataset) Only applies to eval/eval.py 139 | """ 140 | # axis 0: gt, axis 1: prediction 141 | acc = np.diag(hist).sum() / hist.sum() 142 | acc_cls = np.diag(hist) / hist.sum(axis=1) 143 | acc_cls = np.nanmean(acc_cls) 144 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 145 | 146 | print_evaluate_results(hist, iu, dataset) 147 | freq = hist.sum(axis=1) / hist.sum() 148 | mean_iu = np.nanmean(iu) 149 | logging.info('mean {}'.format(mean_iu)) 150 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 151 | 152 | # update latest snapshot 153 | if 'mean_iu' in args.last_record: 154 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format( 155 | args.last_record['epoch'], args.last_record['mean_iu']) 156 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 157 | try: 158 | os.remove(last_snapshot) 159 | except OSError: 160 | pass 161 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format(epoch, mean_iu) 162 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 163 | args.last_record['mean_iu'] = mean_iu 164 | args.last_record['epoch'] = epoch 165 | 166 | torch.cuda.synchronize() 167 | 168 | if optimizer_at is not None: 169 | torch.save({ 170 | 'state_dict': net.state_dict(), 171 | 'optimizer': optimizer.state_dict(), 172 | 'optimizer_at': optimizer_at.state_dict(), 173 | 'scheduler': scheduler.state_dict(), 174 | 'scheduler_at': scheduler_at.state_dict(), 175 | 'epoch': epoch, 176 | 'mean_iu': mean_iu, 177 | 'command': ' '.join(sys.argv[1:]) 178 | }, last_snapshot) 179 | else: 180 | torch.save({ 181 | 'state_dict': net.state_dict(), 182 | 'optimizer': optimizer.state_dict(), 183 | 'scheduler': scheduler.state_dict(), 184 | 'epoch': epoch, 185 | 'mean_iu': mean_iu, 186 | 'command': ' '.join(sys.argv[1:]) 187 | }, last_snapshot) 188 | 189 | # update best snapshot 190 | if mean_iu > args.best_record['mean_iu'] : 191 | # remove old best snapshot 192 | if args.best_record['epoch'] != -1: 193 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format( 194 | args.best_record['epoch'], args.best_record['mean_iu']) 195 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 196 | assert os.path.exists(best_snapshot), \ 197 | 'cant find old snapshot {}'.format(best_snapshot) 198 | os.remove(best_snapshot) 199 | 200 | 201 | # save new best 202 | args.best_record['val_loss'] = val_loss.avg 203 | args.best_record['epoch'] = epoch 204 | args.best_record['acc'] = acc 205 | args.best_record['acc_cls'] = acc_cls 206 | args.best_record['mean_iu'] = mean_iu 207 | args.best_record['fwavacc'] = fwavacc 208 | 209 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format( 210 | args.best_record['epoch'], args.best_record['mean_iu']) 211 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 212 | shutil.copyfile(last_snapshot, best_snapshot) 213 | 214 | 215 | # to_save_dir = os.path.join(args.exp_path, 'best_images') 216 | # os.makedirs(to_save_dir, exist_ok=True) 217 | # val_visual = [] 218 | 219 | # idx = 0 220 | 221 | # visualize = standard_transforms.Compose([ 222 | # standard_transforms.Scale(384), 223 | # standard_transforms.ToTensor() 224 | # ]) 225 | # for bs_idx, bs_data in enumerate(dump_images): 226 | # for local_idx, data in enumerate(zip(bs_data[0], bs_data[1],bs_data[2])): 227 | # gt_pil = args.dataset_cls.colorize_mask(data[0].cpu().numpy()) 228 | # predictions_pil = args.dataset_cls.colorize_mask(data[1].cpu().numpy()) 229 | # img_name = data[2] 230 | 231 | # prediction_fn = '{}_prediction.png'.format(img_name) 232 | # predictions_pil.save(os.path.join(to_save_dir, prediction_fn)) 233 | # gt_fn = '{}_gt.png'.format(img_name) 234 | # gt_pil.save(os.path.join(to_save_dir, gt_fn)) 235 | # val_visual.extend([visualize(gt_pil.convert('RGB')), 236 | # visualize(predictions_pil.convert('RGB'))]) 237 | # if local_idx >= 9: 238 | # break 239 | # val_visual = torch.stack(val_visual, 0) 240 | # val_visual = vutils.make_grid(val_visual, nrow=10, padding=5) 241 | # writer.add_image('imgs', val_visual, curr_iter ) 242 | 243 | logging.info('-' * 107) 244 | fmt_str = '[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 245 | '[mean_iu %.5f], [fwavacc %.5f]' 246 | logging.info(fmt_str % (epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 247 | fmt_str = 'best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 248 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], ' 249 | logging.info(fmt_str % (args.best_record['val_loss'], args.best_record['acc'], 250 | args.best_record['acc_cls'], args.best_record['mean_iu'], 251 | args.best_record['fwavacc'], args.best_record['epoch'])) 252 | logging.info('-' * 107) 253 | 254 | # tensorboard logging of validation phase metrics 255 | 256 | writer.add_scalar('training/acc', acc, curr_iter) 257 | writer.add_scalar('training/acc_cls', acc_cls, curr_iter) 258 | writer.add_scalar('training/mean_iu', mean_iu, curr_iter) 259 | writer.add_scalar('training/val_loss', val_loss.avg, curr_iter) 260 | 261 | 262 | 263 | 264 | 265 | def print_evaluate_results(hist, iu, dataset=None): 266 | # fixme: Need to refactor this dict 267 | try: 268 | id2cat = dataset.id2cat 269 | except: 270 | id2cat = {i: i for i in range(dataset.num_classes)} 271 | iu_false_positive = hist.sum(axis=1) - np.diag(hist) 272 | iu_false_negative = hist.sum(axis=0) - np.diag(hist) 273 | iu_true_positive = np.diag(hist) 274 | 275 | logging.info('IoU:') 276 | logging.info('label_id label iU Precision Recall TP FP FN') 277 | for idx, i in enumerate(iu): 278 | # Format all of the strings: 279 | idx_string = "{:2d}".format(idx) 280 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else '' 281 | iu_string = '{:5.1f}'.format(i * 100) 282 | total_pixels = hist.sum() 283 | tp = '{:5.1f}'.format(100 * iu_true_positive[idx] / total_pixels) 284 | fp = '{:5.1f}'.format( 285 | iu_false_positive[idx] / iu_true_positive[idx]) 286 | fn = '{:5.1f}'.format(iu_false_negative[idx] / iu_true_positive[idx]) 287 | precision = '{:5.1f}'.format( 288 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx])) 289 | recall = '{:5.1f}'.format( 290 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx])) 291 | logging.info('{} {} {} {} {} {} {} {}'.format( 292 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn)) 293 | 294 | 295 | 296 | 297 | class AverageMeter(object): 298 | 299 | def __init__(self): 300 | self.reset() 301 | 302 | def reset(self): 303 | self.val = 0 304 | self.avg = 0 305 | self.sum = 0 306 | self.count = 0 307 | 308 | def update(self, val, n=1): 309 | self.val = val 310 | self.sum += val * n 311 | self.count += n 312 | self.avg = self.sum / self.count 313 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss.py 3 | """ 4 | 5 | import logging 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from config import cfg 11 | 12 | 13 | def get_loss(args): 14 | """ 15 | Get the criterion based on the loss function 16 | args: commandline arguments 17 | return: criterion, criterion_val 18 | """ 19 | if args.cls_wt_loss: 20 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, 21 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 22 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 23 | else: 24 | ce_weight = None 25 | 26 | if args.img_wt_loss: 27 | criterion = ImageBasedCrossEntropyLoss2d( 28 | classes=args.dataset_cls.num_classes, size_average=True, 29 | ignore_index=args.dataset_cls.ignore_label, 30 | upper_bound=args.wt_bound).cuda() 31 | elif args.jointwtborder: 32 | criterion = ImgWtLossSoftNLL(classes=args.dataset_cls.num_classes, 33 | ignore_index=args.dataset_cls.ignore_label, 34 | upper_bound=args.wt_bound).cuda() 35 | else: 36 | print("standard cross entropy") 37 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean', 38 | ignore_index=args.dataset_cls.ignore_label).cuda() 39 | 40 | criterion_val = nn.CrossEntropyLoss(reduction='mean', 41 | ignore_index=args.dataset_cls.ignore_label).cuda() 42 | return criterion, criterion_val 43 | 44 | def get_loss_by_epoch(args): 45 | """ 46 | Get the criterion based on the loss function 47 | args: commandline arguments 48 | return: criterion, criterion_val 49 | """ 50 | 51 | if args.img_wt_loss: 52 | criterion = ImageBasedCrossEntropyLoss2d( 53 | classes=args.dataset_cls.num_classes, size_average=True, 54 | ignore_index=args.dataset_cls.ignore_label, 55 | upper_bound=args.wt_bound).cuda() 56 | elif args.jointwtborder: 57 | criterion = ImgWtLossSoftNLL_by_epoch(classes=args.dataset_cls.num_classes, 58 | ignore_index=args.dataset_cls.ignore_label, 59 | upper_bound=args.wt_bound).cuda() 60 | else: 61 | criterion = CrossEntropyLoss2d(size_average=True, 62 | ignore_index=args.dataset_cls.ignore_label).cuda() 63 | 64 | criterion_val = CrossEntropyLoss2d(size_average=True, 65 | weight=None, 66 | ignore_index=args.dataset_cls.ignore_label).cuda() 67 | return criterion, criterion_val 68 | 69 | 70 | def get_loss_aux(args): 71 | """ 72 | Get the criterion based on the loss function 73 | args: commandline arguments 74 | return: criterion, criterion_val 75 | """ 76 | if args.cls_wt_loss: 77 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, 78 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 79 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 80 | else: 81 | ce_weight = None 82 | 83 | print("standard cross entropy") 84 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean', 85 | ignore_index=args.dataset_cls.ignore_label).cuda() 86 | 87 | return criterion 88 | 89 | def get_loss_bcelogit(args): 90 | if args.cls_wt_loss: 91 | pos_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, 92 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 93 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 94 | else: 95 | pos_weight = None 96 | print("standard bce with logit cross entropy") 97 | criterion = nn.BCEWithLogitsLoss(reduction='mean').cuda() 98 | 99 | return criterion 100 | 101 | def weighted_binary_cross_entropy(output, target): 102 | 103 | weights = torch.Tensor([0.1, 0.9]) 104 | 105 | loss = weights[1] * (target * torch.log(output)) + \ 106 | weights[0] * ((1 - target) * torch.log(1 - output)) 107 | 108 | return torch.neg(torch.mean(loss)) 109 | 110 | class ImageBasedCrossEntropyLoss2d(nn.Module): 111 | """ 112 | Image Weighted Cross Entropy Loss 113 | """ 114 | 115 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255, 116 | norm=False, upper_bound=1.0): 117 | super(ImageBasedCrossEntropyLoss2d, self).__init__() 118 | logging.info("Using Per Image based weighted loss") 119 | self.num_classes = classes 120 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index) 121 | self.norm = norm 122 | self.upper_bound = upper_bound 123 | self.batch_weights = cfg.BATCH_WEIGHTING 124 | self.logsoftmax = nn.LogSoftmax(dim=1) 125 | 126 | def calculate_weights(self, target): 127 | """ 128 | Calculate weights of classes based on the training crop 129 | """ 130 | hist = np.histogram(target.flatten(), range( 131 | self.num_classes + 1), normed=True)[0] 132 | if self.norm: 133 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 134 | else: 135 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 136 | return hist 137 | 138 | def forward(self, inputs, targets): 139 | 140 | target_cpu = targets.data.cpu().numpy() 141 | if self.batch_weights: 142 | weights = self.calculate_weights(target_cpu) 143 | self.nll_loss.weight = torch.Tensor(weights).cuda() 144 | 145 | loss = 0.0 146 | for i in range(0, inputs.shape[0]): 147 | if not self.batch_weights: 148 | weights = self.calculate_weights(target_cpu[i]) 149 | self.nll_loss.weight = torch.Tensor(weights).cuda() 150 | 151 | loss += self.nll_loss(self.logsoftmax(inputs[i].unsqueeze(0)), 152 | targets[i].unsqueeze(0)) 153 | return loss 154 | 155 | 156 | 157 | class CrossEntropyLoss2d(nn.Module): 158 | """ 159 | Cross Entroply NLL Loss 160 | """ 161 | 162 | def __init__(self, weight=None, size_average=True, ignore_index=255): 163 | super(CrossEntropyLoss2d, self).__init__() 164 | logging.info("Using Cross Entropy Loss") 165 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index) 166 | self.logsoftmax = nn.LogSoftmax(dim=1) 167 | # self.weight = weight 168 | 169 | def forward(self, inputs, targets): 170 | return self.nll_loss(self.logsoftmax(inputs), targets) 171 | 172 | def customsoftmax(inp, multihotmask): 173 | """ 174 | Custom Softmax 175 | """ 176 | soft = F.softmax(inp, dim=1) 177 | # This takes the mask * softmax ( sums it up hence summing up the classes in border 178 | # then takes of summed up version vs no summed version 179 | return torch.log( 180 | torch.max(soft, (multihotmask * (soft * multihotmask).sum(1, keepdim=True))) 181 | ) 182 | 183 | class ImgWtLossSoftNLL(nn.Module): 184 | """ 185 | Relax Loss 186 | """ 187 | 188 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0, 189 | norm=False): 190 | super(ImgWtLossSoftNLL, self).__init__() 191 | self.weights = weights 192 | self.num_classes = classes 193 | self.ignore_index = ignore_index 194 | self.upper_bound = upper_bound 195 | self.norm = norm 196 | self.batch_weights = cfg.BATCH_WEIGHTING 197 | 198 | def calculate_weights(self, target): 199 | """ 200 | Calculate weights of the classes based on training crop 201 | """ 202 | if len(target.shape) == 3: 203 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum() 204 | else: 205 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum() 206 | if self.norm: 207 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 208 | else: 209 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 210 | return hist[:-1] 211 | 212 | def custom_nll(self, inputs, target, class_weights, border_weights, mask): 213 | """ 214 | NLL Relaxed Loss Implementation 215 | """ 216 | if (cfg.REDUCE_BORDER_ITER != -1 and cfg.ITER > cfg.REDUCE_BORDER_ITER): 217 | border_weights = 1 / border_weights 218 | target[target > 1] = 1 219 | 220 | loss_matrix = (-1 / border_weights * 221 | (target[:, :-1, :, :].float() * 222 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 223 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \ 224 | (1. - mask.float()) 225 | 226 | # loss_matrix[border_weights > 1] = 0 227 | loss = loss_matrix.sum() 228 | 229 | # +1 to prevent division by 0 230 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1) 231 | return loss 232 | 233 | def forward(self, inputs, target): 234 | weights = target[:, :-1, :, :].sum(1).float() 235 | ignore_mask = (weights == 0) 236 | weights[ignore_mask] = 1 237 | 238 | loss = 0 239 | target_cpu = target.data.cpu().numpy() 240 | 241 | if self.batch_weights: 242 | class_weights = self.calculate_weights(target_cpu) 243 | 244 | for i in range(0, inputs.shape[0]): 245 | if not self.batch_weights: 246 | class_weights = self.calculate_weights(target_cpu[i]) 247 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0), 248 | target[i].unsqueeze(0), 249 | class_weights=torch.Tensor(class_weights).cuda(), 250 | border_weights=weights[i], mask=ignore_mask[i]) 251 | 252 | loss = loss / inputs.shape[0] 253 | return loss 254 | 255 | class ImgWtLossSoftNLL_by_epoch(nn.Module): 256 | """ 257 | Relax Loss 258 | """ 259 | 260 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0, 261 | norm=False): 262 | super(ImgWtLossSoftNLL_by_epoch, self).__init__() 263 | self.weights = weights 264 | self.num_classes = classes 265 | self.ignore_index = ignore_index 266 | self.upper_bound = upper_bound 267 | self.norm = norm 268 | self.batch_weights = cfg.BATCH_WEIGHTING 269 | self.fp16 = False 270 | 271 | 272 | def calculate_weights(self, target): 273 | """ 274 | Calculate weights of the classes based on training crop 275 | """ 276 | if len(target.shape) == 3: 277 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum() 278 | else: 279 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum() 280 | if self.norm: 281 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 282 | else: 283 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 284 | return hist[:-1] 285 | 286 | def custom_nll(self, inputs, target, class_weights, border_weights, mask): 287 | """ 288 | NLL Relaxed Loss Implementation 289 | """ 290 | if (cfg.REDUCE_BORDER_EPOCH != -1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 291 | border_weights = 1 / border_weights 292 | target[target > 1] = 1 293 | if self.fp16: 294 | loss_matrix = (-1 / border_weights * 295 | (target[:, :-1, :, :].half() * 296 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 297 | customsoftmax(inputs, target[:, :-1, :, :].half())).sum(1)) * \ 298 | (1. - mask.half()) 299 | else: 300 | loss_matrix = (-1 / border_weights * 301 | (target[:, :-1, :, :].float() * 302 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 303 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \ 304 | (1. - mask.float()) 305 | 306 | # loss_matrix[border_weights > 1] = 0 307 | loss = loss_matrix.sum() 308 | 309 | # +1 to prevent division by 0 310 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1) 311 | return loss 312 | 313 | def forward(self, inputs, target): 314 | if self.fp16: 315 | weights = target[:, :-1, :, :].sum(1).half() 316 | else: 317 | weights = target[:, :-1, :, :].sum(1).float() 318 | ignore_mask = (weights == 0) 319 | weights[ignore_mask] = 1 320 | 321 | loss = 0 322 | target_cpu = target.data.cpu().numpy() 323 | 324 | if self.batch_weights: 325 | class_weights = self.calculate_weights(target_cpu) 326 | 327 | for i in range(0, inputs.shape[0]): 328 | if not self.batch_weights: 329 | class_weights = self.calculate_weights(target_cpu[i]) 330 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0), 331 | target[i].unsqueeze(0), 332 | class_weights=torch.Tensor(class_weights).cuda(), 333 | border_weights=weights, mask=ignore_mask[i]) 334 | 335 | return loss 336 | -------------------------------------------------------------------------------- /network/Resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code Adapted from: 3 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | import torch 37 | import torch.nn as nn 38 | import torch.utils.model_zoo as model_zoo 39 | import network.mynn as mynn 40 | 41 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 42 | 'resnet152'] 43 | 44 | 45 | model_urls = { 46 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 47 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 48 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 49 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 50 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 51 | } 52 | 53 | 54 | def conv3x3(in_planes, out_planes, stride=1): 55 | """3x3 convolution with padding""" 56 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 57 | padding=1, bias=False) 58 | 59 | 60 | class BasicBlock(nn.Module): 61 | """ 62 | Basic Block for Resnet 63 | """ 64 | expansion = 1 65 | 66 | def __init__(self, inplanes, planes, stride=1, downsample=None): 67 | super(BasicBlock, self).__init__() 68 | self.conv1 = conv3x3(inplanes, planes, stride) 69 | self.bn1 = mynn.Norm2d(planes) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.conv2 = conv3x3(planes, planes) 72 | self.bn2 = mynn.Norm2d(planes) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class Bottleneck(nn.Module): 96 | """ 97 | Bottleneck Layer for Resnet 98 | """ 99 | expansion = 4 100 | 101 | def __init__(self, inplanes, planes, stride=1, downsample=None): 102 | super(Bottleneck, self).__init__() 103 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 104 | self.bn1 = mynn.Norm2d(planes) 105 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 106 | padding=1, bias=False) 107 | self.bn2 = mynn.Norm2d(planes) 108 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 109 | self.bn3 = mynn.Norm2d(planes * self.expansion) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.downsample = downsample 112 | self.stride = stride 113 | 114 | def forward(self, x): 115 | residual = x 116 | 117 | out = self.conv1(x) 118 | out = self.bn1(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv2(out) 122 | out = self.bn2(out) 123 | out = self.relu(out) 124 | 125 | out = self.conv3(out) 126 | out = self.bn3(out) 127 | 128 | if self.downsample is not None: 129 | residual = self.downsample(x) 130 | 131 | out += residual 132 | out = self.relu(out) 133 | 134 | return out 135 | 136 | 137 | class ResNet3X3(nn.Module): 138 | """ 139 | Resnet Global Module for Initialization 140 | """ 141 | def __init__(self, block, layers, num_classes=1000): 142 | # self.inplanes = 64 143 | self.inplanes = 128 144 | super(ResNet3X3, self).__init__() 145 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 146 | # bias=False) 147 | # self.bn1 = mynn.Norm2d(64) 148 | # self.relu = nn.ReLU(inplace=True) 149 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, 150 | bias=False) 151 | self.bn1 = mynn.Norm2d(64) 152 | self.relu1 = nn.ReLU(inplace=True) 153 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, 154 | bias=False) 155 | self.bn2 = mynn.Norm2d(64) 156 | self.relu2 = nn.ReLU(inplace=True) 157 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, 158 | bias=False) 159 | self.bn3 = mynn.Norm2d(self.inplanes) 160 | self.relu3 = nn.ReLU(inplace=True) 161 | 162 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 163 | self.layer1 = self._make_layer(block, 64, layers[0]) 164 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 165 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 166 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 167 | self.avgpool = nn.AvgPool2d(7, stride=1) 168 | self.fc = nn.Linear(512 * block.expansion, num_classes) 169 | 170 | for m in self.modules(): 171 | if isinstance(m, nn.Conv2d): 172 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 173 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.SyncBatchNorm): 174 | nn.init.constant_(m.weight, 1) 175 | nn.init.constant_(m.bias, 0) 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1): 178 | downsample = None 179 | if stride != 1 or self.inplanes != planes * block.expansion: 180 | downsample = nn.Sequential( 181 | nn.Conv2d(self.inplanes, planes * block.expansion, 182 | kernel_size=1, stride=stride, bias=False), 183 | mynn.Norm2d(planes * block.expansion), 184 | ) 185 | 186 | layers = [] 187 | layers.append(block(self.inplanes, planes, stride, downsample)) 188 | self.inplanes = planes * block.expansion 189 | for index in range(1, blocks): 190 | layers.append(block(self.inplanes, planes)) 191 | 192 | return nn.Sequential(*layers) 193 | 194 | def forward(self, x): 195 | # x = self.conv1(x) 196 | # x = self.bn1(x) 197 | # x = self.relu(x) 198 | x = self.conv1(input) 199 | x = self.bn1(x) 200 | x = self.relu1(x) 201 | x = self.conv2(x) 202 | x = self.bn2(x) 203 | x = self.relu2(x) 204 | x = self.conv3(x) 205 | x = self.bn3(x) 206 | x = self.relu3(x) 207 | 208 | x = self.maxpool(x) 209 | 210 | x = self.layer1(x) 211 | x = self.layer2(x) 212 | x = self.layer3(x) 213 | x = self.layer4(x) 214 | 215 | x = self.avgpool(x) 216 | x = x.view(x.size(0), -1) 217 | x = self.fc(x) 218 | 219 | return x 220 | 221 | class ResNet(nn.Module): 222 | """ 223 | Resnet Global Module for Initialization 224 | """ 225 | def __init__(self, block, layers, num_classes=1000): 226 | self.inplanes = 64 227 | # self.inplanes = 128 228 | super(ResNet, self).__init__() 229 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 230 | bias=False) 231 | self.bn1 = mynn.Norm2d(64) 232 | self.relu = nn.ReLU(inplace=True) 233 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, 234 | # bias=False) 235 | # self.bn1 = mynn.Norm2d(64) 236 | # self.relu1 = nn.ReLU(inplace=True) 237 | # self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, 238 | # bias=False) 239 | # self.bn2 = mynn.Norm2d(64) 240 | # self.relu2 = nn.ReLU(inplace=True) 241 | # self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, 242 | # bias=False) 243 | # self.bn3 = mynn.Norm2d(self.inplanes) 244 | # self.relu3 = nn.ReLU(inplace=True) 245 | 246 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 247 | self.layer1 = self._make_layer(block, 64, layers[0]) 248 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 249 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 250 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 251 | self.avgpool = nn.AvgPool2d(7, stride=1) 252 | self.fc = nn.Linear(512 * block.expansion, num_classes) 253 | 254 | for m in self.modules(): 255 | if isinstance(m, nn.Conv2d): 256 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 257 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.SyncBatchNorm): 258 | nn.init.constant_(m.weight, 1) 259 | nn.init.constant_(m.bias, 0) 260 | 261 | def _make_layer(self, block, planes, blocks, stride=1): 262 | downsample = None 263 | if stride != 1 or self.inplanes != planes * block.expansion: 264 | downsample = nn.Sequential( 265 | nn.Conv2d(self.inplanes, planes * block.expansion, 266 | kernel_size=1, stride=stride, bias=False), 267 | mynn.Norm2d(planes * block.expansion), 268 | ) 269 | 270 | layers = [] 271 | layers.append(block(self.inplanes, planes, stride, downsample)) 272 | self.inplanes = planes * block.expansion 273 | for index in range(1, blocks): 274 | layers.append(block(self.inplanes, planes)) 275 | 276 | return nn.Sequential(*layers) 277 | 278 | def forward(self, x): 279 | x = self.conv1(x) 280 | x = self.bn1(x) 281 | x = self.relu(x) 282 | # x = self.conv1(input) 283 | # x = self.bn1(x) 284 | # x = self.relu1(x) 285 | # x = self.conv2(x) 286 | # x = self.bn2(x) 287 | # x = self.relu2(x) 288 | # x = self.conv3(x) 289 | # x = self.bn3(x) 290 | # x = self.relu3(x) 291 | 292 | x = self.maxpool(x) 293 | 294 | x = self.layer1(x) 295 | x = self.layer2(x) 296 | x = self.layer3(x) 297 | x = self.layer4(x) 298 | 299 | x = self.avgpool(x) 300 | x = x.view(x.size(0), -1) 301 | x = self.fc(x) 302 | 303 | return x 304 | 305 | def resnet18(pretrained=True, **kwargs): 306 | """Constructs a ResNet-18 model. 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | """ 311 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 312 | if pretrained: 313 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 314 | return model 315 | 316 | 317 | def resnet34(pretrained=True, **kwargs): 318 | """Constructs a ResNet-34 model. 319 | 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | """ 323 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 324 | if pretrained: 325 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 326 | return model 327 | 328 | 329 | def resnet50(pretrained=True, **kwargs): 330 | """Constructs a ResNet-50 model. 331 | 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | """ 335 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 336 | if pretrained: 337 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 338 | return model 339 | 340 | 341 | def resnet101(pretrained=True, **kwargs): 342 | """Constructs a ResNet-101 model. 343 | 344 | Args: 345 | pretrained (bool): If True, returns a model pre-trained on ImageNet 346 | """ 347 | model = ResNet3X3(Bottleneck, [3, 4, 23, 3], **kwargs) 348 | if pretrained: 349 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 350 | print("########### pretrained ##############") 351 | model.load_state_dict(torch.load('./pretrained/resnet101-imagenet.pth', map_location="cpu")) 352 | return model 353 | 354 | 355 | def resnet152(pretrained=True, **kwargs): 356 | """Constructs a ResNet-152 model. 357 | 358 | Args: 359 | pretrained (bool): If True, returns a model pre-trained on ImageNet 360 | """ 361 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 362 | if pretrained: 363 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 364 | return model 365 | -------------------------------------------------------------------------------- /network/wider_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/mapillary/inplace_abn/ 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, mapillary 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | import logging 36 | import sys 37 | from collections import OrderedDict 38 | from functools import partial 39 | import torch.nn as nn 40 | import torch 41 | import network.mynn as mynn 42 | 43 | def bnrelu(channels): 44 | """ 45 | Single Layer BN and Relui 46 | """ 47 | return nn.Sequential(mynn.Norm2d(channels), 48 | nn.ReLU(inplace=True)) 49 | 50 | class GlobalAvgPool2d(nn.Module): 51 | """ 52 | Global average pooling over the input's spatial dimensions 53 | """ 54 | 55 | def __init__(self): 56 | super(GlobalAvgPool2d, self).__init__() 57 | logging.info("Global Average Pooling Initialized") 58 | 59 | def forward(self, inputs): 60 | in_size = inputs.size() 61 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 62 | 63 | 64 | class IdentityResidualBlock(nn.Module): 65 | """ 66 | Identity Residual Block for WideResnet 67 | """ 68 | def __init__(self, 69 | in_channels, 70 | channels, 71 | stride=1, 72 | dilation=1, 73 | groups=1, 74 | norm_act=bnrelu, 75 | dropout=None, 76 | dist_bn=False 77 | ): 78 | """Configurable identity-mapping residual block 79 | 80 | Parameters 81 | ---------- 82 | in_channels : int 83 | Number of input channels. 84 | channels : list of int 85 | Number of channels in the internal feature maps. 86 | Can either have two or three elements: if three construct 87 | a residual block with two `3 x 3` convolutions, 88 | otherwise construct a bottleneck block with `1 x 1`, then 89 | `3 x 3` then `1 x 1` convolutions. 90 | stride : int 91 | Stride of the first `3 x 3` convolution 92 | dilation : int 93 | Dilation to apply to the `3 x 3` convolutions. 94 | groups : int 95 | Number of convolution groups. 96 | This is used to create ResNeXt-style blocks and is only compatible with 97 | bottleneck blocks. 98 | norm_act : callable 99 | Function to create normalization / activation Module. 100 | dropout: callable 101 | Function to create Dropout Module. 102 | dist_bn: Boolean 103 | A variable to enable or disable use of distributed BN 104 | """ 105 | super(IdentityResidualBlock, self).__init__() 106 | self.dist_bn = dist_bn 107 | 108 | # Check if we are using distributed BN and use the nn from encoding.nn 109 | # library rather than using standard pytorch.nn 110 | 111 | 112 | # Check parameters for inconsistencies 113 | if len(channels) != 2 and len(channels) != 3: 114 | raise ValueError("channels must contain either two or three values") 115 | if len(channels) == 2 and groups != 1: 116 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 117 | 118 | is_bottleneck = len(channels) == 3 119 | need_proj_conv = stride != 1 or in_channels != channels[-1] 120 | 121 | self.bn1 = norm_act(in_channels) 122 | if not is_bottleneck: 123 | layers = [ 124 | ("conv1", nn.Conv2d(in_channels, 125 | channels[0], 126 | 3, 127 | stride=stride, 128 | padding=dilation, 129 | bias=False, 130 | dilation=dilation)), 131 | ("bn2", norm_act(channels[0])), 132 | ("conv2", nn.Conv2d(channels[0], channels[1], 133 | 3, 134 | stride=1, 135 | padding=dilation, 136 | bias=False, 137 | dilation=dilation)) 138 | ] 139 | if dropout is not None: 140 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 141 | else: 142 | layers = [ 143 | ("conv1", 144 | nn.Conv2d(in_channels, 145 | channels[0], 146 | 1, 147 | stride=stride, 148 | padding=0, 149 | bias=False)), 150 | ("bn2", norm_act(channels[0])), 151 | ("conv2", nn.Conv2d(channels[0], 152 | channels[1], 153 | 3, stride=1, 154 | padding=dilation, bias=False, 155 | groups=groups, 156 | dilation=dilation)), 157 | ("bn3", norm_act(channels[1])), 158 | ("conv3", nn.Conv2d(channels[1], channels[2], 159 | 1, stride=1, padding=0, bias=False)) 160 | ] 161 | if dropout is not None: 162 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 163 | self.convs = nn.Sequential(OrderedDict(layers)) 164 | 165 | if need_proj_conv: 166 | self.proj_conv = nn.Conv2d( 167 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 168 | 169 | def forward(self, x): 170 | """ 171 | This is the standard forward function for non-distributed batch norm 172 | """ 173 | if hasattr(self, "proj_conv"): 174 | bn1 = self.bn1(x) 175 | shortcut = self.proj_conv(bn1) 176 | else: 177 | shortcut = x.clone() 178 | bn1 = self.bn1(x) 179 | 180 | out = self.convs(bn1) 181 | out.add_(shortcut) 182 | return out 183 | 184 | 185 | 186 | 187 | class WiderResNet(nn.Module): 188 | """ 189 | WideResnet Global Module for Initialization 190 | """ 191 | def __init__(self, 192 | structure, 193 | norm_act=bnrelu, 194 | classes=0 195 | ): 196 | """Wider ResNet with pre-activation (identity mapping) blocks 197 | 198 | Parameters 199 | ---------- 200 | structure : list of int 201 | Number of residual blocks in each of the six modules of the network. 202 | norm_act : callable 203 | Function to create normalization / activation Module. 204 | classes : int 205 | If not `0` also include global average pooling and \ 206 | a fully-connected layer with `classes` outputs at the end 207 | of the network. 208 | """ 209 | super(WiderResNet, self).__init__() 210 | self.structure = structure 211 | 212 | if len(structure) != 6: 213 | raise ValueError("Expected a structure with six values") 214 | 215 | # Initial layers 216 | self.mod1 = nn.Sequential(OrderedDict([ 217 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 218 | ])) 219 | 220 | # Groups of residual blocks 221 | in_channels = 64 222 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), 223 | (512, 1024, 2048), (1024, 2048, 4096)] 224 | for mod_id, num in enumerate(structure): 225 | # Create blocks for module 226 | blocks = [] 227 | for block_id in range(num): 228 | blocks.append(( 229 | "block%d" % (block_id + 1), 230 | IdentityResidualBlock(in_channels, channels[mod_id], 231 | norm_act=norm_act) 232 | )) 233 | 234 | # Update channels and p_keep 235 | in_channels = channels[mod_id][-1] 236 | 237 | # Create module 238 | if mod_id <= 4: 239 | self.add_module("pool%d" % 240 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 241 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 242 | 243 | # Pooling and predictor 244 | self.bn_out = norm_act(in_channels) 245 | if classes != 0: 246 | self.classifier = nn.Sequential(OrderedDict([ 247 | ("avg_pool", GlobalAvgPool2d()), 248 | ("fc", nn.Linear(in_channels, classes)) 249 | ])) 250 | 251 | def forward(self, img): 252 | out = self.mod1(img) 253 | out = self.mod2(self.pool2(out)) 254 | out = self.mod3(self.pool3(out)) 255 | out = self.mod4(self.pool4(out)) 256 | out = self.mod5(self.pool5(out)) 257 | out = self.mod6(self.pool6(out)) 258 | out = self.mod7(out) 259 | out = self.bn_out(out) 260 | 261 | if hasattr(self, "classifier"): 262 | out = self.classifier(out) 263 | 264 | return out 265 | 266 | 267 | class WiderResNetA2(nn.Module): 268 | """ 269 | Wider ResNet with pre-activation (identity mapping) blocks 270 | 271 | This variant uses down-sampling by max-pooling in the first two blocks and 272 | by strided convolution in the others. 273 | 274 | Parameters 275 | ---------- 276 | structure : list of int 277 | Number of residual blocks in each of the six modules of the network. 278 | norm_act : callable 279 | Function to create normalization / activation Module. 280 | classes : int 281 | If not `0` also include global average pooling and a fully-connected layer 282 | with `classes` outputs at the end 283 | of the network. 284 | dilation : bool 285 | If `True` apply dilation to the last three modules and change the 286 | down-sampling factor from 32 to 8. 287 | """ 288 | def __init__(self, 289 | structure, 290 | norm_act=bnrelu, 291 | classes=0, 292 | dilation=False, 293 | dist_bn=False 294 | ): 295 | super(WiderResNetA2, self).__init__() 296 | self.dist_bn = dist_bn 297 | 298 | # If using distributed batch norm, use the encoding.nn as oppose to torch.nn 299 | 300 | 301 | nn.Dropout = nn.Dropout2d 302 | norm_act = bnrelu 303 | self.structure = structure 304 | self.dilation = dilation 305 | 306 | if len(structure) != 6: 307 | raise ValueError("Expected a structure with six values") 308 | 309 | # Initial layers 310 | self.mod1 = torch.nn.Sequential(OrderedDict([ 311 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 312 | ])) 313 | 314 | # Groups of residual blocks 315 | in_channels = 64 316 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), 317 | (1024, 2048, 4096)] 318 | for mod_id, num in enumerate(structure): 319 | # Create blocks for module 320 | blocks = [] 321 | for block_id in range(num): 322 | if not dilation: 323 | dil = 1 324 | stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1 325 | else: 326 | if mod_id == 3: 327 | dil = 2 328 | elif mod_id > 3: 329 | dil = 4 330 | else: 331 | dil = 1 332 | stride = 2 if block_id == 0 and mod_id == 2 else 1 333 | 334 | if mod_id == 4: 335 | drop = partial(nn.Dropout, p=0.3) 336 | elif mod_id == 5: 337 | drop = partial(nn.Dropout, p=0.5) 338 | else: 339 | drop = None 340 | 341 | blocks.append(( 342 | "block%d" % (block_id + 1), 343 | IdentityResidualBlock(in_channels, 344 | channels[mod_id], norm_act=norm_act, 345 | stride=stride, dilation=dil, 346 | dropout=drop, dist_bn=self.dist_bn) 347 | )) 348 | 349 | # Update channels and p_keep 350 | in_channels = channels[mod_id][-1] 351 | 352 | # Create module 353 | if mod_id < 2: 354 | self.add_module("pool%d" % 355 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 356 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 357 | 358 | # Pooling and predictor 359 | self.bn_out = norm_act(in_channels) 360 | if classes != 0: 361 | self.classifier = nn.Sequential(OrderedDict([ 362 | ("avg_pool", GlobalAvgPool2d()), 363 | ("fc", nn.Linear(in_channels, classes)) 364 | ])) 365 | 366 | def forward(self, img): 367 | out = self.mod1(img) 368 | out = self.mod2(self.pool2(out)) 369 | out = self.mod3(self.pool3(out)) 370 | out = self.mod4(out) 371 | out = self.mod5(out) 372 | out = self.mod6(out) 373 | out = self.mod7(out) 374 | out = self.bn_out(out) 375 | 376 | if hasattr(self, "classifier"): 377 | return self.classifier(out) 378 | return out 379 | 380 | 381 | _NETS = { 382 | "16": {"structure": [1, 1, 1, 1, 1, 1]}, 383 | "20": {"structure": [1, 1, 1, 3, 1, 1]}, 384 | "38": {"structure": [3, 3, 6, 3, 1, 1]}, 385 | } 386 | 387 | __all__ = [] 388 | for name, params in _NETS.items(): 389 | net_name = "wider_resnet" + name 390 | setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params)) 391 | __all__.append(net_name) 392 | for name, params in _NETS.items(): 393 | net_name = "wider_resnet" + name + "_a2" 394 | setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params)) 395 | __all__.append(net_name) 396 | -------------------------------------------------------------------------------- /network/SEresnext.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/Cadene/pretrained-models.pytorch 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, Remi Cadene 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | import logging 36 | from collections import OrderedDict 37 | import math 38 | import torch.nn as nn 39 | from torch.utils import model_zoo 40 | import network.mynn as mynn 41 | 42 | __all__ = ['SENet', 'se_resnext50_32x4d', 'se_resnext101_32x4d'] 43 | 44 | pretrained_settings = { 45 | 'se_resnext50_32x4d': { 46 | 'imagenet': { 47 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 48 | 'input_space': 'RGB', 49 | 'input_size': [3, 224, 224], 50 | 'input_range': [0, 1], 51 | 'mean': [0.485, 0.456, 0.406], 52 | 'std': [0.229, 0.224, 0.225], 53 | 'num_classes': 1000 54 | } 55 | }, 56 | 'se_resnext101_32x4d': { 57 | 'imagenet': { 58 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 59 | 'input_space': 'RGB', 60 | 'input_size': [3, 224, 224], 61 | 'input_range': [0, 1], 62 | 'mean': [0.485, 0.456, 0.406], 63 | 'std': [0.229, 0.224, 0.225], 64 | 'num_classes': 1000 65 | } 66 | }, 67 | } 68 | 69 | 70 | class SEModule(nn.Module): 71 | """ 72 | Sequeeze Excitation Module 73 | """ 74 | def __init__(self, channels, reduction): 75 | super(SEModule, self).__init__() 76 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 77 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 78 | padding=0) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 81 | padding=0) 82 | self.sigmoid = nn.Sigmoid() 83 | 84 | def forward(self, x): 85 | module_input = x 86 | x = self.avg_pool(x) 87 | x = self.fc1(x) 88 | x = self.relu(x) 89 | x = self.fc2(x) 90 | x = self.sigmoid(x) 91 | return module_input * x 92 | 93 | 94 | class Bottleneck(nn.Module): 95 | """ 96 | Base class for bottlenecks that implements `forward()` method. 97 | """ 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | out = self.relu(out) 108 | 109 | out = self.conv3(out) 110 | out = self.bn3(out) 111 | 112 | if self.downsample is not None: 113 | residual = self.downsample(x) 114 | 115 | out = self.se_module(out) + residual 116 | out = self.relu(out) 117 | 118 | return out 119 | 120 | 121 | class SEBottleneck(Bottleneck): 122 | """ 123 | Bottleneck for SENet154. 124 | """ 125 | expansion = 4 126 | 127 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 128 | downsample=None): 129 | super(SEBottleneck, self).__init__() 130 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 131 | self.bn1 = mynn.Norm2d(planes * 2) 132 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, 133 | stride=stride, padding=1, groups=groups, 134 | bias=False) 135 | self.bn2 = mynn.Norm2d(planes * 4) 136 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, 137 | bias=False) 138 | self.bn3 = mynn.Norm2d(planes * 4) 139 | self.relu = nn.ReLU(inplace=True) 140 | self.se_module = SEModule(planes * 4, reduction=reduction) 141 | self.downsample = downsample 142 | self.stride = stride 143 | 144 | 145 | class SEResNetBottleneck(Bottleneck): 146 | """ 147 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 148 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 149 | (the latter is used in the torchvision implementation of ResNet). 150 | """ 151 | expansion = 4 152 | 153 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 154 | downsample=None): 155 | super(SEResNetBottleneck, self).__init__() 156 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 157 | stride=stride) 158 | self.bn1 = mynn.Norm2d(planes) 159 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 160 | groups=groups, bias=False) 161 | self.bn2 = mynn.Norm2d(planes) 162 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 163 | self.bn3 = mynn.Norm2d(planes * 4) 164 | self.relu = nn.ReLU(inplace=True) 165 | self.se_module = SEModule(planes * 4, reduction=reduction) 166 | self.downsample = downsample 167 | self.stride = stride 168 | 169 | 170 | class SEResNeXtBottleneck(Bottleneck): 171 | """ 172 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 173 | """ 174 | expansion = 4 175 | 176 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 177 | downsample=None, base_width=4): 178 | super(SEResNeXtBottleneck, self).__init__() 179 | width = math.floor(planes * (base_width / 64)) * groups 180 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 181 | stride=1) 182 | self.bn1 = mynn.Norm2d(width) 183 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 184 | padding=1, groups=groups, bias=False) 185 | self.bn2 = mynn.Norm2d(width) 186 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 187 | self.bn3 = mynn.Norm2d(planes * 4) 188 | self.relu = nn.ReLU(inplace=True) 189 | self.se_module = SEModule(planes * 4, reduction=reduction) 190 | self.downsample = downsample 191 | self.stride = stride 192 | 193 | 194 | class SENet(nn.Module): 195 | """ 196 | Main Squeeze Excitation Network Module 197 | """ 198 | 199 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 200 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 201 | downsample_padding=1, num_classes=1000): 202 | """ 203 | Parameters 204 | ---------- 205 | block (nn.Module): Bottleneck class. 206 | - For SENet154: SEBottleneck 207 | - For SE-ResNet models: SEResNetBottleneck 208 | - For SE-ResNeXt models: SEResNeXtBottleneck 209 | layers (list of ints): Number of residual blocks for 4 layers of the 210 | network (layer1...layer4). 211 | groups (int): Number of groups for the 3x3 convolution in each 212 | bottleneck block. 213 | - For SENet154: 64 214 | - For SE-ResNet models: 1 215 | - For SE-ResNeXt models: 32 216 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 217 | - For all models: 16 218 | dropout_p (float or None): Drop probability for the Dropout layer. 219 | If `None` the Dropout layer is not used. 220 | - For SENet154: 0.2 221 | - For SE-ResNet models: None 222 | - For SE-ResNeXt models: None 223 | inplanes (int): Number of input channels for layer1. 224 | - For SENet154: 128 225 | - For SE-ResNet models: 64 226 | - For SE-ResNeXt models: 64 227 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 228 | a single 7x7 convolution in layer0. 229 | - For SENet154: True 230 | - For SE-ResNet models: False 231 | - For SE-ResNeXt models: False 232 | downsample_kernel_size (int): Kernel size for downsampling convolutions 233 | in layer2, layer3 and layer4. 234 | - For SENet154: 3 235 | - For SE-ResNet models: 1 236 | - For SE-ResNeXt models: 1 237 | downsample_padding (int): Padding for downsampling convolutions in 238 | layer2, layer3 and layer4. 239 | - For SENet154: 1 240 | - For SE-ResNet models: 0 241 | - For SE-ResNeXt models: 0 242 | num_classes (int): Number of outputs in `last_linear` layer. 243 | - For all models: 1000 244 | """ 245 | super(SENet, self).__init__() 246 | self.inplanes = inplanes 247 | if input_3x3: 248 | layer0_modules = [ 249 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 250 | bias=False)), 251 | ('bn1', mynn.Norm2d(64)), 252 | ('relu1', nn.ReLU(inplace=True)), 253 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 254 | bias=False)), 255 | ('bn2', mynn.Norm2d(64)), 256 | ('relu2', nn.ReLU(inplace=True)), 257 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 258 | bias=False)), 259 | ('bn3', mynn.Norm2d(inplanes)), 260 | ('relu3', nn.ReLU(inplace=True)), 261 | ] 262 | else: 263 | layer0_modules = [ 264 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 265 | padding=3, bias=False)), 266 | ('bn1', mynn.Norm2d(inplanes)), 267 | ('relu1', nn.ReLU(inplace=True)), 268 | ] 269 | # To preserve compatibility with Caffe weights `ceil_mode=True` 270 | # is used instead of `padding=1`. 271 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 272 | ceil_mode=True))) 273 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 274 | self.layer1 = self._make_layer( 275 | block, 276 | planes=64, 277 | blocks=layers[0], 278 | groups=groups, 279 | reduction=reduction, 280 | downsample_kernel_size=1, 281 | downsample_padding=0 282 | ) 283 | self.layer2 = self._make_layer( 284 | block, 285 | planes=128, 286 | blocks=layers[1], 287 | stride=2, 288 | groups=groups, 289 | reduction=reduction, 290 | downsample_kernel_size=downsample_kernel_size, 291 | downsample_padding=downsample_padding 292 | ) 293 | self.layer3 = self._make_layer( 294 | block, 295 | planes=256, 296 | blocks=layers[2], 297 | stride=1, 298 | groups=groups, 299 | reduction=reduction, 300 | downsample_kernel_size=downsample_kernel_size, 301 | downsample_padding=downsample_padding 302 | ) 303 | self.layer4 = self._make_layer( 304 | block, 305 | planes=512, 306 | blocks=layers[3], 307 | stride=1, 308 | groups=groups, 309 | reduction=reduction, 310 | downsample_kernel_size=downsample_kernel_size, 311 | downsample_padding=downsample_padding 312 | ) 313 | self.avg_pool = nn.AvgPool2d(7, stride=1) 314 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 315 | self.last_linear = nn.Linear(512 * block.expansion, num_classes) 316 | 317 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 318 | downsample_kernel_size=1, downsample_padding=0): 319 | downsample = None 320 | if stride != 1 or self.inplanes != planes * block.expansion: 321 | downsample = nn.Sequential( 322 | nn.Conv2d(self.inplanes, planes * block.expansion, 323 | kernel_size=downsample_kernel_size, stride=stride, 324 | padding=downsample_padding, bias=False), 325 | mynn.Norm2d(planes * block.expansion), 326 | ) 327 | 328 | layers = [] 329 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 330 | downsample)) 331 | self.inplanes = planes * block.expansion 332 | for index in range(1, blocks): 333 | layers.append(block(self.inplanes, planes, groups, reduction)) 334 | 335 | return nn.Sequential(*layers) 336 | 337 | def features(self, x): 338 | """ 339 | Forward Pass through the each layer of SE network 340 | """ 341 | x = self.layer0(x) 342 | x = self.layer1(x) 343 | x = self.layer2(x) 344 | x = self.layer3(x) 345 | x = self.layer4(x) 346 | return x 347 | 348 | def logits(self, x): 349 | """ 350 | AvgPool and Linear Layer 351 | """ 352 | x = self.avg_pool(x) 353 | if self.dropout is not None: 354 | x = self.dropout(x) 355 | x = x.view(x.size(0), -1) 356 | x = self.last_linear(x) 357 | return x 358 | 359 | def forward(self, x): 360 | x = self.features(x) 361 | x = self.logits(x) 362 | return x 363 | 364 | 365 | def initialize_pretrained_model(model, num_classes, settings): 366 | """ 367 | Initialize Pretrain Model Information, 368 | Dowload weights, load weights, set variables 369 | """ 370 | assert num_classes == settings['num_classes'], \ 371 | 'num_classes should be {}, but is {}'.format( 372 | settings['num_classes'], num_classes) 373 | weights = model_zoo.load_url(settings['url']) 374 | model.load_state_dict(weights) 375 | model.input_space = settings['input_space'] 376 | model.input_size = settings['input_size'] 377 | model.input_range = settings['input_range'] 378 | model.mean = settings['mean'] 379 | model.std = settings['std'] 380 | 381 | 382 | 383 | def se_resnext50_32x4d(num_classes=1000): 384 | """ 385 | Defination For SE Resnext50 386 | """ 387 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, 388 | dropout_p=None, inplanes=64, input_3x3=False, 389 | downsample_kernel_size=1, downsample_padding=0, 390 | num_classes=num_classes) 391 | settings = pretrained_settings['se_resnext50_32x4d']['imagenet'] 392 | initialize_pretrained_model(model, num_classes, settings) 393 | return model 394 | 395 | 396 | def se_resnext101_32x4d(num_classes=1000): 397 | """ 398 | Defination For SE Resnext101 399 | """ 400 | 401 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, 402 | dropout_p=None, inplanes=64, input_3x3=False, 403 | downsample_kernel_size=1, downsample_padding=0, 404 | num_classes=num_classes) 405 | settings = pretrained_settings['se_resnext101_32x4d']['imagenet'] 406 | initialize_pretrained_model(model, num_classes, settings) 407 | return model 408 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch Optimizer and Scheduler Related Task 3 | """ 4 | import math 5 | import logging 6 | import torch 7 | from torch import optim 8 | from config import cfg 9 | 10 | 11 | def get_optimizer(args, net): 12 | """ 13 | Decide Optimizer (Adam or SGD) 14 | """ 15 | if args.backbone_lr > 0.0: 16 | base_params = [] 17 | resnet_params = [] 18 | resnet_name = [] 19 | resnet_name.append('layer0') 20 | resnet_name.append('layer1') 21 | #resnet_name.append('layer2') 22 | #resnet_name.append('layer3') 23 | #resnet_name.append('layer4') 24 | len_resnet = len(resnet_name) 25 | else: 26 | param_groups = net.parameters() 27 | 28 | if args.backbone_lr > 0.0: 29 | for name, param in net.named_parameters(): 30 | is_resnet = False 31 | for i in range(len_resnet): 32 | if resnet_name[i] in name: 33 | resnet_params.append(param) 34 | # param.requires_grad=False 35 | print("resnet_name", name) 36 | is_resnet = True 37 | break 38 | if not is_resnet: 39 | base_params.append(param) 40 | 41 | if args.sgd: 42 | if args.backbone_lr > 0.0: 43 | optimizer = optim.SGD([ 44 | {'params': base_params}, 45 | {'params': resnet_params, 'lr':args.backbone_lr} 46 | ], 47 | lr=args.lr, 48 | weight_decay=5e-4, #args.weight_decay, 49 | momentum=args.momentum, 50 | nesterov=False) 51 | else: 52 | optimizer = optim.SGD(param_groups, 53 | lr=args.lr, 54 | weight_decay=5e-4, #args.weight_decay, 55 | momentum=args.momentum, 56 | nesterov=False) 57 | else: 58 | raise ValueError('Not a valid optimizer') 59 | 60 | if args.lr_schedule == 'scl-poly': 61 | if cfg.REDUCE_BORDER_ITER == -1: 62 | raise ValueError('ERROR Cannot Do Scale Poly') 63 | 64 | rescale_thresh = cfg.REDUCE_BORDER_ITER 65 | scale_value = args.rescale 66 | lambda1 = lambda iteration: \ 67 | math.pow(1 - iteration / args.max_iter, 68 | args.poly_exp) if iteration < rescale_thresh else scale_value * math.pow( 69 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh), 70 | args.repoly) 71 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 72 | elif args.lr_schedule == 'poly': 73 | lambda1 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp) 74 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 75 | else: 76 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule)) 77 | 78 | return optimizer, scheduler 79 | 80 | 81 | def get_optimizer_attention(args, net): 82 | """ 83 | Decide Optimizer (Adam or SGD) 84 | """ 85 | attention_params = [] 86 | base_params = [] 87 | hanet_name = [] 88 | 89 | if args.backbone_lr > 0.0: 90 | resnet_params = [] 91 | resnet_name = [] 92 | resnet_name.append('layer0') 93 | resnet_name.append('layer1') 94 | #resnet_name.append('layer2') 95 | #resnet_name.append('layer3') 96 | #resnet_name.append('layer4') 97 | len_resnet = len(resnet_name) 98 | 99 | for i in range(5): 100 | if args.hanet[i] > 0: # HANet_Diff 101 | hanet_name.append('hanet' + str(i)) 102 | 103 | len_hanet = len(hanet_name) 104 | 105 | for name, param in net.named_parameters(): 106 | is_hanet = False 107 | is_resnet = False 108 | if args.backbone_lr > 0.0: 109 | for i in range(len_resnet): 110 | if resnet_name[i] in name: 111 | resnet_params.append(param) 112 | # param.requires_grad=False 113 | print("resnet_name", name) 114 | is_resnet = True 115 | break 116 | if not is_resnet: 117 | for i in range(len_hanet): 118 | if hanet_name[i] in name: 119 | attention_params.append(param) 120 | #print("hanet_name", name) 121 | is_hanet = True 122 | break 123 | if not is_hanet and not is_resnet: 124 | base_params.append(param) 125 | #print("base", name) 126 | 127 | if args.sgd: 128 | if args.backbone_lr > 0.0: 129 | optimizer = optim.SGD([ 130 | {'params': base_params}, 131 | {'params': resnet_params, 'lr':args.backbone_lr} 132 | ], 133 | lr=args.lr, 134 | weight_decay=5e-4, #args.weight_decay, 135 | momentum=args.momentum, 136 | nesterov=False) 137 | else: 138 | optimizer = optim.SGD(base_params, 139 | lr=args.lr, 140 | weight_decay=5e-4, #args.weight_decay, 141 | momentum=args.momentum, 142 | nesterov=False) 143 | else: 144 | raise ValueError('Not a valid optimizer') 145 | 146 | print(" ############# HANet Number", len_hanet) 147 | optimizer_at = optim.SGD(attention_params, 148 | lr=args.hanet_lr, 149 | weight_decay=args.hanet_wd, 150 | momentum=args.momentum, 151 | nesterov=False) 152 | 153 | 154 | 155 | if args.lr_schedule == 'scl-poly': 156 | if cfg.REDUCE_BORDER_ITER == -1: 157 | raise ValueError('ERROR Cannot Do Scale Poly') 158 | 159 | rescale_thresh = cfg.REDUCE_BORDER_ITER 160 | scale_value = args.rescale 161 | lambda1 = lambda iteration: \ 162 | math.pow(1 - iteration / args.max_iter, 163 | args.poly_exp) if iteration <= rescale_thresh else scale_value * math.pow( 164 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh), 165 | args.repoly) 166 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 167 | 168 | if args.hanet_poly_exp > 0.0: 169 | lambda2 = lambda iteration: \ 170 | math.pow(1 - iteration / args.max_iter, 171 | args.hanet_poly_exp) if iteration <= rescale_thresh else scale_value * math.pow( 172 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh), 173 | args.repoly) 174 | scheduler_at = optim.lr_scheduler.LambdaLR(optimizer_at, lr_lambda=lambda2) 175 | else: 176 | lambda2 = lambda iteration: \ 177 | math.pow(1 - iteration / args.max_iter, 178 | args.poly_exp) if iteration <= rescale_thresh else scale_value * math.pow( 179 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh), 180 | args.repoly) 181 | scheduler_at = optim.lr_scheduler.LambdaLR(optimizer_at, lr_lambda=lambda2) 182 | 183 | elif args.lr_schedule == 'poly': 184 | lambda1 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp) 185 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 186 | 187 | # for attention module 188 | if args.hanet_poly_exp > 0.0: 189 | lambda2 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.hanet_poly_exp) 190 | scheduler_at = optim.lr_scheduler.LambdaLR(optimizer_at, lr_lambda=lambda2) 191 | else: 192 | lambda2 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp) 193 | scheduler_at = optim.lr_scheduler.LambdaLR(optimizer_at, lr_lambda=lambda2) 194 | else: 195 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule)) 196 | 197 | return optimizer, scheduler, optimizer_at, scheduler_at 198 | 199 | 200 | def get_optimizer_by_epoch(args, net): 201 | """ 202 | Decide Optimizer (Adam or SGD) 203 | """ 204 | param_groups = net.parameters() 205 | 206 | if args.sgd: 207 | optimizer = optim.SGD(param_groups, 208 | lr=args.lr, 209 | weight_decay=args.weight_decay, 210 | momentum=args.momentum, 211 | nesterov=False) 212 | elif args.adam: 213 | amsgrad = False 214 | if args.amsgrad: 215 | amsgrad = True 216 | optimizer = optim.Adam(param_groups, 217 | lr=args.lr, 218 | weight_decay=args.weight_decay, 219 | amsgrad=amsgrad 220 | ) 221 | else: 222 | raise ValueError('Not a valid optimizer') 223 | 224 | if args.lr_schedule == 'scl-poly': 225 | if cfg.REDUCE_BORDER_EPOCH == -1: 226 | raise ValueError('ERROR Cannot Do Scale Poly') 227 | 228 | rescale_thresh = cfg.REDUCE_BORDER_EPOCH 229 | scale_value = args.rescale 230 | lambda1 = lambda epoch: \ 231 | math.pow(1 - epoch / args.max_epoch, 232 | args.poly_exp) if epoch < rescale_thresh else scale_value * math.pow( 233 | 1 - (epoch - rescale_thresh) / (args.max_epoch - rescale_thresh), 234 | args.repoly) 235 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 236 | elif args.lr_schedule == 'poly': 237 | lambda1 = lambda epoch: math.pow(1 - epoch / args.max_epoch, args.poly_exp) 238 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 239 | else: 240 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule)) 241 | 242 | return optimizer, scheduler 243 | 244 | def load_weights_hanet(net, optimizer, optimizer_at, scheduler, scheduler_at, snapshot_file, restore_optimizer_bool=False): 245 | """ 246 | Load weights from snapshot file 247 | """ 248 | logging.info("Loading weights from model %s", snapshot_file) 249 | net, optimizer, optimizer_at, scheduler, scheduler_at, epoch, mean_iu = restore_snapshot_hanet(net, optimizer, 250 | optimizer_at, scheduler, scheduler_at, snapshot_file, restore_optimizer_bool) 251 | return epoch, mean_iu 252 | 253 | def load_weights_pe(net, snapshot_file): 254 | """ 255 | Load weights from snapshot file 256 | """ 257 | logging.info("Loading weights from model %s", snapshot_file) 258 | net = restore_snapshot_pe(net, snapshot_file) 259 | 260 | 261 | def load_weights(net, optimizer, scheduler, snapshot_file, restore_optimizer_bool=False): 262 | """ 263 | Load weights from snapshot file 264 | """ 265 | logging.info("Loading weights from model %s", snapshot_file) 266 | net, optimizer, scheduler, epoch, mean_iu = restore_snapshot(net, optimizer, scheduler, snapshot_file, 267 | restore_optimizer_bool) 268 | return epoch, mean_iu 269 | 270 | def restore_snapshot_pe(net, snapshot): 271 | """ 272 | Restore weights and optimizer (if needed ) for resuming job. 273 | """ 274 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 275 | logging.info("Checkpoint PE Load Compelete") 276 | 277 | if 'state_dict' in checkpoint: 278 | net = forgiving_state_restore_only_pe(net, checkpoint['state_dict']) 279 | else: 280 | net = forgiving_state_restore_only_pe(net, checkpoint) 281 | 282 | return net 283 | 284 | def forgiving_state_restore_only_pe(net, loaded_dict): 285 | """ 286 | Handle partial loading when some tensors don't match up in size. 287 | Because we want to use models that were trained off a different 288 | number of classes. 289 | """ 290 | net_state_dict = net.state_dict() 291 | new_loaded_dict = {} 292 | for k in net_state_dict: 293 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 294 | if 'pos_emb1d' in k: 295 | print("matched loading parameter", k) 296 | new_loaded_dict[k] = loaded_dict[k] 297 | # else: 298 | # print("Skipped loading parameter", k) 299 | # logging.info("Skipped loading parameter %s", k) 300 | net_state_dict.update(new_loaded_dict) 301 | net.load_state_dict(net_state_dict) 302 | return net 303 | 304 | def freeze_pe(net): 305 | for name, param in net.named_parameters(): 306 | if 'pos_emb1d' in name: 307 | print("freeze parameter", name) 308 | param.requires_grad = False 309 | 310 | def restore_snapshot_hanet(net, optimizer, optimizer_at, scheduler, scheduler_at, snapshot, restore_optimizer_bool): 311 | """ 312 | Restore weights and optimizer (if needed ) for resuming job. 313 | """ 314 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 315 | logging.info("Checkpoint Load Compelete") 316 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool: 317 | optimizer.load_state_dict(checkpoint['optimizer']) 318 | 319 | if optimizer_at is not None and 'optimizer_at' in checkpoint and restore_optimizer_bool: 320 | optimizer_at.load_state_dict(checkpoint['optimizer_at']) 321 | 322 | if scheduler is not None and 'scheduler' in checkpoint and restore_optimizer_bool: 323 | scheduler.load_state_dict(checkpoint['scheduler']) 324 | 325 | if scheduler_at is not None and 'scheduler_at' in checkpoint and restore_optimizer_bool: 326 | scheduler_at.load_state_dict(checkpoint['scheduler_at']) 327 | 328 | if 'state_dict' in checkpoint: 329 | net = forgiving_state_restore(net, checkpoint['state_dict']) 330 | else: 331 | net = forgiving_state_restore(net, checkpoint) 332 | 333 | return net, optimizer, optimizer_at, scheduler, scheduler_at, checkpoint['epoch'], checkpoint['mean_iu'] 334 | 335 | def restore_snapshot(net, optimizer, scheduler, snapshot, restore_optimizer_bool): 336 | """ 337 | Restore weights and optimizer (if needed ) for resuming job. 338 | """ 339 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 340 | logging.info("Checkpoint Load Compelete") 341 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool: 342 | optimizer.load_state_dict(checkpoint['optimizer']) 343 | if scheduler is not None and 'scheduler' in checkpoint and restore_optimizer_bool: 344 | scheduler.load_state_dict(checkpoint['scheduler']) 345 | 346 | if 'state_dict' in checkpoint: 347 | net = forgiving_state_restore(net, checkpoint['state_dict']) 348 | else: 349 | net = forgiving_state_restore(net, checkpoint) 350 | 351 | return net, optimizer, scheduler, checkpoint['epoch'], checkpoint['mean_iu'] 352 | 353 | 354 | def forgiving_state_restore(net, loaded_dict): 355 | """ 356 | Handle partial loading when some tensors don't match up in size. 357 | Because we want to use models that were trained off a different 358 | number of classes. 359 | """ 360 | net_state_dict = net.state_dict() 361 | new_loaded_dict = {} 362 | for k in net_state_dict: 363 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 364 | new_loaded_dict[k] = loaded_dict[k] 365 | else: 366 | print("Skipped loading parameter", k) 367 | # logging.info("Skipped loading parameter %s", k) 368 | net_state_dict.update(new_loaded_dict) 369 | net.load_state_dict(net_state_dict) 370 | return net 371 | 372 | def forgiving_state_copy(target_net, source_net): 373 | """ 374 | Handle partial loading when some tensors don't match up in size. 375 | Because we want to use models that were trained off a different 376 | number of classes. 377 | """ 378 | net_state_dict = target_net.state_dict() 379 | loaded_dict = source_net.state_dict() 380 | new_loaded_dict = {} 381 | for k in net_state_dict: 382 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 383 | new_loaded_dict[k] = loaded_dict[k] 384 | print("Matched", k) 385 | else: 386 | print("Skipped loading parameter ", k) 387 | # logging.info("Skipped loading parameter %s", k) 388 | net_state_dict.update(new_loaded_dict) 389 | target_net.load_state_dict(net_state_dict) 390 | return target_net 391 | -------------------------------------------------------------------------------- /datasets/camvid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Camvid Dataset Loader 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | from PIL import Image 9 | from torch.utils import data 10 | import logging 11 | import datasets.uniform as uniform 12 | import json 13 | from config import cfg 14 | import copy 15 | import torch 16 | 17 | 18 | # trainid_to_name = cityscapes_labels.trainId2name 19 | # id_to_trainid = cityscapes_labels.label2trainid 20 | num_classes = 11 21 | ignore_label = 11 22 | root = cfg.DATASET.CAMVID_DIR 23 | 24 | palette = [128, 128, 128, 25 | 128, 0, 0, 26 | 192, 192, 128, 27 | 128, 64, 128, 28 | 0, 0, 192, 29 | 128, 128, 0, 30 | 192, 128, 128, 31 | 64, 64, 128, 32 | 64, 0, 128, 33 | 64, 64, 0, 34 | 0, 128, 192] 35 | 36 | 37 | CAMVID_CLASSES = ['Sky', 38 | 'Building', 39 | 'Column-Pole', 40 | 'Road', 41 | 'Sidewalk', 42 | 'Tree', 43 | 'Sign-Symbol', 44 | 'Fence', 45 | 'Car', 46 | 'Pedestrain', 47 | 'Bicyclist', 48 | 'Void'] 49 | 50 | CAMVID_CLASS_COLORS = [ 51 | (128, 128, 128), 52 | (128, 0, 0), 53 | (192, 192, 128), 54 | (128, 64, 128), 55 | (0, 0, 192), 56 | (128, 128, 0), 57 | (192, 128, 128), 58 | (64, 64, 128), 59 | (64, 0, 128), 60 | (64, 64, 0), 61 | (0, 128, 192), 62 | (0, 0, 0), 63 | ] 64 | 65 | zero_pad = 256 * 3 - len(palette) 66 | for i in range(zero_pad): 67 | palette.append(0) 68 | 69 | def colorize_mask(mask): 70 | # mask: numpy array of the mask 71 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 72 | new_mask.putpalette(palette) 73 | return new_mask 74 | 75 | def add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip): 76 | 77 | c_items = os.listdir(img_path) 78 | c_items.sort() 79 | items = [] 80 | aug_items = [] 81 | 82 | for it in c_items: 83 | item = (os.path.join(img_path, it), os.path.join(mask_path, it)) 84 | items.append(item) 85 | if mode != 'test' and maxSkip > 0: 86 | seq_info = it.split("_") 87 | cur_seq_id = seq_info[-1][:-4] 88 | 89 | if seq_info[0] == "0001TP": 90 | prev_seq_id = "%06d" % (int(cur_seq_id) - maxSkip) 91 | next_seq_id = "%06d" % (int(cur_seq_id) + maxSkip) 92 | elif seq_info[0] == "0006R0": 93 | prev_seq_id = "f%05d" % (int(cur_seq_id[1:]) - maxSkip) 94 | next_seq_id = "f%05d" % (int(cur_seq_id[1:]) + maxSkip) 95 | else: 96 | prev_seq_id = "%05d" % (int(cur_seq_id) - maxSkip) 97 | next_seq_id = "%05d" % (int(cur_seq_id) + maxSkip) 98 | 99 | prev_it = seq_info[0] + "_" + prev_seq_id + '.png' 100 | next_it = seq_info[0] + "_" + next_seq_id + '.png' 101 | 102 | prev_item = (os.path.join(aug_img_path, prev_it), os.path.join(aug_mask_path, prev_it)) 103 | next_item = (os.path.join(aug_img_path, next_it), os.path.join(aug_mask_path, next_it)) 104 | if os.path.isfile(prev_item[0]) and os.path.isfile(prev_item[1]): 105 | aug_items.append(prev_item) 106 | if os.path.isfile(next_item[0]) and os.path.isfile(next_item[1]): 107 | aug_items.append(next_item) 108 | return items, aug_items 109 | 110 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0): 111 | 112 | items = [] 113 | aug_items = [] 114 | assert quality == 'semantic' 115 | assert mode in ['train', 'val', 'trainval', 'test'] 116 | 117 | # img_dir_name = "SegNet/CamVid" 118 | original_img_dir = "" 119 | augmented_img_dir = "camvid_aug3/CamVid" 120 | 121 | img_path = os.path.join(root, original_img_dir, 'train') 122 | mask_path = os.path.join(root, original_img_dir, 'trainannot') 123 | aug_img_path = os.path.join(root, augmented_img_dir, 'train') 124 | aug_mask_path = os.path.join(root, augmented_img_dir, 'trainannot') 125 | 126 | train_items, train_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip) 127 | logging.info('Camvid has a total of {} train images'.format(len(train_items))) 128 | 129 | img_path = os.path.join(root, original_img_dir, 'val') 130 | mask_path = os.path.join(root, original_img_dir, 'valannot') 131 | aug_img_path = os.path.join(root, augmented_img_dir, 'val') 132 | aug_mask_path = os.path.join(root, augmented_img_dir, 'valannot') 133 | 134 | val_items, val_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip) 135 | logging.info('Camvid has a total of {} validation images'.format(len(val_items))) 136 | 137 | if mode == 'test': 138 | img_path = os.path.join(root, original_img_dir, 'test') 139 | mask_path = os.path.join(root, original_img_dir, 'testannot') 140 | test_items, test_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip) 141 | logging.info('Camvid has a total of {} test images'.format(len(test_items))) 142 | 143 | if mode == 'train': 144 | items = train_items 145 | elif mode == 'val': 146 | items = val_items 147 | elif mode == 'trainval': 148 | items = train_items + val_items 149 | aug_items = []#train_aug_items + val_aug_items 150 | elif mode == 'test': 151 | items = test_items 152 | aug_items = [] 153 | else: 154 | logging.info('Unknown mode {}'.format(mode)) 155 | sys.exit() 156 | 157 | logging.info('Camvid-{}: {} images'.format(mode, len(items))) 158 | 159 | return items, aug_items 160 | 161 | class CAMVID(data.Dataset): 162 | 163 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None, 164 | transform=None, target_transform=None, dump_images=False, 165 | class_uniform_pct=0, class_uniform_tile=0, test=False, 166 | cv_split=None, scf=None, hardnm=0): 167 | 168 | self.quality = quality 169 | self.mode = mode 170 | self.maxSkip = maxSkip 171 | self.joint_transform_list = joint_transform_list 172 | self.transform = transform 173 | self.target_transform = target_transform 174 | self.dump_images = dump_images 175 | self.class_uniform_pct = class_uniform_pct 176 | self.class_uniform_tile = class_uniform_tile 177 | self.scf = scf 178 | self.hardnm = hardnm 179 | self.cv_split = cv_split 180 | self.centroids = [] 181 | 182 | self.imgs, self.aug_imgs = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm) 183 | assert len(self.imgs), 'Found 0 images, please check the data set' 184 | 185 | # Centroids for GT data 186 | if self.class_uniform_pct > 0: 187 | json_fn = 'camvid_tile{}_cv{}_{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode) 188 | 189 | if os.path.isfile(json_fn): 190 | with open(json_fn, 'r') as json_data: 191 | centroids = json.load(json_data) 192 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 193 | else: 194 | self.centroids = uniform.class_centroids_all( 195 | self.imgs, 196 | num_classes, 197 | id2trainid=None, 198 | tile_size=class_uniform_tile) 199 | with open(json_fn, 'w') as outfile: 200 | json.dump(self.centroids, outfile, indent=4) 201 | 202 | self.fine_centroids = copy.deepcopy(self.centroids) 203 | 204 | # if self.maxSkip > 0: 205 | # json_fn = 'camvid_tile{}_cv{}_{}_skip{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.maxSkip) 206 | # if os.path.isfile(json_fn): 207 | # with open(json_fn, 'r') as json_data: 208 | # centroids = json.load(json_data) 209 | # self.aug_centroids = {int(idx): centroids[idx] for idx in centroids} 210 | # else: 211 | # self.aug_centroids = uniform.class_centroids_all( 212 | # self.aug_imgs, 213 | # num_classes, 214 | # id2trainid=None, 215 | # tile_size=class_uniform_tile) 216 | # with open(json_fn, 'w') as outfile: 217 | # json.dump(self.aug_centroids, outfile, indent=4) 218 | 219 | # for class_id in range(num_classes): 220 | # self.centroids[class_id].extend(self.aug_centroids[class_id]) 221 | 222 | self.build_epoch() 223 | 224 | def build_epoch(self, cut=False): 225 | 226 | if self.class_uniform_pct > 0: 227 | if cut: 228 | self.imgs_uniform = uniform.build_epoch(self.imgs, 229 | self.fine_centroids, 230 | num_classes, 231 | cfg.CLASS_UNIFORM_PCT) 232 | else: 233 | self.imgs_uniform = uniform.build_epoch(self.imgs, 234 | self.centroids, 235 | num_classes, 236 | cfg.CLASS_UNIFORM_PCT) 237 | else: 238 | self.imgs_uniform = self.imgs 239 | 240 | 241 | def __getitem__(self, index): 242 | elem = self.imgs_uniform[index] 243 | centroid = None 244 | if len(elem) == 4: 245 | img_path, mask_path, centroid, class_id = elem 246 | else: 247 | img_path, mask_path = elem 248 | 249 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 250 | img_name = os.path.splitext(os.path.basename(img_path))[0] 251 | 252 | # Image Transformations 253 | if self.joint_transform_list is not None: 254 | for idx, xform in enumerate(self.joint_transform_list): 255 | if idx == 0 and centroid is not None: 256 | # HACK 257 | # We assume that the first transform is capable of taking 258 | # in a centroid 259 | img, mask = xform(img, mask, centroid) 260 | else: 261 | img, mask = xform(img, mask) 262 | 263 | # Debug 264 | if self.dump_images and centroid is not None: 265 | outdir = './dump_imgs_{}'.format(self.mode) 266 | os.makedirs(outdir, exist_ok=True) 267 | dump_img_name = trainid_to_name[class_id] + '_' + img_name 268 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 269 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 270 | mask_img = colorize_mask(np.array(mask)) 271 | img.save(out_img_fn) 272 | mask_img.save(out_msk_fn) 273 | 274 | if self.transform is not None: 275 | img = self.transform(img) 276 | if self.target_transform is not None: 277 | mask = self.target_transform(mask) 278 | 279 | return img, mask, img_name 280 | 281 | def __len__(self): 282 | return len(self.imgs_uniform) 283 | 284 | 285 | class CAMVIDWithPos(data.Dataset): 286 | 287 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None, 288 | transform=None, target_transform=None, target_aux_transform=None, dump_images=False, 289 | class_uniform_pct=0, class_uniform_tile=0, test=False, 290 | cv_split=None, scf=None, hardnm=0, pos_rfactor=8): 291 | 292 | self.quality = quality 293 | self.mode = mode 294 | self.maxSkip = maxSkip 295 | self.joint_transform_list = joint_transform_list 296 | self.transform = transform 297 | self.target_transform = target_transform 298 | self.target_aux_transform = target_aux_transform 299 | self.dump_images = dump_images 300 | self.class_uniform_pct = class_uniform_pct 301 | self.class_uniform_tile = class_uniform_tile 302 | self.scf = scf 303 | self.hardnm = hardnm 304 | self.cv_split = cv_split 305 | self.centroids = [] 306 | self.pos_rfactor = pos_rfactor 307 | 308 | # position information 309 | self.pos_h = torch.arange(0, 1024).unsqueeze(0).unsqueeze(2).expand(-1,-1,2048)//8 310 | self.pos_w = torch.arange(0, 2048).unsqueeze(0).unsqueeze(1).expand(-1,1024,-1)//16 311 | self.pos_h = self.pos_h[0].byte().numpy() 312 | self.pos_w = self.pos_w[0].byte().numpy() 313 | # pos index to image 314 | self.pos_h = Image.fromarray(self.pos_h, mode="L") 315 | self.pos_w = Image.fromarray(self.pos_w, mode="L") 316 | # position information 317 | 318 | 319 | self.imgs, self.aug_imgs = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm) 320 | assert len(self.imgs), 'Found 0 images, please check the data set' 321 | 322 | # Centroids for GT data 323 | if self.class_uniform_pct > 0: 324 | json_fn = 'camvid_tile{}_cv{}_{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode) 325 | 326 | if os.path.isfile(json_fn): 327 | with open(json_fn, 'r') as json_data: 328 | centroids = json.load(json_data) 329 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 330 | else: 331 | self.centroids = uniform.class_centroids_all( 332 | self.imgs, 333 | num_classes, 334 | id2trainid=None, 335 | tile_size=class_uniform_tile) 336 | with open(json_fn, 'w') as outfile: 337 | json.dump(self.centroids, outfile, indent=4) 338 | 339 | self.fine_centroids = copy.deepcopy(self.centroids) 340 | 341 | if self.maxSkip > 0: 342 | json_fn = 'camvid_tile{}_cv{}_{}_skip{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.maxSkip) 343 | if os.path.isfile(json_fn): 344 | with open(json_fn, 'r') as json_data: 345 | centroids = json.load(json_data) 346 | self.aug_centroids = {int(idx): centroids[idx] for idx in centroids} 347 | else: 348 | self.aug_centroids = uniform.class_centroids_all( 349 | self.aug_imgs, 350 | num_classes, 351 | id2trainid=None, 352 | tile_size=class_uniform_tile) 353 | with open(json_fn, 'w') as outfile: 354 | json.dump(self.aug_centroids, outfile, indent=4) 355 | 356 | for class_id in range(num_classes): 357 | self.centroids[class_id].extend(self.aug_centroids[class_id]) 358 | 359 | self.build_epoch() 360 | 361 | def build_epoch(self, cut=False): 362 | 363 | if self.class_uniform_pct > 0: 364 | if cut: 365 | self.imgs_uniform = uniform.build_epoch(self.imgs, 366 | self.fine_centroids, 367 | num_classes, 368 | cfg.CLASS_UNIFORM_PCT) 369 | else: 370 | self.imgs_uniform = uniform.build_epoch(self.imgs, 371 | self.centroids, 372 | num_classes, 373 | cfg.CLASS_UNIFORM_PCT) 374 | else: 375 | self.imgs_uniform = self.imgs 376 | 377 | 378 | def __getitem__(self, index): 379 | elem = self.imgs_uniform[index] 380 | centroid = None 381 | if len(elem) == 4: 382 | img_path, mask_path, centroid, class_id = elem 383 | else: 384 | img_path, mask_path = elem 385 | 386 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 387 | img_name = os.path.splitext(os.path.basename(img_path))[0] 388 | 389 | # position information 390 | pos_h = self.pos_h 391 | pos_w = self.pos_w 392 | # position information 393 | 394 | # Image Transformations 395 | if self.joint_transform_list is not None: 396 | for idx, xform in enumerate(self.joint_transform_list): 397 | if idx == 0 and centroid is not None: 398 | # HACK 399 | # We assume that the first transform is capable of taking 400 | # in a centroid 401 | # img, mask = xform(img, mask, centroid) 402 | img, mask, (pos_h, pos_w) = xform(img, mask, centroid, pos=(pos_h, pos_w)) 403 | else: 404 | # img, mask = xform(img, mask) 405 | img, mask, (pos_h, pos_w) = xform(img, mask, pos=(pos_h, pos_w)) 406 | 407 | # Debug 408 | if self.dump_images and centroid is not None: 409 | outdir = './dump_imgs_{}'.format(self.mode) 410 | os.makedirs(outdir, exist_ok=True) 411 | dump_img_name = trainid_to_name[class_id] + '_' + img_name 412 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 413 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 414 | mask_img = colorize_mask(np.array(mask)) 415 | img.save(out_img_fn) 416 | mask_img.save(out_msk_fn) 417 | 418 | if self.transform is not None: 419 | img = self.transform(img) 420 | if self.target_aux_transform is not None: 421 | mask_aux = self.target_aux_transform(mask) 422 | else: 423 | mask_aux = torch.tensor([0]) 424 | if self.target_transform is not None: 425 | mask = self.target_transform(mask) 426 | 427 | pos_h = torch.from_numpy(np.array(pos_h, dtype=np.uint8))# // self.pos_rfactor 428 | pos_w = torch.from_numpy(np.array(pos_w, dtype=np.uint8))# // self.pos_rfactor 429 | 430 | return img, mask, img_name, mask_aux, (pos_h, pos_w) 431 | 432 | def __len__(self): 433 | return len(self.imgs_uniform) 434 | 435 | 436 | --------------------------------------------------------------------------------