├── README.md ├── base └── base_model.py ├── data_folder.py ├── data_prepare ├── SegFix_offset_helper.py ├── getDirectionDiffMap.py └── logger.py ├── hhl_utils ├── helpers.py ├── pytorch_ssim.py ├── radam.py ├── ranger.py └── torchsummary.py ├── loss.py ├── models ├── FullNet.py ├── dam │ ├── model_unet_MandD.py │ ├── model_unet_MandD16.py │ ├── model_unet_MandD4.py │ ├── model_unet_MandDandP.py │ ├── model_unet_rev1.py │ └── seg_hrnet_rev1.py ├── deeplabv3_plus.py ├── fcn8.py ├── model_unet.py ├── pspnet.py ├── seg_hrnet.py ├── segnet.py └── unet.py ├── my_transforms.py ├── my_transforms_direction.py ├── options.py ├── postproc_other.py ├── stats_utils.py ├── test.py ├── test_dam.py ├── train.py ├── train_util.py ├── train_util_dam.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # CDNet: Centripetal Direction Network for Nuclear Instance Segmentation 2 | 3 | 4 | [[`ICCV2021`](https://openaccess.thecvf.com/content/ICCV2021/papers/He_CDNet_Centripetal_Direction_Network_for_Nuclear_Instance_Segmentation_ICCV_2021_paper.pdf)] 5 | 6 | The code includes training and inference procedures for CDNet. 7 | 8 | Tips: 9 | There is a result written mistake (U-Net) in Table 4 in the original paper. 10 | The correct result is: 11 | 12 | ### MoNuSeg 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 |
Method NameDiceAJI
U-Net0.81840.5910
Mask-RCNN0.76000.5460
DCAN0.79200.5250
Micro-Net0.79700.5600
DIST0.78900.5590
CIA-Net0.81800.6200
U-Net0.80270.6039
Hover-Net0.82600.6180
BRP-Net - 0.6422
PFF-Net0.80910.6107
Our CDNet0.83160.6331
77 | 78 | 79 | ## Getting Started 80 | #### Create a data folder(/data) and put the datasets(MoNuSeg, CPM17) in it. 81 | 82 | #### Train 83 | ``` 84 | cd CDNet/ 85 | python train.py 86 | ``` 87 | 88 | #### Test 89 | ``` 90 | cd CDNet/ 91 | python test.py 92 | ``` 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class BaseModel(nn.Module): 7 | def __init__(self): 8 | super(BaseModel, self).__init__() 9 | self.logger = logging.getLogger(self.__class__.__name__) 10 | 11 | def forward(self): 12 | raise NotImplementedError 13 | 14 | def summary(self): 15 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 16 | nbr_params = sum([np.prod(p.size()) for p in model_parameters]) 17 | self.logger.info(f'Nbr of trainable parameters: {nbr_params}') 18 | 19 | def __str__(self): 20 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 21 | nbr_params = sum([np.prod(p.size()) for p in model_parameters]) 22 | return super(BaseModel, self).__str__() + f'\nNbr of trainable parameters: {nbr_params}' 23 | #return summary(self, input_shape=(2, 3, 224, 224)) -------------------------------------------------------------------------------- /data_folder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.utils.data as data 3 | import os 4 | from PIL import Image 5 | import numpy as np 6 | import scipy.io as scio 7 | import torch 8 | #from skimage import morphology, io 9 | 10 | IMG_EXTENSIONS = [ 11 | '.jpg', '.JPG', '.jpeg', '.JPEG', 12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 13 | ] 14 | 15 | 16 | def is_image_file(filename): 17 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 18 | 19 | 20 | def img_loader(path, num_channels): 21 | 22 | 23 | if num_channels == 1: 24 | if ('.mat' in path): 25 | img = scio.loadmat(path)['inst_map'] 26 | img = Image.fromarray(img.astype(np.uint8)) 27 | elif('.npy' in path): 28 | img = np.load(path) 29 | img = Image.fromarray(img.astype(np.uint8)) 30 | else: 31 | img = Image.open(path) 32 | else: 33 | if('.mat' in path): 34 | img = scio.loadmat(path)['inst_map'] 35 | elif ('.npy' in path): 36 | img = np.load(path) 37 | img = Image.fromarray(img.astype(np.uint8)) 38 | else: 39 | img = Image.open(path).convert('RGB') 40 | 41 | return img 42 | 43 | 44 | # get the image list pairs 45 | def get_imgs_list(dir_list, post_fix=None): 46 | """ 47 | :param dir_list: [img1_dir, img2_dir, ...] 48 | :param post_fix: e.g. ['label.png', 'weight.png',...] 49 | :return: e.g. [(img1.ext, img1_label.png, img1_weight.png), ...] 50 | """ 51 | img_list = [] 52 | if len(dir_list) == 0: 53 | return img_list 54 | if len(dir_list) != len(post_fix) + 1: 55 | raise (RuntimeError('Should specify the postfix of each img type except the first input.')) 56 | 57 | img_filename_list = [os.listdir(dir_list[i]) for i in range(len(dir_list))] 58 | 59 | for img in img_filename_list[0]: 60 | if not is_image_file(img): 61 | continue 62 | img1_name = os.path.splitext(img)[0] 63 | item = [os.path.join(dir_list[0], img),] 64 | for i in range(1, len(img_filename_list)): 65 | img_name = '{:s}_{:s}'.format(img1_name, post_fix[i-1]) 66 | if img_name in img_filename_list[i]: 67 | img_path = os.path.join(dir_list[i], img_name) 68 | item.append(img_path) 69 | 70 | if len(item) == len(dir_list): 71 | img_list.append(tuple(item)) 72 | 73 | return img_list 74 | 75 | 76 | 77 | # dataset that supports one input image, one target image, and one weight map (optional) 78 | class DataFolder(data.Dataset): 79 | def __init__(self, dir_list, post_fix, num_channels, data_transform=None, loader=img_loader): 80 | super(DataFolder, self).__init__() 81 | if len(dir_list) != len(post_fix) + 1: 82 | raise (RuntimeError('Length of dir_list is different from length of post_fix + 1.')) 83 | if len(dir_list) != len(num_channels): 84 | raise (RuntimeError('Length of dir_list is different from length of num_channels.')) 85 | 86 | self.img_list = get_imgs_list(dir_list, post_fix) 87 | if len(self.img_list) == 0: 88 | raise(RuntimeError('Found 0 image pairs in given directories.')) 89 | 90 | self.data_transform = data_transform 91 | self.num_channels = num_channels 92 | self.loader = loader 93 | 94 | def __getitem__(self, index): 95 | img_paths = self.img_list[index] 96 | 97 | sample = [self.loader(img_paths[i], self.num_channels[i]) for i in range(len(img_paths))] 98 | 99 | if self.data_transform is not None: 100 | sample_tensor = self.data_transform(sample) 101 | 102 | 103 | while(len(torch.unique(sample_tensor[2]))<=1): # sample[2].detach().cpu().numpy() 104 | if self.data_transform is not None: 105 | sample_tensor = self.data_transform(sample) 106 | 107 | return sample_tensor 108 | 109 | def __len__(self): 110 | return len(self.img_list) 111 | 112 | -------------------------------------------------------------------------------- /data_prepare/getDirectionDiffMap.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on 2021/6/7 4 | 5 | @author: he 6 | """ 7 | 8 | 9 | from data_prepare.SegFix_offset_helper import DTOffsetHelper 10 | import numpy as np 11 | import torch 12 | 13 | 14 | def circshift(matrix_ori, direction, shiftnum1, shiftnum2): 15 | # direction = 1,2,3,4 # 偏移方向 1:左上; 2:右上; 3:左下; 4:右下; 16 | c, h, w = matrix_ori.shape 17 | matrix_new = np.zeros_like(matrix_ori) 18 | 19 | for k in range(c): 20 | matrix = matrix_ori[k] 21 | # matrix = matrix_ori[:,:,k] 22 | if (direction == 1): 23 | # 左上 24 | matrix = np.vstack((matrix[shiftnum1:, :], np.zeros_like(matrix[:shiftnum1, :]))) 25 | matrix = np.hstack((matrix[:, shiftnum2:], np.zeros_like(matrix[:, :shiftnum2]))) 26 | elif (direction == 2): 27 | # 右上 28 | matrix = np.vstack((matrix[shiftnum1:, :], np.zeros_like(matrix[:shiftnum1, :]))) 29 | matrix = np.hstack((np.zeros_like(matrix[:, (w - shiftnum2):]), matrix[:, :(w - shiftnum2)])) 30 | elif (direction == 3): 31 | # 左下 32 | matrix = np.vstack((np.zeros_like(matrix[(h - shiftnum1):, :]), matrix[:(h - shiftnum1), :])) 33 | matrix = np.hstack((matrix[:, shiftnum2:], np.zeros_like(matrix[:, :shiftnum2]))) 34 | elif (direction == 4): 35 | # 右下 36 | matrix = np.vstack((np.zeros_like(matrix[(h - shiftnum1):, :]), matrix[:(h - shiftnum1), :])) 37 | matrix = np.hstack((np.zeros_like(matrix[:, (w - shiftnum2):]), matrix[:, :(w - shiftnum2)])) 38 | # matrix_new[k]==>matrix_new[:,:, k] 39 | # matrix_new[:,:, k] = matrix 40 | matrix_new[k] = matrix 41 | 42 | return matrix_new 43 | 44 | def generate_dd_map(label_direction, direction_classes): 45 | direction_offsets = DTOffsetHelper.label_to_vector(torch.from_numpy(label_direction.reshape(1, label_direction.shape[0], label_direction.shape[1])),direction_classes) 46 | direction_offsets = direction_offsets[0].permute(1,2,0).detach().cpu().numpy() 47 | 48 | direction_os = direction_offsets #[256,256,2] 49 | 50 | height, weight = direction_os.shape[0], direction_os.shape[1] 51 | 52 | cos_sim_map = np.zeros((height, weight), dtype=np.float) 53 | 54 | feature_list = [] 55 | feature5 = direction_os # .transpose(1, 2, 0) 56 | if (direction_classes - 1 == 4): 57 | direction_os = direction_os.transpose(2, 0, 1) 58 | feature2 = circshift(direction_os, 1, 1, 0).transpose(1, 2, 0) 59 | feature4 = circshift(direction_os, 3, 0, 1).transpose(1, 2, 0) 60 | feature6 = circshift(direction_os, 4, 0, 1).transpose(1, 2, 0) 61 | feature8 = circshift(direction_os, 3, 1, 0).transpose(1, 2, 0) 62 | 63 | feature_list.append(feature2) 64 | feature_list.append(feature4) 65 | # feature_list.append(feature5) 66 | feature_list.append(feature6) 67 | feature_list.append(feature8) 68 | 69 | elif (direction_classes - 1 == 8 or direction_classes - 1 == 16): 70 | direction_os = direction_os.transpose(2, 0, 1) # [2,256,256] 71 | feature1 = circshift(direction_os, 1, 1, 1).transpose(1, 2, 0) 72 | feature2 = circshift(direction_os, 1, 1, 0).transpose(1, 2, 0) 73 | feature3 = circshift(direction_os, 2, 1, 1).transpose(1, 2, 0) 74 | feature4 = circshift(direction_os, 3, 0, 1).transpose(1, 2, 0) 75 | feature6 = circshift(direction_os, 4, 0, 1).transpose(1, 2, 0) 76 | feature7 = circshift(direction_os, 3, 1, 1).transpose(1, 2, 0) 77 | feature8 = circshift(direction_os, 3, 1, 0).transpose(1, 2, 0) 78 | feature9 = circshift(direction_os, 4, 1, 1).transpose(1, 2, 0) 79 | 80 | feature_list.append(feature1) 81 | feature_list.append(feature2) 82 | feature_list.append(feature3) 83 | feature_list.append(feature4) 84 | # feature_list.append(feature5) 85 | feature_list.append(feature6) 86 | feature_list.append(feature7) 87 | feature_list.append(feature8) 88 | feature_list.append(feature9) 89 | 90 | cos_value = np.zeros((height, weight, direction_classes - 1), dtype=np.float32) 91 | # print('cos_value.shape = {}'.format(cos_value.shape)) 92 | for k, feature_item in enumerate(feature_list): 93 | fenzi = (feature5[:, :, 0] * feature_item[:, :, 0] + feature5[:, :, 1] * feature_item[:, :, 1]) 94 | fenmu = (np.sqrt(pow(feature5[:, :, 0], 2) + pow(feature5[:, :, 1], 2)) * np.sqrt( 95 | pow(feature_item[:, :, 0], 2) + pow(feature_item[:, :, 1], 2)) + 0.000001) 96 | cos_np = fenzi / fenmu 97 | cos_value[:, :, k] = cos_np 98 | 99 | cos_value_min = np.min(cos_value, axis=2) 100 | cos_sim_map = cos_value_min 101 | cos_sim_map[label_direction == 0] = 1 102 | 103 | cos_sim_map_np = (1 - np.around(cos_sim_map)) 104 | cos_sim_map_np_max = np.max(cos_sim_map_np) 105 | cos_sim_map_np_min = np.min(cos_sim_map_np) 106 | cos_sim_map_np_normal = (cos_sim_map_np - cos_sim_map_np_min) / (cos_sim_map_np_max - cos_sim_map_np_min) 107 | 108 | return cos_sim_map_np_normal 109 | 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /data_prepare/logger.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #!/usr/bin/env python 4 | # -*- coding:utf-8 -*- 5 | # Author: Donny You(youansheng@gmail.com) 6 | # Logging tool implemented with the python Package logging. 7 | 8 | 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import argparse 14 | import logging 15 | import os 16 | import sys 17 | 18 | 19 | DEFAULT_LOGFILE_LEVEL = 'debug' 20 | DEFAULT_STDOUT_LEVEL = 'info' 21 | DEFAULT_LOG_FILE = './default.log' 22 | DEFAULT_LOG_FORMAT = '%(asctime)s %(levelname)-7s %(message)s' 23 | 24 | LOG_LEVEL_DICT = { 25 | 'debug': logging.DEBUG, 26 | 'info': logging.INFO, 27 | 'warning': logging.WARNING, 28 | 'error': logging.ERROR, 29 | 'critical': logging.CRITICAL 30 | } 31 | 32 | 33 | class Logger(object): 34 | """ 35 | Args: 36 | Log level: CRITICAL>ERROR>WARNING>INFO>DEBUG. 37 | Log file: The file that stores the logging info. 38 | rewrite: Clear the log file. 39 | log format: The format of log messages. 40 | stdout level: The log level to print on the screen. 41 | """ 42 | logfile_level = None 43 | log_file = None 44 | log_format = None 45 | rewrite = None 46 | stdout_level = None 47 | logger = None 48 | 49 | _caches = {} 50 | 51 | @staticmethod 52 | def init(logfile_level=DEFAULT_LOGFILE_LEVEL, 53 | log_file=DEFAULT_LOG_FILE, 54 | log_format=DEFAULT_LOG_FORMAT, 55 | rewrite=False, 56 | stdout_level=None): 57 | Logger.logfile_level = logfile_level 58 | Logger.log_file = log_file 59 | Logger.log_format = log_format 60 | Logger.rewrite = rewrite 61 | Logger.stdout_level = stdout_level 62 | 63 | Logger.logger = logging.getLogger() 64 | fmt = logging.Formatter(Logger.log_format) 65 | 66 | if Logger.logfile_level is not None: 67 | filemode = 'w' 68 | if not Logger.rewrite: 69 | filemode = 'a' 70 | 71 | dir_name = os.path.dirname(os.path.abspath(Logger.log_file)) 72 | if not os.path.exists(dir_name): 73 | os.makedirs(dir_name) 74 | 75 | if Logger.logfile_level not in LOG_LEVEL_DICT: 76 | print('Invalid logging level: {}'.format(Logger.logfile_level)) 77 | Logger.logfile_level = DEFAULT_LOGFILE_LEVEL 78 | 79 | Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.logfile_level]) 80 | 81 | fh = logging.FileHandler(Logger.log_file, mode=filemode) 82 | fh.setFormatter(fmt) 83 | fh.setLevel(LOG_LEVEL_DICT[Logger.logfile_level]) 84 | 85 | Logger.logger.addHandler(fh) 86 | 87 | if stdout_level is not None: 88 | if Logger.logfile_level is None: 89 | Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.stdout_level]) 90 | 91 | console = logging.StreamHandler() 92 | if Logger.stdout_level not in LOG_LEVEL_DICT: 93 | print('Invalid logging level: {}'.format(Logger.stdout_level)) 94 | return 95 | 96 | console.setLevel(LOG_LEVEL_DICT[Logger.stdout_level]) 97 | console.setFormatter(fmt) 98 | Logger.logger.addHandler(console) 99 | 100 | @staticmethod 101 | def set_log_file(file_path): 102 | Logger.log_file = file_path 103 | Logger.init(log_file=file_path) 104 | 105 | @staticmethod 106 | def set_logfile_level(log_level): 107 | if log_level not in LOG_LEVEL_DICT: 108 | print('Invalid logging level: {}'.format(log_level)) 109 | return 110 | 111 | Logger.init(logfile_level=log_level) 112 | 113 | @staticmethod 114 | def clear_log_file(): 115 | Logger.rewrite = True 116 | Logger.init(rewrite=True) 117 | 118 | @staticmethod 119 | def check_logger(): 120 | if Logger.logger is None: 121 | Logger.init(logfile_level=None, stdout_level=DEFAULT_STDOUT_LEVEL) 122 | 123 | @staticmethod 124 | def set_stdout_level(log_level): 125 | if log_level not in LOG_LEVEL_DICT: 126 | print('Invalid logging level: {}'.format(log_level)) 127 | return 128 | 129 | Logger.init(stdout_level=log_level) 130 | 131 | @staticmethod 132 | def debug(message): 133 | Logger.check_logger() 134 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 135 | lineno = sys._getframe().f_back.f_lineno 136 | prefix = '[{}, {}]'.format(filename,lineno) 137 | Logger.logger.debug('{} {}'.format(prefix, message)) 138 | 139 | @staticmethod 140 | def info(message): 141 | Logger.check_logger() 142 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 143 | lineno = sys._getframe().f_back.f_lineno 144 | prefix = '[{}, {}]'.format(filename,lineno) 145 | Logger.logger.info('{} {}'.format(prefix, message)) 146 | 147 | @staticmethod 148 | def info_once(message): 149 | Logger.check_logger() 150 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 151 | lineno = sys._getframe().f_back.f_lineno 152 | prefix = '[{}, {}]'.format(filename, lineno) 153 | 154 | if Logger._caches.get((prefix, message)) is not None: 155 | return 156 | 157 | Logger.logger.info('{} {}'.format(prefix, message)) 158 | Logger._caches[(prefix, message)] = True 159 | 160 | @staticmethod 161 | def warn(message): 162 | Logger.check_logger() 163 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 164 | lineno = sys._getframe().f_back.f_lineno 165 | prefix = '[{}, {}]'.format(filename,lineno) 166 | Logger.logger.warn('{} {}'.format(prefix, message)) 167 | 168 | @staticmethod 169 | def error(message): 170 | Logger.check_logger() 171 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 172 | lineno = sys._getframe().f_back.f_lineno 173 | prefix = '[{}, {}]'.format(filename,lineno) 174 | Logger.logger.error('{} {}'.format(prefix, message)) 175 | 176 | @staticmethod 177 | def critical(message): 178 | Logger.check_logger() 179 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename) 180 | lineno = sys._getframe().f_back.f_lineno 181 | prefix = '[{}, {}]'.format(filename,lineno) 182 | Logger.logger.critical('{} {}'.format(prefix, message)) 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument('--logfile_level', default="debug", type=str, 188 | dest='logfile_level', help='To set the log level to files.') 189 | parser.add_argument('--stdout_level', default=None, type=str, 190 | dest='stdout_level', help='To set the level to print to screen.') 191 | parser.add_argument('--log_file', default="./default.log", type=str, 192 | dest='log_file', help='The path of log files.') 193 | parser.add_argument('--log_format', default="%(asctime)s %(levelname)-7s %(message)s", 194 | type=str, dest='log_format', help='The format of log messages.') 195 | parser.add_argument('--rewrite', default=False, type=bool, 196 | dest='rewrite', help='Clear the log files existed.') 197 | 198 | args = parser.parse_args() 199 | Logger.init(logfile_level=args.logfile_level, stdout_level=args.stdout_level, 200 | log_file=args.log_file, log_format=args.log_format, rewrite=args.rewrite) 201 | 202 | Logger.info("info test.") 203 | Logger.debug("debug test.") 204 | Logger.warn("warn test.") 205 | Logger.error("error test.") 206 | Logger.debug("debug test.") -------------------------------------------------------------------------------- /hhl_utils/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import math 6 | import PIL 7 | 8 | def dir_exists(path): 9 | if not os.path.exists(path): 10 | os.makedirs(path) 11 | 12 | def initialize_weights(*models): 13 | for model in models: 14 | for m in model.modules(): 15 | if isinstance(m, nn.Conv2d): 16 | nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu') 17 | elif isinstance(m, nn.BatchNorm2d): 18 | m.weight.data.fill_(1.) 19 | m.bias.data.fill_(1e-4) 20 | elif isinstance(m, nn.Linear): 21 | m.weight.data.normal_(0.0, 0.0001) 22 | m.bias.data.zero_() 23 | 24 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 25 | factor = (kernel_size + 1) // 2 26 | if kernel_size % 2 == 1: 27 | center = factor - 1 28 | else: 29 | center = factor - 0.5 30 | og = np.ogrid[:kernel_size, :kernel_size] 31 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 32 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) 33 | weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt 34 | return torch.from_numpy(weight).float() 35 | 36 | def colorize_mask(mask, palette): 37 | zero_pad = 256 * 3 - len(palette) 38 | for i in range(zero_pad): 39 | palette.append(0) 40 | new_mask = PIL.Image.fromarray(mask.astype(np.uint8)).convert('P') 41 | new_mask.putpalette(palette) 42 | return new_mask 43 | 44 | def set_trainable_attr(m,b): 45 | m.trainable = b 46 | for p in m.parameters(): p.requires_grad = b 47 | 48 | def apply_leaf(m, f): 49 | c = m if isinstance(m, (list, tuple)) else list(m.children()) 50 | if isinstance(m, nn.Module): 51 | f(m) 52 | if len(c)>0: 53 | for l in c: 54 | apply_leaf(l,f) 55 | 56 | def set_trainable(l, b): 57 | apply_leaf(l, lambda m: set_trainable_attr(m,b)) -------------------------------------------------------------------------------- /hhl_utils/pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | 22 | mu1_sq = mu1.pow(2) 23 | mu2_sq = mu2.pow(2) 24 | mu1_mu2 = mu1*mu2 25 | 26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 29 | 30 | C1 = 0.01**2 31 | C2 = 0.03**2 32 | 33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | return ssim_map.mean(1).mean(1).mean(1) 39 | 40 | class SSIM(torch.nn.Module): 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def _logssim(img1, img2, window, window_size, channel, size_average = True): 67 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 68 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 69 | 70 | mu1_sq = mu1.pow(2) 71 | mu2_sq = mu2.pow(2) 72 | mu1_mu2 = mu1*mu2 73 | 74 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 75 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 76 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 77 | 78 | C1 = 0.01**2 79 | C2 = 0.03**2 80 | 81 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 82 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map)) 83 | ssim_map = -torch.log(ssim_map + 1e-8) 84 | 85 | if size_average: 86 | return ssim_map.mean() 87 | else: 88 | return ssim_map.mean(1).mean(1).mean(1) 89 | 90 | class LOGSSIM(torch.nn.Module): 91 | def __init__(self, window_size = 11, size_average = True): 92 | super(LOGSSIM, self).__init__() 93 | self.window_size = window_size 94 | self.size_average = size_average 95 | self.channel = 1 96 | self.window = create_window(window_size, self.channel) 97 | 98 | def forward(self, img1, img2): 99 | (_, channel, _, _) = img1.size() 100 | 101 | if channel == self.channel and self.window.data.type() == img1.data.type(): 102 | window = self.window 103 | else: 104 | window = create_window(self.window_size, channel) 105 | 106 | if img1.is_cuda: 107 | window = window.cuda(img1.get_device()) 108 | window = window.type_as(img1) 109 | 110 | self.window = window 111 | self.channel = channel 112 | 113 | 114 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average) 115 | 116 | 117 | def ssim(img1, img2, window_size = 11, size_average = True): 118 | (_, channel, _, _) = img1.size() 119 | window = create_window(window_size, channel) 120 | 121 | if img1.is_cuda: 122 | window = window.cuda(img1.get_device()) 123 | window = window.type_as(img1) 124 | 125 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /hhl_utils/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer#, required 4 | 5 | 6 | class RAdam(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 9 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 10 | self.buffer = [[None, None, None] for ind in range(10)] 11 | super(RAdam, self).__init__(params, defaults) 12 | 13 | def __setstate__(self, state): 14 | super(RAdam, self).__setstate__(state) 15 | 16 | def step(self, closure=None): 17 | 18 | loss = None 19 | if closure is not None: 20 | loss = closure() 21 | 22 | for group in self.param_groups: 23 | 24 | for p in group['params']: 25 | if p.grad is None: 26 | continue 27 | grad = p.grad.data.float() 28 | if grad.is_sparse: 29 | raise RuntimeError('RAdam does not support sparse gradients') 30 | 31 | p_data_fp32 = p.data.float() 32 | 33 | state = self.state[p] 34 | 35 | if len(state) == 0: 36 | state['step'] = 0 37 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 38 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 39 | else: 40 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 41 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 42 | 43 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 44 | beta1, beta2 = group['betas'] 45 | 46 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 47 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 48 | 49 | state['step'] += 1 50 | buffered = self.buffer[int(state['step'] % 10)] 51 | if state['step'] == buffered[0]: 52 | N_sma, step_size = buffered[1], buffered[2] 53 | else: 54 | buffered[0] = state['step'] 55 | beta2_t = beta2 ** state['step'] 56 | N_sma_max = 2 / (1 - beta2) - 1 57 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 58 | buffered[1] = N_sma 59 | 60 | # more conservative since it's an approximated value 61 | if N_sma >= 5: 62 | step_size = group['lr'] * math.sqrt( 63 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 64 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 65 | else: 66 | step_size = group['lr'] / (1 - beta1 ** state['step']) 67 | buffered[2] = step_size 68 | 69 | if group['weight_decay'] != 0: 70 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 71 | 72 | # more conservative since it's an approximated value 73 | if N_sma >= 5: 74 | denom = exp_avg_sq.sqrt().add_(group['eps']) 75 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 76 | else: 77 | p_data_fp32.add_(-step_size, exp_avg) 78 | 79 | p.data.copy_(p_data_fp32) 80 | 81 | return loss 82 | 83 | 84 | class RAdam_4step(Optimizer): 85 | 86 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, update_all=False, 87 | additional_four=False): 88 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 89 | self.update_all = update_all # whether update the first 4 steps 90 | self.additional_four = additional_four # whether use additional 4 steps for SGD 91 | self.buffer = [[None, None] for ind in range(10)] 92 | super(RAdam_4step, self).__init__(params, defaults) 93 | 94 | def __setstate__(self, state): 95 | super(RAdam_4step, self).__setstate__(state) 96 | 97 | def step(self, closure=None): 98 | 99 | loss = None 100 | if closure is not None: 101 | loss = closure() 102 | 103 | for group in self.param_groups: 104 | 105 | for p in group['params']: 106 | if p.grad is None: 107 | continue 108 | grad = p.grad.data.float() 109 | if grad.is_sparse: 110 | raise RuntimeError('RAdam_4step does not support sparse gradients') 111 | 112 | p_data_fp32 = p.data.float() 113 | 114 | state = self.state[p] 115 | 116 | if len(state) == 0: 117 | state[ 118 | 'step'] = -4 if self.additional_four else 0 # since this exp requires exactly 4 step, it is hard coded 119 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 121 | else: 122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 124 | 125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 126 | beta1, beta2 = group['betas'] 127 | 128 | state['step'] += 1 129 | 130 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 131 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 132 | 133 | if state['step'] > 0: 134 | 135 | state_step = state['step'] + 4 if self.additional_four else state[ 136 | 'step'] # since this exp requires exactly 4 step, it is hard coded 137 | 138 | buffered = self.buffer[int(state_step % 10)] 139 | if state_step == buffered[0]: 140 | step_size = buffered[1] 141 | else: 142 | buffered[0] = state_step 143 | beta2_t = beta2 ** state['step'] 144 | 145 | if state['step'] > 4: # since this exp requires exactly 4 step, it is hard coded 146 | N_sma_max = 2 / (1 - beta2) - 1 147 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 148 | step_size = group['lr'] * math.sqrt( 149 | (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / ( 150 | 1 - beta1 ** state_step) 151 | elif self.update_all: 152 | step_size = group['lr'] / (1 - beta1 ** state_step) 153 | else: 154 | step_size = 0 155 | buffered[1] = step_size 156 | 157 | if state['step'] > 4: # since this exp requires exactly 4 step, it is hard coded 158 | if group['weight_decay'] != 0: 159 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 160 | denom = (exp_avg_sq.sqrt() / math.sqrt(1 - beta2 ** state_step)).add_(group['eps']) 161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 162 | p.data.copy_(p_data_fp32) 163 | elif self.update_all: 164 | if group['weight_decay'] != 0: 165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 166 | denom = (exp_avg_sq.sqrt() / math.sqrt(1 - beta2 ** state_step)) 167 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 168 | p.data.copy_(p_data_fp32) 169 | else: 170 | state_step = state['step'] + 4 if self.additional_four else state[ 171 | 'step'] # since this exp requires exactly 4 step, it is hard coded 172 | 173 | if group['weight_decay'] != 0: 174 | p_data_fp32.add_(-group['weight_decay'] * 0.1, p_data_fp32) 175 | 176 | step_size = 0.1 / (1 - beta1 ** state_step) 177 | p_data_fp32.add_(-step_size, exp_avg) 178 | p.data.copy_(p_data_fp32) 179 | 180 | return loss 181 | 182 | 183 | class AdamW(Optimizer): 184 | 185 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 186 | weight_decay=0, use_variance=True, warmup=4000): 187 | defaults = dict(lr=lr, betas=betas, eps=eps, 188 | weight_decay=weight_decay, use_variance=True, warmup=warmup) 189 | print('======== Warmup: {} ========='.format(warmup)) 190 | super(AdamW, self).__init__(params, defaults) 191 | 192 | def __setstate__(self, state): 193 | super(AdamW, self).__setstate__(state) 194 | 195 | def step(self, closure=None): 196 | #global iter_idx 197 | #siter_idx += 1 198 | grad_list = list() 199 | mom_list = list() 200 | mom_2rd_list = list() 201 | 202 | loss = None 203 | if closure is not None: 204 | loss = closure() 205 | 206 | for group in self.param_groups: 207 | 208 | for p in group['params']: 209 | if p.grad is None: 210 | continue 211 | grad = p.grad.data.float() 212 | if grad.is_sparse: 213 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 214 | 215 | p_data_fp32 = p.data.float() 216 | 217 | state = self.state[p] 218 | 219 | if len(state) == 0: 220 | state['step'] = 0 221 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 222 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 223 | else: 224 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 225 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 226 | 227 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 228 | beta1, beta2 = group['betas'] 229 | 230 | state['step'] += 1 231 | 232 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 233 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 234 | 235 | denom = exp_avg_sq.sqrt().add_(group['eps']) 236 | bias_correction1 = 1 - beta1 ** state['step'] 237 | bias_correction2 = 1 - beta2 ** state['step'] 238 | 239 | if group['warmup'] > state['step']: 240 | scheduled_lr = 1e-6 + state['step'] * (group['lr'] - 1e-6) / group['warmup'] 241 | else: 242 | scheduled_lr = group['lr'] 243 | 244 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 245 | if group['weight_decay'] != 0: 246 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 247 | 248 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 249 | 250 | p.data.copy_(p_data_fp32) 251 | 252 | return loss -------------------------------------------------------------------------------- /hhl_utils/ranger.py: -------------------------------------------------------------------------------- 1 | #Ranger deep learning optimizer - RAdam + Lookahead combined. 2 | #https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 3 | 4 | #Ranger has now been used to capture 12 records on the FastAI leaderboard. 5 | 6 | #This version = 9.3.19 7 | 8 | #Credits: 9 | #RAdam --> https://github.com/LiyuanLucasLiu/RAdam 10 | #Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 11 | #Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 12 | 13 | #summary of changes: 14 | #full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 15 | #supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 16 | #changes 8/31/19 - fix references to *self*.N_sma_threshold; 17 | #changed eps to 1e-5 as better default than 1e-8. 18 | 19 | import math 20 | import torch 21 | from torch.optim.optimizer import Optimizer#, required 22 | import itertools as it 23 | 24 | 25 | 26 | class Ranger(Optimizer): 27 | 28 | def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.999), eps=1e-5, weight_decay=0): 29 | #parameter checks 30 | if not 0.0 <= alpha <= 1.0: 31 | raise ValueError(f'Invalid slow update rate: {alpha}') 32 | if not 1 <= k: 33 | raise ValueError(f'Invalid lookahead steps: {k}') 34 | if not lr > 0: 35 | raise ValueError(f'Invalid Learning Rate: {lr}') 36 | if not eps > 0: 37 | raise ValueError(f'Invalid eps: {eps}') 38 | 39 | #parameter comments: 40 | # beta1 (momentum) of .95 seems to work better than .90... 41 | #N_sma_threshold of 5 seems better in testing than 4. 42 | #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 43 | 44 | #prep defaults and init torch.optim base 45 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay) 46 | super().__init__(params,defaults) 47 | 48 | #adjustable threshold 49 | self.N_sma_threshhold = N_sma_threshhold 50 | 51 | #now we can get to work... 52 | #removed as we now use step from RAdam...no need for duplicate step counting 53 | #for group in self.param_groups: 54 | # group["step_counter"] = 0 55 | #print("group step counter init") 56 | 57 | #look ahead params 58 | self.alpha = alpha 59 | self.k = k 60 | 61 | #radam buffer for state 62 | self.radam_buffer = [[None,None,None] for ind in range(10)] 63 | 64 | #self.first_run_check=0 65 | 66 | #lookahead weights 67 | #9/2/19 - lookahead param tensors have been moved to state storage. 68 | #This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs. 69 | 70 | #self.slow_weights = [[p.clone().detach() for p in group['params']] 71 | # for group in self.param_groups] 72 | 73 | #don't use grad for lookahead weights 74 | #for w in it.chain(*self.slow_weights): 75 | # w.requires_grad = False 76 | 77 | def __setstate__(self, state): 78 | print("set state called") 79 | super(Ranger, self).__setstate__(state) 80 | 81 | 82 | def step(self, closure=None): 83 | loss = None 84 | #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 85 | #Uncomment if you need to use the actual closure... 86 | 87 | #if closure is not None: 88 | #loss = closure() 89 | 90 | #Evaluate averages and grad, update param tensors 91 | for group in self.param_groups: 92 | 93 | for p in group['params']: 94 | if p.grad is None: 95 | continue 96 | grad = p.grad.data.float() 97 | if grad.is_sparse: 98 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 99 | 100 | p_data_fp32 = p.data.float() 101 | 102 | state = self.state[p] #get state dict for this param 103 | 104 | if len(state) == 0: #if first time to run...init dictionary with our desired entries 105 | #if self.first_run_check==0: 106 | #self.first_run_check=1 107 | #print("Initializing slow buffer...should not see this at load from saved model!") 108 | state['step'] = 0 109 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 110 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 111 | 112 | #look ahead weight storage now in state dict 113 | state['slow_buffer'] = torch.empty_like(p.data) 114 | state['slow_buffer'].copy_(p.data) 115 | 116 | else: 117 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 118 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 119 | 120 | #begin computations 121 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 122 | beta1, beta2 = group['betas'] 123 | 124 | #compute variance mov avg 125 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 126 | #compute mean moving avg 127 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 128 | 129 | state['step'] += 1 130 | 131 | 132 | buffered = self.radam_buffer[int(state['step'] % 10)] 133 | if state['step'] == buffered[0]: 134 | N_sma, step_size = buffered[1], buffered[2] 135 | else: 136 | buffered[0] = state['step'] 137 | beta2_t = beta2 ** state['step'] 138 | N_sma_max = 2 / (1 - beta2) - 1 139 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 140 | buffered[1] = N_sma 141 | if N_sma > self.N_sma_threshhold: 142 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 143 | else: 144 | step_size = 1.0 / (1 - beta1 ** state['step']) 145 | buffered[2] = step_size 146 | 147 | if group['weight_decay'] != 0: 148 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 149 | 150 | if N_sma > self.N_sma_threshhold: 151 | denom = exp_avg_sq.sqrt().add_(group['eps']) 152 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 153 | else: 154 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 155 | 156 | p.data.copy_(p_data_fp32) 157 | 158 | #integrated look ahead... 159 | #we do it at the param level instead of group level 160 | if state['step'] % group['k'] == 0: 161 | slow_p = state['slow_buffer'] #get access to slow param tensor 162 | slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha 163 | p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor 164 | 165 | return loss -------------------------------------------------------------------------------- /hhl_utils/torchsummary.py: -------------------------------------------------------------------------------- 1 | """ 2 | A modied version of the code by Tae Hwan Jung 3 | https://github.com/graykode/modelsummary 4 | """ 5 | 6 | import torch 7 | import numpy as np 8 | import torch.nn as nn 9 | from collections import OrderedDict 10 | 11 | def summary(model, input_shape, batch_size=-1, intputshow=True): 12 | 13 | def register_hook(module): 14 | def hook(module, input, output=None): 15 | class_name = str(module.__class__).split(".")[-1].split("'")[0] 16 | module_idx = len(summary) 17 | 18 | m_key = "%s-%i" % (class_name, module_idx + 1) 19 | summary[m_key] = OrderedDict() 20 | summary[m_key]["input_shape"] = list(input[0].size()) 21 | summary[m_key]["input_shape"][0] = batch_size 22 | 23 | params = 0 24 | if hasattr(module, "weight") and hasattr(module.weight, "size"): 25 | params += torch.prod(torch.LongTensor(list(module.weight.size()))) 26 | summary[m_key]["trainable"] = module.weight.requires_grad 27 | if hasattr(module, "bias") and hasattr(module.bias, "size"): 28 | params += torch.prod(torch.LongTensor(list(module.bias.size()))) 29 | summary[m_key]["nb_params"] = params 30 | 31 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) 32 | and not (module == model)) and 'torch' in str(module.__class__): 33 | if intputshow is True: 34 | hooks.append(module.register_forward_pre_hook(hook)) 35 | else: 36 | hooks.append(module.register_forward_hook(hook)) 37 | 38 | # create properties 39 | summary = OrderedDict() 40 | hooks = [] 41 | 42 | # register hook 43 | model.apply(register_hook) 44 | model(torch.zeros(input_shape)) 45 | 46 | # remove these hooks 47 | for h in hooks: 48 | h.remove() 49 | 50 | model_info = '' 51 | 52 | model_info += "-----------------------------------------------------------------------\n" 53 | line_new = "{:>25} {:>25} {:>15}".format("Layer (type)", "Input Shape", "Param #") 54 | model_info += line_new + '\n' 55 | model_info += "=======================================================================\n" 56 | 57 | total_params = 0 58 | total_output = 0 59 | trainable_params = 0 60 | for layer in summary: 61 | line_new = "{:>25} {:>25} {:>15}".format( 62 | layer, 63 | str(summary[layer]["input_shape"]), 64 | "{0:,}".format(summary[layer]["nb_params"]), 65 | ) 66 | 67 | total_params += summary[layer]["nb_params"] 68 | if intputshow is True: 69 | total_output += np.prod(summary[layer]["input_shape"]) 70 | else: 71 | total_output += np.prod(summary[layer]["output_shape"]) 72 | if "trainable" in summary[layer]: 73 | if summary[layer]["trainable"] == True: 74 | trainable_params += summary[layer]["nb_params"] 75 | 76 | model_info += line_new + '\n' 77 | 78 | model_info += "=======================================================================\n" 79 | model_info += "Total params: {0:,}\n".format(total_params) 80 | model_info += "Trainable params: {0:,}\n".format(trainable_params) 81 | model_info += "Non-trainable params: {0:,}\n".format(total_params - trainable_params) 82 | model_info += "-----------------------------------------------------------------------\n" 83 | 84 | return model_info -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch import nn 4 | import numpy as np 5 | import math 6 | 7 | 8 | # combined with cross entropy loss, instance level 9 | class LossVariance(nn.Module): 10 | """ The instances in target should be labeled 11 | """ 12 | 13 | def __init__(self): 14 | super(LossVariance, self).__init__() 15 | 16 | def forward(self, input, target): 17 | 18 | B = input.size(0) 19 | 20 | loss = 0 21 | for k in range(B): 22 | unique_vals = target[k].unique() 23 | unique_vals = unique_vals[unique_vals != 0] 24 | 25 | sum_var = 0 26 | for val in unique_vals: 27 | instance = input[k][:, target[k] == val] 28 | if instance.size(1) > 1: 29 | sum_var += instance.var(dim=1).sum() 30 | 31 | loss += sum_var / (len(unique_vals) + 1e-8) 32 | loss /= B 33 | return loss 34 | 35 | 36 | 37 | class FocalLoss2d(nn.Module): 38 | def __init__(self, gamma=2, size_average=True, type="sigmoid"): 39 | super(FocalLoss2d, self).__init__() 40 | self.gamma = gamma 41 | self.size_average = size_average 42 | self.type = type 43 | 44 | def forward(self, logit, target, class_weight=None): 45 | target = target.view(-1, 1).long() 46 | if self.type == 'sigmoid': 47 | if class_weight is None: 48 | class_weight = [1]*2 49 | 50 | prob = F.sigmoid(logit) 51 | prob = prob.view(-1, 1) 52 | prob = torch.cat((1-prob, prob), 1) 53 | select = torch.FloatTensor(len(prob), 2).zero_().cuda() 54 | select.scatter_(1, target, 1.) 55 | 56 | elif self.type=='softmax': 57 | B,C,H,W = logit.size() 58 | if class_weight is None: 59 | class_weight =[1]*C 60 | 61 | logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C) 62 | prob = F.softmax(logit,1) 63 | select = torch.FloatTensor(len(prob), C).zero_().cuda() 64 | select.scatter_(1, target, 1.) 65 | 66 | class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1) 67 | class_weight = torch.gather(class_weight, 0, target) 68 | 69 | prob = (prob*select).sum(1).view(-1,1) 70 | prob = torch.clamp(prob,1e-8,1-1e-8) 71 | batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log() 72 | 73 | if self.size_average: 74 | loss = batch_loss.mean() 75 | else: 76 | loss = batch_loss 77 | 78 | return loss 79 | 80 | # Robust focal loss 81 | class RobustFocalLoss2d(nn.Module): 82 | #assume top 10% is outliers 83 | def __init__(self, gamma=2, size_average=True, type="sigmoid"): 84 | super(RobustFocalLoss2d, self).__init__() 85 | self.gamma = gamma 86 | self.size_average = size_average 87 | self.type = type 88 | 89 | def forward(self, logit, target, class_weight=None): 90 | target = target.view(-1, 1).long() 91 | if self.type=='sigmoid': 92 | if class_weight is None: 93 | class_weight = [1]*2 94 | 95 | prob = F.sigmoid(logit) 96 | prob = prob.view(-1, 1) 97 | prob = torch.cat((1-prob, prob), 1) 98 | select = torch.FloatTensor(len(prob), 2).zero_().cuda() 99 | select.scatter_(1, target, 1.) 100 | 101 | elif self.type=='softmax': 102 | B,C,H,W = logit.size() 103 | if class_weight is None: 104 | class_weight =[1]*C 105 | 106 | logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C) 107 | prob = F.softmax(logit,1) 108 | select = torch.FloatTensor(len(prob), C).zero_().cuda() 109 | select.scatter_(1, target, 1.) 110 | 111 | class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1) 112 | class_weight = torch.gather(class_weight, 0, target) 113 | 114 | prob = (prob*select).sum(1).view(-1,1) 115 | prob = torch.clamp(prob,1e-8,1-1e-8) 116 | 117 | focus = torch.pow((1-prob), self.gamma) 118 | focus = torch.clamp(focus,0,2) 119 | 120 | batch_loss = - class_weight *focus*prob.log() 121 | 122 | if self.size_average: 123 | loss = batch_loss.mean() 124 | else: 125 | loss = batch_loss 126 | 127 | return loss 128 | 129 | 130 | 131 | class DiceLoss(nn.Module): 132 | def __init__(self): 133 | super(DiceLoss, self).__init__() 134 | 135 | def forward(self, input, target): 136 | N = target.size(0) 137 | smooth = 1 138 | 139 | input_flat = input.view(N, -1) 140 | target_flat = target.view(N, -1) 141 | 142 | intersection = input_flat * target_flat 143 | 144 | loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth) 145 | loss = 1 - loss.sum() / N 146 | 147 | return loss 148 | 149 | 150 | class MulticlassDiceLoss(nn.Module): 151 | """ 152 | requires one hot encoded target. Applies DiceLoss on each class iteratively. 153 | requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is 154 | batch size and C is number of classes 155 | """ 156 | 157 | def __init__(self): 158 | super(MulticlassDiceLoss, self).__init__() 159 | 160 | def forward(self, input, target, weights=None): 161 | 162 | C = target.shape[1] 163 | 164 | # if weights is None: 165 | # weights = torch.ones(C) #uniform weights for all classes 166 | 167 | dice = DiceLoss() 168 | totalLoss = 0 169 | 170 | for i in range(C): 171 | diceLoss = dice(input[:, i], target[:, i]) 172 | if weights is not None: 173 | diceLoss *= weights[i] 174 | totalLoss += diceLoss 175 | 176 | return totalLoss 177 | 178 | 179 | 180 | 181 | class Weight_DiceLoss(nn.Module): 182 | def __init__(self): 183 | super(Weight_DiceLoss, self).__init__() 184 | 185 | def forward(self, input, target, weights): 186 | N = target.size(0) 187 | smooth = 1 188 | 189 | input_flat = input.view(N, -1) 190 | target_flat = target.view(N, -1) 191 | weights = weights.view(N, -1) 192 | 193 | intersection = input_flat * target_flat 194 | intersection = intersection * weights 195 | 196 | dice = 2 * (intersection.sum(1) + smooth) / ((input_flat * weights).sum(1) + (target_flat * weights).sum(1) + smooth) 197 | loss = 1 - dice.sum() / N 198 | 199 | return loss 200 | 201 | 202 | class WeightMulticlassDiceLoss(nn.Module): 203 | """ 204 | requires one hot encoded target. Applies DiceLoss on each class iteratively. 205 | requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is 206 | batch size and C is number of classes 207 | """ 208 | 209 | def __init__(self): 210 | super(WeightMulticlassDiceLoss, self).__init__() 211 | 212 | def forward(self, input, target, weights=None): 213 | 214 | C = target.shape[1] 215 | 216 | # if weights is None: 217 | # weights = torch.ones(C) #uniform weights for all classes 218 | # weights[0] = 3 219 | dice = DiceLoss() 220 | wdice = Weight_DiceLoss() 221 | totalLoss = 0 222 | 223 | for i in range(C): 224 | # diceLoss = dice(input[:, i], target[:, i]) 225 | # diceLoss2 = 1 - wdice(input[:, i], target[:, i - 1]) 226 | # diceLoss3 = 1 - wdice(input[:, i], target[:, i%(C-1) + 1]) 227 | # diceLoss = diceLoss - diceLoss2 - diceLoss3 228 | 229 | # diceLoss = dice(input[:, i - 1] + input[:, i] + input[:, i%(C-1) + 1], target[:, i]) 230 | '''''' 231 | if (i == 0): 232 | diceLoss = wdice(input[:, i], target[:, i], weights) * 2 233 | elif (i == 1): 234 | # diceLoss = dice(input[:, C - 1] + input[:, i] + input[:, i + 1], target[:, i]) 235 | diceLoss = wdice(input[:, i], target[:, i], weights) 236 | diceLoss2 = 1 - wdice(input[:, i], target[:, C - 1], weights) 237 | diceLoss3 = 1 - wdice(input[:, i], target[:, i + 1], weights) 238 | diceLoss = diceLoss - diceLoss2 - diceLoss3 239 | 240 | elif (i == C - 1): 241 | # diceLoss = dice(input[:, i - 1] + input[:, i] + input[:, 1], target[:, i]) 242 | diceLoss = wdice(input[:, i], target[:, i], weights) 243 | diceLoss2 = 1 - wdice(input[:, i], target[:, i - 1], weights) 244 | diceLoss3 = 1 - wdice(input[:, i], target[:, 1], weights) 245 | diceLoss = diceLoss - diceLoss2 - diceLoss3 246 | 247 | else: 248 | # diceLoss = dice(input[:, i - 1] + input[:, i] + input[:, i + 1], target[:, i]) 249 | diceLoss = wdice(input[:, i], target[:, i], weights) 250 | diceLoss2 = 1 - wdice(input[:, i], target[:, i - 1], weights) 251 | diceLoss3 = 1 - wdice(input[:, i], target[:, i + 1], weights) 252 | diceLoss = diceLoss - diceLoss2 - diceLoss3 253 | 254 | #if weights is not None: 255 | #diceLoss *= weights[i] 256 | 257 | totalLoss += diceLoss 258 | avgLoss = totalLoss/C 259 | 260 | return avgLoss 261 | 262 | 263 | 264 | 265 | 266 | class CenterLoss(nn.Module): 267 | """Center loss. 268 | 269 | Reference: 270 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 271 | 272 | Args: 273 | num_classes (int): number of classes. 274 | feat_dim (int): feature dimension. 275 | """ 276 | 277 | def __init__(self, num_classes=3, feat_dim=3, use_gpu=True): 278 | super(CenterLoss, self).__init__() 279 | self.num_classes = num_classes 280 | self.feat_dim = feat_dim 281 | self.use_gpu = use_gpu 282 | 283 | if self.use_gpu: 284 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 285 | print(self.centers) 286 | else: 287 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 288 | 289 | def forward(self, input_x, input_label): 290 | """ 291 | Args: 292 | x: feature matrix with shape (batch_size, feat_dim). 293 | labels: ground truth labels with shape (batch_size). 294 | """ 295 | labels = input_label 296 | batch_size = input_x.size(0) 297 | channels = input_x.size(1) 298 | 299 | distmat = torch.pow(input_x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 300 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 301 | distmat.addmm_(1, -2, input_x, self.centers.t()) # math:: out = beta * mat + alpha * (mat1_i @ mat2_i) 302 | 303 | classes = torch.arange(self.num_classes).long() 304 | if self.use_gpu: classes = classes.cuda() 305 | labels2 = input_label.unsqueeze(1).expand(batch_size, self.num_classes) 306 | mask = labels2.cuda().eq(classes.expand(batch_size, self.num_classes)) # eq() 想等返回1, 不相等返回0 307 | 308 | dist = distmat * mask.float() 309 | 310 | 311 | # torch.clamp(input, min, max, out=None) 将输入input张量每个元素的夹紧到区间 [min,max][min,max],并返回结果到一个新张量 312 | 313 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 314 | 315 | return loss 316 | 317 | 318 | 319 | import torch.nn as nn 320 | import torch.nn.functional as F 321 | 322 | def one_hot(label, n_classes, requires_grad=True): 323 | """Return One Hot Label""" 324 | device = label.device 325 | one_hot_label = torch.eye(n_classes, device=device, requires_grad=requires_grad)[label] 326 | one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3) 327 | 328 | return one_hot_label 329 | 330 | 331 | class BoundaryLoss(nn.Module): 332 | """Boundary Loss proposed in: 333 | Alexey Bokhovkin et al., Boundary Loss for Remote Sensing Imagery Semantic Segmentation 334 | https://arxiv.org/abs/1905.07852 335 | """ 336 | 337 | def __init__(self, theta0=3, theta=5): 338 | super().__init__() 339 | 340 | self.theta0 = theta0 341 | self.theta = theta 342 | 343 | def forward(self, pred_output, gt): 344 | """ 345 | Input: 346 | - pred_output: the output from model (before softmax) 347 | shape (N, C, H, W) 348 | - gt: ground truth map #这是原来的输入,最新输入为(N, C, H, W) 349 | shape (N, H, w) 350 | Return: 351 | - boundary loss, averaged over mini-bathc 352 | """ 353 | 354 | n, c, _, _ = pred_output.shape 355 | 356 | # softmax so that predicted map can be distributed in [0, 1] 357 | pred = torch.softmax(pred_output, dim=1) 358 | 359 | # one-hot vector of ground truth 360 | #one_hot_gt = one_hot(gt.long(), c) # 这是原来的输入,最新输入为(N, C, H, W) 361 | one_hot_gt = gt 362 | 363 | 364 | 365 | # boundary map 366 | gt_b = F.max_pool2d(1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) 367 | gt_b -= 1 - one_hot_gt 368 | 369 | pred_b = F.max_pool2d(1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) 370 | pred_b -= 1 - pred 371 | 372 | # extended boundary map 373 | gt_b_ext = F.max_pool2d(gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) 374 | 375 | pred_b_ext = F.max_pool2d(pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) 376 | 377 | # reshape 378 | gt_b = gt_b.view(n, c, -1) 379 | pred_b = pred_b.view(n, c, -1) 380 | gt_b_ext = gt_b_ext.view(n, c, -1) 381 | pred_b_ext = pred_b_ext.view(n, c, -1) 382 | 383 | # Precision, Recall 384 | P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7) 385 | R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7) 386 | 387 | # Boundary F1 Score 388 | BF1 = 2 * P * R / (P + R + 1e-7) 389 | 390 | # summing BF1 Score for each class and average over mini-batch 391 | loss = torch.mean(1 - BF1) 392 | 393 | return loss 394 | 395 | 396 | 397 | 398 | 399 | def dice_loss(input, target, eps=1e-7, if_sigmoid=True): 400 | if if_sigmoid: 401 | input = F.sigmoid(input) 402 | b = input.shape[0] 403 | iflat = input.contiguous().view(b, -1) 404 | tflat = target.float().contiguous().view(b, -1) 405 | intersection = (iflat * tflat).sum(dim=1) 406 | L = (1 - ((2. * intersection + eps) / (iflat.pow(2).sum(dim=1) + tflat.pow(2).sum(dim=1) + eps))).mean() 407 | return L 408 | 409 | def smooth_truncated_loss(p, t, ths=0.06, if_reduction=True, if_balance=True): 410 | n_log_pt = F.binary_cross_entropy_with_logits(p, t, reduction='none') 411 | pt = (-n_log_pt).exp() 412 | L = torch.where(pt>=ths, n_log_pt, -math.log(ths)+0.5*(1-pt.pow(2)/(ths**2))) 413 | if if_reduction: 414 | if if_balance: 415 | return 0.5*((L*t).sum()/t.sum().clamp(1) + (L*(1-t)).sum()/(1-t).sum().clamp(1)) 416 | else: 417 | return L.mean() 418 | else: 419 | return L 420 | 421 | def balance_bce_loss(input, target): 422 | L0 = F.binary_cross_entropy_with_logits(input, target, reduction='none') 423 | return 0.5*((L0*target).sum()/target.sum().clamp(1)+(L0*(1-target)).sum()/(1-target).sum().clamp(1)) 424 | 425 | def compute_loss_list(loss_func, pred=[], target=[], **kwargs): 426 | losses = [] 427 | for ipred, itarget in zip(pred, target): 428 | losses.append(loss_func(ipred, itarget, **kwargs)) 429 | return losses 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | -------------------------------------------------------------------------------- /models/FullNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script defines the structure of FullNet 3 | 4 | Author: Hui Qu 5 | """ 6 | 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import math 12 | 13 | 14 | class ConvLayer(nn.Sequential): 15 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 16 | groups=1): 17 | super(ConvLayer, self).__init__() 18 | self.add_module('conv', nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, 19 | padding=padding, dilation=dilation, bias=False, groups=groups)) 20 | self.add_module('relu', nn.LeakyReLU(inplace=True)) 21 | self.add_module('bn', nn.BatchNorm2d(out_channels)) 22 | 23 | 24 | # --- different types of layers --- # 25 | class BasicLayer(nn.Sequential): 26 | def __init__(self, in_channels, growth_rate, drop_rate, dilation=1): 27 | super(BasicLayer, self).__init__() 28 | self.conv = ConvLayer(in_channels, growth_rate, kernel_size=3, stride=1, padding=dilation, 29 | dilation=dilation) 30 | self.drop_rate = drop_rate 31 | 32 | def forward(self, x): 33 | out = self.conv(x) 34 | if self.drop_rate > 0: 35 | out = F.dropout(out, p=self.drop_rate, training=self.training) 36 | return torch.cat([x, out], 1) 37 | 38 | 39 | class BottleneckLayer(nn.Sequential): 40 | def __init__(self, in_channels, growth_rate, drop_rate, dilation=1): 41 | super(BottleneckLayer, self).__init__() 42 | 43 | inter_planes = growth_rate * 4 44 | self.conv1 = ConvLayer(in_channels, inter_planes, kernel_size=1, padding=0) 45 | self.conv2 = ConvLayer(inter_planes, growth_rate, kernel_size=3, padding=dilation, dilation=dilation) 46 | self.drop_rate = drop_rate 47 | 48 | def forward(self, x): 49 | out = self.conv2(self.conv1(x)) 50 | if self.drop_rate > 0: 51 | out = F.dropout(out, p=self.drop_rate, training=self.training) 52 | return torch.cat([x, out], 1) 53 | 54 | 55 | # --- dense block structure --- # 56 | class DenseBlock(nn.Sequential): 57 | def __init__(self, in_channels, growth_rate, drop_rate, layer_type, dilations): 58 | super(DenseBlock, self).__init__() 59 | for i in range(len(dilations)): 60 | layer = layer_type(in_channels+i*growth_rate, growth_rate, drop_rate, dilations[i]) 61 | self.add_module('denselayer{:d}'.format(i+1), layer) 62 | 63 | 64 | def choose_hybrid_dilations(n_layers, dilation_schedule, is_hybrid): 65 | import numpy as np 66 | # key: (dilation, n_layers) 67 | HD_dict = {(1, 4): [1, 1, 1, 1], 68 | (2, 4): [1, 2, 3, 2], 69 | (4, 4): [1, 2, 5, 9], 70 | (8, 4): [3, 7, 10, 13], 71 | (16, 4): [13, 15, 17, 19], 72 | (1, 6): [1, 1, 1, 1, 1, 1], 73 | (2, 6): [1, 2, 3, 1, 2, 3], 74 | (4, 6): [1, 2, 3, 5, 6, 7], 75 | (8, 6): [2, 5, 7, 9, 11, 14], 76 | (16, 6): [10, 13, 16, 17, 19, 21]} 77 | 78 | dilation_list = np.zeros((len(dilation_schedule), n_layers), dtype=np.int32) 79 | 80 | for i in range(len(dilation_schedule)): 81 | dilation = dilation_schedule[i] 82 | if is_hybrid: 83 | dilation_list[i] = HD_dict[(dilation, n_layers)] 84 | else: 85 | dilation_list[i] = [dilation for k in range(n_layers)] 86 | 87 | return dilation_list 88 | 89 | 90 | class FullNet(nn.Module): 91 | def __init__(self, color_channels, output_channels=2, n_layers=6, growth_rate=24, compress_ratio=0.5, 92 | drop_rate=0.1, dilations=(1,2,4,8,16,4,1), is_hybrid=True, layer_type='basic'): 93 | super(FullNet, self).__init__() 94 | if layer_type == 'basic': 95 | layer_type = BasicLayer 96 | else: 97 | layer_type = BottleneckLayer 98 | 99 | # 1st conv before any dense block 100 | in_channels = 24 101 | self.conv1 = ConvLayer(color_channels, in_channels, kernel_size=3, padding=1) 102 | 103 | self.blocks = nn.Sequential() 104 | n_blocks = len(dilations) 105 | 106 | dilation_list = choose_hybrid_dilations(n_layers, dilations, is_hybrid) 107 | 108 | for i in range(n_blocks): # no trans in last block 109 | block = DenseBlock(in_channels, growth_rate, drop_rate, layer_type, dilation_list[i]) 110 | self.blocks.add_module('block%d' % (i+1), block) 111 | num_trans_in = int(in_channels + n_layers * growth_rate) 112 | num_trans_out = int(math.floor(num_trans_in * compress_ratio)) 113 | trans = ConvLayer(num_trans_in, num_trans_out, kernel_size=1, padding=0) 114 | self.blocks.add_module('trans%d' % (i+1), trans) 115 | in_channels = num_trans_out 116 | #print('block.size = ', block) 117 | #print('num_trans_in = ', num_trans_in, 'num_trans_out = ', num_trans_out) 118 | #print('trans.size = ', trans) 119 | 120 | # final conv 121 | self.conv2 = nn.Conv2d(in_channels, output_channels, kernel_size=3, stride=1, 122 | padding=1, bias=False) 123 | # initialization 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 127 | m.weight.data.normal_(0, math.sqrt(2. / n)) 128 | elif isinstance(m, nn.BatchNorm2d): 129 | m.weight.data.fill_(1) 130 | m.bias.data.zero_() 131 | elif isinstance(m, nn.Linear): 132 | m.bias.data.zero_() 133 | 134 | def forward(self, x): 135 | out = self.conv1(x) 136 | out = self.blocks(out) 137 | out = self.conv2(out) 138 | return out 139 | 140 | 141 | class FCN_pooling(nn.Module): 142 | """same structure with FullNet, except that there are pooling operations after block 1, 2, 3, 4 143 | and upsampling after block 5, 6 144 | """ 145 | def __init__(self, color_channels, output_channels=2, n_layers=6, growth_rate=24, compress_ratio=0.5, 146 | drop_rate=0.1, dilations=(1,2,4,8,16,4,1), is_hybrid=True, layer_type='basic'): 147 | super(FCN_pooling, self).__init__() 148 | if layer_type == 'basic': 149 | layer_type = BasicLayer 150 | else: 151 | layer_type = BottleneckLayer 152 | 153 | # 1st conv before any dense block 154 | in_channels = 24 155 | self.conv1 = ConvLayer(color_channels, in_channels, kernel_size=3, padding=1) 156 | 157 | self.blocks = nn.Sequential() 158 | n_blocks = len(dilations) 159 | 160 | dilation_list = choose_hybrid_dilations(n_layers, dilations, is_hybrid) 161 | 162 | for i in range(7): 163 | block = DenseBlock(in_channels, growth_rate, drop_rate, layer_type, dilation_list[i]) 164 | self.blocks.add_module('block{:d}'.format(i+1), block) 165 | num_trans_in = int(in_channels + n_layers * growth_rate) 166 | num_trans_out = int(math.floor(num_trans_in * compress_ratio)) 167 | trans = ConvLayer(num_trans_in, num_trans_out, kernel_size=1, padding=0) 168 | self.blocks.add_module('trans{:d}'.format(i+1), trans) 169 | if i in range(0, 4): 170 | self.blocks.add_module('pool{:d}'.format(i+1), nn.MaxPool2d(kernel_size=2, stride=2)) 171 | elif i in range(4, 6): 172 | self.blocks.add_module('upsample{:d}'.format(i + 1), nn.UpsamplingBilinear2d(scale_factor=4)) 173 | in_channels = num_trans_out 174 | 175 | # final conv 176 | self.conv2 = nn.Conv2d(in_channels, output_channels, kernel_size=3, stride=1, 177 | padding=1, bias=False) 178 | # initialization 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv2d): 181 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 182 | m.weight.data.normal_(0, math.sqrt(2. / n)) 183 | elif isinstance(m, nn.BatchNorm2d): 184 | m.weight.data.fill_(1) 185 | m.bias.data.zero_() 186 | elif isinstance(m, nn.Linear): 187 | m.bias.data.zero_() 188 | 189 | def forward(self, x): 190 | out = self.conv1(x) 191 | out = self.blocks(out) 192 | out = self.conv2(out) 193 | return out 194 | -------------------------------------------------------------------------------- /models/dam/model_unet_MandD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models, datasets, transforms 4 | from torch.nn import functional as F 5 | import os 6 | 7 | 8 | class revAttention(nn.Module): #sSE 9 | def __init__(self, in_channels): 10 | super().__init__() 11 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False) 12 | self.norm = nn.Sigmoid() 13 | 14 | def forward(self, U, V): 15 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w] 16 | q = self.norm(q) 17 | return U * (1+q) # 18 | 19 | 20 | 21 | 22 | def get_backbone(name, pretrained=True): 23 | 24 | """ Loading backbone, defining names for skip-connections and encoder output. """ 25 | 26 | # TODO: More backbones 27 | 28 | # loading backbone model 29 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}') 30 | if name == 'resnet18': 31 | backbone = models.resnet18(pretrained=pretrained) 32 | elif name == 'resnet34': 33 | backbone = models.resnet34(pretrained=pretrained) 34 | elif name == 'resnet50': 35 | backbone = models.resnet50(pretrained=pretrained) 36 | elif name == 'resnet101': 37 | backbone = models.resnet101(pretrained=pretrained) 38 | elif name == 'resnet152': 39 | backbone = models.resnet152(pretrained=pretrained) 40 | elif name == 'vgg16_bn': 41 | backbone = models.vgg16_bn(pretrained=pretrained).features 42 | elif name == 'vgg19_bn': 43 | backbone = models.vgg19_bn(pretrained=pretrained).features 44 | # elif name == 'inception_v3': 45 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False) 46 | elif name == 'densenet121': 47 | backbone = models.densenet121(pretrained=True).features 48 | elif name == 'densenet161': 49 | backbone = models.densenet161(pretrained=True).features 50 | elif name == 'densenet169': 51 | backbone = models.densenet169(pretrained=True).features 52 | elif name == 'densenet201': 53 | backbone = models.densenet201(pretrained=True).features 54 | elif name == 'unet_encoder': 55 | from unet_backbone import UnetEncoder 56 | backbone = UnetEncoder(3) 57 | else: 58 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 59 | 60 | # specifying skip feature and output names 61 | if name.startswith('resnet'): 62 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3'] 63 | backbone_output = 'layer4' 64 | elif name == 'vgg16_bn': 65 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output 66 | feature_names = ['5', '12', '22', '32', '42'] 67 | backbone_output = '43' 68 | elif name == 'vgg19_bn': 69 | feature_names = ['5', '12', '25', '38', '51'] 70 | backbone_output = '52' 71 | # elif name == 'inception_v3': 72 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e'] 73 | # backbone_output = 'Mixed_7c' 74 | elif name.startswith('densenet'): 75 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3'] 76 | backbone_output = 'denseblock4' 77 | elif name == 'unet_encoder': 78 | feature_names = ['module1', 'module2', 'module3', 'module4'] 79 | backbone_output = 'module5' 80 | else: 81 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 82 | 83 | return backbone, feature_names, backbone_output 84 | 85 | 86 | class UpsampleBlock(nn.Module): 87 | 88 | # TODO: separate parametric and non-parametric classes? 89 | # TODO: skip connection concatenated OR added 90 | 91 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False): 92 | super(UpsampleBlock, self).__init__() 93 | 94 | self.parametric = parametric 95 | ch_out = ch_in/2 if ch_out is None else ch_out 96 | 97 | # first convolution: either transposed conv, or conv following the skip connection 98 | if parametric: 99 | # versions: kernel=4 padding=1, kernel=2 padding=0 100 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4), 101 | stride=2, padding=1, output_padding=0, bias=(not use_bn)) 102 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 103 | else: 104 | self.up = None 105 | ch_in = ch_in + skip_in 106 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3), 107 | stride=1, padding=1, bias=(not use_bn)) 108 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 109 | 110 | self.relu = nn.ReLU(inplace=True) 111 | 112 | # second convolution 113 | conv2_in = ch_out if not parametric else ch_out + skip_in 114 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3), 115 | stride=1, padding=1, bias=(not use_bn)) 116 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None 117 | 118 | #def forward(self, x, skip_connection=None): # 119 | def forward(self, x, skip_connection=1): # 120 | 121 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear', 122 | align_corners=None) 123 | if self.parametric: 124 | x = self.bn1(x) if self.bn1 is not None else x 125 | x = self.relu(x) 126 | 127 | if skip_connection is not None: 128 | # Padding in case the incomping volumes are of different sizes #hhl20200413add 129 | diffY = skip_connection.size()[2] - x.size()[2] 130 | diffX = skip_connection.size()[3] - x.size()[3] 131 | x = F.pad(x, (diffX // 2, diffX - diffX // 2, 132 | diffY // 2, diffY - diffY // 2)) 133 | 134 | x = torch.cat([x, skip_connection], dim=1) 135 | 136 | if not self.parametric: 137 | x = self.conv1(x) 138 | x = self.bn1(x) if self.bn1 is not None else x 139 | x = self.relu(x) 140 | x = self.conv2(x) 141 | x = self.bn2(x) if self.bn2 is not None else x 142 | x = self.relu(x) 143 | 144 | return x 145 | 146 | 147 | 148 | def conv3x3(in_channels, out_channels, stride=1): 149 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 150 | 151 | class ResidualUnit(nn.Module): 152 | def __init__(self, in_channels, out_channels): 153 | super(ResidualUnit, self).__init__() 154 | self.conv1 = conv3x3(in_channels, out_channels, stride=1) 155 | self.bn1 = nn.BatchNorm2d(out_channels) 156 | self.relu1 = nn.ReLU(inplace=True) 157 | self.conv2 = conv3x3(out_channels, out_channels, stride=1) 158 | self.bn2 = nn.BatchNorm2d(out_channels) 159 | self.relu2 = nn.ReLU(inplace=True) 160 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) 161 | 162 | def forward(self, x): 163 | residual = self.conv_1x1(x) 164 | out = self.conv1(x) 165 | out = self.bn1(out) 166 | out = self.relu1(out) 167 | out = self.conv2(out) 168 | out = self.bn2(out) 169 | out += residual 170 | out = self.relu2(out) 171 | return out 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | class Unet(nn.Module): 182 | 183 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones.""" 184 | 185 | def __init__(self, 186 | backbone_name='resnet50', 187 | pretrained=True, 188 | encoder_freeze=False, 189 | classes=21, 190 | decoder_filters=(256, 128, 64, 32, 16), 191 | parametric_upsampling=True, 192 | shortcut_features='default', 193 | decoder_use_batchnorm=True): 194 | super(Unet, self).__init__() 195 | 196 | self.backbone_name = backbone_name 197 | 198 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained) 199 | shortcut_chs, bb_out_chs = self.infer_skip_channels() 200 | if shortcut_features != 'default': 201 | self.shortcut_features = shortcut_features 202 | 203 | # build decoder part 204 | self.upsample_blocks = nn.ModuleList() 205 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections 206 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1]) 207 | num_blocks = len(self.shortcut_features) 208 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)): 209 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out)) 210 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out, 211 | skip_in=shortcut_chs[num_blocks-i-1], 212 | parametric=parametric_upsampling, 213 | use_bn=decoder_use_batchnorm)) 214 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1)) 215 | 216 | if encoder_freeze: 217 | self.freeze_encoder() 218 | 219 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later 220 | 221 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # 222 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 223 | 224 | 225 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64) 226 | self.direction_feature = ResidualUnit(64, 64) 227 | self.point_feature = ResidualUnit(64, 64) 228 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1) 229 | self.directionAtt = revAttention(1) 230 | self.direction_conv = nn.Conv2d(64, 9, kernel_size=1) 231 | self.maskAtt = revAttention(9) 232 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1) 233 | 234 | self.residual = ResidualUnit(64, 64) 235 | 236 | 237 | 238 | def freeze_encoder(self): 239 | 240 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """ 241 | 242 | for param in self.backbone.parameters(): 243 | param.requires_grad = False 244 | 245 | def forward(self, *input): 246 | 247 | """ Forward propagation in U-Net. """ 248 | 249 | x, features = self.forward_backbone(*input) 250 | 251 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks): 252 | skip_features = features[skip_name] 253 | x = upsample_block(x, skip_features) 254 | 255 | x_F1 = self.mask_feature(x) 256 | 257 | x_F2 = self.direction_feature(x_F1) 258 | 259 | x_direction = self.direction_conv(x_F2) 260 | 261 | x_F1_mask = self.residual(x_F1) 262 | x_final_mask = self.mask_conv(x_F1_mask) 263 | 264 | 265 | 266 | return x_final_mask, x_direction 267 | 268 | def forward_backbone(self, x): 269 | 270 | """ Forward propagation in backbone encoder network. """ 271 | 272 | features = {None: None} if None in self.shortcut_features else dict() 273 | for name, child in self.backbone.named_children(): 274 | 275 | if(name == '0' and x.shape[1] !=3): 276 | x = self.child0(x) 277 | elif(name == 'conv1' and x.shape[1] !=3): 278 | x = self.child_conv1(x) 279 | else: 280 | x = child(x) 281 | #x = child(x) 282 | if name in self.shortcut_features: 283 | features[name] = x 284 | if name == self.bb_out_name: 285 | break 286 | 287 | return x, features 288 | 289 | def infer_skip_channels(self): 290 | 291 | """ Getting the number of channels at skip connections and at the output of the encoder. """ 292 | 293 | x = torch.zeros(1, 3, 224, 224) 294 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder' 295 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution 296 | 297 | # forward run in backbone to count channels (dirty solution but works for *any* Module) 298 | for name, child in self.backbone.named_children(): 299 | x = child(x) 300 | if name in self.shortcut_features: 301 | channels.append(x.shape[1]) 302 | if name == self.bb_out_name: 303 | out_channels = x.shape[1] 304 | break 305 | return channels, out_channels 306 | 307 | def get_pretrained_parameters(self): 308 | for name, param in self.backbone.named_parameters(): 309 | if not (self.replaced_conv1 and name == 'conv1.weight'): 310 | yield param 311 | 312 | def get_random_initialized_parameters(self): 313 | pretrained_param_names = set() 314 | for name, param in self.backbone.named_parameters(): 315 | if not (self.replaced_conv1 and name == 'conv1.weight'): 316 | pretrained_param_names.add('backbone.{}'.format(name)) 317 | 318 | for name, param in self.named_parameters(): 319 | if name not in pretrained_param_names: 320 | yield param 321 | 322 | 323 | # if __name__ == "__main__": 324 | 325 | # # simple test run 326 | # net = Unet(backbone_name='resnet18') 327 | 328 | # criterion = nn.MSELoss() 329 | # optimizer = torch.optim.Adam(net.parameters()) 330 | # print('Network initialized. Running a test batch.') 331 | # for _ in range(1): 332 | # with torch.set_grad_enabled(True): 333 | # batch = torch.empty(1, 3, 224, 224).normal_() 334 | # targets = torch.empty(1, 21, 224, 224).normal_() 335 | 336 | # out = net(batch) 337 | # loss = criterion(out, targets) 338 | # loss.backward() 339 | # optimizer.step() 340 | # print(out.shape) 341 | 342 | # print('fasza.') 343 | -------------------------------------------------------------------------------- /models/dam/model_unet_MandD16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models, datasets, transforms 4 | from torch.nn import functional as F 5 | import os 6 | 7 | 8 | 9 | class revAttention(nn.Module): #sSE 10 | def __init__(self, in_channels): 11 | super().__init__() 12 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False) 13 | self.norm = nn.Sigmoid() 14 | 15 | def forward(self, U, V): 16 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w] 17 | q = self.norm(q) 18 | return U * (1+q) # 19 | 20 | 21 | 22 | 23 | def get_backbone(name, pretrained=True): 24 | 25 | """ Loading backbone, defining names for skip-connections and encoder output. """ 26 | 27 | # TODO: More backbones 28 | 29 | # loading backbone model 30 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}') 31 | if name == 'resnet18': 32 | backbone = models.resnet18(pretrained=pretrained) 33 | elif name == 'resnet34': 34 | backbone = models.resnet34(pretrained=pretrained) 35 | elif name == 'resnet50': 36 | backbone = models.resnet50(pretrained=pretrained) 37 | elif name == 'resnet101': 38 | backbone = models.resnet101(pretrained=pretrained) 39 | elif name == 'resnet152': 40 | backbone = models.resnet152(pretrained=pretrained) 41 | elif name == 'vgg16_bn': 42 | backbone = models.vgg16_bn(pretrained=pretrained).features 43 | elif name == 'vgg19_bn': 44 | backbone = models.vgg19_bn(pretrained=pretrained).features 45 | # elif name == 'inception_v3': 46 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False) 47 | elif name == 'densenet121': 48 | backbone = models.densenet121(pretrained=True).features 49 | elif name == 'densenet161': 50 | backbone = models.densenet161(pretrained=True).features 51 | elif name == 'densenet169': 52 | backbone = models.densenet169(pretrained=True).features 53 | elif name == 'densenet201': 54 | backbone = models.densenet201(pretrained=True).features 55 | elif name == 'unet_encoder': 56 | from unet_backbone import UnetEncoder 57 | backbone = UnetEncoder(3) 58 | else: 59 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 60 | 61 | # specifying skip feature and output names 62 | if name.startswith('resnet'): 63 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3'] 64 | backbone_output = 'layer4' 65 | elif name == 'vgg16_bn': 66 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output 67 | feature_names = ['5', '12', '22', '32', '42'] 68 | backbone_output = '43' 69 | elif name == 'vgg19_bn': 70 | feature_names = ['5', '12', '25', '38', '51'] 71 | backbone_output = '52' 72 | # elif name == 'inception_v3': 73 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e'] 74 | # backbone_output = 'Mixed_7c' 75 | elif name.startswith('densenet'): 76 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3'] 77 | backbone_output = 'denseblock4' 78 | elif name == 'unet_encoder': 79 | feature_names = ['module1', 'module2', 'module3', 'module4'] 80 | backbone_output = 'module5' 81 | else: 82 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 83 | 84 | return backbone, feature_names, backbone_output 85 | 86 | 87 | class UpsampleBlock(nn.Module): 88 | 89 | # TODO: separate parametric and non-parametric classes? 90 | # TODO: skip connection concatenated OR added 91 | 92 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False): 93 | super(UpsampleBlock, self).__init__() 94 | 95 | self.parametric = parametric 96 | ch_out = ch_in/2 if ch_out is None else ch_out 97 | 98 | # first convolution: either transposed conv, or conv following the skip connection 99 | if parametric: 100 | # versions: kernel=4 padding=1, kernel=2 padding=0 101 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4), 102 | stride=2, padding=1, output_padding=0, bias=(not use_bn)) 103 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 104 | else: 105 | self.up = None 106 | ch_in = ch_in + skip_in 107 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3), 108 | stride=1, padding=1, bias=(not use_bn)) 109 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 110 | 111 | self.relu = nn.ReLU(inplace=True) 112 | 113 | # second convolution 114 | conv2_in = ch_out if not parametric else ch_out + skip_in 115 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3), 116 | stride=1, padding=1, bias=(not use_bn)) 117 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None 118 | 119 | #def forward(self, x, skip_connection=None): # 120 | def forward(self, x, skip_connection=1): # 121 | 122 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear', 123 | align_corners=None) 124 | if self.parametric: 125 | x = self.bn1(x) if self.bn1 is not None else x 126 | x = self.relu(x) 127 | 128 | if skip_connection is not None: 129 | # Padding in case the incomping volumes are of different sizes 130 | diffY = skip_connection.size()[2] - x.size()[2] 131 | diffX = skip_connection.size()[3] - x.size()[3] 132 | x = F.pad(x, (diffX // 2, diffX - diffX // 2, 133 | diffY // 2, diffY - diffY // 2)) 134 | 135 | x = torch.cat([x, skip_connection], dim=1) 136 | 137 | if not self.parametric: 138 | x = self.conv1(x) 139 | x = self.bn1(x) if self.bn1 is not None else x 140 | x = self.relu(x) 141 | x = self.conv2(x) 142 | x = self.bn2(x) if self.bn2 is not None else x 143 | x = self.relu(x) 144 | 145 | return x 146 | 147 | 148 | 149 | def conv3x3(in_channels, out_channels, stride=1): 150 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 151 | 152 | class ResidualUnit(nn.Module): 153 | def __init__(self, in_channels, out_channels): 154 | super(ResidualUnit, self).__init__() 155 | self.conv1 = conv3x3(in_channels, out_channels, stride=1) 156 | self.bn1 = nn.BatchNorm2d(out_channels) 157 | self.relu1 = nn.ReLU(inplace=True) 158 | self.conv2 = conv3x3(out_channels, out_channels, stride=1) 159 | self.bn2 = nn.BatchNorm2d(out_channels) 160 | self.relu2 = nn.ReLU(inplace=True) 161 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) 162 | 163 | def forward(self, x): 164 | residual = self.conv_1x1(x) 165 | out = self.conv1(x) 166 | out = self.bn1(out) 167 | out = self.relu1(out) 168 | out = self.conv2(out) 169 | out = self.bn2(out) 170 | out += residual 171 | out = self.relu2(out) 172 | return out 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | class Unet(nn.Module): 183 | 184 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones.""" 185 | 186 | def __init__(self, 187 | backbone_name='resnet50', 188 | pretrained=True, 189 | encoder_freeze=False, 190 | classes=21, 191 | decoder_filters=(256, 128, 64, 32, 16), 192 | parametric_upsampling=True, 193 | shortcut_features='default', 194 | decoder_use_batchnorm=True): 195 | super(Unet, self).__init__() 196 | 197 | self.backbone_name = backbone_name 198 | 199 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained) 200 | shortcut_chs, bb_out_chs = self.infer_skip_channels() 201 | if shortcut_features != 'default': 202 | self.shortcut_features = shortcut_features 203 | 204 | # build decoder part 205 | self.upsample_blocks = nn.ModuleList() 206 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections 207 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1]) 208 | num_blocks = len(self.shortcut_features) 209 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)): 210 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out)) 211 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out, 212 | skip_in=shortcut_chs[num_blocks-i-1], 213 | parametric=parametric_upsampling, 214 | use_bn=decoder_use_batchnorm)) 215 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1)) 216 | 217 | if encoder_freeze: 218 | self.freeze_encoder() 219 | 220 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later 221 | #hhl20210611add 用来替代1通道input在child=0时的卷积 222 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # 223 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 224 | 225 | 226 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64) 227 | self.direction_feature = ResidualUnit(64, 64) 228 | self.point_feature = ResidualUnit(64, 64) 229 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1) 230 | self.directionAtt = revAttention(1) 231 | self.direction_conv = nn.Conv2d(64, 16+1, kernel_size=1) 232 | self.maskAtt = revAttention(16+1) 233 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1) 234 | 235 | self.residual = ResidualUnit(64, 64) 236 | 237 | 238 | 239 | def freeze_encoder(self): 240 | 241 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """ 242 | 243 | for param in self.backbone.parameters(): 244 | param.requires_grad = False 245 | 246 | def forward(self, *input): 247 | 248 | """ Forward propagation in U-Net. """ 249 | 250 | x, features = self.forward_backbone(*input) 251 | 252 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks): 253 | skip_features = features[skip_name] 254 | x = upsample_block(x, skip_features) 255 | 256 | x_F1 = self.mask_feature(x) 257 | 258 | x_F2 = self.direction_feature(x_F1) 259 | 260 | x_direction = self.direction_conv(x_F2) 261 | 262 | x_F1_mask = self.residual(x_F1) 263 | x_final_mask = self.mask_conv(x_F1_mask) 264 | 265 | 266 | 267 | return x_final_mask, x_direction 268 | 269 | def forward_backbone(self, x): 270 | 271 | """ Forward propagation in backbone encoder network. """ 272 | 273 | features = {None: None} if None in self.shortcut_features else dict() 274 | for name, child in self.backbone.named_children(): 275 | # hhl20210611add x.shape[1] = 1的情况 276 | if(name == '0' and x.shape[1] !=3): 277 | x = self.child0(x) 278 | elif(name == 'conv1' and x.shape[1] !=3): 279 | x = self.child_conv1(x) 280 | else: 281 | x = child(x) 282 | #x = child(x) 283 | if name in self.shortcut_features: 284 | features[name] = x 285 | if name == self.bb_out_name: 286 | break 287 | 288 | return x, features 289 | 290 | def infer_skip_channels(self): 291 | 292 | """ Getting the number of channels at skip connections and at the output of the encoder. """ 293 | 294 | x = torch.zeros(1, 3, 224, 224) 295 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder' 296 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution 297 | 298 | # forward run in backbone to count channels (dirty solution but works for *any* Module) 299 | for name, child in self.backbone.named_children(): 300 | x = child(x) 301 | if name in self.shortcut_features: 302 | channels.append(x.shape[1]) 303 | if name == self.bb_out_name: 304 | out_channels = x.shape[1] 305 | break 306 | return channels, out_channels 307 | 308 | def get_pretrained_parameters(self): 309 | for name, param in self.backbone.named_parameters(): 310 | if not (self.replaced_conv1 and name == 'conv1.weight'): 311 | yield param 312 | 313 | def get_random_initialized_parameters(self): 314 | pretrained_param_names = set() 315 | for name, param in self.backbone.named_parameters(): 316 | if not (self.replaced_conv1 and name == 'conv1.weight'): 317 | pretrained_param_names.add('backbone.{}'.format(name)) 318 | 319 | for name, param in self.named_parameters(): 320 | if name not in pretrained_param_names: 321 | yield param 322 | 323 | 324 | # if __name__ == "__main__": 325 | 326 | # # simple test run 327 | # net = Unet(backbone_name='resnet18') 328 | 329 | # criterion = nn.MSELoss() 330 | # optimizer = torch.optim.Adam(net.parameters()) 331 | # print('Network initialized. Running a test batch.') 332 | # for _ in range(1): 333 | # with torch.set_grad_enabled(True): 334 | # batch = torch.empty(1, 3, 224, 224).normal_() 335 | # targets = torch.empty(1, 21, 224, 224).normal_() 336 | 337 | # out = net(batch) 338 | # loss = criterion(out, targets) 339 | # loss.backward() 340 | # optimizer.step() 341 | # print(out.shape) 342 | 343 | # print('fasza.') 344 | -------------------------------------------------------------------------------- /models/dam/model_unet_MandD4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models, datasets, transforms 4 | from torch.nn import functional as F 5 | import os 6 | 7 | 8 | class revAttention(nn.Module): #sSE 9 | def __init__(self, in_channels): 10 | super().__init__() 11 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False) 12 | self.norm = nn.Sigmoid() 13 | 14 | def forward(self, U, V): 15 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w] 16 | q = self.norm(q) 17 | return U * (1+q) 18 | 19 | 20 | 21 | 22 | def get_backbone(name, pretrained=True): 23 | 24 | """ Loading backbone, defining names for skip-connections and encoder output. """ 25 | 26 | # TODO: More backbones 27 | 28 | # loading backbone model 29 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}') 30 | if name == 'resnet18': 31 | backbone = models.resnet18(pretrained=pretrained) 32 | elif name == 'resnet34': 33 | backbone = models.resnet34(pretrained=pretrained) 34 | elif name == 'resnet50': 35 | backbone = models.resnet50(pretrained=pretrained) 36 | elif name == 'resnet101': 37 | backbone = models.resnet101(pretrained=pretrained) 38 | elif name == 'resnet152': 39 | backbone = models.resnet152(pretrained=pretrained) 40 | elif name == 'vgg16_bn': 41 | backbone = models.vgg16_bn(pretrained=pretrained).features 42 | elif name == 'vgg19_bn': 43 | backbone = models.vgg19_bn(pretrained=pretrained).features 44 | # elif name == 'inception_v3': 45 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False) 46 | elif name == 'densenet121': 47 | backbone = models.densenet121(pretrained=True).features 48 | elif name == 'densenet161': 49 | backbone = models.densenet161(pretrained=True).features 50 | elif name == 'densenet169': 51 | backbone = models.densenet169(pretrained=True).features 52 | elif name == 'densenet201': 53 | backbone = models.densenet201(pretrained=True).features 54 | elif name == 'unet_encoder': 55 | from unet_backbone import UnetEncoder 56 | backbone = UnetEncoder(3) 57 | else: 58 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 59 | 60 | # specifying skip feature and output names 61 | if name.startswith('resnet'): 62 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3'] 63 | backbone_output = 'layer4' 64 | elif name == 'vgg16_bn': 65 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output 66 | feature_names = ['5', '12', '22', '32', '42'] 67 | backbone_output = '43' 68 | elif name == 'vgg19_bn': 69 | feature_names = ['5', '12', '25', '38', '51'] 70 | backbone_output = '52' 71 | # elif name == 'inception_v3': 72 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e'] 73 | # backbone_output = 'Mixed_7c' 74 | elif name.startswith('densenet'): 75 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3'] 76 | backbone_output = 'denseblock4' 77 | elif name == 'unet_encoder': 78 | feature_names = ['module1', 'module2', 'module3', 'module4'] 79 | backbone_output = 'module5' 80 | else: 81 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 82 | 83 | return backbone, feature_names, backbone_output 84 | 85 | 86 | class UpsampleBlock(nn.Module): 87 | 88 | # TODO: separate parametric and non-parametric classes? 89 | # TODO: skip connection concatenated OR added 90 | 91 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False): 92 | super(UpsampleBlock, self).__init__() 93 | 94 | self.parametric = parametric 95 | ch_out = ch_in/2 if ch_out is None else ch_out 96 | 97 | # first convolution: either transposed conv, or conv following the skip connection 98 | if parametric: 99 | # versions: kernel=4 padding=1, kernel=2 padding=0 100 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4), 101 | stride=2, padding=1, output_padding=0, bias=(not use_bn)) 102 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 103 | else: 104 | self.up = None 105 | ch_in = ch_in + skip_in 106 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3), 107 | stride=1, padding=1, bias=(not use_bn)) 108 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 109 | 110 | self.relu = nn.ReLU(inplace=True) 111 | 112 | # second convolution 113 | conv2_in = ch_out if not parametric else ch_out + skip_in 114 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3), 115 | stride=1, padding=1, bias=(not use_bn)) 116 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None 117 | 118 | #def forward(self, x, skip_connection=None): # 119 | def forward(self, x, skip_connection=1): # 120 | 121 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear', 122 | align_corners=None) 123 | if self.parametric: 124 | x = self.bn1(x) if self.bn1 is not None else x 125 | x = self.relu(x) 126 | 127 | if skip_connection is not None: 128 | # Padding in case the incomping volumes are of different sizes #hhl20200413add 129 | diffY = skip_connection.size()[2] - x.size()[2] 130 | diffX = skip_connection.size()[3] - x.size()[3] 131 | x = F.pad(x, (diffX // 2, diffX - diffX // 2, 132 | diffY // 2, diffY - diffY // 2)) 133 | 134 | x = torch.cat([x, skip_connection], dim=1) 135 | 136 | if not self.parametric: 137 | x = self.conv1(x) 138 | x = self.bn1(x) if self.bn1 is not None else x 139 | x = self.relu(x) 140 | x = self.conv2(x) 141 | x = self.bn2(x) if self.bn2 is not None else x 142 | x = self.relu(x) 143 | 144 | return x 145 | 146 | 147 | 148 | def conv3x3(in_channels, out_channels, stride=1): 149 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 150 | 151 | class ResidualUnit(nn.Module): 152 | def __init__(self, in_channels, out_channels): 153 | super(ResidualUnit, self).__init__() 154 | self.conv1 = conv3x3(in_channels, out_channels, stride=1) 155 | self.bn1 = nn.BatchNorm2d(out_channels) 156 | self.relu1 = nn.ReLU(inplace=True) 157 | self.conv2 = conv3x3(out_channels, out_channels, stride=1) 158 | self.bn2 = nn.BatchNorm2d(out_channels) 159 | self.relu2 = nn.ReLU(inplace=True) 160 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) 161 | 162 | def forward(self, x): 163 | residual = self.conv_1x1(x) 164 | out = self.conv1(x) 165 | out = self.bn1(out) 166 | out = self.relu1(out) 167 | out = self.conv2(out) 168 | out = self.bn2(out) 169 | out += residual 170 | out = self.relu2(out) 171 | return out 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | class Unet(nn.Module): 182 | 183 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones.""" 184 | 185 | def __init__(self, 186 | backbone_name='resnet50', 187 | pretrained=True, 188 | encoder_freeze=False, 189 | classes=21, 190 | decoder_filters=(256, 128, 64, 32, 16), 191 | parametric_upsampling=True, 192 | shortcut_features='default', 193 | decoder_use_batchnorm=True): 194 | super(Unet, self).__init__() 195 | 196 | self.backbone_name = backbone_name 197 | 198 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained) 199 | shortcut_chs, bb_out_chs = self.infer_skip_channels() 200 | if shortcut_features != 'default': 201 | self.shortcut_features = shortcut_features 202 | 203 | # build decoder part 204 | self.upsample_blocks = nn.ModuleList() 205 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections 206 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1]) 207 | num_blocks = len(self.shortcut_features) 208 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)): 209 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out)) 210 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out, 211 | skip_in=shortcut_chs[num_blocks-i-1], 212 | parametric=parametric_upsampling, 213 | use_bn=decoder_use_batchnorm)) 214 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1)) 215 | 216 | if encoder_freeze: 217 | self.freeze_encoder() 218 | 219 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later 220 | 221 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # 222 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 223 | 224 | 225 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64) 226 | self.direction_feature = ResidualUnit(64, 64) 227 | self.point_feature = ResidualUnit(64, 64) 228 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1) 229 | self.directionAtt = revAttention(1) 230 | self.direction_conv = nn.Conv2d(64, 4+1, kernel_size=1) 231 | self.maskAtt = revAttention(4+1) 232 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1) 233 | 234 | self.residual = ResidualUnit(64, 64) 235 | 236 | 237 | 238 | def freeze_encoder(self): 239 | 240 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """ 241 | 242 | for param in self.backbone.parameters(): 243 | param.requires_grad = False 244 | 245 | def forward(self, *input): 246 | 247 | """ Forward propagation in U-Net. """ 248 | 249 | x, features = self.forward_backbone(*input) 250 | 251 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks): 252 | skip_features = features[skip_name] 253 | x = upsample_block(x, skip_features) 254 | 255 | x_F1 = self.mask_feature(x) 256 | 257 | x_F2 = self.direction_feature(x_F1) 258 | 259 | x_direction = self.direction_conv(x_F2) 260 | 261 | x_F1_mask = self.residual(x_F1) 262 | x_final_mask = self.mask_conv(x_F1_mask) 263 | 264 | 265 | 266 | return x_final_mask, x_direction 267 | 268 | def forward_backbone(self, x): 269 | 270 | """ Forward propagation in backbone encoder network. """ 271 | 272 | features = {None: None} if None in self.shortcut_features else dict() 273 | for name, child in self.backbone.named_children(): 274 | 275 | if(name == '0' and x.shape[1] !=3): 276 | x = self.child0(x) 277 | elif(name == 'conv1' and x.shape[1] !=3): 278 | x = self.child_conv1(x) 279 | else: 280 | x = child(x) 281 | #x = child(x) 282 | if name in self.shortcut_features: 283 | features[name] = x 284 | if name == self.bb_out_name: 285 | break 286 | 287 | return x, features 288 | 289 | def infer_skip_channels(self): 290 | 291 | """ Getting the number of channels at skip connections and at the output of the encoder. """ 292 | 293 | x = torch.zeros(1, 3, 224, 224) 294 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder' 295 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution 296 | 297 | # forward run in backbone to count channels (dirty solution but works for *any* Module) 298 | for name, child in self.backbone.named_children(): 299 | x = child(x) 300 | if name in self.shortcut_features: 301 | channels.append(x.shape[1]) 302 | if name == self.bb_out_name: 303 | out_channels = x.shape[1] 304 | break 305 | return channels, out_channels 306 | 307 | def get_pretrained_parameters(self): 308 | for name, param in self.backbone.named_parameters(): 309 | if not (self.replaced_conv1 and name == 'conv1.weight'): 310 | yield param 311 | 312 | def get_random_initialized_parameters(self): 313 | pretrained_param_names = set() 314 | for name, param in self.backbone.named_parameters(): 315 | if not (self.replaced_conv1 and name == 'conv1.weight'): 316 | pretrained_param_names.add('backbone.{}'.format(name)) 317 | 318 | for name, param in self.named_parameters(): 319 | if name not in pretrained_param_names: 320 | yield param 321 | 322 | 323 | # if __name__ == "__main__": 324 | 325 | # # simple test run 326 | # net = Unet(backbone_name='resnet18') 327 | 328 | # criterion = nn.MSELoss() 329 | # optimizer = torch.optim.Adam(net.parameters()) 330 | # print('Network initialized. Running a test batch.') 331 | # for _ in range(1): 332 | # with torch.set_grad_enabled(True): 333 | # batch = torch.empty(1, 3, 224, 224).normal_() 334 | # targets = torch.empty(1, 21, 224, 224).normal_() 335 | 336 | # out = net(batch) 337 | # loss = criterion(out, targets) 338 | # loss.backward() 339 | # optimizer.step() 340 | # print(out.shape) 341 | 342 | # print('fasza.') 343 | -------------------------------------------------------------------------------- /models/dam/model_unet_MandDandP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models, datasets, transforms 4 | from torch.nn import functional as F 5 | import os 6 | 7 | 8 | class revAttention(nn.Module): #sSE 9 | def __init__(self, in_channels): 10 | super().__init__() 11 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False) 12 | self.norm = nn.Sigmoid() 13 | 14 | def forward(self, U, V): 15 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w] 16 | q = self.norm(q) 17 | return U * (1+q) # 18 | 19 | 20 | 21 | 22 | def get_backbone(name, pretrained=True): 23 | 24 | """ Loading backbone, defining names for skip-connections and encoder output. """ 25 | 26 | # TODO: More backbones 27 | 28 | # loading backbone model 29 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}') 30 | if name == 'resnet18': 31 | backbone = models.resnet18(pretrained=pretrained) 32 | elif name == 'resnet34': 33 | backbone = models.resnet34(pretrained=pretrained) 34 | elif name == 'resnet50': 35 | backbone = models.resnet50(pretrained=pretrained) 36 | elif name == 'resnet101': 37 | backbone = models.resnet101(pretrained=pretrained) 38 | elif name == 'resnet152': 39 | backbone = models.resnet152(pretrained=pretrained) 40 | elif name == 'vgg16_bn': 41 | backbone = models.vgg16_bn(pretrained=pretrained).features 42 | elif name == 'vgg19_bn': 43 | backbone = models.vgg19_bn(pretrained=pretrained).features 44 | # elif name == 'inception_v3': 45 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False) 46 | elif name == 'densenet121': 47 | backbone = models.densenet121(pretrained=True).features 48 | elif name == 'densenet161': 49 | backbone = models.densenet161(pretrained=True).features 50 | elif name == 'densenet169': 51 | backbone = models.densenet169(pretrained=True).features 52 | elif name == 'densenet201': 53 | backbone = models.densenet201(pretrained=True).features 54 | elif name == 'unet_encoder': 55 | from unet_backbone import UnetEncoder 56 | backbone = UnetEncoder(3) 57 | else: 58 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 59 | 60 | # specifying skip feature and output names 61 | if name.startswith('resnet'): 62 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3'] 63 | backbone_output = 'layer4' 64 | elif name == 'vgg16_bn': 65 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output 66 | feature_names = ['5', '12', '22', '32', '42'] 67 | backbone_output = '43' 68 | elif name == 'vgg19_bn': 69 | feature_names = ['5', '12', '25', '38', '51'] 70 | backbone_output = '52' 71 | # elif name == 'inception_v3': 72 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e'] 73 | # backbone_output = 'Mixed_7c' 74 | elif name.startswith('densenet'): 75 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3'] 76 | backbone_output = 'denseblock4' 77 | elif name == 'unet_encoder': 78 | feature_names = ['module1', 'module2', 'module3', 'module4'] 79 | backbone_output = 'module5' 80 | else: 81 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 82 | 83 | return backbone, feature_names, backbone_output 84 | 85 | 86 | class UpsampleBlock(nn.Module): 87 | 88 | # TODO: separate parametric and non-parametric classes? 89 | # TODO: skip connection concatenated OR added 90 | 91 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False): 92 | super(UpsampleBlock, self).__init__() 93 | 94 | self.parametric = parametric 95 | ch_out = ch_in/2 if ch_out is None else ch_out 96 | 97 | # first convolution: either transposed conv, or conv following the skip connection 98 | if parametric: 99 | # versions: kernel=4 padding=1, kernel=2 padding=0 100 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4), 101 | stride=2, padding=1, output_padding=0, bias=(not use_bn)) 102 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 103 | else: 104 | self.up = None 105 | ch_in = ch_in + skip_in 106 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3), 107 | stride=1, padding=1, bias=(not use_bn)) 108 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 109 | 110 | self.relu = nn.ReLU(inplace=True) 111 | 112 | # second convolution 113 | conv2_in = ch_out if not parametric else ch_out + skip_in 114 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3), 115 | stride=1, padding=1, bias=(not use_bn)) 116 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None 117 | 118 | #def forward(self, x, skip_connection=None): # 119 | def forward(self, x, skip_connection=1): 120 | 121 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear', 122 | align_corners=None) 123 | if self.parametric: 124 | x = self.bn1(x) if self.bn1 is not None else x 125 | x = self.relu(x) 126 | 127 | if skip_connection is not None: 128 | # Padding in case the incomping volumes are of different sizes 129 | diffY = skip_connection.size()[2] - x.size()[2] 130 | diffX = skip_connection.size()[3] - x.size()[3] 131 | x = F.pad(x, (diffX // 2, diffX - diffX // 2, 132 | diffY // 2, diffY - diffY // 2)) 133 | 134 | x = torch.cat([x, skip_connection], dim=1) 135 | 136 | if not self.parametric: 137 | x = self.conv1(x) 138 | x = self.bn1(x) if self.bn1 is not None else x 139 | x = self.relu(x) 140 | x = self.conv2(x) 141 | x = self.bn2(x) if self.bn2 is not None else x 142 | x = self.relu(x) 143 | 144 | return x 145 | 146 | 147 | 148 | def conv3x3(in_channels, out_channels, stride=1): 149 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 150 | 151 | class ResidualUnit(nn.Module): 152 | def __init__(self, in_channels, out_channels): 153 | super(ResidualUnit, self).__init__() 154 | self.conv1 = conv3x3(in_channels, out_channels, stride=1) 155 | self.bn1 = nn.BatchNorm2d(out_channels) 156 | self.relu1 = nn.ReLU(inplace=True) 157 | self.conv2 = conv3x3(out_channels, out_channels, stride=1) 158 | self.bn2 = nn.BatchNorm2d(out_channels) 159 | self.relu2 = nn.ReLU(inplace=True) 160 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) 161 | 162 | def forward(self, x): 163 | residual = self.conv_1x1(x) 164 | out = self.conv1(x) 165 | out = self.bn1(out) 166 | out = self.relu1(out) 167 | out = self.conv2(out) 168 | out = self.bn2(out) 169 | out += residual 170 | out = self.relu2(out) 171 | return out 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | class Unet(nn.Module): 182 | 183 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones.""" 184 | 185 | def __init__(self, 186 | backbone_name='resnet50', 187 | pretrained=True, 188 | encoder_freeze=False, 189 | classes=21, 190 | decoder_filters=(256, 128, 64, 32, 16), 191 | parametric_upsampling=True, 192 | shortcut_features='default', 193 | decoder_use_batchnorm=True): 194 | super(Unet, self).__init__() 195 | 196 | self.backbone_name = backbone_name 197 | 198 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained) 199 | shortcut_chs, bb_out_chs = self.infer_skip_channels() 200 | if shortcut_features != 'default': 201 | self.shortcut_features = shortcut_features 202 | 203 | # build decoder part 204 | self.upsample_blocks = nn.ModuleList() 205 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections 206 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1]) 207 | num_blocks = len(self.shortcut_features) 208 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)): 209 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out)) 210 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out, 211 | skip_in=shortcut_chs[num_blocks-i-1], 212 | parametric=parametric_upsampling, 213 | use_bn=decoder_use_batchnorm)) 214 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1)) 215 | 216 | if encoder_freeze: 217 | self.freeze_encoder() 218 | 219 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later 220 | # 221 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # 222 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 223 | 224 | 225 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64) 226 | self.direction_feature = ResidualUnit(64, 64) 227 | self.point_feature = ResidualUnit(64, 64) 228 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1) 229 | self.directionAtt = revAttention(1) 230 | self.direction_conv = nn.Conv2d(64, 9, kernel_size=1) 231 | self.maskAtt = revAttention(9) 232 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1) 233 | 234 | self.residual = ResidualUnit(64, 64) 235 | 236 | 237 | 238 | def freeze_encoder(self): 239 | 240 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """ 241 | 242 | for param in self.backbone.parameters(): 243 | param.requires_grad = False 244 | 245 | def forward(self, *input): 246 | 247 | """ Forward propagation in U-Net. """ 248 | 249 | x, features = self.forward_backbone(*input) 250 | 251 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks): 252 | skip_features = features[skip_name] 253 | x = upsample_block(x, skip_features) 254 | 255 | x_F1 = self.mask_feature(x) 256 | 257 | x_F2 = self.direction_feature(x_F1) 258 | x_F3 = self.point_feature(x_F2) 259 | 260 | x_direction = self.direction_conv(x_F2) 261 | x_point = self.point_conv(x_F3) 262 | 263 | x_F1_mask = self.residual(x_F1) 264 | x_final_mask = self.mask_conv(x_F1_mask) 265 | 266 | 267 | 268 | return x_final_mask, x_point, x_direction 269 | 270 | def forward_backbone(self, x): 271 | 272 | """ Forward propagation in backbone encoder network. """ 273 | 274 | features = {None: None} if None in self.shortcut_features else dict() 275 | for name, child in self.backbone.named_children(): 276 | # 277 | if(name == '0' and x.shape[1] !=3): 278 | x = self.child0(x) 279 | elif(name == 'conv1' and x.shape[1] !=3): 280 | x = self.child_conv1(x) 281 | else: 282 | x = child(x) 283 | #x = child(x) 284 | if name in self.shortcut_features: 285 | features[name] = x 286 | if name == self.bb_out_name: 287 | break 288 | 289 | return x, features 290 | 291 | def infer_skip_channels(self): 292 | 293 | """ Getting the number of channels at skip connections and at the output of the encoder. """ 294 | 295 | x = torch.zeros(1, 3, 224, 224) 296 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder' 297 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution 298 | 299 | # forward run in backbone to count channels (dirty solution but works for *any* Module) 300 | for name, child in self.backbone.named_children(): 301 | x = child(x) 302 | if name in self.shortcut_features: 303 | channels.append(x.shape[1]) 304 | if name == self.bb_out_name: 305 | out_channels = x.shape[1] 306 | break 307 | return channels, out_channels 308 | 309 | def get_pretrained_parameters(self): 310 | for name, param in self.backbone.named_parameters(): 311 | if not (self.replaced_conv1 and name == 'conv1.weight'): 312 | yield param 313 | 314 | def get_random_initialized_parameters(self): 315 | pretrained_param_names = set() 316 | for name, param in self.backbone.named_parameters(): 317 | if not (self.replaced_conv1 and name == 'conv1.weight'): 318 | pretrained_param_names.add('backbone.{}'.format(name)) 319 | 320 | for name, param in self.named_parameters(): 321 | if name not in pretrained_param_names: 322 | yield param 323 | 324 | 325 | # if __name__ == "__main__": 326 | 327 | # # simple test run 328 | # net = Unet(backbone_name='resnet18') 329 | 330 | # criterion = nn.MSELoss() 331 | # optimizer = torch.optim.Adam(net.parameters()) 332 | # print('Network initialized. Running a test batch.') 333 | # for _ in range(1): 334 | # with torch.set_grad_enabled(True): 335 | # batch = torch.empty(1, 3, 224, 224).normal_() 336 | # targets = torch.empty(1, 21, 224, 224).normal_() 337 | 338 | # out = net(batch) 339 | # loss = criterion(out, targets) 340 | # loss.backward() 341 | # optimizer.step() 342 | # print(out.shape) 343 | 344 | # print('fasza.') 345 | -------------------------------------------------------------------------------- /models/dam/model_unet_rev1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models, datasets, transforms 4 | from torch.nn import functional as F 5 | import os 6 | 7 | 8 | class revAttention(nn.Module): #sSE 9 | def __init__(self, in_channels): 10 | super().__init__() 11 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False) 12 | self.norm = nn.Sigmoid() 13 | 14 | def forward(self, U, V): 15 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w] 16 | q = self.norm(q) 17 | return U * (1+q) 18 | 19 | 20 | 21 | 22 | def get_backbone(name, pretrained=True): 23 | 24 | """ Loading backbone, defining names for skip-connections and encoder output. """ 25 | 26 | # TODO: More backbones 27 | 28 | # loading backbone model 29 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}') 30 | if name == 'resnet18': 31 | backbone = models.resnet18(pretrained=pretrained) 32 | elif name == 'resnet34': 33 | backbone = models.resnet34(pretrained=pretrained) 34 | elif name == 'resnet50': 35 | backbone = models.resnet50(pretrained=pretrained) 36 | elif name == 'resnet101': 37 | backbone = models.resnet101(pretrained=pretrained) 38 | elif name == 'resnet152': 39 | backbone = models.resnet152(pretrained=pretrained) 40 | elif name == 'vgg16_bn': 41 | backbone = models.vgg16_bn(pretrained=pretrained).features 42 | elif name == 'vgg19_bn': 43 | backbone = models.vgg19_bn(pretrained=pretrained).features 44 | # elif name == 'inception_v3': 45 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False) 46 | elif name == 'densenet121': 47 | backbone = models.densenet121(pretrained=True).features 48 | elif name == 'densenet161': 49 | backbone = models.densenet161(pretrained=True).features 50 | elif name == 'densenet169': 51 | backbone = models.densenet169(pretrained=True).features 52 | elif name == 'densenet201': 53 | backbone = models.densenet201(pretrained=True).features 54 | elif name == 'unet_encoder': 55 | from unet_backbone import UnetEncoder 56 | backbone = UnetEncoder(3) 57 | else: 58 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 59 | 60 | # specifying skip feature and output names 61 | if name.startswith('resnet'): 62 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3'] 63 | backbone_output = 'layer4' 64 | elif name == 'vgg16_bn': 65 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output 66 | feature_names = ['5', '12', '22', '32', '42'] 67 | backbone_output = '43' 68 | elif name == 'vgg19_bn': 69 | feature_names = ['5', '12', '25', '38', '51'] 70 | backbone_output = '52' 71 | # elif name == 'inception_v3': 72 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e'] 73 | # backbone_output = 'Mixed_7c' 74 | elif name.startswith('densenet'): 75 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3'] 76 | backbone_output = 'denseblock4' 77 | elif name == 'unet_encoder': 78 | feature_names = ['module1', 'module2', 'module3', 'module4'] 79 | backbone_output = 'module5' 80 | else: 81 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 82 | 83 | return backbone, feature_names, backbone_output 84 | 85 | 86 | class UpsampleBlock(nn.Module): 87 | 88 | # TODO: separate parametric and non-parametric classes? 89 | # TODO: skip connection concatenated OR added 90 | 91 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False): 92 | super(UpsampleBlock, self).__init__() 93 | 94 | self.parametric = parametric 95 | ch_out = ch_in/2 if ch_out is None else ch_out 96 | 97 | # first convolution: either transposed conv, or conv following the skip connection 98 | if parametric: 99 | # versions: kernel=4 padding=1, kernel=2 padding=0 100 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4), 101 | stride=2, padding=1, output_padding=0, bias=(not use_bn)) 102 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 103 | else: 104 | self.up = None 105 | ch_in = ch_in + skip_in 106 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3), 107 | stride=1, padding=1, bias=(not use_bn)) 108 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 109 | 110 | self.relu = nn.ReLU(inplace=True) 111 | 112 | # second convolution 113 | conv2_in = ch_out if not parametric else ch_out + skip_in 114 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3), 115 | stride=1, padding=1, bias=(not use_bn)) 116 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None 117 | 118 | #def forward(self, x, skip_connection=None): #original code 119 | def forward(self, x, skip_connection=1): # hhl revised 120 | 121 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear', 122 | align_corners=None) 123 | if self.parametric: 124 | x = self.bn1(x) if self.bn1 is not None else x 125 | x = self.relu(x) 126 | 127 | if skip_connection is not None: 128 | diffY = skip_connection.size()[2] - x.size()[2] 129 | diffX = skip_connection.size()[3] - x.size()[3] 130 | x = F.pad(x, (diffX // 2, diffX - diffX // 2, 131 | diffY // 2, diffY - diffY // 2)) 132 | 133 | x = torch.cat([x, skip_connection], dim=1) 134 | 135 | if not self.parametric: 136 | x = self.conv1(x) 137 | x = self.bn1(x) if self.bn1 is not None else x 138 | x = self.relu(x) 139 | x = self.conv2(x) 140 | x = self.bn2(x) if self.bn2 is not None else x 141 | x = self.relu(x) 142 | 143 | return x 144 | 145 | 146 | 147 | def conv3x3(in_channels, out_channels, stride=1): 148 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) 149 | 150 | class ResidualUnit(nn.Module): 151 | def __init__(self, in_channels, out_channels): 152 | super(ResidualUnit, self).__init__() 153 | self.conv1 = conv3x3(in_channels, out_channels, stride=1) 154 | self.bn1 = nn.BatchNorm2d(out_channels) 155 | self.relu1 = nn.ReLU(inplace=True) 156 | self.conv2 = conv3x3(out_channels, out_channels, stride=1) 157 | self.bn2 = nn.BatchNorm2d(out_channels) 158 | self.relu2 = nn.ReLU(inplace=True) 159 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) 160 | 161 | def forward(self, x): 162 | residual = self.conv_1x1(x) 163 | out = self.conv1(x) 164 | out = self.bn1(out) 165 | out = self.relu1(out) 166 | out = self.conv2(out) 167 | out = self.bn2(out) 168 | out += residual 169 | out = self.relu2(out) 170 | return out 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | class Unet(nn.Module): 181 | 182 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones.""" 183 | 184 | def __init__(self, 185 | backbone_name='resnet50', 186 | pretrained=True, 187 | encoder_freeze=False, 188 | classes=21, 189 | decoder_filters=(256, 128, 64, 32, 16), 190 | parametric_upsampling=True, 191 | shortcut_features='default', 192 | decoder_use_batchnorm=True): 193 | super(Unet, self).__init__() 194 | 195 | self.backbone_name = backbone_name 196 | 197 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained) 198 | shortcut_chs, bb_out_chs = self.infer_skip_channels() 199 | if shortcut_features != 'default': 200 | self.shortcut_features = shortcut_features 201 | 202 | # build decoder part 203 | self.upsample_blocks = nn.ModuleList() 204 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections 205 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1]) 206 | num_blocks = len(self.shortcut_features) 207 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)): 208 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out)) 209 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out, 210 | skip_in=shortcut_chs[num_blocks-i-1], 211 | parametric=parametric_upsampling, 212 | use_bn=decoder_use_batchnorm)) 213 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1)) 214 | 215 | if encoder_freeze: 216 | self.freeze_encoder() 217 | 218 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later 219 | 220 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # 221 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 222 | 223 | 224 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64) 225 | self.direction_feature = ResidualUnit(64, 64) 226 | self.point_feature = ResidualUnit(64, 64) 227 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1) 228 | self.directionAtt = revAttention(1) 229 | self.direction_conv = nn.Conv2d(64, 9, kernel_size=1) 230 | self.maskAtt = revAttention(9) 231 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1) 232 | 233 | 234 | 235 | 236 | 237 | def freeze_encoder(self): 238 | 239 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """ 240 | 241 | for param in self.backbone.parameters(): 242 | param.requires_grad = False 243 | 244 | def forward(self, *input): 245 | 246 | """ Forward propagation in U-Net. """ 247 | 248 | x, features = self.forward_backbone(*input) 249 | 250 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks): 251 | skip_features = features[skip_name] 252 | x = upsample_block(x, skip_features) 253 | 254 | x_F1 = self.mask_feature(x) 255 | 256 | x_F2 = self.direction_feature(x_F1) 257 | x_F3 = self.point_feature(x_F2) 258 | x_point = self.point_conv(x_F3) 259 | x_F2_direction = self.directionAtt(x_F2, x_point) 260 | x_direction = self.direction_conv(x_F2_direction) 261 | 262 | x_F1_mask = self.maskAtt(x_F1, x_direction) 263 | x_final_mask = self.mask_conv(x_F1_mask) 264 | 265 | 266 | return x_final_mask, x_point, x_direction 267 | 268 | def forward_backbone(self, x): 269 | 270 | """ Forward propagation in backbone encoder network. """ 271 | 272 | features = {None: None} if None in self.shortcut_features else dict() 273 | for name, child in self.backbone.named_children(): 274 | 275 | if(name == '0' and x.shape[1] !=3): 276 | x = self.child0(x) 277 | elif(name == 'conv1' and x.shape[1] !=3): 278 | x = self.child_conv1(x) 279 | else: 280 | x = child(x) 281 | #x = child(x) 282 | if name in self.shortcut_features: 283 | features[name] = x 284 | if name == self.bb_out_name: 285 | break 286 | 287 | return x, features 288 | 289 | def infer_skip_channels(self): 290 | 291 | """ Getting the number of channels at skip connections and at the output of the encoder. """ 292 | 293 | x = torch.zeros(1, 3, 224, 224) 294 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder' 295 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution 296 | 297 | # forward run in backbone to count channels (dirty solution but works for *any* Module) 298 | for name, child in self.backbone.named_children(): 299 | x = child(x) 300 | if name in self.shortcut_features: 301 | channels.append(x.shape[1]) 302 | if name == self.bb_out_name: 303 | out_channels = x.shape[1] 304 | break 305 | return channels, out_channels 306 | 307 | def get_pretrained_parameters(self): 308 | for name, param in self.backbone.named_parameters(): 309 | if not (self.replaced_conv1 and name == 'conv1.weight'): 310 | yield param 311 | 312 | def get_random_initialized_parameters(self): 313 | pretrained_param_names = set() 314 | for name, param in self.backbone.named_parameters(): 315 | if not (self.replaced_conv1 and name == 'conv1.weight'): 316 | pretrained_param_names.add('backbone.{}'.format(name)) 317 | 318 | for name, param in self.named_parameters(): 319 | if name not in pretrained_param_names: 320 | yield param 321 | 322 | 323 | # if __name__ == "__main__": 324 | 325 | # # simple test run 326 | # net = Unet(backbone_name='resnet18') 327 | 328 | # criterion = nn.MSELoss() 329 | # optimizer = torch.optim.Adam(net.parameters()) 330 | # print('Network initialized. Running a test batch.') 331 | # for _ in range(1): 332 | # with torch.set_grad_enabled(True): 333 | # batch = torch.empty(1, 3, 224, 224).normal_() 334 | # targets = torch.empty(1, 21, 224, 224).normal_() 335 | 336 | # out = net(batch) 337 | # loss = criterion(out, targets) 338 | # loss.backward() 339 | # optimizer.step() 340 | # print(out.shape) 341 | 342 | # print('fasza.') 343 | -------------------------------------------------------------------------------- /models/deeplabv3_plus.py: -------------------------------------------------------------------------------- 1 | from base.base_model import BaseModel 2 | import torch 3 | import math 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torchvision import models 7 | import torch.utils.model_zoo as model_zoo 8 | from hhl_utils.helpers import initialize_weights 9 | from itertools import chain 10 | 11 | ''' 12 | -> ResNet BackBone 13 | ''' 14 | 15 | class ResNet(nn.Module): 16 | def __init__(self, in_channels=3, output_stride=16, backbone='resnet101', pretrained=True): 17 | super(ResNet, self).__init__() 18 | model = getattr(models, backbone)(pretrained) 19 | if not pretrained or in_channels != 3: 20 | self.layer0 = nn.Sequential( 21 | nn.Conv2d(in_channels, 64, 7, stride=2, padding=3, bias=False), 22 | nn.BatchNorm2d(64), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 25 | ) 26 | initialize_weights(self.layer0) 27 | else: 28 | self.layer0 = nn.Sequential(*list(model.children())[:4]) 29 | 30 | self.layer1 = model.layer1 31 | self.layer2 = model.layer2 32 | self.layer3 = model.layer3 33 | self.layer4 = model.layer4 34 | 35 | if output_stride == 16: s3, s4, d3, d4 = (2, 1, 1, 2) 36 | elif output_stride == 8: s3, s4, d3, d4 = (1, 1, 2, 4) 37 | 38 | if output_stride == 8: 39 | for n, m in self.layer3.named_modules(): 40 | if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'): 41 | m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3) 42 | elif 'conv2' in n: 43 | m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3) 44 | elif 'downsample.0' in n: 45 | m.stride = (s3, s3) 46 | 47 | for n, m in self.layer4.named_modules(): 48 | if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'): 49 | m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4) 50 | elif 'conv2' in n: 51 | m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4) 52 | elif 'downsample.0' in n: 53 | m.stride = (s4, s4) 54 | 55 | def forward(self, x): 56 | x = self.layer0(x) 57 | x = self.layer1(x) 58 | low_level_features = x 59 | x = self.layer2(x) 60 | x = self.layer3(x) 61 | x = self.layer4(x) 62 | 63 | return x, low_level_features 64 | 65 | ''' 66 | -> (Aligned) Xception BackBone 67 | Pretrained model from https://github.com/Cadene/pretrained-models.pytorch 68 | by Remi Cadene 69 | ''' 70 | class SeparableConv2d(nn.Module): 71 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=nn.BatchNorm2d): 72 | super(SeparableConv2d, self).__init__() 73 | 74 | if dilation > kernel_size//2: padding = dilation 75 | else: padding = kernel_size//2 76 | 77 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding=padding, 78 | dilation=dilation, groups=in_channels, bias=bias) 79 | self.bn = nn.BatchNorm2d(in_channels) 80 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, bias=bias) 81 | 82 | def forward(self, x): 83 | x = self.conv1(x) 84 | x = self.bn(x) 85 | x = self.pointwise(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, exit_flow=False, use_1st_relu=True): 91 | super(Block, self).__init__() 92 | 93 | if in_channels != out_channels or stride !=1: 94 | self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False) 95 | self.skipbn = nn.BatchNorm2d(out_channels) 96 | else: self.skip = None 97 | 98 | rep = [] 99 | self.relu = nn.ReLU(inplace=True) 100 | 101 | rep.append(self.relu) 102 | rep.append(SeparableConv2d(in_channels, out_channels, 3, stride=1, dilation=dilation)) 103 | rep.append(nn.BatchNorm2d(out_channels)) 104 | 105 | rep.append(self.relu) 106 | rep.append(SeparableConv2d(out_channels, out_channels, 3, stride=1, dilation=dilation)) 107 | rep.append(nn.BatchNorm2d(out_channels)) 108 | 109 | rep.append(self.relu) 110 | rep.append(SeparableConv2d(out_channels, out_channels, 3, stride=stride, dilation=dilation)) 111 | rep.append(nn.BatchNorm2d(out_channels)) 112 | 113 | if exit_flow: 114 | rep[3:6] = rep[:3] 115 | rep[:3] = [ 116 | self.relu, 117 | SeparableConv2d(in_channels, in_channels, 3, 1, dilation), 118 | nn.BatchNorm2d(in_channels)] 119 | 120 | if not use_1st_relu: rep = rep[1:] 121 | self.rep = nn.Sequential(*rep) 122 | 123 | def forward(self, x): 124 | output = self.rep(x) 125 | if self.skip is not None: 126 | skip = self.skip(x) 127 | skip = self.skipbn(skip) 128 | else: 129 | skip = x 130 | 131 | x = output + skip 132 | return x 133 | 134 | class Xception(nn.Module): 135 | def __init__(self, output_stride=16, in_channels=3, pretrained=True): 136 | super(Xception, self).__init__() 137 | 138 | # Stride for block 3 (entry flow), and the dilation rates for middle flow and exit flow 139 | if output_stride == 16: b3_s, mf_d, ef_d = 2, 1, (1, 2) 140 | if output_stride == 8: b3_s, mf_d, ef_d = 1, 2, (2, 4) 141 | 142 | # Entry Flow 143 | self.conv1 = nn.Conv2d(in_channels, 32, 3, 2, padding=1, bias=False) 144 | self.bn1 = nn.BatchNorm2d(32) 145 | self.relu = nn.ReLU(inplace=True) 146 | self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1, bias=False) 147 | self.bn2 = nn.BatchNorm2d(64) 148 | 149 | self.block1 = Block(64, 128, stride=2, dilation=1, use_1st_relu=False) 150 | self.block2 = Block(128, 256, stride=2, dilation=1) 151 | self.block3 = Block(256, 728, stride=b3_s, dilation=1) 152 | 153 | # Middle Flow 154 | for i in range(16): 155 | exec(f'self.block{i+4} = Block(728, 728, stride=1, dilation=mf_d)') 156 | 157 | # Exit flow 158 | self.block20 = Block(728, 1024, stride=1, dilation=ef_d[0], exit_flow=True) 159 | 160 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=ef_d[1]) 161 | self.bn3 = nn.BatchNorm2d(1536) 162 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=ef_d[1]) 163 | self.bn4 = nn.BatchNorm2d(1536) 164 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=ef_d[1]) 165 | self.bn5 = nn.BatchNorm2d(2048) 166 | 167 | initialize_weights(self) 168 | if pretrained: self._load_pretrained_model() 169 | 170 | 171 | def _load_pretrained_model(self): 172 | url = 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth' 173 | pretrained_weights = model_zoo.load_url(url) 174 | state_dict = self.state_dict() 175 | model_dict = {} 176 | 177 | for k, v in pretrained_weights.items(): 178 | if k in state_dict: 179 | if 'pointwise' in k: 180 | v = v.unsqueeze(-1).unsqueeze(-1) # [C, C] -> [C, C, 1, 1] 181 | if k.startswith('block11'): 182 | # In Xception there is only 8 blocks in Middle flow 183 | model_dict[k] = v 184 | for i in range(8): 185 | model_dict[k.replace('block11', f'block{i+12}')] = v 186 | elif k.startswith('block12'): 187 | model_dict[k.replace('block12', 'block20')] = v 188 | elif k.startswith('bn3'): 189 | model_dict[k] = v 190 | model_dict[k.replace('bn3', 'bn4')] = v 191 | elif k.startswith('conv4'): 192 | model_dict[k.replace('conv4', 'conv5')] = v 193 | elif k.startswith('bn4'): 194 | model_dict[k.replace('bn4', 'bn5')] = v 195 | else: 196 | model_dict[k] = v 197 | 198 | state_dict.update(model_dict) 199 | self.load_state_dict(state_dict) 200 | 201 | def forward(self, x): 202 | # Entry flow 203 | x = self.conv1(x) 204 | x = self.bn1(x) 205 | x = self.relu(x) 206 | x = self.conv2(x) 207 | x = self.bn2(x) 208 | x = self.block1(x) 209 | low_level_features = x 210 | x = F.relu(x) 211 | x = self.block2(x) 212 | x = self.block3(x) 213 | 214 | # Middle flow 215 | x = self.block4(x) 216 | x = self.block5(x) 217 | x = self.block6(x) 218 | x = self.block7(x) 219 | x = self.block8(x) 220 | x = self.block9(x) 221 | x = self.block10(x) 222 | x = self.block11(x) 223 | x = self.block12(x) 224 | x = self.block13(x) 225 | x = self.block14(x) 226 | x = self.block15(x) 227 | x = self.block16(x) 228 | x = self.block17(x) 229 | x = self.block18(x) 230 | x = self.block19(x) 231 | 232 | # Exit flow 233 | x = self.block20(x) 234 | x = self.relu(x) 235 | x = self.conv3(x) 236 | x = self.bn3(x) 237 | x = self.relu(x) 238 | 239 | x = self.conv4(x) 240 | x = self.bn4(x) 241 | x = self.relu(x) 242 | 243 | x = self.conv5(x) 244 | x = self.bn5(x) 245 | x = self.relu(x) 246 | 247 | return x, low_level_features 248 | 249 | ''' 250 | -> The Atrous Spatial Pyramid Pooling 251 | ''' 252 | 253 | def assp_branch(in_channels, out_channles, kernel_size, dilation): 254 | padding = 0 if kernel_size == 1 else dilation 255 | return nn.Sequential( 256 | nn.Conv2d(in_channels, out_channles, kernel_size, padding=padding, dilation=dilation, bias=False), 257 | nn.BatchNorm2d(out_channles), 258 | nn.ReLU(inplace=True)) 259 | 260 | class ASSP(nn.Module): 261 | def __init__(self, in_channels, output_stride): 262 | super(ASSP, self).__init__() 263 | 264 | assert output_stride in [8, 16], 'Only output strides of 8 or 16 are suported' 265 | if output_stride == 16: dilations = [1, 6, 12, 18] 266 | elif output_stride == 8: dilations = [1, 12, 24, 36] 267 | 268 | self.aspp1 = assp_branch(in_channels, 256, 1, dilation=dilations[0]) 269 | self.aspp2 = assp_branch(in_channels, 256, 3, dilation=dilations[1]) 270 | self.aspp3 = assp_branch(in_channels, 256, 3, dilation=dilations[2]) 271 | self.aspp4 = assp_branch(in_channels, 256, 3, dilation=dilations[3]) 272 | 273 | self.avg_pool = nn.Sequential( 274 | nn.AdaptiveAvgPool2d((1, 1)), 275 | nn.Conv2d(in_channels, 256, 1, bias=False), 276 | nn.BatchNorm2d(256), 277 | nn.ReLU(inplace=True)) 278 | 279 | self.conv1 = nn.Conv2d(256*5, 256, 1, bias=False) 280 | self.bn1 = nn.BatchNorm2d(256) 281 | self.relu = nn.ReLU(inplace=True) 282 | self.dropout = nn.Dropout(0.5) 283 | 284 | initialize_weights(self) 285 | 286 | def forward(self, x): 287 | x1 = self.aspp1(x) 288 | x2 = self.aspp2(x) 289 | x3 = self.aspp3(x) 290 | x4 = self.aspp4(x) 291 | x5 = F.interpolate(self.avg_pool(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True) 292 | 293 | x = self.conv1(torch.cat((x1, x2, x3, x4, x5), dim=1)) 294 | x = self.bn1(x) 295 | x = self.dropout(self.relu(x)) 296 | 297 | return x 298 | 299 | ''' 300 | -> Decoder 301 | ''' 302 | 303 | class Decoder(nn.Module): 304 | def __init__(self, low_level_channels, num_classes): 305 | super(Decoder, self).__init__() 306 | self.conv1 = nn.Conv2d(low_level_channels, 48, 1, bias=False) 307 | self.bn1 = nn.BatchNorm2d(48) 308 | self.relu = nn.ReLU(inplace=True) 309 | 310 | # Table 2, best performance with two 3x3 convs 311 | self.output = nn.Sequential( 312 | nn.Conv2d(48+256, 256, 3, stride=1, padding=1, bias=False), 313 | nn.BatchNorm2d(256), 314 | nn.ReLU(inplace=True), 315 | nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False), 316 | nn.BatchNorm2d(256), 317 | nn.ReLU(inplace=True), 318 | nn.Dropout(0.1), 319 | nn.Conv2d(256, num_classes, 1, stride=1), 320 | ) 321 | initialize_weights(self) 322 | 323 | def forward(self, x, low_level_features): 324 | low_level_features = self.conv1(low_level_features) 325 | low_level_features = self.relu(self.bn1(low_level_features)) 326 | H, W = low_level_features.size(2), low_level_features.size(3) 327 | 328 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) 329 | x = self.output(torch.cat((low_level_features, x), dim=1)) 330 | return x 331 | 332 | ''' 333 | -> Deeplab V3 + 334 | ''' 335 | 336 | class DeepLab(BaseModel): 337 | def __init__(self, num_classes, in_channels=3, backbone='xception', pretrained=False,#pretrained=True, hhl20191020gai 338 | output_stride=16, freeze_bn=False, **_): 339 | 340 | super(DeepLab, self).__init__() 341 | assert ('xception' or 'resnet' in backbone) 342 | if 'resnet' in backbone: 343 | self.backbone = ResNet(in_channels=in_channels, output_stride=output_stride, pretrained=pretrained) 344 | low_level_channels = 256 345 | else: 346 | self.backbone = Xception(output_stride=output_stride, pretrained=pretrained) 347 | low_level_channels = 128 348 | 349 | self.ASSP = ASSP(in_channels=2048, output_stride=output_stride) 350 | self.decoder = Decoder(low_level_channels, num_classes) 351 | 352 | if freeze_bn: self.freeze_bn() 353 | 354 | def forward(self, x): 355 | H, W = x.size(2), x.size(3) 356 | x, low_level_features = self.backbone(x) 357 | x = self.ASSP(x) 358 | x = self.decoder(x, low_level_features) 359 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) 360 | return x 361 | 362 | # Two functions to yield the parameters of the backbone 363 | # & Decoder / ASSP to use differentiable learning rates 364 | # FIXME: in xception, we use the parameters from xception and not aligned xception 365 | # better to have higher lr for this backbone 366 | 367 | def get_backbone_params(self): 368 | return self.backbone.parameters() 369 | 370 | def get_decoder_params(self): 371 | return chain(self.ASSP.parameters(), self.decoder.parameters()) 372 | 373 | def freeze_bn(self): 374 | for module in self.modules(): 375 | if isinstance(module, nn.BatchNorm2d): module.eval() 376 | 377 | -------------------------------------------------------------------------------- /models/fcn8.py: -------------------------------------------------------------------------------- 1 | from base.base_model import BaseModel 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | from hhl_utils.helpers import get_upsampling_weight 6 | import torch 7 | from itertools import chain 8 | 9 | class FCN8(BaseModel): 10 | def __init__(self, num_classes, pretrained=False, freeze_bn=False, **_):#pretrained=True hhl20191020gai 11 | super(FCN8, self).__init__() 12 | vgg = models.vgg16(pretrained) 13 | features = list(vgg.features.children()) 14 | classifier = list(vgg.classifier.children()) 15 | 16 | # Pad the input to enable small inputs and allow matching feature maps 17 | features[0].padding = (100, 100) 18 | 19 | # Enbale ceil in max pool, to avoid different sizes when upsampling 20 | for layer in features: 21 | if 'MaxPool' in layer.__class__.__name__: 22 | layer.ceil_mode = True 23 | 24 | # Extract pool3, pool4 and pool5 from the VGG net 25 | self.pool3 = nn.Sequential(*features[:17]) 26 | self.pool4 = nn.Sequential(*features[17:24]) 27 | self.pool5 = nn.Sequential(*features[24:]) 28 | 29 | # Adjust the depth of pool3 and pool4 to num_classes 30 | self.adj_pool3 = nn.Conv2d(256, num_classes, kernel_size=1) 31 | self.adj_pool4 = nn.Conv2d(512, num_classes, kernel_size=1) 32 | 33 | # Replace the FC layer of VGG with conv layers 34 | conv6 = nn.Conv2d(512, 4096, kernel_size=7) 35 | conv7 = nn.Conv2d(4096, 4096, kernel_size=1) 36 | output = nn.Conv2d(4096, num_classes, kernel_size=1) 37 | 38 | # Copy the weights from VGG's FC pretrained layers 39 | conv6.weight.data.copy_(classifier[0].weight.data.view( 40 | conv6.weight.data.size())) 41 | conv6.bias.data.copy_(classifier[0].bias.data) 42 | 43 | conv7.weight.data.copy_(classifier[3].weight.data.view( 44 | conv7.weight.data.size())) 45 | conv7.bias.data.copy_(classifier[3].bias.data) 46 | 47 | # Get the outputs 48 | self.output = nn.Sequential(conv6, nn.ReLU(inplace=True), nn.Dropout(), 49 | conv7, nn.ReLU(inplace=True), nn.Dropout(), 50 | output) 51 | 52 | # We'll need three upsampling layers, upsampling (x2 +2) the ouputs 53 | # upsampling (x2 +2) addition of pool4 and upsampled output 54 | # upsampling (x8 +8) the final value (pool3 + added output and pool4) 55 | self.up_output = nn.ConvTranspose2d(num_classes, num_classes, 56 | kernel_size=4, stride=2, bias=False) 57 | self.up_pool4_out = nn.ConvTranspose2d(num_classes, num_classes, 58 | kernel_size=4, stride=2, bias=False) 59 | self.up_final = nn.ConvTranspose2d(num_classes, num_classes, 60 | kernel_size=16, stride=8, bias=False) 61 | 62 | # We'll use guassian kernels for the upsampling weights 63 | self.up_output.weight.data.copy_( 64 | get_upsampling_weight(num_classes, num_classes, 4)) 65 | self.up_pool4_out.weight.data.copy_( 66 | get_upsampling_weight(num_classes, num_classes, 4)) 67 | self.up_final.weight.data.copy_( 68 | get_upsampling_weight(num_classes, num_classes, 16)) 69 | 70 | # We'll freeze the wights, this is a fixed upsampling and not deconv 71 | for m in self.modules(): 72 | if isinstance(m, nn.ConvTranspose2d): 73 | m.weight.requires_grad = False 74 | if freeze_bn: self.freeze_bn() 75 | 76 | def forward(self, x): 77 | imh_H, img_W = x.size()[2], x.size()[3] 78 | 79 | # Forward the image 80 | pool3 = self.pool3(x) 81 | pool4 = self.pool4(pool3) 82 | pool5 = self.pool5(pool4) 83 | 84 | # Get the outputs and upsmaple them 85 | output = self.output(pool5) 86 | up_output = self.up_output(output) 87 | 88 | # Adjust pool4 and add the uped-outputs to pool4 89 | adjstd_pool4 = self.adj_pool4(0.01 * pool4) 90 | add_out_pool4 = self.up_pool4_out(adjstd_pool4[:, :, 5: (5 + up_output.size()[2]), 91 | 5: (5 + up_output.size()[3])] 92 | + up_output) 93 | 94 | # Adjust pool3 and add it to the uped last addition 95 | adjstd_pool3 = self.adj_pool3(0.0001 * pool3) 96 | final_value = self.up_final(adjstd_pool3[:, :, 9: (9 + add_out_pool4.size()[2]), 9: (9 + add_out_pool4.size()[3])] 97 | + add_out_pool4) 98 | 99 | # Remove the corresponding padded regions to the input img size 100 | final_value = final_value[:, :, 31: (31 + imh_H), 31: (31 + img_W)].contiguous() 101 | return final_value 102 | 103 | def get_backbone_params(self): 104 | return chain(self.pool3.parameters(), self.pool4.parameters(), self.pool5.parameters(), self.output.parameters()) 105 | 106 | def get_decoder_params(self): 107 | return chain(self.up_output.parameters(), self.adj_pool4.parameters(), self.up_pool4_out.parameters(), 108 | self.adj_pool3.parameters(), self.up_final.parameters()) 109 | 110 | def freeze_bn(self): 111 | for module in self.modules(): 112 | if isinstance(module, nn.BatchNorm2d): module.eval() 113 | 114 | -------------------------------------------------------------------------------- /models/model_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models, datasets, transforms 4 | from torch.nn import functional as F 5 | import os 6 | 7 | 8 | def get_backbone(name, pretrained=True): 9 | 10 | """ Loading backbone, defining names for skip-connections and encoder output. """ 11 | 12 | # TODO: More backbones 13 | 14 | # loading backbone model 15 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}') 16 | if name == 'resnet18': 17 | backbone = models.resnet18(pretrained=pretrained) 18 | elif name == 'resnet34': 19 | backbone = models.resnet34(pretrained=pretrained) 20 | elif name == 'resnet50': 21 | backbone = models.resnet50(pretrained=pretrained) 22 | elif name == 'resnet101': 23 | backbone = models.resnet101(pretrained=pretrained) 24 | elif name == 'resnet152': 25 | backbone = models.resnet152(pretrained=pretrained) 26 | elif name == 'vgg16_bn': 27 | backbone = models.vgg16_bn(pretrained=pretrained).features 28 | elif name == 'vgg19_bn': 29 | backbone = models.vgg19_bn(pretrained=pretrained).features 30 | # elif name == 'inception_v3': 31 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False) 32 | elif name == 'densenet121': 33 | backbone = models.densenet121(pretrained=True).features 34 | elif name == 'densenet161': 35 | backbone = models.densenet161(pretrained=True).features 36 | elif name == 'densenet169': 37 | backbone = models.densenet169(pretrained=True).features 38 | elif name == 'densenet201': 39 | backbone = models.densenet201(pretrained=True).features 40 | elif name == 'unet_encoder': 41 | from unet_backbone import UnetEncoder 42 | backbone = UnetEncoder(3) 43 | else: 44 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 45 | 46 | # specifying skip feature and output names 47 | if name.startswith('resnet'): 48 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3'] 49 | backbone_output = 'layer4' 50 | elif name == 'vgg16_bn': 51 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output 52 | feature_names = ['5', '12', '22', '32', '42'] 53 | backbone_output = '43' 54 | elif name == 'vgg19_bn': 55 | feature_names = ['5', '12', '25', '38', '51'] 56 | backbone_output = '52' 57 | # elif name == 'inception_v3': 58 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e'] 59 | # backbone_output = 'Mixed_7c' 60 | elif name.startswith('densenet'): 61 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3'] 62 | backbone_output = 'denseblock4' 63 | elif name == 'unet_encoder': 64 | feature_names = ['module1', 'module2', 'module3', 'module4'] 65 | backbone_output = 'module5' 66 | else: 67 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name)) 68 | 69 | return backbone, feature_names, backbone_output 70 | 71 | 72 | class UpsampleBlock(nn.Module): 73 | 74 | # TODO: separate parametric and non-parametric classes? 75 | # TODO: skip connection concatenated OR added 76 | 77 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False): 78 | super(UpsampleBlock, self).__init__() 79 | 80 | self.parametric = parametric 81 | ch_out = ch_in/2 if ch_out is None else ch_out 82 | 83 | # first convolution: either transposed conv, or conv following the skip connection 84 | if parametric: 85 | # versions: kernel=4 padding=1, kernel=2 padding=0 86 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4), 87 | stride=2, padding=1, output_padding=0, bias=(not use_bn)) 88 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 89 | else: 90 | self.up = None 91 | ch_in = ch_in + skip_in 92 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3), 93 | stride=1, padding=1, bias=(not use_bn)) 94 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None 95 | 96 | self.relu = nn.ReLU(inplace=True) 97 | 98 | # second convolution 99 | conv2_in = ch_out if not parametric else ch_out + skip_in 100 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3), 101 | stride=1, padding=1, bias=(not use_bn)) 102 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None 103 | 104 | #def forward(self, x, skip_connection=None): # 105 | def forward(self, x, skip_connection=1): # 106 | 107 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear', 108 | align_corners=None) 109 | if self.parametric: 110 | x = self.bn1(x) if self.bn1 is not None else x 111 | x = self.relu(x) 112 | 113 | if skip_connection is not None: 114 | # Padding in case the incomping volumes are of different sizes #hhl20200413add 115 | diffY = skip_connection.size()[2] - x.size()[2] 116 | diffX = skip_connection.size()[3] - x.size()[3] 117 | x = F.pad(x, (diffX // 2, diffX - diffX // 2, 118 | diffY // 2, diffY - diffY // 2)) 119 | 120 | x = torch.cat([x, skip_connection], dim=1) 121 | 122 | if not self.parametric: 123 | x = self.conv1(x) 124 | x = self.bn1(x) if self.bn1 is not None else x 125 | x = self.relu(x) 126 | x = self.conv2(x) 127 | x = self.bn2(x) if self.bn2 is not None else x 128 | x = self.relu(x) 129 | 130 | return x 131 | 132 | 133 | class Unet(nn.Module): 134 | 135 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones.""" 136 | 137 | def __init__(self, 138 | backbone_name='resnet50', 139 | pretrained=True, 140 | encoder_freeze=False, 141 | classes=21, 142 | decoder_filters=(256, 128, 64, 32, 16), 143 | parametric_upsampling=True, 144 | shortcut_features='default', 145 | decoder_use_batchnorm=True): 146 | super(Unet, self).__init__() 147 | 148 | self.backbone_name = backbone_name 149 | 150 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained) 151 | shortcut_chs, bb_out_chs = self.infer_skip_channels() 152 | if shortcut_features != 'default': 153 | self.shortcut_features = shortcut_features 154 | 155 | # build decoder part 156 | self.upsample_blocks = nn.ModuleList() 157 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections 158 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1]) 159 | num_blocks = len(self.shortcut_features) 160 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)): 161 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out)) 162 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out, 163 | skip_in=shortcut_chs[num_blocks-i-1], 164 | parametric=parametric_upsampling, 165 | use_bn=decoder_use_batchnorm)) 166 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1)) 167 | 168 | if encoder_freeze: 169 | self.freeze_encoder() 170 | 171 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later 172 | 173 | # 174 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # 175 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 176 | 177 | def freeze_encoder(self): 178 | 179 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """ 180 | 181 | for param in self.backbone.parameters(): 182 | param.requires_grad = False 183 | 184 | def forward(self, *input): 185 | 186 | """ Forward propagation in U-Net. """ 187 | 188 | x, features = self.forward_backbone(*input) 189 | 190 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks): 191 | skip_features = features[skip_name] 192 | x = upsample_block(x, skip_features) 193 | 194 | x = self.final_conv(x) 195 | 196 | return x 197 | 198 | def forward_backbone(self, x): 199 | 200 | """ Forward propagation in backbone encoder network. """ 201 | #print('x.shape = ',x.shape) 202 | features = {None: None} if None in self.shortcut_features else dict() 203 | for name, child in self.backbone.named_children(): 204 | #print(name,child) 205 | # 206 | if(name == '0' and x.shape[1] !=3): 207 | x = self.child0(x) 208 | elif(name == 'conv1' and x.shape[1] !=3): 209 | x = self.child_conv1(x) 210 | else: 211 | x = child(x) 212 | 213 | if name in self.shortcut_features: 214 | features[name] = x 215 | if name == self.bb_out_name: 216 | break 217 | 218 | return x, features 219 | 220 | def infer_skip_channels(self): 221 | 222 | """ Getting the number of channels at skip connections and at the output of the encoder. """ 223 | 224 | x = torch.zeros(1, 3, 224, 224) 225 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder' 226 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution 227 | 228 | # forward run in backbone to count channels (dirty solution but works for *any* Module) 229 | for name, child in self.backbone.named_children(): 230 | x = child(x) 231 | if name in self.shortcut_features: 232 | channels.append(x.shape[1]) 233 | if name == self.bb_out_name: 234 | out_channels = x.shape[1] 235 | break 236 | return channels, out_channels 237 | 238 | def get_pretrained_parameters(self): 239 | for name, param in self.backbone.named_parameters(): 240 | if not (self.replaced_conv1 and name == 'conv1.weight'): 241 | yield param 242 | 243 | def get_random_initialized_parameters(self): 244 | pretrained_param_names = set() 245 | for name, param in self.backbone.named_parameters(): 246 | if not (self.replaced_conv1 and name == 'conv1.weight'): 247 | pretrained_param_names.add('backbone.{}'.format(name)) 248 | 249 | for name, param in self.named_parameters(): 250 | if name not in pretrained_param_names: 251 | yield param 252 | 253 | 254 | # if __name__ == "__main__": 255 | 256 | # # simple test run 257 | # net = Unet(backbone_name='resnet18') 258 | 259 | # criterion = nn.MSELoss() 260 | # optimizer = torch.optim.Adam(net.parameters()) 261 | # print('Network initialized. Running a test batch.') 262 | # for _ in range(1): 263 | # with torch.set_grad_enabled(True): 264 | # batch = torch.empty(1, 3, 224, 224).normal_() 265 | # targets = torch.empty(1, 21, 224, 224).normal_() 266 | 267 | # out = net(batch) 268 | # loss = criterion(out, targets) 269 | # loss.backward() 270 | # optimizer.step() 271 | # print(out.shape) 272 | 273 | # print('fasza.') 274 | -------------------------------------------------------------------------------- /models/pspnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from models import resnet 6 | from torchvision import models 7 | from base.base_model import BaseModel 8 | from hhl_utils.helpers import initialize_weights, set_trainable 9 | from itertools import chain 10 | 11 | class _PSPModule(nn.Module): 12 | def __init__(self, in_channels, bin_sizes, norm_layer): 13 | super(_PSPModule, self).__init__() 14 | out_channels = in_channels // len(bin_sizes) 15 | self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s, norm_layer) 16 | for b_s in bin_sizes]) 17 | self.bottleneck = nn.Sequential( 18 | nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), out_channels, 19 | kernel_size=3, padding=1, bias=False), 20 | norm_layer(out_channels), 21 | nn.ReLU(inplace=True), 22 | nn.Dropout2d(0.1) 23 | ) 24 | 25 | def _make_stages(self, in_channels, out_channels, bin_sz, norm_layer): 26 | prior = nn.AdaptiveAvgPool2d(output_size=bin_sz) 27 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 28 | bn = norm_layer(out_channels) 29 | relu = nn.ReLU(inplace=True) 30 | return nn.Sequential(prior, conv, bn, relu) 31 | 32 | def forward(self, features): 33 | h, w = features.size()[2], features.size()[3] 34 | pyramids = [features] 35 | pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', 36 | align_corners=True) for stage in self.stages]) 37 | output = self.bottleneck(torch.cat(pyramids, dim=1)) 38 | return output 39 | 40 | 41 | class PSPNet(BaseModel): 42 | def __init__(self, num_classes, in_channels=3, backbone='resnet152', pretrained=False, use_aux=True, freeze_bn=False, freeze_backbone=False):#pretrained=True hhl20191020gai 43 | super(PSPNet, self).__init__() 44 | # TODO: Use synch batchnorm 45 | norm_layer = nn.BatchNorm2d 46 | model = getattr(resnet, backbone)(pretrained, norm_layer=norm_layer, ) 47 | m_out_sz = model.fc.in_features 48 | self.use_aux = use_aux 49 | 50 | self.initial = nn.Sequential(*list(model.children())[:4]) 51 | if in_channels != 3: 52 | self.initial[0] = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) 53 | self.initial = nn.Sequential(*self.initial) 54 | 55 | self.layer1 = model.layer1 56 | self.layer2 = model.layer2 57 | self.layer3 = model.layer3 58 | self.layer4 = model.layer4 59 | 60 | self.master_branch = nn.Sequential( 61 | _PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=norm_layer), 62 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1) 63 | ) 64 | 65 | self.auxiliary_branch = nn.Sequential( 66 | nn.Conv2d(m_out_sz//2, m_out_sz//4, kernel_size=3, padding=1, bias=False), 67 | norm_layer(m_out_sz//4), 68 | nn.ReLU(inplace=True), 69 | nn.Dropout2d(0.1), 70 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1) 71 | ) 72 | 73 | initialize_weights(self.master_branch, self.auxiliary_branch) 74 | if freeze_bn: self.freeze_bn() 75 | if freeze_backbone: 76 | set_trainable([self.initial, self.layer1, self.layer2, self.layer3, self.layer4], False) 77 | 78 | def forward(self, x): 79 | input_size = (x.size()[2], x.size()[3]) 80 | x = self.initial(x) 81 | x = self.layer1(x) 82 | x = self.layer2(x) 83 | x_aux = self.layer3(x) 84 | x = self.layer4(x_aux) 85 | 86 | output = self.master_branch(x) 87 | output = F.interpolate(output, size=input_size, mode='bilinear', align_corners=False)# UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. 88 | output = output[:, :, :input_size[0], :input_size[1]] 89 | 90 | if self.training and self.use_aux: 91 | aux = self.auxiliary_branch(x_aux) 92 | aux = F.interpolate(aux, size=input_size, mode='bilinear', align_corners=False)# UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. 93 | aux = aux[:, :, :input_size[0], :input_size[1]] 94 | return output, aux 95 | return output 96 | 97 | def get_backbone_params(self): 98 | return chain(self.initial.parameters(), self.layer1.parameters(), self.layer2.parameters(), 99 | self.layer3.parameters(), self.layer4.parameters()) 100 | 101 | def get_decoder_params(self): 102 | return chain(self.master_branch.parameters(), self.auxiliary_branch.parameters()) 103 | 104 | def freeze_bn(self): 105 | for module in self.modules(): 106 | if isinstance(module, nn.BatchNorm2d): module.eval() 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | ## PSP with dense net as the backbone 117 | 118 | class PSPDenseNet(BaseModel): 119 | def __init__(self, num_classes, in_channels=3, backbone='densenet201', pretrained=False, use_aux=True, freeze_bn=False, **_):#pretrained=True hhl20191020gai 120 | super(PSPDenseNet, self).__init__() 121 | self.use_aux = use_aux 122 | model = getattr(models, backbone)(pretrained) 123 | m_out_sz = model.classifier.in_features 124 | aux_out_sz = model.features.transition3.conv.out_channels 125 | 126 | if not pretrained or in_channels != 3: 127 | # If we're training from scratch, better to use 3x3 convs 128 | block0 = [nn.Conv2d(in_channels, 64, 3, stride=2, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)] 129 | block0.extend( 130 | [nn.Conv2d(64, 64, 3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)] * 2 131 | ) 132 | self.block0 = nn.Sequential( 133 | *block0, 134 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 135 | ) 136 | initialize_weights(self.block0) 137 | else: 138 | self.block0 = nn.Sequential(*list(model.features.children())[:4]) 139 | 140 | self.block1 = model.features.denseblock1 141 | self.block2 = model.features.denseblock2 142 | self.block3 = model.features.denseblock3 143 | self.block4 = model.features.denseblock4 144 | 145 | self.transition1 = model.features.transition1 146 | # No pooling 147 | self.transition2 = nn.Sequential( 148 | *list(model.features.transition2.children())[:-1]) 149 | self.transition3 = nn.Sequential( 150 | *list(model.features.transition3.children())[:-1]) 151 | 152 | for n, m in self.block3.named_modules(): 153 | if 'conv2' in n: 154 | m.dilation, m.padding = (2,2), (2,2) 155 | for n, m in self.block4.named_modules(): 156 | if 'conv2' in n: 157 | m.dilation, m.padding = (4,4), (4,4) 158 | 159 | self.master_branch = nn.Sequential( 160 | _PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=nn.BatchNorm2d), 161 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1) 162 | ) 163 | 164 | self.auxiliary_branch = nn.Sequential( 165 | nn.Conv2d(aux_out_sz, m_out_sz//4, kernel_size=3, padding=1, bias=False), 166 | nn.BatchNorm2d(m_out_sz//4), 167 | nn.ReLU(inplace=True), 168 | nn.Dropout2d(0.1), 169 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1) 170 | ) 171 | 172 | initialize_weights(self.master_branch, self.auxiliary_branch) 173 | if freeze_bn: self.freeze_bn() 174 | 175 | def forward(self, x): 176 | input_size = (x.size()[2], x.size()[3]) 177 | 178 | x = self.block0(x) 179 | x = self.block1(x) 180 | x = self.transition1(x) 181 | x = self.block2(x) 182 | x = self.transition2(x) 183 | x = self.block3(x) 184 | x_aux = self.transition3(x) 185 | x = self.block4(x_aux) 186 | 187 | output = self.master_branch(x) 188 | output = F.interpolate(output, size=input_size, mode='bilinear', align_corners=False)#UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. 189 | 190 | if self.training and self.use_aux: 191 | aux = self.auxiliary_branch(x_aux) 192 | aux = F.interpolate(aux, size=input_size, mode='bilinear', align_corners=False)#UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. 193 | return output, aux 194 | return output 195 | 196 | def get_backbone_params(self): 197 | return chain(self.block0.parameters(), self.block1.parameters(), self.block2.parameters(), 198 | self.block3.parameters(), self.transition1.parameters(), self.transition2.parameters(), 199 | self.transition3.parameters()) 200 | 201 | def get_decoder_params(self): 202 | return chain(self.master_branch.parameters(), self.auxiliary_branch.parameters()) 203 | 204 | def freeze_bn(self): 205 | for module in self.modules(): 206 | if isinstance(module, nn.BatchNorm2d): module.eval() -------------------------------------------------------------------------------- /models/segnet.py: -------------------------------------------------------------------------------- 1 | from base.base_model import BaseModel 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import models 6 | from itertools import chain 7 | from math import ceil 8 | 9 | class SegNet(BaseModel): 10 | def __init__(self, num_classes, in_channels=3, pretrained=False, freeze_bn=False, **_):#pretrained=True hhl20191020gai 11 | super(SegNet, self).__init__() 12 | vgg_bn = models.vgg16_bn(pretrained= pretrained) 13 | encoder = list(vgg_bn.features.children()) 14 | 15 | # Adjust the input size 16 | if in_channels != 3: 17 | encoder[0].in_channels = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1) 18 | 19 | # Encoder, VGG without any maxpooling 20 | self.stage1_encoder = nn.Sequential(*encoder[:6]) 21 | self.stage2_encoder = nn.Sequential(*encoder[7:13]) 22 | self.stage3_encoder = nn.Sequential(*encoder[14:23]) 23 | self.stage4_encoder = nn.Sequential(*encoder[24:33]) 24 | self.stage5_encoder = nn.Sequential(*encoder[34:-1]) 25 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True) 26 | 27 | # Decoder, same as the encoder but reversed, maxpool will not be used 28 | decoder = encoder 29 | decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)] 30 | # Replace the last conv layer 31 | decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 32 | # When reversing, we also reversed conv->batchN->relu, correct it 33 | decoder = [item for i in range(0, len(decoder), 3) for item in decoder[i:i+3][::-1]] 34 | # Replace some conv layers & batchN after them 35 | for i, module in enumerate(decoder): 36 | if isinstance(module, nn.Conv2d): 37 | if module.in_channels != module.out_channels: 38 | decoder[i+1] = nn.BatchNorm2d(module.in_channels) 39 | decoder[i] = nn.Conv2d(module.out_channels, module.in_channels, kernel_size=3, stride=1, padding=1) 40 | 41 | self.stage1_decoder = nn.Sequential(*decoder[0:9]) 42 | self.stage2_decoder = nn.Sequential(*decoder[9:18]) 43 | self.stage3_decoder = nn.Sequential(*decoder[18:27]) 44 | self.stage4_decoder = nn.Sequential(*decoder[27:33]) 45 | self.stage5_decoder = nn.Sequential(*decoder[33:], 46 | nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1) 47 | ) 48 | self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2) 49 | 50 | self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder, 51 | self.stage4_decoder, self.stage5_decoder) 52 | if freeze_bn: self.freeze_bn() 53 | 54 | def _initialize_weights(self, *stages): 55 | for modules in stages: 56 | for module in modules.modules(): 57 | if isinstance(module, nn.Conv2d): 58 | nn.init.kaiming_normal_(module.weight) 59 | if module.bias is not None: 60 | module.bias.data.zero_() 61 | elif isinstance(module, nn.BatchNorm2d): 62 | module.weight.data.fill_(1) 63 | module.bias.data.zero_() 64 | 65 | def forward(self, x): 66 | # Encoder 67 | x = self.stage1_encoder(x) 68 | x1_size = x.size() 69 | x, indices1 = self.pool(x) 70 | 71 | x = self.stage2_encoder(x) 72 | x2_size = x.size() 73 | x, indices2 = self.pool(x) 74 | 75 | x = self.stage3_encoder(x) 76 | x3_size = x.size() 77 | x, indices3 = self.pool(x) 78 | 79 | x = self.stage4_encoder(x) 80 | x4_size = x.size() 81 | x, indices4 = self.pool(x) 82 | 83 | x = self.stage5_encoder(x) 84 | x5_size = x.size() 85 | x, indices5 = self.pool(x) 86 | 87 | # Decoder 88 | x = self.unpool(x, indices=indices5, output_size=x5_size) 89 | x = self.stage1_decoder(x) 90 | 91 | x = self.unpool(x, indices=indices4, output_size=x4_size) 92 | x = self.stage2_decoder(x) 93 | 94 | x = self.unpool(x, indices=indices3, output_size=x3_size) 95 | x = self.stage3_decoder(x) 96 | 97 | x = self.unpool(x, indices=indices2, output_size=x2_size) 98 | x = self.stage4_decoder(x) 99 | 100 | x = self.unpool(x, indices=indices1, output_size=x1_size) 101 | x = self.stage5_decoder(x) 102 | 103 | return x 104 | 105 | def get_backbone_params(self): 106 | return [] 107 | 108 | def get_decoder_params(self): 109 | return self.parameters() 110 | 111 | def freeze_bn(self): 112 | for module in self.modules(): 113 | if isinstance(module, nn.BatchNorm2d): module.eval() 114 | 115 | 116 | 117 | class DecoderBottleneck(nn.Module): 118 | def __init__(self, inchannels): 119 | super(DecoderBottleneck, self).__init__() 120 | self.conv1 = nn.Conv2d(inchannels, inchannels//4, kernel_size=1, bias=False) 121 | self.bn1 = nn.BatchNorm2d(inchannels//4) 122 | self.conv2 = nn.ConvTranspose2d(inchannels//4, inchannels//4, kernel_size=2, stride=2, bias=False) 123 | self.bn2 = nn.BatchNorm2d(inchannels//4) 124 | self.conv3 = nn.Conv2d(inchannels//4, inchannels//2, 1, bias=False) 125 | self.bn3 = nn.BatchNorm2d(inchannels//2) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.downsample = nn.Sequential( 128 | nn.ConvTranspose2d(inchannels, inchannels//2, kernel_size=2, stride=2, bias=False), 129 | nn.BatchNorm2d(inchannels//2)) 130 | 131 | def forward(self, x): 132 | out = self.conv1(x) 133 | out = self.bn1(out) 134 | out = self.relu(out) 135 | out = self.conv2(out) 136 | out = self.bn2(out) 137 | out = self.relu(out) 138 | out = self.conv3(out) 139 | out = self.bn3(out) 140 | 141 | identity = self.downsample(x) 142 | out += identity 143 | out = self.relu(out) 144 | return out 145 | 146 | class LastBottleneck(nn.Module): 147 | def __init__(self, inchannels): 148 | super(LastBottleneck, self).__init__() 149 | self.conv1 = nn.Conv2d(inchannels, inchannels//4, kernel_size=1, bias=False) 150 | self.bn1 = nn.BatchNorm2d(inchannels//4) 151 | self.conv2 = nn.Conv2d(inchannels//4, inchannels//4, kernel_size=3, padding=1, bias=False) 152 | self.bn2 = nn.BatchNorm2d(inchannels//4) 153 | self.conv3 = nn.Conv2d(inchannels//4, inchannels//4, 1, bias=False) 154 | self.bn3 = nn.BatchNorm2d(inchannels//4) 155 | self.relu = nn.ReLU(inplace=True) 156 | self.downsample = nn.Sequential( 157 | nn.Conv2d(inchannels, inchannels//4, kernel_size=1, bias=False), 158 | nn.BatchNorm2d(inchannels//4)) 159 | 160 | def forward(self, x): 161 | out = self.conv1(x) 162 | out = self.bn1(out) 163 | out = self.relu(out) 164 | out = self.conv2(out) 165 | out = self.bn2(out) 166 | out = self.relu(out) 167 | out = self.conv3(out) 168 | out = self.bn3(out) 169 | 170 | identity = self.downsample(x) 171 | out += identity 172 | out = self.relu(out) 173 | return out 174 | 175 | class SegResNet(BaseModel): 176 | def __init__(self, num_classes, in_channels=3, pretrained=True, freeze_bn=False, **_): 177 | super(SegResNet, self).__init__() 178 | resnet50 = models.resnet50(pretrained=pretrained) 179 | encoder = list(resnet50.children()) 180 | if in_channels != 3: 181 | encoder[0].in_channels = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1) 182 | encoder[3].return_indices = True 183 | 184 | # Encoder 185 | self.first_conv = nn.Sequential(*encoder[:4]) 186 | resnet50_blocks = list(resnet50.children())[4:-2] 187 | self.encoder = nn.Sequential(*resnet50_blocks) 188 | 189 | # Decoder 190 | resnet50_untrained = models.resnet50(pretrained=False) 191 | resnet50_blocks = list(resnet50_untrained.children())[4:-2][::-1] 192 | decoder = [] 193 | channels = (2048, 1024, 512) 194 | for i, block in enumerate(resnet50_blocks[:-1]): 195 | new_block = list(block.children())[::-1][:-1] 196 | decoder.append(nn.Sequential(*new_block, DecoderBottleneck(channels[i]))) 197 | new_block = list(resnet50_blocks[-1].children())[::-1][:-1] 198 | decoder.append(nn.Sequential(*new_block, LastBottleneck(256))) 199 | 200 | self.decoder = nn.Sequential(*decoder) 201 | self.last_conv = nn.Sequential( 202 | nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, bias=False), 203 | nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1) 204 | ) 205 | if freeze_bn: self.freeze_bn() 206 | 207 | def forward(self, x): 208 | inputsize = x.size() 209 | 210 | # Encoder 211 | x, indices = self.first_conv(x) 212 | x = self.encoder(x) 213 | 214 | # Decoder 215 | x = self.decoder(x) 216 | h_diff = ceil((x.size()[2] - indices.size()[2]) / 2) 217 | w_diff = ceil((x.size()[3] - indices.size()[3]) / 2) 218 | if indices.size()[2] % 2 == 1: 219 | x = x[:, :, h_diff:x.size()[2]-(h_diff-1), w_diff: x.size()[3]-(w_diff-1)] 220 | else: 221 | x = x[:, :, h_diff:x.size()[2]-h_diff, w_diff: x.size()[3]-w_diff] 222 | 223 | x = F.max_unpool2d(x, indices, kernel_size=2, stride=2) 224 | x = self.last_conv(x) 225 | 226 | if inputsize != x.size(): 227 | h_diff = (x.size()[2] - inputsize[2]) // 2 228 | w_diff = (x.size()[3] - inputsize[3]) // 2 229 | x = x[:, :, h_diff:x.size()[2]-h_diff, w_diff: x.size()[3]-w_diff] 230 | if h_diff % 2 != 0: x = x[:, :, :-1, :] 231 | if w_diff % 2 != 0: x = x[:, :, :, :-1] 232 | 233 | return x 234 | 235 | def get_backbone_params(self): 236 | return chain(self.first_conv.parameters(), self.encoder.parameters()) 237 | 238 | def get_decoder_params(self): 239 | return chain(self.decoder.parameters(), self.last_conv.parameters()) 240 | 241 | def freeze_bn(self): 242 | for module in self.modules(): 243 | if isinstance(module, nn.BatchNorm2d): module.eval() 244 | 245 | 246 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | #from base.base_model import BaseModel 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from itertools import chain 6 | 7 | 8 | class encoder(nn.Module): 9 | def __init__(self, in_channels, out_channels): 10 | super(encoder, self).__init__() 11 | self.down_conv = nn.Sequential( 12 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 13 | nn.BatchNorm2d(out_channels), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 16 | nn.BatchNorm2d(out_channels), 17 | nn.ReLU(inplace=True), 18 | ) 19 | self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True) 20 | 21 | def forward(self, x): 22 | x = self.down_conv(x) 23 | x_pooled = self.pool(x) 24 | return x, x_pooled 25 | 26 | #nn.Upsample(scale_factor=2) 27 | class decoder(nn.Module): 28 | def __init__(self, in_channels, out_channels): 29 | super(decoder, self).__init__() 30 | self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 31 | self.up_conv = nn.Sequential( 32 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), 33 | nn.BatchNorm2d(out_channels), 34 | nn.ReLU(inplace=True), 35 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 36 | nn.BatchNorm2d(out_channels), 37 | nn.ReLU(inplace=True), 38 | ) 39 | 40 | def forward(self, x_copy, x): 41 | x = self.up(x) 42 | # Padding in case the incomping volumes are of different sizes 43 | diffY = x_copy.size()[2] - x.size()[2] 44 | diffX = x_copy.size()[3] - x.size()[3] 45 | x = F.pad(x, (diffX // 2, diffX - diffX // 2, 46 | diffY // 2, diffY - diffY // 2)) 47 | # Concatenate 48 | x = torch.cat([x_copy, x], dim=1) 49 | x = self.up_conv(x) 50 | return x 51 | 52 | 53 | class UNet(nn.Module):#BaseModel 54 | def __init__(self, num_classes, in_channels=3, freeze_bn=False, **_): 55 | super(UNet, self).__init__() 56 | self.down1 = encoder(in_channels, 64) 57 | self.down2 = encoder(64, 128) 58 | self.down3 = encoder(128, 256) 59 | self.down4 = encoder(256, 512) 60 | self.middle_conv = nn.Sequential( 61 | nn.Conv2d(512, 1024, kernel_size=3, padding=1), 62 | nn.BatchNorm2d(1024), 63 | nn.ReLU(inplace=True), 64 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1), 65 | nn.BatchNorm2d(1024), 66 | nn.ReLU(inplace=True), 67 | ) 68 | self.up1 = decoder(1024, 512) 69 | self.up2 = decoder(512, 256) 70 | self.up3 = decoder(256, 128) 71 | self.up4 = decoder(128, 64) 72 | self.up = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2) 73 | self.beforefinal2_conv = nn.Conv2d(128, num_classes, kernel_size=1) # 128 74 | 75 | self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1) 76 | self._initialize_weights() 77 | if freeze_bn: 78 | self.freeze_bn() 79 | 80 | def _initialize_weights(self): 81 | for module in self.modules(): 82 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 83 | nn.init.kaiming_normal_(module.weight) 84 | if module.bias is not None: 85 | module.bias.data.zero_() 86 | elif isinstance(module, nn.BatchNorm2d): 87 | module.weight.data.fill_(1) 88 | module.bias.data.zero_() 89 | 90 | def forward(self, x): 91 | x1, x = self.down1(x) 92 | x2, x = self.down2(x) 93 | x3, x = self.down3(x) 94 | x4, x = self.down4(x) 95 | x = self.middle_conv(x) 96 | x = self.up1(x4, x) 97 | x = self.up2(x3, x) 98 | x = self.up3(x2, x) 99 | #x_beforefinal2_temp = self.up(x) 100 | #x_beforefinal2 = self.beforefinal2_conv(x_beforefinal2_temp) 101 | 102 | x = self.up4(x1, x) 103 | # x_beforefinal2 = self.beforefinal2_conv(x) 104 | 105 | x_final = self.final_conv(x) 106 | return x_final#, x_beforefinal2 107 | 108 | def get_backbone_params(self): 109 | # There is no backbone for unet, all the parameters are trained from scratch 110 | return [] 111 | 112 | def get_decoder_params(self): 113 | return self.parameters() 114 | 115 | def freeze_bn(self): 116 | for module in self.modules(): 117 | if isinstance(module, nn.BatchNorm2d): module.eval() -------------------------------------------------------------------------------- /postproc_other.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import cv2 4 | import numpy as np 5 | from scipy.ndimage import filters, measurements 6 | from scipy.ndimage.morphology import ( 7 | binary_erosion, 8 | binary_dilation, 9 | binary_fill_holes, 10 | distance_transform_cdt, 11 | distance_transform_edt) 12 | from skimage.morphology import remove_small_objects#, watershed 13 | from skimage.segmentation import watershed 14 | 15 | def process(pred, model_mode, min_size = 10, ws=True): 16 | def gen_inst_dst_map(ann): 17 | shape = ann.shape[:2] # HW 18 | nuc_list = list(np.unique(ann)) 19 | nuc_list.remove(0) # 0 is background 20 | 21 | canvas = np.zeros(shape, dtype=np.uint8) 22 | for nuc_id in nuc_list: 23 | nuc_map = np.copy(ann == nuc_id) 24 | nuc_dst = distance_transform_edt(nuc_map) 25 | nuc_dst = 255 * (nuc_dst / np.amax(nuc_dst)) 26 | canvas += nuc_dst.astype('uint8') 27 | return canvas 28 | 29 | if model_mode != 'dcan': 30 | assert len(pred.shape) == 2, 'Prediction shape is not HW' 31 | pred[pred > 0.5] = 1 32 | pred[pred <= 0.5] = 0 33 | 34 | # ! refactor these 35 | ws = False if model_mode == 'unet' or model_mode == 'micronet' else ws 36 | if ws: 37 | dist = measurements.label(pred)[0] 38 | dist = gen_inst_dst_map(dist) 39 | marker = np.copy(dist) 40 | marker[marker <= 125] = 0 41 | marker[marker > 125] = 1 42 | marker = binary_fill_holes(marker) 43 | marker = binary_erosion(marker, iterations=1) 44 | marker = measurements.label(marker)[0] 45 | 46 | marker = remove_small_objects(marker, min_size=min_size) 47 | pred = watershed(-dist, marker, mask=pred) 48 | pred = remove_small_objects(pred, min_size=min_size) 49 | #print('============================ ws = True ============================ ') 50 | else: 51 | pred = binary_fill_holes(pred) 52 | pred = measurements.label(pred)[0] 53 | pred = remove_small_objects(pred, min_size=min_size) 54 | print('binary_fill_holes(pred), measurements.label(pred)[0], remove_small_objects(pred, min_size=10)') 55 | 56 | if model_mode == 'micronet': 57 | # * dilate with same kernel size used for erosion during training 58 | kernel = np.array([[0, 1, 0], 59 | [1, 1, 1], 60 | [0, 1, 0]], np.uint8) 61 | 62 | canvas = np.zeros([pred.shape[0], pred.shape[1]]) 63 | for inst_id in range(1, np.max(pred)+1): 64 | inst_map = np.array(pred == inst_id, dtype=np.uint8) 65 | inst_map = cv2.dilate(inst_map, kernel, iterations=1) 66 | inst_map = binary_fill_holes(inst_map) 67 | canvas[inst_map > 0] = inst_id 68 | pred = canvas 69 | else: 70 | assert (pred.shape[2]) == 2, 'Prediction should have contour and blb' 71 | blb = pred[...,0] 72 | blb = np.squeeze(blb) 73 | cnt = pred[...,1] 74 | cnt = np.squeeze(cnt) 75 | 76 | pred = blb - cnt # NOTE 77 | pred[pred > 0.3] = 1 # Kumar 0.3, UHCW 0.3 78 | pred[pred <= 0.3] = 0 # CPM2017 0.1 79 | pred = measurements.label(pred)[0] 80 | pred = remove_small_objects(pred, min_size=min_size) # 20 81 | canvas = np.zeros([pred.shape[0], pred.shape[1]]) 82 | 83 | k_disk = np.array([ 84 | [0, 0, 0, 1, 0, 0, 0], 85 | [0, 0, 1, 1, 1, 0, 0], 86 | [0, 1, 1, 1, 1, 1, 0], 87 | [1, 1, 1, 1, 1, 1, 1], 88 | [0, 1, 1, 1, 1, 1, 0], 89 | [0, 0, 1, 1, 1, 0, 0], 90 | [0, 0, 0, 1, 0, 0, 0], 91 | ], np.uint8) 92 | for inst_id in range(1, np.max(pred)+1): 93 | inst_map = np.array(pred == inst_id, dtype=np.uint8) 94 | inst_map = cv2.dilate(inst_map, k_disk, iterations=1) 95 | inst_map = binary_fill_holes(inst_map) 96 | canvas[inst_map > 0] = inst_id 97 | pred = canvas 98 | 99 | return pred -------------------------------------------------------------------------------- /stats_utils.py: -------------------------------------------------------------------------------- 1 | # from HoverNet 2 | import warnings 3 | import numpy as np 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | 7 | def get_fast_aji(true, pred): 8 | """ 9 | AJI version distributed by MoNuSeg, has no permutation problem but suffered from 10 | over-penalisation similar to DICE2 11 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] 12 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no 13 | effect on the result. 14 | """ 15 | true = np.copy(true) # ? do we need this 16 | pred = np.copy(pred) 17 | true_id_list = list(np.unique(true)) 18 | pred_id_list = list(np.unique(pred)) 19 | 20 | true_masks = [None, ] 21 | for t in true_id_list[1:]: 22 | t_mask = np.array(true == t, np.uint8) 23 | true_masks.append(t_mask) 24 | 25 | pred_masks = [None, ] 26 | for p in pred_id_list[1:]: 27 | p_mask = np.array(pred == p, np.uint8) 28 | pred_masks.append(p_mask) 29 | 30 | # prefill with value 31 | pairwise_inter = np.zeros([len(true_id_list) - 1, 32 | len(pred_id_list) - 1], dtype=np.float64) 33 | pairwise_union = np.zeros([len(true_id_list) - 1, 34 | len(pred_id_list) - 1], dtype=np.float64) 35 | # 多检 36 | pairwise_FP = np.zeros([len(true_id_list) - 1, 37 | len(pred_id_list) - 1], dtype=np.float64) 38 | # 漏检 39 | pairwise_FN = np.zeros([len(true_id_list) - 1, 40 | len(pred_id_list) - 1], dtype=np.float64) 41 | 42 | # caching pairwise 43 | for true_id in true_id_list[1:]: # 0-th is background 44 | t_mask = true_masks[true_id] 45 | pred_true_overlap = pred[t_mask > 0] 46 | pred_true_overlap_id = np.unique(pred_true_overlap) 47 | pred_true_overlap_id = list(pred_true_overlap_id) 48 | for pred_id in pred_true_overlap_id: 49 | if pred_id == 0: # ignore 50 | continue # overlaping background 51 | p_mask = pred_masks[pred_id] 52 | total = (t_mask + p_mask).sum() 53 | inter = (t_mask * p_mask).sum() 54 | pairwise_inter[true_id - 1, pred_id - 1] = inter 55 | pairwise_union[true_id - 1, pred_id - 1] = total - inter 56 | 57 | pairwise_FP[true_id - 1, pred_id - 1] = p_mask.sum() - inter 58 | pairwise_FN[true_id - 1, pred_id - 1] = t_mask.sum() - inter 59 | # 60 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6) 61 | # pair of pred that give highest iou for each true, dont care 62 | # about reusing pred instance multiple times 63 | paired_pred = np.argmax(pairwise_iou, axis=1) 64 | pairwise_iou = np.max(pairwise_iou, axis=1) 65 | # exlude those dont have intersection 66 | paired_true = np.nonzero(pairwise_iou > 0.0)[0] 67 | paired_pred = paired_pred[paired_true] 68 | # print(paired_true.shape, paired_pred.shape) 69 | 70 | overall_inter = (pairwise_inter[paired_true, paired_pred]).sum() 71 | overall_union = (pairwise_union[paired_true, paired_pred]).sum() 72 | 73 | overall_FP = (pairwise_FP[paired_true, paired_pred]).sum() 74 | overall_FN = (pairwise_FN[paired_true, paired_pred]).sum() 75 | 76 | 77 | # 78 | paired_true = (list(paired_true + 1)) # index to instance ID 79 | paired_pred = (list(paired_pred + 1)) 80 | # add all unpaired GT and Prediction into the union 81 | unpaired_true = np.array([idx for idx in true_id_list[1:] if idx not in paired_true]) 82 | unpaired_pred = np.array([idx for idx in pred_id_list[1:] if idx not in paired_pred]) 83 | 84 | less_pred = 0 85 | more_pred = 0 86 | 87 | for true_id in unpaired_true: 88 | less_pred += true_masks[true_id].sum() 89 | overall_union += true_masks[true_id].sum() 90 | for pred_id in unpaired_pred: 91 | more_pred += pred_masks[pred_id].sum() 92 | overall_union += pred_masks[pred_id].sum() 93 | # 94 | aji_score = overall_inter / overall_union 95 | fm = overall_union - overall_inter 96 | print('\t [ana_FP = {:.4f}, ana_FN = {:.4f}, ana_less = {:.4f}, ana_more = {:.4f}]'.format((overall_FP / fm),(overall_FN / fm),(less_pred / fm),(more_pred / fm))) 97 | 98 | return aji_score, overall_FP / fm, overall_FN / fm, less_pred / fm, more_pred / fm 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | ##### 108 | def get_fast_aji_plus(true, pred): 109 | """ 110 | AJI+, an AJI version with maximal unique pairing to obtain overall intersecion. 111 | Every prediction instance is paired with at most 1 GT instance (1 to 1) mapping, unlike AJI 112 | where a prediction instance can be paired against many GT instances (1 to many). 113 | Remaining unpaired GT and Prediction instances will be added to the overall union. 114 | The 1 to 1 mapping prevents AJI's over-penalisation from happening. 115 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] 116 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no 117 | effect on the result. 118 | """ 119 | true = np.copy(true) # ? do we need this 120 | pred = np.copy(pred) 121 | true_id_list = list(np.unique(true)) 122 | pred_id_list = list(np.unique(pred)) 123 | 124 | true_masks = [None, ] 125 | for t in true_id_list[1:]: 126 | t_mask = np.array(true == t, np.uint8) 127 | true_masks.append(t_mask) 128 | 129 | pred_masks = [None, ] 130 | for p in pred_id_list[1:]: 131 | p_mask = np.array(pred == p, np.uint8) 132 | pred_masks.append(p_mask) 133 | 134 | # prefill with value 135 | pairwise_inter = np.zeros([len(true_id_list) - 1, 136 | len(pred_id_list) - 1], dtype=np.float64) 137 | pairwise_union = np.zeros([len(true_id_list) - 1, 138 | len(pred_id_list) - 1], dtype=np.float64) 139 | 140 | # caching pairwise 141 | for true_id in true_id_list[1:]: # 0-th is background 142 | t_mask = true_masks[true_id] 143 | pred_true_overlap = pred[t_mask > 0] 144 | pred_true_overlap_id = np.unique(pred_true_overlap) 145 | pred_true_overlap_id = list(pred_true_overlap_id) 146 | for pred_id in pred_true_overlap_id: 147 | if pred_id == 0: # ignore 148 | continue # overlaping background 149 | p_mask = pred_masks[pred_id] 150 | total = (t_mask + p_mask).sum() 151 | inter = (t_mask * p_mask).sum() 152 | pairwise_inter[true_id - 1, pred_id - 1] = inter 153 | pairwise_union[true_id - 1, pred_id - 1] = total - inter 154 | # 155 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6) 156 | #### Munkres pairing to find maximal unique pairing 157 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) 158 | ### extract the paired cost and remove invalid pair 159 | paired_iou = pairwise_iou[paired_true, paired_pred] 160 | # now select all those paired with iou != 0.0 i.e have intersection 161 | paired_true = paired_true[paired_iou > 0.0] 162 | paired_pred = paired_pred[paired_iou > 0.0] 163 | paired_inter = pairwise_inter[paired_true, paired_pred] 164 | paired_union = pairwise_union[paired_true, paired_pred] 165 | paired_true = (list(paired_true + 1)) # index to instance ID 166 | paired_pred = (list(paired_pred + 1)) 167 | overall_inter = paired_inter.sum() 168 | overall_union = paired_union.sum() 169 | # add all unpaired GT and Prediction into the union 170 | unpaired_true = np.array([idx for idx in true_id_list[1:] if idx not in paired_true]) 171 | unpaired_pred = np.array([idx for idx in pred_id_list[1:] if idx not in paired_pred]) 172 | for true_id in unpaired_true: 173 | overall_union += true_masks[true_id].sum() 174 | for pred_id in unpaired_pred: 175 | overall_union += pred_masks[pred_id].sum() 176 | # 177 | aji_score = overall_inter / overall_union 178 | return aji_score 179 | 180 | 181 | ##### 182 | def get_fast_pq(true, pred, match_iou=0.5): 183 | """ 184 | `match_iou` is the IoU threshold level to determine the pairing between 185 | GT instances `p` and prediction instances `g`. `p` and `g` is a pair 186 | if IoU > `match_iou`. However, pair of `p` and `g` must be unique 187 | (1 prediction instance to 1 GT instance mapping). 188 | If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching 189 | in bipartite graphs) is caculated to find the maximal amount of unique pairing. 190 | If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and 191 | the number of pairs is also maximal. 192 | 193 | Fast computation requires instance IDs are in contiguous orderding 194 | i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand 195 | and `by_size` flag has no effect on the result. 196 | Returns: 197 | [dq, sq, pq]: measurement statistic 198 | [paired_true, paired_pred, unpaired_true, unpaired_pred]: 199 | pairing information to perform measurement 200 | 201 | """ 202 | assert match_iou >= 0.0, "Cant' be negative" 203 | 204 | true = np.copy(true) 205 | pred = np.copy(pred) 206 | true_id_list = list(np.unique(true)) 207 | pred_id_list = list(np.unique(pred)) 208 | 209 | true_masks = [None, ] 210 | for t in true_id_list[1:]: 211 | t_mask = np.array(true == t, np.uint8) 212 | true_masks.append(t_mask) 213 | 214 | pred_masks = [None, ] 215 | for p in pred_id_list[1:]: 216 | p_mask = np.array(pred == p, np.uint8) 217 | pred_masks.append(p_mask) 218 | 219 | # prefill with value 220 | pairwise_iou = np.zeros([len(true_id_list) - 1, 221 | len(pred_id_list) - 1], dtype=np.float64) 222 | 223 | # caching pairwise iou 224 | for true_id in true_id_list[1:]: # 0-th is background 225 | t_mask = true_masks[true_id] 226 | pred_true_overlap = pred[t_mask > 0] 227 | pred_true_overlap_id = np.unique(pred_true_overlap) 228 | pred_true_overlap_id = list(pred_true_overlap_id) 229 | for pred_id in pred_true_overlap_id: 230 | if pred_id == 0: # ignore 231 | continue # overlaping background 232 | p_mask = pred_masks[pred_id] 233 | total = (t_mask + p_mask).sum() 234 | inter = (t_mask * p_mask).sum() 235 | iou = inter / (total - inter) 236 | pairwise_iou[true_id - 1, pred_id - 1] = iou 237 | # 238 | if match_iou >= 0.5: 239 | paired_iou = pairwise_iou[pairwise_iou > match_iou] 240 | pairwise_iou[pairwise_iou <= match_iou] = 0.0 241 | paired_true, paired_pred = np.nonzero(pairwise_iou) 242 | paired_iou = pairwise_iou[paired_true, paired_pred] 243 | paired_true += 1 # index is instance id - 1 244 | paired_pred += 1 # hence return back to original 245 | else: # * Exhaustive maximal unique pairing 246 | #### Munkres pairing with scipy library 247 | # the algorithm return (row indices, matched column indices) 248 | # if there is multiple same cost in a row, index of first occurence 249 | # is return, thus the unique pairing is ensure 250 | # inverse pair to get high IoU as minimum 251 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) 252 | ### extract the paired cost and remove invalid pair 253 | paired_iou = pairwise_iou[paired_true, paired_pred] 254 | 255 | # now select those above threshold level 256 | # paired with iou = 0.0 i.e no intersection => FP or FN 257 | paired_true = list(paired_true[paired_iou > match_iou] + 1) 258 | paired_pred = list(paired_pred[paired_iou > match_iou] + 1) 259 | paired_iou = paired_iou[paired_iou > match_iou] 260 | 261 | # get the actual FP and FN 262 | unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true] 263 | unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred] 264 | # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred)) 265 | 266 | # 267 | tp = len(paired_true) 268 | fp = len(unpaired_pred) 269 | fn = len(unpaired_true) 270 | # get the F1-score i.e DQ 271 | dq = tp / (tp + 0.5 * fp + 0.5 * fn) 272 | # get the SQ, no paired has 0 iou so not impact 273 | sq = paired_iou.sum() / (tp + 1.0e-6) 274 | 275 | return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred] 276 | 277 | 278 | ##### 279 | def get_fast_dice_2(true, pred): 280 | """ 281 | Ensemble dice 282 | """ 283 | true = np.copy(true) 284 | pred = np.copy(pred) 285 | true_id = list(np.unique(true)) 286 | pred_id = list(np.unique(pred)) 287 | 288 | overall_total = 0 289 | overall_inter = 0 290 | 291 | true_masks = [np.zeros(true.shape)] 292 | for t in true_id[1:]: 293 | t_mask = np.array(true == t, np.uint8) 294 | true_masks.append(t_mask) 295 | 296 | pred_masks = [np.zeros(true.shape)] 297 | for p in pred_id[1:]: 298 | p_mask = np.array(pred == p, np.uint8) 299 | pred_masks.append(p_mask) 300 | 301 | for true_idx in range(1, len(true_id)): 302 | t_mask = true_masks[true_idx] 303 | pred_true_overlap = pred[t_mask > 0] 304 | pred_true_overlap_id = np.unique(pred_true_overlap) 305 | pred_true_overlap_id = list(pred_true_overlap_id) 306 | try: # blinly remove background 307 | pred_true_overlap_id.remove(0) 308 | except ValueError: 309 | pass # just mean no background 310 | for pred_idx in pred_true_overlap_id: 311 | p_mask = pred_masks[pred_idx] 312 | total = (t_mask + p_mask).sum() 313 | inter = (t_mask * p_mask).sum() 314 | overall_total += total 315 | overall_inter += inter 316 | 317 | return 2 * overall_inter / overall_total 318 | 319 | 320 | ##### 321 | 322 | #####--------------------------As pseudocode 323 | def get_dice_1(true, pred): 324 | """ 325 | Traditional dice 326 | """ 327 | # cast to binary 1st 328 | true = np.copy(true) 329 | pred = np.copy(pred) 330 | true[true > 0] = 1 331 | pred[pred > 0] = 1 332 | inter = true * pred 333 | denom = true + pred 334 | return 2.0 * np.sum(inter) / np.sum(denom) 335 | 336 | 337 | #### 338 | def get_dice_2(true, pred): 339 | true = np.copy(true) 340 | pred = np.copy(pred) 341 | true_id = list(np.unique(true)) 342 | pred_id = list(np.unique(pred)) 343 | # remove background aka id 0 344 | true_id.remove(0) 345 | pred_id.remove(0) 346 | 347 | total_markup = 0 348 | total_intersect = 0 349 | for t in true_id: 350 | t_mask = np.array(true == t, np.uint8) 351 | for p in pred_id: 352 | p_mask = np.array(pred == p, np.uint8) 353 | intersect = p_mask * t_mask 354 | if intersect.sum() > 0: 355 | total_intersect += intersect.sum() 356 | total_markup += (t_mask.sum() + p_mask.sum()) 357 | return 2 * total_intersect / total_markup 358 | 359 | 360 | ##### 361 | def remap_label(pred, by_size=False): 362 | """ 363 | Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] 364 | not [0, 2, 4, 6]. The ordering of instances (which one comes first) 365 | is preserved unless by_size=True, then the instances will be reordered 366 | so that bigger nucler has smaller ID 367 | Args: 368 | pred : the 2d array contain instances where each instances is marked 369 | by non-zero integer 370 | by_size : renaming with larger nuclei has smaller id (on-top) 371 | """ 372 | pred_id = list(np.unique(pred)) 373 | pred_id.remove(0) 374 | if len(pred_id) == 0: 375 | return pred # no label 376 | if by_size: 377 | pred_size = [] 378 | for inst_id in pred_id: 379 | size = (pred == inst_id).sum() 380 | pred_size.append(size) 381 | # sort the id by size in descending order 382 | pair_list = zip(pred_id, pred_size) 383 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True) 384 | pred_id, pred_size = zip(*pair_list) 385 | 386 | new_pred = np.zeros(pred.shape, np.int32) 387 | for idx, inst_id in enumerate(pred_id): 388 | new_pred[pred == inst_id] = idx + 1 389 | return new_pred 390 | 391 | 392 | ##### 393 | def pair_coordinates(setA, setB, radius): 394 | """ 395 | Use the Munkres or Kuhn-Munkres algorithm to find the most optimal 396 | unique pairing (largest possible match) when pairing points in set B 397 | against points in set A, using distance as cost function 398 | Args: 399 | setA, setB: np.array (float32) of size Nx2 contains the of XY coordinate 400 | of N different points 401 | radius: valid area around a point in setA to consider 402 | a given coordinate in setB a candidate for match 403 | Return: 404 | pairing: pairing is an array of indices 405 | where point at index pairing[0] in set A paired with point 406 | in set B at index pairing[1] 407 | unparedA, unpairedB: remaining poitn in set A and set B unpaired 408 | """ 409 | 410 | # * Euclidean distance as the cost matrix 411 | setA_tile = np.expand_dims(setA, axis=1) 412 | setB_tile = np.expand_dims(setB, axis=0) 413 | setA_tile = np.repeat(setA_tile, setB.shape[0], axis=1) 414 | setB_tile = np.repeat(setB_tile, setA.shape[0], axis=0) 415 | pair_distance = (setA_tile - setB_tile) ** 2 416 | # set A is row, and set B is paired against set A 417 | pair_distance = np.sqrt(np.sum(pair_distance, axis=-1)) 418 | 419 | # * Munkres pairing with scipy library 420 | # the algorithm return (row indices, matched column indices) 421 | # if there is multiple same cost in a row, index of first occurence 422 | # is return, thus the unique pairing is ensured 423 | indicesA, paired_indicesB = linear_sum_assignment(pair_distance) 424 | 425 | # extract the paired cost and remove instances 426 | # outside of designated radius 427 | pair_cost = pair_distance[indicesA, paired_indicesB] 428 | 429 | pairedA = indicesA[pair_cost <= radius] 430 | pairedB = paired_indicesB[pair_cost <= radius] 431 | 432 | unpairedA = [idx for idx in range(setA.shape[0]) if idx not in list(pairedA)] 433 | unpairedB = [idx for idx in range(setB.shape[0]) if idx not in list(pairedB)] 434 | 435 | pairing = np.array(list(zip(pairedA, pairedB))) 436 | unpairedA = np.array(unpairedA, dtype=np.int64) 437 | unpairedB = np.array(unpairedB, dtype=np.int64) 438 | 439 | return pairing, unpairedA, unpairedB --------------------------------------------------------------------------------