├── utils
├── __init__.py
├── logger.py
├── visualize.py
├── score.py
├── lr_scheduler.py
├── loss.py
└── distributed.py
├── heatmapAD.png
├── models
├── __init__.py
├── DDRNet_39.py
├── DDRNet_23_slim.py
├── DDRNet_23.py
└── DDRNet_23_vis1.py
├── LICENSE
├── README.md
└── eval.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/heatmapAD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlpacaLLaMa643/All-day-CityScapes-segmentation/HEAD/heatmapAD.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .DDRNet_23_slim import get_ddrnet_23_slim
2 | from .DDRNet_39 import get_ddrnet_39
3 | from .DDRNet_23 import get_ddrnet_23
4 | from .DDRNet_23_vis1 import get_ddrnet_23_vis1
5 |
6 | from .DDRNet_23 import get_CA_interact
7 | from .DDRNet_23 import get_CA_merge
8 |
9 | models = {
10 | 'ddrnet_39': get_ddrnet_39,
11 | 'ddrnet_23_slim': get_ddrnet_23_slim,
12 | 'ddrnet_23': get_ddrnet_23,
13 | 'ddrnet_23_vis1': get_ddrnet_23_vis1,
14 | }
15 |
16 | intermodule={'inter': get_CA_interact}
17 |
18 | mergemodule={'merge': get_CA_merge}
19 |
20 | def get_segmentation_model(model, **kwargs):
21 | """Segmentation models"""
22 | return models[model.lower()](**kwargs)
23 |
24 | def get_inter_model(model, **kwargs):
25 | """interaction C-A module"""
26 | return intermodule[model.lower()](**kwargs)
27 |
28 | def get_merge_model(model, **kwargs):
29 | """merge C-A models"""
30 | return mergemodule[model.lower()](**kwargs)
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Qi BI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2 | import logging
3 | import os
4 | import sys
5 |
6 | __all__ = ['setup_logger']
7 |
8 |
9 | # reference from: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/logger.py
10 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", mode='w'):
11 | logger = logging.getLogger(name)
12 | logger.setLevel(logging.DEBUG)
13 | # don't log results for the non-master process
14 | if distributed_rank > 0:
15 | return logger
16 | ch = logging.StreamHandler(stream=sys.stdout)
17 | ch.setLevel(logging.DEBUG)
18 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
19 | ch.setFormatter(formatter)
20 | logger.addHandler(ch)
21 |
22 | if save_dir:
23 | if not os.path.exists(save_dir):
24 | os.makedirs(save_dir)
25 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) # 'a+' for add, 'w' for overwrite
26 | fh.setLevel(logging.DEBUG)
27 | fh.setFormatter(formatter)
28 | logger.addHandler(fh)
29 |
30 | return logger
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # All-day-CityScapes-segmentation
2 | All-day Semantic Segmentation & All-day CityScapes dataset
3 |
4 | This is the official implementation of our work entitled as ```Interactive Learning of Intrinsic and Extrinsic Properties for All-day Semantic Segmentation```, accepted by ```IEEE Transactions on Image Processing```.
5 |
6 | 
7 |
8 | # Dataset Download
9 |
10 | Please download ```All-day CityScapes``` from: [https://isis-data.science.uva.nl/cv/1ADcityscape.zip]
11 |
12 | For CopyRight issue, we only provide the rendered samples on both training and validation set of the original ```CityScapes```.
13 |
14 | All the sample name and data folder organization from ```All-day CityScapes``` is the same as the original ```CityScapes```.
15 |
16 | # Source Code & Implementation
17 |
18 | The proposed ```interactive intrinsic-extrinsic learning``` can be embedded into a variety of ```CNN``` and ```ViT``` based segmentation models.
19 |
20 | Here we provide the source code that is implemented on DDRNet-23 backbone, which is: 1) simple and easy to config; 2) most of the experiments in this paper conduct on.
21 | This implementation is highly based on the DDRNet source code. The original implementation of DDRNet can be found in this page.
22 |
23 | Please follow the below steps to run the AO-SegNet (DDRNet-23 based backbone).
24 |
25 | ### Step 1: Configuration
26 |
27 | Follow the original DDRNet-23 to prepare all the packages and data folder.
28 |
29 | ### Step 2: Train the Model
30 |
31 | ```python train.py --data_pth D:/alldaycityscapes --nclass 19```
32 |
33 | ### Step 3: Evaluation
34 |
35 | ```python eval.py```
36 |
37 | # Citation and Reference
38 | If you find this project useful, please cite:
39 | ```
40 | @ARTICLE{Bi2023AD,
41 | author={Bi, Qi and You, Shaodi and Gevers, Theo},
42 | journal={IEEE Transactions on Image Processing},
43 | title={Interactive Learning of Intrinsic and Extrinsic Properties for All-day Semantic Segmentation},
44 | year={2023},
45 | volume={32},
46 | number={},
47 | pages={3821-3835},
48 | doi={10.1109/TIP.2023.3290469}}
49 | ```
50 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from train import parse_args
4 | from utils.distributed import synchronize, get_rank, make_data_sampler, make_batch_data_sampler
5 | from utils.logger import setup_logger
6 | from utils.visualize import get_color_pallete
7 | from utils.score import SegmentationMetric
8 | from models import get_segmentation_model
9 | from dataloader.cityscapes import CitySegmentation
10 | from torchvision import transforms
11 | import torch.backends.cudnn as cudnn
12 | import torch.utils.data as data
13 | import torch.nn as nn
14 | import torch
15 | import os
16 | import sys
17 |
18 | cur_path = os.path.abspath(os.path.dirname(__file__))
19 | root_path = os.path.split(cur_path)[0]
20 | sys.path.append(root_path)
21 |
22 |
23 | class Evaluator(object):
24 | def __init__(self, args):
25 | self.args = args
26 | self.args.pretrained = True
27 | self.device = torch.device(args.device)
28 |
29 | # image transform
30 | input_transform = transforms.Compose([
31 | transforms.ToTensor(),
32 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
33 | ])
34 |
35 | # dataset and dataloader
36 | val_dataset = CitySegmentation(
37 | args.data_path, split='val', mode='testval', transform=input_transform)
38 | val_sampler = make_data_sampler(val_dataset, False, args.distributed)
39 | val_batch_sampler = make_batch_data_sampler(
40 | val_sampler, images_per_batch=1)
41 | self.val_loader = data.DataLoader(dataset=val_dataset,
42 | batch_sampler=val_batch_sampler,
43 | num_workers=args.workers,
44 | pin_memory=True)
45 |
46 | # create network
47 | BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
48 | self.model = get_segmentation_model(
49 | model=args.model, pretrained=False).to(self.device)
50 |
51 | if self.args.pretrained:
52 | self.model.load_state_dict(torch.load(
53 | "./trained_models/ddrnet_23_dualresnet_citys_best_model.pth",
54 | map_location=self.args.device))
55 | logger.info("Model restored successfully!!!!")
56 |
57 | if args.distributed:
58 | self.model = nn.parallel.DistributedDataParallel(self.model,
59 | device_ids=[args.local_rank], output_device=args.local_rank)
60 |
61 | self.model.to(self.device)
62 |
63 | self.metric = SegmentationMetric(val_dataset.num_class)
64 |
65 | def eval(self):
66 | self.metric.reset()
67 | self.model.eval()
68 | if self.args.distributed:
69 | model = self.model.module
70 | else:
71 | model = self.model
72 | logger.info("Start validation, Total sample: {:d}".format(
73 | len(self.val_loader)))
74 | for i, (image, target, filename) in enumerate(self.val_loader):
75 | image = image.to(self.device)
76 | target = target.to(self.device)
77 |
78 | with torch.no_grad():
79 | outputs, _, _ = model(image)
80 | self.metric.update(outputs[0], target)
81 | pixAcc, mIoU = self.metric.get()
82 | logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
83 | i + 1, pixAcc * 100, mIoU * 100))
84 |
85 | if self.args.save_pred:
86 | pred = torch.argmax(outputs[0], 1)
87 | pred = pred.cpu().data.numpy()
88 |
89 | predict = pred.squeeze(0)
90 | mask = get_color_pallete(predict, self.args.dataset)
91 | mask.save(os.path.join(
92 | outdir, os.path.splitext(filename[0])[0] + '.png'))
93 | logger.info("Whole validation set mIoU: {:.3f}".format(mIoU * 100))
94 |
95 | synchronize()
96 |
97 |
98 | if __name__ == '__main__':
99 | args = parse_args()
100 | num_gpus = int(os.environ["WORLD_SIZE"]
101 | ) if "WORLD_SIZE" in os.environ else 1
102 | args.distributed = num_gpus > 1
103 | if not args.no_cuda and torch.cuda.is_available():
104 | cudnn.benchmark = True
105 | args.device = "cuda"
106 | else:
107 | args.distributed = False
108 | args.device = "cpu"
109 | if args.distributed:
110 | torch.cuda.set_device(args.local_rank)
111 | torch.distributed.init_process_group(
112 | backend="nccl", init_method="env://")
113 | synchronize()
114 |
115 | # TODO: optim code
116 | args.save_pred = True
117 | if args.save_pred:
118 | outdir = 'runs/pred_pic/{}_{}_{}'.format(
119 | args.model, args.backbone, args.dataset)
120 | if not os.path.exists(outdir):
121 | os.makedirs(outdir)
122 |
123 | logger = setup_logger("semantic_segmentation", args.log_dir, get_rank(),
124 | filename='{}_{}_{}_log.txt'.format(args.model, args.backbone, args.dataset), mode='a+')
125 |
126 | evaluator = Evaluator(args)
127 | evaluator.eval()
128 | torch.cuda.empty_cache()
129 |
--------------------------------------------------------------------------------
/utils/visualize.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | __all__ = ['get_color_pallete', 'print_iou', 'set_img_color',
6 | 'show_prediction', 'show_colorful_images', 'save_colorful_images']
7 |
8 |
9 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False):
10 | n = iu.size
11 | lines = []
12 | for i in range(n):
13 | if class_names is None:
14 | cls = 'Class %d:' % (i + 1)
15 | else:
16 | cls = '%d %s' % (i + 1, class_names[i])
17 | # lines.append('%-8s: %.3f%%' % (cls, iu[i] * 100))
18 | mean_IU = np.nanmean(iu)
19 | mean_IU_no_back = np.nanmean(iu[1:])
20 | if show_no_back:
21 | lines.append('mean_IU: %.3f%% || mean_IU_no_back: %.3f%% || mean_pixel_acc: %.3f%%' % (
22 | mean_IU * 100, mean_IU_no_back * 100, mean_pixel_acc * 100))
23 | else:
24 | lines.append('mean_IU: %.3f%% || mean_pixel_acc: %.3f%%' % (mean_IU * 100, mean_pixel_acc * 100))
25 | lines.append('=================================================')
26 | line = "\n".join(lines)
27 |
28 | print(line)
29 |
30 |
31 | def set_img_color(img, label, colors, background=0, show255=False):
32 | for i in range(len(colors)):
33 | if i != background:
34 | img[np.where(label == i)] = colors[i]
35 | if show255:
36 | img[np.where(label == 255)] = 255
37 |
38 | return img
39 |
40 |
41 | def show_prediction(img, pred, colors, background=0):
42 | im = np.array(img, np.uint8)
43 | set_img_color(im, pred, colors, background)
44 | out = np.array(im)
45 |
46 | return out
47 |
48 |
49 | def show_colorful_images(prediction, palettes):
50 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()])
51 | im.show()
52 |
53 |
54 | def save_colorful_images(prediction, filename, output_dir, palettes):
55 | '''
56 | :param prediction: [B, H, W, C]
57 | '''
58 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()])
59 | fn = os.path.join(output_dir, filename)
60 | out_dir = os.path.split(fn)[0]
61 | if not os.path.exists(out_dir):
62 | os.mkdir(out_dir)
63 | im.save(fn)
64 |
65 |
66 | def get_color_pallete(npimg, dataset='pascal_voc'):
67 | """Visualize image.
68 |
69 | Parameters
70 | ----------
71 | npimg : numpy.ndarray
72 | Single channel image with shape `H, W, 1`.
73 | dataset : str, default: 'pascal_voc'
74 | The dataset that model pretrained on. ('pascal_voc', 'ade20k')
75 | Returns
76 | -------
77 | out_img : PIL.Image
78 | Image with color pallete
79 | """
80 | # recovery boundary
81 | if dataset in ('pascal_voc', 'pascal_aug'):
82 | npimg[npimg == -1] = 255
83 | # put colormap
84 | if dataset == 'ade20k':
85 | npimg = npimg + 1
86 | out_img = Image.fromarray(npimg.astype('uint8'))
87 | out_img.putpalette(adepallete)
88 | return out_img
89 | elif dataset == 'citys':
90 | out_img = Image.fromarray(npimg.astype('uint8'))
91 | out_img.putpalette(cityspallete)
92 | return out_img
93 | out_img = Image.fromarray(npimg.astype('uint8'))
94 | out_img.putpalette(vocpallete)
95 | return out_img
96 |
97 |
98 | def _getvocpallete(num_cls):
99 | n = num_cls
100 | pallete = [0] * (n * 3)
101 | for j in range(0, n):
102 | lab = j
103 | pallete[j * 3 + 0] = 0
104 | pallete[j * 3 + 1] = 0
105 | pallete[j * 3 + 2] = 0
106 | i = 0
107 | while (lab > 0):
108 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
109 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
110 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
111 | i = i + 1
112 | lab >>= 3
113 | return pallete
114 |
115 |
116 | vocpallete = _getvocpallete(256)
117 |
118 | adepallete = [
119 | 0, 0, 0, 120, 120, 120, 180, 120, 120, 6, 230, 230, 80, 50, 50, 4, 200, 3, 120, 120, 80, 140, 140, 140, 204,
120 | 5, 255, 230, 230, 230, 4, 250, 7, 224, 5, 255, 235, 255, 7, 150, 5, 61, 120, 120, 70, 8, 255, 51, 255, 6, 82,
121 | 143, 255, 140, 204, 255, 4, 255, 51, 7, 204, 70, 3, 0, 102, 200, 61, 230, 250, 255, 6, 51, 11, 102, 255, 255,
122 | 7, 71, 255, 9, 224, 9, 7, 230, 220, 220, 220, 255, 9, 92, 112, 9, 255, 8, 255, 214, 7, 255, 224, 255, 184, 6,
123 | 10, 255, 71, 255, 41, 10, 7, 255, 255, 224, 255, 8, 102, 8, 255, 255, 61, 6, 255, 194, 7, 255, 122, 8, 0, 255,
124 | 20, 255, 8, 41, 255, 5, 153, 6, 51, 255, 235, 12, 255, 160, 150, 20, 0, 163, 255, 140, 140, 140, 250, 10, 15,
125 | 20, 255, 0, 31, 255, 0, 255, 31, 0, 255, 224, 0, 153, 255, 0, 0, 0, 255, 255, 71, 0, 0, 235, 255, 0, 173, 255,
126 | 31, 0, 255, 11, 200, 200, 255, 82, 0, 0, 255, 245, 0, 61, 255, 0, 255, 112, 0, 255, 133, 255, 0, 0, 255, 163,
127 | 0, 255, 102, 0, 194, 255, 0, 0, 143, 255, 51, 255, 0, 0, 82, 255, 0, 255, 41, 0, 255, 173, 10, 0, 255, 173, 255,
128 | 0, 0, 255, 153, 255, 92, 0, 255, 0, 255, 255, 0, 245, 255, 0, 102, 255, 173, 0, 255, 0, 20, 255, 184, 184, 0,
129 | 31, 255, 0, 255, 61, 0, 71, 255, 255, 0, 204, 0, 255, 194, 0, 255, 82, 0, 10, 255, 0, 112, 255, 51, 0, 255, 0,
130 | 194, 255, 0, 122, 255, 0, 255, 163, 255, 153, 0, 0, 255, 10, 255, 112, 0, 143, 255, 0, 82, 0, 255, 163, 255,
131 | 0, 255, 235, 0, 8, 184, 170, 133, 0, 255, 0, 255, 92, 184, 0, 255, 255, 0, 31, 0, 184, 255, 0, 214, 255, 255,
132 | 0, 112, 92, 255, 0, 0, 224, 255, 112, 224, 255, 70, 184, 160, 163, 0, 255, 153, 0, 255, 71, 255, 0, 255, 0,
133 | 163, 255, 204, 0, 255, 0, 143, 0, 255, 235, 133, 255, 0, 255, 0, 235, 245, 0, 255, 255, 0, 122, 255, 245, 0,
134 | 10, 190, 212, 214, 255, 0, 0, 204, 255, 20, 0, 255, 255, 255, 0, 0, 153, 255, 0, 41, 255, 0, 255, 204, 41, 0,
135 | 255, 41, 255, 0, 173, 0, 255, 0, 245, 255, 71, 0, 255, 122, 0, 255, 0, 255, 184, 0, 92, 255, 184, 255, 0, 0,
136 | 133, 255, 255, 214, 0, 25, 194, 194, 102, 255, 0, 92, 0, 255]
137 |
138 | cityspallete = [
139 | 128, 64, 128,
140 | 244, 35, 232,
141 | 70, 70, 70,
142 | 102, 102, 156,
143 | 190, 153, 153,
144 | 153, 153, 153,
145 | 250, 170, 30,
146 | 220, 220, 0,
147 | 107, 142, 35,
148 | 152, 251, 152,
149 | 0, 130, 180,
150 | 220, 20, 60,
151 | 255, 0, 0,
152 | 0, 0, 142,
153 | 0, 0, 70,
154 | 0, 60, 100,
155 | 0, 80, 100,
156 | 0, 0, 230,
157 | 119, 11, 32,
158 | ]
159 |
--------------------------------------------------------------------------------
/utils/score.py:
--------------------------------------------------------------------------------
1 | """Evaluation Metrics for Semantic Segmentation"""
2 | import torch
3 | import numpy as np
4 |
5 | __all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union',
6 | 'pixelAccuracy', 'intersectionAndUnion', 'hist_info', 'compute_score']
7 |
8 |
9 | class SegmentationMetric(object):
10 | """Computes pixAcc and mIoU metric scores
11 | """
12 |
13 | def __init__(self, nclass):
14 | super(SegmentationMetric, self).__init__()
15 | self.nclass = nclass
16 | self.reset()
17 |
18 | def update(self, preds, labels):
19 | """Updates the internal evaluation result.
20 |
21 | Parameters
22 | ----------
23 | labels : 'NumpyArray' or list of `NumpyArray`
24 | The labels of the data.
25 | preds : 'NumpyArray' or list of `NumpyArray`
26 | Predicted values.
27 | """
28 |
29 | def evaluate_worker(self, pred, label):
30 | correct, labeled = batch_pix_accuracy(pred, label)
31 | inter, union = batch_intersection_union(pred, label, self.nclass)
32 |
33 | self.total_correct += correct
34 | self.total_label += labeled
35 | if self.total_inter.device != inter.device:
36 | self.total_inter = self.total_inter.to(inter.device)
37 | self.total_union = self.total_union.to(union.device)
38 | self.total_inter += inter
39 | self.total_union += union
40 |
41 | if isinstance(preds, torch.Tensor):
42 | evaluate_worker(self, preds, labels)
43 | elif isinstance(preds, (list, tuple)):
44 | for (pred, label) in zip(preds, labels):
45 | evaluate_worker(self, pred, label)
46 |
47 | def get(self):
48 | """Gets the current evaluation result.
49 |
50 | Returns
51 | -------
52 | metrics : tuple of float
53 | pixAcc and mIoU
54 | """
55 | pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove np.spacing(1)
56 | IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union)
57 | mIoU = IoU.mean().item()
58 | return pixAcc, mIoU
59 |
60 | def reset(self):
61 | """Resets the internal evaluation result to initial state."""
62 | self.total_inter = torch.zeros(self.nclass)
63 | self.total_union = torch.zeros(self.nclass)
64 | self.total_correct = 0
65 | self.total_label = 0
66 |
67 |
68 | # pytorch version
69 | def batch_pix_accuracy(output, target):
70 | """PixAcc"""
71 | # inputs are numpy array, output 4D, target 3D
72 | predict = torch.argmax(output.long(), 1) + 1
73 | target = target.long() + 1
74 |
75 | pixel_labeled = torch.sum(target > 0).item()
76 | pixel_correct = torch.sum((predict == target) * (target > 0)).item()
77 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled"
78 | return pixel_correct, pixel_labeled
79 |
80 |
81 | def batch_intersection_union(output, target, nclass):
82 | """mIoU"""
83 | # inputs are numpy array, output 4D, target 3D
84 | mini = 1
85 | maxi = nclass
86 | nbins = nclass
87 | predict = torch.argmax(output, 1) + 1
88 | target = target.float() + 1
89 |
90 | predict = predict.float() * (target > 0).float()
91 | intersection = predict * (predict == target).float()
92 | # areas of intersection and union
93 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary.
94 | area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi)
95 | area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi)
96 | area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi)
97 | area_union = area_pred + area_lab - area_inter
98 | assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area"
99 | return area_inter.float(), area_union.float()
100 |
101 |
102 | def pixelAccuracy(imPred, imLab):
103 | """
104 | This function takes the prediction and label of a single image, returns pixel-wise accuracy
105 | To compute over many images do:
106 | for i = range(Nimages):
107 | (pixel_accuracy[i], pixel_correct[i], pixel_labeled[i]) = \
108 | pixelAccuracy(imPred[i], imLab[i])
109 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (np.spacing(1) + np.sum(pixel_labeled))
110 | """
111 | # Remove classes from unlabeled pixels in gt image.
112 | # We should not penalize detections in unlabeled portions of the image.
113 | pixel_labeled = np.sum(imLab >= 0)
114 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0))
115 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled
116 | return (pixel_accuracy, pixel_correct, pixel_labeled)
117 |
118 |
119 | def intersectionAndUnion(imPred, imLab, numClass):
120 | """
121 | This function takes the prediction and label of a single image,
122 | returns intersection and union areas for each class
123 | To compute over many images do:
124 | for i in range(Nimages):
125 | (area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i])
126 | IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1)
127 | """
128 | # Remove classes from unlabeled pixels in gt image.
129 | # We should not penalize detections in unlabeled portions of the image.
130 | imPred = imPred * (imLab >= 0)
131 |
132 | # Compute area intersection:
133 | intersection = imPred * (imPred == imLab)
134 | (area_intersection, _) = np.histogram(intersection, bins=numClass, range=(1, numClass))
135 |
136 | # Compute area union:
137 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
138 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
139 | area_union = area_pred + area_lab - area_intersection
140 | return (area_intersection, area_union)
141 |
142 |
143 | def hist_info(pred, label, num_cls):
144 | assert pred.shape == label.shape
145 | k = (label >= 0) & (label < num_cls)
146 | labeled = np.sum(k)
147 | correct = np.sum((pred[k] == label[k]))
148 |
149 | return np.bincount(num_cls * label[k].astype(int) + pred[k], minlength=num_cls ** 2).reshape(num_cls,
150 | num_cls), labeled, correct
151 |
152 |
153 | def compute_score(hist, correct, labeled):
154 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
155 | mean_IU = np.nanmean(iu)
156 | mean_IU_no_back = np.nanmean(iu[1:])
157 | freq = hist.sum(1) / hist.sum()
158 | freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
159 | mean_pixel_acc = correct / labeled
160 |
161 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc
162 |
--------------------------------------------------------------------------------
/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | """Popular Learning Rate Schedulers"""
2 | from __future__ import division
3 | import math
4 | import torch
5 |
6 | from bisect import bisect_right
7 |
8 | __all__ = ['LRScheduler', 'WarmupMultiStepLR', 'WarmupPolyLR']
9 |
10 |
11 | class LRScheduler(object):
12 | r"""Learning Rate Scheduler
13 |
14 | Parameters
15 | ----------
16 | mode : str
17 | Modes for learning rate scheduler.
18 | Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'.
19 | base_lr : float
20 | Base learning rate, i.e. the starting learning rate.
21 | target_lr : float
22 | Target learning rate, i.e. the ending learning rate.
23 | With constant mode target_lr is ignored.
24 | niters : int
25 | Number of iterations to be scheduled.
26 | nepochs : int
27 | Number of epochs to be scheduled.
28 | iters_per_epoch : int
29 | Number of iterations in each epoch.
30 | offset : int
31 | Number of iterations before this scheduler.
32 | power : float
33 | Power parameter of poly scheduler.
34 | step_iter : list
35 | A list of iterations to decay the learning rate.
36 | step_epoch : list
37 | A list of epochs to decay the learning rate.
38 | step_factor : float
39 | Learning rate decay factor.
40 | """
41 |
42 | def __init__(self, mode, base_lr=0.01, target_lr=0, niters=0, nepochs=0, iters_per_epoch=0,
43 | offset=0, power=0.9, step_iter=None, step_epoch=None, step_factor=0.1, warmup_epochs=0):
44 | super(LRScheduler, self).__init__()
45 | assert (mode in ['constant', 'step', 'linear', 'poly', 'cosine'])
46 |
47 | if mode == 'step':
48 | assert (step_iter is not None or step_epoch is not None)
49 | self.niters = niters
50 | self.step = step_iter
51 | epoch_iters = nepochs * iters_per_epoch
52 | if epoch_iters > 0:
53 | self.niters = epoch_iters
54 | if step_epoch is not None:
55 | self.step = [s * iters_per_epoch for s in step_epoch]
56 |
57 | self.step_factor = step_factor
58 | self.base_lr = base_lr
59 | self.target_lr = base_lr if mode == 'constant' else target_lr
60 | self.offset = offset
61 | self.power = power
62 | self.warmup_iters = warmup_epochs * iters_per_epoch
63 | self.mode = mode
64 |
65 | def __call__(self, optimizer, num_update):
66 | self.update(num_update)
67 | assert self.learning_rate >= 0
68 | self._adjust_learning_rate(optimizer, self.learning_rate)
69 |
70 | def update(self, num_update):
71 | N = self.niters - 1
72 | T = num_update - self.offset
73 | T = min(max(0, T), N)
74 |
75 | if self.mode == 'constant':
76 | factor = 0
77 | elif self.mode == 'linear':
78 | factor = 1 - T / N
79 | elif self.mode == 'poly':
80 | factor = pow(1 - T / N, self.power)
81 | elif self.mode == 'cosine':
82 | factor = (1 + math.cos(math.pi * T / N)) / 2
83 | elif self.mode == 'step':
84 | if self.step is not None:
85 | count = sum([1 for s in self.step if s <= T])
86 | factor = pow(self.step_factor, count)
87 | else:
88 | factor = 1
89 | else:
90 | raise NotImplementedError
91 |
92 | # warm up lr schedule
93 | if self.warmup_iters > 0 and T < self.warmup_iters:
94 | factor = factor * 1.0 * T / self.warmup_iters
95 |
96 | if self.mode == 'step':
97 | self.learning_rate = self.base_lr * factor
98 | else:
99 | self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor
100 |
101 | def _adjust_learning_rate(self, optimizer, lr):
102 | optimizer.param_groups[0]['lr'] = lr
103 | # enlarge the lr at the head
104 | for i in range(1, len(optimizer.param_groups)):
105 | optimizer.param_groups[i]['lr'] = lr * 10
106 |
107 |
108 | # separating MultiStepLR with WarmupLR
109 | # but the current LRScheduler design doesn't allow it
110 | # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/solver/lr_scheduler.py
111 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
112 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3,
113 | warmup_iters=500, warmup_method="linear", last_epoch=-1):
114 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
115 | if not list(milestones) == sorted(milestones):
116 | raise ValueError(
117 | "Milestones should be a list of" " increasing integers. Got {}", milestones)
118 | if warmup_method not in ("constant", "linear"):
119 | raise ValueError(
120 | "Only 'constant' or 'linear' warmup_method accepted got {}".format(warmup_method))
121 |
122 | self.milestones = milestones
123 | self.gamma = gamma
124 | self.warmup_factor = warmup_factor
125 | self.warmup_iters = warmup_iters
126 | self.warmup_method = warmup_method
127 |
128 | def get_lr(self):
129 | warmup_factor = 1
130 | if self.last_epoch < self.warmup_iters:
131 | if self.warmup_method == 'constant':
132 | warmup_factor = self.warmup_factor
133 | elif self.warmup_factor == 'linear':
134 | alpha = float(self.last_epoch) / self.warmup_iters
135 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
136 | return [base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch)
137 | for base_lr in self.base_lrs]
138 |
139 |
140 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
141 | def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3,
142 | warmup_iters=500, warmup_method='linear', last_epoch=-1):
143 | if warmup_method not in ("constant", "linear"):
144 | raise ValueError(
145 | "Only 'constant' or 'linear' warmup_method accepted "
146 | "got {}".format(warmup_method))
147 |
148 | self.target_lr = target_lr
149 | self.max_iters = max_iters
150 | self.power = power
151 | self.warmup_factor = warmup_factor
152 | self.warmup_iters = warmup_iters
153 | self.warmup_method = warmup_method
154 |
155 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch)
156 |
157 | def get_lr(self):
158 | N = self.max_iters - self.warmup_iters
159 | T = self.last_epoch - self.warmup_iters
160 | if self.last_epoch < self.warmup_iters:
161 | if self.warmup_method == 'constant':
162 | warmup_factor = self.warmup_factor
163 | elif self.warmup_method == 'linear':
164 | alpha = float(self.last_epoch) / self.warmup_iters
165 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
166 | else:
167 | raise ValueError("Unknown warmup type.")
168 | return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs]
169 | factor = pow(1 - T / N, self.power)
170 | return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs]
171 |
172 |
173 | if __name__ == '__main__':
174 | import torch
175 | import torch.nn as nn
176 |
177 | model = nn.Conv2d(16, 16, 3, 1, 1)
178 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
179 | lr_scheduler = WarmupPolyLR(optimizer, niters=1000)
180 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | """Custom losses."""
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from torch.autograd import Variable
7 |
8 | __all__ = ['MixSoftmaxCrossEntropyLoss', 'MixSoftmaxCrossEntropyOHEMLoss',
9 | 'EncNetLoss', 'ICNetLoss', 'get_segmentation_loss']
10 |
11 |
12 | # TODO: optim function
13 | class MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss):
14 | def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs):
15 | super(MixSoftmaxCrossEntropyLoss, self).__init__(
16 | ignore_index=ignore_index)
17 | self.aux = aux
18 | self.aux_weight = aux_weight
19 |
20 | def _aux_forward(self, *inputs, **kwargs):
21 | *preds, target = tuple(inputs)
22 |
23 | loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[0], target)
24 | for i in range(1, len(preds)):
25 | aux_loss = super(MixSoftmaxCrossEntropyLoss,
26 | self).forward(preds[i], target)
27 | loss += self.aux_weight * aux_loss
28 | return loss
29 |
30 | def forward(self, *inputs, **kwargs):
31 | preds, target = tuple(inputs)
32 | inputs = tuple(list(preds) + [target])
33 | if self.aux:
34 | return dict(loss=self._aux_forward(*inputs))
35 | else:
36 | return dict(loss=super(MixSoftmaxCrossEntropyLoss, self).forward(*inputs))
37 |
38 |
39 | # reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/loss.py
40 | class EncNetLoss(nn.CrossEntropyLoss):
41 | """2D Cross Entropy Loss with SE Loss"""
42 |
43 | def __init__(self, se_loss=True, se_weight=0.2, nclass=19, aux=False,
44 | aux_weight=0.4, weight=None, ignore_index=-1, **kwargs):
45 | super(EncNetLoss, self).__init__(weight, None, ignore_index)
46 | self.se_loss = se_loss
47 | self.aux = aux
48 | self.nclass = nclass
49 | self.se_weight = se_weight
50 | self.aux_weight = aux_weight
51 | self.bceloss = nn.BCELoss(weight)
52 |
53 | def forward(self, *inputs):
54 | preds, target = tuple(inputs)
55 | inputs = tuple(list(preds) + [target])
56 | if not self.se_loss and not self.aux:
57 | return super(EncNetLoss, self).forward(*inputs)
58 | elif not self.se_loss:
59 | pred1, pred2, target = tuple(inputs)
60 | loss1 = super(EncNetLoss, self).forward(pred1, target)
61 | loss2 = super(EncNetLoss, self).forward(pred2, target)
62 | return dict(loss=loss1 + self.aux_weight * loss2)
63 | elif not self.aux:
64 | pred, se_pred, target = tuple(inputs)
65 | se_target = self._get_batch_label_vector(
66 | target, nclass=self.nclass).type_as(pred)
67 | loss1 = super(EncNetLoss, self).forward(pred, target)
68 | loss2 = self.bceloss(torch.sigmoid(se_pred), se_target)
69 | return dict(loss=loss1 + self.se_weight * loss2)
70 | else:
71 | pred1, se_pred, pred2, target = tuple(inputs)
72 | se_target = self._get_batch_label_vector(
73 | target, nclass=self.nclass).type_as(pred1)
74 | loss1 = super(EncNetLoss, self).forward(pred1, target)
75 | loss2 = super(EncNetLoss, self).forward(pred2, target)
76 | loss3 = self.bceloss(torch.sigmoid(se_pred), se_target)
77 | return dict(loss=loss1 + self.aux_weight * loss2 + self.se_weight * loss3)
78 |
79 | @staticmethod
80 | def _get_batch_label_vector(target, nclass):
81 | # target is a 3D Variable BxHxW, output is 2D BxnClass
82 | batch = target.size(0)
83 | tvect = Variable(torch.zeros(batch, nclass))
84 | for i in range(batch):
85 | hist = torch.histc(target[i].cpu().data.float(),
86 | bins=nclass, min=0,
87 | max=nclass - 1)
88 | vect = hist > 0
89 | tvect[i] = vect
90 | return tvect
91 |
92 |
93 | # TODO: optim function
94 | class ICNetLoss(nn.CrossEntropyLoss):
95 | """Cross Entropy Loss for ICNet"""
96 |
97 | def __init__(self, nclass, aux_weight=0.4, ignore_index=-1, **kwargs):
98 | super(ICNetLoss, self).__init__(ignore_index=ignore_index)
99 | self.nclass = nclass
100 | self.aux_weight = aux_weight
101 |
102 | def forward(self, *inputs):
103 | preds, target = tuple(inputs)
104 | inputs = tuple(list(preds) + [target])
105 |
106 | pred, pred_sub4, pred_sub8, pred_sub16, target = tuple(inputs)
107 | # [batch, W, H] -> [batch, 1, W, H]
108 | target = target.unsqueeze(1).float()
109 | target_sub4 = F.interpolate(target, pred_sub4.size(
110 | )[2:], mode='bilinear', align_corners=True).squeeze(1).long()
111 | target_sub8 = F.interpolate(target, pred_sub8.size(
112 | )[2:], mode='bilinear', align_corners=True).squeeze(1).long()
113 | target_sub16 = F.interpolate(target, pred_sub16.size()[2:], mode='bilinear', align_corners=True).squeeze(
114 | 1).long()
115 | loss1 = super(ICNetLoss, self).forward(pred_sub4, target_sub4)
116 | loss2 = super(ICNetLoss, self).forward(pred_sub8, target_sub8)
117 | loss3 = super(ICNetLoss, self).forward(pred_sub16, target_sub16)
118 | return dict(loss=loss1 + loss2 * self.aux_weight + loss3 * self.aux_weight)
119 |
120 |
121 | class OhemCrossEntropy2d(nn.Module):
122 | def __init__(self, ignore_index=-1, thresh=0.7, min_kept=100000, use_weight=True, **kwargs):
123 | super(OhemCrossEntropy2d, self).__init__()
124 | self.ignore_index = ignore_index
125 | self.thresh = float(thresh)
126 | self.min_kept = int(min_kept)
127 | if use_weight:
128 | weight = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
129 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955,
130 | 1.0865, 1.1529, 1.0507])
131 | self.criterion = torch.nn.CrossEntropyLoss(
132 | weight=weight, ignore_index=ignore_index)
133 | else:
134 | self.criterion = torch.nn.CrossEntropyLoss(
135 | ignore_index=ignore_index)
136 |
137 | def forward(self, pred, target):
138 | n, c, h, w = pred.size()
139 | target = target.view(-1)
140 | valid_mask = target.ne(self.ignore_index)
141 | target = target * valid_mask.long()
142 | num_valid = valid_mask.sum()
143 |
144 | prob = F.softmax(pred, dim=1)
145 | prob = prob.transpose(0, 1).reshape(c, -1)
146 |
147 | if self.min_kept > num_valid:
148 | print("Lables: {}".format(num_valid))
149 | elif num_valid > 0:
150 | #prob = prob.masked_fill_(1 - valid_mask, 1)
151 | prob = prob.masked_fill_(~valid_mask, 1)
152 | mask_prob = prob[target, torch.arange(
153 | len(target), dtype=torch.long)]
154 | threshold = self.thresh
155 | if self.min_kept > 0:
156 | index = mask_prob.argsort()
157 | threshold_index = index[min(len(index), self.min_kept) - 1]
158 | if mask_prob[threshold_index] > self.thresh:
159 | threshold = mask_prob[threshold_index]
160 | kept_mask = mask_prob.le(threshold)
161 | valid_mask = valid_mask * kept_mask
162 | target = target * kept_mask.long()
163 |
164 | #target = target.masked_fill_(1 - valid_mask, self.ignore_index)
165 | target = target.masked_fill_(~valid_mask, self.ignore_index)
166 | target = target.view(n, h, w)
167 |
168 | return self.criterion(pred, target)
169 |
170 |
171 | class MixSoftmaxCrossEntropyOHEMLoss(OhemCrossEntropy2d):
172 | def __init__(self, aux=False, aux_weight=0.4, weight=None, ignore_index=-1, **kwargs):
173 | super(MixSoftmaxCrossEntropyOHEMLoss, self).__init__(
174 | ignore_index=ignore_index)
175 | self.aux = aux
176 | self.aux_weight = aux_weight
177 | self.bceloss = nn.BCELoss(weight)
178 |
179 | def _aux_forward(self, *inputs, **kwargs):
180 | *preds, target = tuple(inputs)
181 |
182 | loss = super(MixSoftmaxCrossEntropyOHEMLoss,
183 | self).forward(preds[0], target)
184 | for i in range(1, len(preds)):
185 | aux_loss = super(MixSoftmaxCrossEntropyOHEMLoss,
186 | self).forward(preds[i], target)
187 | loss += self.aux_weight * aux_loss
188 | return loss
189 |
190 | def forward(self, *inputs):
191 | preds, target = tuple(inputs)
192 | inputs = tuple(list(preds) + [target])
193 | if self.aux:
194 | return dict(loss=self._aux_forward(*inputs))
195 | else:
196 | return dict(loss=super(MixSoftmaxCrossEntropyOHEMLoss, self).forward(*inputs))
197 |
198 |
199 | def get_segmentation_loss(model, use_ohem=False, **kwargs):
200 | if use_ohem:
201 | return MixSoftmaxCrossEntropyOHEMLoss(**kwargs)
202 |
203 | model = model.lower()
204 | if model == 'encnet':
205 | return EncNetLoss(**kwargs)
206 | elif model == 'icnet':
207 | return ICNetLoss(**kwargs)
208 | else:
209 | return MixSoftmaxCrossEntropyLoss(**kwargs)
210 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains primitives for multi-gpu communication.
3 | This is useful when doing distributed training.
4 | """
5 | import math
6 | import pickle
7 | import torch
8 | import torch.utils.data as data
9 | import torch.distributed as dist
10 |
11 | from torch.utils.data.sampler import Sampler, BatchSampler
12 |
13 | __all__ = ['get_world_size', 'get_rank', 'synchronize', 'is_main_process',
14 | 'all_gather', 'make_data_sampler', 'make_batch_data_sampler',
15 | 'reduce_dict', 'reduce_loss_dict']
16 |
17 |
18 | # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/comm.py
19 | def get_world_size():
20 | if not dist.is_available():
21 | return 1
22 | if not dist.is_initialized():
23 | return 1
24 | return dist.get_world_size()
25 |
26 |
27 | def get_rank():
28 | if not dist.is_available():
29 | return 0
30 | if not dist.is_initialized():
31 | return 0
32 | return dist.get_rank()
33 |
34 |
35 | def is_main_process():
36 | return get_rank() == 0
37 |
38 |
39 | def synchronize():
40 | """
41 | Helper function to synchronize (barrier) among all processes when
42 | using distributed training
43 | """
44 | if not dist.is_available():
45 | return
46 | if not dist.is_initialized():
47 | return
48 | world_size = dist.get_world_size()
49 | if world_size == 1:
50 | return
51 | dist.barrier()
52 |
53 |
54 | def all_gather(data):
55 | """
56 | Run all_gather on arbitrary picklable data (not necessarily tensors)
57 | Args:
58 | data: any picklable object
59 | Returns:
60 | list[data]: list of data gathered from each rank
61 | """
62 | world_size = get_world_size()
63 | if world_size == 1:
64 | return [data]
65 |
66 | # serialized to a Tensor
67 | buffer = pickle.dumps(data)
68 | storage = torch.ByteStorage.from_buffer(buffer)
69 | tensor = torch.ByteTensor(storage).to("cuda")
70 |
71 | # obtain Tensor size of each rank
72 | local_size = torch.IntTensor([tensor.numel()]).to("cuda")
73 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
74 | dist.all_gather(size_list, local_size)
75 | size_list = [int(size.item()) for size in size_list]
76 | max_size = max(size_list)
77 |
78 | # receiving Tensor from all ranks
79 | # we pad the tensor because torch all_gather does not support
80 | # gathering tensors of different shapes
81 | tensor_list = []
82 | for _ in size_list:
83 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
84 | if local_size != max_size:
85 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
86 | tensor = torch.cat((tensor, padding), dim=0)
87 | dist.all_gather(tensor_list, tensor)
88 |
89 | data_list = []
90 | for size, tensor in zip(size_list, tensor_list):
91 | buffer = tensor.cpu().numpy().tobytes()[:size]
92 | data_list.append(pickle.loads(buffer))
93 |
94 | return data_list
95 |
96 |
97 | def reduce_dict(input_dict, average=True):
98 | """
99 | Args:
100 | input_dict (dict): all the values will be reduced
101 | average (bool): whether to do average or sum
102 | Reduce the values in the dictionary from all processes so that process with rank
103 | 0 has the averaged results. Returns a dict with the same fields as
104 | input_dict, after reduction.
105 | """
106 | world_size = get_world_size()
107 | if world_size < 2:
108 | return input_dict
109 | with torch.no_grad():
110 | names = []
111 | values = []
112 | # sort the keys so that they are consistent across processes
113 | for k in sorted(input_dict.keys()):
114 | names.append(k)
115 | values.append(input_dict[k])
116 | values = torch.stack(values, dim=0)
117 | dist.reduce(values, dst=0)
118 | if dist.get_rank() == 0 and average:
119 | # only main process gets accumulated, so only divide by
120 | # world_size in this case
121 | values /= world_size
122 | reduced_dict = {k: v for k, v in zip(names, values)}
123 | return reduced_dict
124 |
125 |
126 | def reduce_loss_dict(loss_dict):
127 | """
128 | Reduce the loss dictionary from all processes so that process with rank
129 | 0 has the averaged results. Returns a dict with the same fields as
130 | loss_dict, after reduction.
131 | """
132 | world_size = get_world_size()
133 | if world_size < 2:
134 | return loss_dict
135 | with torch.no_grad():
136 | loss_names = []
137 | all_losses = []
138 | for k in sorted(loss_dict.keys()):
139 | loss_names.append(k)
140 | all_losses.append(loss_dict[k])
141 | all_losses = torch.stack(all_losses, dim=0)
142 | dist.reduce(all_losses, dst=0)
143 | if dist.get_rank() == 0:
144 | # only main process gets accumulated, so only divide by
145 | # world_size in this case
146 | all_losses /= world_size
147 | reduced_losses = {k: v for k, v in zip(loss_names, all_losses)}
148 | return reduced_losses
149 |
150 |
151 | def make_data_sampler(dataset, shuffle, distributed):
152 | if distributed:
153 | return DistributedSampler(dataset, shuffle=shuffle)
154 | if shuffle:
155 | sampler = data.sampler.RandomSampler(dataset)
156 | else:
157 | sampler = data.sampler.SequentialSampler(dataset)
158 | return sampler
159 |
160 |
161 | def make_batch_data_sampler(sampler, images_per_batch, num_iters=None, start_iter=0):
162 | batch_sampler = data.sampler.BatchSampler(sampler, images_per_batch, drop_last=True)
163 | if num_iters is not None:
164 | batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iters, start_iter)
165 | return batch_sampler
166 |
167 |
168 | # Code is copy-pasted from https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/samplers/distributed.py
169 | class DistributedSampler(Sampler):
170 | """Sampler that restricts data loading to a subset of the dataset.
171 | It is especially useful in conjunction with
172 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
173 | process can pass a DistributedSampler instance as a DataLoader sampler,
174 | and load a subset of the original dataset that is exclusive to it.
175 | .. note::
176 | Dataset is assumed to be of constant size.
177 | Arguments:
178 | dataset: Dataset used for sampling.
179 | num_replicas (optional): Number of processes participating in
180 | distributed training.
181 | rank (optional): Rank of the current process within num_replicas.
182 | """
183 |
184 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
185 | if num_replicas is None:
186 | if not dist.is_available():
187 | raise RuntimeError("Requires distributed package to be available")
188 | num_replicas = dist.get_world_size()
189 | if rank is None:
190 | if not dist.is_available():
191 | raise RuntimeError("Requires distributed package to be available")
192 | rank = dist.get_rank()
193 | self.dataset = dataset
194 | self.num_replicas = num_replicas
195 | self.rank = rank
196 | self.epoch = 0
197 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
198 | self.total_size = self.num_samples * self.num_replicas
199 | self.shuffle = shuffle
200 |
201 | def __iter__(self):
202 | if self.shuffle:
203 | # deterministically shuffle based on epoch
204 | g = torch.Generator()
205 | g.manual_seed(self.epoch)
206 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
207 | else:
208 | indices = torch.arange(len(self.dataset)).tolist()
209 |
210 | # add extra samples to make it evenly divisible
211 | indices += indices[: (self.total_size - len(indices))]
212 | assert len(indices) == self.total_size
213 |
214 | # subsample
215 | offset = self.num_samples * self.rank
216 | indices = indices[offset: offset + self.num_samples]
217 | assert len(indices) == self.num_samples
218 |
219 | return iter(indices)
220 |
221 | def __len__(self):
222 | return self.num_samples
223 |
224 | def set_epoch(self, epoch):
225 | self.epoch = epoch
226 |
227 |
228 | class IterationBasedBatchSampler(BatchSampler):
229 | """
230 | Wraps a BatchSampler, resampling from it until
231 | a specified number of iterations have been sampled
232 | """
233 |
234 | def __init__(self, batch_sampler, num_iterations, start_iter=0):
235 | self.batch_sampler = batch_sampler
236 | self.num_iterations = num_iterations
237 | self.start_iter = start_iter
238 |
239 | def __iter__(self):
240 | iteration = self.start_iter
241 | while iteration <= self.num_iterations:
242 | # if the underlying sampler has a set_epoch method, like
243 | # DistributedSampler, used for making each process see
244 | # a different split of the dataset, then set it
245 | if hasattr(self.batch_sampler.sampler, "set_epoch"):
246 | self.batch_sampler.sampler.set_epoch(iteration)
247 | for batch in self.batch_sampler:
248 | iteration += 1
249 | if iteration > self.num_iterations:
250 | break
251 | yield batch
252 |
253 | def __len__(self):
254 | return self.num_iterations
255 |
256 |
257 | if __name__ == '__main__':
258 | pass
259 |
--------------------------------------------------------------------------------
/models/DDRNet_39.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import torch
4 | import numpy as np
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.nn import init
8 | from collections import OrderedDict
9 |
10 | # for single gpu
11 | BatchNorm2d = nn.BatchNorm2d
12 | bn_mom = 0.1
13 |
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | """3x3 convolution with padding"""
17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18 | padding=1, bias=False)
19 |
20 |
21 | class BasicBlock(nn.Module):
22 | expansion = 1
23 |
24 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False):
25 | super(BasicBlock, self).__init__()
26 | self.conv1 = conv3x3(inplanes, planes, stride)
27 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
28 | self.relu = nn.ReLU(inplace=True)
29 | self.conv2 = conv3x3(planes, planes)
30 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
31 | self.downsample = downsample
32 | self.stride = stride
33 | self.no_relu = no_relu
34 |
35 | def forward(self, x):
36 | residual = x
37 |
38 | out = self.conv1(x)
39 | out = self.bn1(out)
40 | out = self.relu(out)
41 |
42 | out = self.conv2(out)
43 | out = self.bn2(out)
44 |
45 | if self.downsample is not None:
46 | residual = self.downsample(x)
47 |
48 | out += residual
49 |
50 | if self.no_relu:
51 | return out
52 | else:
53 | return self.relu(out)
54 |
55 |
56 | class Bottleneck(nn.Module):
57 | expansion = 2
58 |
59 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True):
60 | super(Bottleneck, self).__init__()
61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62 | self.bn1 = nn.BatchNorm2d(planes, momentum=bn_mom)
63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
64 | padding=1, bias=False)
65 | self.bn2 = nn.BatchNorm2d(planes, momentum=bn_mom)
66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
67 | bias=False)
68 | self.bn3 = nn.BatchNorm2d(planes * self.expansion,
69 | momentum=bn_mom)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.downsample = downsample
72 | self.stride = stride
73 | self.no_relu = no_relu
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv3(out)
87 | out = self.bn3(out)
88 |
89 | if self.downsample is not None:
90 | residual = self.downsample(x)
91 |
92 | out += residual
93 | if self.no_relu:
94 | return out
95 | else:
96 | return self.relu(out)
97 |
98 |
99 | class DAPPM(nn.Module):
100 | def __init__(self, inplanes, branch_planes, outplanes):
101 | super(DAPPM, self).__init__()
102 | self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
103 | BatchNorm2d(inplanes, momentum=bn_mom),
104 | nn.ReLU(inplace=True),
105 | nn.Conv2d(inplanes, branch_planes,
106 | kernel_size=1, bias=False),
107 | )
108 | self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
109 | BatchNorm2d(inplanes, momentum=bn_mom),
110 | nn.ReLU(inplace=True),
111 | nn.Conv2d(inplanes, branch_planes,
112 | kernel_size=1, bias=False),
113 | )
114 | self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
115 | BatchNorm2d(inplanes, momentum=bn_mom),
116 | nn.ReLU(inplace=True),
117 | nn.Conv2d(inplanes, branch_planes,
118 | kernel_size=1, bias=False),
119 | )
120 | self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
121 | BatchNorm2d(inplanes, momentum=bn_mom),
122 | nn.ReLU(inplace=True),
123 | nn.Conv2d(inplanes, branch_planes,
124 | kernel_size=1, bias=False),
125 | )
126 | self.scale0 = nn.Sequential(
127 | BatchNorm2d(inplanes, momentum=bn_mom),
128 | nn.ReLU(inplace=True),
129 | nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
130 | )
131 | self.process1 = nn.Sequential(
132 | BatchNorm2d(branch_planes, momentum=bn_mom),
133 | nn.ReLU(inplace=True),
134 | nn.Conv2d(branch_planes, branch_planes,
135 | kernel_size=3, padding=1, bias=False),
136 | )
137 | self.process2 = nn.Sequential(
138 | BatchNorm2d(branch_planes, momentum=bn_mom),
139 | nn.ReLU(inplace=True),
140 | nn.Conv2d(branch_planes, branch_planes,
141 | kernel_size=3, padding=1, bias=False),
142 | )
143 | self.process3 = nn.Sequential(
144 | BatchNorm2d(branch_planes, momentum=bn_mom),
145 | nn.ReLU(inplace=True),
146 | nn.Conv2d(branch_planes, branch_planes,
147 | kernel_size=3, padding=1, bias=False),
148 | )
149 | self.process4 = nn.Sequential(
150 | BatchNorm2d(branch_planes, momentum=bn_mom),
151 | nn.ReLU(inplace=True),
152 | nn.Conv2d(branch_planes, branch_planes,
153 | kernel_size=3, padding=1, bias=False),
154 | )
155 | self.compression = nn.Sequential(
156 | BatchNorm2d(branch_planes * 5, momentum=bn_mom),
157 | nn.ReLU(inplace=True),
158 | nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
159 | )
160 | self.shortcut = nn.Sequential(
161 | BatchNorm2d(inplanes, momentum=bn_mom),
162 | nn.ReLU(inplace=True),
163 | nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
164 | )
165 |
166 | def forward(self, x):
167 |
168 | #x = self.downsample(x)
169 | width = x.shape[-1]
170 | height = x.shape[-2]
171 | x_list = []
172 |
173 | x_list.append(self.scale0(x))
174 | x_list.append(self.process1((F.interpolate(self.scale1(x),
175 | size=[height, width],
176 | mode='bilinear')+x_list[0])))
177 | x_list.append((self.process2((F.interpolate(self.scale2(x),
178 | size=[height, width],
179 | mode='bilinear')+x_list[1]))))
180 | x_list.append(self.process3((F.interpolate(self.scale3(x),
181 | size=[height, width],
182 | mode='bilinear')+x_list[2])))
183 | x_list.append(self.process4((F.interpolate(self.scale4(x),
184 | size=[height, width],
185 | mode='bilinear')+x_list[3])))
186 |
187 | out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
188 | return out
189 |
190 |
191 | class segmenthead(nn.Module):
192 |
193 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
194 | super(segmenthead, self).__init__()
195 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
196 | self.conv1 = nn.Conv2d(inplanes, interplanes,
197 | kernel_size=3, padding=1, bias=False)
198 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
199 | self.relu = nn.ReLU(inplace=True)
200 | self.conv2 = nn.Conv2d(interplanes, outplanes,
201 | kernel_size=1, padding=0, bias=True)
202 | self.scale_factor = scale_factor
203 |
204 | def forward(self, x):
205 |
206 | x = self.conv1(self.relu(self.bn1(x)))
207 | out = self.conv2(self.relu(self.bn2(x)))
208 |
209 | if self.scale_factor is not None:
210 | height = x.shape[-2] * self.scale_factor
211 | width = x.shape[-1] * self.scale_factor
212 | out = F.interpolate(out,
213 | size=[height, width],
214 | mode='bilinear')
215 |
216 | return out
217 |
218 |
219 | class DualResNet(nn.Module):
220 |
221 | def __init__(self, block, layers, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False):
222 | super(DualResNet, self).__init__()
223 |
224 | highres_planes = planes * 2
225 | self.augment = augment
226 |
227 | self.conv1 = nn.Sequential(
228 | nn.Conv2d(3, planes, kernel_size=3, stride=2, padding=1),
229 | BatchNorm2d(planes, momentum=bn_mom),
230 | nn.ReLU(inplace=True),
231 | nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1),
232 | BatchNorm2d(planes, momentum=bn_mom),
233 | nn.ReLU(inplace=True),
234 | )
235 |
236 | self.relu = nn.ReLU(inplace=False)
237 | self.layer1 = self._make_layer(block, planes, planes, layers[0])
238 | self.layer2 = self._make_layer(
239 | block, planes, planes * 2, layers[1], stride=2)
240 | self.layer3_1 = self._make_layer(
241 | block, planes * 2, planes * 4, layers[2] // 2, stride=2)
242 | self.layer3_2 = self._make_layer(
243 | block, planes * 4, planes * 4, layers[2] // 2)
244 | self.layer4 = self._make_layer(
245 | block, planes * 4, planes * 8, layers[3], stride=2)
246 |
247 | self.compression3_1 = nn.Sequential(
248 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False),
249 | BatchNorm2d(highres_planes, momentum=bn_mom),
250 | )
251 |
252 | self.compression3_2 = nn.Sequential(
253 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False),
254 | BatchNorm2d(highres_planes, momentum=bn_mom),
255 | )
256 |
257 | self.compression4 = nn.Sequential(
258 | nn.Conv2d(planes * 8, highres_planes, kernel_size=1, bias=False),
259 | BatchNorm2d(highres_planes, momentum=bn_mom),
260 | )
261 |
262 | self.down3_1 = nn.Sequential(
263 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3,
264 | stride=2, padding=1, bias=False),
265 | BatchNorm2d(planes * 4, momentum=bn_mom),
266 | )
267 |
268 | self.down3_2 = nn.Sequential(
269 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3,
270 | stride=2, padding=1, bias=False),
271 | BatchNorm2d(planes * 4, momentum=bn_mom),
272 | )
273 |
274 | self.down4 = nn.Sequential(
275 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3,
276 | stride=2, padding=1, bias=False),
277 | BatchNorm2d(planes * 4, momentum=bn_mom),
278 | nn.ReLU(inplace=True),
279 | nn.Conv2d(planes * 4, planes * 8, kernel_size=3,
280 | stride=2, padding=1, bias=False),
281 | BatchNorm2d(planes * 8, momentum=bn_mom),
282 | )
283 |
284 | self.layer3_1_ = self._make_layer(
285 | block, planes * 2, highres_planes, layers[2] // 2)
286 |
287 | self.layer3_2_ = self._make_layer(
288 | block, highres_planes, highres_planes, layers[2] // 2)
289 |
290 | self.layer4_ = self._make_layer(
291 | block, highres_planes, highres_planes, layers[3])
292 |
293 | self.layer5_ = self._make_layer(
294 | Bottleneck, highres_planes, highres_planes, 1)
295 |
296 | self.layer5 = self._make_layer(
297 | Bottleneck, planes * 8, planes * 8, 1, stride=2)
298 |
299 | self.spp = DAPPM(planes * 16, spp_planes, planes * 4)
300 |
301 | if self.augment:
302 | self.seghead_extra = segmenthead(
303 | highres_planes, head_planes, num_classes)
304 |
305 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes)
306 |
307 | for m in self.modules():
308 | if isinstance(m, nn.Conv2d):
309 | nn.init.kaiming_normal_(
310 | m.weight, mode='fan_out', nonlinearity='relu')
311 | elif isinstance(m, BatchNorm2d):
312 | nn.init.constant_(m.weight, 1)
313 | nn.init.constant_(m.bias, 0)
314 |
315 | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
316 | downsample = None
317 | if stride != 1 or inplanes != planes * block.expansion:
318 | downsample = nn.Sequential(
319 | nn.Conv2d(inplanes, planes * block.expansion,
320 | kernel_size=1, stride=stride, bias=False),
321 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
322 | )
323 |
324 | layers = []
325 | layers.append(block(inplanes, planes, stride, downsample))
326 | inplanes = planes * block.expansion
327 | for i in range(1, blocks):
328 | if i == (blocks-1):
329 | layers.append(block(inplanes, planes, stride=1, no_relu=True))
330 | else:
331 | layers.append(block(inplanes, planes, stride=1, no_relu=False))
332 |
333 | return nn.Sequential(*layers)
334 |
335 | def forward(self, x):
336 |
337 | width_output_or = x.shape[-1]
338 | height_output_or = x.shape[-2]
339 |
340 | width_output = x.shape[-1] // 8
341 | height_output = x.shape[-2] // 8
342 | layers = []
343 |
344 | x = self.conv1(x)
345 |
346 | x = self.layer1(x)
347 | layers.append(x)
348 |
349 | x = self.layer2(self.relu(x))
350 | layers.append(x)
351 |
352 | x = self.layer3_1(self.relu(x))
353 | layers.append(x)
354 | x_ = self.layer3_1_(self.relu(layers[1]))
355 | x = x + self.down3_1(self.relu(x_))
356 | x_ = x_ + F.interpolate(
357 | self.compression3_1(self.relu(layers[2])),
358 | size=[height_output, width_output],
359 | mode='bilinear')
360 |
361 | x = self.layer3_2(self.relu(x))
362 | layers.append(x)
363 | x_ = self.layer3_2_(self.relu(x_))
364 | x = x + self.down3_2(self.relu(x_))
365 | x_ = x_ + F.interpolate(
366 | self.compression3_2(self.relu(layers[3])),
367 | size=[height_output, width_output],
368 | mode='bilinear')
369 |
370 | temp = x_
371 |
372 | x = self.layer4(self.relu(x))
373 | layers.append(x)
374 | x_ = self.layer4_(self.relu(x_))
375 | x = x + self.down4(self.relu(x_))
376 | x_ = x_ + F.interpolate(
377 | self.compression4(self.relu(layers[4])),
378 | size=[height_output, width_output],
379 | mode='bilinear')
380 |
381 | x_ = self.layer5_(self.relu(x_))
382 | x = F.interpolate(
383 | self.spp(self.layer5(self.relu(x))),
384 | size=[height_output, width_output],
385 | mode='bilinear')
386 |
387 | x_ = self.final_layer(x + x_)
388 |
389 | outputs = []
390 |
391 | x_ = F.interpolate(x_,
392 | size=[height_output_or, width_output_or],
393 | mode='bilinear', align_corners=True)
394 | outputs.append(x_)
395 |
396 | if self.augment:
397 | x_extra = self.seghead_extra(temp)
398 | return [x_, x_extra]
399 | else:
400 | return tuple(outputs)
401 |
402 |
403 | def DualResNet_imagenet(pretrained=False):
404 | model = DualResNet(BasicBlock, [3, 4, 6, 3], num_classes=19,
405 | planes=64, spp_planes=128, head_planes=256, augment=False)
406 | if pretrained:
407 |
408 | pretrained_state = torch.load(
409 | "./models/DDRNet39_imagenet.pth", map_location='cpu')
410 | model_dict = model.state_dict()
411 | pretrained_state = {k: v for k, v in pretrained_state.items() if
412 | (k in model_dict and v.shape == model_dict[k].shape)}
413 | model_dict.update(pretrained_state)
414 |
415 | model.load_state_dict(model_dict, strict=False)
416 | print("Having loaded imagenet-pretrained weights successfully!")
417 | return model
418 |
419 |
420 | def get_ddrnet_39(pretrained=False):
421 |
422 | model = DualResNet_imagenet(pretrained=pretrained)
423 | return model
424 |
425 |
426 | if __name__ == "__main__":
427 |
428 | model = DualResNet_imagenet(pretrained=True)
429 |
--------------------------------------------------------------------------------
/models/DDRNet_23_slim.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 | import torch
8 | from collections import OrderedDict
9 |
10 | BatchNorm2d = nn.BatchNorm2d
11 | bn_mom = 0.1
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | expansion = 1
22 |
23 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False):
24 | super(BasicBlock, self).__init__()
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.conv2 = conv3x3(planes, planes)
29 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
30 | self.downsample = downsample
31 | self.stride = stride
32 | self.no_relu = no_relu
33 |
34 | def forward(self, x):
35 | residual = x
36 |
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv2(out)
42 | out = self.bn2(out)
43 |
44 | if self.downsample is not None:
45 | residual = self.downsample(x)
46 |
47 | out += residual
48 |
49 | if self.no_relu:
50 | return out
51 | else:
52 | return self.relu(out)
53 |
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 2
57 |
58 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True):
59 | super(Bottleneck, self).__init__()
60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63 | padding=1, bias=False)
64 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
65 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
66 | bias=False)
67 | self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom)
68 | self.relu = nn.ReLU(inplace=True)
69 | self.downsample = downsample
70 | self.stride = stride
71 | self.no_relu = no_relu
72 |
73 | def forward(self, x):
74 | residual = x
75 |
76 | out = self.conv1(x)
77 | out = self.bn1(out)
78 | out = self.relu(out)
79 |
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv3(out)
85 | out = self.bn3(out)
86 |
87 | if self.downsample is not None:
88 | residual = self.downsample(x)
89 |
90 | out += residual
91 | if self.no_relu:
92 | return out
93 | else:
94 | return self.relu(out)
95 |
96 |
97 | class DAPPM(nn.Module):
98 | def __init__(self, inplanes, branch_planes, outplanes):
99 | super(DAPPM, self).__init__()
100 | self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
101 | BatchNorm2d(inplanes, momentum=bn_mom),
102 | nn.ReLU(inplace=True),
103 | nn.Conv2d(inplanes, branch_planes,
104 | kernel_size=1, bias=False),
105 | )
106 | self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
107 | BatchNorm2d(inplanes, momentum=bn_mom),
108 | nn.ReLU(inplace=True),
109 | nn.Conv2d(inplanes, branch_planes,
110 | kernel_size=1, bias=False),
111 | )
112 | self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
113 | BatchNorm2d(inplanes, momentum=bn_mom),
114 | nn.ReLU(inplace=True),
115 | nn.Conv2d(inplanes, branch_planes,
116 | kernel_size=1, bias=False),
117 | )
118 | self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
119 | BatchNorm2d(inplanes, momentum=bn_mom),
120 | nn.ReLU(inplace=True),
121 | nn.Conv2d(inplanes, branch_planes,
122 | kernel_size=1, bias=False),
123 | )
124 | self.scale0 = nn.Sequential(
125 | BatchNorm2d(inplanes, momentum=bn_mom),
126 | nn.ReLU(inplace=True),
127 | nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
128 | )
129 | self.process1 = nn.Sequential(
130 | BatchNorm2d(branch_planes, momentum=bn_mom),
131 | nn.ReLU(inplace=True),
132 | nn.Conv2d(branch_planes, branch_planes,
133 | kernel_size=3, padding=1, bias=False),
134 | )
135 | self.process2 = nn.Sequential(
136 | BatchNorm2d(branch_planes, momentum=bn_mom),
137 | nn.ReLU(inplace=True),
138 | nn.Conv2d(branch_planes, branch_planes,
139 | kernel_size=3, padding=1, bias=False),
140 | )
141 | self.process3 = nn.Sequential(
142 | BatchNorm2d(branch_planes, momentum=bn_mom),
143 | nn.ReLU(inplace=True),
144 | nn.Conv2d(branch_planes, branch_planes,
145 | kernel_size=3, padding=1, bias=False),
146 | )
147 | self.process4 = nn.Sequential(
148 | BatchNorm2d(branch_planes, momentum=bn_mom),
149 | nn.ReLU(inplace=True),
150 | nn.Conv2d(branch_planes, branch_planes,
151 | kernel_size=3, padding=1, bias=False),
152 | )
153 | self.compression = nn.Sequential(
154 | BatchNorm2d(branch_planes * 5, momentum=bn_mom),
155 | nn.ReLU(inplace=True),
156 | nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
157 | )
158 | self.shortcut = nn.Sequential(
159 | BatchNorm2d(inplanes, momentum=bn_mom),
160 | nn.ReLU(inplace=True),
161 | nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
162 | )
163 |
164 | def forward(self, x):
165 |
166 | # x = self.downsample(x)
167 | width = x.shape[-1]
168 | height = x.shape[-2]
169 | x_list = []
170 |
171 | x_list.append(self.scale0(x))
172 | x_list.append(self.process1((F.interpolate(self.scale1(x),
173 | size=[height, width],
174 | mode='bilinear')+x_list[0])))
175 | x_list.append((self.process2((F.interpolate(self.scale2(x),
176 | size=[height, width],
177 | mode='bilinear')+x_list[1]))))
178 | x_list.append(self.process3((F.interpolate(self.scale3(x),
179 | size=[height, width],
180 | mode='bilinear')+x_list[2])))
181 | x_list.append(self.process4((F.interpolate(self.scale4(x),
182 | size=[height, width],
183 | mode='bilinear')+x_list[3])))
184 |
185 | out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
186 | return out
187 |
188 |
189 | class segmenthead(nn.Module):
190 |
191 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
192 | super(segmenthead, self).__init__()
193 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
194 | self.conv1 = nn.Conv2d(inplanes, interplanes,
195 | kernel_size=3, padding=1, bias=False)
196 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
197 | self.relu = nn.ReLU(inplace=True)
198 | self.conv2 = nn.Conv2d(interplanes, outplanes,
199 | kernel_size=1, padding=0, bias=True)
200 | self.scale_factor = scale_factor
201 |
202 | def forward(self, x):
203 |
204 | x = self.conv1(self.relu(self.bn1(x)))
205 | out = self.conv2(self.relu(self.bn2(x)))
206 |
207 | if self.scale_factor is not None:
208 | height = x.shape[-2] * self.scale_factor
209 | width = x.shape[-1] * self.scale_factor
210 | out = F.interpolate(out,
211 | size=[height, width],
212 | mode='bilinear')
213 |
214 | return out
215 |
216 |
217 | class DualResNet(nn.Module):
218 |
219 | def __init__(self, block, layers, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False):
220 | super(DualResNet, self).__init__()
221 |
222 | highres_planes = planes * 2
223 | self.augment = augment
224 |
225 | self.conv1 = nn.Sequential(
226 | nn.Conv2d(3, planes, kernel_size=3, stride=2, padding=1),
227 | BatchNorm2d(planes, momentum=bn_mom),
228 | nn.ReLU(inplace=True),
229 | nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1),
230 | BatchNorm2d(planes, momentum=bn_mom),
231 | nn.ReLU(inplace=True),
232 | )
233 |
234 | self.relu = nn.ReLU(inplace=False)
235 | self.layer1 = self._make_layer(block, planes, planes, layers[0])
236 | self.layer2 = self._make_layer(
237 | block, planes, planes * 2, layers[1], stride=2)
238 | self.layer3 = self._make_layer(
239 | block, planes * 2, planes * 4, layers[2], stride=2)
240 | self.layer4 = self._make_layer(
241 | block, planes * 4, planes * 8, layers[3], stride=2)
242 |
243 | self.compression3 = nn.Sequential(
244 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False),
245 | BatchNorm2d(highres_planes, momentum=bn_mom),
246 | )
247 |
248 | self.compression4 = nn.Sequential(
249 | nn.Conv2d(planes * 8, highres_planes, kernel_size=1, bias=False),
250 | BatchNorm2d(highres_planes, momentum=bn_mom),
251 | )
252 |
253 | self.down3 = nn.Sequential(
254 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3,
255 | stride=2, padding=1, bias=False),
256 | BatchNorm2d(planes * 4, momentum=bn_mom),
257 | )
258 |
259 | self.down4 = nn.Sequential(
260 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3,
261 | stride=2, padding=1, bias=False),
262 | BatchNorm2d(planes * 4, momentum=bn_mom),
263 | nn.ReLU(inplace=True),
264 | nn.Conv2d(planes * 4, planes * 8, kernel_size=3,
265 | stride=2, padding=1, bias=False),
266 | BatchNorm2d(planes * 8, momentum=bn_mom),
267 | )
268 |
269 | self.layer3_ = self._make_layer(block, planes * 2, highres_planes, 2)
270 |
271 | self.layer4_ = self._make_layer(
272 | block, highres_planes, highres_planes, 2)
273 |
274 | self.layer5_ = self._make_layer(
275 | Bottleneck, highres_planes, highres_planes, 1)
276 |
277 | self.layer5 = self._make_layer(
278 | Bottleneck, planes * 8, planes * 8, 1, stride=2)
279 |
280 | self.spp = DAPPM(planes * 16, spp_planes, planes * 4)
281 |
282 | if self.augment:
283 | self.seghead_extra = segmenthead(
284 | highres_planes, head_planes, num_classes)
285 |
286 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes)
287 |
288 | for m in self.modules():
289 | if isinstance(m, nn.Conv2d):
290 | nn.init.kaiming_normal_(
291 | m.weight, mode='fan_out', nonlinearity='relu')
292 | elif isinstance(m, BatchNorm2d):
293 | nn.init.constant_(m.weight, 1)
294 | nn.init.constant_(m.bias, 0)
295 |
296 | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
297 | downsample = None
298 | if stride != 1 or inplanes != planes * block.expansion:
299 | downsample = nn.Sequential(
300 | nn.Conv2d(inplanes, planes * block.expansion,
301 | kernel_size=1, stride=stride, bias=False),
302 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
303 | )
304 |
305 | layers = []
306 | layers.append(block(inplanes, planes, stride, downsample))
307 | inplanes = planes * block.expansion
308 | for i in range(1, blocks):
309 | if i == (blocks-1):
310 | layers.append(block(inplanes, planes, stride=1, no_relu=True))
311 | else:
312 | layers.append(block(inplanes, planes, stride=1, no_relu=False))
313 |
314 | return nn.Sequential(*layers)
315 |
316 | def forward(self, x):
317 | width_output_or = x.shape[-1]
318 | height_output_or = x.shape[-2]
319 |
320 | width_output = x.shape[-1] // 8
321 | height_output = x.shape[-2] // 8
322 | layers = []
323 |
324 | x = self.conv1(x)
325 |
326 | x = self.layer1(x)
327 | layers.append(x)
328 |
329 | x = self.layer2(self.relu(x))
330 | layers.append(x)
331 |
332 | x = self.layer3(self.relu(x))
333 | x_i=x
334 | layers.append(x)
335 | x_ = self.layer3_(self.relu(layers[1]))
336 |
337 | x = x + self.down3(self.relu(x_))
338 | x_ = x_ + F.interpolate(
339 | self.compression3(self.relu(layers[2])),
340 | size=[height_output, width_output],
341 | mode='bilinear')
342 | if self.augment:
343 | temp = x_
344 |
345 | x = self.layer4(self.relu(x))
346 | layers.append(x)
347 | x_ = self.layer4_(self.relu(x_))
348 |
349 | x = x + self.down4(self.relu(x_))
350 | x_ = x_ + F.interpolate(
351 | self.compression4(self.relu(layers[3])),
352 | size=[height_output, width_output],
353 | mode='bilinear')
354 |
355 | ### high resolution 1/8 branch, apperance branch
356 | x_ = self.layer5_(self.relu(x_))
357 | x_a =x_
358 |
359 | ### low resolution 1/64 branch, need upsampling, content branch
360 | x = F.interpolate(
361 | self.spp(self.layer5(self.relu(x))),
362 | size=[height_output, width_output],
363 | mode='bilinear')
364 | x_c = x
365 |
366 | x_ = self.final_layer(x + x_)
367 |
368 | outputs = []
369 | outputs_c=[]
370 | outputs_a=[]
371 | outputs_i=[]
372 |
373 | x_ = F.interpolate(x_,
374 | size=[height_output_or, width_output_or],
375 | mode='bilinear', align_corners=True)
376 | outputs.append(x_)
377 |
378 |
379 | x_c = F.interpolate(x_c,
380 | size=[height_output_or, width_output_or],
381 | mode='bilinear', align_corners=True)
382 | outputs_c.append(x_c)
383 |
384 | x_a = F.interpolate(x_a,
385 | size=[height_output_or, width_output_or],
386 | mode='bilinear', align_corners=True)
387 | outputs_a.append(x_a)
388 |
389 | x_i = F.interpolate(x_i,
390 | size=[height_output_or, width_output_or],
391 | mode='bilinear', align_corners=True)
392 | outputs_i.append(x_i)
393 |
394 | # if self.augment:
395 | # assert 1 == 0
396 | # x_extra = self.seghead_extra(temp)
397 | # return [x_, x_extra]
398 | # else:
399 | # #return tuple(outputs), tuple(outputs_c), tuple(outputs_a), tuple(outputs_i)
400 | # # return tuple(outputs)#, tuple(outputs_c), tuple(outputs_a), tuple(outputs_i)
401 | # return outputs, outputs_c ,outputs_a, outputs_i
402 | #return tuple(outputs), tuple(outputs_c), tuple(outputs_a), tuple(outputs_i)
403 | return tuple(outputs)#, x_c, x_a, x_i
404 | #return outputs, outputs_c ,outputs_a, outputs_i
405 |
406 |
407 | def DualResNet_imagenet(pretrained=False):
408 | #model, C, A, I = DualResNet(BasicBlock, [2, 2, 2, 2], num_classes=19,
409 | # planes=32, spp_planes=128, head_planes=64, augment=False)
410 | model = DualResNet(BasicBlock, [2, 2, 2, 2], num_classes=19,
411 | planes=32, spp_planes=128, head_planes=64, augment=False)
412 | if pretrained:
413 | # remove hardcoded path by user provided path
414 | checkpoint = torch.load(
415 | "D:/DDR/models/DDRNet23s_imagenet.pth", map_location='cpu')
416 |
417 | new_state_dict = OrderedDict()
418 | model_dict = model.state_dict()
419 |
420 | for existing_key, _ in model_dict.items():
421 | for k, v in checkpoint.items():
422 | if existing_key == k:
423 | name = k[7:]
424 | new_state_dict[k] = v
425 | model_dict.update(new_state_dict)
426 | model.load_state_dict(model_dict)
427 |
428 | print("Having loaded imagenet-pretrained successfully!")
429 |
430 | # model.load_state_dict(new_state_dict, strict=False)
431 | return model#, C, A, I
432 |
433 |
434 | def get_ddrnet_23_slim(pretrained=True):
435 |
436 | model = DualResNet_imagenet(pretrained=pretrained)
437 | return model
438 |
439 |
440 | if __name__ == "__main__":
441 |
442 | model = DualResNet_imagenet(pretrained=True)
443 |
--------------------------------------------------------------------------------
/models/DDRNet_23.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 | from collections import OrderedDict
8 |
9 | BatchNorm2d = nn.BatchNorm2d
10 | bn_mom = 0.1
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | """3x3 convolution with padding"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=1, bias=False)
17 |
18 |
19 | class BasicBlock(nn.Module):
20 | expansion = 1
21 |
22 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False):
23 | super(BasicBlock, self).__init__()
24 | self.conv1 = conv3x3(inplanes, planes, stride)
25 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
26 | self.relu = nn.ReLU(inplace=True)
27 | self.conv2 = conv3x3(planes, planes)
28 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
29 | self.downsample = downsample
30 | self.stride = stride
31 | self.no_relu = no_relu
32 |
33 | def forward(self, x):
34 | residual = x
35 |
36 | out = self.conv1(x)
37 | out = self.bn1(out)
38 | out = self.relu(out)
39 |
40 | out = self.conv2(out)
41 | out = self.bn2(out)
42 |
43 | if self.downsample is not None:
44 | residual = self.downsample(x)
45 |
46 | out += residual
47 |
48 | if self.no_relu:
49 | return out
50 | else:
51 | return self.relu(out)
52 |
53 |
54 | class Bottleneck(nn.Module):
55 | expansion = 2
56 |
57 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True):
58 | super(Bottleneck, self).__init__()
59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
60 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
62 | padding=1, bias=False)
63 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
64 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
65 | bias=False)
66 | self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom)
67 | self.relu = nn.ReLU(inplace=True)
68 | self.downsample = downsample
69 | self.stride = stride
70 | self.no_relu = no_relu
71 |
72 | def forward(self, x):
73 | residual = x
74 |
75 | out = self.conv1(x)
76 | out = self.bn1(out)
77 | out = self.relu(out)
78 |
79 | out = self.conv2(out)
80 | out = self.bn2(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv3(out)
84 | out = self.bn3(out)
85 |
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 |
89 | out += residual
90 | if self.no_relu:
91 | return out
92 | else:
93 | return self.relu(out)
94 |
95 |
96 | class DAPPM(nn.Module):
97 | def __init__(self, inplanes, branch_planes, outplanes):
98 | super(DAPPM, self).__init__()
99 | self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
100 | BatchNorm2d(inplanes, momentum=bn_mom),
101 | nn.ReLU(inplace=True),
102 | nn.Conv2d(inplanes, branch_planes,
103 | kernel_size=1, bias=False),
104 | )
105 | self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
106 | BatchNorm2d(inplanes, momentum=bn_mom),
107 | nn.ReLU(inplace=True),
108 | nn.Conv2d(inplanes, branch_planes,
109 | kernel_size=1, bias=False),
110 | )
111 | self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
112 | BatchNorm2d(inplanes, momentum=bn_mom),
113 | nn.ReLU(inplace=True),
114 | nn.Conv2d(inplanes, branch_planes,
115 | kernel_size=1, bias=False),
116 | )
117 | self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
118 | BatchNorm2d(inplanes, momentum=bn_mom),
119 | nn.ReLU(inplace=True),
120 | nn.Conv2d(inplanes, branch_planes,
121 | kernel_size=1, bias=False),
122 | )
123 | self.scale0 = nn.Sequential(
124 | BatchNorm2d(inplanes, momentum=bn_mom),
125 | nn.ReLU(inplace=True),
126 | nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
127 | )
128 | self.process1 = nn.Sequential(
129 | BatchNorm2d(branch_planes, momentum=bn_mom),
130 | nn.ReLU(inplace=True),
131 | nn.Conv2d(branch_planes, branch_planes,
132 | kernel_size=3, padding=1, bias=False),
133 | )
134 | self.process2 = nn.Sequential(
135 | BatchNorm2d(branch_planes, momentum=bn_mom),
136 | nn.ReLU(inplace=True),
137 | nn.Conv2d(branch_planes, branch_planes,
138 | kernel_size=3, padding=1, bias=False),
139 | )
140 | self.process3 = nn.Sequential(
141 | BatchNorm2d(branch_planes, momentum=bn_mom),
142 | nn.ReLU(inplace=True),
143 | nn.Conv2d(branch_planes, branch_planes,
144 | kernel_size=3, padding=1, bias=False),
145 | )
146 | self.process4 = nn.Sequential(
147 | BatchNorm2d(branch_planes, momentum=bn_mom),
148 | nn.ReLU(inplace=True),
149 | nn.Conv2d(branch_planes, branch_planes,
150 | kernel_size=3, padding=1, bias=False),
151 | )
152 | self.compression = nn.Sequential(
153 | BatchNorm2d(branch_planes * 5, momentum=bn_mom),
154 | nn.ReLU(inplace=True),
155 | nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
156 | )
157 | self.shortcut = nn.Sequential(
158 | BatchNorm2d(inplanes, momentum=bn_mom),
159 | nn.ReLU(inplace=True),
160 | nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
161 | )
162 |
163 | def forward(self, x):
164 |
165 | # x = self.downsample(x)
166 | width = x.shape[-1]
167 | height = x.shape[-2]
168 | x_list = []
169 |
170 | x_list.append(self.scale0(x))
171 | x_list.append(self.process1((F.interpolate(self.scale1(x),
172 | size=[height, width],
173 | mode='bilinear')+x_list[0])))
174 | x_list.append((self.process2((F.interpolate(self.scale2(x),
175 | size=[height, width],
176 | mode='bilinear')+x_list[1]))))
177 | x_list.append(self.process3((F.interpolate(self.scale3(x),
178 | size=[height, width],
179 | mode='bilinear')+x_list[2])))
180 | x_list.append(self.process4((F.interpolate(self.scale4(x),
181 | size=[height, width],
182 | mode='bilinear')+x_list[3])))
183 |
184 | out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
185 | return out
186 |
187 |
188 | class segmenthead(nn.Module):
189 |
190 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
191 | super(segmenthead, self).__init__()
192 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
193 | self.conv1 = nn.Conv2d(inplanes, interplanes,
194 | kernel_size=3, padding=1, bias=False)
195 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
196 | self.relu = nn.ReLU(inplace=True)
197 | self.conv2 = nn.Conv2d(interplanes, outplanes,
198 | kernel_size=1, padding=0, bias=True)
199 | self.scale_factor = scale_factor
200 |
201 | ### for C representation branch. feature refine and attention weight matrix
202 | self.conv3 = nn.Sequential(
203 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0),
204 | BatchNorm2d(1, momentum=bn_mom),
205 | nn.ReLU(inplace=True),
206 | )
207 |
208 | ### for A representation branch. feature refine and attention weight matrix
209 | self.conv4 = nn.Sequential(
210 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0),
211 | BatchNorm2d(1, momentum=bn_mom),
212 | nn.ReLU(inplace=True),
213 | )
214 |
215 | self.conv5 = nn.Sequential(
216 | nn.Conv2d(256, 19, kernel_size=1, stride=1, padding=0),
217 | BatchNorm2d(19, momentum=bn_mom),
218 | nn.ReLU(inplace=True),
219 | )
220 |
221 | self.conv6 = nn.Sequential(
222 | nn.Conv2d(256, 19, kernel_size=1, stride=1, padding=0),
223 | BatchNorm2d(19, momentum=bn_mom),
224 | nn.ReLU(inplace=True),
225 | )
226 |
227 | def forward(self, C, A):
228 |
229 | attention_c=self.conv3(C)
230 | attention_a=self.conv4(A)
231 |
232 | C_ = 2*C + attention_c *C + A *attention_c
233 | A_ =A + attention_a *A
234 |
235 | x = self.conv1(self.relu(self.bn1(C_+A_)))
236 | out = self.conv2(self.relu(self.bn2(x)))
237 |
238 | Cupdate=self.conv5(C)
239 | C_update=self.conv6(C_)
240 |
241 | if self.scale_factor is not None:
242 | height = x.shape[-2] * self.scale_factor
243 | width = x.shape[-1] * self.scale_factor
244 | out = F.interpolate(out,
245 | size=[height, width],
246 | mode='bilinear')
247 |
248 | return out, C_update, Cupdate
249 |
250 |
251 | class segmentheadold(nn.Module):
252 |
253 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
254 | super(segmenthead, self).__init__()
255 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
256 | self.conv1 = nn.Conv2d(inplanes, interplanes,
257 | kernel_size=3, padding=1, bias=False)
258 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
259 | self.relu = nn.ReLU(inplace=True)
260 | self.conv2 = nn.Conv2d(interplanes, outplanes,
261 | kernel_size=1, padding=0, bias=True)
262 | self.scale_factor = scale_factor
263 |
264 | def forward(self, x):
265 |
266 | x = self.conv1(self.relu(self.bn1(x)))
267 | out = self.conv2(self.relu(self.bn2(x)))
268 |
269 | if self.scale_factor is not None:
270 | height = x.shape[-2] * self.scale_factor
271 | width = x.shape[-1] * self.scale_factor
272 | out = F.interpolate(out,
273 | size=[height, width],
274 | mode='bilinear')
275 |
276 | return out
277 |
278 |
279 | class DualResNet(nn.Module):
280 |
281 | def __init__(self, block, layers, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False):
282 | super(DualResNet, self).__init__()
283 |
284 | highres_planes = planes * 2
285 | self.augment = augment
286 |
287 | self.conv1 = nn.Sequential(
288 | nn.Conv2d(3, planes, kernel_size=3, stride=2, padding=1),
289 | BatchNorm2d(planes, momentum=bn_mom),
290 | nn.ReLU(inplace=True),
291 | nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1),
292 | BatchNorm2d(planes, momentum=bn_mom),
293 | nn.ReLU(inplace=True),
294 | )
295 |
296 | self.relu = nn.ReLU(inplace=False)
297 | self.layer1 = self._make_layer(block, planes, planes, layers[0])
298 | self.layer2 = self._make_layer(
299 | block, planes, planes * 2, layers[1], stride=2)
300 | self.layer3 = self._make_layer(
301 | block, planes * 2, planes * 4, layers[2], stride=2)
302 | self.layer4 = self._make_layer(
303 | block, planes * 4, planes * 8, layers[3], stride=2)
304 |
305 | self.compression3 = nn.Sequential(
306 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False),
307 | BatchNorm2d(highres_planes, momentum=bn_mom),
308 | )
309 |
310 | self.compression4 = nn.Sequential(
311 | nn.Conv2d(planes * 8, highres_planes, kernel_size=1, bias=False),
312 | BatchNorm2d(highres_planes, momentum=bn_mom),
313 | )
314 |
315 | self.down3 = nn.Sequential(
316 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3,
317 | stride=2, padding=1, bias=False),
318 | BatchNorm2d(planes * 4, momentum=bn_mom),
319 | )
320 |
321 | self.down4 = nn.Sequential(
322 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3,
323 | stride=2, padding=1, bias=False),
324 | BatchNorm2d(planes * 4, momentum=bn_mom),
325 | nn.ReLU(inplace=True),
326 | nn.Conv2d(planes * 4, planes * 8, kernel_size=3,
327 | stride=2, padding=1, bias=False),
328 | BatchNorm2d(planes * 8, momentum=bn_mom),
329 | )
330 |
331 | self.layer3_ = self._make_layer(block, planes * 2, highres_planes, 2)
332 |
333 | self.layer4_ = self._make_layer(
334 | block, highres_planes, highres_planes, 2)
335 |
336 | self.layer5_ = self._make_layer(
337 | Bottleneck, highres_planes, highres_planes, 1)
338 |
339 | self.layer5 = self._make_layer(
340 | Bottleneck, planes * 8, planes * 8, 1, stride=2)
341 |
342 | self.spp = DAPPM(planes * 16, spp_planes, planes * 4)
343 |
344 | if self.augment:
345 | self.seghead_extra = segmenthead(
346 | highres_planes, head_planes, num_classes)
347 |
348 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes)
349 |
350 | for m in self.modules():
351 | if isinstance(m, nn.Conv2d):
352 | nn.init.kaiming_normal_(
353 | m.weight, mode='fan_out', nonlinearity='relu')
354 | elif isinstance(m, BatchNorm2d):
355 | nn.init.constant_(m.weight, 1)
356 | nn.init.constant_(m.bias, 0)
357 |
358 | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
359 | downsample = None
360 | if stride != 1 or inplanes != planes * block.expansion:
361 | downsample = nn.Sequential(
362 | nn.Conv2d(inplanes, planes * block.expansion,
363 | kernel_size=1, stride=stride, bias=False),
364 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
365 | )
366 |
367 | layers = []
368 | layers.append(block(inplanes, planes, stride, downsample))
369 | inplanes = planes * block.expansion
370 | for i in range(1, blocks):
371 | if i == (blocks-1):
372 | layers.append(block(inplanes, planes, stride=1, no_relu=True))
373 | else:
374 | layers.append(block(inplanes, planes, stride=1, no_relu=False))
375 |
376 | return nn.Sequential(*layers)
377 |
378 | def forward(self, x):
379 |
380 | width_output_or = x.shape[-1]
381 | height_output_or = x.shape[-2]
382 |
383 | width_output = x.shape[-1] // 8
384 | height_output = x.shape[-2] // 8
385 | layers = []
386 |
387 | x = self.conv1(x)
388 |
389 | x = self.layer1(x)
390 | layers.append(x)
391 |
392 | x = self.layer2(self.relu(x))
393 | layers.append(x)
394 |
395 | x = self.layer3(self.relu(x))
396 | layers.append(x)
397 |
398 | x_ = self.layer3_(self.relu(layers[1])) #### [4,128,128,128]
399 |
400 | x = x + self.down3(self.relu(x_)) ###[4,256,64,64]
401 |
402 | x_ = x_ + F.interpolate(
403 | self.compression3(self.relu(layers[2])),
404 | size=[height_output, width_output],
405 | mode='bilinear') ###[4,128,64,64]
406 |
407 | if self.augment:
408 | temp = x_
409 |
410 | x = self.layer4(self.relu(x))
411 | layers.append(x)
412 | x_ = self.layer4_(self.relu(x_))
413 |
414 | x = x + self.down4(self.relu(x_))
415 | x_ = x_ + F.interpolate(
416 | self.compression4(self.relu(layers[3])),
417 | size=[height_output, width_output],
418 | mode='bilinear') ###[4,128,128,128]
419 |
420 | ### high-resolution 1/8 apperance branch
421 | before_E = x_
422 | x_ = self.layer5_(self.relu(x_))
423 | x_a = x_ ###[4,256,128,128]
424 | after_E = x_
425 |
426 | ### low resolution 1/64 branch, need upsampling, content branch
427 | x = self.layer5(self.relu(x))
428 | before_I = F.interpolate(
429 | x,
430 | size=[height_output, width_output],
431 | mode='bilinear')
432 | x = F.interpolate(
433 | self.spp(x),
434 | size=[height_output, width_output],
435 | mode='bilinear')
436 | x_c = x ###[4,256,128,128]
437 | after_I = x
438 |
439 | x_, C, C_ = self.final_layer(x, x_) ### seghead [4,19,128,128]
440 |
441 | outputs = []
442 |
443 | x_ = F.interpolate(x_,
444 | size=[height_output_or, width_output_or],
445 | mode='bilinear', align_corners=True) #[4,19,1024,1024]
446 |
447 | outputs.append(x_)
448 |
449 | if self.augment:
450 | x_extra = self.seghead_extra(temp)
451 | return [x_, x_extra]
452 | else:
453 | return tuple(outputs), C_, C
454 |
455 |
456 | def DualResNet_imagenet(pretrained=True):
457 | model = DualResNet(BasicBlock, [2, 2, 2, 2], num_classes=19,
458 | planes=64, spp_planes=128, head_planes=128, augment=False)
459 | if pretrained:
460 | pretrained_state = torch.load(
461 | "D:/DDR/models/DDRNet23_imagenet.pth", map_location='cpu')
462 | model_dict = model.state_dict()
463 | pretrained_state = {k: v for k, v in pretrained_state.items() if
464 | (k in model_dict and v.shape == model_dict[k].shape)}
465 | model_dict.update(pretrained_state)
466 |
467 | model.load_state_dict(model_dict, strict=False)
468 | print("Having loaded imagenet-pretrained weights successfully!")
469 |
470 | return model
471 |
472 |
473 | def get_ddrnet_23(pretrained=True):
474 |
475 | model = DualResNet_imagenet(pretrained=pretrained)
476 | return model
477 |
478 |
479 | ### C-leaner & A-learner later part
480 |
481 | class CAinteract(nn.Module):
482 |
483 | def __init__(self, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False):
484 | super(CAinteract, self).__init__()
485 |
486 | self.augment = augment
487 | highres_planes = planes * 2
488 |
489 | ### for C representation branch. feature refine and attention weight matrix
490 | self.conv1 = nn.Sequential(
491 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
492 | BatchNorm2d(256, momentum=bn_mom),
493 | nn.ReLU(inplace=True),
494 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0),
495 | BatchNorm2d(1, momentum=bn_mom),
496 | nn.ReLU(inplace=True),
497 | )
498 |
499 | ### for A representation branch. feature refine and attention weight matrix
500 | self.conv2 = nn.Sequential(
501 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
502 | BatchNorm2d(256, momentum=bn_mom),
503 | nn.ReLU(inplace=True),
504 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0),
505 | BatchNorm2d(1, momentum=bn_mom),
506 | nn.ReLU(inplace=True),
507 | )
508 |
509 | self.relu = nn.ReLU(inplace=False)
510 |
511 | if self.augment:
512 | self.seghead_extra = segmenthead(
513 | highres_planes, head_planes, num_classes)
514 |
515 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes)
516 |
517 |
518 | for m in self.modules():
519 | if isinstance(m, nn.Conv2d):
520 | nn.init.kaiming_normal_(
521 | m.weight, mode='fan_out', nonlinearity='relu')
522 | elif isinstance(m, BatchNorm2d):
523 | nn.init.constant_(m.weight, 1)
524 | nn.init.constant_(m.bias, 0)
525 |
526 | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
527 | downsample = None
528 | if stride != 1 or inplanes != planes * block.expansion:
529 | downsample = nn.Sequential(
530 | nn.Conv2d(inplanes, planes * block.expansion,
531 | kernel_size=1, stride=stride, bias=False),
532 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
533 | )
534 |
535 | layers = []
536 | layers.append(block(inplanes, planes, stride, downsample))
537 | inplanes = planes * block.expansion
538 | for i in range(1, blocks):
539 | if i == (blocks-1):
540 | layers.append(block(inplanes, planes, stride=1, no_relu=True))
541 | else:
542 | layers.append(block(inplanes, planes, stride=1, no_relu=False))
543 |
544 | return nn.Sequential(*layers)
545 |
546 | def forward(self, X_, C, A):
547 |
548 | # layers = []
549 | height_output=1024
550 | width_output=1024
551 |
552 | attention_c = self.conv1(C)
553 | attention_a = self.conv2(A)
554 |
555 | C=C*attention_c+C
556 | A=A*attention_a+A
557 |
558 | x_ = self.final_layer(2*X_+ C + A) ### seghead [4,19,128,128]
559 |
560 | outputs = []
561 |
562 | x_ = F.interpolate(x_,
563 | size=[height_output, width_output],
564 | mode='bilinear', align_corners=True) #[4,19,1024,1024]
565 |
566 | outputs.append(x_)
567 |
568 | return tuple(outputs), C, A
569 |
570 |
571 | ### C-A merge module
572 |
573 | class CAmerge(nn.Module):
574 |
575 | def __init__(self, planes=64, augment=False):
576 | super(CAmerge, self).__init__()
577 |
578 | highres_planes = planes * 2
579 | self.augment = augment
580 |
581 | #### for upsample 1/8 to 1/4 256->64
582 |
583 | self.conv1 = nn.Sequential(
584 | # nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=1),
585 | # BatchNorm2d(128, momentum=bn_mom),
586 | # nn.ReLU(inplace=True),
587 | nn.Conv2d(256, planes, kernel_size=3, stride=1, padding=1),
588 | BatchNorm2d(planes, momentum=bn_mom),
589 | nn.ReLU(inplace=True),
590 | )
591 |
592 | ##### for upsample 1/4 to 1/2 64->32
593 | self.conv2 = nn.Sequential(
594 | nn.Conv2d(planes, 32, kernel_size=3, stride=1, padding=1),
595 | BatchNorm2d(32, momentum=bn_mom),
596 | nn.ReLU(inplace=True),
597 | )
598 |
599 | ##### for upsample 1/2 to orginal 1/1 32->3
600 | self.conv3 = nn.Sequential(
601 | nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
602 | BatchNorm2d(3, momentum=bn_mom),
603 | nn.ReLU(inplace=True),
604 | )
605 |
606 | self.relu = nn.ReLU(inplace=False)
607 |
608 | # if self.augment:
609 | # self.seghead_extra = segmenthead(
610 | # highres_planes, head_planes, num_classes)
611 |
612 | # self.final_layer = segmenthead(planes * 4, head_planes, num_classes)
613 |
614 | for m in self.modules():
615 | if isinstance(m, nn.Conv2d):
616 | nn.init.kaiming_normal_(
617 | m.weight, mode='fan_out', nonlinearity='relu')
618 | elif isinstance(m, BatchNorm2d):
619 | nn.init.constant_(m.weight, 1)
620 | nn.init.constant_(m.bias, 0)
621 |
622 | def forward(self, C, A):
623 |
624 | layers = []
625 |
626 | ### C and A are from 1/8 resolution, input for reconstruction ###[4,256,128,128] -> [4,64,128,128] -> [4,64,256,256]
627 | x = self.conv1(C+A)
628 | layers.append(x)
629 | ## upsample 128->256
630 | x = F.interpolate(
631 | self.relu(layers[0]),
632 | size=[256,256],
633 | mode='bilinear')
634 |
635 | ###[4,64,256,256] -> [4,32,256,256] -> [4,32,512,512]
636 | x = self.conv2(self.relu(x))
637 | # if x_41 has more channels than 64 (plane), then first need to add another layer to compress it to 64 channels
638 | layers.append(x)
639 | ## upsample 256->512
640 | x = F.interpolate(
641 | self.relu(layers[1]),
642 | size=[512,512],
643 | mode='bilinear')
644 |
645 | ###[4,32,512,512] -> [4,3,512,512] -> [4,3,1024,1024]
646 |
647 | x = self.conv3(self.relu(x))
648 | layers.append(x)
649 |
650 | x = F.interpolate(
651 | self.relu(layers[2]),
652 | size=[1024, 1024],
653 | mode='bilinear') ###[4,128,64,64]
654 |
655 | if self.augment:
656 | temp = x
657 |
658 | outputs = []
659 |
660 | outputs.append(x)
661 |
662 | if self.augment:
663 | x_extra = self.seghead_extra(temp)
664 | return [x, x_extra]
665 | else:
666 | return tuple(outputs), x
667 |
668 | def get_CA_interact( ):
669 | model = CAinteract( )
670 | return model
671 |
672 | def get_CA_merge( ):
673 | model = CAmerge( )
674 | return model
675 |
676 | if __name__ == "__main__":
677 |
678 | model = DualResNet_imagenet(pretrained=True)
679 |
--------------------------------------------------------------------------------
/models/DDRNet_23_vis1.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import init
7 | from collections import OrderedDict
8 | import random
9 |
10 | BatchNorm2d = nn.BatchNorm2d
11 | bn_mom = 0.1
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | expansion = 1
22 |
23 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False):
24 | super(BasicBlock, self).__init__()
25 | self.conv1 = conv3x3(inplanes, planes, stride)
26 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.conv2 = conv3x3(planes, planes)
29 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
30 | self.downsample = downsample
31 | self.stride = stride
32 | self.no_relu = no_relu
33 |
34 | def forward(self, x):
35 | residual = x
36 |
37 | out = self.conv1(x)
38 | out = self.bn1(out)
39 | out = self.relu(out)
40 |
41 | out = self.conv2(out)
42 | out = self.bn2(out)
43 |
44 | if self.downsample is not None:
45 | residual = self.downsample(x)
46 |
47 | out += residual
48 |
49 | if self.no_relu:
50 | return out
51 | else:
52 | return self.relu(out)
53 |
54 |
55 | class Bottleneck(nn.Module):
56 | expansion = 2
57 |
58 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True):
59 | super(Bottleneck, self).__init__()
60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom)
62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63 | padding=1, bias=False)
64 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom)
65 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
66 | bias=False)
67 | self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom)
68 | self.relu = nn.ReLU(inplace=True)
69 | self.downsample = downsample
70 | self.stride = stride
71 | self.no_relu = no_relu
72 |
73 | def forward(self, x):
74 | residual = x
75 |
76 | out = self.conv1(x)
77 | out = self.bn1(out)
78 | out = self.relu(out)
79 |
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv3(out)
85 | out = self.bn3(out)
86 |
87 | if self.downsample is not None:
88 | residual = self.downsample(x)
89 |
90 | out += residual
91 | if self.no_relu:
92 | return out
93 | else:
94 | return self.relu(out)
95 |
96 |
97 | class DAPPM(nn.Module):
98 | def __init__(self, inplanes, branch_planes, outplanes):
99 | super(DAPPM, self).__init__()
100 | self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2),
101 | BatchNorm2d(inplanes, momentum=bn_mom),
102 | nn.ReLU(inplace=True),
103 | nn.Conv2d(inplanes, branch_planes,
104 | kernel_size=1, bias=False),
105 | )
106 | self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4),
107 | BatchNorm2d(inplanes, momentum=bn_mom),
108 | nn.ReLU(inplace=True),
109 | nn.Conv2d(inplanes, branch_planes,
110 | kernel_size=1, bias=False),
111 | )
112 | self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8),
113 | BatchNorm2d(inplanes, momentum=bn_mom),
114 | nn.ReLU(inplace=True),
115 | nn.Conv2d(inplanes, branch_planes,
116 | kernel_size=1, bias=False),
117 | )
118 | self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
119 | BatchNorm2d(inplanes, momentum=bn_mom),
120 | nn.ReLU(inplace=True),
121 | nn.Conv2d(inplanes, branch_planes,
122 | kernel_size=1, bias=False),
123 | )
124 | self.scale0 = nn.Sequential(
125 | BatchNorm2d(inplanes, momentum=bn_mom),
126 | nn.ReLU(inplace=True),
127 | nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False),
128 | )
129 | self.process1 = nn.Sequential(
130 | BatchNorm2d(branch_planes, momentum=bn_mom),
131 | nn.ReLU(inplace=True),
132 | nn.Conv2d(branch_planes, branch_planes,
133 | kernel_size=3, padding=1, bias=False),
134 | )
135 | self.process2 = nn.Sequential(
136 | BatchNorm2d(branch_planes, momentum=bn_mom),
137 | nn.ReLU(inplace=True),
138 | nn.Conv2d(branch_planes, branch_planes,
139 | kernel_size=3, padding=1, bias=False),
140 | )
141 | self.process3 = nn.Sequential(
142 | BatchNorm2d(branch_planes, momentum=bn_mom),
143 | nn.ReLU(inplace=True),
144 | nn.Conv2d(branch_planes, branch_planes,
145 | kernel_size=3, padding=1, bias=False),
146 | )
147 | self.process4 = nn.Sequential(
148 | BatchNorm2d(branch_planes, momentum=bn_mom),
149 | nn.ReLU(inplace=True),
150 | nn.Conv2d(branch_planes, branch_planes,
151 | kernel_size=3, padding=1, bias=False),
152 | )
153 | self.compression = nn.Sequential(
154 | BatchNorm2d(branch_planes * 5, momentum=bn_mom),
155 | nn.ReLU(inplace=True),
156 | nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False),
157 | )
158 | self.shortcut = nn.Sequential(
159 | BatchNorm2d(inplanes, momentum=bn_mom),
160 | nn.ReLU(inplace=True),
161 | nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False),
162 | )
163 |
164 | def forward(self, x):
165 |
166 | # x = self.downsample(x)
167 | width = x.shape[-1]
168 | height = x.shape[-2]
169 | x_list = []
170 |
171 | x_list.append(self.scale0(x))
172 | x_list.append(self.process1((F.interpolate(self.scale1(x),
173 | size=[height, width],
174 | mode='bilinear')+x_list[0])))
175 | x_list.append((self.process2((F.interpolate(self.scale2(x),
176 | size=[height, width],
177 | mode='bilinear')+x_list[1]))))
178 | x_list.append(self.process3((F.interpolate(self.scale3(x),
179 | size=[height, width],
180 | mode='bilinear')+x_list[2])))
181 | x_list.append(self.process4((F.interpolate(self.scale4(x),
182 | size=[height, width],
183 | mode='bilinear')+x_list[3])))
184 |
185 | out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x)
186 | return out
187 |
188 |
189 | class segmenthead(nn.Module):
190 |
191 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
192 | super(segmenthead, self).__init__()
193 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
194 | self.conv1 = nn.Conv2d(inplanes, interplanes,
195 | kernel_size=3, padding=1, bias=False)
196 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
197 | self.relu = nn.ReLU(inplace=True)
198 | self.conv2 = nn.Conv2d(interplanes, outplanes,
199 | kernel_size=1, padding=0, bias=True)
200 | self.scale_factor = scale_factor
201 |
202 | ### for C representation branch. feature refine and attention weight matrix
203 | self.conv3 = nn.Sequential(
204 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0),
205 | BatchNorm2d(1, momentum=bn_mom),
206 | nn.ReLU(inplace=True),
207 | )
208 |
209 | ### for A representation branch. feature refine and attention weight matrix
210 | self.conv4 = nn.Sequential(
211 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0),
212 | BatchNorm2d(1, momentum=bn_mom),
213 | nn.ReLU(inplace=True),
214 | )
215 |
216 | self.conv5 = nn.Sequential(
217 | nn.Conv2d(256, 19, kernel_size=1, stride=1, padding=0),
218 | BatchNorm2d(19, momentum=bn_mom),
219 | nn.ReLU(inplace=True),
220 | )
221 |
222 | self.conv6 = nn.Sequential(
223 | nn.Conv2d(256, 19, kernel_size=1, stride=1, padding=0),
224 | BatchNorm2d(19, momentum=bn_mom),
225 | nn.ReLU(inplace=True),
226 | )
227 |
228 | def forward(self, C, A):
229 |
230 | attention_c=self.conv3(C)
231 | attention_a=self.conv4(A)
232 |
233 | C_ = 2*C + attention_c *C + A *attention_c
234 | A_ =A + attention_a *A
235 |
236 | x = self.conv1(self.relu(self.bn1(C_+A_)))
237 | out = self.conv2(self.relu(self.bn2(x)))
238 |
239 | Cupdate=self.conv5(C)
240 | C_update=self.conv6(C_)
241 |
242 | if self.scale_factor is not None:
243 | height = x.shape[-2] * self.scale_factor
244 | width = x.shape[-1] * self.scale_factor
245 | out = F.interpolate(out,
246 | size=[height, width],
247 | mode='bilinear')
248 |
249 | return out, C_update, Cupdate
250 |
251 |
252 | class segmentheadold(nn.Module):
253 |
254 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None):
255 | super(segmenthead, self).__init__()
256 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom)
257 | self.conv1 = nn.Conv2d(inplanes, interplanes,
258 | kernel_size=3, padding=1, bias=False)
259 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom)
260 | self.relu = nn.ReLU(inplace=True)
261 | self.conv2 = nn.Conv2d(interplanes, outplanes,
262 | kernel_size=1, padding=0, bias=True)
263 | self.scale_factor = scale_factor
264 |
265 | def forward(self, x):
266 |
267 | x = self.conv1(self.relu(self.bn1(x)))
268 | out = self.conv2(self.relu(self.bn2(x)))
269 |
270 | if self.scale_factor is not None:
271 | height = x.shape[-2] * self.scale_factor
272 | width = x.shape[-1] * self.scale_factor
273 | out = F.interpolate(out,
274 | size=[height, width],
275 | mode='bilinear')
276 |
277 | return out
278 |
279 |
280 | class DualResNet(nn.Module):
281 |
282 | def __init__(self, block, layers, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False):
283 | super(DualResNet, self).__init__()
284 |
285 | highres_planes = planes * 2
286 | self.augment = augment
287 |
288 | self.conv1 = nn.Sequential(
289 | nn.Conv2d(3, planes, kernel_size=3, stride=2, padding=1),
290 | BatchNorm2d(planes, momentum=bn_mom),
291 | nn.ReLU(inplace=True),
292 | nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1),
293 | BatchNorm2d(planes, momentum=bn_mom),
294 | nn.ReLU(inplace=True),
295 | )
296 |
297 | self.relu = nn.ReLU(inplace=False)
298 | self.layer1 = self._make_layer(block, planes, planes, layers[0])
299 | self.layer2 = self._make_layer(
300 | block, planes, planes * 2, layers[1], stride=2)
301 | self.layer3 = self._make_layer(
302 | block, planes * 2, planes * 4, layers[2], stride=2)
303 | self.layer4 = self._make_layer(
304 | block, planes * 4, planes * 8, layers[3], stride=2)
305 |
306 | self.compression3 = nn.Sequential(
307 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False),
308 | BatchNorm2d(highres_planes, momentum=bn_mom),
309 | )
310 |
311 | self.compression4 = nn.Sequential(
312 | nn.Conv2d(planes * 8, highres_planes, kernel_size=1, bias=False),
313 | BatchNorm2d(highres_planes, momentum=bn_mom),
314 | )
315 |
316 | self.down3 = nn.Sequential(
317 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3,
318 | stride=2, padding=1, bias=False),
319 | BatchNorm2d(planes * 4, momentum=bn_mom),
320 | )
321 |
322 | self.down4 = nn.Sequential(
323 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3,
324 | stride=2, padding=1, bias=False),
325 | BatchNorm2d(planes * 4, momentum=bn_mom),
326 | nn.ReLU(inplace=True),
327 | nn.Conv2d(planes * 4, planes * 8, kernel_size=3,
328 | stride=2, padding=1, bias=False),
329 | BatchNorm2d(planes * 8, momentum=bn_mom),
330 | )
331 |
332 | self.layer3_ = self._make_layer(block, planes * 2, highres_planes, 2)
333 |
334 | self.layer4_ = self._make_layer(
335 | block, highres_planes, highres_planes, 2)
336 |
337 | self.layer5_ = self._make_layer(
338 | Bottleneck, highres_planes, highres_planes, 1)
339 |
340 | self.layer5 = self._make_layer(
341 | Bottleneck, planes * 8, planes * 8, 1, stride=2)
342 |
343 | self.spp = DAPPM(planes * 16, spp_planes, planes * 4)
344 |
345 | if self.augment:
346 | self.seghead_extra = segmenthead(
347 | highres_planes, head_planes, num_classes)
348 |
349 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes)
350 |
351 | for m in self.modules():
352 | if isinstance(m, nn.Conv2d):
353 | nn.init.kaiming_normal_(
354 | m.weight, mode='fan_out', nonlinearity='relu')
355 | elif isinstance(m, BatchNorm2d):
356 | nn.init.constant_(m.weight, 1)
357 | nn.init.constant_(m.bias, 0)
358 |
359 | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
360 | downsample = None
361 | if stride != 1 or inplanes != planes * block.expansion:
362 | downsample = nn.Sequential(
363 | nn.Conv2d(inplanes, planes * block.expansion,
364 | kernel_size=1, stride=stride, bias=False),
365 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
366 | )
367 |
368 | layers = []
369 | layers.append(block(inplanes, planes, stride, downsample))
370 | inplanes = planes * block.expansion
371 | for i in range(1, blocks):
372 | if i == (blocks-1):
373 | layers.append(block(inplanes, planes, stride=1, no_relu=True))
374 | else:
375 | layers.append(block(inplanes, planes, stride=1, no_relu=False))
376 |
377 | return nn.Sequential(*layers)
378 |
379 | def forward(self, x):
380 |
381 | width_output_or = x.shape[-1]
382 | height_output_or = x.shape[-2]
383 |
384 | width_output = x.shape[-1] // 8
385 | height_output = x.shape[-2] // 8
386 | layers = []
387 |
388 | x = self.conv1(x)
389 |
390 | x = self.layer1(x)
391 | layers.append(x)
392 |
393 | x = self.layer2(self.relu(x))
394 | layers.append(x)
395 |
396 | x = self.layer3(self.relu(x))
397 | layers.append(x)
398 |
399 | x_ = self.layer3_(self.relu(layers[1])) #### [4,128,128,128]
400 |
401 | x = x + self.down3(self.relu(x_)) ###[4,256,64,64]
402 |
403 | x_ = x_ + F.interpolate(
404 | self.compression3(self.relu(layers[2])),
405 | size=[height_output, width_output],
406 | mode='bilinear') ###[4,128,64,64]
407 |
408 | if self.augment:
409 | temp = x_
410 |
411 | x = self.layer4(self.relu(x))
412 | layers.append(x)
413 | x_ = self.layer4_(self.relu(x_))
414 |
415 | x = x + self.down4(self.relu(x_))
416 | x_ = x_ + F.interpolate(
417 | self.compression4(self.relu(layers[3])),
418 | size=[height_output, width_output],
419 | mode='bilinear') ###[4,128,128,128]
420 |
421 | ### high-resolution 1/8 apperance branch
422 | before_E = x_
423 | x_ = self.layer5_(self.relu(x_))
424 | x_a = x_ ###[4,256,128,128]
425 | after_E = x_
426 |
427 | ### low resolution 1/64 branch, need upsampling, content branch
428 | x = self.layer5(self.relu(x))
429 | before_I = F.interpolate(
430 | x,
431 | size=[height_output, width_output],
432 | mode='bilinear')
433 | x = F.interpolate(
434 | self.spp(x),
435 | size=[height_output, width_output],
436 | mode='bilinear')
437 | x_c = x ###[4,256,128,128]
438 | after_I = x
439 |
440 | # tsne
441 | # before_Is = torch.mean(torch.mean(before_I, axis=1), 1)
442 | # after_Is = torch.mean(torch.mean(after_I, axis=1), 1)
443 | # before_Es = torch.mean(torch.mean(before_E, axis=1), 1)
444 | # after_Es = torch.mean(torch.mean(after_E, axis=1), 1)
445 |
446 | # heatmap
447 | # before_Is = torch.mean(before_I, axis=1) # 1024, 128, 256
448 | # after_Is = torch.mean(after_I, axis=1)
449 | # before_Es = torch.mean(before_E, axis=1)
450 | # after_Es = torch.mean(after_E, axis=1)
451 | np.random.seed(123)
452 | random_index1 = np.arange(before_I.shape[1])
453 | random_index2 = np.arange(after_I.shape[1])
454 | random_index3 = np.arange(before_E.shape[1])
455 | random_index4 = np.arange(after_E.shape[1])
456 | random.shuffle(random_index1)
457 | random.shuffle(random_index2)
458 | random.shuffle(random_index3)
459 | random.shuffle(random_index4)
460 | random_index1 = random_index1[:20]
461 | random_index2 = random_index2[:20]
462 | random_index3 = random_index3[:20]
463 | random_index4 = random_index4[:20]
464 | for i, ind in enumerate(random_index1):
465 | if i == 0:
466 | before_Is = before_I[:, ind, :, :]
467 | else:
468 | before_Is = torch.cat((before_Is, before_I[:, ind, :, :]), 0)
469 | for i, ind in enumerate(random_index2):
470 | if i == 0:
471 | after_Is = after_I[:, ind, :, :]
472 | else:
473 | after_Is = torch.cat((after_Is, after_I[:, ind, :, :]), 0)
474 | for i, ind in enumerate(random_index3):
475 | if i == 0:
476 | before_Es = before_E[:, ind, :, :]
477 | else:
478 | before_Es = torch.cat((before_Es, before_E[:, ind, :, :]), 0)
479 | for i, ind in enumerate(random_index4):
480 | if i == 0:
481 | after_Es = after_E[:, ind, :, :]
482 | else:
483 | after_Es = torch.cat((after_Es, after_E[:, ind, :, :]), 0)
484 |
485 | return before_Is, after_Is, before_Es, after_Es
486 |
487 | x_, C, C_ = self.final_layer(x, x_) ### seghead [4,19,128,128]
488 |
489 | outputs = []
490 |
491 | x_ = F.interpolate(x_,
492 | size=[height_output_or, width_output_or],
493 | mode='bilinear', align_corners=True) #[4,19,1024,1024]
494 |
495 | outputs.append(x_)
496 |
497 | if self.augment:
498 | x_extra = self.seghead_extra(temp)
499 | return [x_, x_extra]
500 | else:
501 | return tuple(outputs), C_, C
502 |
503 |
504 | def DualResNet_imagenet(pretrained=True):
505 | model = DualResNet(BasicBlock, [2, 2, 2, 2], num_classes=19,
506 | planes=64, spp_planes=128, head_planes=128, augment=False)
507 | if pretrained:
508 | pretrained_state = torch.load(
509 | "D:/DDR/models/DDRNet23_imagenet.pth", map_location='cpu')
510 | model_dict = model.state_dict()
511 | pretrained_state = {k: v for k, v in pretrained_state.items() if
512 | (k in model_dict and v.shape == model_dict[k].shape)}
513 | model_dict.update(pretrained_state)
514 |
515 | model.load_state_dict(model_dict, strict=False)
516 | print("Having loaded imagenet-pretrained weights successfully!")
517 |
518 | return model
519 |
520 |
521 | def get_ddrnet_23_vis1(pretrained=True):
522 |
523 | model = DualResNet_imagenet(pretrained=pretrained)
524 | return model
525 |
526 |
527 | ### C-leaner & A-learner later part
528 |
529 | class CAinteract(nn.Module):
530 |
531 | def __init__(self, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False):
532 | super(CAinteract, self).__init__()
533 |
534 | self.augment = augment
535 | highres_planes = planes * 2
536 |
537 | ### for C representation branch. feature refine and attention weight matrix
538 | self.conv1 = nn.Sequential(
539 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
540 | BatchNorm2d(256, momentum=bn_mom),
541 | nn.ReLU(inplace=True),
542 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0),
543 | BatchNorm2d(1, momentum=bn_mom),
544 | nn.ReLU(inplace=True),
545 | )
546 |
547 | ### for A representation branch. feature refine and attention weight matrix
548 | self.conv2 = nn.Sequential(
549 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0),
550 | BatchNorm2d(256, momentum=bn_mom),
551 | nn.ReLU(inplace=True),
552 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0),
553 | BatchNorm2d(1, momentum=bn_mom),
554 | nn.ReLU(inplace=True),
555 | )
556 |
557 | self.relu = nn.ReLU(inplace=False)
558 |
559 | if self.augment:
560 | self.seghead_extra = segmenthead(
561 | highres_planes, head_planes, num_classes)
562 |
563 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes)
564 |
565 |
566 | for m in self.modules():
567 | if isinstance(m, nn.Conv2d):
568 | nn.init.kaiming_normal_(
569 | m.weight, mode='fan_out', nonlinearity='relu')
570 | elif isinstance(m, BatchNorm2d):
571 | nn.init.constant_(m.weight, 1)
572 | nn.init.constant_(m.bias, 0)
573 |
574 | def _make_layer(self, block, inplanes, planes, blocks, stride=1):
575 | downsample = None
576 | if stride != 1 or inplanes != planes * block.expansion:
577 | downsample = nn.Sequential(
578 | nn.Conv2d(inplanes, planes * block.expansion,
579 | kernel_size=1, stride=stride, bias=False),
580 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom),
581 | )
582 |
583 | layers = []
584 | layers.append(block(inplanes, planes, stride, downsample))
585 | inplanes = planes * block.expansion
586 | for i in range(1, blocks):
587 | if i == (blocks-1):
588 | layers.append(block(inplanes, planes, stride=1, no_relu=True))
589 | else:
590 | layers.append(block(inplanes, planes, stride=1, no_relu=False))
591 |
592 | return nn.Sequential(*layers)
593 |
594 | def forward(self, X_, C, A):
595 |
596 | # layers = []
597 | height_output=1024
598 | width_output=1024
599 |
600 | attention_c = self.conv1(C)
601 | attention_a = self.conv2(A)
602 |
603 | C=C*attention_c+C
604 | A=A*attention_a+A
605 |
606 | x_ = self.final_layer(2*X_+ C + A) ### seghead [4,19,128,128]
607 |
608 | outputs = []
609 |
610 | x_ = F.interpolate(x_,
611 | size=[height_output, width_output],
612 | mode='bilinear', align_corners=True) #[4,19,1024,1024]
613 |
614 | outputs.append(x_)
615 |
616 | return tuple(outputs), C, A
617 |
618 |
619 | ### C-A merge module
620 |
621 | class CAmerge(nn.Module):
622 |
623 | def __init__(self, planes=64, augment=False):
624 | super(CAmerge, self).__init__()
625 |
626 | highres_planes = planes * 2
627 | self.augment = augment
628 |
629 | #### for upsample 1/8 to 1/4 256->64
630 |
631 | self.conv1 = nn.Sequential(
632 | # nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=1),
633 | # BatchNorm2d(128, momentum=bn_mom),
634 | # nn.ReLU(inplace=True),
635 | nn.Conv2d(256, planes, kernel_size=3, stride=1, padding=1),
636 | BatchNorm2d(planes, momentum=bn_mom),
637 | nn.ReLU(inplace=True),
638 | )
639 |
640 | ##### for upsample 1/4 to 1/2 64->32
641 | self.conv2 = nn.Sequential(
642 | nn.Conv2d(planes, 32, kernel_size=3, stride=1, padding=1),
643 | BatchNorm2d(32, momentum=bn_mom),
644 | nn.ReLU(inplace=True),
645 | )
646 |
647 | ##### for upsample 1/2 to orginal 1/1 32->3
648 | self.conv3 = nn.Sequential(
649 | nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
650 | BatchNorm2d(3, momentum=bn_mom),
651 | nn.ReLU(inplace=True),
652 | )
653 |
654 | self.relu = nn.ReLU(inplace=False)
655 |
656 | # if self.augment:
657 | # self.seghead_extra = segmenthead(
658 | # highres_planes, head_planes, num_classes)
659 |
660 | # self.final_layer = segmenthead(planes * 4, head_planes, num_classes)
661 |
662 | for m in self.modules():
663 | if isinstance(m, nn.Conv2d):
664 | nn.init.kaiming_normal_(
665 | m.weight, mode='fan_out', nonlinearity='relu')
666 | elif isinstance(m, BatchNorm2d):
667 | nn.init.constant_(m.weight, 1)
668 | nn.init.constant_(m.bias, 0)
669 |
670 | def forward(self, C, A):
671 |
672 | layers = []
673 |
674 | ### C and A are from 1/8 resolution, input for reconstruction ###[4,256,128,128] -> [4,64,128,128] -> [4,64,256,256]
675 | x = self.conv1(C+A)
676 | layers.append(x)
677 | ## upsample 128->256
678 | x = F.interpolate(
679 | self.relu(layers[0]),
680 | size=[256,256],
681 | mode='bilinear')
682 |
683 | ###[4,64,256,256] -> [4,32,256,256] -> [4,32,512,512]
684 | x = self.conv2(self.relu(x))
685 | # if x_41 has more channels than 64 (plane), then first need to add another layer to compress it to 64 channels
686 | layers.append(x)
687 | ## upsample 256->512
688 | x = F.interpolate(
689 | self.relu(layers[1]),
690 | size=[512,512],
691 | mode='bilinear')
692 |
693 | ###[4,32,512,512] -> [4,3,512,512] -> [4,3,1024,1024]
694 |
695 | x = self.conv3(self.relu(x))
696 | layers.append(x)
697 |
698 | x = F.interpolate(
699 | self.relu(layers[2]),
700 | size=[1024, 1024],
701 | mode='bilinear') ###[4,128,64,64]
702 |
703 | if self.augment:
704 | temp = x
705 |
706 | outputs = []
707 |
708 | outputs.append(x)
709 |
710 | if self.augment:
711 | x_extra = self.seghead_extra(temp)
712 | return [x, x_extra]
713 | else:
714 | return tuple(outputs), x
715 |
716 | def get_CA_interact( ):
717 | model = CAinteract( )
718 | return model
719 |
720 | def get_CA_merge( ):
721 | model = CAmerge( )
722 | return model
723 |
724 | if __name__ == "__main__":
725 |
726 | model = DualResNet_imagenet(pretrained=True)
727 |
--------------------------------------------------------------------------------