├── utils ├── __init__.py ├── logger.py ├── visualize.py ├── score.py ├── lr_scheduler.py ├── loss.py └── distributed.py ├── heatmapAD.png ├── models ├── __init__.py ├── DDRNet_39.py ├── DDRNet_23_slim.py ├── DDRNet_23.py └── DDRNet_23_vis1.py ├── LICENSE ├── README.md └── eval.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /heatmapAD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlpacaLLaMa643/All-day-CityScapes-segmentation/HEAD/heatmapAD.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .DDRNet_23_slim import get_ddrnet_23_slim 2 | from .DDRNet_39 import get_ddrnet_39 3 | from .DDRNet_23 import get_ddrnet_23 4 | from .DDRNet_23_vis1 import get_ddrnet_23_vis1 5 | 6 | from .DDRNet_23 import get_CA_interact 7 | from .DDRNet_23 import get_CA_merge 8 | 9 | models = { 10 | 'ddrnet_39': get_ddrnet_39, 11 | 'ddrnet_23_slim': get_ddrnet_23_slim, 12 | 'ddrnet_23': get_ddrnet_23, 13 | 'ddrnet_23_vis1': get_ddrnet_23_vis1, 14 | } 15 | 16 | intermodule={'inter': get_CA_interact} 17 | 18 | mergemodule={'merge': get_CA_merge} 19 | 20 | def get_segmentation_model(model, **kwargs): 21 | """Segmentation models""" 22 | return models[model.lower()](**kwargs) 23 | 24 | def get_inter_model(model, **kwargs): 25 | """interaction C-A module""" 26 | return intermodule[model.lower()](**kwargs) 27 | 28 | def get_merge_model(model, **kwargs): 29 | """merge C-A models""" 30 | return mergemodule[model.lower()](**kwargs) 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Qi BI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | 6 | __all__ = ['setup_logger'] 7 | 8 | 9 | # reference from: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/logger.py 10 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", mode='w'): 11 | logger = logging.getLogger(name) 12 | logger.setLevel(logging.DEBUG) 13 | # don't log results for the non-master process 14 | if distributed_rank > 0: 15 | return logger 16 | ch = logging.StreamHandler(stream=sys.stdout) 17 | ch.setLevel(logging.DEBUG) 18 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 19 | ch.setFormatter(formatter) 20 | logger.addHandler(ch) 21 | 22 | if save_dir: 23 | if not os.path.exists(save_dir): 24 | os.makedirs(save_dir) 25 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) # 'a+' for add, 'w' for overwrite 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # All-day-CityScapes-segmentation 2 | All-day Semantic Segmentation & All-day CityScapes dataset 3 | 4 | This is the official implementation of our work entitled as ```Interactive Learning of Intrinsic and Extrinsic Properties for All-day Semantic Segmentation```, accepted by ```IEEE Transactions on Image Processing```. 5 | 6 | ![avatar](/heatmapAD.png) 7 | 8 | # Dataset Download 9 | 10 | Please download ```All-day CityScapes``` from: [https://isis-data.science.uva.nl/cv/1ADcityscape.zip] 11 | 12 | For CopyRight issue, we only provide the rendered samples on both training and validation set of the original ```CityScapes```. 13 | 14 | All the sample name and data folder organization from ```All-day CityScapes``` is the same as the original ```CityScapes```. 15 | 16 | # Source Code & Implementation 17 | 18 | The proposed ```interactive intrinsic-extrinsic learning``` can be embedded into a variety of ```CNN``` and ```ViT``` based segmentation models. 19 | 20 | Here we provide the source code that is implemented on DDRNet-23 backbone, which is: 1) simple and easy to config; 2) most of the experiments in this paper conduct on. 21 | This implementation is highly based on the DDRNet source code. The original implementation of DDRNet can be found in this page. 22 | 23 | Please follow the below steps to run the AO-SegNet (DDRNet-23 based backbone). 24 | 25 | ### Step 1: Configuration 26 | 27 | Follow the original DDRNet-23 to prepare all the packages and data folder. 28 | 29 | ### Step 2: Train the Model 30 | 31 | ```python train.py --data_pth D:/alldaycityscapes --nclass 19``` 32 | 33 | ### Step 3: Evaluation 34 | 35 | ```python eval.py``` 36 | 37 | # Citation and Reference 38 | If you find this project useful, please cite: 39 | ``` 40 | @ARTICLE{Bi2023AD, 41 | author={Bi, Qi and You, Shaodi and Gevers, Theo}, 42 | journal={IEEE Transactions on Image Processing}, 43 | title={Interactive Learning of Intrinsic and Extrinsic Properties for All-day Semantic Segmentation}, 44 | year={2023}, 45 | volume={32}, 46 | number={}, 47 | pages={3821-3835}, 48 | doi={10.1109/TIP.2023.3290469}} 49 | ``` 50 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from train import parse_args 4 | from utils.distributed import synchronize, get_rank, make_data_sampler, make_batch_data_sampler 5 | from utils.logger import setup_logger 6 | from utils.visualize import get_color_pallete 7 | from utils.score import SegmentationMetric 8 | from models import get_segmentation_model 9 | from dataloader.cityscapes import CitySegmentation 10 | from torchvision import transforms 11 | import torch.backends.cudnn as cudnn 12 | import torch.utils.data as data 13 | import torch.nn as nn 14 | import torch 15 | import os 16 | import sys 17 | 18 | cur_path = os.path.abspath(os.path.dirname(__file__)) 19 | root_path = os.path.split(cur_path)[0] 20 | sys.path.append(root_path) 21 | 22 | 23 | class Evaluator(object): 24 | def __init__(self, args): 25 | self.args = args 26 | self.args.pretrained = True 27 | self.device = torch.device(args.device) 28 | 29 | # image transform 30 | input_transform = transforms.Compose([ 31 | transforms.ToTensor(), 32 | transforms.Normalize([.485, .456, .406], [.229, .224, .225]), 33 | ]) 34 | 35 | # dataset and dataloader 36 | val_dataset = CitySegmentation( 37 | args.data_path, split='val', mode='testval', transform=input_transform) 38 | val_sampler = make_data_sampler(val_dataset, False, args.distributed) 39 | val_batch_sampler = make_batch_data_sampler( 40 | val_sampler, images_per_batch=1) 41 | self.val_loader = data.DataLoader(dataset=val_dataset, 42 | batch_sampler=val_batch_sampler, 43 | num_workers=args.workers, 44 | pin_memory=True) 45 | 46 | # create network 47 | BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d 48 | self.model = get_segmentation_model( 49 | model=args.model, pretrained=False).to(self.device) 50 | 51 | if self.args.pretrained: 52 | self.model.load_state_dict(torch.load( 53 | "./trained_models/ddrnet_23_dualresnet_citys_best_model.pth", 54 | map_location=self.args.device)) 55 | logger.info("Model restored successfully!!!!") 56 | 57 | if args.distributed: 58 | self.model = nn.parallel.DistributedDataParallel(self.model, 59 | device_ids=[args.local_rank], output_device=args.local_rank) 60 | 61 | self.model.to(self.device) 62 | 63 | self.metric = SegmentationMetric(val_dataset.num_class) 64 | 65 | def eval(self): 66 | self.metric.reset() 67 | self.model.eval() 68 | if self.args.distributed: 69 | model = self.model.module 70 | else: 71 | model = self.model 72 | logger.info("Start validation, Total sample: {:d}".format( 73 | len(self.val_loader))) 74 | for i, (image, target, filename) in enumerate(self.val_loader): 75 | image = image.to(self.device) 76 | target = target.to(self.device) 77 | 78 | with torch.no_grad(): 79 | outputs, _, _ = model(image) 80 | self.metric.update(outputs[0], target) 81 | pixAcc, mIoU = self.metric.get() 82 | logger.info("Sample: {:d}, validation pixAcc: {:.3f}, mIoU: {:.3f}".format( 83 | i + 1, pixAcc * 100, mIoU * 100)) 84 | 85 | if self.args.save_pred: 86 | pred = torch.argmax(outputs[0], 1) 87 | pred = pred.cpu().data.numpy() 88 | 89 | predict = pred.squeeze(0) 90 | mask = get_color_pallete(predict, self.args.dataset) 91 | mask.save(os.path.join( 92 | outdir, os.path.splitext(filename[0])[0] + '.png')) 93 | logger.info("Whole validation set mIoU: {:.3f}".format(mIoU * 100)) 94 | 95 | synchronize() 96 | 97 | 98 | if __name__ == '__main__': 99 | args = parse_args() 100 | num_gpus = int(os.environ["WORLD_SIZE"] 101 | ) if "WORLD_SIZE" in os.environ else 1 102 | args.distributed = num_gpus > 1 103 | if not args.no_cuda and torch.cuda.is_available(): 104 | cudnn.benchmark = True 105 | args.device = "cuda" 106 | else: 107 | args.distributed = False 108 | args.device = "cpu" 109 | if args.distributed: 110 | torch.cuda.set_device(args.local_rank) 111 | torch.distributed.init_process_group( 112 | backend="nccl", init_method="env://") 113 | synchronize() 114 | 115 | # TODO: optim code 116 | args.save_pred = True 117 | if args.save_pred: 118 | outdir = 'runs/pred_pic/{}_{}_{}'.format( 119 | args.model, args.backbone, args.dataset) 120 | if not os.path.exists(outdir): 121 | os.makedirs(outdir) 122 | 123 | logger = setup_logger("semantic_segmentation", args.log_dir, get_rank(), 124 | filename='{}_{}_{}_log.txt'.format(args.model, args.backbone, args.dataset), mode='a+') 125 | 126 | evaluator = Evaluator(args) 127 | evaluator.eval() 128 | torch.cuda.empty_cache() 129 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | __all__ = ['get_color_pallete', 'print_iou', 'set_img_color', 6 | 'show_prediction', 'show_colorful_images', 'save_colorful_images'] 7 | 8 | 9 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False): 10 | n = iu.size 11 | lines = [] 12 | for i in range(n): 13 | if class_names is None: 14 | cls = 'Class %d:' % (i + 1) 15 | else: 16 | cls = '%d %s' % (i + 1, class_names[i]) 17 | # lines.append('%-8s: %.3f%%' % (cls, iu[i] * 100)) 18 | mean_IU = np.nanmean(iu) 19 | mean_IU_no_back = np.nanmean(iu[1:]) 20 | if show_no_back: 21 | lines.append('mean_IU: %.3f%% || mean_IU_no_back: %.3f%% || mean_pixel_acc: %.3f%%' % ( 22 | mean_IU * 100, mean_IU_no_back * 100, mean_pixel_acc * 100)) 23 | else: 24 | lines.append('mean_IU: %.3f%% || mean_pixel_acc: %.3f%%' % (mean_IU * 100, mean_pixel_acc * 100)) 25 | lines.append('=================================================') 26 | line = "\n".join(lines) 27 | 28 | print(line) 29 | 30 | 31 | def set_img_color(img, label, colors, background=0, show255=False): 32 | for i in range(len(colors)): 33 | if i != background: 34 | img[np.where(label == i)] = colors[i] 35 | if show255: 36 | img[np.where(label == 255)] = 255 37 | 38 | return img 39 | 40 | 41 | def show_prediction(img, pred, colors, background=0): 42 | im = np.array(img, np.uint8) 43 | set_img_color(im, pred, colors, background) 44 | out = np.array(im) 45 | 46 | return out 47 | 48 | 49 | def show_colorful_images(prediction, palettes): 50 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()]) 51 | im.show() 52 | 53 | 54 | def save_colorful_images(prediction, filename, output_dir, palettes): 55 | ''' 56 | :param prediction: [B, H, W, C] 57 | ''' 58 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()]) 59 | fn = os.path.join(output_dir, filename) 60 | out_dir = os.path.split(fn)[0] 61 | if not os.path.exists(out_dir): 62 | os.mkdir(out_dir) 63 | im.save(fn) 64 | 65 | 66 | def get_color_pallete(npimg, dataset='pascal_voc'): 67 | """Visualize image. 68 | 69 | Parameters 70 | ---------- 71 | npimg : numpy.ndarray 72 | Single channel image with shape `H, W, 1`. 73 | dataset : str, default: 'pascal_voc' 74 | The dataset that model pretrained on. ('pascal_voc', 'ade20k') 75 | Returns 76 | ------- 77 | out_img : PIL.Image 78 | Image with color pallete 79 | """ 80 | # recovery boundary 81 | if dataset in ('pascal_voc', 'pascal_aug'): 82 | npimg[npimg == -1] = 255 83 | # put colormap 84 | if dataset == 'ade20k': 85 | npimg = npimg + 1 86 | out_img = Image.fromarray(npimg.astype('uint8')) 87 | out_img.putpalette(adepallete) 88 | return out_img 89 | elif dataset == 'citys': 90 | out_img = Image.fromarray(npimg.astype('uint8')) 91 | out_img.putpalette(cityspallete) 92 | return out_img 93 | out_img = Image.fromarray(npimg.astype('uint8')) 94 | out_img.putpalette(vocpallete) 95 | return out_img 96 | 97 | 98 | def _getvocpallete(num_cls): 99 | n = num_cls 100 | pallete = [0] * (n * 3) 101 | for j in range(0, n): 102 | lab = j 103 | pallete[j * 3 + 0] = 0 104 | pallete[j * 3 + 1] = 0 105 | pallete[j * 3 + 2] = 0 106 | i = 0 107 | while (lab > 0): 108 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 109 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 110 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 111 | i = i + 1 112 | lab >>= 3 113 | return pallete 114 | 115 | 116 | vocpallete = _getvocpallete(256) 117 | 118 | adepallete = [ 119 | 0, 0, 0, 120, 120, 120, 180, 120, 120, 6, 230, 230, 80, 50, 50, 4, 200, 3, 120, 120, 80, 140, 140, 140, 204, 120 | 5, 255, 230, 230, 230, 4, 250, 7, 224, 5, 255, 235, 255, 7, 150, 5, 61, 120, 120, 70, 8, 255, 51, 255, 6, 82, 121 | 143, 255, 140, 204, 255, 4, 255, 51, 7, 204, 70, 3, 0, 102, 200, 61, 230, 250, 255, 6, 51, 11, 102, 255, 255, 122 | 7, 71, 255, 9, 224, 9, 7, 230, 220, 220, 220, 255, 9, 92, 112, 9, 255, 8, 255, 214, 7, 255, 224, 255, 184, 6, 123 | 10, 255, 71, 255, 41, 10, 7, 255, 255, 224, 255, 8, 102, 8, 255, 255, 61, 6, 255, 194, 7, 255, 122, 8, 0, 255, 124 | 20, 255, 8, 41, 255, 5, 153, 6, 51, 255, 235, 12, 255, 160, 150, 20, 0, 163, 255, 140, 140, 140, 250, 10, 15, 125 | 20, 255, 0, 31, 255, 0, 255, 31, 0, 255, 224, 0, 153, 255, 0, 0, 0, 255, 255, 71, 0, 0, 235, 255, 0, 173, 255, 126 | 31, 0, 255, 11, 200, 200, 255, 82, 0, 0, 255, 245, 0, 61, 255, 0, 255, 112, 0, 255, 133, 255, 0, 0, 255, 163, 127 | 0, 255, 102, 0, 194, 255, 0, 0, 143, 255, 51, 255, 0, 0, 82, 255, 0, 255, 41, 0, 255, 173, 10, 0, 255, 173, 255, 128 | 0, 0, 255, 153, 255, 92, 0, 255, 0, 255, 255, 0, 245, 255, 0, 102, 255, 173, 0, 255, 0, 20, 255, 184, 184, 0, 129 | 31, 255, 0, 255, 61, 0, 71, 255, 255, 0, 204, 0, 255, 194, 0, 255, 82, 0, 10, 255, 0, 112, 255, 51, 0, 255, 0, 130 | 194, 255, 0, 122, 255, 0, 255, 163, 255, 153, 0, 0, 255, 10, 255, 112, 0, 143, 255, 0, 82, 0, 255, 163, 255, 131 | 0, 255, 235, 0, 8, 184, 170, 133, 0, 255, 0, 255, 92, 184, 0, 255, 255, 0, 31, 0, 184, 255, 0, 214, 255, 255, 132 | 0, 112, 92, 255, 0, 0, 224, 255, 112, 224, 255, 70, 184, 160, 163, 0, 255, 153, 0, 255, 71, 255, 0, 255, 0, 133 | 163, 255, 204, 0, 255, 0, 143, 0, 255, 235, 133, 255, 0, 255, 0, 235, 245, 0, 255, 255, 0, 122, 255, 245, 0, 134 | 10, 190, 212, 214, 255, 0, 0, 204, 255, 20, 0, 255, 255, 255, 0, 0, 153, 255, 0, 41, 255, 0, 255, 204, 41, 0, 135 | 255, 41, 255, 0, 173, 0, 255, 0, 245, 255, 71, 0, 255, 122, 0, 255, 0, 255, 184, 0, 92, 255, 184, 255, 0, 0, 136 | 133, 255, 255, 214, 0, 25, 194, 194, 102, 255, 0, 92, 0, 255] 137 | 138 | cityspallete = [ 139 | 128, 64, 128, 140 | 244, 35, 232, 141 | 70, 70, 70, 142 | 102, 102, 156, 143 | 190, 153, 153, 144 | 153, 153, 153, 145 | 250, 170, 30, 146 | 220, 220, 0, 147 | 107, 142, 35, 148 | 152, 251, 152, 149 | 0, 130, 180, 150 | 220, 20, 60, 151 | 255, 0, 0, 152 | 0, 0, 142, 153 | 0, 0, 70, 154 | 0, 60, 100, 155 | 0, 80, 100, 156 | 0, 0, 230, 157 | 119, 11, 32, 158 | ] 159 | -------------------------------------------------------------------------------- /utils/score.py: -------------------------------------------------------------------------------- 1 | """Evaluation Metrics for Semantic Segmentation""" 2 | import torch 3 | import numpy as np 4 | 5 | __all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union', 6 | 'pixelAccuracy', 'intersectionAndUnion', 'hist_info', 'compute_score'] 7 | 8 | 9 | class SegmentationMetric(object): 10 | """Computes pixAcc and mIoU metric scores 11 | """ 12 | 13 | def __init__(self, nclass): 14 | super(SegmentationMetric, self).__init__() 15 | self.nclass = nclass 16 | self.reset() 17 | 18 | def update(self, preds, labels): 19 | """Updates the internal evaluation result. 20 | 21 | Parameters 22 | ---------- 23 | labels : 'NumpyArray' or list of `NumpyArray` 24 | The labels of the data. 25 | preds : 'NumpyArray' or list of `NumpyArray` 26 | Predicted values. 27 | """ 28 | 29 | def evaluate_worker(self, pred, label): 30 | correct, labeled = batch_pix_accuracy(pred, label) 31 | inter, union = batch_intersection_union(pred, label, self.nclass) 32 | 33 | self.total_correct += correct 34 | self.total_label += labeled 35 | if self.total_inter.device != inter.device: 36 | self.total_inter = self.total_inter.to(inter.device) 37 | self.total_union = self.total_union.to(union.device) 38 | self.total_inter += inter 39 | self.total_union += union 40 | 41 | if isinstance(preds, torch.Tensor): 42 | evaluate_worker(self, preds, labels) 43 | elif isinstance(preds, (list, tuple)): 44 | for (pred, label) in zip(preds, labels): 45 | evaluate_worker(self, pred, label) 46 | 47 | def get(self): 48 | """Gets the current evaluation result. 49 | 50 | Returns 51 | ------- 52 | metrics : tuple of float 53 | pixAcc and mIoU 54 | """ 55 | pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove np.spacing(1) 56 | IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union) 57 | mIoU = IoU.mean().item() 58 | return pixAcc, mIoU 59 | 60 | def reset(self): 61 | """Resets the internal evaluation result to initial state.""" 62 | self.total_inter = torch.zeros(self.nclass) 63 | self.total_union = torch.zeros(self.nclass) 64 | self.total_correct = 0 65 | self.total_label = 0 66 | 67 | 68 | # pytorch version 69 | def batch_pix_accuracy(output, target): 70 | """PixAcc""" 71 | # inputs are numpy array, output 4D, target 3D 72 | predict = torch.argmax(output.long(), 1) + 1 73 | target = target.long() + 1 74 | 75 | pixel_labeled = torch.sum(target > 0).item() 76 | pixel_correct = torch.sum((predict == target) * (target > 0)).item() 77 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 78 | return pixel_correct, pixel_labeled 79 | 80 | 81 | def batch_intersection_union(output, target, nclass): 82 | """mIoU""" 83 | # inputs are numpy array, output 4D, target 3D 84 | mini = 1 85 | maxi = nclass 86 | nbins = nclass 87 | predict = torch.argmax(output, 1) + 1 88 | target = target.float() + 1 89 | 90 | predict = predict.float() * (target > 0).float() 91 | intersection = predict * (predict == target).float() 92 | # areas of intersection and union 93 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. 94 | area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) 95 | area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) 96 | area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) 97 | area_union = area_pred + area_lab - area_inter 98 | assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" 99 | return area_inter.float(), area_union.float() 100 | 101 | 102 | def pixelAccuracy(imPred, imLab): 103 | """ 104 | This function takes the prediction and label of a single image, returns pixel-wise accuracy 105 | To compute over many images do: 106 | for i = range(Nimages): 107 | (pixel_accuracy[i], pixel_correct[i], pixel_labeled[i]) = \ 108 | pixelAccuracy(imPred[i], imLab[i]) 109 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (np.spacing(1) + np.sum(pixel_labeled)) 110 | """ 111 | # Remove classes from unlabeled pixels in gt image. 112 | # We should not penalize detections in unlabeled portions of the image. 113 | pixel_labeled = np.sum(imLab >= 0) 114 | pixel_correct = np.sum((imPred == imLab) * (imLab >= 0)) 115 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled 116 | return (pixel_accuracy, pixel_correct, pixel_labeled) 117 | 118 | 119 | def intersectionAndUnion(imPred, imLab, numClass): 120 | """ 121 | This function takes the prediction and label of a single image, 122 | returns intersection and union areas for each class 123 | To compute over many images do: 124 | for i in range(Nimages): 125 | (area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i]) 126 | IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1) 127 | """ 128 | # Remove classes from unlabeled pixels in gt image. 129 | # We should not penalize detections in unlabeled portions of the image. 130 | imPred = imPred * (imLab >= 0) 131 | 132 | # Compute area intersection: 133 | intersection = imPred * (imPred == imLab) 134 | (area_intersection, _) = np.histogram(intersection, bins=numClass, range=(1, numClass)) 135 | 136 | # Compute area union: 137 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 138 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 139 | area_union = area_pred + area_lab - area_intersection 140 | return (area_intersection, area_union) 141 | 142 | 143 | def hist_info(pred, label, num_cls): 144 | assert pred.shape == label.shape 145 | k = (label >= 0) & (label < num_cls) 146 | labeled = np.sum(k) 147 | correct = np.sum((pred[k] == label[k])) 148 | 149 | return np.bincount(num_cls * label[k].astype(int) + pred[k], minlength=num_cls ** 2).reshape(num_cls, 150 | num_cls), labeled, correct 151 | 152 | 153 | def compute_score(hist, correct, labeled): 154 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 155 | mean_IU = np.nanmean(iu) 156 | mean_IU_no_back = np.nanmean(iu[1:]) 157 | freq = hist.sum(1) / hist.sum() 158 | freq_IU = (iu[freq > 0] * freq[freq > 0]).sum() 159 | mean_pixel_acc = correct / labeled 160 | 161 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc 162 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """Popular Learning Rate Schedulers""" 2 | from __future__ import division 3 | import math 4 | import torch 5 | 6 | from bisect import bisect_right 7 | 8 | __all__ = ['LRScheduler', 'WarmupMultiStepLR', 'WarmupPolyLR'] 9 | 10 | 11 | class LRScheduler(object): 12 | r"""Learning Rate Scheduler 13 | 14 | Parameters 15 | ---------- 16 | mode : str 17 | Modes for learning rate scheduler. 18 | Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'. 19 | base_lr : float 20 | Base learning rate, i.e. the starting learning rate. 21 | target_lr : float 22 | Target learning rate, i.e. the ending learning rate. 23 | With constant mode target_lr is ignored. 24 | niters : int 25 | Number of iterations to be scheduled. 26 | nepochs : int 27 | Number of epochs to be scheduled. 28 | iters_per_epoch : int 29 | Number of iterations in each epoch. 30 | offset : int 31 | Number of iterations before this scheduler. 32 | power : float 33 | Power parameter of poly scheduler. 34 | step_iter : list 35 | A list of iterations to decay the learning rate. 36 | step_epoch : list 37 | A list of epochs to decay the learning rate. 38 | step_factor : float 39 | Learning rate decay factor. 40 | """ 41 | 42 | def __init__(self, mode, base_lr=0.01, target_lr=0, niters=0, nepochs=0, iters_per_epoch=0, 43 | offset=0, power=0.9, step_iter=None, step_epoch=None, step_factor=0.1, warmup_epochs=0): 44 | super(LRScheduler, self).__init__() 45 | assert (mode in ['constant', 'step', 'linear', 'poly', 'cosine']) 46 | 47 | if mode == 'step': 48 | assert (step_iter is not None or step_epoch is not None) 49 | self.niters = niters 50 | self.step = step_iter 51 | epoch_iters = nepochs * iters_per_epoch 52 | if epoch_iters > 0: 53 | self.niters = epoch_iters 54 | if step_epoch is not None: 55 | self.step = [s * iters_per_epoch for s in step_epoch] 56 | 57 | self.step_factor = step_factor 58 | self.base_lr = base_lr 59 | self.target_lr = base_lr if mode == 'constant' else target_lr 60 | self.offset = offset 61 | self.power = power 62 | self.warmup_iters = warmup_epochs * iters_per_epoch 63 | self.mode = mode 64 | 65 | def __call__(self, optimizer, num_update): 66 | self.update(num_update) 67 | assert self.learning_rate >= 0 68 | self._adjust_learning_rate(optimizer, self.learning_rate) 69 | 70 | def update(self, num_update): 71 | N = self.niters - 1 72 | T = num_update - self.offset 73 | T = min(max(0, T), N) 74 | 75 | if self.mode == 'constant': 76 | factor = 0 77 | elif self.mode == 'linear': 78 | factor = 1 - T / N 79 | elif self.mode == 'poly': 80 | factor = pow(1 - T / N, self.power) 81 | elif self.mode == 'cosine': 82 | factor = (1 + math.cos(math.pi * T / N)) / 2 83 | elif self.mode == 'step': 84 | if self.step is not None: 85 | count = sum([1 for s in self.step if s <= T]) 86 | factor = pow(self.step_factor, count) 87 | else: 88 | factor = 1 89 | else: 90 | raise NotImplementedError 91 | 92 | # warm up lr schedule 93 | if self.warmup_iters > 0 and T < self.warmup_iters: 94 | factor = factor * 1.0 * T / self.warmup_iters 95 | 96 | if self.mode == 'step': 97 | self.learning_rate = self.base_lr * factor 98 | else: 99 | self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor 100 | 101 | def _adjust_learning_rate(self, optimizer, lr): 102 | optimizer.param_groups[0]['lr'] = lr 103 | # enlarge the lr at the head 104 | for i in range(1, len(optimizer.param_groups)): 105 | optimizer.param_groups[i]['lr'] = lr * 10 106 | 107 | 108 | # separating MultiStepLR with WarmupLR 109 | # but the current LRScheduler design doesn't allow it 110 | # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/solver/lr_scheduler.py 111 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 112 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, 113 | warmup_iters=500, warmup_method="linear", last_epoch=-1): 114 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 115 | if not list(milestones) == sorted(milestones): 116 | raise ValueError( 117 | "Milestones should be a list of" " increasing integers. Got {}", milestones) 118 | if warmup_method not in ("constant", "linear"): 119 | raise ValueError( 120 | "Only 'constant' or 'linear' warmup_method accepted got {}".format(warmup_method)) 121 | 122 | self.milestones = milestones 123 | self.gamma = gamma 124 | self.warmup_factor = warmup_factor 125 | self.warmup_iters = warmup_iters 126 | self.warmup_method = warmup_method 127 | 128 | def get_lr(self): 129 | warmup_factor = 1 130 | if self.last_epoch < self.warmup_iters: 131 | if self.warmup_method == 'constant': 132 | warmup_factor = self.warmup_factor 133 | elif self.warmup_factor == 'linear': 134 | alpha = float(self.last_epoch) / self.warmup_iters 135 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 136 | return [base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 137 | for base_lr in self.base_lrs] 138 | 139 | 140 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler): 141 | def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3, 142 | warmup_iters=500, warmup_method='linear', last_epoch=-1): 143 | if warmup_method not in ("constant", "linear"): 144 | raise ValueError( 145 | "Only 'constant' or 'linear' warmup_method accepted " 146 | "got {}".format(warmup_method)) 147 | 148 | self.target_lr = target_lr 149 | self.max_iters = max_iters 150 | self.power = power 151 | self.warmup_factor = warmup_factor 152 | self.warmup_iters = warmup_iters 153 | self.warmup_method = warmup_method 154 | 155 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch) 156 | 157 | def get_lr(self): 158 | N = self.max_iters - self.warmup_iters 159 | T = self.last_epoch - self.warmup_iters 160 | if self.last_epoch < self.warmup_iters: 161 | if self.warmup_method == 'constant': 162 | warmup_factor = self.warmup_factor 163 | elif self.warmup_method == 'linear': 164 | alpha = float(self.last_epoch) / self.warmup_iters 165 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 166 | else: 167 | raise ValueError("Unknown warmup type.") 168 | return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs] 169 | factor = pow(1 - T / N, self.power) 170 | return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs] 171 | 172 | 173 | if __name__ == '__main__': 174 | import torch 175 | import torch.nn as nn 176 | 177 | model = nn.Conv2d(16, 16, 3, 1, 1) 178 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 179 | lr_scheduler = WarmupPolyLR(optimizer, niters=1000) 180 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | """Custom losses.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | __all__ = ['MixSoftmaxCrossEntropyLoss', 'MixSoftmaxCrossEntropyOHEMLoss', 9 | 'EncNetLoss', 'ICNetLoss', 'get_segmentation_loss'] 10 | 11 | 12 | # TODO: optim function 13 | class MixSoftmaxCrossEntropyLoss(nn.CrossEntropyLoss): 14 | def __init__(self, aux=True, aux_weight=0.2, ignore_index=-1, **kwargs): 15 | super(MixSoftmaxCrossEntropyLoss, self).__init__( 16 | ignore_index=ignore_index) 17 | self.aux = aux 18 | self.aux_weight = aux_weight 19 | 20 | def _aux_forward(self, *inputs, **kwargs): 21 | *preds, target = tuple(inputs) 22 | 23 | loss = super(MixSoftmaxCrossEntropyLoss, self).forward(preds[0], target) 24 | for i in range(1, len(preds)): 25 | aux_loss = super(MixSoftmaxCrossEntropyLoss, 26 | self).forward(preds[i], target) 27 | loss += self.aux_weight * aux_loss 28 | return loss 29 | 30 | def forward(self, *inputs, **kwargs): 31 | preds, target = tuple(inputs) 32 | inputs = tuple(list(preds) + [target]) 33 | if self.aux: 34 | return dict(loss=self._aux_forward(*inputs)) 35 | else: 36 | return dict(loss=super(MixSoftmaxCrossEntropyLoss, self).forward(*inputs)) 37 | 38 | 39 | # reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/loss.py 40 | class EncNetLoss(nn.CrossEntropyLoss): 41 | """2D Cross Entropy Loss with SE Loss""" 42 | 43 | def __init__(self, se_loss=True, se_weight=0.2, nclass=19, aux=False, 44 | aux_weight=0.4, weight=None, ignore_index=-1, **kwargs): 45 | super(EncNetLoss, self).__init__(weight, None, ignore_index) 46 | self.se_loss = se_loss 47 | self.aux = aux 48 | self.nclass = nclass 49 | self.se_weight = se_weight 50 | self.aux_weight = aux_weight 51 | self.bceloss = nn.BCELoss(weight) 52 | 53 | def forward(self, *inputs): 54 | preds, target = tuple(inputs) 55 | inputs = tuple(list(preds) + [target]) 56 | if not self.se_loss and not self.aux: 57 | return super(EncNetLoss, self).forward(*inputs) 58 | elif not self.se_loss: 59 | pred1, pred2, target = tuple(inputs) 60 | loss1 = super(EncNetLoss, self).forward(pred1, target) 61 | loss2 = super(EncNetLoss, self).forward(pred2, target) 62 | return dict(loss=loss1 + self.aux_weight * loss2) 63 | elif not self.aux: 64 | pred, se_pred, target = tuple(inputs) 65 | se_target = self._get_batch_label_vector( 66 | target, nclass=self.nclass).type_as(pred) 67 | loss1 = super(EncNetLoss, self).forward(pred, target) 68 | loss2 = self.bceloss(torch.sigmoid(se_pred), se_target) 69 | return dict(loss=loss1 + self.se_weight * loss2) 70 | else: 71 | pred1, se_pred, pred2, target = tuple(inputs) 72 | se_target = self._get_batch_label_vector( 73 | target, nclass=self.nclass).type_as(pred1) 74 | loss1 = super(EncNetLoss, self).forward(pred1, target) 75 | loss2 = super(EncNetLoss, self).forward(pred2, target) 76 | loss3 = self.bceloss(torch.sigmoid(se_pred), se_target) 77 | return dict(loss=loss1 + self.aux_weight * loss2 + self.se_weight * loss3) 78 | 79 | @staticmethod 80 | def _get_batch_label_vector(target, nclass): 81 | # target is a 3D Variable BxHxW, output is 2D BxnClass 82 | batch = target.size(0) 83 | tvect = Variable(torch.zeros(batch, nclass)) 84 | for i in range(batch): 85 | hist = torch.histc(target[i].cpu().data.float(), 86 | bins=nclass, min=0, 87 | max=nclass - 1) 88 | vect = hist > 0 89 | tvect[i] = vect 90 | return tvect 91 | 92 | 93 | # TODO: optim function 94 | class ICNetLoss(nn.CrossEntropyLoss): 95 | """Cross Entropy Loss for ICNet""" 96 | 97 | def __init__(self, nclass, aux_weight=0.4, ignore_index=-1, **kwargs): 98 | super(ICNetLoss, self).__init__(ignore_index=ignore_index) 99 | self.nclass = nclass 100 | self.aux_weight = aux_weight 101 | 102 | def forward(self, *inputs): 103 | preds, target = tuple(inputs) 104 | inputs = tuple(list(preds) + [target]) 105 | 106 | pred, pred_sub4, pred_sub8, pred_sub16, target = tuple(inputs) 107 | # [batch, W, H] -> [batch, 1, W, H] 108 | target = target.unsqueeze(1).float() 109 | target_sub4 = F.interpolate(target, pred_sub4.size( 110 | )[2:], mode='bilinear', align_corners=True).squeeze(1).long() 111 | target_sub8 = F.interpolate(target, pred_sub8.size( 112 | )[2:], mode='bilinear', align_corners=True).squeeze(1).long() 113 | target_sub16 = F.interpolate(target, pred_sub16.size()[2:], mode='bilinear', align_corners=True).squeeze( 114 | 1).long() 115 | loss1 = super(ICNetLoss, self).forward(pred_sub4, target_sub4) 116 | loss2 = super(ICNetLoss, self).forward(pred_sub8, target_sub8) 117 | loss3 = super(ICNetLoss, self).forward(pred_sub16, target_sub16) 118 | return dict(loss=loss1 + loss2 * self.aux_weight + loss3 * self.aux_weight) 119 | 120 | 121 | class OhemCrossEntropy2d(nn.Module): 122 | def __init__(self, ignore_index=-1, thresh=0.7, min_kept=100000, use_weight=True, **kwargs): 123 | super(OhemCrossEntropy2d, self).__init__() 124 | self.ignore_index = ignore_index 125 | self.thresh = float(thresh) 126 | self.min_kept = int(min_kept) 127 | if use_weight: 128 | weight = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 129 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 130 | 1.0865, 1.1529, 1.0507]) 131 | self.criterion = torch.nn.CrossEntropyLoss( 132 | weight=weight, ignore_index=ignore_index) 133 | else: 134 | self.criterion = torch.nn.CrossEntropyLoss( 135 | ignore_index=ignore_index) 136 | 137 | def forward(self, pred, target): 138 | n, c, h, w = pred.size() 139 | target = target.view(-1) 140 | valid_mask = target.ne(self.ignore_index) 141 | target = target * valid_mask.long() 142 | num_valid = valid_mask.sum() 143 | 144 | prob = F.softmax(pred, dim=1) 145 | prob = prob.transpose(0, 1).reshape(c, -1) 146 | 147 | if self.min_kept > num_valid: 148 | print("Lables: {}".format(num_valid)) 149 | elif num_valid > 0: 150 | #prob = prob.masked_fill_(1 - valid_mask, 1) 151 | prob = prob.masked_fill_(~valid_mask, 1) 152 | mask_prob = prob[target, torch.arange( 153 | len(target), dtype=torch.long)] 154 | threshold = self.thresh 155 | if self.min_kept > 0: 156 | index = mask_prob.argsort() 157 | threshold_index = index[min(len(index), self.min_kept) - 1] 158 | if mask_prob[threshold_index] > self.thresh: 159 | threshold = mask_prob[threshold_index] 160 | kept_mask = mask_prob.le(threshold) 161 | valid_mask = valid_mask * kept_mask 162 | target = target * kept_mask.long() 163 | 164 | #target = target.masked_fill_(1 - valid_mask, self.ignore_index) 165 | target = target.masked_fill_(~valid_mask, self.ignore_index) 166 | target = target.view(n, h, w) 167 | 168 | return self.criterion(pred, target) 169 | 170 | 171 | class MixSoftmaxCrossEntropyOHEMLoss(OhemCrossEntropy2d): 172 | def __init__(self, aux=False, aux_weight=0.4, weight=None, ignore_index=-1, **kwargs): 173 | super(MixSoftmaxCrossEntropyOHEMLoss, self).__init__( 174 | ignore_index=ignore_index) 175 | self.aux = aux 176 | self.aux_weight = aux_weight 177 | self.bceloss = nn.BCELoss(weight) 178 | 179 | def _aux_forward(self, *inputs, **kwargs): 180 | *preds, target = tuple(inputs) 181 | 182 | loss = super(MixSoftmaxCrossEntropyOHEMLoss, 183 | self).forward(preds[0], target) 184 | for i in range(1, len(preds)): 185 | aux_loss = super(MixSoftmaxCrossEntropyOHEMLoss, 186 | self).forward(preds[i], target) 187 | loss += self.aux_weight * aux_loss 188 | return loss 189 | 190 | def forward(self, *inputs): 191 | preds, target = tuple(inputs) 192 | inputs = tuple(list(preds) + [target]) 193 | if self.aux: 194 | return dict(loss=self._aux_forward(*inputs)) 195 | else: 196 | return dict(loss=super(MixSoftmaxCrossEntropyOHEMLoss, self).forward(*inputs)) 197 | 198 | 199 | def get_segmentation_loss(model, use_ohem=False, **kwargs): 200 | if use_ohem: 201 | return MixSoftmaxCrossEntropyOHEMLoss(**kwargs) 202 | 203 | model = model.lower() 204 | if model == 'encnet': 205 | return EncNetLoss(**kwargs) 206 | elif model == 'icnet': 207 | return ICNetLoss(**kwargs) 208 | else: 209 | return MixSoftmaxCrossEntropyLoss(**kwargs) 210 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains primitives for multi-gpu communication. 3 | This is useful when doing distributed training. 4 | """ 5 | import math 6 | import pickle 7 | import torch 8 | import torch.utils.data as data 9 | import torch.distributed as dist 10 | 11 | from torch.utils.data.sampler import Sampler, BatchSampler 12 | 13 | __all__ = ['get_world_size', 'get_rank', 'synchronize', 'is_main_process', 14 | 'all_gather', 'make_data_sampler', 'make_batch_data_sampler', 15 | 'reduce_dict', 'reduce_loss_dict'] 16 | 17 | 18 | # reference: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/utils/comm.py 19 | def get_world_size(): 20 | if not dist.is_available(): 21 | return 1 22 | if not dist.is_initialized(): 23 | return 1 24 | return dist.get_world_size() 25 | 26 | 27 | def get_rank(): 28 | if not dist.is_available(): 29 | return 0 30 | if not dist.is_initialized(): 31 | return 0 32 | return dist.get_rank() 33 | 34 | 35 | def is_main_process(): 36 | return get_rank() == 0 37 | 38 | 39 | def synchronize(): 40 | """ 41 | Helper function to synchronize (barrier) among all processes when 42 | using distributed training 43 | """ 44 | if not dist.is_available(): 45 | return 46 | if not dist.is_initialized(): 47 | return 48 | world_size = dist.get_world_size() 49 | if world_size == 1: 50 | return 51 | dist.barrier() 52 | 53 | 54 | def all_gather(data): 55 | """ 56 | Run all_gather on arbitrary picklable data (not necessarily tensors) 57 | Args: 58 | data: any picklable object 59 | Returns: 60 | list[data]: list of data gathered from each rank 61 | """ 62 | world_size = get_world_size() 63 | if world_size == 1: 64 | return [data] 65 | 66 | # serialized to a Tensor 67 | buffer = pickle.dumps(data) 68 | storage = torch.ByteStorage.from_buffer(buffer) 69 | tensor = torch.ByteTensor(storage).to("cuda") 70 | 71 | # obtain Tensor size of each rank 72 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 73 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 74 | dist.all_gather(size_list, local_size) 75 | size_list = [int(size.item()) for size in size_list] 76 | max_size = max(size_list) 77 | 78 | # receiving Tensor from all ranks 79 | # we pad the tensor because torch all_gather does not support 80 | # gathering tensors of different shapes 81 | tensor_list = [] 82 | for _ in size_list: 83 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 84 | if local_size != max_size: 85 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 86 | tensor = torch.cat((tensor, padding), dim=0) 87 | dist.all_gather(tensor_list, tensor) 88 | 89 | data_list = [] 90 | for size, tensor in zip(size_list, tensor_list): 91 | buffer = tensor.cpu().numpy().tobytes()[:size] 92 | data_list.append(pickle.loads(buffer)) 93 | 94 | return data_list 95 | 96 | 97 | def reduce_dict(input_dict, average=True): 98 | """ 99 | Args: 100 | input_dict (dict): all the values will be reduced 101 | average (bool): whether to do average or sum 102 | Reduce the values in the dictionary from all processes so that process with rank 103 | 0 has the averaged results. Returns a dict with the same fields as 104 | input_dict, after reduction. 105 | """ 106 | world_size = get_world_size() 107 | if world_size < 2: 108 | return input_dict 109 | with torch.no_grad(): 110 | names = [] 111 | values = [] 112 | # sort the keys so that they are consistent across processes 113 | for k in sorted(input_dict.keys()): 114 | names.append(k) 115 | values.append(input_dict[k]) 116 | values = torch.stack(values, dim=0) 117 | dist.reduce(values, dst=0) 118 | if dist.get_rank() == 0 and average: 119 | # only main process gets accumulated, so only divide by 120 | # world_size in this case 121 | values /= world_size 122 | reduced_dict = {k: v for k, v in zip(names, values)} 123 | return reduced_dict 124 | 125 | 126 | def reduce_loss_dict(loss_dict): 127 | """ 128 | Reduce the loss dictionary from all processes so that process with rank 129 | 0 has the averaged results. Returns a dict with the same fields as 130 | loss_dict, after reduction. 131 | """ 132 | world_size = get_world_size() 133 | if world_size < 2: 134 | return loss_dict 135 | with torch.no_grad(): 136 | loss_names = [] 137 | all_losses = [] 138 | for k in sorted(loss_dict.keys()): 139 | loss_names.append(k) 140 | all_losses.append(loss_dict[k]) 141 | all_losses = torch.stack(all_losses, dim=0) 142 | dist.reduce(all_losses, dst=0) 143 | if dist.get_rank() == 0: 144 | # only main process gets accumulated, so only divide by 145 | # world_size in this case 146 | all_losses /= world_size 147 | reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} 148 | return reduced_losses 149 | 150 | 151 | def make_data_sampler(dataset, shuffle, distributed): 152 | if distributed: 153 | return DistributedSampler(dataset, shuffle=shuffle) 154 | if shuffle: 155 | sampler = data.sampler.RandomSampler(dataset) 156 | else: 157 | sampler = data.sampler.SequentialSampler(dataset) 158 | return sampler 159 | 160 | 161 | def make_batch_data_sampler(sampler, images_per_batch, num_iters=None, start_iter=0): 162 | batch_sampler = data.sampler.BatchSampler(sampler, images_per_batch, drop_last=True) 163 | if num_iters is not None: 164 | batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iters, start_iter) 165 | return batch_sampler 166 | 167 | 168 | # Code is copy-pasted from https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/data/samplers/distributed.py 169 | class DistributedSampler(Sampler): 170 | """Sampler that restricts data loading to a subset of the dataset. 171 | It is especially useful in conjunction with 172 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 173 | process can pass a DistributedSampler instance as a DataLoader sampler, 174 | and load a subset of the original dataset that is exclusive to it. 175 | .. note:: 176 | Dataset is assumed to be of constant size. 177 | Arguments: 178 | dataset: Dataset used for sampling. 179 | num_replicas (optional): Number of processes participating in 180 | distributed training. 181 | rank (optional): Rank of the current process within num_replicas. 182 | """ 183 | 184 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 185 | if num_replicas is None: 186 | if not dist.is_available(): 187 | raise RuntimeError("Requires distributed package to be available") 188 | num_replicas = dist.get_world_size() 189 | if rank is None: 190 | if not dist.is_available(): 191 | raise RuntimeError("Requires distributed package to be available") 192 | rank = dist.get_rank() 193 | self.dataset = dataset 194 | self.num_replicas = num_replicas 195 | self.rank = rank 196 | self.epoch = 0 197 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 198 | self.total_size = self.num_samples * self.num_replicas 199 | self.shuffle = shuffle 200 | 201 | def __iter__(self): 202 | if self.shuffle: 203 | # deterministically shuffle based on epoch 204 | g = torch.Generator() 205 | g.manual_seed(self.epoch) 206 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 207 | else: 208 | indices = torch.arange(len(self.dataset)).tolist() 209 | 210 | # add extra samples to make it evenly divisible 211 | indices += indices[: (self.total_size - len(indices))] 212 | assert len(indices) == self.total_size 213 | 214 | # subsample 215 | offset = self.num_samples * self.rank 216 | indices = indices[offset: offset + self.num_samples] 217 | assert len(indices) == self.num_samples 218 | 219 | return iter(indices) 220 | 221 | def __len__(self): 222 | return self.num_samples 223 | 224 | def set_epoch(self, epoch): 225 | self.epoch = epoch 226 | 227 | 228 | class IterationBasedBatchSampler(BatchSampler): 229 | """ 230 | Wraps a BatchSampler, resampling from it until 231 | a specified number of iterations have been sampled 232 | """ 233 | 234 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 235 | self.batch_sampler = batch_sampler 236 | self.num_iterations = num_iterations 237 | self.start_iter = start_iter 238 | 239 | def __iter__(self): 240 | iteration = self.start_iter 241 | while iteration <= self.num_iterations: 242 | # if the underlying sampler has a set_epoch method, like 243 | # DistributedSampler, used for making each process see 244 | # a different split of the dataset, then set it 245 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 246 | self.batch_sampler.sampler.set_epoch(iteration) 247 | for batch in self.batch_sampler: 248 | iteration += 1 249 | if iteration > self.num_iterations: 250 | break 251 | yield batch 252 | 253 | def __len__(self): 254 | return self.num_iterations 255 | 256 | 257 | if __name__ == '__main__': 258 | pass 259 | -------------------------------------------------------------------------------- /models/DDRNet_39.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn import init 8 | from collections import OrderedDict 9 | 10 | # for single gpu 11 | BatchNorm2d = nn.BatchNorm2d 12 | bn_mom = 0.1 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | 24 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False): 25 | super(BasicBlock, self).__init__() 26 | self.conv1 = conv3x3(inplanes, planes, stride) 27 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom) 31 | self.downsample = downsample 32 | self.stride = stride 33 | self.no_relu = no_relu 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | 50 | if self.no_relu: 51 | return out 52 | else: 53 | return self.relu(out) 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 2 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes, momentum=bn_mom) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes, momentum=bn_mom) 66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 67 | bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * self.expansion, 69 | momentum=bn_mom) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | self.no_relu = no_relu 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | if self.no_relu: 94 | return out 95 | else: 96 | return self.relu(out) 97 | 98 | 99 | class DAPPM(nn.Module): 100 | def __init__(self, inplanes, branch_planes, outplanes): 101 | super(DAPPM, self).__init__() 102 | self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2), 103 | BatchNorm2d(inplanes, momentum=bn_mom), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(inplanes, branch_planes, 106 | kernel_size=1, bias=False), 107 | ) 108 | self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4), 109 | BatchNorm2d(inplanes, momentum=bn_mom), 110 | nn.ReLU(inplace=True), 111 | nn.Conv2d(inplanes, branch_planes, 112 | kernel_size=1, bias=False), 113 | ) 114 | self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8), 115 | BatchNorm2d(inplanes, momentum=bn_mom), 116 | nn.ReLU(inplace=True), 117 | nn.Conv2d(inplanes, branch_planes, 118 | kernel_size=1, bias=False), 119 | ) 120 | self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 121 | BatchNorm2d(inplanes, momentum=bn_mom), 122 | nn.ReLU(inplace=True), 123 | nn.Conv2d(inplanes, branch_planes, 124 | kernel_size=1, bias=False), 125 | ) 126 | self.scale0 = nn.Sequential( 127 | BatchNorm2d(inplanes, momentum=bn_mom), 128 | nn.ReLU(inplace=True), 129 | nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False), 130 | ) 131 | self.process1 = nn.Sequential( 132 | BatchNorm2d(branch_planes, momentum=bn_mom), 133 | nn.ReLU(inplace=True), 134 | nn.Conv2d(branch_planes, branch_planes, 135 | kernel_size=3, padding=1, bias=False), 136 | ) 137 | self.process2 = nn.Sequential( 138 | BatchNorm2d(branch_planes, momentum=bn_mom), 139 | nn.ReLU(inplace=True), 140 | nn.Conv2d(branch_planes, branch_planes, 141 | kernel_size=3, padding=1, bias=False), 142 | ) 143 | self.process3 = nn.Sequential( 144 | BatchNorm2d(branch_planes, momentum=bn_mom), 145 | nn.ReLU(inplace=True), 146 | nn.Conv2d(branch_planes, branch_planes, 147 | kernel_size=3, padding=1, bias=False), 148 | ) 149 | self.process4 = nn.Sequential( 150 | BatchNorm2d(branch_planes, momentum=bn_mom), 151 | nn.ReLU(inplace=True), 152 | nn.Conv2d(branch_planes, branch_planes, 153 | kernel_size=3, padding=1, bias=False), 154 | ) 155 | self.compression = nn.Sequential( 156 | BatchNorm2d(branch_planes * 5, momentum=bn_mom), 157 | nn.ReLU(inplace=True), 158 | nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False), 159 | ) 160 | self.shortcut = nn.Sequential( 161 | BatchNorm2d(inplanes, momentum=bn_mom), 162 | nn.ReLU(inplace=True), 163 | nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False), 164 | ) 165 | 166 | def forward(self, x): 167 | 168 | #x = self.downsample(x) 169 | width = x.shape[-1] 170 | height = x.shape[-2] 171 | x_list = [] 172 | 173 | x_list.append(self.scale0(x)) 174 | x_list.append(self.process1((F.interpolate(self.scale1(x), 175 | size=[height, width], 176 | mode='bilinear')+x_list[0]))) 177 | x_list.append((self.process2((F.interpolate(self.scale2(x), 178 | size=[height, width], 179 | mode='bilinear')+x_list[1])))) 180 | x_list.append(self.process3((F.interpolate(self.scale3(x), 181 | size=[height, width], 182 | mode='bilinear')+x_list[2]))) 183 | x_list.append(self.process4((F.interpolate(self.scale4(x), 184 | size=[height, width], 185 | mode='bilinear')+x_list[3]))) 186 | 187 | out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x) 188 | return out 189 | 190 | 191 | class segmenthead(nn.Module): 192 | 193 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None): 194 | super(segmenthead, self).__init__() 195 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom) 196 | self.conv1 = nn.Conv2d(inplanes, interplanes, 197 | kernel_size=3, padding=1, bias=False) 198 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom) 199 | self.relu = nn.ReLU(inplace=True) 200 | self.conv2 = nn.Conv2d(interplanes, outplanes, 201 | kernel_size=1, padding=0, bias=True) 202 | self.scale_factor = scale_factor 203 | 204 | def forward(self, x): 205 | 206 | x = self.conv1(self.relu(self.bn1(x))) 207 | out = self.conv2(self.relu(self.bn2(x))) 208 | 209 | if self.scale_factor is not None: 210 | height = x.shape[-2] * self.scale_factor 211 | width = x.shape[-1] * self.scale_factor 212 | out = F.interpolate(out, 213 | size=[height, width], 214 | mode='bilinear') 215 | 216 | return out 217 | 218 | 219 | class DualResNet(nn.Module): 220 | 221 | def __init__(self, block, layers, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False): 222 | super(DualResNet, self).__init__() 223 | 224 | highres_planes = planes * 2 225 | self.augment = augment 226 | 227 | self.conv1 = nn.Sequential( 228 | nn.Conv2d(3, planes, kernel_size=3, stride=2, padding=1), 229 | BatchNorm2d(planes, momentum=bn_mom), 230 | nn.ReLU(inplace=True), 231 | nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1), 232 | BatchNorm2d(planes, momentum=bn_mom), 233 | nn.ReLU(inplace=True), 234 | ) 235 | 236 | self.relu = nn.ReLU(inplace=False) 237 | self.layer1 = self._make_layer(block, planes, planes, layers[0]) 238 | self.layer2 = self._make_layer( 239 | block, planes, planes * 2, layers[1], stride=2) 240 | self.layer3_1 = self._make_layer( 241 | block, planes * 2, planes * 4, layers[2] // 2, stride=2) 242 | self.layer3_2 = self._make_layer( 243 | block, planes * 4, planes * 4, layers[2] // 2) 244 | self.layer4 = self._make_layer( 245 | block, planes * 4, planes * 8, layers[3], stride=2) 246 | 247 | self.compression3_1 = nn.Sequential( 248 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False), 249 | BatchNorm2d(highres_planes, momentum=bn_mom), 250 | ) 251 | 252 | self.compression3_2 = nn.Sequential( 253 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False), 254 | BatchNorm2d(highres_planes, momentum=bn_mom), 255 | ) 256 | 257 | self.compression4 = nn.Sequential( 258 | nn.Conv2d(planes * 8, highres_planes, kernel_size=1, bias=False), 259 | BatchNorm2d(highres_planes, momentum=bn_mom), 260 | ) 261 | 262 | self.down3_1 = nn.Sequential( 263 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3, 264 | stride=2, padding=1, bias=False), 265 | BatchNorm2d(planes * 4, momentum=bn_mom), 266 | ) 267 | 268 | self.down3_2 = nn.Sequential( 269 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3, 270 | stride=2, padding=1, bias=False), 271 | BatchNorm2d(planes * 4, momentum=bn_mom), 272 | ) 273 | 274 | self.down4 = nn.Sequential( 275 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3, 276 | stride=2, padding=1, bias=False), 277 | BatchNorm2d(planes * 4, momentum=bn_mom), 278 | nn.ReLU(inplace=True), 279 | nn.Conv2d(planes * 4, planes * 8, kernel_size=3, 280 | stride=2, padding=1, bias=False), 281 | BatchNorm2d(planes * 8, momentum=bn_mom), 282 | ) 283 | 284 | self.layer3_1_ = self._make_layer( 285 | block, planes * 2, highres_planes, layers[2] // 2) 286 | 287 | self.layer3_2_ = self._make_layer( 288 | block, highres_planes, highres_planes, layers[2] // 2) 289 | 290 | self.layer4_ = self._make_layer( 291 | block, highres_planes, highres_planes, layers[3]) 292 | 293 | self.layer5_ = self._make_layer( 294 | Bottleneck, highres_planes, highres_planes, 1) 295 | 296 | self.layer5 = self._make_layer( 297 | Bottleneck, planes * 8, planes * 8, 1, stride=2) 298 | 299 | self.spp = DAPPM(planes * 16, spp_planes, planes * 4) 300 | 301 | if self.augment: 302 | self.seghead_extra = segmenthead( 303 | highres_planes, head_planes, num_classes) 304 | 305 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes) 306 | 307 | for m in self.modules(): 308 | if isinstance(m, nn.Conv2d): 309 | nn.init.kaiming_normal_( 310 | m.weight, mode='fan_out', nonlinearity='relu') 311 | elif isinstance(m, BatchNorm2d): 312 | nn.init.constant_(m.weight, 1) 313 | nn.init.constant_(m.bias, 0) 314 | 315 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 316 | downsample = None 317 | if stride != 1 or inplanes != planes * block.expansion: 318 | downsample = nn.Sequential( 319 | nn.Conv2d(inplanes, planes * block.expansion, 320 | kernel_size=1, stride=stride, bias=False), 321 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom), 322 | ) 323 | 324 | layers = [] 325 | layers.append(block(inplanes, planes, stride, downsample)) 326 | inplanes = planes * block.expansion 327 | for i in range(1, blocks): 328 | if i == (blocks-1): 329 | layers.append(block(inplanes, planes, stride=1, no_relu=True)) 330 | else: 331 | layers.append(block(inplanes, planes, stride=1, no_relu=False)) 332 | 333 | return nn.Sequential(*layers) 334 | 335 | def forward(self, x): 336 | 337 | width_output_or = x.shape[-1] 338 | height_output_or = x.shape[-2] 339 | 340 | width_output = x.shape[-1] // 8 341 | height_output = x.shape[-2] // 8 342 | layers = [] 343 | 344 | x = self.conv1(x) 345 | 346 | x = self.layer1(x) 347 | layers.append(x) 348 | 349 | x = self.layer2(self.relu(x)) 350 | layers.append(x) 351 | 352 | x = self.layer3_1(self.relu(x)) 353 | layers.append(x) 354 | x_ = self.layer3_1_(self.relu(layers[1])) 355 | x = x + self.down3_1(self.relu(x_)) 356 | x_ = x_ + F.interpolate( 357 | self.compression3_1(self.relu(layers[2])), 358 | size=[height_output, width_output], 359 | mode='bilinear') 360 | 361 | x = self.layer3_2(self.relu(x)) 362 | layers.append(x) 363 | x_ = self.layer3_2_(self.relu(x_)) 364 | x = x + self.down3_2(self.relu(x_)) 365 | x_ = x_ + F.interpolate( 366 | self.compression3_2(self.relu(layers[3])), 367 | size=[height_output, width_output], 368 | mode='bilinear') 369 | 370 | temp = x_ 371 | 372 | x = self.layer4(self.relu(x)) 373 | layers.append(x) 374 | x_ = self.layer4_(self.relu(x_)) 375 | x = x + self.down4(self.relu(x_)) 376 | x_ = x_ + F.interpolate( 377 | self.compression4(self.relu(layers[4])), 378 | size=[height_output, width_output], 379 | mode='bilinear') 380 | 381 | x_ = self.layer5_(self.relu(x_)) 382 | x = F.interpolate( 383 | self.spp(self.layer5(self.relu(x))), 384 | size=[height_output, width_output], 385 | mode='bilinear') 386 | 387 | x_ = self.final_layer(x + x_) 388 | 389 | outputs = [] 390 | 391 | x_ = F.interpolate(x_, 392 | size=[height_output_or, width_output_or], 393 | mode='bilinear', align_corners=True) 394 | outputs.append(x_) 395 | 396 | if self.augment: 397 | x_extra = self.seghead_extra(temp) 398 | return [x_, x_extra] 399 | else: 400 | return tuple(outputs) 401 | 402 | 403 | def DualResNet_imagenet(pretrained=False): 404 | model = DualResNet(BasicBlock, [3, 4, 6, 3], num_classes=19, 405 | planes=64, spp_planes=128, head_planes=256, augment=False) 406 | if pretrained: 407 | 408 | pretrained_state = torch.load( 409 | "./models/DDRNet39_imagenet.pth", map_location='cpu') 410 | model_dict = model.state_dict() 411 | pretrained_state = {k: v for k, v in pretrained_state.items() if 412 | (k in model_dict and v.shape == model_dict[k].shape)} 413 | model_dict.update(pretrained_state) 414 | 415 | model.load_state_dict(model_dict, strict=False) 416 | print("Having loaded imagenet-pretrained weights successfully!") 417 | return model 418 | 419 | 420 | def get_ddrnet_39(pretrained=False): 421 | 422 | model = DualResNet_imagenet(pretrained=pretrained) 423 | return model 424 | 425 | 426 | if __name__ == "__main__": 427 | 428 | model = DualResNet_imagenet(pretrained=True) 429 | -------------------------------------------------------------------------------- /models/DDRNet_23_slim.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | import torch 8 | from collections import OrderedDict 9 | 10 | BatchNorm2d = nn.BatchNorm2d 11 | bn_mom = 0.1 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom) 30 | self.downsample = downsample 31 | self.stride = stride 32 | self.no_relu = no_relu 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | 49 | if self.no_relu: 50 | return out 51 | else: 52 | return self.relu(out) 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 2 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom) 65 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 66 | bias=False) 67 | self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | self.no_relu = no_relu 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | if self.no_relu: 92 | return out 93 | else: 94 | return self.relu(out) 95 | 96 | 97 | class DAPPM(nn.Module): 98 | def __init__(self, inplanes, branch_planes, outplanes): 99 | super(DAPPM, self).__init__() 100 | self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2), 101 | BatchNorm2d(inplanes, momentum=bn_mom), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(inplanes, branch_planes, 104 | kernel_size=1, bias=False), 105 | ) 106 | self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4), 107 | BatchNorm2d(inplanes, momentum=bn_mom), 108 | nn.ReLU(inplace=True), 109 | nn.Conv2d(inplanes, branch_planes, 110 | kernel_size=1, bias=False), 111 | ) 112 | self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8), 113 | BatchNorm2d(inplanes, momentum=bn_mom), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(inplanes, branch_planes, 116 | kernel_size=1, bias=False), 117 | ) 118 | self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 119 | BatchNorm2d(inplanes, momentum=bn_mom), 120 | nn.ReLU(inplace=True), 121 | nn.Conv2d(inplanes, branch_planes, 122 | kernel_size=1, bias=False), 123 | ) 124 | self.scale0 = nn.Sequential( 125 | BatchNorm2d(inplanes, momentum=bn_mom), 126 | nn.ReLU(inplace=True), 127 | nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False), 128 | ) 129 | self.process1 = nn.Sequential( 130 | BatchNorm2d(branch_planes, momentum=bn_mom), 131 | nn.ReLU(inplace=True), 132 | nn.Conv2d(branch_planes, branch_planes, 133 | kernel_size=3, padding=1, bias=False), 134 | ) 135 | self.process2 = nn.Sequential( 136 | BatchNorm2d(branch_planes, momentum=bn_mom), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(branch_planes, branch_planes, 139 | kernel_size=3, padding=1, bias=False), 140 | ) 141 | self.process3 = nn.Sequential( 142 | BatchNorm2d(branch_planes, momentum=bn_mom), 143 | nn.ReLU(inplace=True), 144 | nn.Conv2d(branch_planes, branch_planes, 145 | kernel_size=3, padding=1, bias=False), 146 | ) 147 | self.process4 = nn.Sequential( 148 | BatchNorm2d(branch_planes, momentum=bn_mom), 149 | nn.ReLU(inplace=True), 150 | nn.Conv2d(branch_planes, branch_planes, 151 | kernel_size=3, padding=1, bias=False), 152 | ) 153 | self.compression = nn.Sequential( 154 | BatchNorm2d(branch_planes * 5, momentum=bn_mom), 155 | nn.ReLU(inplace=True), 156 | nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False), 157 | ) 158 | self.shortcut = nn.Sequential( 159 | BatchNorm2d(inplanes, momentum=bn_mom), 160 | nn.ReLU(inplace=True), 161 | nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False), 162 | ) 163 | 164 | def forward(self, x): 165 | 166 | # x = self.downsample(x) 167 | width = x.shape[-1] 168 | height = x.shape[-2] 169 | x_list = [] 170 | 171 | x_list.append(self.scale0(x)) 172 | x_list.append(self.process1((F.interpolate(self.scale1(x), 173 | size=[height, width], 174 | mode='bilinear')+x_list[0]))) 175 | x_list.append((self.process2((F.interpolate(self.scale2(x), 176 | size=[height, width], 177 | mode='bilinear')+x_list[1])))) 178 | x_list.append(self.process3((F.interpolate(self.scale3(x), 179 | size=[height, width], 180 | mode='bilinear')+x_list[2]))) 181 | x_list.append(self.process4((F.interpolate(self.scale4(x), 182 | size=[height, width], 183 | mode='bilinear')+x_list[3]))) 184 | 185 | out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x) 186 | return out 187 | 188 | 189 | class segmenthead(nn.Module): 190 | 191 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None): 192 | super(segmenthead, self).__init__() 193 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom) 194 | self.conv1 = nn.Conv2d(inplanes, interplanes, 195 | kernel_size=3, padding=1, bias=False) 196 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom) 197 | self.relu = nn.ReLU(inplace=True) 198 | self.conv2 = nn.Conv2d(interplanes, outplanes, 199 | kernel_size=1, padding=0, bias=True) 200 | self.scale_factor = scale_factor 201 | 202 | def forward(self, x): 203 | 204 | x = self.conv1(self.relu(self.bn1(x))) 205 | out = self.conv2(self.relu(self.bn2(x))) 206 | 207 | if self.scale_factor is not None: 208 | height = x.shape[-2] * self.scale_factor 209 | width = x.shape[-1] * self.scale_factor 210 | out = F.interpolate(out, 211 | size=[height, width], 212 | mode='bilinear') 213 | 214 | return out 215 | 216 | 217 | class DualResNet(nn.Module): 218 | 219 | def __init__(self, block, layers, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False): 220 | super(DualResNet, self).__init__() 221 | 222 | highres_planes = planes * 2 223 | self.augment = augment 224 | 225 | self.conv1 = nn.Sequential( 226 | nn.Conv2d(3, planes, kernel_size=3, stride=2, padding=1), 227 | BatchNorm2d(planes, momentum=bn_mom), 228 | nn.ReLU(inplace=True), 229 | nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1), 230 | BatchNorm2d(planes, momentum=bn_mom), 231 | nn.ReLU(inplace=True), 232 | ) 233 | 234 | self.relu = nn.ReLU(inplace=False) 235 | self.layer1 = self._make_layer(block, planes, planes, layers[0]) 236 | self.layer2 = self._make_layer( 237 | block, planes, planes * 2, layers[1], stride=2) 238 | self.layer3 = self._make_layer( 239 | block, planes * 2, planes * 4, layers[2], stride=2) 240 | self.layer4 = self._make_layer( 241 | block, planes * 4, planes * 8, layers[3], stride=2) 242 | 243 | self.compression3 = nn.Sequential( 244 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False), 245 | BatchNorm2d(highres_planes, momentum=bn_mom), 246 | ) 247 | 248 | self.compression4 = nn.Sequential( 249 | nn.Conv2d(planes * 8, highres_planes, kernel_size=1, bias=False), 250 | BatchNorm2d(highres_planes, momentum=bn_mom), 251 | ) 252 | 253 | self.down3 = nn.Sequential( 254 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3, 255 | stride=2, padding=1, bias=False), 256 | BatchNorm2d(planes * 4, momentum=bn_mom), 257 | ) 258 | 259 | self.down4 = nn.Sequential( 260 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3, 261 | stride=2, padding=1, bias=False), 262 | BatchNorm2d(planes * 4, momentum=bn_mom), 263 | nn.ReLU(inplace=True), 264 | nn.Conv2d(planes * 4, planes * 8, kernel_size=3, 265 | stride=2, padding=1, bias=False), 266 | BatchNorm2d(planes * 8, momentum=bn_mom), 267 | ) 268 | 269 | self.layer3_ = self._make_layer(block, planes * 2, highres_planes, 2) 270 | 271 | self.layer4_ = self._make_layer( 272 | block, highres_planes, highres_planes, 2) 273 | 274 | self.layer5_ = self._make_layer( 275 | Bottleneck, highres_planes, highres_planes, 1) 276 | 277 | self.layer5 = self._make_layer( 278 | Bottleneck, planes * 8, planes * 8, 1, stride=2) 279 | 280 | self.spp = DAPPM(planes * 16, spp_planes, planes * 4) 281 | 282 | if self.augment: 283 | self.seghead_extra = segmenthead( 284 | highres_planes, head_planes, num_classes) 285 | 286 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes) 287 | 288 | for m in self.modules(): 289 | if isinstance(m, nn.Conv2d): 290 | nn.init.kaiming_normal_( 291 | m.weight, mode='fan_out', nonlinearity='relu') 292 | elif isinstance(m, BatchNorm2d): 293 | nn.init.constant_(m.weight, 1) 294 | nn.init.constant_(m.bias, 0) 295 | 296 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 297 | downsample = None 298 | if stride != 1 or inplanes != planes * block.expansion: 299 | downsample = nn.Sequential( 300 | nn.Conv2d(inplanes, planes * block.expansion, 301 | kernel_size=1, stride=stride, bias=False), 302 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom), 303 | ) 304 | 305 | layers = [] 306 | layers.append(block(inplanes, planes, stride, downsample)) 307 | inplanes = planes * block.expansion 308 | for i in range(1, blocks): 309 | if i == (blocks-1): 310 | layers.append(block(inplanes, planes, stride=1, no_relu=True)) 311 | else: 312 | layers.append(block(inplanes, planes, stride=1, no_relu=False)) 313 | 314 | return nn.Sequential(*layers) 315 | 316 | def forward(self, x): 317 | width_output_or = x.shape[-1] 318 | height_output_or = x.shape[-2] 319 | 320 | width_output = x.shape[-1] // 8 321 | height_output = x.shape[-2] // 8 322 | layers = [] 323 | 324 | x = self.conv1(x) 325 | 326 | x = self.layer1(x) 327 | layers.append(x) 328 | 329 | x = self.layer2(self.relu(x)) 330 | layers.append(x) 331 | 332 | x = self.layer3(self.relu(x)) 333 | x_i=x 334 | layers.append(x) 335 | x_ = self.layer3_(self.relu(layers[1])) 336 | 337 | x = x + self.down3(self.relu(x_)) 338 | x_ = x_ + F.interpolate( 339 | self.compression3(self.relu(layers[2])), 340 | size=[height_output, width_output], 341 | mode='bilinear') 342 | if self.augment: 343 | temp = x_ 344 | 345 | x = self.layer4(self.relu(x)) 346 | layers.append(x) 347 | x_ = self.layer4_(self.relu(x_)) 348 | 349 | x = x + self.down4(self.relu(x_)) 350 | x_ = x_ + F.interpolate( 351 | self.compression4(self.relu(layers[3])), 352 | size=[height_output, width_output], 353 | mode='bilinear') 354 | 355 | ### high resolution 1/8 branch, apperance branch 356 | x_ = self.layer5_(self.relu(x_)) 357 | x_a =x_ 358 | 359 | ### low resolution 1/64 branch, need upsampling, content branch 360 | x = F.interpolate( 361 | self.spp(self.layer5(self.relu(x))), 362 | size=[height_output, width_output], 363 | mode='bilinear') 364 | x_c = x 365 | 366 | x_ = self.final_layer(x + x_) 367 | 368 | outputs = [] 369 | outputs_c=[] 370 | outputs_a=[] 371 | outputs_i=[] 372 | 373 | x_ = F.interpolate(x_, 374 | size=[height_output_or, width_output_or], 375 | mode='bilinear', align_corners=True) 376 | outputs.append(x_) 377 | 378 | 379 | x_c = F.interpolate(x_c, 380 | size=[height_output_or, width_output_or], 381 | mode='bilinear', align_corners=True) 382 | outputs_c.append(x_c) 383 | 384 | x_a = F.interpolate(x_a, 385 | size=[height_output_or, width_output_or], 386 | mode='bilinear', align_corners=True) 387 | outputs_a.append(x_a) 388 | 389 | x_i = F.interpolate(x_i, 390 | size=[height_output_or, width_output_or], 391 | mode='bilinear', align_corners=True) 392 | outputs_i.append(x_i) 393 | 394 | # if self.augment: 395 | # assert 1 == 0 396 | # x_extra = self.seghead_extra(temp) 397 | # return [x_, x_extra] 398 | # else: 399 | # #return tuple(outputs), tuple(outputs_c), tuple(outputs_a), tuple(outputs_i) 400 | # # return tuple(outputs)#, tuple(outputs_c), tuple(outputs_a), tuple(outputs_i) 401 | # return outputs, outputs_c ,outputs_a, outputs_i 402 | #return tuple(outputs), tuple(outputs_c), tuple(outputs_a), tuple(outputs_i) 403 | return tuple(outputs)#, x_c, x_a, x_i 404 | #return outputs, outputs_c ,outputs_a, outputs_i 405 | 406 | 407 | def DualResNet_imagenet(pretrained=False): 408 | #model, C, A, I = DualResNet(BasicBlock, [2, 2, 2, 2], num_classes=19, 409 | # planes=32, spp_planes=128, head_planes=64, augment=False) 410 | model = DualResNet(BasicBlock, [2, 2, 2, 2], num_classes=19, 411 | planes=32, spp_planes=128, head_planes=64, augment=False) 412 | if pretrained: 413 | # remove hardcoded path by user provided path 414 | checkpoint = torch.load( 415 | "D:/DDR/models/DDRNet23s_imagenet.pth", map_location='cpu') 416 | 417 | new_state_dict = OrderedDict() 418 | model_dict = model.state_dict() 419 | 420 | for existing_key, _ in model_dict.items(): 421 | for k, v in checkpoint.items(): 422 | if existing_key == k: 423 | name = k[7:] 424 | new_state_dict[k] = v 425 | model_dict.update(new_state_dict) 426 | model.load_state_dict(model_dict) 427 | 428 | print("Having loaded imagenet-pretrained successfully!") 429 | 430 | # model.load_state_dict(new_state_dict, strict=False) 431 | return model#, C, A, I 432 | 433 | 434 | def get_ddrnet_23_slim(pretrained=True): 435 | 436 | model = DualResNet_imagenet(pretrained=pretrained) 437 | return model 438 | 439 | 440 | if __name__ == "__main__": 441 | 442 | model = DualResNet_imagenet(pretrained=True) 443 | -------------------------------------------------------------------------------- /models/DDRNet_23.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | from collections import OrderedDict 8 | 9 | BatchNorm2d = nn.BatchNorm2d 10 | bn_mom = 0.1 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom) 29 | self.downsample = downsample 30 | self.stride = stride 31 | self.no_relu = no_relu 32 | 33 | def forward(self, x): 34 | residual = x 35 | 36 | out = self.conv1(x) 37 | out = self.bn1(out) 38 | out = self.relu(out) 39 | 40 | out = self.conv2(out) 41 | out = self.bn2(out) 42 | 43 | if self.downsample is not None: 44 | residual = self.downsample(x) 45 | 46 | out += residual 47 | 48 | if self.no_relu: 49 | return out 50 | else: 51 | return self.relu(out) 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 2 56 | 57 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 60 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom) 61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 62 | padding=1, bias=False) 63 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom) 64 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 65 | bias=False) 66 | self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | self.no_relu = no_relu 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | if self.no_relu: 91 | return out 92 | else: 93 | return self.relu(out) 94 | 95 | 96 | class DAPPM(nn.Module): 97 | def __init__(self, inplanes, branch_planes, outplanes): 98 | super(DAPPM, self).__init__() 99 | self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2), 100 | BatchNorm2d(inplanes, momentum=bn_mom), 101 | nn.ReLU(inplace=True), 102 | nn.Conv2d(inplanes, branch_planes, 103 | kernel_size=1, bias=False), 104 | ) 105 | self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4), 106 | BatchNorm2d(inplanes, momentum=bn_mom), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(inplanes, branch_planes, 109 | kernel_size=1, bias=False), 110 | ) 111 | self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8), 112 | BatchNorm2d(inplanes, momentum=bn_mom), 113 | nn.ReLU(inplace=True), 114 | nn.Conv2d(inplanes, branch_planes, 115 | kernel_size=1, bias=False), 116 | ) 117 | self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 118 | BatchNorm2d(inplanes, momentum=bn_mom), 119 | nn.ReLU(inplace=True), 120 | nn.Conv2d(inplanes, branch_planes, 121 | kernel_size=1, bias=False), 122 | ) 123 | self.scale0 = nn.Sequential( 124 | BatchNorm2d(inplanes, momentum=bn_mom), 125 | nn.ReLU(inplace=True), 126 | nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False), 127 | ) 128 | self.process1 = nn.Sequential( 129 | BatchNorm2d(branch_planes, momentum=bn_mom), 130 | nn.ReLU(inplace=True), 131 | nn.Conv2d(branch_planes, branch_planes, 132 | kernel_size=3, padding=1, bias=False), 133 | ) 134 | self.process2 = nn.Sequential( 135 | BatchNorm2d(branch_planes, momentum=bn_mom), 136 | nn.ReLU(inplace=True), 137 | nn.Conv2d(branch_planes, branch_planes, 138 | kernel_size=3, padding=1, bias=False), 139 | ) 140 | self.process3 = nn.Sequential( 141 | BatchNorm2d(branch_planes, momentum=bn_mom), 142 | nn.ReLU(inplace=True), 143 | nn.Conv2d(branch_planes, branch_planes, 144 | kernel_size=3, padding=1, bias=False), 145 | ) 146 | self.process4 = nn.Sequential( 147 | BatchNorm2d(branch_planes, momentum=bn_mom), 148 | nn.ReLU(inplace=True), 149 | nn.Conv2d(branch_planes, branch_planes, 150 | kernel_size=3, padding=1, bias=False), 151 | ) 152 | self.compression = nn.Sequential( 153 | BatchNorm2d(branch_planes * 5, momentum=bn_mom), 154 | nn.ReLU(inplace=True), 155 | nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False), 156 | ) 157 | self.shortcut = nn.Sequential( 158 | BatchNorm2d(inplanes, momentum=bn_mom), 159 | nn.ReLU(inplace=True), 160 | nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False), 161 | ) 162 | 163 | def forward(self, x): 164 | 165 | # x = self.downsample(x) 166 | width = x.shape[-1] 167 | height = x.shape[-2] 168 | x_list = [] 169 | 170 | x_list.append(self.scale0(x)) 171 | x_list.append(self.process1((F.interpolate(self.scale1(x), 172 | size=[height, width], 173 | mode='bilinear')+x_list[0]))) 174 | x_list.append((self.process2((F.interpolate(self.scale2(x), 175 | size=[height, width], 176 | mode='bilinear')+x_list[1])))) 177 | x_list.append(self.process3((F.interpolate(self.scale3(x), 178 | size=[height, width], 179 | mode='bilinear')+x_list[2]))) 180 | x_list.append(self.process4((F.interpolate(self.scale4(x), 181 | size=[height, width], 182 | mode='bilinear')+x_list[3]))) 183 | 184 | out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x) 185 | return out 186 | 187 | 188 | class segmenthead(nn.Module): 189 | 190 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None): 191 | super(segmenthead, self).__init__() 192 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom) 193 | self.conv1 = nn.Conv2d(inplanes, interplanes, 194 | kernel_size=3, padding=1, bias=False) 195 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom) 196 | self.relu = nn.ReLU(inplace=True) 197 | self.conv2 = nn.Conv2d(interplanes, outplanes, 198 | kernel_size=1, padding=0, bias=True) 199 | self.scale_factor = scale_factor 200 | 201 | ### for C representation branch. feature refine and attention weight matrix 202 | self.conv3 = nn.Sequential( 203 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0), 204 | BatchNorm2d(1, momentum=bn_mom), 205 | nn.ReLU(inplace=True), 206 | ) 207 | 208 | ### for A representation branch. feature refine and attention weight matrix 209 | self.conv4 = nn.Sequential( 210 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0), 211 | BatchNorm2d(1, momentum=bn_mom), 212 | nn.ReLU(inplace=True), 213 | ) 214 | 215 | self.conv5 = nn.Sequential( 216 | nn.Conv2d(256, 19, kernel_size=1, stride=1, padding=0), 217 | BatchNorm2d(19, momentum=bn_mom), 218 | nn.ReLU(inplace=True), 219 | ) 220 | 221 | self.conv6 = nn.Sequential( 222 | nn.Conv2d(256, 19, kernel_size=1, stride=1, padding=0), 223 | BatchNorm2d(19, momentum=bn_mom), 224 | nn.ReLU(inplace=True), 225 | ) 226 | 227 | def forward(self, C, A): 228 | 229 | attention_c=self.conv3(C) 230 | attention_a=self.conv4(A) 231 | 232 | C_ = 2*C + attention_c *C + A *attention_c 233 | A_ =A + attention_a *A 234 | 235 | x = self.conv1(self.relu(self.bn1(C_+A_))) 236 | out = self.conv2(self.relu(self.bn2(x))) 237 | 238 | Cupdate=self.conv5(C) 239 | C_update=self.conv6(C_) 240 | 241 | if self.scale_factor is not None: 242 | height = x.shape[-2] * self.scale_factor 243 | width = x.shape[-1] * self.scale_factor 244 | out = F.interpolate(out, 245 | size=[height, width], 246 | mode='bilinear') 247 | 248 | return out, C_update, Cupdate 249 | 250 | 251 | class segmentheadold(nn.Module): 252 | 253 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None): 254 | super(segmenthead, self).__init__() 255 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom) 256 | self.conv1 = nn.Conv2d(inplanes, interplanes, 257 | kernel_size=3, padding=1, bias=False) 258 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom) 259 | self.relu = nn.ReLU(inplace=True) 260 | self.conv2 = nn.Conv2d(interplanes, outplanes, 261 | kernel_size=1, padding=0, bias=True) 262 | self.scale_factor = scale_factor 263 | 264 | def forward(self, x): 265 | 266 | x = self.conv1(self.relu(self.bn1(x))) 267 | out = self.conv2(self.relu(self.bn2(x))) 268 | 269 | if self.scale_factor is not None: 270 | height = x.shape[-2] * self.scale_factor 271 | width = x.shape[-1] * self.scale_factor 272 | out = F.interpolate(out, 273 | size=[height, width], 274 | mode='bilinear') 275 | 276 | return out 277 | 278 | 279 | class DualResNet(nn.Module): 280 | 281 | def __init__(self, block, layers, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False): 282 | super(DualResNet, self).__init__() 283 | 284 | highres_planes = planes * 2 285 | self.augment = augment 286 | 287 | self.conv1 = nn.Sequential( 288 | nn.Conv2d(3, planes, kernel_size=3, stride=2, padding=1), 289 | BatchNorm2d(planes, momentum=bn_mom), 290 | nn.ReLU(inplace=True), 291 | nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1), 292 | BatchNorm2d(planes, momentum=bn_mom), 293 | nn.ReLU(inplace=True), 294 | ) 295 | 296 | self.relu = nn.ReLU(inplace=False) 297 | self.layer1 = self._make_layer(block, planes, planes, layers[0]) 298 | self.layer2 = self._make_layer( 299 | block, planes, planes * 2, layers[1], stride=2) 300 | self.layer3 = self._make_layer( 301 | block, planes * 2, planes * 4, layers[2], stride=2) 302 | self.layer4 = self._make_layer( 303 | block, planes * 4, planes * 8, layers[3], stride=2) 304 | 305 | self.compression3 = nn.Sequential( 306 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False), 307 | BatchNorm2d(highres_planes, momentum=bn_mom), 308 | ) 309 | 310 | self.compression4 = nn.Sequential( 311 | nn.Conv2d(planes * 8, highres_planes, kernel_size=1, bias=False), 312 | BatchNorm2d(highres_planes, momentum=bn_mom), 313 | ) 314 | 315 | self.down3 = nn.Sequential( 316 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3, 317 | stride=2, padding=1, bias=False), 318 | BatchNorm2d(planes * 4, momentum=bn_mom), 319 | ) 320 | 321 | self.down4 = nn.Sequential( 322 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3, 323 | stride=2, padding=1, bias=False), 324 | BatchNorm2d(planes * 4, momentum=bn_mom), 325 | nn.ReLU(inplace=True), 326 | nn.Conv2d(planes * 4, planes * 8, kernel_size=3, 327 | stride=2, padding=1, bias=False), 328 | BatchNorm2d(planes * 8, momentum=bn_mom), 329 | ) 330 | 331 | self.layer3_ = self._make_layer(block, planes * 2, highres_planes, 2) 332 | 333 | self.layer4_ = self._make_layer( 334 | block, highres_planes, highres_planes, 2) 335 | 336 | self.layer5_ = self._make_layer( 337 | Bottleneck, highres_planes, highres_planes, 1) 338 | 339 | self.layer5 = self._make_layer( 340 | Bottleneck, planes * 8, planes * 8, 1, stride=2) 341 | 342 | self.spp = DAPPM(planes * 16, spp_planes, planes * 4) 343 | 344 | if self.augment: 345 | self.seghead_extra = segmenthead( 346 | highres_planes, head_planes, num_classes) 347 | 348 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes) 349 | 350 | for m in self.modules(): 351 | if isinstance(m, nn.Conv2d): 352 | nn.init.kaiming_normal_( 353 | m.weight, mode='fan_out', nonlinearity='relu') 354 | elif isinstance(m, BatchNorm2d): 355 | nn.init.constant_(m.weight, 1) 356 | nn.init.constant_(m.bias, 0) 357 | 358 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 359 | downsample = None 360 | if stride != 1 or inplanes != planes * block.expansion: 361 | downsample = nn.Sequential( 362 | nn.Conv2d(inplanes, planes * block.expansion, 363 | kernel_size=1, stride=stride, bias=False), 364 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom), 365 | ) 366 | 367 | layers = [] 368 | layers.append(block(inplanes, planes, stride, downsample)) 369 | inplanes = planes * block.expansion 370 | for i in range(1, blocks): 371 | if i == (blocks-1): 372 | layers.append(block(inplanes, planes, stride=1, no_relu=True)) 373 | else: 374 | layers.append(block(inplanes, planes, stride=1, no_relu=False)) 375 | 376 | return nn.Sequential(*layers) 377 | 378 | def forward(self, x): 379 | 380 | width_output_or = x.shape[-1] 381 | height_output_or = x.shape[-2] 382 | 383 | width_output = x.shape[-1] // 8 384 | height_output = x.shape[-2] // 8 385 | layers = [] 386 | 387 | x = self.conv1(x) 388 | 389 | x = self.layer1(x) 390 | layers.append(x) 391 | 392 | x = self.layer2(self.relu(x)) 393 | layers.append(x) 394 | 395 | x = self.layer3(self.relu(x)) 396 | layers.append(x) 397 | 398 | x_ = self.layer3_(self.relu(layers[1])) #### [4,128,128,128] 399 | 400 | x = x + self.down3(self.relu(x_)) ###[4,256,64,64] 401 | 402 | x_ = x_ + F.interpolate( 403 | self.compression3(self.relu(layers[2])), 404 | size=[height_output, width_output], 405 | mode='bilinear') ###[4,128,64,64] 406 | 407 | if self.augment: 408 | temp = x_ 409 | 410 | x = self.layer4(self.relu(x)) 411 | layers.append(x) 412 | x_ = self.layer4_(self.relu(x_)) 413 | 414 | x = x + self.down4(self.relu(x_)) 415 | x_ = x_ + F.interpolate( 416 | self.compression4(self.relu(layers[3])), 417 | size=[height_output, width_output], 418 | mode='bilinear') ###[4,128,128,128] 419 | 420 | ### high-resolution 1/8 apperance branch 421 | before_E = x_ 422 | x_ = self.layer5_(self.relu(x_)) 423 | x_a = x_ ###[4,256,128,128] 424 | after_E = x_ 425 | 426 | ### low resolution 1/64 branch, need upsampling, content branch 427 | x = self.layer5(self.relu(x)) 428 | before_I = F.interpolate( 429 | x, 430 | size=[height_output, width_output], 431 | mode='bilinear') 432 | x = F.interpolate( 433 | self.spp(x), 434 | size=[height_output, width_output], 435 | mode='bilinear') 436 | x_c = x ###[4,256,128,128] 437 | after_I = x 438 | 439 | x_, C, C_ = self.final_layer(x, x_) ### seghead [4,19,128,128] 440 | 441 | outputs = [] 442 | 443 | x_ = F.interpolate(x_, 444 | size=[height_output_or, width_output_or], 445 | mode='bilinear', align_corners=True) #[4,19,1024,1024] 446 | 447 | outputs.append(x_) 448 | 449 | if self.augment: 450 | x_extra = self.seghead_extra(temp) 451 | return [x_, x_extra] 452 | else: 453 | return tuple(outputs), C_, C 454 | 455 | 456 | def DualResNet_imagenet(pretrained=True): 457 | model = DualResNet(BasicBlock, [2, 2, 2, 2], num_classes=19, 458 | planes=64, spp_planes=128, head_planes=128, augment=False) 459 | if pretrained: 460 | pretrained_state = torch.load( 461 | "D:/DDR/models/DDRNet23_imagenet.pth", map_location='cpu') 462 | model_dict = model.state_dict() 463 | pretrained_state = {k: v for k, v in pretrained_state.items() if 464 | (k in model_dict and v.shape == model_dict[k].shape)} 465 | model_dict.update(pretrained_state) 466 | 467 | model.load_state_dict(model_dict, strict=False) 468 | print("Having loaded imagenet-pretrained weights successfully!") 469 | 470 | return model 471 | 472 | 473 | def get_ddrnet_23(pretrained=True): 474 | 475 | model = DualResNet_imagenet(pretrained=pretrained) 476 | return model 477 | 478 | 479 | ### C-leaner & A-learner later part 480 | 481 | class CAinteract(nn.Module): 482 | 483 | def __init__(self, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False): 484 | super(CAinteract, self).__init__() 485 | 486 | self.augment = augment 487 | highres_planes = planes * 2 488 | 489 | ### for C representation branch. feature refine and attention weight matrix 490 | self.conv1 = nn.Sequential( 491 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0), 492 | BatchNorm2d(256, momentum=bn_mom), 493 | nn.ReLU(inplace=True), 494 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0), 495 | BatchNorm2d(1, momentum=bn_mom), 496 | nn.ReLU(inplace=True), 497 | ) 498 | 499 | ### for A representation branch. feature refine and attention weight matrix 500 | self.conv2 = nn.Sequential( 501 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0), 502 | BatchNorm2d(256, momentum=bn_mom), 503 | nn.ReLU(inplace=True), 504 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0), 505 | BatchNorm2d(1, momentum=bn_mom), 506 | nn.ReLU(inplace=True), 507 | ) 508 | 509 | self.relu = nn.ReLU(inplace=False) 510 | 511 | if self.augment: 512 | self.seghead_extra = segmenthead( 513 | highres_planes, head_planes, num_classes) 514 | 515 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes) 516 | 517 | 518 | for m in self.modules(): 519 | if isinstance(m, nn.Conv2d): 520 | nn.init.kaiming_normal_( 521 | m.weight, mode='fan_out', nonlinearity='relu') 522 | elif isinstance(m, BatchNorm2d): 523 | nn.init.constant_(m.weight, 1) 524 | nn.init.constant_(m.bias, 0) 525 | 526 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 527 | downsample = None 528 | if stride != 1 or inplanes != planes * block.expansion: 529 | downsample = nn.Sequential( 530 | nn.Conv2d(inplanes, planes * block.expansion, 531 | kernel_size=1, stride=stride, bias=False), 532 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom), 533 | ) 534 | 535 | layers = [] 536 | layers.append(block(inplanes, planes, stride, downsample)) 537 | inplanes = planes * block.expansion 538 | for i in range(1, blocks): 539 | if i == (blocks-1): 540 | layers.append(block(inplanes, planes, stride=1, no_relu=True)) 541 | else: 542 | layers.append(block(inplanes, planes, stride=1, no_relu=False)) 543 | 544 | return nn.Sequential(*layers) 545 | 546 | def forward(self, X_, C, A): 547 | 548 | # layers = [] 549 | height_output=1024 550 | width_output=1024 551 | 552 | attention_c = self.conv1(C) 553 | attention_a = self.conv2(A) 554 | 555 | C=C*attention_c+C 556 | A=A*attention_a+A 557 | 558 | x_ = self.final_layer(2*X_+ C + A) ### seghead [4,19,128,128] 559 | 560 | outputs = [] 561 | 562 | x_ = F.interpolate(x_, 563 | size=[height_output, width_output], 564 | mode='bilinear', align_corners=True) #[4,19,1024,1024] 565 | 566 | outputs.append(x_) 567 | 568 | return tuple(outputs), C, A 569 | 570 | 571 | ### C-A merge module 572 | 573 | class CAmerge(nn.Module): 574 | 575 | def __init__(self, planes=64, augment=False): 576 | super(CAmerge, self).__init__() 577 | 578 | highres_planes = planes * 2 579 | self.augment = augment 580 | 581 | #### for upsample 1/8 to 1/4 256->64 582 | 583 | self.conv1 = nn.Sequential( 584 | # nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=1), 585 | # BatchNorm2d(128, momentum=bn_mom), 586 | # nn.ReLU(inplace=True), 587 | nn.Conv2d(256, planes, kernel_size=3, stride=1, padding=1), 588 | BatchNorm2d(planes, momentum=bn_mom), 589 | nn.ReLU(inplace=True), 590 | ) 591 | 592 | ##### for upsample 1/4 to 1/2 64->32 593 | self.conv2 = nn.Sequential( 594 | nn.Conv2d(planes, 32, kernel_size=3, stride=1, padding=1), 595 | BatchNorm2d(32, momentum=bn_mom), 596 | nn.ReLU(inplace=True), 597 | ) 598 | 599 | ##### for upsample 1/2 to orginal 1/1 32->3 600 | self.conv3 = nn.Sequential( 601 | nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1), 602 | BatchNorm2d(3, momentum=bn_mom), 603 | nn.ReLU(inplace=True), 604 | ) 605 | 606 | self.relu = nn.ReLU(inplace=False) 607 | 608 | # if self.augment: 609 | # self.seghead_extra = segmenthead( 610 | # highres_planes, head_planes, num_classes) 611 | 612 | # self.final_layer = segmenthead(planes * 4, head_planes, num_classes) 613 | 614 | for m in self.modules(): 615 | if isinstance(m, nn.Conv2d): 616 | nn.init.kaiming_normal_( 617 | m.weight, mode='fan_out', nonlinearity='relu') 618 | elif isinstance(m, BatchNorm2d): 619 | nn.init.constant_(m.weight, 1) 620 | nn.init.constant_(m.bias, 0) 621 | 622 | def forward(self, C, A): 623 | 624 | layers = [] 625 | 626 | ### C and A are from 1/8 resolution, input for reconstruction ###[4,256,128,128] -> [4,64,128,128] -> [4,64,256,256] 627 | x = self.conv1(C+A) 628 | layers.append(x) 629 | ## upsample 128->256 630 | x = F.interpolate( 631 | self.relu(layers[0]), 632 | size=[256,256], 633 | mode='bilinear') 634 | 635 | ###[4,64,256,256] -> [4,32,256,256] -> [4,32,512,512] 636 | x = self.conv2(self.relu(x)) 637 | # if x_41 has more channels than 64 (plane), then first need to add another layer to compress it to 64 channels 638 | layers.append(x) 639 | ## upsample 256->512 640 | x = F.interpolate( 641 | self.relu(layers[1]), 642 | size=[512,512], 643 | mode='bilinear') 644 | 645 | ###[4,32,512,512] -> [4,3,512,512] -> [4,3,1024,1024] 646 | 647 | x = self.conv3(self.relu(x)) 648 | layers.append(x) 649 | 650 | x = F.interpolate( 651 | self.relu(layers[2]), 652 | size=[1024, 1024], 653 | mode='bilinear') ###[4,128,64,64] 654 | 655 | if self.augment: 656 | temp = x 657 | 658 | outputs = [] 659 | 660 | outputs.append(x) 661 | 662 | if self.augment: 663 | x_extra = self.seghead_extra(temp) 664 | return [x, x_extra] 665 | else: 666 | return tuple(outputs), x 667 | 668 | def get_CA_interact( ): 669 | model = CAinteract( ) 670 | return model 671 | 672 | def get_CA_merge( ): 673 | model = CAmerge( ) 674 | return model 675 | 676 | if __name__ == "__main__": 677 | 678 | model = DualResNet_imagenet(pretrained=True) 679 | -------------------------------------------------------------------------------- /models/DDRNet_23_vis1.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | from collections import OrderedDict 8 | import random 9 | 10 | BatchNorm2d = nn.BatchNorm2d 11 | bn_mom = 0.1 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | 20 | class BasicBlock(nn.Module): 21 | expansion = 1 22 | 23 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=False): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = conv3x3(inplanes, planes, stride) 26 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.conv2 = conv3x3(planes, planes) 29 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom) 30 | self.downsample = downsample 31 | self.stride = stride 32 | self.no_relu = no_relu 33 | 34 | def forward(self, x): 35 | residual = x 36 | 37 | out = self.conv1(x) 38 | out = self.bn1(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv2(out) 42 | out = self.bn2(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | 49 | if self.no_relu: 50 | return out 51 | else: 52 | return self.relu(out) 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 2 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None, no_relu=True): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = BatchNorm2d(planes, momentum=bn_mom) 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = BatchNorm2d(planes, momentum=bn_mom) 65 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, 66 | bias=False) 67 | self.bn3 = BatchNorm2d(planes * self.expansion, momentum=bn_mom) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | self.no_relu = no_relu 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | if self.no_relu: 92 | return out 93 | else: 94 | return self.relu(out) 95 | 96 | 97 | class DAPPM(nn.Module): 98 | def __init__(self, inplanes, branch_planes, outplanes): 99 | super(DAPPM, self).__init__() 100 | self.scale1 = nn.Sequential(nn.AvgPool2d(kernel_size=5, stride=2, padding=2), 101 | BatchNorm2d(inplanes, momentum=bn_mom), 102 | nn.ReLU(inplace=True), 103 | nn.Conv2d(inplanes, branch_planes, 104 | kernel_size=1, bias=False), 105 | ) 106 | self.scale2 = nn.Sequential(nn.AvgPool2d(kernel_size=9, stride=4, padding=4), 107 | BatchNorm2d(inplanes, momentum=bn_mom), 108 | nn.ReLU(inplace=True), 109 | nn.Conv2d(inplanes, branch_planes, 110 | kernel_size=1, bias=False), 111 | ) 112 | self.scale3 = nn.Sequential(nn.AvgPool2d(kernel_size=17, stride=8, padding=8), 113 | BatchNorm2d(inplanes, momentum=bn_mom), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(inplanes, branch_planes, 116 | kernel_size=1, bias=False), 117 | ) 118 | self.scale4 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 119 | BatchNorm2d(inplanes, momentum=bn_mom), 120 | nn.ReLU(inplace=True), 121 | nn.Conv2d(inplanes, branch_planes, 122 | kernel_size=1, bias=False), 123 | ) 124 | self.scale0 = nn.Sequential( 125 | BatchNorm2d(inplanes, momentum=bn_mom), 126 | nn.ReLU(inplace=True), 127 | nn.Conv2d(inplanes, branch_planes, kernel_size=1, bias=False), 128 | ) 129 | self.process1 = nn.Sequential( 130 | BatchNorm2d(branch_planes, momentum=bn_mom), 131 | nn.ReLU(inplace=True), 132 | nn.Conv2d(branch_planes, branch_planes, 133 | kernel_size=3, padding=1, bias=False), 134 | ) 135 | self.process2 = nn.Sequential( 136 | BatchNorm2d(branch_planes, momentum=bn_mom), 137 | nn.ReLU(inplace=True), 138 | nn.Conv2d(branch_planes, branch_planes, 139 | kernel_size=3, padding=1, bias=False), 140 | ) 141 | self.process3 = nn.Sequential( 142 | BatchNorm2d(branch_planes, momentum=bn_mom), 143 | nn.ReLU(inplace=True), 144 | nn.Conv2d(branch_planes, branch_planes, 145 | kernel_size=3, padding=1, bias=False), 146 | ) 147 | self.process4 = nn.Sequential( 148 | BatchNorm2d(branch_planes, momentum=bn_mom), 149 | nn.ReLU(inplace=True), 150 | nn.Conv2d(branch_planes, branch_planes, 151 | kernel_size=3, padding=1, bias=False), 152 | ) 153 | self.compression = nn.Sequential( 154 | BatchNorm2d(branch_planes * 5, momentum=bn_mom), 155 | nn.ReLU(inplace=True), 156 | nn.Conv2d(branch_planes * 5, outplanes, kernel_size=1, bias=False), 157 | ) 158 | self.shortcut = nn.Sequential( 159 | BatchNorm2d(inplanes, momentum=bn_mom), 160 | nn.ReLU(inplace=True), 161 | nn.Conv2d(inplanes, outplanes, kernel_size=1, bias=False), 162 | ) 163 | 164 | def forward(self, x): 165 | 166 | # x = self.downsample(x) 167 | width = x.shape[-1] 168 | height = x.shape[-2] 169 | x_list = [] 170 | 171 | x_list.append(self.scale0(x)) 172 | x_list.append(self.process1((F.interpolate(self.scale1(x), 173 | size=[height, width], 174 | mode='bilinear')+x_list[0]))) 175 | x_list.append((self.process2((F.interpolate(self.scale2(x), 176 | size=[height, width], 177 | mode='bilinear')+x_list[1])))) 178 | x_list.append(self.process3((F.interpolate(self.scale3(x), 179 | size=[height, width], 180 | mode='bilinear')+x_list[2]))) 181 | x_list.append(self.process4((F.interpolate(self.scale4(x), 182 | size=[height, width], 183 | mode='bilinear')+x_list[3]))) 184 | 185 | out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x) 186 | return out 187 | 188 | 189 | class segmenthead(nn.Module): 190 | 191 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None): 192 | super(segmenthead, self).__init__() 193 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom) 194 | self.conv1 = nn.Conv2d(inplanes, interplanes, 195 | kernel_size=3, padding=1, bias=False) 196 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom) 197 | self.relu = nn.ReLU(inplace=True) 198 | self.conv2 = nn.Conv2d(interplanes, outplanes, 199 | kernel_size=1, padding=0, bias=True) 200 | self.scale_factor = scale_factor 201 | 202 | ### for C representation branch. feature refine and attention weight matrix 203 | self.conv3 = nn.Sequential( 204 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0), 205 | BatchNorm2d(1, momentum=bn_mom), 206 | nn.ReLU(inplace=True), 207 | ) 208 | 209 | ### for A representation branch. feature refine and attention weight matrix 210 | self.conv4 = nn.Sequential( 211 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0), 212 | BatchNorm2d(1, momentum=bn_mom), 213 | nn.ReLU(inplace=True), 214 | ) 215 | 216 | self.conv5 = nn.Sequential( 217 | nn.Conv2d(256, 19, kernel_size=1, stride=1, padding=0), 218 | BatchNorm2d(19, momentum=bn_mom), 219 | nn.ReLU(inplace=True), 220 | ) 221 | 222 | self.conv6 = nn.Sequential( 223 | nn.Conv2d(256, 19, kernel_size=1, stride=1, padding=0), 224 | BatchNorm2d(19, momentum=bn_mom), 225 | nn.ReLU(inplace=True), 226 | ) 227 | 228 | def forward(self, C, A): 229 | 230 | attention_c=self.conv3(C) 231 | attention_a=self.conv4(A) 232 | 233 | C_ = 2*C + attention_c *C + A *attention_c 234 | A_ =A + attention_a *A 235 | 236 | x = self.conv1(self.relu(self.bn1(C_+A_))) 237 | out = self.conv2(self.relu(self.bn2(x))) 238 | 239 | Cupdate=self.conv5(C) 240 | C_update=self.conv6(C_) 241 | 242 | if self.scale_factor is not None: 243 | height = x.shape[-2] * self.scale_factor 244 | width = x.shape[-1] * self.scale_factor 245 | out = F.interpolate(out, 246 | size=[height, width], 247 | mode='bilinear') 248 | 249 | return out, C_update, Cupdate 250 | 251 | 252 | class segmentheadold(nn.Module): 253 | 254 | def __init__(self, inplanes, interplanes, outplanes, scale_factor=None): 255 | super(segmenthead, self).__init__() 256 | self.bn1 = BatchNorm2d(inplanes, momentum=bn_mom) 257 | self.conv1 = nn.Conv2d(inplanes, interplanes, 258 | kernel_size=3, padding=1, bias=False) 259 | self.bn2 = BatchNorm2d(interplanes, momentum=bn_mom) 260 | self.relu = nn.ReLU(inplace=True) 261 | self.conv2 = nn.Conv2d(interplanes, outplanes, 262 | kernel_size=1, padding=0, bias=True) 263 | self.scale_factor = scale_factor 264 | 265 | def forward(self, x): 266 | 267 | x = self.conv1(self.relu(self.bn1(x))) 268 | out = self.conv2(self.relu(self.bn2(x))) 269 | 270 | if self.scale_factor is not None: 271 | height = x.shape[-2] * self.scale_factor 272 | width = x.shape[-1] * self.scale_factor 273 | out = F.interpolate(out, 274 | size=[height, width], 275 | mode='bilinear') 276 | 277 | return out 278 | 279 | 280 | class DualResNet(nn.Module): 281 | 282 | def __init__(self, block, layers, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False): 283 | super(DualResNet, self).__init__() 284 | 285 | highres_planes = planes * 2 286 | self.augment = augment 287 | 288 | self.conv1 = nn.Sequential( 289 | nn.Conv2d(3, planes, kernel_size=3, stride=2, padding=1), 290 | BatchNorm2d(planes, momentum=bn_mom), 291 | nn.ReLU(inplace=True), 292 | nn.Conv2d(planes, planes, kernel_size=3, stride=2, padding=1), 293 | BatchNorm2d(planes, momentum=bn_mom), 294 | nn.ReLU(inplace=True), 295 | ) 296 | 297 | self.relu = nn.ReLU(inplace=False) 298 | self.layer1 = self._make_layer(block, planes, planes, layers[0]) 299 | self.layer2 = self._make_layer( 300 | block, planes, planes * 2, layers[1], stride=2) 301 | self.layer3 = self._make_layer( 302 | block, planes * 2, planes * 4, layers[2], stride=2) 303 | self.layer4 = self._make_layer( 304 | block, planes * 4, planes * 8, layers[3], stride=2) 305 | 306 | self.compression3 = nn.Sequential( 307 | nn.Conv2d(planes * 4, highres_planes, kernel_size=1, bias=False), 308 | BatchNorm2d(highres_planes, momentum=bn_mom), 309 | ) 310 | 311 | self.compression4 = nn.Sequential( 312 | nn.Conv2d(planes * 8, highres_planes, kernel_size=1, bias=False), 313 | BatchNorm2d(highres_planes, momentum=bn_mom), 314 | ) 315 | 316 | self.down3 = nn.Sequential( 317 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3, 318 | stride=2, padding=1, bias=False), 319 | BatchNorm2d(planes * 4, momentum=bn_mom), 320 | ) 321 | 322 | self.down4 = nn.Sequential( 323 | nn.Conv2d(highres_planes, planes * 4, kernel_size=3, 324 | stride=2, padding=1, bias=False), 325 | BatchNorm2d(planes * 4, momentum=bn_mom), 326 | nn.ReLU(inplace=True), 327 | nn.Conv2d(planes * 4, planes * 8, kernel_size=3, 328 | stride=2, padding=1, bias=False), 329 | BatchNorm2d(planes * 8, momentum=bn_mom), 330 | ) 331 | 332 | self.layer3_ = self._make_layer(block, planes * 2, highres_planes, 2) 333 | 334 | self.layer4_ = self._make_layer( 335 | block, highres_planes, highres_planes, 2) 336 | 337 | self.layer5_ = self._make_layer( 338 | Bottleneck, highres_planes, highres_planes, 1) 339 | 340 | self.layer5 = self._make_layer( 341 | Bottleneck, planes * 8, planes * 8, 1, stride=2) 342 | 343 | self.spp = DAPPM(planes * 16, spp_planes, planes * 4) 344 | 345 | if self.augment: 346 | self.seghead_extra = segmenthead( 347 | highres_planes, head_planes, num_classes) 348 | 349 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes) 350 | 351 | for m in self.modules(): 352 | if isinstance(m, nn.Conv2d): 353 | nn.init.kaiming_normal_( 354 | m.weight, mode='fan_out', nonlinearity='relu') 355 | elif isinstance(m, BatchNorm2d): 356 | nn.init.constant_(m.weight, 1) 357 | nn.init.constant_(m.bias, 0) 358 | 359 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 360 | downsample = None 361 | if stride != 1 or inplanes != planes * block.expansion: 362 | downsample = nn.Sequential( 363 | nn.Conv2d(inplanes, planes * block.expansion, 364 | kernel_size=1, stride=stride, bias=False), 365 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom), 366 | ) 367 | 368 | layers = [] 369 | layers.append(block(inplanes, planes, stride, downsample)) 370 | inplanes = planes * block.expansion 371 | for i in range(1, blocks): 372 | if i == (blocks-1): 373 | layers.append(block(inplanes, planes, stride=1, no_relu=True)) 374 | else: 375 | layers.append(block(inplanes, planes, stride=1, no_relu=False)) 376 | 377 | return nn.Sequential(*layers) 378 | 379 | def forward(self, x): 380 | 381 | width_output_or = x.shape[-1] 382 | height_output_or = x.shape[-2] 383 | 384 | width_output = x.shape[-1] // 8 385 | height_output = x.shape[-2] // 8 386 | layers = [] 387 | 388 | x = self.conv1(x) 389 | 390 | x = self.layer1(x) 391 | layers.append(x) 392 | 393 | x = self.layer2(self.relu(x)) 394 | layers.append(x) 395 | 396 | x = self.layer3(self.relu(x)) 397 | layers.append(x) 398 | 399 | x_ = self.layer3_(self.relu(layers[1])) #### [4,128,128,128] 400 | 401 | x = x + self.down3(self.relu(x_)) ###[4,256,64,64] 402 | 403 | x_ = x_ + F.interpolate( 404 | self.compression3(self.relu(layers[2])), 405 | size=[height_output, width_output], 406 | mode='bilinear') ###[4,128,64,64] 407 | 408 | if self.augment: 409 | temp = x_ 410 | 411 | x = self.layer4(self.relu(x)) 412 | layers.append(x) 413 | x_ = self.layer4_(self.relu(x_)) 414 | 415 | x = x + self.down4(self.relu(x_)) 416 | x_ = x_ + F.interpolate( 417 | self.compression4(self.relu(layers[3])), 418 | size=[height_output, width_output], 419 | mode='bilinear') ###[4,128,128,128] 420 | 421 | ### high-resolution 1/8 apperance branch 422 | before_E = x_ 423 | x_ = self.layer5_(self.relu(x_)) 424 | x_a = x_ ###[4,256,128,128] 425 | after_E = x_ 426 | 427 | ### low resolution 1/64 branch, need upsampling, content branch 428 | x = self.layer5(self.relu(x)) 429 | before_I = F.interpolate( 430 | x, 431 | size=[height_output, width_output], 432 | mode='bilinear') 433 | x = F.interpolate( 434 | self.spp(x), 435 | size=[height_output, width_output], 436 | mode='bilinear') 437 | x_c = x ###[4,256,128,128] 438 | after_I = x 439 | 440 | # tsne 441 | # before_Is = torch.mean(torch.mean(before_I, axis=1), 1) 442 | # after_Is = torch.mean(torch.mean(after_I, axis=1), 1) 443 | # before_Es = torch.mean(torch.mean(before_E, axis=1), 1) 444 | # after_Es = torch.mean(torch.mean(after_E, axis=1), 1) 445 | 446 | # heatmap 447 | # before_Is = torch.mean(before_I, axis=1) # 1024, 128, 256 448 | # after_Is = torch.mean(after_I, axis=1) 449 | # before_Es = torch.mean(before_E, axis=1) 450 | # after_Es = torch.mean(after_E, axis=1) 451 | np.random.seed(123) 452 | random_index1 = np.arange(before_I.shape[1]) 453 | random_index2 = np.arange(after_I.shape[1]) 454 | random_index3 = np.arange(before_E.shape[1]) 455 | random_index4 = np.arange(after_E.shape[1]) 456 | random.shuffle(random_index1) 457 | random.shuffle(random_index2) 458 | random.shuffle(random_index3) 459 | random.shuffle(random_index4) 460 | random_index1 = random_index1[:20] 461 | random_index2 = random_index2[:20] 462 | random_index3 = random_index3[:20] 463 | random_index4 = random_index4[:20] 464 | for i, ind in enumerate(random_index1): 465 | if i == 0: 466 | before_Is = before_I[:, ind, :, :] 467 | else: 468 | before_Is = torch.cat((before_Is, before_I[:, ind, :, :]), 0) 469 | for i, ind in enumerate(random_index2): 470 | if i == 0: 471 | after_Is = after_I[:, ind, :, :] 472 | else: 473 | after_Is = torch.cat((after_Is, after_I[:, ind, :, :]), 0) 474 | for i, ind in enumerate(random_index3): 475 | if i == 0: 476 | before_Es = before_E[:, ind, :, :] 477 | else: 478 | before_Es = torch.cat((before_Es, before_E[:, ind, :, :]), 0) 479 | for i, ind in enumerate(random_index4): 480 | if i == 0: 481 | after_Es = after_E[:, ind, :, :] 482 | else: 483 | after_Es = torch.cat((after_Es, after_E[:, ind, :, :]), 0) 484 | 485 | return before_Is, after_Is, before_Es, after_Es 486 | 487 | x_, C, C_ = self.final_layer(x, x_) ### seghead [4,19,128,128] 488 | 489 | outputs = [] 490 | 491 | x_ = F.interpolate(x_, 492 | size=[height_output_or, width_output_or], 493 | mode='bilinear', align_corners=True) #[4,19,1024,1024] 494 | 495 | outputs.append(x_) 496 | 497 | if self.augment: 498 | x_extra = self.seghead_extra(temp) 499 | return [x_, x_extra] 500 | else: 501 | return tuple(outputs), C_, C 502 | 503 | 504 | def DualResNet_imagenet(pretrained=True): 505 | model = DualResNet(BasicBlock, [2, 2, 2, 2], num_classes=19, 506 | planes=64, spp_planes=128, head_planes=128, augment=False) 507 | if pretrained: 508 | pretrained_state = torch.load( 509 | "D:/DDR/models/DDRNet23_imagenet.pth", map_location='cpu') 510 | model_dict = model.state_dict() 511 | pretrained_state = {k: v for k, v in pretrained_state.items() if 512 | (k in model_dict and v.shape == model_dict[k].shape)} 513 | model_dict.update(pretrained_state) 514 | 515 | model.load_state_dict(model_dict, strict=False) 516 | print("Having loaded imagenet-pretrained weights successfully!") 517 | 518 | return model 519 | 520 | 521 | def get_ddrnet_23_vis1(pretrained=True): 522 | 523 | model = DualResNet_imagenet(pretrained=pretrained) 524 | return model 525 | 526 | 527 | ### C-leaner & A-learner later part 528 | 529 | class CAinteract(nn.Module): 530 | 531 | def __init__(self, num_classes=19, planes=64, spp_planes=128, head_planes=128, augment=False): 532 | super(CAinteract, self).__init__() 533 | 534 | self.augment = augment 535 | highres_planes = planes * 2 536 | 537 | ### for C representation branch. feature refine and attention weight matrix 538 | self.conv1 = nn.Sequential( 539 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0), 540 | BatchNorm2d(256, momentum=bn_mom), 541 | nn.ReLU(inplace=True), 542 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0), 543 | BatchNorm2d(1, momentum=bn_mom), 544 | nn.ReLU(inplace=True), 545 | ) 546 | 547 | ### for A representation branch. feature refine and attention weight matrix 548 | self.conv2 = nn.Sequential( 549 | nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0), 550 | BatchNorm2d(256, momentum=bn_mom), 551 | nn.ReLU(inplace=True), 552 | nn.Conv2d(256, 1, kernel_size=1, stride=1, padding=0), 553 | BatchNorm2d(1, momentum=bn_mom), 554 | nn.ReLU(inplace=True), 555 | ) 556 | 557 | self.relu = nn.ReLU(inplace=False) 558 | 559 | if self.augment: 560 | self.seghead_extra = segmenthead( 561 | highres_planes, head_planes, num_classes) 562 | 563 | self.final_layer = segmenthead(planes * 4, head_planes, num_classes) 564 | 565 | 566 | for m in self.modules(): 567 | if isinstance(m, nn.Conv2d): 568 | nn.init.kaiming_normal_( 569 | m.weight, mode='fan_out', nonlinearity='relu') 570 | elif isinstance(m, BatchNorm2d): 571 | nn.init.constant_(m.weight, 1) 572 | nn.init.constant_(m.bias, 0) 573 | 574 | def _make_layer(self, block, inplanes, planes, blocks, stride=1): 575 | downsample = None 576 | if stride != 1 or inplanes != planes * block.expansion: 577 | downsample = nn.Sequential( 578 | nn.Conv2d(inplanes, planes * block.expansion, 579 | kernel_size=1, stride=stride, bias=False), 580 | nn.BatchNorm2d(planes * block.expansion, momentum=bn_mom), 581 | ) 582 | 583 | layers = [] 584 | layers.append(block(inplanes, planes, stride, downsample)) 585 | inplanes = planes * block.expansion 586 | for i in range(1, blocks): 587 | if i == (blocks-1): 588 | layers.append(block(inplanes, planes, stride=1, no_relu=True)) 589 | else: 590 | layers.append(block(inplanes, planes, stride=1, no_relu=False)) 591 | 592 | return nn.Sequential(*layers) 593 | 594 | def forward(self, X_, C, A): 595 | 596 | # layers = [] 597 | height_output=1024 598 | width_output=1024 599 | 600 | attention_c = self.conv1(C) 601 | attention_a = self.conv2(A) 602 | 603 | C=C*attention_c+C 604 | A=A*attention_a+A 605 | 606 | x_ = self.final_layer(2*X_+ C + A) ### seghead [4,19,128,128] 607 | 608 | outputs = [] 609 | 610 | x_ = F.interpolate(x_, 611 | size=[height_output, width_output], 612 | mode='bilinear', align_corners=True) #[4,19,1024,1024] 613 | 614 | outputs.append(x_) 615 | 616 | return tuple(outputs), C, A 617 | 618 | 619 | ### C-A merge module 620 | 621 | class CAmerge(nn.Module): 622 | 623 | def __init__(self, planes=64, augment=False): 624 | super(CAmerge, self).__init__() 625 | 626 | highres_planes = planes * 2 627 | self.augment = augment 628 | 629 | #### for upsample 1/8 to 1/4 256->64 630 | 631 | self.conv1 = nn.Sequential( 632 | # nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=1), 633 | # BatchNorm2d(128, momentum=bn_mom), 634 | # nn.ReLU(inplace=True), 635 | nn.Conv2d(256, planes, kernel_size=3, stride=1, padding=1), 636 | BatchNorm2d(planes, momentum=bn_mom), 637 | nn.ReLU(inplace=True), 638 | ) 639 | 640 | ##### for upsample 1/4 to 1/2 64->32 641 | self.conv2 = nn.Sequential( 642 | nn.Conv2d(planes, 32, kernel_size=3, stride=1, padding=1), 643 | BatchNorm2d(32, momentum=bn_mom), 644 | nn.ReLU(inplace=True), 645 | ) 646 | 647 | ##### for upsample 1/2 to orginal 1/1 32->3 648 | self.conv3 = nn.Sequential( 649 | nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1), 650 | BatchNorm2d(3, momentum=bn_mom), 651 | nn.ReLU(inplace=True), 652 | ) 653 | 654 | self.relu = nn.ReLU(inplace=False) 655 | 656 | # if self.augment: 657 | # self.seghead_extra = segmenthead( 658 | # highres_planes, head_planes, num_classes) 659 | 660 | # self.final_layer = segmenthead(planes * 4, head_planes, num_classes) 661 | 662 | for m in self.modules(): 663 | if isinstance(m, nn.Conv2d): 664 | nn.init.kaiming_normal_( 665 | m.weight, mode='fan_out', nonlinearity='relu') 666 | elif isinstance(m, BatchNorm2d): 667 | nn.init.constant_(m.weight, 1) 668 | nn.init.constant_(m.bias, 0) 669 | 670 | def forward(self, C, A): 671 | 672 | layers = [] 673 | 674 | ### C and A are from 1/8 resolution, input for reconstruction ###[4,256,128,128] -> [4,64,128,128] -> [4,64,256,256] 675 | x = self.conv1(C+A) 676 | layers.append(x) 677 | ## upsample 128->256 678 | x = F.interpolate( 679 | self.relu(layers[0]), 680 | size=[256,256], 681 | mode='bilinear') 682 | 683 | ###[4,64,256,256] -> [4,32,256,256] -> [4,32,512,512] 684 | x = self.conv2(self.relu(x)) 685 | # if x_41 has more channels than 64 (plane), then first need to add another layer to compress it to 64 channels 686 | layers.append(x) 687 | ## upsample 256->512 688 | x = F.interpolate( 689 | self.relu(layers[1]), 690 | size=[512,512], 691 | mode='bilinear') 692 | 693 | ###[4,32,512,512] -> [4,3,512,512] -> [4,3,1024,1024] 694 | 695 | x = self.conv3(self.relu(x)) 696 | layers.append(x) 697 | 698 | x = F.interpolate( 699 | self.relu(layers[2]), 700 | size=[1024, 1024], 701 | mode='bilinear') ###[4,128,64,64] 702 | 703 | if self.augment: 704 | temp = x 705 | 706 | outputs = [] 707 | 708 | outputs.append(x) 709 | 710 | if self.augment: 711 | x_extra = self.seghead_extra(temp) 712 | return [x, x_extra] 713 | else: 714 | return tuple(outputs), x 715 | 716 | def get_CA_interact( ): 717 | model = CAinteract( ) 718 | return model 719 | 720 | def get_CA_merge( ): 721 | model = CAmerge( ) 722 | return model 723 | 724 | if __name__ == "__main__": 725 | 726 | model = DualResNet_imagenet(pretrained=True) 727 | --------------------------------------------------------------------------------