├── utils
├── __init__.py
├── attr_dict.py
├── my_data_parallel.py
└── misc.py
├── transforms
├── __init__.py
└── transforms.py
├── assets
├── 6529-teaser.gif
├── 6529-architecture.png
└── youtube_capture_p.png
├── scripts
├── eval_r101_os8.sh
├── eval_resnext_os4.sh
├── eval_resnext_os8.sh
├── submit_r101_os8.sh
├── submit_cityscapes_resnext.sh
├── train_m_os16_hanet.sh
├── train_r101_os16_hanet.sh
├── train_r101_os8_hanet.sh
├── train_resnext_pretrain.sh
├── train_r101_os8_hanet_best.sh
└── train_resnext_fr_pretrain.sh
├── datasets
├── nullloader.py
├── sampler.py
├── mapillary.py
├── kitti.py
├── uniform.py
├── cityscapes_labels.py
└── camvid.py
├── network
├── __init__.py
├── mynn.py
├── HANet.py
├── PosEmbedding.py
├── Resnet.py
├── wider_resnet.py
└── SEresnext.py
├── config.py
├── README.md
├── loss.py
└── optimizer.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/transforms/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/6529-teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shachoi/HANet/HEAD/assets/6529-teaser.gif
--------------------------------------------------------------------------------
/assets/6529-architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shachoi/HANet/HEAD/assets/6529-architecture.png
--------------------------------------------------------------------------------
/assets/youtube_capture_p.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shachoi/HANet/HEAD/assets/youtube_capture_p.png
--------------------------------------------------------------------------------
/scripts/eval_r101_os8.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \
5 | --dataset cityscapes \
6 | --arch network.deepv3.DeepR101V3PlusD_HANet_OS8 \
7 | --inference_mode sliding \
8 | --scales 0.5,1.0,2.0 \
9 | --split val \
10 | --cv_split 0 \
11 | --ckpt_path ${2} \
12 | --snapshot ${1} \
13 | --pos_rfactor 8 \
14 | --hanet 1 1 1 1 0 \
15 | --hanet_set 3 64 3 \
16 | --hanet_pos 2 1 \
17 | --dropout 0.1 \
18 | --aux_loss \
19 |
--------------------------------------------------------------------------------
/scripts/eval_resnext_os4.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \
5 | --dataset cityscapes \
6 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS4 \
7 | --inference_mode sliding \
8 | --scales 0.5,1.0,2.0 \
9 | --split val \
10 | --cv_split 0 \
11 | --ckpt_path ${2} \
12 | --snapshot ${1} \
13 | --pos_rfactor 8 \
14 | --hanet 1 1 1 1 0 \
15 | --hanet_set 3 64 3 \
16 | --hanet_pos 2 1 \
17 | --dropout 0.1 \
18 | --aux_loss \
19 |
--------------------------------------------------------------------------------
/scripts/eval_resnext_os8.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \
5 | --dataset cityscapes \
6 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS8 \
7 | --inference_mode sliding \
8 | --scales 0.5,1.0,2.0 \
9 | --split val \
10 | --cv_split 0 \
11 | --ckpt_path ${2} \
12 | --snapshot ${1} \
13 | --pos_rfactor 8 \
14 | --hanet 1 1 1 1 0 \
15 | --hanet_set 3 64 3 \
16 | --hanet_pos 2 1 \
17 | --dropout 0.1 \
18 | --aux_loss \
19 |
--------------------------------------------------------------------------------
/scripts/submit_r101_os8.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \
5 | --dataset cityscapes \
6 | --arch network.deepv3.DeepR101V3PlusD_HANet_OS8 \
7 | --inference_mode sliding \
8 | --scales 0.5,1.0,2.0 \
9 | --split test \
10 | --cv_split 0 \
11 | --ckpt_path ${2} \
12 | --snapshot ${1} \
13 | --pos_rfactor 8 \
14 | --hanet 1 1 1 1 0 \
15 | --hanet_set 3 64 3 \
16 | --hanet_pos 2 1 \
17 | --dropout 0.3 \
18 | --aux_loss \
19 | --dump_images \
20 |
--------------------------------------------------------------------------------
/scripts/submit_cityscapes_resnext.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 | python -m torch.distributed.launch --nproc_per_node=1 eval.py \
5 | --dataset cityscapes \
6 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS8 \
7 | --inference_mode sliding \
8 | --scales 0.5,1.0,2.0 \
9 | --split test \
10 | --cv_split 0 \
11 | --ckpt_path ${2} \
12 | --snapshot ${1} \
13 | --pos_rfactor 8 \
14 | --hanet 1 1 1 1 0 \
15 | --hanet_set 3 64 3 \
16 | --hanet_pos 2 1 \
17 | --pos_noise 0.0 \
18 | --dropout 0.1 \
19 | --aux_loss \
20 | --dump_images \
21 |
--------------------------------------------------------------------------------
/datasets/nullloader.py:
--------------------------------------------------------------------------------
1 | """
2 | Null Loader
3 | """
4 | import numpy as np
5 | import torch
6 | from torch.utils import data
7 |
8 | num_classes = 19
9 | ignore_label = 255
10 |
11 | class NullLoader(data.Dataset):
12 | """
13 | Null Dataset for Performance
14 | """
15 | def __init__(self,crop_size):
16 | self.imgs = range(200)
17 | self.crop_size = crop_size
18 |
19 | def __getitem__(self, index):
20 | #Return img, mask, name
21 | return torch.FloatTensor(np.zeros((3,self.crop_size,self.crop_size))), torch.LongTensor(np.zeros((self.crop_size,self.crop_size))), 'img' + str(index)
22 |
23 | def __len__(self):
24 | return len(self.imgs)
--------------------------------------------------------------------------------
/scripts/train_m_os16_hanet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Example on Cityscapes
3 | python -m torch.distributed.launch --nproc_per_node=1 train.py \
4 | --dataset cityscapes \
5 | --arch network.deepv3.DeepMobileNetV3PlusD_HANet \
6 | --city_mode 'train' \
7 | --lr_schedule poly \
8 | --lr 0.01 \
9 | --poly_exp 0.9 \
10 | --hanet_lr 0.01 \
11 | --hanet_poly_exp 0.9 \
12 | --max_cu_epoch 10000 \
13 | --class_uniform_pct 0.5 \
14 | --class_uniform_tile 1024 \
15 | --syncbn \
16 | --sgd \
17 | --crop_size 768 \
18 | --scale_min 0.5 \
19 | --scale_max 2.0 \
20 | --rrotate 0 \
21 | --color_aug 0.25 \
22 | --gblur \
23 | --max_iter 40000 \
24 | --bs_mult 8 \
25 | --hanet 1 1 1 1 0 \
26 | --hanet_set 3 16 3 \
27 | --hanet_pos 2 1 \
28 | --pos_rfactor 8 \
29 | --dropout 0.1 \
30 | --pos_noise 0.3 \
31 | --aux_loss \
32 | --date 0101 \
33 | --exp m_os16_hanet \
34 | --ckpt ./logs/ \
35 | --tb_path ./logs/
36 |
--------------------------------------------------------------------------------
/scripts/train_r101_os16_hanet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Example on Cityscapes
3 | python -m torch.distributed.launch --nproc_per_node=2 train.py \
4 | --dataset cityscapes \
5 | --arch network.deepv3.DeepR101V3PlusD_HANet \
6 | --city_mode 'train' \
7 | --lr_schedule poly \
8 | --lr 0.01 \
9 | --poly_exp 0.9 \
10 | --hanet_lr 0.01 \
11 | --hanet_poly_exp 0.9 \
12 | --max_cu_epoch 10000 \
13 | --class_uniform_pct 0.5 \
14 | --class_uniform_tile 1024 \
15 | --syncbn \
16 | --sgd \
17 | --crop_size 768 \
18 | --scale_min 0.5 \
19 | --scale_max 2.0 \
20 | --rrotate 0 \
21 | --color_aug 0.25 \
22 | --gblur \
23 | --max_iter 40000 \
24 | --bs_mult 4 \
25 | --hanet 1 1 1 1 0 \
26 | --hanet_set 3 32 3 \
27 | --hanet_pos 2 1 \
28 | --pos_rfactor 8 \
29 | --dropout 0.1 \
30 | --pos_noise 0.5 \
31 | --aux_loss \
32 | --date 0101 \
33 | --exp r101_os16_hanet \
34 | --ckpt ./logs/ \
35 | --tb_path ./logs/
36 |
--------------------------------------------------------------------------------
/scripts/train_r101_os8_hanet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Example on Cityscapes
3 | python -m torch.distributed.launch --nproc_per_node=2 train.py \
4 | --dataset cityscapes \
5 | --arch network.deepv3.DeepR101V3PlusD_HANet_OS8 \
6 | --city_mode 'train' \
7 | --lr_schedule poly \
8 | --lr 0.01 \
9 | --poly_exp 0.9 \
10 | --hanet_lr 0.01 \
11 | --hanet_poly_exp 0.9 \
12 | --max_cu_epoch 10000 \
13 | --class_uniform_pct 0.5 \
14 | --class_uniform_tile 1024 \
15 | --syncbn \
16 | --sgd \
17 | --crop_size 768 \
18 | --scale_min 0.5 \
19 | --scale_max 2.0 \
20 | --rrotate 0 \
21 | --color_aug 0.25 \
22 | --gblur \
23 | --max_iter 40000 \
24 | --bs_mult 4 \
25 | --hanet 1 1 1 1 0 \
26 | --hanet_set 3 64 3 \
27 | --hanet_pos 2 1 \
28 | --pos_rfactor 8 \
29 | --dropout 0.1 \
30 | --pos_noise 0.5 \
31 | --aux_loss \
32 | --date 0104 \
33 | --exp r101_os8_hanet_64_01 \
34 | --ckpt ./logs/ \
35 | --tb_path ./logs/
36 |
--------------------------------------------------------------------------------
/scripts/train_resnext_pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Example on Cityscapes
3 | python -m torch.distributed.launch --nproc_per_node=6 train.py \
4 | --dataset mapillary \
5 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS8 \
6 | --city_mode 'train' \
7 | --lr_schedule poly \
8 | --lr 0.01 \
9 | --poly_exp 0.9 \
10 | --hanet_lr 0.01 \
11 | --hanet_poly_exp 0.9 \
12 | --max_cu_epoch 10000 \
13 | --class_uniform_pct 0.5 \
14 | --class_uniform_tile 1024 \
15 | --syncbn \
16 | --sgd \
17 | --crop_size 864 \
18 | --scale_min 0.5 \
19 | --scale_max 2.0 \
20 | --rrotate 0 \
21 | --color_aug 0.25 \
22 | --gblur \
23 | --max_iter 200000 \
24 | --bs_mult 2 \
25 | --hanet 0 0 0 0 0 \
26 | --hanet_set 3 64 3 \
27 | --hanet_pos 2 1 \
28 | --no_pos_dataset \
29 | --dropout 0.1 \
30 | --pos_noise 0.5 \
31 | --img_wt_loss \
32 | --date 0101 \
33 | --exp resnext_pretrain \
34 | --ckpt ./logs/ \
35 | --tb_path ./logs/
36 |
--------------------------------------------------------------------------------
/scripts/train_r101_os8_hanet_best.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Example on Cityscapes
3 | python -m torch.distributed.launch --nproc_per_node=4 train.py \
4 | --dataset cityscapes \
5 | --arch network.deepv3.DeepR101V3PlusD_HANet_OS8 \
6 | --city_mode 'trainval' \
7 | --lr_schedule poly \
8 | --lr 0.01 \
9 | --poly_exp 0.9 \
10 | --hanet_lr 0.01 \
11 | --hanet_poly_exp 0.9 \
12 | --max_cu_epoch 10000 \
13 | --class_uniform_pct 0.5 \
14 | --class_uniform_tile 1024 \
15 | --syncbn \
16 | --sgd \
17 | --crop_size 864 \
18 | --scale_min 0.5 \
19 | --scale_max 2.0 \
20 | --rrotate 0 \
21 | --color_aug 0.25 \
22 | --gblur \
23 | --max_iter 90000 \
24 | --bs_mult 3 \
25 | --hanet 1 1 1 1 0 \
26 | --hanet_set 3 64 3 \
27 | --hanet_pos 2 1 \
28 | --pos_rfactor 8 \
29 | --dropout 0.1 \
30 | --pos_noise 0.5 \
31 | --aux_loss \
32 | --cls_wt_loss \
33 | --date 0101 \
34 | --exp r101_os8_hanet_best \
35 | --ckpt ./logs/ \
36 | --tb_path ./logs/
37 |
--------------------------------------------------------------------------------
/scripts/train_resnext_fr_pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Example on Cityscapes
3 | python -m torch.distributed.launch --nproc_per_node=6 train.py \
4 | --dataset cityscapes \
5 | --arch network.deepv3.DeepResNext101V3PlusD_HANet_OS8 \
6 | --snapshot pretrained/resnext_mapillary_0.47475.pth \
7 | --city_mode 'trainval' \
8 | --lr_schedule poly \
9 | --lr 0.01 \
10 | --poly_exp 0.9 \
11 | --hanet_lr 0.01 \
12 | --hanet_poly_exp 0.9 \
13 | --max_cu_epoch 10000 \
14 | --class_uniform_pct 0.5 \
15 | --class_uniform_tile 1024 \
16 | --coarse_boost_classes 14,15,16,3,12,17,4 \
17 | --syncbn \
18 | --sgd \
19 | --crop_size 864 \
20 | --scale_min 0.5 \
21 | --scale_max 2.0 \
22 | --rrotate 0 \
23 | --color_aug 0.25 \
24 | --gblur \
25 | --max_iter 90000 \
26 | --bs_mult 2 \
27 | --hanet 1 1 1 1 0 \
28 | --hanet_set 3 64 3 \
29 | --hanet_pos 2 1 \
30 | --pos_rfactor 8 \
31 | --dropout 0.1 \
32 | --pos_noise 0.5 \
33 | --aux_loss \
34 | --cls_wt_loss \
35 | --date 0101 \
36 | --exp resnext_from_pretrain \
37 | --ckpt ./logs/ \
38 | --tb_path ./logs/
39 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Network Initializations
3 | """
4 |
5 | import logging
6 | import importlib
7 | import torch
8 |
9 |
10 |
11 | def get_net(args, criterion, criterion_aux=None):
12 | """
13 | Get Network Architecture based on arguments provided
14 | """
15 | net = get_model(args=args, num_classes=args.dataset_cls.num_classes,
16 | criterion=criterion, criterion_aux=criterion_aux)
17 | num_params = sum([param.nelement() for param in net.parameters()])
18 | logging.info('Model params = {:2.3f}M'.format(num_params / 1000000))
19 |
20 | net = net.cuda()
21 | return net
22 |
23 |
24 | def warp_network_in_dataparallel(net, gpuid):
25 | """
26 | Wrap the network in Dataparallel
27 | """
28 | # torch.cuda.set_device(gpuid)
29 | # net.cuda(gpuid)
30 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid], find_unused_parameters=True)
31 | # net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid])#, find_unused_parameters=True)
32 | return net
33 |
34 |
35 | def get_model(args, num_classes, criterion, criterion_aux=None):
36 | """
37 | Fetch Network Function Pointer
38 | """
39 | network = args.arch
40 | module = network[:network.rfind('.')]
41 | model = network[network.rfind('.') + 1:]
42 | mod = importlib.import_module(module)
43 | net_func = getattr(mod, model)
44 | net = net_func(args=args, num_classes=num_classes, criterion=criterion, criterion_aux=criterion_aux)
45 | return net
46 |
--------------------------------------------------------------------------------
/utils/attr_dict.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py
4 |
5 | Source License
6 | # Copyright (c) 2017-present, Facebook, Inc.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | ##############################################################################
20 | #
21 | # Based on:
22 | # --------------------------------------------------------
23 | # Fast R-CNN
24 | # Copyright (c) 2015 Microsoft
25 | # Licensed under The MIT License [see LICENSE for details]
26 | # Written by Ross Girshick
27 | # --------------------------------------------------------
28 | """
29 |
30 | class AttrDict(dict):
31 |
32 | IMMUTABLE = '__immutable__'
33 |
34 | def __init__(self, *args, **kwargs):
35 | super(AttrDict, self).__init__(*args, **kwargs)
36 | self.__dict__[AttrDict.IMMUTABLE] = False
37 |
38 | def __getattr__(self, name):
39 | if name in self.__dict__:
40 | return self.__dict__[name]
41 | elif name in self:
42 | return self[name]
43 | else:
44 | raise AttributeError(name)
45 |
46 | def __setattr__(self, name, value):
47 | if not self.__dict__[AttrDict.IMMUTABLE]:
48 | if name in self.__dict__:
49 | self.__dict__[name] = value
50 | else:
51 | self[name] = value
52 | else:
53 | raise AttributeError(
54 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'.
55 | format(name, value)
56 | )
57 |
58 | def immutable(self, is_immutable):
59 | """Set immutability to is_immutable and recursively apply the setting
60 | to all nested AttrDicts.
61 | """
62 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable
63 | # Recursively set immutable state
64 | for v in self.__dict__.values():
65 | if isinstance(v, AttrDict):
66 | v.immutable(is_immutable)
67 | for v in self.values():
68 | if isinstance(v, AttrDict):
69 | v.immutable(is_immutable)
70 |
71 | def is_immutable(self):
72 | return self.__dict__[AttrDict.IMMUTABLE]
73 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py
4 |
5 | Source License
6 | # Copyright (c) 2017-present, Facebook, Inc.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | ##############################################################################
20 | #
21 | # Based on:
22 | # --------------------------------------------------------
23 | # Fast R-CNN
24 | # Copyright (c) 2015 Microsoft
25 | # Licensed under The MIT License [see LICENSE for details]
26 | # Written by Ross Girshick
27 | # --------------------------------------------------------
28 | """
29 | ##############################################################################
30 | #Config
31 | ##############################################################################
32 |
33 |
34 | from __future__ import absolute_import
35 | from __future__ import division
36 | from __future__ import print_function
37 | from __future__ import unicode_literals
38 |
39 |
40 | import torch
41 |
42 |
43 | from utils.attr_dict import AttrDict
44 |
45 |
46 | __C = AttrDict()
47 | cfg = __C
48 | __C.ITER = 0
49 | __C.EPOCH = 0
50 | # Use Class Uniform Sampling to give each class proper sampling
51 | __C.CLASS_UNIFORM_PCT = 0.0
52 |
53 | # Use class weighted loss per batch to increase loss for low pixel count classes per batch
54 | __C.BATCH_WEIGHTING = False
55 |
56 | # Border Relaxation Count
57 | __C.BORDER_WINDOW = 1
58 | # Number of epoch to use before turn off border restriction
59 | __C.REDUCE_BORDER_ITER = -1
60 | __C.REDUCE_BORDER_EPOCH = -1
61 | # Comma Seperated List of class id to relax
62 | __C.STRICTBORDERCLASS = None
63 |
64 |
65 |
66 | #Attribute Dictionary for Dataset
67 | __C.DATASET = AttrDict()
68 | #Cityscapes Dir Location
69 | __C.DATASET.CITYSCAPES_DIR = '/home/nas_datasets/segmentation/cityscapes'
70 | #SDC Augmented Cityscapes Dir Location
71 | __C.DATASET.CITYSCAPES_AUG_DIR = ''
72 | #Mapillary Dataset Dir Location
73 | __C.DATASET.MAPILLARY_DIR = '/home/nas_datasets/segmentation/mapillary'
74 | #GTAV, BDD100K Dataset Dir Location
75 | __C.DATASET.GTAV_DIR = '/home/nas_datasets/segmentation/gtav'
76 | __C.DATASET.BDD_DIR = '/home/nas_datasets/segmentation/bdd100k/seg'
77 | #Kitti Dataset Dir Location
78 | __C.DATASET.KITTI_DIR = ''
79 | #SDC Augmented Kitti Dataset Dir Location
80 | __C.DATASET.KITTI_AUG_DIR = ''
81 | #Camvid Dataset Dir Location
82 | __C.DATASET.CAMVID_DIR = '/home/nas_datasets/segmentation/SegNet-Tutorial/CamVid'
83 | #Number of splits to support
84 | __C.DATASET.CV_SPLITS = 3
85 |
86 |
87 | __C.MODEL = AttrDict()
88 | __C.MODEL.BN = 'pytorch-syncnorm'
89 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm
90 |
91 | def assert_and_infer_cfg(args, make_immutable=True, train_mode=True):
92 | """Call this function in your script after you have finished setting all cfg
93 | values that are necessary (e.g., merging a config from a file, merging
94 | command line config options, etc.). By default, this function will also
95 | mark the global cfg as immutable to prevent changing the global cfg settings
96 | during script execution (which can lead to hard to debug errors or code
97 | that's harder to understand than is necessary).
98 | """
99 |
100 | if hasattr(args, 'syncbn') and args.syncbn:
101 | __C.MODEL.BN = 'pytorch-syncnorm'
102 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm
103 | print('Using pytorch sync batch norm')
104 | else:
105 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d
106 | print('Using regular batch norm')
107 |
108 | if not train_mode:
109 | cfg.immutable(True)
110 | return
111 | if args.class_uniform_pct:
112 | cfg.CLASS_UNIFORM_PCT = args.class_uniform_pct
113 |
114 | if args.batch_weighting:
115 | __C.BATCH_WEIGHTING = True
116 |
117 | if args.jointwtborder:
118 | if args.strict_bdr_cls != '':
119 | __C.STRICTBORDERCLASS = [int(i) for i in args.strict_bdr_cls.split(",")]
120 | if args.rlx_off_iter > -1:
121 | __C.REDUCE_BORDER_ITER = args.rlx_off_iter
122 |
123 | if make_immutable:
124 | cfg.immutable(True)
125 |
--------------------------------------------------------------------------------
/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017,
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 |
36 |
37 |
38 | import math
39 | import torch
40 | from torch.distributed import get_world_size, get_rank
41 | from torch.utils.data import Sampler
42 |
43 | class DistributedSampler(Sampler):
44 | """Sampler that restricts data loading to a subset of the dataset.
45 |
46 | It is especially useful in conjunction with
47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
48 | process can pass a DistributedSampler instance as a DataLoader sampler,
49 | and load a subset of the original dataset that is exclusive to it.
50 |
51 | .. note::
52 | Dataset is assumed to be of constant size.
53 |
54 | Arguments:
55 | dataset: Dataset used for sampling.
56 | num_replicas (optional): Number of processes participating in
57 | distributed training.
58 | rank (optional): Rank of the current process within num_replicas.
59 | """
60 |
61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None):
62 | if num_replicas is None:
63 | num_replicas = get_world_size()
64 | if rank is None:
65 | rank = get_rank()
66 | self.dataset = dataset
67 | self.num_replicas = num_replicas
68 | self.rank = rank
69 | self.epoch = 0
70 | self.consecutive_sample = consecutive_sample
71 | self.permutation = permutation
72 | if pad:
73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
74 | else:
75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
76 | self.total_size = self.num_samples * self.num_replicas
77 |
78 | def __iter__(self):
79 | # deterministically shuffle based on epoch
80 | g = torch.Generator()
81 | g.manual_seed(self.epoch)
82 |
83 | if self.permutation:
84 | indices = list(torch.randperm(len(self.dataset), generator=g))
85 | else:
86 | indices = list([x for x in range(len(self.dataset))])
87 |
88 | # add extra samples to make it evenly divisible
89 | if self.total_size > len(indices):
90 | indices += indices[:(self.total_size - len(indices))]
91 |
92 | # subsample
93 | if self.consecutive_sample:
94 | offset = self.num_samples * self.rank
95 | indices = indices[offset:offset + self.num_samples]
96 | else:
97 | indices = indices[self.rank:self.total_size:self.num_replicas]
98 | assert len(indices) == self.num_samples
99 |
100 | return iter(indices)
101 |
102 | def __len__(self):
103 | return self.num_samples
104 |
105 | def set_epoch(self, epoch):
106 | self.epoch = epoch
107 |
108 | def set_num_samples(self):
109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
110 | self.total_size = self.num_samples * self.num_replicas
--------------------------------------------------------------------------------
/network/mynn.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization
3 | """
4 | import torch.nn as nn
5 | import torch
6 | from config import cfg
7 |
8 | def Norm2d(in_channels):
9 | """
10 | Custom Norm Function to allow flexible switching
11 | """
12 | layer = getattr(cfg.MODEL, 'BNFUNC')
13 | normalization_layer = layer(in_channels)
14 | return normalization_layer
15 |
16 |
17 | def freeze_weights(*models):
18 | for model in models:
19 | for k in model.parameters():
20 | k.requires_grad = False
21 |
22 | def unfreeze_weights(*models):
23 | for model in models:
24 | for k in model.parameters():
25 | k.requires_grad = True
26 |
27 | def initialize_weights(*models):
28 | """
29 | Initialize Model Weights
30 | """
31 | for model in models:
32 | for module in model.modules():
33 | if isinstance(module, (nn.Conv2d, nn.Linear)):
34 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
35 | if module.bias is not None:
36 | module.bias.data.zero_()
37 | elif isinstance(module, nn.Conv1d):
38 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
39 | if module.bias is not None:
40 | module.bias.data.zero_()
41 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or \
42 | isinstance(module, nn.GroupNorm) or isinstance(module, nn.SyncBatchNorm):
43 | module.weight.data.fill_(1)
44 | module.bias.data.zero_()
45 |
46 | def initialize_embedding(*models):
47 | """
48 | Initialize Model Weights
49 | """
50 | for model in models:
51 | for module in model.modules():
52 | if isinstance(module, nn.Embedding):
53 | module.weight.data.zero_() #original
54 |
55 |
56 |
57 | def Upsample(x, size):
58 | """
59 | Wrapper Around the Upsample Call
60 | """
61 | return nn.functional.interpolate(x, size=size, mode='bilinear',
62 | align_corners=True)
63 |
64 | def Zero_Masking(input_tensor, mask_org):
65 | output = input_tensor.clone()
66 | output.mul_(mask_org)
67 | return output
68 |
69 | def RandomPosZero_Masking(input_tensor, p=0.5):
70 | output = input_tensor.clone()
71 | noise_b = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3))
72 | noise_u = input_tensor.new().resize_(input_tensor.size(0), input_tensor.size(1), input_tensor.size(2), input_tensor.size(3))
73 | noise_b.bernoulli_(1 - p)
74 | noise_b = noise_b.expand_as(input_tensor)
75 | output.mul_(noise_b)
76 | return output
77 |
78 | def RandomVal_Masking(input_tensor, mask_org):
79 | output = input_tensor.clone()
80 | noise_u = input_tensor.new().resize_(input_tensor.size(0), input_tensor.size(1), input_tensor.size(2), input_tensor.size(3))
81 | mask = (mask_org==0).type(input_tensor.type())
82 | mask = mask.expand_as(input_tensor)
83 | mask = torch.mul(mask, noise_u.uniform_(torch.min(input_tensor).item(), torch.max(input_tensor).item()))
84 | mask_org = mask_org.expand_as(input_tensor)
85 | output.mul_(mask_org)
86 | output.add_(mask)
87 | return output
88 |
89 | def RandomPosVal_Masking(input_tensor, p=0.5):
90 | output = input_tensor.clone()
91 | noise_b = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3))
92 | noise_u = input_tensor.new().resize_(input_tensor.size(0), input_tensor.size(1), input_tensor.size(2), input_tensor.size(3))
93 | mask = noise_b.bernoulli_(1 - p)
94 | mask = (mask==0).type(input_tensor.type())
95 | mask = mask.expand_as(input_tensor)
96 | mask = torch.mul(mask, noise_u.uniform_(torch.min(input_tensor).item(), torch.max(input_tensor).item()))
97 | noise_b = noise_b.expand_as(input_tensor)
98 | output.mul_(noise_b)
99 | output.add_(mask)
100 | return output
101 |
102 | def masking(input_tensor, p=0.5):
103 | output = input_tensor.clone()
104 | noise_b = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3))
105 | noise_u = input_tensor.new().resize_(input_tensor.size(0), 1, input_tensor.size(2), input_tensor.size(3))
106 | mask = noise_b.bernoulli_(1 - p)
107 | mask = (mask==0).type(input_tensor.type())
108 | mask.mul_(noise_u.uniform_(torch.min(input_tensor).item(), torch.max(input_tensor).item()))
109 | # mask.mul_(noise_u.uniform_(5, 10))
110 | noise_b = noise_b.expand_as(input_tensor)
111 | mask = mask.expand_as(input_tensor)
112 | output.mul_(noise_b)
113 | output.add_(mask)
114 | return output
115 |
--------------------------------------------------------------------------------
/network/HANet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from network.mynn import Norm2d, Upsample
6 | from network.PosEmbedding import PosEmbedding1D, PosEncoding1D
7 |
8 |
9 | class HANet_Conv(nn.Module):
10 |
11 | def __init__(self, in_channel, out_channel, kernel_size=3, r_factor=64, layer=3, pos_injection=2, is_encoding=1,
12 | pos_rfactor=8, pooling='mean', dropout_prob=0.0, pos_noise=0.0):
13 | super(HANet_Conv, self).__init__()
14 |
15 | self.pooling = pooling
16 | self.pos_injection = pos_injection
17 | self.layer = layer
18 | self.dropout_prob = dropout_prob
19 | self.sigmoid = nn.Sigmoid()
20 |
21 | if r_factor > 0:
22 | mid_1_channel = math.ceil(in_channel / r_factor)
23 | elif r_factor < 0:
24 | r_factor = r_factor * -1
25 | mid_1_channel = in_channel * r_factor
26 |
27 | if self.dropout_prob > 0:
28 | self.dropout = nn.Dropout2d(self.dropout_prob)
29 |
30 | self.attention_first = nn.Sequential(
31 | nn.Conv1d(in_channels=in_channel, out_channels=mid_1_channel,
32 | kernel_size=1, stride=1, padding=0, bias=False),
33 | Norm2d(mid_1_channel),
34 | nn.ReLU(inplace=True))
35 |
36 | if layer == 2:
37 | self.attention_second = nn.Sequential(
38 | nn.Conv1d(in_channels=mid_1_channel, out_channels=out_channel,
39 | kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True))
40 | elif layer == 3:
41 | mid_2_channel = (mid_1_channel * 2)
42 | self.attention_second = nn.Sequential(
43 | nn.Conv1d(in_channels=mid_1_channel, out_channels=mid_2_channel,
44 | kernel_size=3, stride=1, padding=1, bias=True),
45 | Norm2d(mid_2_channel),
46 | nn.ReLU(inplace=True))
47 | self.attention_third = nn.Sequential(
48 | nn.Conv1d(in_channels=mid_2_channel, out_channels=out_channel,
49 | kernel_size=kernel_size, stride=1, padding=kernel_size//2, bias=True))
50 |
51 | if self.pooling == 'mean':
52 | #print("##### average pooling")
53 | self.rowpool = nn.AdaptiveAvgPool2d((128//pos_rfactor,1))
54 | else:
55 | #print("##### max pooling")
56 | self.rowpool = nn.AdaptiveMaxPool2d((128//pos_rfactor,1))
57 |
58 | if pos_rfactor > 0:
59 | if is_encoding == 0:
60 | if self.pos_injection == 1:
61 | self.pos_emb1d_1st = PosEmbedding1D(pos_rfactor, dim=in_channel, pos_noise=pos_noise)
62 | elif self.pos_injection == 2:
63 | self.pos_emb1d_2nd = PosEmbedding1D(pos_rfactor, dim=mid_1_channel, pos_noise=pos_noise)
64 | elif is_encoding == 1:
65 | if self.pos_injection == 1:
66 | self.pos_emb1d_1st = PosEncoding1D(pos_rfactor, dim=in_channel, pos_noise=pos_noise)
67 | elif self.pos_injection == 2:
68 | self.pos_emb1d_2nd = PosEncoding1D(pos_rfactor, dim=mid_1_channel, pos_noise=pos_noise)
69 | else:
70 | print("Not supported position encoding")
71 | exit()
72 |
73 |
74 | def forward(self, x, out, pos=None, return_attention=False, return_posmap=False, attention_loss=False):
75 | """
76 | inputs :
77 | x : input feature maps( B X C X W X H)
78 | returns :
79 | out : self attention value + input feature
80 | attention: B X N X N (N is Width*Height)
81 | """
82 | H = out.size(2)
83 | x1d = self.rowpool(x).squeeze(3)
84 |
85 | if pos is not None and self.pos_injection == 1:
86 | if return_posmap:
87 | x1d, pos_map1 = self.pos_emb1d_1st(x1d, pos, True)
88 | else:
89 | x1d = self.pos_emb1d_1st(x1d, pos)
90 |
91 | if self.dropout_prob > 0:
92 | x1d = self.dropout(x1d)
93 | x1d = self.attention_first(x1d)
94 |
95 | if pos is not None and self.pos_injection == 2:
96 | if return_posmap:
97 | x1d, pos_map2 = self.pos_emb1d_2nd(x1d, pos, True)
98 | else:
99 | x1d = self.pos_emb1d_2nd(x1d, pos)
100 |
101 | x1d = self.attention_second(x1d)
102 |
103 | if self.layer == 3:
104 | x1d = self.attention_third(x1d)
105 | if attention_loss:
106 | last_attention = x1d
107 | x1d = self.sigmoid(x1d)
108 | else:
109 | if attention_loss:
110 | last_attention = x1d
111 | x1d = self.sigmoid(x1d)
112 |
113 | x1d = F.interpolate(x1d, size=H, mode='linear')
114 | out = torch.mul(out, x1d.unsqueeze(3))
115 |
116 | if return_attention:
117 | if return_posmap:
118 | if self.pos_injection == 1:
119 | pos_map = (pos_map1)
120 | elif self.pos_injection == 2:
121 | pos_map = (pos_map2)
122 | return out, x1d, pos_map
123 | else:
124 | return out, x1d
125 | else:
126 | if attention_loss:
127 | return out, last_attention
128 | else:
129 | return out
130 |
--------------------------------------------------------------------------------
/network/PosEmbedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from network.mynn import Norm2d, Upsample, initialize_embedding
5 | import numpy as np
6 |
7 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
8 | ''' Sinusoid position encoding table '''
9 | def cal_angle(position, hid_idx):
10 | if d_hid > 50:
11 | cycle = 10
12 | elif d_hid > 5:
13 | cycle = 100
14 | else:
15 | cycle = 10000
16 | cycle = 10 if d_hid > 50 else 100
17 | return position / np.power(cycle, 2 * (hid_idx // 2) / d_hid)
18 | def get_posi_angle_vec(position):
19 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
20 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
21 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
22 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
23 | if padding_idx is not None:
24 | # zero vector for padding dimension
25 | sinusoid_table[padding_idx] = 0.
26 | return torch.FloatTensor(sinusoid_table)
27 |
28 | class PosEmbedding2D(nn.Module):
29 |
30 | def __init__(self, pos_rfactor, dim):
31 | super(PosEmbedding2D, self).__init__()
32 |
33 | self.pos_layer_h = nn.Embedding((128//pos_rfactor)+1, dim)
34 | self.pos_layer_w = nn.Embedding((128//pos_rfactor)+1, dim)
35 | initialize_embedding(self.pos_layer_h)
36 | initialize_embedding(self.pos_layer_w)
37 |
38 | def forward(self, x, pos):
39 | pos_h, pos_w = pos
40 | pos_h = pos_h.unsqueeze(1)
41 | pos_w = pos_w.unsqueeze(1)
42 | pos_h = nn.functional.interpolate(pos_h.float(), size=x.shape[2:], mode='nearest').long() # B X 1 X H X W
43 | pos_w = nn.functional.interpolate(pos_w.float(), size=x.shape[2:], mode='nearest').long() # B X 1 X H X W
44 | pos_h = self.pos_layer_h(pos_h).transpose(1,4).squeeze(4) # B X 1 X H X W X C
45 | pos_w = self.pos_layer_w(pos_w).transpose(1,4).squeeze(4) # B X 1 X H X W X C
46 | x = x + pos_h + pos_w
47 | return x
48 |
49 | class PosEncoding1D(nn.Module):
50 |
51 | def __init__(self, pos_rfactor, dim, pos_noise=0.0):
52 | super(PosEncoding1D, self).__init__()
53 | print("use PosEncoding1D")
54 | self.sel_index = torch.tensor([0]).cuda()
55 | pos_enc = (get_sinusoid_encoding_table((128//pos_rfactor)+1, dim) + 1)
56 | self.pos_layer = nn.Embedding.from_pretrained(embeddings=pos_enc, freeze=True)
57 | self.pos_noise = pos_noise
58 | self.noise_clamp = 16 // pos_rfactor # 4: 4, 8: 2, 16: 1
59 |
60 | self.pos_rfactor = pos_rfactor
61 | if pos_noise > 0.0:
62 | self.min = 0.0 #torch.tensor([0]).cuda()
63 | self.max = 128//pos_rfactor #torch.tensor([128//pos_rfactor]).cuda()
64 | self.noise = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([pos_noise]))
65 |
66 | def forward(self, x, pos, return_posmap=False):
67 | pos_h, _ = pos # B X H X W
68 | pos_h = pos_h//self.pos_rfactor
69 | pos_h = pos_h.index_select(2, self.sel_index).unsqueeze(1).squeeze(3) # B X 1 X H
70 | pos_h = nn.functional.interpolate(pos_h.float(), size=x.shape[2], mode='nearest').long() # B X 1 X 48
71 |
72 | if self.training is True and self.pos_noise > 0.0:
73 | #pos_h = pos_h + (self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long()
74 | pos_h = pos_h + torch.clamp((self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long(),
75 | min=-self.noise_clamp, max=self.noise_clamp)
76 | pos_h = torch.clamp(pos_h, min=self.min, max=self.max)
77 | #pos_h = torch.where(pos_h < self.min_tensor, self.min_tensor, pos_h)
78 | #pos_h = torch.where(pos_h > self.max_tensor, self.max_tensor, pos_h)
79 |
80 | pos_h = self.pos_layer(pos_h).transpose(1,3).squeeze(3) # B X 1 X 48 X 80 > B X 80 X 48 X 1
81 | x = x + pos_h
82 | if return_posmap:
83 | return x, self.pos_layer.weight # 33 X 80
84 | return x
85 |
86 | class PosEmbedding1D(nn.Module):
87 |
88 | def __init__(self, pos_rfactor, dim, pos_noise=0.0):
89 | super(PosEmbedding1D, self).__init__()
90 | print("use PosEmbedding1D")
91 | self.sel_index = torch.tensor([0]).cuda()
92 | self.pos_layer = nn.Embedding((128//pos_rfactor)+1, dim)
93 | initialize_embedding(self.pos_layer)
94 | self.pos_noise = pos_noise
95 | self.pos_rfactor = pos_rfactor
96 | self.noise_clamp = 16 // pos_rfactor # 4: 4, 8: 2, 16: 1
97 |
98 | if pos_noise > 0.0:
99 | self.min = 0.0 #torch.tensor([0]).cuda()
100 | self.max = 128//pos_rfactor #torch.tensor([128//pos_rfactor]).cuda()
101 | self.noise = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([pos_noise]))
102 |
103 | def forward(self, x, pos, return_posmap=False):
104 | pos_h, _ = pos # B X H X W
105 | pos_h = pos_h//self.pos_rfactor
106 | pos_h = pos_h.index_select(2, self.sel_index).unsqueeze(1).squeeze(3) # B X 1 X H
107 | pos_h = nn.functional.interpolate(pos_h.float(), size=x.shape[2], mode='nearest').long() # B X 1 X 48
108 |
109 | if self.training is True and self.pos_noise > 0.0:
110 | #pos_h = pos_h + (self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long()
111 | pos_h = pos_h + torch.clamp((self.noise.sample(pos_h.shape).squeeze(3).cuda()//1).long(),
112 | min=-self.noise_clamp, max=self.noise_clamp)
113 | pos_h = torch.clamp(pos_h, min=self.min, max=self.max)
114 |
115 | pos_h = self.pos_layer(pos_h).transpose(1,3).squeeze(3) # B X 1 X 48 X 80 > B X 80 X 48 X 1
116 | x = x + pos_h
117 | if return_posmap:
118 | return x, self.pos_layer.weight # 33 X 80
119 | return x
120 |
--------------------------------------------------------------------------------
/datasets/mapillary.py:
--------------------------------------------------------------------------------
1 | """
2 | Mapillary Dataset Loader
3 | """
4 | from PIL import Image
5 | from torch.utils import data
6 | import os
7 | import numpy as np
8 | import json
9 | import datasets.uniform as uniform
10 | from config import cfg
11 |
12 | num_classes = 65
13 | ignore_label = 65
14 | root = cfg.DATASET.MAPILLARY_DIR
15 | config_fn = os.path.join(root, 'config.json')
16 | id_to_ignore_or_group = {}
17 | color_mapping = []
18 | id_to_trainid = {}
19 |
20 |
21 | def colorize_mask(image_array):
22 | """
23 | Colorize a segmentation mask
24 | """
25 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P')
26 | new_mask.putpalette(color_mapping)
27 | return new_mask
28 |
29 |
30 | def make_dataset(quality, mode):
31 | """
32 | Create File List
33 | """
34 | assert (quality == 'semantic' and mode in ['train', 'val'])
35 | img_dir_name = None
36 | if quality == 'semantic':
37 | if mode == 'train':
38 | img_dir_name = 'training'
39 | if mode == 'val':
40 | img_dir_name = 'validation'
41 | mask_path = os.path.join(root, img_dir_name, 'labels')
42 | else:
43 | raise BaseException("Instance Segmentation Not support")
44 |
45 | img_path = os.path.join(root, img_dir_name, 'images')
46 | print(img_path)
47 | if quality != 'video':
48 | imgs = sorted([os.path.splitext(f)[0] for f in os.listdir(img_path)])
49 | msks = sorted([os.path.splitext(f)[0] for f in os.listdir(mask_path)])
50 | assert imgs == msks
51 |
52 | items = []
53 | c_items = os.listdir(img_path)
54 | if '.DS_Store' in c_items:
55 | c_items.remove('.DS_Store')
56 |
57 | for it in c_items:
58 | if quality == 'video':
59 | item = (os.path.join(img_path, it), os.path.join(img_path, it))
60 | else:
61 | item = (os.path.join(img_path, it),
62 | os.path.join(mask_path, it.replace(".jpg", ".png")))
63 | items.append(item)
64 | return items
65 |
66 |
67 | def gen_colormap():
68 | """
69 | Get Color Map from file
70 | """
71 | global color_mapping
72 |
73 | # load mapillary config
74 | with open(config_fn) as config_file:
75 | config = json.load(config_file)
76 | config_labels = config['labels']
77 |
78 | # calculate label color mapping
79 | colormap = []
80 | id2name = {}
81 | for i in range(0, len(config_labels)):
82 | colormap = colormap + config_labels[i]['color']
83 | id2name[i] = config_labels[i]['readable']
84 | color_mapping = colormap
85 | return id2name
86 |
87 |
88 | class Mapillary(data.Dataset):
89 | def __init__(self, quality, mode, joint_transform_list=None,
90 | transform=None, target_transform=None, dump_images=False,
91 | class_uniform_pct=0, class_uniform_tile=768, test=False):
92 | """
93 | class_uniform_pct = Percent of class uniform samples. 1.0 means fully uniform.
94 | 0.0 means fully random.
95 | class_uniform_tile_size = Class uniform tile size
96 | """
97 | self.quality = quality
98 | self.mode = mode
99 | self.joint_transform_list = joint_transform_list
100 | self.transform = transform
101 | self.target_transform = target_transform
102 | self.dump_images = dump_images
103 | self.class_uniform_pct = class_uniform_pct
104 | self.class_uniform_tile = class_uniform_tile
105 | self.id2name = gen_colormap()
106 | self.imgs_uniform = None
107 | for i in range(num_classes):
108 | id_to_trainid[i] = i
109 |
110 | # find all images
111 | self.imgs = make_dataset(quality, mode)
112 | if len(self.imgs) == 0:
113 | raise RuntimeError('Found 0 images, please check the data set')
114 | if test:
115 | np.random.shuffle(self.imgs)
116 | self.imgs = self.imgs[:200]
117 |
118 | if self.class_uniform_pct:
119 | json_fn = 'mapillary_tile{}.json'.format(self.class_uniform_tile)
120 | if os.path.isfile(json_fn):
121 | with open(json_fn, 'r') as json_data:
122 | centroids = json.load(json_data)
123 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
124 | else:
125 | # centroids is a dict (indexed by class) of lists of centroids
126 | self.centroids = uniform.class_centroids_all(
127 | self.imgs,
128 | num_classes,
129 | id2trainid=None,
130 | tile_size=self.class_uniform_tile)
131 | with open(json_fn, 'w') as outfile:
132 | json.dump(self.centroids, outfile, indent=4)
133 | else:
134 | self.centroids = []
135 | self.build_epoch()
136 |
137 | def build_epoch(self):
138 | if self.class_uniform_pct != 0:
139 | self.imgs_uniform = uniform.build_epoch(self.imgs,
140 | self.centroids,
141 | num_classes,
142 | self.class_uniform_pct)
143 | else:
144 | self.imgs_uniform = self.imgs
145 |
146 | def __getitem__(self, index):
147 | if len(self.imgs_uniform[index]) == 2:
148 | img_path, mask_path = self.imgs_uniform[index]
149 | centroid = None
150 | class_id = None
151 | else:
152 | img_path, mask_path, centroid, class_id = self.imgs_uniform[index]
153 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
154 | img_name = os.path.splitext(os.path.basename(img_path))[0]
155 |
156 | mask = np.array(mask)
157 | mask_copy = mask.copy()
158 | for k, v in id_to_ignore_or_group.items():
159 | mask_copy[mask == k] = v
160 | mask = Image.fromarray(mask_copy.astype(np.uint8))
161 |
162 | # Image Transformations
163 | if self.joint_transform_list is not None:
164 | for idx, xform in enumerate(self.joint_transform_list):
165 | if idx == 0 and centroid is not None:
166 | # HACK! Assume the first transform accepts a centroid
167 | img, mask = xform(img, mask, centroid)
168 | else:
169 | img, mask = xform(img, mask)
170 |
171 | if self.dump_images:
172 | outdir = 'dump_imgs_{}'.format(self.mode)
173 | os.makedirs(outdir, exist_ok=True)
174 | if centroid is not None:
175 | dump_img_name = self.id2name[class_id] + '_' + img_name
176 | else:
177 | dump_img_name = img_name
178 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
179 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
180 | mask_img = colorize_mask(np.array(mask))
181 | img.save(out_img_fn)
182 | mask_img.save(out_msk_fn)
183 |
184 | if self.transform is not None:
185 | img = self.transform(img)
186 | if self.target_transform is not None:
187 | mask = self.target_transform(mask)
188 | return img, mask, img_name
189 |
190 | def __len__(self):
191 | return len(self.imgs_uniform)
192 |
193 | def calculate_weights(self):
194 | raise BaseException("not supported yet")
195 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## HANet: Official Project Webpage
2 | HANet is an add-on module for urban-scene segmentation to exploit the structural priors existing in urban-scene. It is effective and wide applicable!
3 |
4 |
5 |
6 |
7 |
8 | This repository provides the official PyTorch implementation of the following paper:
9 | > **Cars Can’t Fly up in the Sky:** Improving Urban-Scene Segmentation via Height-driven Attention Networks
10 | > Sungha Choi (LGE, Korea Univ.), Joanne T. Kim (LLNL, Korea Univ.), Jaegul Choo (KAIST)
11 | > In CVPR 2020
12 |
13 | > Paper : [[pdf]](http://openaccess.thecvf.com/content_CVPR_2020/papers/Choi_Cars_Cant_Fly_Up_in_the_Sky_Improving_Urban-Scene_Segmentation_CVPR_2020_paper.pdf) [[supp]](http://openaccess.thecvf.com/content_CVPR_2020/supplemental/Choi_Cars_Cant_Fly_CVPR_2020_supplemental.pdf)
14 |
15 |
16 |
17 |
18 |
19 | > **Abstract:** *This paper exploits the intrinsic features of urban-scene images and proposes a general add-on module, called height driven attention networks (HANet), for improving semantic segmentation for urban-scene images. It emphasizes informative features or classes selectively according to the vertical position of a pixel. The pixel-wise class distributions are significantly different from each other among horizontally segmented sections in the urban-scene images. Likewise, urban-scene images have their own distinct characteristics, but most semantic segmentation networks do not reflect such unique attributes in the architecture. The proposed network architecture incorporates the capability exploiting the attributes to handle the urban scene dataset effectively. We validate the consistent performance (mIoU) increase of various semantic segmentation models on two datasets when HANet is adopted. This extensive quantitative analysis demonstrates that adding our module to existing models is easy and cost-effective. Our method achieves a new state-of-the-art performance on the Cityscapes benchmark with a large margin among ResNet-101 based segmentation models. Also, we show that the proposed model is coherent with the facts observed in the urban scene by visualizing and interpreting the attention map*
20 |
21 | ## Concept Video
22 | Click the figure to watch the youtube video of our paper!
23 |
24 |
25 | 
26 |
27 |
28 | ## Cityscapes Benchmark
29 | | Models | Data | Crop Size | Batch Size | Output Stride | mIoU | External Link |
30 | |:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|:--------:|
31 | | HANet (ResNext-101) | Fine train/val + Coarse | 864X864 | 12 | 8 | 83.2% | [Benchmark](https://www.cityscapes-dataset.com/anonymous-results/?id=9a8b7333dcb66360b4f38ba00db7c84e7997f7203084bf6e92ca9bbabbc34640) |
32 | | HANet (ResNet-101) | Fine train/val | 864X864 | 12 | 8 | 82.1% | [Benchmark](https://www.cityscapes-dataset.com/anonymous-results/?id=f96818d678c67c82449323203d144e530fb66102a5b5a101f599a96cc62458e7) |
33 | | HANet (ResNet-101) | Fine train | 768X768 | 8 | 8 | 80.9% | [Benchmark](https://www.cityscapes-dataset.com/anonymous-results/?id=1e5e85818e439332fdae01037259706d9091be2b9fca850eb4a851805f5ed44d) |
34 |
35 |
36 | ## Pytorch Implementation
37 | ### Installation
38 | Clone this repository.
39 | ```
40 | git clone https://github.com/shachoi/HANet.git
41 | cd HANet
42 | ```
43 | Install following packages.
44 | ```
45 | conda create --name hanet python=3.6
46 | conda activate hanet
47 | conda install -y pytorch=1.4.0 torchvision=0.5.0 cudatoolkit=10.1 -c pytorch
48 | conda install scipy==1.4.1
49 | conda install tqdm==4.46.0
50 | conda install scikit-image==0.16.2
51 | pip install tensorboardX==2.0
52 | pip install thop
53 | ```
54 |
55 | ### Datasets
56 | We evaludated HANet on [Cityscapes](https://www.cityscapes-dataset.com/) and [BDD-100K](https://bair.berkeley.edu/blog/2018/05/30/bdd/).
57 |
58 | For Cityscapes dataset, download "leftImg8bit_trainvaltest.zip" and "gtFine_trainvaltest.zip" from https://www.cityscapes-dataset.com/downloads/
59 | Unzip the files and make the directory structures as follows.
60 | ```
61 | cityscapes
62 | └ leftImg8bit_trainvaltest
63 | └ leftImg8bit
64 | └ train
65 | └ val
66 | └ test
67 | └ gtFine_trainvaltest
68 | └ gtFine
69 | └ train
70 | └ val
71 | └ test
72 | ```
73 | You should modify the path in **"/config.py"** according to your Cityscapes dataset path.
74 |
75 | ```
76 | #Cityscapes Dir Location
77 | __C.DATASET.CITYSCAPES_DIR =
78 | ```
79 |
80 | Additionally, you can use Cityscapes coarse dataset to get best mIoU score.
81 |
82 | Please refer the training script **"/scripts/train_resnext_fr_pretrain.sh"**.
83 |
84 | The other training scripts don't use Cityscapes coarse dataset.
85 |
86 | ### Pretrained Models
87 | #### All models trained for our paper
88 | You can download all models evaluated in our paper at [Google Drive](https://drive.google.com/drive/folders/1PfrG1d3fq4T9c96FmfOIkfKPShTTRz2G?usp=sharing)
89 |
90 | #### ImageNet pretrained ResNet-101 which has three 3×3 convolutions in the first layer
91 | To train ResNet-101 based HANet, you should download ImageNet pretrained ResNet-101 from [this link](https://drive.google.com/file/d/1Sx1Clf9Q9BsXKklZuUIqSJJhjMNF3jAa/view?usp=sharing). Put it into following directory.
92 | ```
93 | /pretrained/resnet101-imagenet.pth
94 | ```
95 | This pretrained model is from [MIT CSAIL Computer Vision Group](http://sceneparsing.csail.mit.edu/)
96 |
97 | #### Mapillary pretrained ResNext-101
98 | You can finetune HANet from Mapillary pretrained ResNext-101 using the training script **"/scripts/train_resnext_fr_pretrain.sh"**.
99 | Download it from [this link](https://drive.google.com/file/d/1GJ4VOSiLwNuyqOgRqQoe9FbvnklI2TYe/view?usp=sharing) and put it into following directory.
100 | ```
101 | /pretrained/resnext_mapillary_0.47475.pth
102 | ```
103 | ### Training Networks
104 | According to the specification of your gpu system, you may modify the training script.
105 | ```
106 | python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE train.py \
107 | ...
108 | --bs_mult NUM_BATCH_PER_SINGLE_GPU \
109 | ...
110 | ```
111 | You can train HANet (based on ResNet-101) using **finely annotated training and validation set** with following command.
112 | ```
113 | $ CUDA_VISIBLE_DEVICES=0,1,2,3 ./scripts/train_r101_os8_hanet_best.sh
114 | ```
115 | Otherwise, you can train HANet (based on ResNet-101) using only **finely annotated training set** with following command.
116 | ```
117 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/train_r101_os8_hanet.sh
118 | ```
119 | To run the script "train_r101_os8_hanet.sh", two Titan RTX GPUs (2 X 24GB GPU Memory) are required.
120 |
121 | Additioanlly, we provide various training scripts like MobileNet based HANet.
122 |
123 | The results will be stored in **"/logs/"**
124 | ### Inference
125 | ```
126 | $ CUDA_VISIBLE_DEVICES=0 ./scripts/eval_r101_os8.sh
127 | ```
128 | ### Submit to Cityscapes benchmark server
129 | ```
130 | $ CUDA_VISIBLE_DEVICES=0 ./scripts/submit_r101_os8.sh
131 | ```
132 |
133 | ## Citation
134 | If you find this work useful for your research, please cite our paper:
135 | ```
136 | @InProceedings{Choi_2020_CVPR,
137 | author = {Choi, Sungha and Kim, Joanne T. and Choo, Jaegul},
138 | title = {Cars Can't Fly Up in the Sky: Improving Urban-Scene Segmentation via Height-Driven Attention Networks},
139 | booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
140 | month = {June},
141 | year = {2020}
142 | }
143 | ```
144 |
145 | ## Acknowledgments
146 | Our pytorch implementation is heavily derived from [NVIDIA segmentation](https://github.com/NVIDIA/semantic-segmentation).
147 | Thanks to the NVIDIA implementations.
148 |
--------------------------------------------------------------------------------
/utils/my_data_parallel.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | # Code adapted from:
4 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py
5 | #
6 | # BSD 3-Clause License
7 | #
8 | # Copyright (c) 2017,
9 | # All rights reserved.
10 | #
11 | # Redistribution and use in source and binary forms, with or without
12 | # modification, are permitted provided that the following conditions are met:
13 | #
14 | # * Redistributions of source code must retain the above copyright notice, this
15 | # list of conditions and the following disclaimer.
16 | #
17 | # * Redistributions in binary form must reproduce the above copyright notice,
18 | # this list of conditions and the following disclaimer in the documentation
19 | # and/or other materials provided with the distribution.
20 | #
21 | # * Neither the name of the copyright holder nor the names of its
22 | # contributors may be used to endorse or promote products derived from
23 | # this software without specific prior written permission.
24 | #
25 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s
35 | """
36 |
37 |
38 | import operator
39 | import torch
40 | import warnings
41 | from torch.nn.modules import Module
42 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
43 | from torch.nn.parallel.replicate import replicate
44 | from torch.nn.parallel.parallel_apply import parallel_apply
45 |
46 |
47 | def _check_balance(device_ids):
48 | imbalance_warn = """
49 | There is an imbalance between your GPUs. You may want to exclude GPU {} which
50 | has less than 75% of the memory or cores of GPU {}. You can do so by setting
51 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
52 | environment variable."""
53 |
54 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
55 |
56 | def warn_imbalance(get_prop):
57 | values = [get_prop(props) for props in dev_props]
58 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
59 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
60 | if min_val / max_val < 0.75:
61 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
62 | return True
63 | return False
64 |
65 | if warn_imbalance(lambda props: props.total_memory):
66 | return
67 | if warn_imbalance(lambda props: props.multi_processor_count):
68 | return
69 |
70 |
71 |
72 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, gather=True):
73 | """
74 | Evaluates module(input) in parallel across the GPUs given in device_ids.
75 | This is the functional version of the DataParallel module.
76 | Args:
77 | module: the module to evaluate in parallel
78 | inputs: inputs to the module
79 | device_ids: GPU ids on which to replicate module
80 | output_device: GPU location of the output Use -1 to indicate the CPU.
81 | (default: device_ids[0])
82 | Returns:
83 | a Tensor containing the result of module(input) located on
84 | output_device
85 | """
86 | if not isinstance(inputs, tuple):
87 | inputs = (inputs,)
88 |
89 | if device_ids is None:
90 | device_ids = list(range(torch.cuda.device_count()))
91 |
92 | if output_device is None:
93 | output_device = device_ids[0]
94 |
95 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
96 | if len(device_ids) == 1:
97 | return module(*inputs[0], **module_kwargs[0])
98 | used_device_ids = device_ids[:len(inputs)]
99 | replicas = replicate(module, used_device_ids)
100 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
101 | if gather:
102 | return gather(outputs, output_device, dim)
103 | else:
104 | return outputs
105 |
106 |
107 |
108 | class MyDataParallel(Module):
109 | """
110 | Implements data parallelism at the module level.
111 | This container parallelizes the application of the given module by
112 | splitting the input across the specified devices by chunking in the batch
113 | dimension. In the forward pass, the module is replicated on each device,
114 | and each replica handles a portion of the input. During the backwards
115 | pass, gradients from each replica are summed into the original module.
116 | The batch size should be larger than the number of GPUs used.
117 | See also: :ref:`cuda-nn-dataparallel-instead`
118 | Arbitrary positional and keyword inputs are allowed to be passed into
119 | DataParallel EXCEPT Tensors. All tensors will be scattered on dim
120 | specified (default 0). Primitive types will be broadcasted, but all
121 | other types will be a shallow copy and can be corrupted if written to in
122 | the model's forward pass.
123 | .. warning::
124 | Forward and backward hooks defined on :attr:`module` and its submodules
125 | will be invoked ``len(device_ids)`` times, each with inputs located on
126 | a particular device. Particularly, the hooks are only guaranteed to be
127 | executed in correct order with respect to operations on corresponding
128 | devices. For example, it is not guaranteed that hooks set via
129 | :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
130 | `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
131 | that each such hook be executed before the corresponding
132 | :meth:`~torch.nn.Module.forward` call of that device.
133 | .. warning::
134 | When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
135 | :func:`forward`, this wrapper will return a vector of length equal to
136 | number of devices used in data parallelism, containing the result from
137 | each device.
138 | .. note::
139 | There is a subtlety in using the
140 | ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
141 | :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
142 | See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
143 | details.
144 | Args:
145 | module: module to be parallelized
146 | device_ids: CUDA devices (default: all devices)
147 | output_device: device location of output (default: device_ids[0])
148 | Attributes:
149 | module (Module): the module to be parallelized
150 | Example::
151 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
152 | >>> output = net(input_var)
153 | """
154 |
155 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
156 |
157 | def __init__(self, module, device_ids=None, output_device=None, dim=0, gather=True):
158 | super(MyDataParallel, self).__init__()
159 |
160 | if not torch.cuda.is_available():
161 | self.module = module
162 | self.device_ids = []
163 | return
164 |
165 | if device_ids is None:
166 | device_ids = list(range(torch.cuda.device_count()))
167 | if output_device is None:
168 | output_device = device_ids[0]
169 | self.dim = dim
170 | self.module = module
171 | self.device_ids = device_ids
172 | self.output_device = output_device
173 | self.gather_bool = gather
174 |
175 | _check_balance(self.device_ids)
176 |
177 | if len(self.device_ids) == 1:
178 | self.module.cuda(device_ids[0])
179 |
180 | def forward(self, *inputs, **kwargs):
181 | if not self.device_ids:
182 | return self.module(*inputs, **kwargs)
183 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
184 | if len(self.device_ids) == 1:
185 | return [self.module(*inputs[0], **kwargs[0])]
186 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
187 | outputs = self.parallel_apply(replicas, inputs, kwargs)
188 | if self.gather_bool:
189 | return self.gather(outputs, self.output_device)
190 | else:
191 | return outputs
192 |
193 | def replicate(self, module, device_ids):
194 | return replicate(module, device_ids)
195 |
196 | def scatter(self, inputs, kwargs, device_ids):
197 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
198 |
199 | def parallel_apply(self, replicas, inputs, kwargs):
200 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
201 |
202 | def gather(self, outputs, output_device):
203 | return gather(outputs, output_device, dim=self.dim)
204 |
205 |
--------------------------------------------------------------------------------
/datasets/kitti.py:
--------------------------------------------------------------------------------
1 | """
2 | KITTI Dataset Loader
3 | """
4 |
5 | import os
6 | import sys
7 | import numpy as np
8 | from PIL import Image
9 | from torch.utils import data
10 | import logging
11 | import datasets.uniform as uniform
12 | import datasets.cityscapes_labels as cityscapes_labels
13 | import json
14 | from config import cfg
15 |
16 |
17 | trainid_to_name = cityscapes_labels.trainId2name
18 | id_to_trainid = cityscapes_labels.label2trainid
19 | num_classes = 19
20 | ignore_label = 255
21 | root = cfg.DATASET.KITTI_DIR
22 | aug_root = cfg.DATASET.KITTI_AUG_DIR
23 |
24 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153,
25 | 153, 153, 153, 250, 170, 30,
26 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60,
27 | 255, 0, 0, 0, 0, 142, 0, 0, 70,
28 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
29 | zero_pad = 256 * 3 - len(palette)
30 | for i in range(zero_pad):
31 | palette.append(0)
32 |
33 | def colorize_mask(mask):
34 | # mask: numpy array of the mask
35 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
36 | new_mask.putpalette(palette)
37 | return new_mask
38 |
39 | def get_train_val(cv_split, all_items):
40 |
41 | # 90/10 train/val split, three random splits
42 | val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198]
43 | val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197]
44 | val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199]
45 |
46 | train_set = []
47 | val_set = []
48 |
49 | if cv_split == 0:
50 | for i in range(200):
51 | if i in val_0:
52 | val_set.append(all_items[i])
53 | else:
54 | train_set.append(all_items[i])
55 | elif cv_split == 1:
56 | for i in range(200):
57 | if i in val_1:
58 | val_set.append(all_items[i])
59 | else:
60 | train_set.append(all_items[i])
61 | elif cv_split == 2:
62 | for i in range(200):
63 | if i in val_2:
64 | val_set.append(all_items[i])
65 | else:
66 | train_set.append(all_items[i])
67 | else:
68 | logging.info('Unknown cv_split {}'.format(cv_split))
69 | sys.exit()
70 |
71 | return train_set, val_set
72 |
73 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0):
74 |
75 | items = []
76 | all_items = []
77 | aug_items = []
78 |
79 | assert quality == 'semantic'
80 | assert mode in ['train', 'val', 'trainval']
81 | # note that train and val are randomly determined, no official split
82 |
83 | img_dir_name = "training"
84 | img_path = os.path.join(root, img_dir_name, 'image_2')
85 | mask_path = os.path.join(root, img_dir_name, 'semantic')
86 |
87 | c_items = os.listdir(img_path)
88 | c_items.sort()
89 |
90 | for it in c_items:
91 | item = (os.path.join(img_path, it), os.path.join(mask_path, it))
92 | all_items.append(item)
93 | logging.info('KITTI has a total of {} images'.format(len(all_items)))
94 |
95 | # split into train/val
96 | train_set, val_set = get_train_val(cv_split, all_items)
97 |
98 | if mode == 'train':
99 | items = train_set
100 | elif mode == 'val':
101 | items = val_set
102 | elif mode == 'trainval':
103 | items = train_set + val_set
104 | else:
105 | logging.info('Unknown mode {}'.format(mode))
106 | sys.exit()
107 |
108 | logging.info('KITTI-{}: {} images'.format(mode, len(items)))
109 |
110 | return items, aug_items
111 |
112 | class KITTI(data.Dataset):
113 |
114 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None,
115 | transform=None, target_transform=None, dump_images=False,
116 | class_uniform_pct=0, class_uniform_tile=0, test=False,
117 | cv_split=None, scf=None, hardnm=0):
118 |
119 | self.quality = quality
120 | self.mode = mode
121 | self.maxSkip = maxSkip
122 | self.joint_transform_list = joint_transform_list
123 | self.transform = transform
124 | self.target_transform = target_transform
125 | self.dump_images = dump_images
126 | self.class_uniform_pct = class_uniform_pct
127 | self.class_uniform_tile = class_uniform_tile
128 | self.scf = scf
129 | self.hardnm = hardnm
130 |
131 | if cv_split:
132 | self.cv_split = cv_split
133 | assert cv_split < cfg.DATASET.CV_SPLITS, \
134 | 'expected cv_split {} to be < CV_SPLITS {}'.format(
135 | cv_split, cfg.DATASET.CV_SPLITS)
136 | else:
137 | self.cv_split = 0
138 |
139 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm)
140 | assert len(self.imgs), 'Found 0 images, please check the data set'
141 | # self.cal_shape(self.imgs)
142 |
143 | # Centroids for GT data
144 | if self.class_uniform_pct > 0:
145 | if self.scf:
146 | json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split)
147 | else:
148 | json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm)
149 | if os.path.isfile(json_fn):
150 | with open(json_fn, 'r') as json_data:
151 | centroids = json.load(json_data)
152 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
153 | else:
154 | if self.scf:
155 | self.centroids = kitti_uniform.class_centroids_all(
156 | self.imgs,
157 | num_classes,
158 | id2trainid=id_to_trainid,
159 | tile_size=class_uniform_tile)
160 | else:
161 | self.centroids = uniform.class_centroids_all(
162 | self.imgs,
163 | num_classes,
164 | id2trainid=id_to_trainid,
165 | tile_size=class_uniform_tile)
166 | with open(json_fn, 'w') as outfile:
167 | json.dump(self.centroids, outfile, indent=4)
168 |
169 | self.build_epoch()
170 |
171 |
172 | def cal_shape(self, imgs):
173 |
174 | for i in imgs:
175 | img_path, mask_path = i
176 | img = Image.open(img_path).convert('RGB')
177 | print(img.size)
178 |
179 | def build_epoch(self, cut=False):
180 | if self.class_uniform_pct > 0:
181 | self.imgs_uniform = uniform.build_epoch(self.imgs,
182 | self.centroids,
183 | num_classes,
184 | cfg.CLASS_UNIFORM_PCT)
185 | else:
186 | self.imgs_uniform = self.imgs
187 |
188 | def __getitem__(self, index):
189 | elem = self.imgs_uniform[index]
190 | centroid = None
191 | if len(elem) == 4:
192 | img_path, mask_path, centroid, class_id = elem
193 | else:
194 | img_path, mask_path = elem
195 |
196 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
197 | img_name = os.path.splitext(os.path.basename(img_path))[0]
198 |
199 | # kitti scale correction factor
200 | if self.mode == 'train' or self.mode == 'trainval':
201 | if self.scf:
202 | width, height = img.size
203 | img = img.resize((width*2, height*2), Image.BICUBIC)
204 | mask = mask.resize((width*2, height*2), Image.NEAREST)
205 | elif self.mode == 'val':
206 | width, height = 1242, 376
207 | img = img.resize((width, height), Image.BICUBIC)
208 | mask = mask.resize((width, height), Image.NEAREST)
209 | else:
210 | logging.info('Unknown mode {}'.format(mode))
211 | sys.exit()
212 |
213 | mask = np.array(mask)
214 | mask_copy = mask.copy()
215 |
216 | for k, v in id_to_trainid.items():
217 | mask_copy[mask == k] = v
218 | mask = Image.fromarray(mask_copy.astype(np.uint8))
219 |
220 | # Image Transformations
221 | if self.joint_transform_list is not None:
222 | for idx, xform in enumerate(self.joint_transform_list):
223 | if idx == 0 and centroid is not None:
224 | # HACK
225 | # We assume that the first transform is capable of taking
226 | # in a centroid
227 | img, mask = xform(img, mask, centroid)
228 | else:
229 | img, mask = xform(img, mask)
230 |
231 | # Debug
232 | if self.dump_images and centroid is not None:
233 | outdir = './dump_imgs_{}'.format(self.mode)
234 | os.makedirs(outdir, exist_ok=True)
235 | dump_img_name = trainid_to_name[class_id] + '_' + img_name
236 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
237 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
238 | mask_img = colorize_mask(np.array(mask))
239 | img.save(out_img_fn)
240 | mask_img.save(out_msk_fn)
241 |
242 | if self.transform is not None:
243 | img = self.transform(img)
244 | if self.target_transform is not None:
245 | mask = self.target_transform(mask)
246 |
247 | return img, mask, img_name
248 |
249 | def __len__(self):
250 | return len(self.imgs_uniform)
251 |
252 |
--------------------------------------------------------------------------------
/datasets/uniform.py:
--------------------------------------------------------------------------------
1 | """
2 | Uniform sampling of classes.
3 | For all images, for all classes, generate centroids around which to sample.
4 |
5 | All images are divided into tiles.
6 | For each tile, a class can be present or not. If it is
7 | present, calculate the centroid of the class and record it.
8 |
9 | We would like to thank Peter Kontschieder for the inspiration of this idea.
10 | """
11 |
12 | import logging
13 | from collections import defaultdict
14 | from PIL import Image
15 | import numpy as np
16 | from scipy import ndimage
17 | from tqdm import tqdm
18 |
19 | pbar = None
20 |
21 | class Point():
22 | """
23 | Point Class For X and Y Location
24 | """
25 | def __init__(self, x, y):
26 | self.x = x
27 | self.y = y
28 |
29 |
30 | def calc_tile_locations(tile_size, image_size):
31 | """
32 | Divide an image into tiles to help us cover classes that are spread out.
33 | tile_size: size of tile to distribute
34 | image_size: original image size
35 | return: locations of the tiles
36 | """
37 | image_size_y, image_size_x = image_size
38 | locations = []
39 | for y in range(image_size_y // tile_size):
40 | for x in range(image_size_x // tile_size):
41 | x_offs = x * tile_size
42 | y_offs = y * tile_size
43 | locations.append((x_offs, y_offs))
44 | return locations
45 |
46 |
47 | def class_centroids_image(item, tile_size, num_classes, id2trainid):
48 | """
49 | For one image, calculate centroids for all classes present in image.
50 | item: image, image_name
51 | tile_size:
52 | num_classes:
53 | id2trainid: mapping from original id to training ids
54 | return: Centroids are calculated for each tile.
55 | """
56 | image_fn, label_fn = item
57 | centroids = defaultdict(list)
58 | mask = np.array(Image.open(label_fn))
59 | image_size = mask.shape
60 | tile_locations = calc_tile_locations(tile_size, image_size)
61 |
62 | mask_copy = mask.copy()
63 | if id2trainid:
64 | for k, v in id2trainid.items():
65 | mask[mask_copy == k] = v
66 |
67 | for x_offs, y_offs in tile_locations:
68 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size]
69 | for class_id in range(num_classes):
70 | if class_id in patch:
71 | patch_class = (patch == class_id).astype(int)
72 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class)
73 | centroid_y = int(centroid_y) + y_offs
74 | centroid_x = int(centroid_x) + x_offs
75 | centroid = (centroid_x, centroid_y)
76 | centroids[class_id].append((image_fn, label_fn, centroid, class_id))
77 | pbar.update(1)
78 | return centroids
79 |
80 | import scipy.misc as m
81 |
82 | def class_centroids_image_from_color(item, tile_size, num_classes, id2trainid):
83 | """
84 | For one image, calculate centroids for all classes present in image.
85 | item: image, image_name
86 | tile_size:
87 | num_classes:
88 | id2trainid: mapping from original id to training ids
89 | return: Centroids are calculated for each tile.
90 | """
91 | image_fn, label_fn = item
92 | centroids = defaultdict(list)
93 | mask = m.imread(label_fn)
94 | image_size = mask[:,:,0].shape
95 | tile_locations = calc_tile_locations(tile_size, image_size)
96 |
97 | # mask = m.imread(label_fn)
98 | # mask_copy = np.full((img.size[1], img.size[0]), 255, dtype=np.uint8)
99 | # for k, v in id2trainid.items():
100 | # mask_copy[(mask == k)[:,:,0]] = v
101 | # mask = Image.fromarray(mask_copy.astype(np.uint8))
102 |
103 | # mask_copy = mask.copy()
104 | # mask_copy = mask.copy()
105 | # if id2trainid:
106 | # for k, v in id2trainid.items():
107 | # mask[mask_copy == k] = v
108 |
109 | mask_copy = np.full(image_size, 255, dtype=np.uint8)
110 |
111 | if id2trainid:
112 | for k, v in id2trainid.items():
113 | # print("0", mask.shape)
114 | # print("1", ((mask == np.array(k))[:,:,0]).shape) # 1052, 1914
115 | # # print("2", mask == np.array(k)[:,:,0])
116 | # break
117 | # if v != 255:
118 | # print(v)
119 | # if v == 2:
120 | # print(k, v, "num", np.count_nonzero(mask == np.array(k)))
121 | # break
122 | if v != 255 and v != -1:
123 | mask_copy[(mask == np.array(k))[:,:,0] & (mask == np.array(k))[:,:,1] & (mask == np.array(k))[:,:,2]] = v
124 | mask = mask_copy
125 |
126 | # mask_copy = mask.copy()
127 | # if id2trainid:
128 | # for k, v in id2trainid.items():
129 | # mask[mask_copy == k] = v
130 |
131 | for x_offs, y_offs in tile_locations:
132 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size]
133 | for class_id in range(num_classes):
134 | if class_id in patch:
135 | patch_class = (patch == class_id).astype(int)
136 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class)
137 | centroid_y = int(centroid_y) + y_offs
138 | centroid_x = int(centroid_x) + x_offs
139 | centroid = (centroid_x, centroid_y)
140 | centroids[class_id].append((image_fn, label_fn, centroid, class_id))
141 | pbar.update(1)
142 | return centroids
143 |
144 | def pooled_class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024):
145 | """
146 | Calculate class centroids for all classes for all images for all tiles.
147 | items: list of (image_fn, label_fn)
148 | tile size: size of tile
149 | returns: dict that contains a list of centroids for each class
150 | """
151 | from multiprocessing.dummy import Pool
152 | from functools import partial
153 | pool = Pool(32)
154 | global pbar
155 | pbar = tqdm(total=len(items), desc='pooled centroid extraction')
156 | class_centroids_item = partial(class_centroids_image_from_color,
157 | num_classes=num_classes,
158 | id2trainid=id2trainid,
159 | tile_size=tile_size)
160 |
161 | centroids = defaultdict(list)
162 | new_centroids = pool.map(class_centroids_item, items)
163 | pool.close()
164 | pool.join()
165 |
166 | # combine each image's items into a single global dict
167 | for image_items in new_centroids:
168 | for class_id in image_items:
169 | centroids[class_id].extend(image_items[class_id])
170 | return centroids
171 |
172 |
173 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024):
174 | """
175 | Calculate class centroids for all classes for all images for all tiles.
176 | items: list of (image_fn, label_fn)
177 | tile size: size of tile
178 | returns: dict that contains a list of centroids for each class
179 | """
180 | from multiprocessing.dummy import Pool
181 | from functools import partial
182 | pool = Pool(32)
183 | global pbar
184 | pbar = tqdm(total=len(items), desc='pooled centroid extraction')
185 | class_centroids_item = partial(class_centroids_image,
186 | num_classes=num_classes,
187 | id2trainid=id2trainid,
188 | tile_size=tile_size)
189 |
190 | centroids = defaultdict(list)
191 | new_centroids = pool.map(class_centroids_item, items)
192 | pool.close()
193 | pool.join()
194 |
195 | # combine each image's items into a single global dict
196 | for image_items in new_centroids:
197 | for class_id in image_items:
198 | centroids[class_id].extend(image_items[class_id])
199 | return centroids
200 |
201 |
202 | def unpooled_class_centroids_all(items, num_classes, tile_size=1024):
203 | """
204 | Calculate class centroids for all classes for all images for all tiles.
205 | items: list of (image_fn, label_fn)
206 | tile size: size of tile
207 | returns: dict that contains a list of centroids for each class
208 | """
209 | centroids = defaultdict(list)
210 | global pbar
211 | pbar = tqdm(total=len(items), desc='centroid extraction')
212 | for image, label in items:
213 | new_centroids = class_centroids_image((image, label),
214 | tile_size,
215 | num_classes)
216 | for class_id in new_centroids:
217 | centroids[class_id].extend(new_centroids[class_id])
218 |
219 | return centroids
220 |
221 |
222 | def class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024):
223 | """
224 | intermediate function to call pooled_class_centroid
225 | """
226 |
227 | pooled_centroids = pooled_class_centroids_all_from_color(items, num_classes,
228 | id2trainid, tile_size)
229 | return pooled_centroids
230 |
231 |
232 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024):
233 | """
234 | intermediate function to call pooled_class_centroid
235 | """
236 |
237 | pooled_centroids = pooled_class_centroids_all(items, num_classes,
238 | id2trainid, tile_size)
239 | return pooled_centroids
240 |
241 |
242 | def random_sampling(alist, num):
243 | """
244 | Randomly sample num items from the list
245 | alist: list of centroids to sample from
246 | num: can be larger than the list and if so, then wrap around
247 | return: class uniform samples from the list
248 | """
249 | sampling = []
250 | len_list = len(alist)
251 | assert len_list, 'len_list is zero!'
252 | indices = np.arange(len_list)
253 | np.random.shuffle(indices)
254 |
255 | for i in range(num):
256 | item = alist[indices[i % len_list]]
257 | sampling.append(item)
258 | return sampling
259 |
260 |
261 | def build_epoch(imgs, centroids, num_classes, class_uniform_pct):
262 | """
263 | Generate an epochs-worth of crops using uniform sampling. Needs to be called every
264 | imgs: list of imgs
265 | centroids:
266 | num_classes:
267 | class_uniform_pct: class uniform sampling percent ( % of uniform images in one epoch )
268 | """
269 | logging.info("Class Uniform Percentage: %s", str(class_uniform_pct))
270 | num_epoch = int(len(imgs))
271 |
272 | logging.info('Class Uniform items per Epoch:%s', str(num_epoch))
273 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes)
274 | num_rand = num_epoch - num_per_class * num_classes
275 | # create random crops
276 | imgs_uniform = random_sampling(imgs, num_rand)
277 |
278 | # now add uniform sampling
279 | for class_id in range(num_classes):
280 | string_format = "cls %d len %d"% (class_id, len(centroids[class_id]))
281 | logging.info(string_format)
282 | for class_id in range(num_classes):
283 | centroid_len = len(centroids[class_id])
284 | if centroid_len == 0:
285 | pass
286 | else:
287 | class_centroids = random_sampling(centroids[class_id], num_per_class)
288 | imgs_uniform.extend(class_centroids)
289 |
290 | return imgs_uniform
291 |
--------------------------------------------------------------------------------
/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code borrowded from:
3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py
4 | #
5 | #
6 | # MIT License
7 | #
8 | # Copyright (c) 2017 ZijunDeng
9 | #
10 | # Permission is hereby granted, free of charge, to any person obtaining a copy
11 | # of this software and associated documentation files (the "Software"), to deal
12 | # in the Software without restriction, including without limitation the rights
13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14 | # copies of the Software, and to permit persons to whom the Software is
15 | # furnished to do so, subject to the following conditions:
16 | #
17 | # The above copyright notice and this permission notice shall be included in all
18 | # copies or substantial portions of the Software.
19 | #
20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 | # SOFTWARE.
27 |
28 | """
29 |
30 | """
31 | Standard Transform
32 | """
33 |
34 | import random
35 | import numpy as np
36 | from skimage.filters import gaussian
37 | from skimage.restoration import denoise_bilateral
38 | import torch
39 | from PIL import Image, ImageEnhance
40 | import torchvision.transforms as torch_tr
41 | from config import cfg
42 | from scipy.ndimage.interpolation import shift
43 |
44 | from skimage.segmentation import find_boundaries
45 |
46 | try:
47 | import accimage
48 | except ImportError:
49 | accimage = None
50 |
51 |
52 | class RandomVerticalFlip(object):
53 | def __call__(self, img):
54 | if random.random() < 0.5:
55 | return img.transpose(Image.FLIP_TOP_BOTTOM)
56 | return img
57 |
58 |
59 | class DeNormalize(object):
60 | def __init__(self, mean, std):
61 | self.mean = mean
62 | self.std = std
63 |
64 | def __call__(self, tensor):
65 | for t, m, s in zip(tensor, self.mean, self.std):
66 | t.mul_(s).add_(m)
67 | return tensor
68 |
69 |
70 | class MaskToTensor(object):
71 | def __call__(self, img):
72 | return torch.from_numpy(np.array(img, dtype=np.int32)).long()
73 |
74 | class RelaxedBoundaryLossToTensor(object):
75 | """
76 | Boundary Relaxation
77 | """
78 | def __init__(self,ignore_id, num_classes):
79 | self.ignore_id=ignore_id
80 | self.num_classes= num_classes
81 |
82 |
83 | def new_one_hot_converter(self,a):
84 | ncols = self.num_classes+1
85 | out = np.zeros( (a.size,ncols), dtype=np.uint8)
86 | out[np.arange(a.size),a.ravel()] = 1
87 | out.shape = a.shape + (ncols,)
88 | return out
89 |
90 | def __call__(self,img):
91 |
92 | img_arr = np.array(img)
93 | img_arr[img_arr==self.ignore_id]=self.num_classes
94 |
95 | if cfg.STRICTBORDERCLASS != None:
96 | one_hot_orig = self.new_one_hot_converter(img_arr)
97 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1]))
98 | for cls in cfg.STRICTBORDERCLASS:
99 | mask = np.logical_or(mask,(img_arr == cls))
100 | one_hot = 0
101 |
102 | border = cfg.BORDER_WINDOW
103 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER):
104 | border = border // 2
105 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8)
106 |
107 | for i in range(-border,border+1):
108 | for j in range(-border, border+1):
109 | shifted= shift(img_arr,(i,j), cval=self.num_classes)
110 | one_hot += self.new_one_hot_converter(shifted)
111 |
112 | one_hot[one_hot>1] = 1
113 |
114 | if cfg.STRICTBORDERCLASS != None:
115 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot)
116 |
117 | one_hot = np.moveaxis(one_hot,-1,0)
118 |
119 |
120 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER):
121 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot)
122 | # print(one_hot.shape)
123 | return torch.from_numpy(one_hot).byte()
124 |
125 | class ResizeHeight(object):
126 | def __init__(self, size, interpolation=Image.BILINEAR):
127 | self.target_h = size
128 | self.interpolation = interpolation
129 |
130 | def __call__(self, img):
131 | w, h = img.size
132 | target_w = int(w / h * self.target_h)
133 | return img.resize((target_w, self.target_h), self.interpolation)
134 |
135 |
136 | class FreeScale(object):
137 | def __init__(self, size, interpolation=Image.BILINEAR):
138 | self.size = tuple(reversed(size)) # size: (h, w)
139 | self.interpolation = interpolation
140 |
141 | def __call__(self, img):
142 | return img.resize(self.size, self.interpolation)
143 |
144 |
145 | class FlipChannels(object):
146 | """
147 | Flip around the x-axis
148 | """
149 | def __call__(self, img):
150 | img = np.array(img)[:, :, ::-1]
151 | return Image.fromarray(img.astype(np.uint8))
152 |
153 |
154 | class RandomGaussianBlur(object):
155 | """
156 | Apply Gaussian Blur
157 | """
158 | def __call__(self, img):
159 | sigma = 0.15 + random.random() * 1.15
160 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
161 | blurred_img *= 255
162 | return Image.fromarray(blurred_img.astype(np.uint8))
163 |
164 |
165 | class RandomBilateralBlur(object):
166 | """
167 | Apply Bilateral Filtering
168 |
169 | """
170 | def __call__(self, img):
171 | sigma = random.uniform(0.05,0.75)
172 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True)
173 | blurred_img *= 255
174 | return Image.fromarray(blurred_img.astype(np.uint8))
175 |
176 | def _is_pil_image(img):
177 | if accimage is not None:
178 | return isinstance(img, (Image.Image, accimage.Image))
179 | else:
180 | return isinstance(img, Image.Image)
181 |
182 |
183 | def adjust_brightness(img, brightness_factor):
184 | """Adjust brightness of an Image.
185 |
186 | Args:
187 | img (PIL Image): PIL Image to be adjusted.
188 | brightness_factor (float): How much to adjust the brightness. Can be
189 | any non negative number. 0 gives a black image, 1 gives the
190 | original image while 2 increases the brightness by a factor of 2.
191 |
192 | Returns:
193 | PIL Image: Brightness adjusted image.
194 | """
195 | if not _is_pil_image(img):
196 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
197 |
198 | enhancer = ImageEnhance.Brightness(img)
199 | img = enhancer.enhance(brightness_factor)
200 | return img
201 |
202 |
203 | def adjust_contrast(img, contrast_factor):
204 | """Adjust contrast of an Image.
205 |
206 | Args:
207 | img (PIL Image): PIL Image to be adjusted.
208 | contrast_factor (float): How much to adjust the contrast. Can be any
209 | non negative number. 0 gives a solid gray image, 1 gives the
210 | original image while 2 increases the contrast by a factor of 2.
211 |
212 | Returns:
213 | PIL Image: Contrast adjusted image.
214 | """
215 | if not _is_pil_image(img):
216 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
217 |
218 | enhancer = ImageEnhance.Contrast(img)
219 | img = enhancer.enhance(contrast_factor)
220 | return img
221 |
222 |
223 | def adjust_saturation(img, saturation_factor):
224 | """Adjust color saturation of an image.
225 |
226 | Args:
227 | img (PIL Image): PIL Image to be adjusted.
228 | saturation_factor (float): How much to adjust the saturation. 0 will
229 | give a black and white image, 1 will give the original image while
230 | 2 will enhance the saturation by a factor of 2.
231 |
232 | Returns:
233 | PIL Image: Saturation adjusted image.
234 | """
235 | if not _is_pil_image(img):
236 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
237 |
238 | enhancer = ImageEnhance.Color(img)
239 | img = enhancer.enhance(saturation_factor)
240 | return img
241 |
242 |
243 | def adjust_hue(img, hue_factor):
244 | """Adjust hue of an image.
245 |
246 | The image hue is adjusted by converting the image to HSV and
247 | cyclically shifting the intensities in the hue channel (H).
248 | The image is then converted back to original image mode.
249 |
250 | `hue_factor` is the amount of shift in H channel and must be in the
251 | interval `[-0.5, 0.5]`.
252 |
253 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
254 |
255 | Args:
256 | img (PIL Image): PIL Image to be adjusted.
257 | hue_factor (float): How much to shift the hue channel. Should be in
258 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
259 | HSV space in positive and negative direction respectively.
260 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
261 | with complementary colors while 0 gives the original image.
262 |
263 | Returns:
264 | PIL Image: Hue adjusted image.
265 | """
266 | if not(-0.5 <= hue_factor <= 0.5):
267 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
268 |
269 | if not _is_pil_image(img):
270 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
271 |
272 | input_mode = img.mode
273 | if input_mode in {'L', '1', 'I', 'F'}:
274 | return img
275 |
276 | h, s, v = img.convert('HSV').split()
277 |
278 | np_h = np.array(h, dtype=np.uint8)
279 | # uint8 addition take cares of rotation across boundaries
280 | with np.errstate(over='ignore'):
281 | np_h += np.uint8(hue_factor * 255)
282 | h = Image.fromarray(np_h, 'L')
283 |
284 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
285 | return img
286 |
287 |
288 | class ColorJitter(object):
289 | """Randomly change the brightness, contrast and saturation of an image.
290 |
291 | Args:
292 | brightness (float): How much to jitter brightness. brightness_factor
293 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
294 | contrast (float): How much to jitter contrast. contrast_factor
295 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
296 | saturation (float): How much to jitter saturation. saturation_factor
297 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
298 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
299 | [-hue, hue]. Should be >=0 and <= 0.5.
300 | """
301 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
302 | self.brightness = brightness
303 | self.contrast = contrast
304 | self.saturation = saturation
305 | self.hue = hue
306 |
307 | @staticmethod
308 | def get_params(brightness, contrast, saturation, hue):
309 | """Get a randomized transform to be applied on image.
310 |
311 | Arguments are same as that of __init__.
312 |
313 | Returns:
314 | Transform which randomly adjusts brightness, contrast and
315 | saturation in a random order.
316 | """
317 | transforms = []
318 | if brightness > 0:
319 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
320 | transforms.append(
321 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
322 |
323 | if contrast > 0:
324 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
325 | transforms.append(
326 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
327 |
328 | if saturation > 0:
329 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
330 | transforms.append(
331 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
332 |
333 | if hue > 0:
334 | hue_factor = np.random.uniform(-hue, hue)
335 | transforms.append(
336 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
337 |
338 | np.random.shuffle(transforms)
339 | transform = torch_tr.Compose(transforms)
340 |
341 | return transform
342 |
343 | def __call__(self, img):
344 | """
345 | Args:
346 | img (PIL Image): Input image.
347 |
348 | Returns:
349 | PIL Image: Color jittered image.
350 | """
351 | transform = self.get_params(self.brightness, self.contrast,
352 | self.saturation, self.hue)
353 | return transform(img)
354 |
--------------------------------------------------------------------------------
/datasets/cityscapes_labels.py:
--------------------------------------------------------------------------------
1 | """
2 | # File taken from https://github.com/mcordts/cityscapesScripts/
3 | # License File Available at:
4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt
5 |
6 | # ----------------------
7 | # The Cityscapes Dataset
8 | # ----------------------
9 | #
10 | #
11 | # License agreement
12 | # -----------------
13 | #
14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree:
15 | #
16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions.
17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website.
18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character.
19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain.
20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt).
21 | #
22 | #
23 | # Contact
24 | # -------
25 | #
26 | # Marius Cordts, Mohamed Omran
27 | # www.cityscapes-dataset.net
28 |
29 | """
30 | from collections import namedtuple
31 |
32 |
33 | #--------------------------------------------------------------------------------
34 | # Definitions
35 | #--------------------------------------------------------------------------------
36 |
37 | # a label and all meta information
38 | Label = namedtuple( 'Label' , [
39 |
40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
41 | # We use them to uniquely name a class
42 |
43 | 'id' , # An integer ID that is associated with this label.
44 | # The IDs are used to represent the label in ground truth images
45 | # An ID of -1 means that this label does not have an ID and thus
46 | # is ignored when creating ground truth images (e.g. license plate).
47 | # Do not modify these IDs, since exactly these IDs are expected by the
48 | # evaluation server.
49 |
50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
51 | # ground truth images with train IDs, using the tools provided in the
52 | # 'preparation' folder. However, make sure to validate or submit results
53 | # to our evaluation server using the regular IDs above!
54 | # For trainIds, multiple labels might have the same ID. Then, these labels
55 | # are mapped to the same class in the ground truth images. For the inverse
56 | # mapping, we use the label that is defined first in the list below.
57 | # For example, mapping all void-type classes to the same ID in training,
58 | # might make sense for some approaches.
59 | # Max value is 255!
60 |
61 | 'category' , # The name of the category that this label belongs to
62 |
63 | 'categoryId' , # The ID of this category. Used to create ground truth images
64 | # on category level.
65 |
66 | 'hasInstances', # Whether this label distinguishes between single instances or not
67 |
68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
69 | # during evaluations or not
70 |
71 | 'color' , # The color of this label
72 | ] )
73 |
74 |
75 | #--------------------------------------------------------------------------------
76 | # A list of all labels
77 | #--------------------------------------------------------------------------------
78 |
79 | # Please adapt the train IDs as appropriate for you approach.
80 | # Note that you might want to ignore labels with ID 255 during training.
81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like.
82 | # Make sure to provide your results using the original IDs and not the training IDs.
83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these!
84 |
85 | labels = [
86 | # name id trainId category catId hasInstances ignoreInEval color
87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,154) ), # (153,153,153)
106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,143) ), # ( 0, 0,142)
122 | ]
123 |
124 |
125 | #--------------------------------------------------------------------------------
126 | # Create dictionaries for a fast lookup
127 | #--------------------------------------------------------------------------------
128 |
129 | # Please refer to the main method below for example usages!
130 |
131 | # name to label object
132 | name2label = { label.name : label for label in labels }
133 | # id to label object
134 | id2label = { label.id : label for label in labels }
135 | # trainId to label object
136 | trainId2label = { label.trainId : label for label in reversed(labels) }
137 | # label2trainid
138 | label2trainid = { label.id : label.trainId for label in labels }
139 | # trainId to label object
140 | trainId2name = { label.trainId : label.name for label in labels }
141 | trainId2color = { label.trainId : label.color for label in labels }
142 |
143 | color2trainId = { label.color : label.trainId for label in labels }
144 |
145 | trainId2trainId = { label.trainId : label.trainId for label in labels }
146 |
147 | # category to list of label objects
148 | category2labels = {}
149 | for label in labels:
150 | category = label.category
151 | if category in category2labels:
152 | category2labels[category].append(label)
153 | else:
154 | category2labels[category] = [label]
155 |
156 | #--------------------------------------------------------------------------------
157 | # Assure single instance name
158 | #--------------------------------------------------------------------------------
159 |
160 | # returns the label name that describes a single instance (if possible)
161 | # e.g. input | output
162 | # ----------------------
163 | # car | car
164 | # cargroup | car
165 | # foo | None
166 | # foogroup | None
167 | # skygroup | None
168 | def assureSingleInstanceName( name ):
169 | # if the name is known, it is not a group
170 | if name in name2label:
171 | return name
172 | # test if the name actually denotes a group
173 | if not name.endswith("group"):
174 | return None
175 | # remove group
176 | name = name[:-len("group")]
177 | # test if the new name exists
178 | if not name in name2label:
179 | return None
180 | # test if the new name denotes a label that actually has instances
181 | if not name2label[name].hasInstances:
182 | return None
183 | # all good then
184 | return name
185 |
186 | #--------------------------------------------------------------------------------
187 | # Main for testing
188 | #--------------------------------------------------------------------------------
189 |
190 | # just a dummy main
191 | if __name__ == "__main__":
192 | # Print all the labels
193 | print("List of cityscapes labels:")
194 | print("")
195 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )))
196 | print((" " + ('-' * 98)))
197 | for label in labels:
198 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )))
199 | print("")
200 |
201 | print("Example usages:")
202 |
203 | # Map from name to label
204 | name = 'car'
205 | id = name2label[name].id
206 | print(("ID of label '{name}': {id}".format( name=name, id=id )))
207 |
208 | # Map from ID to label
209 | category = id2label[id].category
210 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category )))
211 |
212 | # Map from trainID to label
213 | trainId = 0
214 | name = trainId2label[trainId].name
215 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )))
216 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | Miscellanous Functions
3 | """
4 |
5 | import sys
6 | import re
7 | import os
8 | import shutil
9 | import torch
10 | from datetime import datetime
11 | import logging
12 | from subprocess import call
13 | import shlex
14 | from tensorboardX import SummaryWriter
15 | import numpy as np
16 | import torchvision.transforms as standard_transforms
17 | import torchvision.utils as vutils
18 |
19 |
20 | # Create unique output dir name based on non-default command line args
21 | def make_exp_name(args, parser):
22 | exp_name = '{}-{}'.format(args.dataset[:4], args.arch[:])
23 | dict_args = vars(args)
24 |
25 | # sort so that we get a consistent directory name
26 | argnames = sorted(dict_args)
27 | ignorelist = ['date', 'exp', 'arch','prev_best_filepath', 'lr_schedule', 'max_cu_epoch', 'max_epoch',
28 | 'strict_bdr_cls', 'world_size', 'tb_path','best_record', 'test_mode', 'ckpt', 'coarse_boost_classes',
29 | 'crop_size', 'dist_url', 'syncbn', 'max_iter', 'color_aug', 'scale_max', 'scale_min', 'bs_mult',
30 | 'hanet_lr', 'class_uniform_pct', 'class_uniform_tile', 'hanet', 'hanet_set', 'hanet_pos']
31 | # build experiment name with non-default args
32 | for argname in argnames:
33 | if dict_args[argname] != parser.get_default(argname):
34 | if argname in ignorelist:
35 | continue
36 | if argname == 'snapshot':
37 | arg_str = 'PT'
38 | argname = ''
39 | elif argname == 'nosave':
40 | arg_str = ''
41 | argname=''
42 | elif argname == 'freeze_trunk':
43 | argname = ''
44 | arg_str = 'ft'
45 | elif argname == 'syncbn':
46 | argname = ''
47 | arg_str = 'sbn'
48 | elif argname == 'jointwtborder':
49 | argname = ''
50 | arg_str = 'rlx_loss'
51 | elif isinstance(dict_args[argname], bool):
52 | arg_str = 'T' if dict_args[argname] else 'F'
53 | else:
54 | arg_str = str(dict_args[argname])[:7]
55 | if argname is not '':
56 | exp_name += '_{}_{}'.format(str(argname), arg_str)
57 | else:
58 | exp_name += '_{}'.format(arg_str)
59 | # clean special chars out exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name)
60 | return exp_name
61 |
62 | def fast_hist(label_pred, label_true, num_classes):
63 | mask = (label_true >= 0) & (label_true < num_classes)
64 | hist = np.bincount(
65 | num_classes * label_true[mask].astype(int) +
66 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
67 | return hist
68 |
69 | def per_class_iu(hist):
70 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
71 |
72 | def save_log(prefix, output_dir, date_str, rank=0):
73 | fmt = '%(asctime)s.%(msecs)03d %(message)s'
74 | date_fmt = '%m-%d %H:%M:%S'
75 | filename = os.path.join(output_dir, prefix + '_' + date_str +'_rank_' + str(rank) +'.log')
76 | print("Logging :", filename)
77 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt,
78 | filename=filename, filemode='w')
79 | console = logging.StreamHandler()
80 | console.setLevel(logging.INFO)
81 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt)
82 | console.setFormatter(formatter)
83 | if rank == 0:
84 | logging.getLogger('').addHandler(console)
85 | else:
86 | fh = logging.FileHandler(filename)
87 | logging.getLogger('').addHandler(fh)
88 |
89 |
90 |
91 | def prep_experiment(args, parser):
92 | """
93 | Make output directories, setup logging, Tensorboard, snapshot code.
94 | """
95 | ckpt_path = args.ckpt
96 | tb_path = args.tb_path
97 | exp_name = make_exp_name(args, parser)
98 | args.exp_path = os.path.join(ckpt_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H')))
99 | args.tb_exp_path = os.path.join(tb_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H')))
100 | args.ngpu = torch.cuda.device_count()
101 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
102 | args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
103 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
104 | args.last_record = {}
105 | if args.local_rank == 0:
106 | os.makedirs(args.exp_path, exist_ok=True)
107 | os.makedirs(args.tb_exp_path, exist_ok=True)
108 | save_log('log', args.exp_path, args.date_str, rank=args.local_rank)
109 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write(
110 | str(args) + '\n\n')
111 | writer = SummaryWriter(log_dir=args.tb_exp_path, comment=args.tb_tag)
112 | return writer
113 | return None
114 |
115 | def evaluate_eval_for_inference(hist, dataset=None):
116 | """
117 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for
118 | large dataset) Only applies to eval/eval.py
119 | """
120 | # axis 0: gt, axis 1: prediction
121 | acc = np.diag(hist).sum() / hist.sum()
122 | acc_cls = np.diag(hist) / hist.sum(axis=1)
123 | acc_cls = np.nanmean(acc_cls)
124 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
125 |
126 | print_evaluate_results(hist, iu, dataset=dataset)
127 | freq = hist.sum(axis=1) / hist.sum()
128 | mean_iu = np.nanmean(iu)
129 | logging.info('mean {}'.format(mean_iu))
130 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
131 | return acc, acc_cls, mean_iu, fwavacc
132 |
133 |
134 |
135 | def evaluate_eval(args, net, optimizer, scheduler, val_loss, hist, dump_images, writer, epoch=0, dataset=None, curr_iter=0, optimizer_at=None, scheduler_at=None):
136 | """
137 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for
138 | large dataset) Only applies to eval/eval.py
139 | """
140 | # axis 0: gt, axis 1: prediction
141 | acc = np.diag(hist).sum() / hist.sum()
142 | acc_cls = np.diag(hist) / hist.sum(axis=1)
143 | acc_cls = np.nanmean(acc_cls)
144 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
145 |
146 | print_evaluate_results(hist, iu, dataset)
147 | freq = hist.sum(axis=1) / hist.sum()
148 | mean_iu = np.nanmean(iu)
149 | logging.info('mean {}'.format(mean_iu))
150 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
151 |
152 | # update latest snapshot
153 | if 'mean_iu' in args.last_record:
154 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format(
155 | args.last_record['epoch'], args.last_record['mean_iu'])
156 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
157 | try:
158 | os.remove(last_snapshot)
159 | except OSError:
160 | pass
161 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format(epoch, mean_iu)
162 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
163 | args.last_record['mean_iu'] = mean_iu
164 | args.last_record['epoch'] = epoch
165 |
166 | torch.cuda.synchronize()
167 |
168 | if optimizer_at is not None:
169 | torch.save({
170 | 'state_dict': net.state_dict(),
171 | 'optimizer': optimizer.state_dict(),
172 | 'optimizer_at': optimizer_at.state_dict(),
173 | 'scheduler': scheduler.state_dict(),
174 | 'scheduler_at': scheduler_at.state_dict(),
175 | 'epoch': epoch,
176 | 'mean_iu': mean_iu,
177 | 'command': ' '.join(sys.argv[1:])
178 | }, last_snapshot)
179 | else:
180 | torch.save({
181 | 'state_dict': net.state_dict(),
182 | 'optimizer': optimizer.state_dict(),
183 | 'scheduler': scheduler.state_dict(),
184 | 'epoch': epoch,
185 | 'mean_iu': mean_iu,
186 | 'command': ' '.join(sys.argv[1:])
187 | }, last_snapshot)
188 |
189 | # update best snapshot
190 | if mean_iu > args.best_record['mean_iu'] :
191 | # remove old best snapshot
192 | if args.best_record['epoch'] != -1:
193 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format(
194 | args.best_record['epoch'], args.best_record['mean_iu'])
195 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
196 | assert os.path.exists(best_snapshot), \
197 | 'cant find old snapshot {}'.format(best_snapshot)
198 | os.remove(best_snapshot)
199 |
200 |
201 | # save new best
202 | args.best_record['val_loss'] = val_loss.avg
203 | args.best_record['epoch'] = epoch
204 | args.best_record['acc'] = acc
205 | args.best_record['acc_cls'] = acc_cls
206 | args.best_record['mean_iu'] = mean_iu
207 | args.best_record['fwavacc'] = fwavacc
208 |
209 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format(
210 | args.best_record['epoch'], args.best_record['mean_iu'])
211 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
212 | shutil.copyfile(last_snapshot, best_snapshot)
213 |
214 |
215 | # to_save_dir = os.path.join(args.exp_path, 'best_images')
216 | # os.makedirs(to_save_dir, exist_ok=True)
217 | # val_visual = []
218 |
219 | # idx = 0
220 |
221 | # visualize = standard_transforms.Compose([
222 | # standard_transforms.Scale(384),
223 | # standard_transforms.ToTensor()
224 | # ])
225 | # for bs_idx, bs_data in enumerate(dump_images):
226 | # for local_idx, data in enumerate(zip(bs_data[0], bs_data[1],bs_data[2])):
227 | # gt_pil = args.dataset_cls.colorize_mask(data[0].cpu().numpy())
228 | # predictions_pil = args.dataset_cls.colorize_mask(data[1].cpu().numpy())
229 | # img_name = data[2]
230 |
231 | # prediction_fn = '{}_prediction.png'.format(img_name)
232 | # predictions_pil.save(os.path.join(to_save_dir, prediction_fn))
233 | # gt_fn = '{}_gt.png'.format(img_name)
234 | # gt_pil.save(os.path.join(to_save_dir, gt_fn))
235 | # val_visual.extend([visualize(gt_pil.convert('RGB')),
236 | # visualize(predictions_pil.convert('RGB'))])
237 | # if local_idx >= 9:
238 | # break
239 | # val_visual = torch.stack(val_visual, 0)
240 | # val_visual = vutils.make_grid(val_visual, nrow=10, padding=5)
241 | # writer.add_image('imgs', val_visual, curr_iter )
242 |
243 | logging.info('-' * 107)
244 | fmt_str = '[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\
245 | '[mean_iu %.5f], [fwavacc %.5f]'
246 | logging.info(fmt_str % (epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))
247 | fmt_str = 'best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\
248 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], '
249 | logging.info(fmt_str % (args.best_record['val_loss'], args.best_record['acc'],
250 | args.best_record['acc_cls'], args.best_record['mean_iu'],
251 | args.best_record['fwavacc'], args.best_record['epoch']))
252 | logging.info('-' * 107)
253 |
254 | # tensorboard logging of validation phase metrics
255 |
256 | writer.add_scalar('training/acc', acc, curr_iter)
257 | writer.add_scalar('training/acc_cls', acc_cls, curr_iter)
258 | writer.add_scalar('training/mean_iu', mean_iu, curr_iter)
259 | writer.add_scalar('training/val_loss', val_loss.avg, curr_iter)
260 |
261 |
262 |
263 |
264 |
265 | def print_evaluate_results(hist, iu, dataset=None):
266 | # fixme: Need to refactor this dict
267 | try:
268 | id2cat = dataset.id2cat
269 | except:
270 | id2cat = {i: i for i in range(dataset.num_classes)}
271 | iu_false_positive = hist.sum(axis=1) - np.diag(hist)
272 | iu_false_negative = hist.sum(axis=0) - np.diag(hist)
273 | iu_true_positive = np.diag(hist)
274 |
275 | logging.info('IoU:')
276 | logging.info('label_id label iU Precision Recall TP FP FN')
277 | for idx, i in enumerate(iu):
278 | # Format all of the strings:
279 | idx_string = "{:2d}".format(idx)
280 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else ''
281 | iu_string = '{:5.1f}'.format(i * 100)
282 | total_pixels = hist.sum()
283 | tp = '{:5.1f}'.format(100 * iu_true_positive[idx] / total_pixels)
284 | fp = '{:5.1f}'.format(
285 | iu_false_positive[idx] / iu_true_positive[idx])
286 | fn = '{:5.1f}'.format(iu_false_negative[idx] / iu_true_positive[idx])
287 | precision = '{:5.1f}'.format(
288 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx]))
289 | recall = '{:5.1f}'.format(
290 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx]))
291 | logging.info('{} {} {} {} {} {} {} {}'.format(
292 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn))
293 |
294 |
295 |
296 |
297 | class AverageMeter(object):
298 |
299 | def __init__(self):
300 | self.reset()
301 |
302 | def reset(self):
303 | self.val = 0
304 | self.avg = 0
305 | self.sum = 0
306 | self.count = 0
307 |
308 | def update(self, val, n=1):
309 | self.val = val
310 | self.sum += val * n
311 | self.count += n
312 | self.avg = self.sum / self.count
313 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Loss.py
3 | """
4 |
5 | import logging
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from config import cfg
11 |
12 |
13 | def get_loss(args):
14 | """
15 | Get the criterion based on the loss function
16 | args: commandline arguments
17 | return: criterion, criterion_val
18 | """
19 | if args.cls_wt_loss:
20 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
21 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
22 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507])
23 | else:
24 | ce_weight = None
25 |
26 | if args.img_wt_loss:
27 | criterion = ImageBasedCrossEntropyLoss2d(
28 | classes=args.dataset_cls.num_classes, size_average=True,
29 | ignore_index=args.dataset_cls.ignore_label,
30 | upper_bound=args.wt_bound).cuda()
31 | elif args.jointwtborder:
32 | criterion = ImgWtLossSoftNLL(classes=args.dataset_cls.num_classes,
33 | ignore_index=args.dataset_cls.ignore_label,
34 | upper_bound=args.wt_bound).cuda()
35 | else:
36 | print("standard cross entropy")
37 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean',
38 | ignore_index=args.dataset_cls.ignore_label).cuda()
39 |
40 | criterion_val = nn.CrossEntropyLoss(reduction='mean',
41 | ignore_index=args.dataset_cls.ignore_label).cuda()
42 | return criterion, criterion_val
43 |
44 | def get_loss_by_epoch(args):
45 | """
46 | Get the criterion based on the loss function
47 | args: commandline arguments
48 | return: criterion, criterion_val
49 | """
50 |
51 | if args.img_wt_loss:
52 | criterion = ImageBasedCrossEntropyLoss2d(
53 | classes=args.dataset_cls.num_classes, size_average=True,
54 | ignore_index=args.dataset_cls.ignore_label,
55 | upper_bound=args.wt_bound).cuda()
56 | elif args.jointwtborder:
57 | criterion = ImgWtLossSoftNLL_by_epoch(classes=args.dataset_cls.num_classes,
58 | ignore_index=args.dataset_cls.ignore_label,
59 | upper_bound=args.wt_bound).cuda()
60 | else:
61 | criterion = CrossEntropyLoss2d(size_average=True,
62 | ignore_index=args.dataset_cls.ignore_label).cuda()
63 |
64 | criterion_val = CrossEntropyLoss2d(size_average=True,
65 | weight=None,
66 | ignore_index=args.dataset_cls.ignore_label).cuda()
67 | return criterion, criterion_val
68 |
69 |
70 | def get_loss_aux(args):
71 | """
72 | Get the criterion based on the loss function
73 | args: commandline arguments
74 | return: criterion, criterion_val
75 | """
76 | if args.cls_wt_loss:
77 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
78 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
79 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507])
80 | else:
81 | ce_weight = None
82 |
83 | print("standard cross entropy")
84 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean',
85 | ignore_index=args.dataset_cls.ignore_label).cuda()
86 |
87 | return criterion
88 |
89 | def get_loss_bcelogit(args):
90 | if args.cls_wt_loss:
91 | pos_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
92 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
93 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507])
94 | else:
95 | pos_weight = None
96 | print("standard bce with logit cross entropy")
97 | criterion = nn.BCEWithLogitsLoss(reduction='mean').cuda()
98 |
99 | return criterion
100 |
101 | def weighted_binary_cross_entropy(output, target):
102 |
103 | weights = torch.Tensor([0.1, 0.9])
104 |
105 | loss = weights[1] * (target * torch.log(output)) + \
106 | weights[0] * ((1 - target) * torch.log(1 - output))
107 |
108 | return torch.neg(torch.mean(loss))
109 |
110 | class ImageBasedCrossEntropyLoss2d(nn.Module):
111 | """
112 | Image Weighted Cross Entropy Loss
113 | """
114 |
115 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255,
116 | norm=False, upper_bound=1.0):
117 | super(ImageBasedCrossEntropyLoss2d, self).__init__()
118 | logging.info("Using Per Image based weighted loss")
119 | self.num_classes = classes
120 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index)
121 | self.norm = norm
122 | self.upper_bound = upper_bound
123 | self.batch_weights = cfg.BATCH_WEIGHTING
124 | self.logsoftmax = nn.LogSoftmax(dim=1)
125 |
126 | def calculate_weights(self, target):
127 | """
128 | Calculate weights of classes based on the training crop
129 | """
130 | hist = np.histogram(target.flatten(), range(
131 | self.num_classes + 1), normed=True)[0]
132 | if self.norm:
133 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
134 | else:
135 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
136 | return hist
137 |
138 | def forward(self, inputs, targets):
139 |
140 | target_cpu = targets.data.cpu().numpy()
141 | if self.batch_weights:
142 | weights = self.calculate_weights(target_cpu)
143 | self.nll_loss.weight = torch.Tensor(weights).cuda()
144 |
145 | loss = 0.0
146 | for i in range(0, inputs.shape[0]):
147 | if not self.batch_weights:
148 | weights = self.calculate_weights(target_cpu[i])
149 | self.nll_loss.weight = torch.Tensor(weights).cuda()
150 |
151 | loss += self.nll_loss(self.logsoftmax(inputs[i].unsqueeze(0)),
152 | targets[i].unsqueeze(0))
153 | return loss
154 |
155 |
156 |
157 | class CrossEntropyLoss2d(nn.Module):
158 | """
159 | Cross Entroply NLL Loss
160 | """
161 |
162 | def __init__(self, weight=None, size_average=True, ignore_index=255):
163 | super(CrossEntropyLoss2d, self).__init__()
164 | logging.info("Using Cross Entropy Loss")
165 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index)
166 | self.logsoftmax = nn.LogSoftmax(dim=1)
167 | # self.weight = weight
168 |
169 | def forward(self, inputs, targets):
170 | return self.nll_loss(self.logsoftmax(inputs), targets)
171 |
172 | def customsoftmax(inp, multihotmask):
173 | """
174 | Custom Softmax
175 | """
176 | soft = F.softmax(inp, dim=1)
177 | # This takes the mask * softmax ( sums it up hence summing up the classes in border
178 | # then takes of summed up version vs no summed version
179 | return torch.log(
180 | torch.max(soft, (multihotmask * (soft * multihotmask).sum(1, keepdim=True)))
181 | )
182 |
183 | class ImgWtLossSoftNLL(nn.Module):
184 | """
185 | Relax Loss
186 | """
187 |
188 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0,
189 | norm=False):
190 | super(ImgWtLossSoftNLL, self).__init__()
191 | self.weights = weights
192 | self.num_classes = classes
193 | self.ignore_index = ignore_index
194 | self.upper_bound = upper_bound
195 | self.norm = norm
196 | self.batch_weights = cfg.BATCH_WEIGHTING
197 |
198 | def calculate_weights(self, target):
199 | """
200 | Calculate weights of the classes based on training crop
201 | """
202 | if len(target.shape) == 3:
203 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum()
204 | else:
205 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum()
206 | if self.norm:
207 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
208 | else:
209 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
210 | return hist[:-1]
211 |
212 | def custom_nll(self, inputs, target, class_weights, border_weights, mask):
213 | """
214 | NLL Relaxed Loss Implementation
215 | """
216 | if (cfg.REDUCE_BORDER_ITER != -1 and cfg.ITER > cfg.REDUCE_BORDER_ITER):
217 | border_weights = 1 / border_weights
218 | target[target > 1] = 1
219 |
220 | loss_matrix = (-1 / border_weights *
221 | (target[:, :-1, :, :].float() *
222 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
223 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \
224 | (1. - mask.float())
225 |
226 | # loss_matrix[border_weights > 1] = 0
227 | loss = loss_matrix.sum()
228 |
229 | # +1 to prevent division by 0
230 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1)
231 | return loss
232 |
233 | def forward(self, inputs, target):
234 | weights = target[:, :-1, :, :].sum(1).float()
235 | ignore_mask = (weights == 0)
236 | weights[ignore_mask] = 1
237 |
238 | loss = 0
239 | target_cpu = target.data.cpu().numpy()
240 |
241 | if self.batch_weights:
242 | class_weights = self.calculate_weights(target_cpu)
243 |
244 | for i in range(0, inputs.shape[0]):
245 | if not self.batch_weights:
246 | class_weights = self.calculate_weights(target_cpu[i])
247 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0),
248 | target[i].unsqueeze(0),
249 | class_weights=torch.Tensor(class_weights).cuda(),
250 | border_weights=weights[i], mask=ignore_mask[i])
251 |
252 | loss = loss / inputs.shape[0]
253 | return loss
254 |
255 | class ImgWtLossSoftNLL_by_epoch(nn.Module):
256 | """
257 | Relax Loss
258 | """
259 |
260 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0,
261 | norm=False):
262 | super(ImgWtLossSoftNLL_by_epoch, self).__init__()
263 | self.weights = weights
264 | self.num_classes = classes
265 | self.ignore_index = ignore_index
266 | self.upper_bound = upper_bound
267 | self.norm = norm
268 | self.batch_weights = cfg.BATCH_WEIGHTING
269 | self.fp16 = False
270 |
271 |
272 | def calculate_weights(self, target):
273 | """
274 | Calculate weights of the classes based on training crop
275 | """
276 | if len(target.shape) == 3:
277 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum()
278 | else:
279 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum()
280 | if self.norm:
281 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
282 | else:
283 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
284 | return hist[:-1]
285 |
286 | def custom_nll(self, inputs, target, class_weights, border_weights, mask):
287 | """
288 | NLL Relaxed Loss Implementation
289 | """
290 | if (cfg.REDUCE_BORDER_EPOCH != -1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH):
291 | border_weights = 1 / border_weights
292 | target[target > 1] = 1
293 | if self.fp16:
294 | loss_matrix = (-1 / border_weights *
295 | (target[:, :-1, :, :].half() *
296 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
297 | customsoftmax(inputs, target[:, :-1, :, :].half())).sum(1)) * \
298 | (1. - mask.half())
299 | else:
300 | loss_matrix = (-1 / border_weights *
301 | (target[:, :-1, :, :].float() *
302 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
303 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \
304 | (1. - mask.float())
305 |
306 | # loss_matrix[border_weights > 1] = 0
307 | loss = loss_matrix.sum()
308 |
309 | # +1 to prevent division by 0
310 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1)
311 | return loss
312 |
313 | def forward(self, inputs, target):
314 | if self.fp16:
315 | weights = target[:, :-1, :, :].sum(1).half()
316 | else:
317 | weights = target[:, :-1, :, :].sum(1).float()
318 | ignore_mask = (weights == 0)
319 | weights[ignore_mask] = 1
320 |
321 | loss = 0
322 | target_cpu = target.data.cpu().numpy()
323 |
324 | if self.batch_weights:
325 | class_weights = self.calculate_weights(target_cpu)
326 |
327 | for i in range(0, inputs.shape[0]):
328 | if not self.batch_weights:
329 | class_weights = self.calculate_weights(target_cpu[i])
330 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0),
331 | target[i].unsqueeze(0),
332 | class_weights=torch.Tensor(class_weights).cuda(),
333 | border_weights=weights, mask=ignore_mask[i])
334 |
335 | return loss
336 |
--------------------------------------------------------------------------------
/network/Resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code Adapted from:
3 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017,
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 |
36 | import torch
37 | import torch.nn as nn
38 | import torch.utils.model_zoo as model_zoo
39 | import network.mynn as mynn
40 |
41 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
42 | 'resnet152']
43 |
44 |
45 | model_urls = {
46 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
47 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
48 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
49 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
50 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
51 | }
52 |
53 |
54 | def conv3x3(in_planes, out_planes, stride=1):
55 | """3x3 convolution with padding"""
56 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
57 | padding=1, bias=False)
58 |
59 |
60 | class BasicBlock(nn.Module):
61 | """
62 | Basic Block for Resnet
63 | """
64 | expansion = 1
65 |
66 | def __init__(self, inplanes, planes, stride=1, downsample=None):
67 | super(BasicBlock, self).__init__()
68 | self.conv1 = conv3x3(inplanes, planes, stride)
69 | self.bn1 = mynn.Norm2d(planes)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.conv2 = conv3x3(planes, planes)
72 | self.bn2 = mynn.Norm2d(planes)
73 | self.downsample = downsample
74 | self.stride = stride
75 |
76 | def forward(self, x):
77 | residual = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | out = self.relu(out)
91 |
92 | return out
93 |
94 |
95 | class Bottleneck(nn.Module):
96 | """
97 | Bottleneck Layer for Resnet
98 | """
99 | expansion = 4
100 |
101 | def __init__(self, inplanes, planes, stride=1, downsample=None):
102 | super(Bottleneck, self).__init__()
103 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
104 | self.bn1 = mynn.Norm2d(planes)
105 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
106 | padding=1, bias=False)
107 | self.bn2 = mynn.Norm2d(planes)
108 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
109 | self.bn3 = mynn.Norm2d(planes * self.expansion)
110 | self.relu = nn.ReLU(inplace=True)
111 | self.downsample = downsample
112 | self.stride = stride
113 |
114 | def forward(self, x):
115 | residual = x
116 |
117 | out = self.conv1(x)
118 | out = self.bn1(out)
119 | out = self.relu(out)
120 |
121 | out = self.conv2(out)
122 | out = self.bn2(out)
123 | out = self.relu(out)
124 |
125 | out = self.conv3(out)
126 | out = self.bn3(out)
127 |
128 | if self.downsample is not None:
129 | residual = self.downsample(x)
130 |
131 | out += residual
132 | out = self.relu(out)
133 |
134 | return out
135 |
136 |
137 | class ResNet3X3(nn.Module):
138 | """
139 | Resnet Global Module for Initialization
140 | """
141 | def __init__(self, block, layers, num_classes=1000):
142 | # self.inplanes = 64
143 | self.inplanes = 128
144 | super(ResNet3X3, self).__init__()
145 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
146 | # bias=False)
147 | # self.bn1 = mynn.Norm2d(64)
148 | # self.relu = nn.ReLU(inplace=True)
149 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
150 | bias=False)
151 | self.bn1 = mynn.Norm2d(64)
152 | self.relu1 = nn.ReLU(inplace=True)
153 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1,
154 | bias=False)
155 | self.bn2 = mynn.Norm2d(64)
156 | self.relu2 = nn.ReLU(inplace=True)
157 | self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1,
158 | bias=False)
159 | self.bn3 = mynn.Norm2d(self.inplanes)
160 | self.relu3 = nn.ReLU(inplace=True)
161 |
162 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
163 | self.layer1 = self._make_layer(block, 64, layers[0])
164 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
165 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
166 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
167 | self.avgpool = nn.AvgPool2d(7, stride=1)
168 | self.fc = nn.Linear(512 * block.expansion, num_classes)
169 |
170 | for m in self.modules():
171 | if isinstance(m, nn.Conv2d):
172 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
173 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.SyncBatchNorm):
174 | nn.init.constant_(m.weight, 1)
175 | nn.init.constant_(m.bias, 0)
176 |
177 | def _make_layer(self, block, planes, blocks, stride=1):
178 | downsample = None
179 | if stride != 1 or self.inplanes != planes * block.expansion:
180 | downsample = nn.Sequential(
181 | nn.Conv2d(self.inplanes, planes * block.expansion,
182 | kernel_size=1, stride=stride, bias=False),
183 | mynn.Norm2d(planes * block.expansion),
184 | )
185 |
186 | layers = []
187 | layers.append(block(self.inplanes, planes, stride, downsample))
188 | self.inplanes = planes * block.expansion
189 | for index in range(1, blocks):
190 | layers.append(block(self.inplanes, planes))
191 |
192 | return nn.Sequential(*layers)
193 |
194 | def forward(self, x):
195 | # x = self.conv1(x)
196 | # x = self.bn1(x)
197 | # x = self.relu(x)
198 | x = self.conv1(input)
199 | x = self.bn1(x)
200 | x = self.relu1(x)
201 | x = self.conv2(x)
202 | x = self.bn2(x)
203 | x = self.relu2(x)
204 | x = self.conv3(x)
205 | x = self.bn3(x)
206 | x = self.relu3(x)
207 |
208 | x = self.maxpool(x)
209 |
210 | x = self.layer1(x)
211 | x = self.layer2(x)
212 | x = self.layer3(x)
213 | x = self.layer4(x)
214 |
215 | x = self.avgpool(x)
216 | x = x.view(x.size(0), -1)
217 | x = self.fc(x)
218 |
219 | return x
220 |
221 | class ResNet(nn.Module):
222 | """
223 | Resnet Global Module for Initialization
224 | """
225 | def __init__(self, block, layers, num_classes=1000):
226 | self.inplanes = 64
227 | # self.inplanes = 128
228 | super(ResNet, self).__init__()
229 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
230 | bias=False)
231 | self.bn1 = mynn.Norm2d(64)
232 | self.relu = nn.ReLU(inplace=True)
233 | # self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
234 | # bias=False)
235 | # self.bn1 = mynn.Norm2d(64)
236 | # self.relu1 = nn.ReLU(inplace=True)
237 | # self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1,
238 | # bias=False)
239 | # self.bn2 = mynn.Norm2d(64)
240 | # self.relu2 = nn.ReLU(inplace=True)
241 | # self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1,
242 | # bias=False)
243 | # self.bn3 = mynn.Norm2d(self.inplanes)
244 | # self.relu3 = nn.ReLU(inplace=True)
245 |
246 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
247 | self.layer1 = self._make_layer(block, 64, layers[0])
248 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
249 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
250 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
251 | self.avgpool = nn.AvgPool2d(7, stride=1)
252 | self.fc = nn.Linear(512 * block.expansion, num_classes)
253 |
254 | for m in self.modules():
255 | if isinstance(m, nn.Conv2d):
256 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
257 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.SyncBatchNorm):
258 | nn.init.constant_(m.weight, 1)
259 | nn.init.constant_(m.bias, 0)
260 |
261 | def _make_layer(self, block, planes, blocks, stride=1):
262 | downsample = None
263 | if stride != 1 or self.inplanes != planes * block.expansion:
264 | downsample = nn.Sequential(
265 | nn.Conv2d(self.inplanes, planes * block.expansion,
266 | kernel_size=1, stride=stride, bias=False),
267 | mynn.Norm2d(planes * block.expansion),
268 | )
269 |
270 | layers = []
271 | layers.append(block(self.inplanes, planes, stride, downsample))
272 | self.inplanes = planes * block.expansion
273 | for index in range(1, blocks):
274 | layers.append(block(self.inplanes, planes))
275 |
276 | return nn.Sequential(*layers)
277 |
278 | def forward(self, x):
279 | x = self.conv1(x)
280 | x = self.bn1(x)
281 | x = self.relu(x)
282 | # x = self.conv1(input)
283 | # x = self.bn1(x)
284 | # x = self.relu1(x)
285 | # x = self.conv2(x)
286 | # x = self.bn2(x)
287 | # x = self.relu2(x)
288 | # x = self.conv3(x)
289 | # x = self.bn3(x)
290 | # x = self.relu3(x)
291 |
292 | x = self.maxpool(x)
293 |
294 | x = self.layer1(x)
295 | x = self.layer2(x)
296 | x = self.layer3(x)
297 | x = self.layer4(x)
298 |
299 | x = self.avgpool(x)
300 | x = x.view(x.size(0), -1)
301 | x = self.fc(x)
302 |
303 | return x
304 |
305 | def resnet18(pretrained=True, **kwargs):
306 | """Constructs a ResNet-18 model.
307 |
308 | Args:
309 | pretrained (bool): If True, returns a model pre-trained on ImageNet
310 | """
311 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
312 | if pretrained:
313 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
314 | return model
315 |
316 |
317 | def resnet34(pretrained=True, **kwargs):
318 | """Constructs a ResNet-34 model.
319 |
320 | Args:
321 | pretrained (bool): If True, returns a model pre-trained on ImageNet
322 | """
323 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
324 | if pretrained:
325 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
326 | return model
327 |
328 |
329 | def resnet50(pretrained=True, **kwargs):
330 | """Constructs a ResNet-50 model.
331 |
332 | Args:
333 | pretrained (bool): If True, returns a model pre-trained on ImageNet
334 | """
335 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
336 | if pretrained:
337 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
338 | return model
339 |
340 |
341 | def resnet101(pretrained=True, **kwargs):
342 | """Constructs a ResNet-101 model.
343 |
344 | Args:
345 | pretrained (bool): If True, returns a model pre-trained on ImageNet
346 | """
347 | model = ResNet3X3(Bottleneck, [3, 4, 23, 3], **kwargs)
348 | if pretrained:
349 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
350 | print("########### pretrained ##############")
351 | model.load_state_dict(torch.load('./pretrained/resnet101-imagenet.pth', map_location="cpu"))
352 | return model
353 |
354 |
355 | def resnet152(pretrained=True, **kwargs):
356 | """Constructs a ResNet-152 model.
357 |
358 | Args:
359 | pretrained (bool): If True, returns a model pre-trained on ImageNet
360 | """
361 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
362 | if pretrained:
363 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
364 | return model
365 |
--------------------------------------------------------------------------------
/network/wider_resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/mapillary/inplace_abn/
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017, mapillary
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 | import logging
36 | import sys
37 | from collections import OrderedDict
38 | from functools import partial
39 | import torch.nn as nn
40 | import torch
41 | import network.mynn as mynn
42 |
43 | def bnrelu(channels):
44 | """
45 | Single Layer BN and Relui
46 | """
47 | return nn.Sequential(mynn.Norm2d(channels),
48 | nn.ReLU(inplace=True))
49 |
50 | class GlobalAvgPool2d(nn.Module):
51 | """
52 | Global average pooling over the input's spatial dimensions
53 | """
54 |
55 | def __init__(self):
56 | super(GlobalAvgPool2d, self).__init__()
57 | logging.info("Global Average Pooling Initialized")
58 |
59 | def forward(self, inputs):
60 | in_size = inputs.size()
61 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
62 |
63 |
64 | class IdentityResidualBlock(nn.Module):
65 | """
66 | Identity Residual Block for WideResnet
67 | """
68 | def __init__(self,
69 | in_channels,
70 | channels,
71 | stride=1,
72 | dilation=1,
73 | groups=1,
74 | norm_act=bnrelu,
75 | dropout=None,
76 | dist_bn=False
77 | ):
78 | """Configurable identity-mapping residual block
79 |
80 | Parameters
81 | ----------
82 | in_channels : int
83 | Number of input channels.
84 | channels : list of int
85 | Number of channels in the internal feature maps.
86 | Can either have two or three elements: if three construct
87 | a residual block with two `3 x 3` convolutions,
88 | otherwise construct a bottleneck block with `1 x 1`, then
89 | `3 x 3` then `1 x 1` convolutions.
90 | stride : int
91 | Stride of the first `3 x 3` convolution
92 | dilation : int
93 | Dilation to apply to the `3 x 3` convolutions.
94 | groups : int
95 | Number of convolution groups.
96 | This is used to create ResNeXt-style blocks and is only compatible with
97 | bottleneck blocks.
98 | norm_act : callable
99 | Function to create normalization / activation Module.
100 | dropout: callable
101 | Function to create Dropout Module.
102 | dist_bn: Boolean
103 | A variable to enable or disable use of distributed BN
104 | """
105 | super(IdentityResidualBlock, self).__init__()
106 | self.dist_bn = dist_bn
107 |
108 | # Check if we are using distributed BN and use the nn from encoding.nn
109 | # library rather than using standard pytorch.nn
110 |
111 |
112 | # Check parameters for inconsistencies
113 | if len(channels) != 2 and len(channels) != 3:
114 | raise ValueError("channels must contain either two or three values")
115 | if len(channels) == 2 and groups != 1:
116 | raise ValueError("groups > 1 are only valid if len(channels) == 3")
117 |
118 | is_bottleneck = len(channels) == 3
119 | need_proj_conv = stride != 1 or in_channels != channels[-1]
120 |
121 | self.bn1 = norm_act(in_channels)
122 | if not is_bottleneck:
123 | layers = [
124 | ("conv1", nn.Conv2d(in_channels,
125 | channels[0],
126 | 3,
127 | stride=stride,
128 | padding=dilation,
129 | bias=False,
130 | dilation=dilation)),
131 | ("bn2", norm_act(channels[0])),
132 | ("conv2", nn.Conv2d(channels[0], channels[1],
133 | 3,
134 | stride=1,
135 | padding=dilation,
136 | bias=False,
137 | dilation=dilation))
138 | ]
139 | if dropout is not None:
140 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
141 | else:
142 | layers = [
143 | ("conv1",
144 | nn.Conv2d(in_channels,
145 | channels[0],
146 | 1,
147 | stride=stride,
148 | padding=0,
149 | bias=False)),
150 | ("bn2", norm_act(channels[0])),
151 | ("conv2", nn.Conv2d(channels[0],
152 | channels[1],
153 | 3, stride=1,
154 | padding=dilation, bias=False,
155 | groups=groups,
156 | dilation=dilation)),
157 | ("bn3", norm_act(channels[1])),
158 | ("conv3", nn.Conv2d(channels[1], channels[2],
159 | 1, stride=1, padding=0, bias=False))
160 | ]
161 | if dropout is not None:
162 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
163 | self.convs = nn.Sequential(OrderedDict(layers))
164 |
165 | if need_proj_conv:
166 | self.proj_conv = nn.Conv2d(
167 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
168 |
169 | def forward(self, x):
170 | """
171 | This is the standard forward function for non-distributed batch norm
172 | """
173 | if hasattr(self, "proj_conv"):
174 | bn1 = self.bn1(x)
175 | shortcut = self.proj_conv(bn1)
176 | else:
177 | shortcut = x.clone()
178 | bn1 = self.bn1(x)
179 |
180 | out = self.convs(bn1)
181 | out.add_(shortcut)
182 | return out
183 |
184 |
185 |
186 |
187 | class WiderResNet(nn.Module):
188 | """
189 | WideResnet Global Module for Initialization
190 | """
191 | def __init__(self,
192 | structure,
193 | norm_act=bnrelu,
194 | classes=0
195 | ):
196 | """Wider ResNet with pre-activation (identity mapping) blocks
197 |
198 | Parameters
199 | ----------
200 | structure : list of int
201 | Number of residual blocks in each of the six modules of the network.
202 | norm_act : callable
203 | Function to create normalization / activation Module.
204 | classes : int
205 | If not `0` also include global average pooling and \
206 | a fully-connected layer with `classes` outputs at the end
207 | of the network.
208 | """
209 | super(WiderResNet, self).__init__()
210 | self.structure = structure
211 |
212 | if len(structure) != 6:
213 | raise ValueError("Expected a structure with six values")
214 |
215 | # Initial layers
216 | self.mod1 = nn.Sequential(OrderedDict([
217 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))
218 | ]))
219 |
220 | # Groups of residual blocks
221 | in_channels = 64
222 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024),
223 | (512, 1024, 2048), (1024, 2048, 4096)]
224 | for mod_id, num in enumerate(structure):
225 | # Create blocks for module
226 | blocks = []
227 | for block_id in range(num):
228 | blocks.append((
229 | "block%d" % (block_id + 1),
230 | IdentityResidualBlock(in_channels, channels[mod_id],
231 | norm_act=norm_act)
232 | ))
233 |
234 | # Update channels and p_keep
235 | in_channels = channels[mod_id][-1]
236 |
237 | # Create module
238 | if mod_id <= 4:
239 | self.add_module("pool%d" %
240 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1))
241 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
242 |
243 | # Pooling and predictor
244 | self.bn_out = norm_act(in_channels)
245 | if classes != 0:
246 | self.classifier = nn.Sequential(OrderedDict([
247 | ("avg_pool", GlobalAvgPool2d()),
248 | ("fc", nn.Linear(in_channels, classes))
249 | ]))
250 |
251 | def forward(self, img):
252 | out = self.mod1(img)
253 | out = self.mod2(self.pool2(out))
254 | out = self.mod3(self.pool3(out))
255 | out = self.mod4(self.pool4(out))
256 | out = self.mod5(self.pool5(out))
257 | out = self.mod6(self.pool6(out))
258 | out = self.mod7(out)
259 | out = self.bn_out(out)
260 |
261 | if hasattr(self, "classifier"):
262 | out = self.classifier(out)
263 |
264 | return out
265 |
266 |
267 | class WiderResNetA2(nn.Module):
268 | """
269 | Wider ResNet with pre-activation (identity mapping) blocks
270 |
271 | This variant uses down-sampling by max-pooling in the first two blocks and
272 | by strided convolution in the others.
273 |
274 | Parameters
275 | ----------
276 | structure : list of int
277 | Number of residual blocks in each of the six modules of the network.
278 | norm_act : callable
279 | Function to create normalization / activation Module.
280 | classes : int
281 | If not `0` also include global average pooling and a fully-connected layer
282 | with `classes` outputs at the end
283 | of the network.
284 | dilation : bool
285 | If `True` apply dilation to the last three modules and change the
286 | down-sampling factor from 32 to 8.
287 | """
288 | def __init__(self,
289 | structure,
290 | norm_act=bnrelu,
291 | classes=0,
292 | dilation=False,
293 | dist_bn=False
294 | ):
295 | super(WiderResNetA2, self).__init__()
296 | self.dist_bn = dist_bn
297 |
298 | # If using distributed batch norm, use the encoding.nn as oppose to torch.nn
299 |
300 |
301 | nn.Dropout = nn.Dropout2d
302 | norm_act = bnrelu
303 | self.structure = structure
304 | self.dilation = dilation
305 |
306 | if len(structure) != 6:
307 | raise ValueError("Expected a structure with six values")
308 |
309 | # Initial layers
310 | self.mod1 = torch.nn.Sequential(OrderedDict([
311 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))
312 | ]))
313 |
314 | # Groups of residual blocks
315 | in_channels = 64
316 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048),
317 | (1024, 2048, 4096)]
318 | for mod_id, num in enumerate(structure):
319 | # Create blocks for module
320 | blocks = []
321 | for block_id in range(num):
322 | if not dilation:
323 | dil = 1
324 | stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1
325 | else:
326 | if mod_id == 3:
327 | dil = 2
328 | elif mod_id > 3:
329 | dil = 4
330 | else:
331 | dil = 1
332 | stride = 2 if block_id == 0 and mod_id == 2 else 1
333 |
334 | if mod_id == 4:
335 | drop = partial(nn.Dropout, p=0.3)
336 | elif mod_id == 5:
337 | drop = partial(nn.Dropout, p=0.5)
338 | else:
339 | drop = None
340 |
341 | blocks.append((
342 | "block%d" % (block_id + 1),
343 | IdentityResidualBlock(in_channels,
344 | channels[mod_id], norm_act=norm_act,
345 | stride=stride, dilation=dil,
346 | dropout=drop, dist_bn=self.dist_bn)
347 | ))
348 |
349 | # Update channels and p_keep
350 | in_channels = channels[mod_id][-1]
351 |
352 | # Create module
353 | if mod_id < 2:
354 | self.add_module("pool%d" %
355 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1))
356 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
357 |
358 | # Pooling and predictor
359 | self.bn_out = norm_act(in_channels)
360 | if classes != 0:
361 | self.classifier = nn.Sequential(OrderedDict([
362 | ("avg_pool", GlobalAvgPool2d()),
363 | ("fc", nn.Linear(in_channels, classes))
364 | ]))
365 |
366 | def forward(self, img):
367 | out = self.mod1(img)
368 | out = self.mod2(self.pool2(out))
369 | out = self.mod3(self.pool3(out))
370 | out = self.mod4(out)
371 | out = self.mod5(out)
372 | out = self.mod6(out)
373 | out = self.mod7(out)
374 | out = self.bn_out(out)
375 |
376 | if hasattr(self, "classifier"):
377 | return self.classifier(out)
378 | return out
379 |
380 |
381 | _NETS = {
382 | "16": {"structure": [1, 1, 1, 1, 1, 1]},
383 | "20": {"structure": [1, 1, 1, 3, 1, 1]},
384 | "38": {"structure": [3, 3, 6, 3, 1, 1]},
385 | }
386 |
387 | __all__ = []
388 | for name, params in _NETS.items():
389 | net_name = "wider_resnet" + name
390 | setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params))
391 | __all__.append(net_name)
392 | for name, params in _NETS.items():
393 | net_name = "wider_resnet" + name + "_a2"
394 | setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params))
395 | __all__.append(net_name)
396 |
--------------------------------------------------------------------------------
/network/SEresnext.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/Cadene/pretrained-models.pytorch
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017, Remi Cadene
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 | import logging
36 | from collections import OrderedDict
37 | import math
38 | import torch.nn as nn
39 | from torch.utils import model_zoo
40 | import network.mynn as mynn
41 |
42 | __all__ = ['SENet', 'se_resnext50_32x4d', 'se_resnext101_32x4d']
43 |
44 | pretrained_settings = {
45 | 'se_resnext50_32x4d': {
46 | 'imagenet': {
47 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
48 | 'input_space': 'RGB',
49 | 'input_size': [3, 224, 224],
50 | 'input_range': [0, 1],
51 | 'mean': [0.485, 0.456, 0.406],
52 | 'std': [0.229, 0.224, 0.225],
53 | 'num_classes': 1000
54 | }
55 | },
56 | 'se_resnext101_32x4d': {
57 | 'imagenet': {
58 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
59 | 'input_space': 'RGB',
60 | 'input_size': [3, 224, 224],
61 | 'input_range': [0, 1],
62 | 'mean': [0.485, 0.456, 0.406],
63 | 'std': [0.229, 0.224, 0.225],
64 | 'num_classes': 1000
65 | }
66 | },
67 | }
68 |
69 |
70 | class SEModule(nn.Module):
71 | """
72 | Sequeeze Excitation Module
73 | """
74 | def __init__(self, channels, reduction):
75 | super(SEModule, self).__init__()
76 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
77 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
78 | padding=0)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
81 | padding=0)
82 | self.sigmoid = nn.Sigmoid()
83 |
84 | def forward(self, x):
85 | module_input = x
86 | x = self.avg_pool(x)
87 | x = self.fc1(x)
88 | x = self.relu(x)
89 | x = self.fc2(x)
90 | x = self.sigmoid(x)
91 | return module_input * x
92 |
93 |
94 | class Bottleneck(nn.Module):
95 | """
96 | Base class for bottlenecks that implements `forward()` method.
97 | """
98 | def forward(self, x):
99 | residual = x
100 |
101 | out = self.conv1(x)
102 | out = self.bn1(out)
103 | out = self.relu(out)
104 |
105 | out = self.conv2(out)
106 | out = self.bn2(out)
107 | out = self.relu(out)
108 |
109 | out = self.conv3(out)
110 | out = self.bn3(out)
111 |
112 | if self.downsample is not None:
113 | residual = self.downsample(x)
114 |
115 | out = self.se_module(out) + residual
116 | out = self.relu(out)
117 |
118 | return out
119 |
120 |
121 | class SEBottleneck(Bottleneck):
122 | """
123 | Bottleneck for SENet154.
124 | """
125 | expansion = 4
126 |
127 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
128 | downsample=None):
129 | super(SEBottleneck, self).__init__()
130 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
131 | self.bn1 = mynn.Norm2d(planes * 2)
132 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
133 | stride=stride, padding=1, groups=groups,
134 | bias=False)
135 | self.bn2 = mynn.Norm2d(planes * 4)
136 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
137 | bias=False)
138 | self.bn3 = mynn.Norm2d(planes * 4)
139 | self.relu = nn.ReLU(inplace=True)
140 | self.se_module = SEModule(planes * 4, reduction=reduction)
141 | self.downsample = downsample
142 | self.stride = stride
143 |
144 |
145 | class SEResNetBottleneck(Bottleneck):
146 | """
147 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
148 | implementation and uses `stride=stride` in `conv1` and not in `conv2`
149 | (the latter is used in the torchvision implementation of ResNet).
150 | """
151 | expansion = 4
152 |
153 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
154 | downsample=None):
155 | super(SEResNetBottleneck, self).__init__()
156 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
157 | stride=stride)
158 | self.bn1 = mynn.Norm2d(planes)
159 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
160 | groups=groups, bias=False)
161 | self.bn2 = mynn.Norm2d(planes)
162 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
163 | self.bn3 = mynn.Norm2d(planes * 4)
164 | self.relu = nn.ReLU(inplace=True)
165 | self.se_module = SEModule(planes * 4, reduction=reduction)
166 | self.downsample = downsample
167 | self.stride = stride
168 |
169 |
170 | class SEResNeXtBottleneck(Bottleneck):
171 | """
172 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
173 | """
174 | expansion = 4
175 |
176 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
177 | downsample=None, base_width=4):
178 | super(SEResNeXtBottleneck, self).__init__()
179 | width = math.floor(planes * (base_width / 64)) * groups
180 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
181 | stride=1)
182 | self.bn1 = mynn.Norm2d(width)
183 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
184 | padding=1, groups=groups, bias=False)
185 | self.bn2 = mynn.Norm2d(width)
186 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
187 | self.bn3 = mynn.Norm2d(planes * 4)
188 | self.relu = nn.ReLU(inplace=True)
189 | self.se_module = SEModule(planes * 4, reduction=reduction)
190 | self.downsample = downsample
191 | self.stride = stride
192 |
193 |
194 | class SENet(nn.Module):
195 | """
196 | Main Squeeze Excitation Network Module
197 | """
198 |
199 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
200 | inplanes=128, input_3x3=True, downsample_kernel_size=3,
201 | downsample_padding=1, num_classes=1000):
202 | """
203 | Parameters
204 | ----------
205 | block (nn.Module): Bottleneck class.
206 | - For SENet154: SEBottleneck
207 | - For SE-ResNet models: SEResNetBottleneck
208 | - For SE-ResNeXt models: SEResNeXtBottleneck
209 | layers (list of ints): Number of residual blocks for 4 layers of the
210 | network (layer1...layer4).
211 | groups (int): Number of groups for the 3x3 convolution in each
212 | bottleneck block.
213 | - For SENet154: 64
214 | - For SE-ResNet models: 1
215 | - For SE-ResNeXt models: 32
216 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
217 | - For all models: 16
218 | dropout_p (float or None): Drop probability for the Dropout layer.
219 | If `None` the Dropout layer is not used.
220 | - For SENet154: 0.2
221 | - For SE-ResNet models: None
222 | - For SE-ResNeXt models: None
223 | inplanes (int): Number of input channels for layer1.
224 | - For SENet154: 128
225 | - For SE-ResNet models: 64
226 | - For SE-ResNeXt models: 64
227 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
228 | a single 7x7 convolution in layer0.
229 | - For SENet154: True
230 | - For SE-ResNet models: False
231 | - For SE-ResNeXt models: False
232 | downsample_kernel_size (int): Kernel size for downsampling convolutions
233 | in layer2, layer3 and layer4.
234 | - For SENet154: 3
235 | - For SE-ResNet models: 1
236 | - For SE-ResNeXt models: 1
237 | downsample_padding (int): Padding for downsampling convolutions in
238 | layer2, layer3 and layer4.
239 | - For SENet154: 1
240 | - For SE-ResNet models: 0
241 | - For SE-ResNeXt models: 0
242 | num_classes (int): Number of outputs in `last_linear` layer.
243 | - For all models: 1000
244 | """
245 | super(SENet, self).__init__()
246 | self.inplanes = inplanes
247 | if input_3x3:
248 | layer0_modules = [
249 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
250 | bias=False)),
251 | ('bn1', mynn.Norm2d(64)),
252 | ('relu1', nn.ReLU(inplace=True)),
253 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
254 | bias=False)),
255 | ('bn2', mynn.Norm2d(64)),
256 | ('relu2', nn.ReLU(inplace=True)),
257 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
258 | bias=False)),
259 | ('bn3', mynn.Norm2d(inplanes)),
260 | ('relu3', nn.ReLU(inplace=True)),
261 | ]
262 | else:
263 | layer0_modules = [
264 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
265 | padding=3, bias=False)),
266 | ('bn1', mynn.Norm2d(inplanes)),
267 | ('relu1', nn.ReLU(inplace=True)),
268 | ]
269 | # To preserve compatibility with Caffe weights `ceil_mode=True`
270 | # is used instead of `padding=1`.
271 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
272 | ceil_mode=True)))
273 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
274 | self.layer1 = self._make_layer(
275 | block,
276 | planes=64,
277 | blocks=layers[0],
278 | groups=groups,
279 | reduction=reduction,
280 | downsample_kernel_size=1,
281 | downsample_padding=0
282 | )
283 | self.layer2 = self._make_layer(
284 | block,
285 | planes=128,
286 | blocks=layers[1],
287 | stride=2,
288 | groups=groups,
289 | reduction=reduction,
290 | downsample_kernel_size=downsample_kernel_size,
291 | downsample_padding=downsample_padding
292 | )
293 | self.layer3 = self._make_layer(
294 | block,
295 | planes=256,
296 | blocks=layers[2],
297 | stride=1,
298 | groups=groups,
299 | reduction=reduction,
300 | downsample_kernel_size=downsample_kernel_size,
301 | downsample_padding=downsample_padding
302 | )
303 | self.layer4 = self._make_layer(
304 | block,
305 | planes=512,
306 | blocks=layers[3],
307 | stride=1,
308 | groups=groups,
309 | reduction=reduction,
310 | downsample_kernel_size=downsample_kernel_size,
311 | downsample_padding=downsample_padding
312 | )
313 | self.avg_pool = nn.AvgPool2d(7, stride=1)
314 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
315 | self.last_linear = nn.Linear(512 * block.expansion, num_classes)
316 |
317 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
318 | downsample_kernel_size=1, downsample_padding=0):
319 | downsample = None
320 | if stride != 1 or self.inplanes != planes * block.expansion:
321 | downsample = nn.Sequential(
322 | nn.Conv2d(self.inplanes, planes * block.expansion,
323 | kernel_size=downsample_kernel_size, stride=stride,
324 | padding=downsample_padding, bias=False),
325 | mynn.Norm2d(planes * block.expansion),
326 | )
327 |
328 | layers = []
329 | layers.append(block(self.inplanes, planes, groups, reduction, stride,
330 | downsample))
331 | self.inplanes = planes * block.expansion
332 | for index in range(1, blocks):
333 | layers.append(block(self.inplanes, planes, groups, reduction))
334 |
335 | return nn.Sequential(*layers)
336 |
337 | def features(self, x):
338 | """
339 | Forward Pass through the each layer of SE network
340 | """
341 | x = self.layer0(x)
342 | x = self.layer1(x)
343 | x = self.layer2(x)
344 | x = self.layer3(x)
345 | x = self.layer4(x)
346 | return x
347 |
348 | def logits(self, x):
349 | """
350 | AvgPool and Linear Layer
351 | """
352 | x = self.avg_pool(x)
353 | if self.dropout is not None:
354 | x = self.dropout(x)
355 | x = x.view(x.size(0), -1)
356 | x = self.last_linear(x)
357 | return x
358 |
359 | def forward(self, x):
360 | x = self.features(x)
361 | x = self.logits(x)
362 | return x
363 |
364 |
365 | def initialize_pretrained_model(model, num_classes, settings):
366 | """
367 | Initialize Pretrain Model Information,
368 | Dowload weights, load weights, set variables
369 | """
370 | assert num_classes == settings['num_classes'], \
371 | 'num_classes should be {}, but is {}'.format(
372 | settings['num_classes'], num_classes)
373 | weights = model_zoo.load_url(settings['url'])
374 | model.load_state_dict(weights)
375 | model.input_space = settings['input_space']
376 | model.input_size = settings['input_size']
377 | model.input_range = settings['input_range']
378 | model.mean = settings['mean']
379 | model.std = settings['std']
380 |
381 |
382 |
383 | def se_resnext50_32x4d(num_classes=1000):
384 | """
385 | Defination For SE Resnext50
386 | """
387 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
388 | dropout_p=None, inplanes=64, input_3x3=False,
389 | downsample_kernel_size=1, downsample_padding=0,
390 | num_classes=num_classes)
391 | settings = pretrained_settings['se_resnext50_32x4d']['imagenet']
392 | initialize_pretrained_model(model, num_classes, settings)
393 | return model
394 |
395 |
396 | def se_resnext101_32x4d(num_classes=1000):
397 | """
398 | Defination For SE Resnext101
399 | """
400 |
401 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
402 | dropout_p=None, inplanes=64, input_3x3=False,
403 | downsample_kernel_size=1, downsample_padding=0,
404 | num_classes=num_classes)
405 | settings = pretrained_settings['se_resnext101_32x4d']['imagenet']
406 | initialize_pretrained_model(model, num_classes, settings)
407 | return model
408 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Pytorch Optimizer and Scheduler Related Task
3 | """
4 | import math
5 | import logging
6 | import torch
7 | from torch import optim
8 | from config import cfg
9 |
10 |
11 | def get_optimizer(args, net):
12 | """
13 | Decide Optimizer (Adam or SGD)
14 | """
15 | if args.backbone_lr > 0.0:
16 | base_params = []
17 | resnet_params = []
18 | resnet_name = []
19 | resnet_name.append('layer0')
20 | resnet_name.append('layer1')
21 | #resnet_name.append('layer2')
22 | #resnet_name.append('layer3')
23 | #resnet_name.append('layer4')
24 | len_resnet = len(resnet_name)
25 | else:
26 | param_groups = net.parameters()
27 |
28 | if args.backbone_lr > 0.0:
29 | for name, param in net.named_parameters():
30 | is_resnet = False
31 | for i in range(len_resnet):
32 | if resnet_name[i] in name:
33 | resnet_params.append(param)
34 | # param.requires_grad=False
35 | print("resnet_name", name)
36 | is_resnet = True
37 | break
38 | if not is_resnet:
39 | base_params.append(param)
40 |
41 | if args.sgd:
42 | if args.backbone_lr > 0.0:
43 | optimizer = optim.SGD([
44 | {'params': base_params},
45 | {'params': resnet_params, 'lr':args.backbone_lr}
46 | ],
47 | lr=args.lr,
48 | weight_decay=5e-4, #args.weight_decay,
49 | momentum=args.momentum,
50 | nesterov=False)
51 | else:
52 | optimizer = optim.SGD(param_groups,
53 | lr=args.lr,
54 | weight_decay=5e-4, #args.weight_decay,
55 | momentum=args.momentum,
56 | nesterov=False)
57 | else:
58 | raise ValueError('Not a valid optimizer')
59 |
60 | if args.lr_schedule == 'scl-poly':
61 | if cfg.REDUCE_BORDER_ITER == -1:
62 | raise ValueError('ERROR Cannot Do Scale Poly')
63 |
64 | rescale_thresh = cfg.REDUCE_BORDER_ITER
65 | scale_value = args.rescale
66 | lambda1 = lambda iteration: \
67 | math.pow(1 - iteration / args.max_iter,
68 | args.poly_exp) if iteration < rescale_thresh else scale_value * math.pow(
69 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh),
70 | args.repoly)
71 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
72 | elif args.lr_schedule == 'poly':
73 | lambda1 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp)
74 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
75 | else:
76 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule))
77 |
78 | return optimizer, scheduler
79 |
80 |
81 | def get_optimizer_attention(args, net):
82 | """
83 | Decide Optimizer (Adam or SGD)
84 | """
85 | attention_params = []
86 | base_params = []
87 | hanet_name = []
88 |
89 | if args.backbone_lr > 0.0:
90 | resnet_params = []
91 | resnet_name = []
92 | resnet_name.append('layer0')
93 | resnet_name.append('layer1')
94 | #resnet_name.append('layer2')
95 | #resnet_name.append('layer3')
96 | #resnet_name.append('layer4')
97 | len_resnet = len(resnet_name)
98 |
99 | for i in range(5):
100 | if args.hanet[i] > 0: # HANet_Diff
101 | hanet_name.append('hanet' + str(i))
102 |
103 | len_hanet = len(hanet_name)
104 |
105 | for name, param in net.named_parameters():
106 | is_hanet = False
107 | is_resnet = False
108 | if args.backbone_lr > 0.0:
109 | for i in range(len_resnet):
110 | if resnet_name[i] in name:
111 | resnet_params.append(param)
112 | # param.requires_grad=False
113 | print("resnet_name", name)
114 | is_resnet = True
115 | break
116 | if not is_resnet:
117 | for i in range(len_hanet):
118 | if hanet_name[i] in name:
119 | attention_params.append(param)
120 | #print("hanet_name", name)
121 | is_hanet = True
122 | break
123 | if not is_hanet and not is_resnet:
124 | base_params.append(param)
125 | #print("base", name)
126 |
127 | if args.sgd:
128 | if args.backbone_lr > 0.0:
129 | optimizer = optim.SGD([
130 | {'params': base_params},
131 | {'params': resnet_params, 'lr':args.backbone_lr}
132 | ],
133 | lr=args.lr,
134 | weight_decay=5e-4, #args.weight_decay,
135 | momentum=args.momentum,
136 | nesterov=False)
137 | else:
138 | optimizer = optim.SGD(base_params,
139 | lr=args.lr,
140 | weight_decay=5e-4, #args.weight_decay,
141 | momentum=args.momentum,
142 | nesterov=False)
143 | else:
144 | raise ValueError('Not a valid optimizer')
145 |
146 | print(" ############# HANet Number", len_hanet)
147 | optimizer_at = optim.SGD(attention_params,
148 | lr=args.hanet_lr,
149 | weight_decay=args.hanet_wd,
150 | momentum=args.momentum,
151 | nesterov=False)
152 |
153 |
154 |
155 | if args.lr_schedule == 'scl-poly':
156 | if cfg.REDUCE_BORDER_ITER == -1:
157 | raise ValueError('ERROR Cannot Do Scale Poly')
158 |
159 | rescale_thresh = cfg.REDUCE_BORDER_ITER
160 | scale_value = args.rescale
161 | lambda1 = lambda iteration: \
162 | math.pow(1 - iteration / args.max_iter,
163 | args.poly_exp) if iteration <= rescale_thresh else scale_value * math.pow(
164 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh),
165 | args.repoly)
166 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
167 |
168 | if args.hanet_poly_exp > 0.0:
169 | lambda2 = lambda iteration: \
170 | math.pow(1 - iteration / args.max_iter,
171 | args.hanet_poly_exp) if iteration <= rescale_thresh else scale_value * math.pow(
172 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh),
173 | args.repoly)
174 | scheduler_at = optim.lr_scheduler.LambdaLR(optimizer_at, lr_lambda=lambda2)
175 | else:
176 | lambda2 = lambda iteration: \
177 | math.pow(1 - iteration / args.max_iter,
178 | args.poly_exp) if iteration <= rescale_thresh else scale_value * math.pow(
179 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh),
180 | args.repoly)
181 | scheduler_at = optim.lr_scheduler.LambdaLR(optimizer_at, lr_lambda=lambda2)
182 |
183 | elif args.lr_schedule == 'poly':
184 | lambda1 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp)
185 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
186 |
187 | # for attention module
188 | if args.hanet_poly_exp > 0.0:
189 | lambda2 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.hanet_poly_exp)
190 | scheduler_at = optim.lr_scheduler.LambdaLR(optimizer_at, lr_lambda=lambda2)
191 | else:
192 | lambda2 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp)
193 | scheduler_at = optim.lr_scheduler.LambdaLR(optimizer_at, lr_lambda=lambda2)
194 | else:
195 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule))
196 |
197 | return optimizer, scheduler, optimizer_at, scheduler_at
198 |
199 |
200 | def get_optimizer_by_epoch(args, net):
201 | """
202 | Decide Optimizer (Adam or SGD)
203 | """
204 | param_groups = net.parameters()
205 |
206 | if args.sgd:
207 | optimizer = optim.SGD(param_groups,
208 | lr=args.lr,
209 | weight_decay=args.weight_decay,
210 | momentum=args.momentum,
211 | nesterov=False)
212 | elif args.adam:
213 | amsgrad = False
214 | if args.amsgrad:
215 | amsgrad = True
216 | optimizer = optim.Adam(param_groups,
217 | lr=args.lr,
218 | weight_decay=args.weight_decay,
219 | amsgrad=amsgrad
220 | )
221 | else:
222 | raise ValueError('Not a valid optimizer')
223 |
224 | if args.lr_schedule == 'scl-poly':
225 | if cfg.REDUCE_BORDER_EPOCH == -1:
226 | raise ValueError('ERROR Cannot Do Scale Poly')
227 |
228 | rescale_thresh = cfg.REDUCE_BORDER_EPOCH
229 | scale_value = args.rescale
230 | lambda1 = lambda epoch: \
231 | math.pow(1 - epoch / args.max_epoch,
232 | args.poly_exp) if epoch < rescale_thresh else scale_value * math.pow(
233 | 1 - (epoch - rescale_thresh) / (args.max_epoch - rescale_thresh),
234 | args.repoly)
235 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
236 | elif args.lr_schedule == 'poly':
237 | lambda1 = lambda epoch: math.pow(1 - epoch / args.max_epoch, args.poly_exp)
238 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
239 | else:
240 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule))
241 |
242 | return optimizer, scheduler
243 |
244 | def load_weights_hanet(net, optimizer, optimizer_at, scheduler, scheduler_at, snapshot_file, restore_optimizer_bool=False):
245 | """
246 | Load weights from snapshot file
247 | """
248 | logging.info("Loading weights from model %s", snapshot_file)
249 | net, optimizer, optimizer_at, scheduler, scheduler_at, epoch, mean_iu = restore_snapshot_hanet(net, optimizer,
250 | optimizer_at, scheduler, scheduler_at, snapshot_file, restore_optimizer_bool)
251 | return epoch, mean_iu
252 |
253 | def load_weights_pe(net, snapshot_file):
254 | """
255 | Load weights from snapshot file
256 | """
257 | logging.info("Loading weights from model %s", snapshot_file)
258 | net = restore_snapshot_pe(net, snapshot_file)
259 |
260 |
261 | def load_weights(net, optimizer, scheduler, snapshot_file, restore_optimizer_bool=False):
262 | """
263 | Load weights from snapshot file
264 | """
265 | logging.info("Loading weights from model %s", snapshot_file)
266 | net, optimizer, scheduler, epoch, mean_iu = restore_snapshot(net, optimizer, scheduler, snapshot_file,
267 | restore_optimizer_bool)
268 | return epoch, mean_iu
269 |
270 | def restore_snapshot_pe(net, snapshot):
271 | """
272 | Restore weights and optimizer (if needed ) for resuming job.
273 | """
274 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu'))
275 | logging.info("Checkpoint PE Load Compelete")
276 |
277 | if 'state_dict' in checkpoint:
278 | net = forgiving_state_restore_only_pe(net, checkpoint['state_dict'])
279 | else:
280 | net = forgiving_state_restore_only_pe(net, checkpoint)
281 |
282 | return net
283 |
284 | def forgiving_state_restore_only_pe(net, loaded_dict):
285 | """
286 | Handle partial loading when some tensors don't match up in size.
287 | Because we want to use models that were trained off a different
288 | number of classes.
289 | """
290 | net_state_dict = net.state_dict()
291 | new_loaded_dict = {}
292 | for k in net_state_dict:
293 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
294 | if 'pos_emb1d' in k:
295 | print("matched loading parameter", k)
296 | new_loaded_dict[k] = loaded_dict[k]
297 | # else:
298 | # print("Skipped loading parameter", k)
299 | # logging.info("Skipped loading parameter %s", k)
300 | net_state_dict.update(new_loaded_dict)
301 | net.load_state_dict(net_state_dict)
302 | return net
303 |
304 | def freeze_pe(net):
305 | for name, param in net.named_parameters():
306 | if 'pos_emb1d' in name:
307 | print("freeze parameter", name)
308 | param.requires_grad = False
309 |
310 | def restore_snapshot_hanet(net, optimizer, optimizer_at, scheduler, scheduler_at, snapshot, restore_optimizer_bool):
311 | """
312 | Restore weights and optimizer (if needed ) for resuming job.
313 | """
314 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu'))
315 | logging.info("Checkpoint Load Compelete")
316 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool:
317 | optimizer.load_state_dict(checkpoint['optimizer'])
318 |
319 | if optimizer_at is not None and 'optimizer_at' in checkpoint and restore_optimizer_bool:
320 | optimizer_at.load_state_dict(checkpoint['optimizer_at'])
321 |
322 | if scheduler is not None and 'scheduler' in checkpoint and restore_optimizer_bool:
323 | scheduler.load_state_dict(checkpoint['scheduler'])
324 |
325 | if scheduler_at is not None and 'scheduler_at' in checkpoint and restore_optimizer_bool:
326 | scheduler_at.load_state_dict(checkpoint['scheduler_at'])
327 |
328 | if 'state_dict' in checkpoint:
329 | net = forgiving_state_restore(net, checkpoint['state_dict'])
330 | else:
331 | net = forgiving_state_restore(net, checkpoint)
332 |
333 | return net, optimizer, optimizer_at, scheduler, scheduler_at, checkpoint['epoch'], checkpoint['mean_iu']
334 |
335 | def restore_snapshot(net, optimizer, scheduler, snapshot, restore_optimizer_bool):
336 | """
337 | Restore weights and optimizer (if needed ) for resuming job.
338 | """
339 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu'))
340 | logging.info("Checkpoint Load Compelete")
341 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool:
342 | optimizer.load_state_dict(checkpoint['optimizer'])
343 | if scheduler is not None and 'scheduler' in checkpoint and restore_optimizer_bool:
344 | scheduler.load_state_dict(checkpoint['scheduler'])
345 |
346 | if 'state_dict' in checkpoint:
347 | net = forgiving_state_restore(net, checkpoint['state_dict'])
348 | else:
349 | net = forgiving_state_restore(net, checkpoint)
350 |
351 | return net, optimizer, scheduler, checkpoint['epoch'], checkpoint['mean_iu']
352 |
353 |
354 | def forgiving_state_restore(net, loaded_dict):
355 | """
356 | Handle partial loading when some tensors don't match up in size.
357 | Because we want to use models that were trained off a different
358 | number of classes.
359 | """
360 | net_state_dict = net.state_dict()
361 | new_loaded_dict = {}
362 | for k in net_state_dict:
363 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
364 | new_loaded_dict[k] = loaded_dict[k]
365 | else:
366 | print("Skipped loading parameter", k)
367 | # logging.info("Skipped loading parameter %s", k)
368 | net_state_dict.update(new_loaded_dict)
369 | net.load_state_dict(net_state_dict)
370 | return net
371 |
372 | def forgiving_state_copy(target_net, source_net):
373 | """
374 | Handle partial loading when some tensors don't match up in size.
375 | Because we want to use models that were trained off a different
376 | number of classes.
377 | """
378 | net_state_dict = target_net.state_dict()
379 | loaded_dict = source_net.state_dict()
380 | new_loaded_dict = {}
381 | for k in net_state_dict:
382 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
383 | new_loaded_dict[k] = loaded_dict[k]
384 | print("Matched", k)
385 | else:
386 | print("Skipped loading parameter ", k)
387 | # logging.info("Skipped loading parameter %s", k)
388 | net_state_dict.update(new_loaded_dict)
389 | target_net.load_state_dict(net_state_dict)
390 | return target_net
391 |
--------------------------------------------------------------------------------
/datasets/camvid.py:
--------------------------------------------------------------------------------
1 | """
2 | Camvid Dataset Loader
3 | """
4 |
5 | import os
6 | import sys
7 | import numpy as np
8 | from PIL import Image
9 | from torch.utils import data
10 | import logging
11 | import datasets.uniform as uniform
12 | import json
13 | from config import cfg
14 | import copy
15 | import torch
16 |
17 |
18 | # trainid_to_name = cityscapes_labels.trainId2name
19 | # id_to_trainid = cityscapes_labels.label2trainid
20 | num_classes = 11
21 | ignore_label = 11
22 | root = cfg.DATASET.CAMVID_DIR
23 |
24 | palette = [128, 128, 128,
25 | 128, 0, 0,
26 | 192, 192, 128,
27 | 128, 64, 128,
28 | 0, 0, 192,
29 | 128, 128, 0,
30 | 192, 128, 128,
31 | 64, 64, 128,
32 | 64, 0, 128,
33 | 64, 64, 0,
34 | 0, 128, 192]
35 |
36 |
37 | CAMVID_CLASSES = ['Sky',
38 | 'Building',
39 | 'Column-Pole',
40 | 'Road',
41 | 'Sidewalk',
42 | 'Tree',
43 | 'Sign-Symbol',
44 | 'Fence',
45 | 'Car',
46 | 'Pedestrain',
47 | 'Bicyclist',
48 | 'Void']
49 |
50 | CAMVID_CLASS_COLORS = [
51 | (128, 128, 128),
52 | (128, 0, 0),
53 | (192, 192, 128),
54 | (128, 64, 128),
55 | (0, 0, 192),
56 | (128, 128, 0),
57 | (192, 128, 128),
58 | (64, 64, 128),
59 | (64, 0, 128),
60 | (64, 64, 0),
61 | (0, 128, 192),
62 | (0, 0, 0),
63 | ]
64 |
65 | zero_pad = 256 * 3 - len(palette)
66 | for i in range(zero_pad):
67 | palette.append(0)
68 |
69 | def colorize_mask(mask):
70 | # mask: numpy array of the mask
71 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
72 | new_mask.putpalette(palette)
73 | return new_mask
74 |
75 | def add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip):
76 |
77 | c_items = os.listdir(img_path)
78 | c_items.sort()
79 | items = []
80 | aug_items = []
81 |
82 | for it in c_items:
83 | item = (os.path.join(img_path, it), os.path.join(mask_path, it))
84 | items.append(item)
85 | if mode != 'test' and maxSkip > 0:
86 | seq_info = it.split("_")
87 | cur_seq_id = seq_info[-1][:-4]
88 |
89 | if seq_info[0] == "0001TP":
90 | prev_seq_id = "%06d" % (int(cur_seq_id) - maxSkip)
91 | next_seq_id = "%06d" % (int(cur_seq_id) + maxSkip)
92 | elif seq_info[0] == "0006R0":
93 | prev_seq_id = "f%05d" % (int(cur_seq_id[1:]) - maxSkip)
94 | next_seq_id = "f%05d" % (int(cur_seq_id[1:]) + maxSkip)
95 | else:
96 | prev_seq_id = "%05d" % (int(cur_seq_id) - maxSkip)
97 | next_seq_id = "%05d" % (int(cur_seq_id) + maxSkip)
98 |
99 | prev_it = seq_info[0] + "_" + prev_seq_id + '.png'
100 | next_it = seq_info[0] + "_" + next_seq_id + '.png'
101 |
102 | prev_item = (os.path.join(aug_img_path, prev_it), os.path.join(aug_mask_path, prev_it))
103 | next_item = (os.path.join(aug_img_path, next_it), os.path.join(aug_mask_path, next_it))
104 | if os.path.isfile(prev_item[0]) and os.path.isfile(prev_item[1]):
105 | aug_items.append(prev_item)
106 | if os.path.isfile(next_item[0]) and os.path.isfile(next_item[1]):
107 | aug_items.append(next_item)
108 | return items, aug_items
109 |
110 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0):
111 |
112 | items = []
113 | aug_items = []
114 | assert quality == 'semantic'
115 | assert mode in ['train', 'val', 'trainval', 'test']
116 |
117 | # img_dir_name = "SegNet/CamVid"
118 | original_img_dir = ""
119 | augmented_img_dir = "camvid_aug3/CamVid"
120 |
121 | img_path = os.path.join(root, original_img_dir, 'train')
122 | mask_path = os.path.join(root, original_img_dir, 'trainannot')
123 | aug_img_path = os.path.join(root, augmented_img_dir, 'train')
124 | aug_mask_path = os.path.join(root, augmented_img_dir, 'trainannot')
125 |
126 | train_items, train_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip)
127 | logging.info('Camvid has a total of {} train images'.format(len(train_items)))
128 |
129 | img_path = os.path.join(root, original_img_dir, 'val')
130 | mask_path = os.path.join(root, original_img_dir, 'valannot')
131 | aug_img_path = os.path.join(root, augmented_img_dir, 'val')
132 | aug_mask_path = os.path.join(root, augmented_img_dir, 'valannot')
133 |
134 | val_items, val_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip)
135 | logging.info('Camvid has a total of {} validation images'.format(len(val_items)))
136 |
137 | if mode == 'test':
138 | img_path = os.path.join(root, original_img_dir, 'test')
139 | mask_path = os.path.join(root, original_img_dir, 'testannot')
140 | test_items, test_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip)
141 | logging.info('Camvid has a total of {} test images'.format(len(test_items)))
142 |
143 | if mode == 'train':
144 | items = train_items
145 | elif mode == 'val':
146 | items = val_items
147 | elif mode == 'trainval':
148 | items = train_items + val_items
149 | aug_items = []#train_aug_items + val_aug_items
150 | elif mode == 'test':
151 | items = test_items
152 | aug_items = []
153 | else:
154 | logging.info('Unknown mode {}'.format(mode))
155 | sys.exit()
156 |
157 | logging.info('Camvid-{}: {} images'.format(mode, len(items)))
158 |
159 | return items, aug_items
160 |
161 | class CAMVID(data.Dataset):
162 |
163 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None,
164 | transform=None, target_transform=None, dump_images=False,
165 | class_uniform_pct=0, class_uniform_tile=0, test=False,
166 | cv_split=None, scf=None, hardnm=0):
167 |
168 | self.quality = quality
169 | self.mode = mode
170 | self.maxSkip = maxSkip
171 | self.joint_transform_list = joint_transform_list
172 | self.transform = transform
173 | self.target_transform = target_transform
174 | self.dump_images = dump_images
175 | self.class_uniform_pct = class_uniform_pct
176 | self.class_uniform_tile = class_uniform_tile
177 | self.scf = scf
178 | self.hardnm = hardnm
179 | self.cv_split = cv_split
180 | self.centroids = []
181 |
182 | self.imgs, self.aug_imgs = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm)
183 | assert len(self.imgs), 'Found 0 images, please check the data set'
184 |
185 | # Centroids for GT data
186 | if self.class_uniform_pct > 0:
187 | json_fn = 'camvid_tile{}_cv{}_{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode)
188 |
189 | if os.path.isfile(json_fn):
190 | with open(json_fn, 'r') as json_data:
191 | centroids = json.load(json_data)
192 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
193 | else:
194 | self.centroids = uniform.class_centroids_all(
195 | self.imgs,
196 | num_classes,
197 | id2trainid=None,
198 | tile_size=class_uniform_tile)
199 | with open(json_fn, 'w') as outfile:
200 | json.dump(self.centroids, outfile, indent=4)
201 |
202 | self.fine_centroids = copy.deepcopy(self.centroids)
203 |
204 | # if self.maxSkip > 0:
205 | # json_fn = 'camvid_tile{}_cv{}_{}_skip{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.maxSkip)
206 | # if os.path.isfile(json_fn):
207 | # with open(json_fn, 'r') as json_data:
208 | # centroids = json.load(json_data)
209 | # self.aug_centroids = {int(idx): centroids[idx] for idx in centroids}
210 | # else:
211 | # self.aug_centroids = uniform.class_centroids_all(
212 | # self.aug_imgs,
213 | # num_classes,
214 | # id2trainid=None,
215 | # tile_size=class_uniform_tile)
216 | # with open(json_fn, 'w') as outfile:
217 | # json.dump(self.aug_centroids, outfile, indent=4)
218 |
219 | # for class_id in range(num_classes):
220 | # self.centroids[class_id].extend(self.aug_centroids[class_id])
221 |
222 | self.build_epoch()
223 |
224 | def build_epoch(self, cut=False):
225 |
226 | if self.class_uniform_pct > 0:
227 | if cut:
228 | self.imgs_uniform = uniform.build_epoch(self.imgs,
229 | self.fine_centroids,
230 | num_classes,
231 | cfg.CLASS_UNIFORM_PCT)
232 | else:
233 | self.imgs_uniform = uniform.build_epoch(self.imgs,
234 | self.centroids,
235 | num_classes,
236 | cfg.CLASS_UNIFORM_PCT)
237 | else:
238 | self.imgs_uniform = self.imgs
239 |
240 |
241 | def __getitem__(self, index):
242 | elem = self.imgs_uniform[index]
243 | centroid = None
244 | if len(elem) == 4:
245 | img_path, mask_path, centroid, class_id = elem
246 | else:
247 | img_path, mask_path = elem
248 |
249 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
250 | img_name = os.path.splitext(os.path.basename(img_path))[0]
251 |
252 | # Image Transformations
253 | if self.joint_transform_list is not None:
254 | for idx, xform in enumerate(self.joint_transform_list):
255 | if idx == 0 and centroid is not None:
256 | # HACK
257 | # We assume that the first transform is capable of taking
258 | # in a centroid
259 | img, mask = xform(img, mask, centroid)
260 | else:
261 | img, mask = xform(img, mask)
262 |
263 | # Debug
264 | if self.dump_images and centroid is not None:
265 | outdir = './dump_imgs_{}'.format(self.mode)
266 | os.makedirs(outdir, exist_ok=True)
267 | dump_img_name = trainid_to_name[class_id] + '_' + img_name
268 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
269 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
270 | mask_img = colorize_mask(np.array(mask))
271 | img.save(out_img_fn)
272 | mask_img.save(out_msk_fn)
273 |
274 | if self.transform is not None:
275 | img = self.transform(img)
276 | if self.target_transform is not None:
277 | mask = self.target_transform(mask)
278 |
279 | return img, mask, img_name
280 |
281 | def __len__(self):
282 | return len(self.imgs_uniform)
283 |
284 |
285 | class CAMVIDWithPos(data.Dataset):
286 |
287 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None,
288 | transform=None, target_transform=None, target_aux_transform=None, dump_images=False,
289 | class_uniform_pct=0, class_uniform_tile=0, test=False,
290 | cv_split=None, scf=None, hardnm=0, pos_rfactor=8):
291 |
292 | self.quality = quality
293 | self.mode = mode
294 | self.maxSkip = maxSkip
295 | self.joint_transform_list = joint_transform_list
296 | self.transform = transform
297 | self.target_transform = target_transform
298 | self.target_aux_transform = target_aux_transform
299 | self.dump_images = dump_images
300 | self.class_uniform_pct = class_uniform_pct
301 | self.class_uniform_tile = class_uniform_tile
302 | self.scf = scf
303 | self.hardnm = hardnm
304 | self.cv_split = cv_split
305 | self.centroids = []
306 | self.pos_rfactor = pos_rfactor
307 |
308 | # position information
309 | self.pos_h = torch.arange(0, 1024).unsqueeze(0).unsqueeze(2).expand(-1,-1,2048)//8
310 | self.pos_w = torch.arange(0, 2048).unsqueeze(0).unsqueeze(1).expand(-1,1024,-1)//16
311 | self.pos_h = self.pos_h[0].byte().numpy()
312 | self.pos_w = self.pos_w[0].byte().numpy()
313 | # pos index to image
314 | self.pos_h = Image.fromarray(self.pos_h, mode="L")
315 | self.pos_w = Image.fromarray(self.pos_w, mode="L")
316 | # position information
317 |
318 |
319 | self.imgs, self.aug_imgs = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm)
320 | assert len(self.imgs), 'Found 0 images, please check the data set'
321 |
322 | # Centroids for GT data
323 | if self.class_uniform_pct > 0:
324 | json_fn = 'camvid_tile{}_cv{}_{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode)
325 |
326 | if os.path.isfile(json_fn):
327 | with open(json_fn, 'r') as json_data:
328 | centroids = json.load(json_data)
329 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
330 | else:
331 | self.centroids = uniform.class_centroids_all(
332 | self.imgs,
333 | num_classes,
334 | id2trainid=None,
335 | tile_size=class_uniform_tile)
336 | with open(json_fn, 'w') as outfile:
337 | json.dump(self.centroids, outfile, indent=4)
338 |
339 | self.fine_centroids = copy.deepcopy(self.centroids)
340 |
341 | if self.maxSkip > 0:
342 | json_fn = 'camvid_tile{}_cv{}_{}_skip{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.maxSkip)
343 | if os.path.isfile(json_fn):
344 | with open(json_fn, 'r') as json_data:
345 | centroids = json.load(json_data)
346 | self.aug_centroids = {int(idx): centroids[idx] for idx in centroids}
347 | else:
348 | self.aug_centroids = uniform.class_centroids_all(
349 | self.aug_imgs,
350 | num_classes,
351 | id2trainid=None,
352 | tile_size=class_uniform_tile)
353 | with open(json_fn, 'w') as outfile:
354 | json.dump(self.aug_centroids, outfile, indent=4)
355 |
356 | for class_id in range(num_classes):
357 | self.centroids[class_id].extend(self.aug_centroids[class_id])
358 |
359 | self.build_epoch()
360 |
361 | def build_epoch(self, cut=False):
362 |
363 | if self.class_uniform_pct > 0:
364 | if cut:
365 | self.imgs_uniform = uniform.build_epoch(self.imgs,
366 | self.fine_centroids,
367 | num_classes,
368 | cfg.CLASS_UNIFORM_PCT)
369 | else:
370 | self.imgs_uniform = uniform.build_epoch(self.imgs,
371 | self.centroids,
372 | num_classes,
373 | cfg.CLASS_UNIFORM_PCT)
374 | else:
375 | self.imgs_uniform = self.imgs
376 |
377 |
378 | def __getitem__(self, index):
379 | elem = self.imgs_uniform[index]
380 | centroid = None
381 | if len(elem) == 4:
382 | img_path, mask_path, centroid, class_id = elem
383 | else:
384 | img_path, mask_path = elem
385 |
386 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
387 | img_name = os.path.splitext(os.path.basename(img_path))[0]
388 |
389 | # position information
390 | pos_h = self.pos_h
391 | pos_w = self.pos_w
392 | # position information
393 |
394 | # Image Transformations
395 | if self.joint_transform_list is not None:
396 | for idx, xform in enumerate(self.joint_transform_list):
397 | if idx == 0 and centroid is not None:
398 | # HACK
399 | # We assume that the first transform is capable of taking
400 | # in a centroid
401 | # img, mask = xform(img, mask, centroid)
402 | img, mask, (pos_h, pos_w) = xform(img, mask, centroid, pos=(pos_h, pos_w))
403 | else:
404 | # img, mask = xform(img, mask)
405 | img, mask, (pos_h, pos_w) = xform(img, mask, pos=(pos_h, pos_w))
406 |
407 | # Debug
408 | if self.dump_images and centroid is not None:
409 | outdir = './dump_imgs_{}'.format(self.mode)
410 | os.makedirs(outdir, exist_ok=True)
411 | dump_img_name = trainid_to_name[class_id] + '_' + img_name
412 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
413 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
414 | mask_img = colorize_mask(np.array(mask))
415 | img.save(out_img_fn)
416 | mask_img.save(out_msk_fn)
417 |
418 | if self.transform is not None:
419 | img = self.transform(img)
420 | if self.target_aux_transform is not None:
421 | mask_aux = self.target_aux_transform(mask)
422 | else:
423 | mask_aux = torch.tensor([0])
424 | if self.target_transform is not None:
425 | mask = self.target_transform(mask)
426 |
427 | pos_h = torch.from_numpy(np.array(pos_h, dtype=np.uint8))# // self.pos_rfactor
428 | pos_w = torch.from_numpy(np.array(pos_w, dtype=np.uint8))# // self.pos_rfactor
429 |
430 | return img, mask, img_name, mask_aux, (pos_h, pos_w)
431 |
432 | def __len__(self):
433 | return len(self.imgs_uniform)
434 |
435 |
436 |
--------------------------------------------------------------------------------