├── README.md
├── base
└── base_model.py
├── data_folder.py
├── data_prepare
├── SegFix_offset_helper.py
├── getDirectionDiffMap.py
└── logger.py
├── hhl_utils
├── helpers.py
├── pytorch_ssim.py
├── radam.py
├── ranger.py
└── torchsummary.py
├── loss.py
├── models
├── FullNet.py
├── dam
│ ├── model_unet_MandD.py
│ ├── model_unet_MandD16.py
│ ├── model_unet_MandD4.py
│ ├── model_unet_MandDandP.py
│ ├── model_unet_rev1.py
│ └── seg_hrnet_rev1.py
├── deeplabv3_plus.py
├── fcn8.py
├── model_unet.py
├── pspnet.py
├── seg_hrnet.py
├── segnet.py
└── unet.py
├── my_transforms.py
├── my_transforms_direction.py
├── options.py
├── postproc_other.py
├── stats_utils.py
├── test.py
├── test_dam.py
├── train.py
├── train_util.py
├── train_util_dam.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # CDNet: Centripetal Direction Network for Nuclear Instance Segmentation
2 |
3 |
4 | [[`ICCV2021`](https://openaccess.thecvf.com/content/ICCV2021/papers/He_CDNet_Centripetal_Direction_Network_for_Nuclear_Instance_Segmentation_ICCV_2021_paper.pdf)]
5 |
6 | The code includes training and inference procedures for CDNet.
7 |
8 | Tips:
9 | There is a result written mistake (U-Net) in Table 4 in the original paper.
10 | The correct result is:
11 |
12 | ### MoNuSeg
13 |
14 |
15 |
16 |
17 | | Method Name |
18 | Dice |
19 | AJI |
20 |
21 |
22 | | U-Net |
23 | 0.8184 |
24 | 0.5910 |
25 |
26 |
27 | | Mask-RCNN |
28 | 0.7600 |
29 | 0.5460 |
30 |
31 |
32 | | DCAN |
33 | 0.7920 |
34 | 0.5250 |
35 |
36 |
37 | | Micro-Net |
38 | 0.7970 |
39 | 0.5600 |
40 |
41 |
42 | | DIST |
43 | 0.7890 |
44 | 0.5590 |
45 |
46 |
47 | | CIA-Net |
48 | 0.8180 |
49 | 0.6200 |
50 |
51 |
52 | | U-Net |
53 | 0.8027 |
54 | 0.6039 |
55 |
56 |
57 | | Hover-Net |
58 | 0.8260 |
59 | 0.6180 |
60 |
61 |
62 | | BRP-Net |
63 | - |
64 | 0.6422 |
65 |
66 |
67 | | PFF-Net |
68 | 0.8091 |
69 | 0.6107 |
70 |
71 |
72 | | Our CDNet |
73 | 0.8316 |
74 | 0.6331 |
75 |
76 |
77 |
78 |
79 | ## Getting Started
80 | #### Create a data folder(/data) and put the datasets(MoNuSeg, CPM17) in it.
81 |
82 | #### Train
83 | ```
84 | cd CDNet/
85 | python train.py
86 | ```
87 |
88 | #### Test
89 | ```
90 | cd CDNet/
91 | python test.py
92 | ```
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
--------------------------------------------------------------------------------
/base/base_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | class BaseModel(nn.Module):
7 | def __init__(self):
8 | super(BaseModel, self).__init__()
9 | self.logger = logging.getLogger(self.__class__.__name__)
10 |
11 | def forward(self):
12 | raise NotImplementedError
13 |
14 | def summary(self):
15 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
16 | nbr_params = sum([np.prod(p.size()) for p in model_parameters])
17 | self.logger.info(f'Nbr of trainable parameters: {nbr_params}')
18 |
19 | def __str__(self):
20 | model_parameters = filter(lambda p: p.requires_grad, self.parameters())
21 | nbr_params = sum([np.prod(p.size()) for p in model_parameters])
22 | return super(BaseModel, self).__str__() + f'\nNbr of trainable parameters: {nbr_params}'
23 | #return summary(self, input_shape=(2, 3, 224, 224))
--------------------------------------------------------------------------------
/data_folder.py:
--------------------------------------------------------------------------------
1 |
2 | import torch.utils.data as data
3 | import os
4 | from PIL import Image
5 | import numpy as np
6 | import scipy.io as scio
7 | import torch
8 | #from skimage import morphology, io
9 |
10 | IMG_EXTENSIONS = [
11 | '.jpg', '.JPG', '.jpeg', '.JPEG',
12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
13 | ]
14 |
15 |
16 | def is_image_file(filename):
17 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
18 |
19 |
20 | def img_loader(path, num_channels):
21 |
22 |
23 | if num_channels == 1:
24 | if ('.mat' in path):
25 | img = scio.loadmat(path)['inst_map']
26 | img = Image.fromarray(img.astype(np.uint8))
27 | elif('.npy' in path):
28 | img = np.load(path)
29 | img = Image.fromarray(img.astype(np.uint8))
30 | else:
31 | img = Image.open(path)
32 | else:
33 | if('.mat' in path):
34 | img = scio.loadmat(path)['inst_map']
35 | elif ('.npy' in path):
36 | img = np.load(path)
37 | img = Image.fromarray(img.astype(np.uint8))
38 | else:
39 | img = Image.open(path).convert('RGB')
40 |
41 | return img
42 |
43 |
44 | # get the image list pairs
45 | def get_imgs_list(dir_list, post_fix=None):
46 | """
47 | :param dir_list: [img1_dir, img2_dir, ...]
48 | :param post_fix: e.g. ['label.png', 'weight.png',...]
49 | :return: e.g. [(img1.ext, img1_label.png, img1_weight.png), ...]
50 | """
51 | img_list = []
52 | if len(dir_list) == 0:
53 | return img_list
54 | if len(dir_list) != len(post_fix) + 1:
55 | raise (RuntimeError('Should specify the postfix of each img type except the first input.'))
56 |
57 | img_filename_list = [os.listdir(dir_list[i]) for i in range(len(dir_list))]
58 |
59 | for img in img_filename_list[0]:
60 | if not is_image_file(img):
61 | continue
62 | img1_name = os.path.splitext(img)[0]
63 | item = [os.path.join(dir_list[0], img),]
64 | for i in range(1, len(img_filename_list)):
65 | img_name = '{:s}_{:s}'.format(img1_name, post_fix[i-1])
66 | if img_name in img_filename_list[i]:
67 | img_path = os.path.join(dir_list[i], img_name)
68 | item.append(img_path)
69 |
70 | if len(item) == len(dir_list):
71 | img_list.append(tuple(item))
72 |
73 | return img_list
74 |
75 |
76 |
77 | # dataset that supports one input image, one target image, and one weight map (optional)
78 | class DataFolder(data.Dataset):
79 | def __init__(self, dir_list, post_fix, num_channels, data_transform=None, loader=img_loader):
80 | super(DataFolder, self).__init__()
81 | if len(dir_list) != len(post_fix) + 1:
82 | raise (RuntimeError('Length of dir_list is different from length of post_fix + 1.'))
83 | if len(dir_list) != len(num_channels):
84 | raise (RuntimeError('Length of dir_list is different from length of num_channels.'))
85 |
86 | self.img_list = get_imgs_list(dir_list, post_fix)
87 | if len(self.img_list) == 0:
88 | raise(RuntimeError('Found 0 image pairs in given directories.'))
89 |
90 | self.data_transform = data_transform
91 | self.num_channels = num_channels
92 | self.loader = loader
93 |
94 | def __getitem__(self, index):
95 | img_paths = self.img_list[index]
96 |
97 | sample = [self.loader(img_paths[i], self.num_channels[i]) for i in range(len(img_paths))]
98 |
99 | if self.data_transform is not None:
100 | sample_tensor = self.data_transform(sample)
101 |
102 |
103 | while(len(torch.unique(sample_tensor[2]))<=1): # sample[2].detach().cpu().numpy()
104 | if self.data_transform is not None:
105 | sample_tensor = self.data_transform(sample)
106 |
107 | return sample_tensor
108 |
109 | def __len__(self):
110 | return len(self.img_list)
111 |
112 |
--------------------------------------------------------------------------------
/data_prepare/getDirectionDiffMap.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on 2021/6/7
4 |
5 | @author: he
6 | """
7 |
8 |
9 | from data_prepare.SegFix_offset_helper import DTOffsetHelper
10 | import numpy as np
11 | import torch
12 |
13 |
14 | def circshift(matrix_ori, direction, shiftnum1, shiftnum2):
15 | # direction = 1,2,3,4 # 偏移方向 1:左上; 2:右上; 3:左下; 4:右下;
16 | c, h, w = matrix_ori.shape
17 | matrix_new = np.zeros_like(matrix_ori)
18 |
19 | for k in range(c):
20 | matrix = matrix_ori[k]
21 | # matrix = matrix_ori[:,:,k]
22 | if (direction == 1):
23 | # 左上
24 | matrix = np.vstack((matrix[shiftnum1:, :], np.zeros_like(matrix[:shiftnum1, :])))
25 | matrix = np.hstack((matrix[:, shiftnum2:], np.zeros_like(matrix[:, :shiftnum2])))
26 | elif (direction == 2):
27 | # 右上
28 | matrix = np.vstack((matrix[shiftnum1:, :], np.zeros_like(matrix[:shiftnum1, :])))
29 | matrix = np.hstack((np.zeros_like(matrix[:, (w - shiftnum2):]), matrix[:, :(w - shiftnum2)]))
30 | elif (direction == 3):
31 | # 左下
32 | matrix = np.vstack((np.zeros_like(matrix[(h - shiftnum1):, :]), matrix[:(h - shiftnum1), :]))
33 | matrix = np.hstack((matrix[:, shiftnum2:], np.zeros_like(matrix[:, :shiftnum2])))
34 | elif (direction == 4):
35 | # 右下
36 | matrix = np.vstack((np.zeros_like(matrix[(h - shiftnum1):, :]), matrix[:(h - shiftnum1), :]))
37 | matrix = np.hstack((np.zeros_like(matrix[:, (w - shiftnum2):]), matrix[:, :(w - shiftnum2)]))
38 | # matrix_new[k]==>matrix_new[:,:, k]
39 | # matrix_new[:,:, k] = matrix
40 | matrix_new[k] = matrix
41 |
42 | return matrix_new
43 |
44 | def generate_dd_map(label_direction, direction_classes):
45 | direction_offsets = DTOffsetHelper.label_to_vector(torch.from_numpy(label_direction.reshape(1, label_direction.shape[0], label_direction.shape[1])),direction_classes)
46 | direction_offsets = direction_offsets[0].permute(1,2,0).detach().cpu().numpy()
47 |
48 | direction_os = direction_offsets #[256,256,2]
49 |
50 | height, weight = direction_os.shape[0], direction_os.shape[1]
51 |
52 | cos_sim_map = np.zeros((height, weight), dtype=np.float)
53 |
54 | feature_list = []
55 | feature5 = direction_os # .transpose(1, 2, 0)
56 | if (direction_classes - 1 == 4):
57 | direction_os = direction_os.transpose(2, 0, 1)
58 | feature2 = circshift(direction_os, 1, 1, 0).transpose(1, 2, 0)
59 | feature4 = circshift(direction_os, 3, 0, 1).transpose(1, 2, 0)
60 | feature6 = circshift(direction_os, 4, 0, 1).transpose(1, 2, 0)
61 | feature8 = circshift(direction_os, 3, 1, 0).transpose(1, 2, 0)
62 |
63 | feature_list.append(feature2)
64 | feature_list.append(feature4)
65 | # feature_list.append(feature5)
66 | feature_list.append(feature6)
67 | feature_list.append(feature8)
68 |
69 | elif (direction_classes - 1 == 8 or direction_classes - 1 == 16):
70 | direction_os = direction_os.transpose(2, 0, 1) # [2,256,256]
71 | feature1 = circshift(direction_os, 1, 1, 1).transpose(1, 2, 0)
72 | feature2 = circshift(direction_os, 1, 1, 0).transpose(1, 2, 0)
73 | feature3 = circshift(direction_os, 2, 1, 1).transpose(1, 2, 0)
74 | feature4 = circshift(direction_os, 3, 0, 1).transpose(1, 2, 0)
75 | feature6 = circshift(direction_os, 4, 0, 1).transpose(1, 2, 0)
76 | feature7 = circshift(direction_os, 3, 1, 1).transpose(1, 2, 0)
77 | feature8 = circshift(direction_os, 3, 1, 0).transpose(1, 2, 0)
78 | feature9 = circshift(direction_os, 4, 1, 1).transpose(1, 2, 0)
79 |
80 | feature_list.append(feature1)
81 | feature_list.append(feature2)
82 | feature_list.append(feature3)
83 | feature_list.append(feature4)
84 | # feature_list.append(feature5)
85 | feature_list.append(feature6)
86 | feature_list.append(feature7)
87 | feature_list.append(feature8)
88 | feature_list.append(feature9)
89 |
90 | cos_value = np.zeros((height, weight, direction_classes - 1), dtype=np.float32)
91 | # print('cos_value.shape = {}'.format(cos_value.shape))
92 | for k, feature_item in enumerate(feature_list):
93 | fenzi = (feature5[:, :, 0] * feature_item[:, :, 0] + feature5[:, :, 1] * feature_item[:, :, 1])
94 | fenmu = (np.sqrt(pow(feature5[:, :, 0], 2) + pow(feature5[:, :, 1], 2)) * np.sqrt(
95 | pow(feature_item[:, :, 0], 2) + pow(feature_item[:, :, 1], 2)) + 0.000001)
96 | cos_np = fenzi / fenmu
97 | cos_value[:, :, k] = cos_np
98 |
99 | cos_value_min = np.min(cos_value, axis=2)
100 | cos_sim_map = cos_value_min
101 | cos_sim_map[label_direction == 0] = 1
102 |
103 | cos_sim_map_np = (1 - np.around(cos_sim_map))
104 | cos_sim_map_np_max = np.max(cos_sim_map_np)
105 | cos_sim_map_np_min = np.min(cos_sim_map_np)
106 | cos_sim_map_np_normal = (cos_sim_map_np - cos_sim_map_np_min) / (cos_sim_map_np_max - cos_sim_map_np_min)
107 |
108 | return cos_sim_map_np_normal
109 |
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------
/data_prepare/logger.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | #!/usr/bin/env python
4 | # -*- coding:utf-8 -*-
5 | # Author: Donny You(youansheng@gmail.com)
6 | # Logging tool implemented with the python Package logging.
7 |
8 |
9 | from __future__ import absolute_import
10 | from __future__ import division
11 | from __future__ import print_function
12 |
13 | import argparse
14 | import logging
15 | import os
16 | import sys
17 |
18 |
19 | DEFAULT_LOGFILE_LEVEL = 'debug'
20 | DEFAULT_STDOUT_LEVEL = 'info'
21 | DEFAULT_LOG_FILE = './default.log'
22 | DEFAULT_LOG_FORMAT = '%(asctime)s %(levelname)-7s %(message)s'
23 |
24 | LOG_LEVEL_DICT = {
25 | 'debug': logging.DEBUG,
26 | 'info': logging.INFO,
27 | 'warning': logging.WARNING,
28 | 'error': logging.ERROR,
29 | 'critical': logging.CRITICAL
30 | }
31 |
32 |
33 | class Logger(object):
34 | """
35 | Args:
36 | Log level: CRITICAL>ERROR>WARNING>INFO>DEBUG.
37 | Log file: The file that stores the logging info.
38 | rewrite: Clear the log file.
39 | log format: The format of log messages.
40 | stdout level: The log level to print on the screen.
41 | """
42 | logfile_level = None
43 | log_file = None
44 | log_format = None
45 | rewrite = None
46 | stdout_level = None
47 | logger = None
48 |
49 | _caches = {}
50 |
51 | @staticmethod
52 | def init(logfile_level=DEFAULT_LOGFILE_LEVEL,
53 | log_file=DEFAULT_LOG_FILE,
54 | log_format=DEFAULT_LOG_FORMAT,
55 | rewrite=False,
56 | stdout_level=None):
57 | Logger.logfile_level = logfile_level
58 | Logger.log_file = log_file
59 | Logger.log_format = log_format
60 | Logger.rewrite = rewrite
61 | Logger.stdout_level = stdout_level
62 |
63 | Logger.logger = logging.getLogger()
64 | fmt = logging.Formatter(Logger.log_format)
65 |
66 | if Logger.logfile_level is not None:
67 | filemode = 'w'
68 | if not Logger.rewrite:
69 | filemode = 'a'
70 |
71 | dir_name = os.path.dirname(os.path.abspath(Logger.log_file))
72 | if not os.path.exists(dir_name):
73 | os.makedirs(dir_name)
74 |
75 | if Logger.logfile_level not in LOG_LEVEL_DICT:
76 | print('Invalid logging level: {}'.format(Logger.logfile_level))
77 | Logger.logfile_level = DEFAULT_LOGFILE_LEVEL
78 |
79 | Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.logfile_level])
80 |
81 | fh = logging.FileHandler(Logger.log_file, mode=filemode)
82 | fh.setFormatter(fmt)
83 | fh.setLevel(LOG_LEVEL_DICT[Logger.logfile_level])
84 |
85 | Logger.logger.addHandler(fh)
86 |
87 | if stdout_level is not None:
88 | if Logger.logfile_level is None:
89 | Logger.logger.setLevel(LOG_LEVEL_DICT[Logger.stdout_level])
90 |
91 | console = logging.StreamHandler()
92 | if Logger.stdout_level not in LOG_LEVEL_DICT:
93 | print('Invalid logging level: {}'.format(Logger.stdout_level))
94 | return
95 |
96 | console.setLevel(LOG_LEVEL_DICT[Logger.stdout_level])
97 | console.setFormatter(fmt)
98 | Logger.logger.addHandler(console)
99 |
100 | @staticmethod
101 | def set_log_file(file_path):
102 | Logger.log_file = file_path
103 | Logger.init(log_file=file_path)
104 |
105 | @staticmethod
106 | def set_logfile_level(log_level):
107 | if log_level not in LOG_LEVEL_DICT:
108 | print('Invalid logging level: {}'.format(log_level))
109 | return
110 |
111 | Logger.init(logfile_level=log_level)
112 |
113 | @staticmethod
114 | def clear_log_file():
115 | Logger.rewrite = True
116 | Logger.init(rewrite=True)
117 |
118 | @staticmethod
119 | def check_logger():
120 | if Logger.logger is None:
121 | Logger.init(logfile_level=None, stdout_level=DEFAULT_STDOUT_LEVEL)
122 |
123 | @staticmethod
124 | def set_stdout_level(log_level):
125 | if log_level not in LOG_LEVEL_DICT:
126 | print('Invalid logging level: {}'.format(log_level))
127 | return
128 |
129 | Logger.init(stdout_level=log_level)
130 |
131 | @staticmethod
132 | def debug(message):
133 | Logger.check_logger()
134 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
135 | lineno = sys._getframe().f_back.f_lineno
136 | prefix = '[{}, {}]'.format(filename,lineno)
137 | Logger.logger.debug('{} {}'.format(prefix, message))
138 |
139 | @staticmethod
140 | def info(message):
141 | Logger.check_logger()
142 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
143 | lineno = sys._getframe().f_back.f_lineno
144 | prefix = '[{}, {}]'.format(filename,lineno)
145 | Logger.logger.info('{} {}'.format(prefix, message))
146 |
147 | @staticmethod
148 | def info_once(message):
149 | Logger.check_logger()
150 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
151 | lineno = sys._getframe().f_back.f_lineno
152 | prefix = '[{}, {}]'.format(filename, lineno)
153 |
154 | if Logger._caches.get((prefix, message)) is not None:
155 | return
156 |
157 | Logger.logger.info('{} {}'.format(prefix, message))
158 | Logger._caches[(prefix, message)] = True
159 |
160 | @staticmethod
161 | def warn(message):
162 | Logger.check_logger()
163 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
164 | lineno = sys._getframe().f_back.f_lineno
165 | prefix = '[{}, {}]'.format(filename,lineno)
166 | Logger.logger.warn('{} {}'.format(prefix, message))
167 |
168 | @staticmethod
169 | def error(message):
170 | Logger.check_logger()
171 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
172 | lineno = sys._getframe().f_back.f_lineno
173 | prefix = '[{}, {}]'.format(filename,lineno)
174 | Logger.logger.error('{} {}'.format(prefix, message))
175 |
176 | @staticmethod
177 | def critical(message):
178 | Logger.check_logger()
179 | filename = os.path.basename(sys._getframe().f_back.f_code.co_filename)
180 | lineno = sys._getframe().f_back.f_lineno
181 | prefix = '[{}, {}]'.format(filename,lineno)
182 | Logger.logger.critical('{} {}'.format(prefix, message))
183 |
184 |
185 | if __name__ == "__main__":
186 | parser = argparse.ArgumentParser()
187 | parser.add_argument('--logfile_level', default="debug", type=str,
188 | dest='logfile_level', help='To set the log level to files.')
189 | parser.add_argument('--stdout_level', default=None, type=str,
190 | dest='stdout_level', help='To set the level to print to screen.')
191 | parser.add_argument('--log_file', default="./default.log", type=str,
192 | dest='log_file', help='The path of log files.')
193 | parser.add_argument('--log_format', default="%(asctime)s %(levelname)-7s %(message)s",
194 | type=str, dest='log_format', help='The format of log messages.')
195 | parser.add_argument('--rewrite', default=False, type=bool,
196 | dest='rewrite', help='Clear the log files existed.')
197 |
198 | args = parser.parse_args()
199 | Logger.init(logfile_level=args.logfile_level, stdout_level=args.stdout_level,
200 | log_file=args.log_file, log_format=args.log_format, rewrite=args.rewrite)
201 |
202 | Logger.info("info test.")
203 | Logger.debug("debug test.")
204 | Logger.warn("warn test.")
205 | Logger.error("error test.")
206 | Logger.debug("debug test.")
--------------------------------------------------------------------------------
/hhl_utils/helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import math
6 | import PIL
7 |
8 | def dir_exists(path):
9 | if not os.path.exists(path):
10 | os.makedirs(path)
11 |
12 | def initialize_weights(*models):
13 | for model in models:
14 | for m in model.modules():
15 | if isinstance(m, nn.Conv2d):
16 | nn.init.kaiming_normal_(m.weight.data, nonlinearity='relu')
17 | elif isinstance(m, nn.BatchNorm2d):
18 | m.weight.data.fill_(1.)
19 | m.bias.data.fill_(1e-4)
20 | elif isinstance(m, nn.Linear):
21 | m.weight.data.normal_(0.0, 0.0001)
22 | m.bias.data.zero_()
23 |
24 | def get_upsampling_weight(in_channels, out_channels, kernel_size):
25 | factor = (kernel_size + 1) // 2
26 | if kernel_size % 2 == 1:
27 | center = factor - 1
28 | else:
29 | center = factor - 0.5
30 | og = np.ogrid[:kernel_size, :kernel_size]
31 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
32 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64)
33 | weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt
34 | return torch.from_numpy(weight).float()
35 |
36 | def colorize_mask(mask, palette):
37 | zero_pad = 256 * 3 - len(palette)
38 | for i in range(zero_pad):
39 | palette.append(0)
40 | new_mask = PIL.Image.fromarray(mask.astype(np.uint8)).convert('P')
41 | new_mask.putpalette(palette)
42 | return new_mask
43 |
44 | def set_trainable_attr(m,b):
45 | m.trainable = b
46 | for p in m.parameters(): p.requires_grad = b
47 |
48 | def apply_leaf(m, f):
49 | c = m if isinstance(m, (list, tuple)) else list(m.children())
50 | if isinstance(m, nn.Module):
51 | f(m)
52 | if len(c)>0:
53 | for l in c:
54 | apply_leaf(l,f)
55 |
56 | def set_trainable(l, b):
57 | apply_leaf(l, lambda m: set_trainable_attr(m,b))
--------------------------------------------------------------------------------
/hhl_utils/pytorch_ssim.py:
--------------------------------------------------------------------------------
1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import numpy as np
6 | from math import exp
7 |
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10 | return gauss/gauss.sum()
11 |
12 | def create_window(window_size, channel):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
16 | return window
17 |
18 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
21 |
22 | mu1_sq = mu1.pow(2)
23 | mu2_sq = mu2.pow(2)
24 | mu1_mu2 = mu1*mu2
25 |
26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
29 |
30 | C1 = 0.01**2
31 | C2 = 0.03**2
32 |
33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
34 |
35 | if size_average:
36 | return ssim_map.mean()
37 | else:
38 | return ssim_map.mean(1).mean(1).mean(1)
39 |
40 | class SSIM(torch.nn.Module):
41 | def __init__(self, window_size = 11, size_average = True):
42 | super(SSIM, self).__init__()
43 | self.window_size = window_size
44 | self.size_average = size_average
45 | self.channel = 1
46 | self.window = create_window(window_size, self.channel)
47 |
48 | def forward(self, img1, img2):
49 | (_, channel, _, _) = img1.size()
50 |
51 | if channel == self.channel and self.window.data.type() == img1.data.type():
52 | window = self.window
53 | else:
54 | window = create_window(self.window_size, channel)
55 |
56 | if img1.is_cuda:
57 | window = window.cuda(img1.get_device())
58 | window = window.type_as(img1)
59 |
60 | self.window = window
61 | self.channel = channel
62 |
63 |
64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
65 |
66 | def _logssim(img1, img2, window, window_size, channel, size_average = True):
67 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
68 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
69 |
70 | mu1_sq = mu1.pow(2)
71 | mu2_sq = mu2.pow(2)
72 | mu1_mu2 = mu1*mu2
73 |
74 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
75 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
76 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
77 |
78 | C1 = 0.01**2
79 | C2 = 0.03**2
80 |
81 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
82 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map))
83 | ssim_map = -torch.log(ssim_map + 1e-8)
84 |
85 | if size_average:
86 | return ssim_map.mean()
87 | else:
88 | return ssim_map.mean(1).mean(1).mean(1)
89 |
90 | class LOGSSIM(torch.nn.Module):
91 | def __init__(self, window_size = 11, size_average = True):
92 | super(LOGSSIM, self).__init__()
93 | self.window_size = window_size
94 | self.size_average = size_average
95 | self.channel = 1
96 | self.window = create_window(window_size, self.channel)
97 |
98 | def forward(self, img1, img2):
99 | (_, channel, _, _) = img1.size()
100 |
101 | if channel == self.channel and self.window.data.type() == img1.data.type():
102 | window = self.window
103 | else:
104 | window = create_window(self.window_size, channel)
105 |
106 | if img1.is_cuda:
107 | window = window.cuda(img1.get_device())
108 | window = window.type_as(img1)
109 |
110 | self.window = window
111 | self.channel = channel
112 |
113 |
114 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average)
115 |
116 |
117 | def ssim(img1, img2, window_size = 11, size_average = True):
118 | (_, channel, _, _) = img1.size()
119 | window = create_window(window_size, channel)
120 |
121 | if img1.is_cuda:
122 | window = window.cuda(img1.get_device())
123 | window = window.type_as(img1)
124 |
125 | return _ssim(img1, img2, window, window_size, channel, size_average)
--------------------------------------------------------------------------------
/hhl_utils/radam.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.optim.optimizer import Optimizer#, required
4 |
5 |
6 | class RAdam(Optimizer):
7 |
8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
9 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
10 | self.buffer = [[None, None, None] for ind in range(10)]
11 | super(RAdam, self).__init__(params, defaults)
12 |
13 | def __setstate__(self, state):
14 | super(RAdam, self).__setstate__(state)
15 |
16 | def step(self, closure=None):
17 |
18 | loss = None
19 | if closure is not None:
20 | loss = closure()
21 |
22 | for group in self.param_groups:
23 |
24 | for p in group['params']:
25 | if p.grad is None:
26 | continue
27 | grad = p.grad.data.float()
28 | if grad.is_sparse:
29 | raise RuntimeError('RAdam does not support sparse gradients')
30 |
31 | p_data_fp32 = p.data.float()
32 |
33 | state = self.state[p]
34 |
35 | if len(state) == 0:
36 | state['step'] = 0
37 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
38 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
39 | else:
40 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
41 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
42 |
43 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
44 | beta1, beta2 = group['betas']
45 |
46 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
47 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
48 |
49 | state['step'] += 1
50 | buffered = self.buffer[int(state['step'] % 10)]
51 | if state['step'] == buffered[0]:
52 | N_sma, step_size = buffered[1], buffered[2]
53 | else:
54 | buffered[0] = state['step']
55 | beta2_t = beta2 ** state['step']
56 | N_sma_max = 2 / (1 - beta2) - 1
57 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
58 | buffered[1] = N_sma
59 |
60 | # more conservative since it's an approximated value
61 | if N_sma >= 5:
62 | step_size = group['lr'] * math.sqrt(
63 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
64 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
65 | else:
66 | step_size = group['lr'] / (1 - beta1 ** state['step'])
67 | buffered[2] = step_size
68 |
69 | if group['weight_decay'] != 0:
70 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
71 |
72 | # more conservative since it's an approximated value
73 | if N_sma >= 5:
74 | denom = exp_avg_sq.sqrt().add_(group['eps'])
75 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
76 | else:
77 | p_data_fp32.add_(-step_size, exp_avg)
78 |
79 | p.data.copy_(p_data_fp32)
80 |
81 | return loss
82 |
83 |
84 | class RAdam_4step(Optimizer):
85 |
86 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, update_all=False,
87 | additional_four=False):
88 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
89 | self.update_all = update_all # whether update the first 4 steps
90 | self.additional_four = additional_four # whether use additional 4 steps for SGD
91 | self.buffer = [[None, None] for ind in range(10)]
92 | super(RAdam_4step, self).__init__(params, defaults)
93 |
94 | def __setstate__(self, state):
95 | super(RAdam_4step, self).__setstate__(state)
96 |
97 | def step(self, closure=None):
98 |
99 | loss = None
100 | if closure is not None:
101 | loss = closure()
102 |
103 | for group in self.param_groups:
104 |
105 | for p in group['params']:
106 | if p.grad is None:
107 | continue
108 | grad = p.grad.data.float()
109 | if grad.is_sparse:
110 | raise RuntimeError('RAdam_4step does not support sparse gradients')
111 |
112 | p_data_fp32 = p.data.float()
113 |
114 | state = self.state[p]
115 |
116 | if len(state) == 0:
117 | state[
118 | 'step'] = -4 if self.additional_four else 0 # since this exp requires exactly 4 step, it is hard coded
119 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
120 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
121 | else:
122 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
123 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
124 |
125 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
126 | beta1, beta2 = group['betas']
127 |
128 | state['step'] += 1
129 |
130 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
131 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
132 |
133 | if state['step'] > 0:
134 |
135 | state_step = state['step'] + 4 if self.additional_four else state[
136 | 'step'] # since this exp requires exactly 4 step, it is hard coded
137 |
138 | buffered = self.buffer[int(state_step % 10)]
139 | if state_step == buffered[0]:
140 | step_size = buffered[1]
141 | else:
142 | buffered[0] = state_step
143 | beta2_t = beta2 ** state['step']
144 |
145 | if state['step'] > 4: # since this exp requires exactly 4 step, it is hard coded
146 | N_sma_max = 2 / (1 - beta2) - 1
147 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
148 | step_size = group['lr'] * math.sqrt(
149 | (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (
150 | 1 - beta1 ** state_step)
151 | elif self.update_all:
152 | step_size = group['lr'] / (1 - beta1 ** state_step)
153 | else:
154 | step_size = 0
155 | buffered[1] = step_size
156 |
157 | if state['step'] > 4: # since this exp requires exactly 4 step, it is hard coded
158 | if group['weight_decay'] != 0:
159 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
160 | denom = (exp_avg_sq.sqrt() / math.sqrt(1 - beta2 ** state_step)).add_(group['eps'])
161 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
162 | p.data.copy_(p_data_fp32)
163 | elif self.update_all:
164 | if group['weight_decay'] != 0:
165 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
166 | denom = (exp_avg_sq.sqrt() / math.sqrt(1 - beta2 ** state_step))
167 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
168 | p.data.copy_(p_data_fp32)
169 | else:
170 | state_step = state['step'] + 4 if self.additional_four else state[
171 | 'step'] # since this exp requires exactly 4 step, it is hard coded
172 |
173 | if group['weight_decay'] != 0:
174 | p_data_fp32.add_(-group['weight_decay'] * 0.1, p_data_fp32)
175 |
176 | step_size = 0.1 / (1 - beta1 ** state_step)
177 | p_data_fp32.add_(-step_size, exp_avg)
178 | p.data.copy_(p_data_fp32)
179 |
180 | return loss
181 |
182 |
183 | class AdamW(Optimizer):
184 |
185 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
186 | weight_decay=0, use_variance=True, warmup=4000):
187 | defaults = dict(lr=lr, betas=betas, eps=eps,
188 | weight_decay=weight_decay, use_variance=True, warmup=warmup)
189 | print('======== Warmup: {} ========='.format(warmup))
190 | super(AdamW, self).__init__(params, defaults)
191 |
192 | def __setstate__(self, state):
193 | super(AdamW, self).__setstate__(state)
194 |
195 | def step(self, closure=None):
196 | #global iter_idx
197 | #siter_idx += 1
198 | grad_list = list()
199 | mom_list = list()
200 | mom_2rd_list = list()
201 |
202 | loss = None
203 | if closure is not None:
204 | loss = closure()
205 |
206 | for group in self.param_groups:
207 |
208 | for p in group['params']:
209 | if p.grad is None:
210 | continue
211 | grad = p.grad.data.float()
212 | if grad.is_sparse:
213 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
214 |
215 | p_data_fp32 = p.data.float()
216 |
217 | state = self.state[p]
218 |
219 | if len(state) == 0:
220 | state['step'] = 0
221 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
222 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
223 | else:
224 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
225 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
226 |
227 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
228 | beta1, beta2 = group['betas']
229 |
230 | state['step'] += 1
231 |
232 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
233 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
234 |
235 | denom = exp_avg_sq.sqrt().add_(group['eps'])
236 | bias_correction1 = 1 - beta1 ** state['step']
237 | bias_correction2 = 1 - beta2 ** state['step']
238 |
239 | if group['warmup'] > state['step']:
240 | scheduled_lr = 1e-6 + state['step'] * (group['lr'] - 1e-6) / group['warmup']
241 | else:
242 | scheduled_lr = group['lr']
243 |
244 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
245 | if group['weight_decay'] != 0:
246 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
247 |
248 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
249 |
250 | p.data.copy_(p_data_fp32)
251 |
252 | return loss
--------------------------------------------------------------------------------
/hhl_utils/ranger.py:
--------------------------------------------------------------------------------
1 | #Ranger deep learning optimizer - RAdam + Lookahead combined.
2 | #https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
3 |
4 | #Ranger has now been used to capture 12 records on the FastAI leaderboard.
5 |
6 | #This version = 9.3.19
7 |
8 | #Credits:
9 | #RAdam --> https://github.com/LiyuanLucasLiu/RAdam
10 | #Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
11 | #Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
12 |
13 | #summary of changes:
14 | #full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
15 | #supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
16 | #changes 8/31/19 - fix references to *self*.N_sma_threshold;
17 | #changed eps to 1e-5 as better default than 1e-8.
18 |
19 | import math
20 | import torch
21 | from torch.optim.optimizer import Optimizer#, required
22 | import itertools as it
23 |
24 |
25 |
26 | class Ranger(Optimizer):
27 |
28 | def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.999), eps=1e-5, weight_decay=0):
29 | #parameter checks
30 | if not 0.0 <= alpha <= 1.0:
31 | raise ValueError(f'Invalid slow update rate: {alpha}')
32 | if not 1 <= k:
33 | raise ValueError(f'Invalid lookahead steps: {k}')
34 | if not lr > 0:
35 | raise ValueError(f'Invalid Learning Rate: {lr}')
36 | if not eps > 0:
37 | raise ValueError(f'Invalid eps: {eps}')
38 |
39 | #parameter comments:
40 | # beta1 (momentum) of .95 seems to work better than .90...
41 | #N_sma_threshold of 5 seems better in testing than 4.
42 | #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
43 |
44 | #prep defaults and init torch.optim base
45 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, eps=eps, weight_decay=weight_decay)
46 | super().__init__(params,defaults)
47 |
48 | #adjustable threshold
49 | self.N_sma_threshhold = N_sma_threshhold
50 |
51 | #now we can get to work...
52 | #removed as we now use step from RAdam...no need for duplicate step counting
53 | #for group in self.param_groups:
54 | # group["step_counter"] = 0
55 | #print("group step counter init")
56 |
57 | #look ahead params
58 | self.alpha = alpha
59 | self.k = k
60 |
61 | #radam buffer for state
62 | self.radam_buffer = [[None,None,None] for ind in range(10)]
63 |
64 | #self.first_run_check=0
65 |
66 | #lookahead weights
67 | #9/2/19 - lookahead param tensors have been moved to state storage.
68 | #This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs.
69 |
70 | #self.slow_weights = [[p.clone().detach() for p in group['params']]
71 | # for group in self.param_groups]
72 |
73 | #don't use grad for lookahead weights
74 | #for w in it.chain(*self.slow_weights):
75 | # w.requires_grad = False
76 |
77 | def __setstate__(self, state):
78 | print("set state called")
79 | super(Ranger, self).__setstate__(state)
80 |
81 |
82 | def step(self, closure=None):
83 | loss = None
84 | #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
85 | #Uncomment if you need to use the actual closure...
86 |
87 | #if closure is not None:
88 | #loss = closure()
89 |
90 | #Evaluate averages and grad, update param tensors
91 | for group in self.param_groups:
92 |
93 | for p in group['params']:
94 | if p.grad is None:
95 | continue
96 | grad = p.grad.data.float()
97 | if grad.is_sparse:
98 | raise RuntimeError('Ranger optimizer does not support sparse gradients')
99 |
100 | p_data_fp32 = p.data.float()
101 |
102 | state = self.state[p] #get state dict for this param
103 |
104 | if len(state) == 0: #if first time to run...init dictionary with our desired entries
105 | #if self.first_run_check==0:
106 | #self.first_run_check=1
107 | #print("Initializing slow buffer...should not see this at load from saved model!")
108 | state['step'] = 0
109 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
110 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
111 |
112 | #look ahead weight storage now in state dict
113 | state['slow_buffer'] = torch.empty_like(p.data)
114 | state['slow_buffer'].copy_(p.data)
115 |
116 | else:
117 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
118 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
119 |
120 | #begin computations
121 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
122 | beta1, beta2 = group['betas']
123 |
124 | #compute variance mov avg
125 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
126 | #compute mean moving avg
127 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
128 |
129 | state['step'] += 1
130 |
131 |
132 | buffered = self.radam_buffer[int(state['step'] % 10)]
133 | if state['step'] == buffered[0]:
134 | N_sma, step_size = buffered[1], buffered[2]
135 | else:
136 | buffered[0] = state['step']
137 | beta2_t = beta2 ** state['step']
138 | N_sma_max = 2 / (1 - beta2) - 1
139 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
140 | buffered[1] = N_sma
141 | if N_sma > self.N_sma_threshhold:
142 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
143 | else:
144 | step_size = 1.0 / (1 - beta1 ** state['step'])
145 | buffered[2] = step_size
146 |
147 | if group['weight_decay'] != 0:
148 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
149 |
150 | if N_sma > self.N_sma_threshhold:
151 | denom = exp_avg_sq.sqrt().add_(group['eps'])
152 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
153 | else:
154 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
155 |
156 | p.data.copy_(p_data_fp32)
157 |
158 | #integrated look ahead...
159 | #we do it at the param level instead of group level
160 | if state['step'] % group['k'] == 0:
161 | slow_p = state['slow_buffer'] #get access to slow param tensor
162 | slow_p.add_(self.alpha, p.data - slow_p) #(fast weights - slow weights) * alpha
163 | p.data.copy_(slow_p) #copy interpolated weights to RAdam param tensor
164 |
165 | return loss
--------------------------------------------------------------------------------
/hhl_utils/torchsummary.py:
--------------------------------------------------------------------------------
1 | """
2 | A modied version of the code by Tae Hwan Jung
3 | https://github.com/graykode/modelsummary
4 | """
5 |
6 | import torch
7 | import numpy as np
8 | import torch.nn as nn
9 | from collections import OrderedDict
10 |
11 | def summary(model, input_shape, batch_size=-1, intputshow=True):
12 |
13 | def register_hook(module):
14 | def hook(module, input, output=None):
15 | class_name = str(module.__class__).split(".")[-1].split("'")[0]
16 | module_idx = len(summary)
17 |
18 | m_key = "%s-%i" % (class_name, module_idx + 1)
19 | summary[m_key] = OrderedDict()
20 | summary[m_key]["input_shape"] = list(input[0].size())
21 | summary[m_key]["input_shape"][0] = batch_size
22 |
23 | params = 0
24 | if hasattr(module, "weight") and hasattr(module.weight, "size"):
25 | params += torch.prod(torch.LongTensor(list(module.weight.size())))
26 | summary[m_key]["trainable"] = module.weight.requires_grad
27 | if hasattr(module, "bias") and hasattr(module.bias, "size"):
28 | params += torch.prod(torch.LongTensor(list(module.bias.size())))
29 | summary[m_key]["nb_params"] = params
30 |
31 | if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList)
32 | and not (module == model)) and 'torch' in str(module.__class__):
33 | if intputshow is True:
34 | hooks.append(module.register_forward_pre_hook(hook))
35 | else:
36 | hooks.append(module.register_forward_hook(hook))
37 |
38 | # create properties
39 | summary = OrderedDict()
40 | hooks = []
41 |
42 | # register hook
43 | model.apply(register_hook)
44 | model(torch.zeros(input_shape))
45 |
46 | # remove these hooks
47 | for h in hooks:
48 | h.remove()
49 |
50 | model_info = ''
51 |
52 | model_info += "-----------------------------------------------------------------------\n"
53 | line_new = "{:>25} {:>25} {:>15}".format("Layer (type)", "Input Shape", "Param #")
54 | model_info += line_new + '\n'
55 | model_info += "=======================================================================\n"
56 |
57 | total_params = 0
58 | total_output = 0
59 | trainable_params = 0
60 | for layer in summary:
61 | line_new = "{:>25} {:>25} {:>15}".format(
62 | layer,
63 | str(summary[layer]["input_shape"]),
64 | "{0:,}".format(summary[layer]["nb_params"]),
65 | )
66 |
67 | total_params += summary[layer]["nb_params"]
68 | if intputshow is True:
69 | total_output += np.prod(summary[layer]["input_shape"])
70 | else:
71 | total_output += np.prod(summary[layer]["output_shape"])
72 | if "trainable" in summary[layer]:
73 | if summary[layer]["trainable"] == True:
74 | trainable_params += summary[layer]["nb_params"]
75 |
76 | model_info += line_new + '\n'
77 |
78 | model_info += "=======================================================================\n"
79 | model_info += "Total params: {0:,}\n".format(total_params)
80 | model_info += "Trainable params: {0:,}\n".format(trainable_params)
81 | model_info += "Non-trainable params: {0:,}\n".format(total_params - trainable_params)
82 | model_info += "-----------------------------------------------------------------------\n"
83 |
84 | return model_info
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from torch import nn
4 | import numpy as np
5 | import math
6 |
7 |
8 | # combined with cross entropy loss, instance level
9 | class LossVariance(nn.Module):
10 | """ The instances in target should be labeled
11 | """
12 |
13 | def __init__(self):
14 | super(LossVariance, self).__init__()
15 |
16 | def forward(self, input, target):
17 |
18 | B = input.size(0)
19 |
20 | loss = 0
21 | for k in range(B):
22 | unique_vals = target[k].unique()
23 | unique_vals = unique_vals[unique_vals != 0]
24 |
25 | sum_var = 0
26 | for val in unique_vals:
27 | instance = input[k][:, target[k] == val]
28 | if instance.size(1) > 1:
29 | sum_var += instance.var(dim=1).sum()
30 |
31 | loss += sum_var / (len(unique_vals) + 1e-8)
32 | loss /= B
33 | return loss
34 |
35 |
36 |
37 | class FocalLoss2d(nn.Module):
38 | def __init__(self, gamma=2, size_average=True, type="sigmoid"):
39 | super(FocalLoss2d, self).__init__()
40 | self.gamma = gamma
41 | self.size_average = size_average
42 | self.type = type
43 |
44 | def forward(self, logit, target, class_weight=None):
45 | target = target.view(-1, 1).long()
46 | if self.type == 'sigmoid':
47 | if class_weight is None:
48 | class_weight = [1]*2
49 |
50 | prob = F.sigmoid(logit)
51 | prob = prob.view(-1, 1)
52 | prob = torch.cat((1-prob, prob), 1)
53 | select = torch.FloatTensor(len(prob), 2).zero_().cuda()
54 | select.scatter_(1, target, 1.)
55 |
56 | elif self.type=='softmax':
57 | B,C,H,W = logit.size()
58 | if class_weight is None:
59 | class_weight =[1]*C
60 |
61 | logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C)
62 | prob = F.softmax(logit,1)
63 | select = torch.FloatTensor(len(prob), C).zero_().cuda()
64 | select.scatter_(1, target, 1.)
65 |
66 | class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1)
67 | class_weight = torch.gather(class_weight, 0, target)
68 |
69 | prob = (prob*select).sum(1).view(-1,1)
70 | prob = torch.clamp(prob,1e-8,1-1e-8)
71 | batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log()
72 |
73 | if self.size_average:
74 | loss = batch_loss.mean()
75 | else:
76 | loss = batch_loss
77 |
78 | return loss
79 |
80 | # Robust focal loss
81 | class RobustFocalLoss2d(nn.Module):
82 | #assume top 10% is outliers
83 | def __init__(self, gamma=2, size_average=True, type="sigmoid"):
84 | super(RobustFocalLoss2d, self).__init__()
85 | self.gamma = gamma
86 | self.size_average = size_average
87 | self.type = type
88 |
89 | def forward(self, logit, target, class_weight=None):
90 | target = target.view(-1, 1).long()
91 | if self.type=='sigmoid':
92 | if class_weight is None:
93 | class_weight = [1]*2
94 |
95 | prob = F.sigmoid(logit)
96 | prob = prob.view(-1, 1)
97 | prob = torch.cat((1-prob, prob), 1)
98 | select = torch.FloatTensor(len(prob), 2).zero_().cuda()
99 | select.scatter_(1, target, 1.)
100 |
101 | elif self.type=='softmax':
102 | B,C,H,W = logit.size()
103 | if class_weight is None:
104 | class_weight =[1]*C
105 |
106 | logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C)
107 | prob = F.softmax(logit,1)
108 | select = torch.FloatTensor(len(prob), C).zero_().cuda()
109 | select.scatter_(1, target, 1.)
110 |
111 | class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1)
112 | class_weight = torch.gather(class_weight, 0, target)
113 |
114 | prob = (prob*select).sum(1).view(-1,1)
115 | prob = torch.clamp(prob,1e-8,1-1e-8)
116 |
117 | focus = torch.pow((1-prob), self.gamma)
118 | focus = torch.clamp(focus,0,2)
119 |
120 | batch_loss = - class_weight *focus*prob.log()
121 |
122 | if self.size_average:
123 | loss = batch_loss.mean()
124 | else:
125 | loss = batch_loss
126 |
127 | return loss
128 |
129 |
130 |
131 | class DiceLoss(nn.Module):
132 | def __init__(self):
133 | super(DiceLoss, self).__init__()
134 |
135 | def forward(self, input, target):
136 | N = target.size(0)
137 | smooth = 1
138 |
139 | input_flat = input.view(N, -1)
140 | target_flat = target.view(N, -1)
141 |
142 | intersection = input_flat * target_flat
143 |
144 | loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
145 | loss = 1 - loss.sum() / N
146 |
147 | return loss
148 |
149 |
150 | class MulticlassDiceLoss(nn.Module):
151 | """
152 | requires one hot encoded target. Applies DiceLoss on each class iteratively.
153 | requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
154 | batch size and C is number of classes
155 | """
156 |
157 | def __init__(self):
158 | super(MulticlassDiceLoss, self).__init__()
159 |
160 | def forward(self, input, target, weights=None):
161 |
162 | C = target.shape[1]
163 |
164 | # if weights is None:
165 | # weights = torch.ones(C) #uniform weights for all classes
166 |
167 | dice = DiceLoss()
168 | totalLoss = 0
169 |
170 | for i in range(C):
171 | diceLoss = dice(input[:, i], target[:, i])
172 | if weights is not None:
173 | diceLoss *= weights[i]
174 | totalLoss += diceLoss
175 |
176 | return totalLoss
177 |
178 |
179 |
180 |
181 | class Weight_DiceLoss(nn.Module):
182 | def __init__(self):
183 | super(Weight_DiceLoss, self).__init__()
184 |
185 | def forward(self, input, target, weights):
186 | N = target.size(0)
187 | smooth = 1
188 |
189 | input_flat = input.view(N, -1)
190 | target_flat = target.view(N, -1)
191 | weights = weights.view(N, -1)
192 |
193 | intersection = input_flat * target_flat
194 | intersection = intersection * weights
195 |
196 | dice = 2 * (intersection.sum(1) + smooth) / ((input_flat * weights).sum(1) + (target_flat * weights).sum(1) + smooth)
197 | loss = 1 - dice.sum() / N
198 |
199 | return loss
200 |
201 |
202 | class WeightMulticlassDiceLoss(nn.Module):
203 | """
204 | requires one hot encoded target. Applies DiceLoss on each class iteratively.
205 | requires input.shape[0:1] and target.shape[0:1] to be (N, C) where N is
206 | batch size and C is number of classes
207 | """
208 |
209 | def __init__(self):
210 | super(WeightMulticlassDiceLoss, self).__init__()
211 |
212 | def forward(self, input, target, weights=None):
213 |
214 | C = target.shape[1]
215 |
216 | # if weights is None:
217 | # weights = torch.ones(C) #uniform weights for all classes
218 | # weights[0] = 3
219 | dice = DiceLoss()
220 | wdice = Weight_DiceLoss()
221 | totalLoss = 0
222 |
223 | for i in range(C):
224 | # diceLoss = dice(input[:, i], target[:, i])
225 | # diceLoss2 = 1 - wdice(input[:, i], target[:, i - 1])
226 | # diceLoss3 = 1 - wdice(input[:, i], target[:, i%(C-1) + 1])
227 | # diceLoss = diceLoss - diceLoss2 - diceLoss3
228 |
229 | # diceLoss = dice(input[:, i - 1] + input[:, i] + input[:, i%(C-1) + 1], target[:, i])
230 | ''''''
231 | if (i == 0):
232 | diceLoss = wdice(input[:, i], target[:, i], weights) * 2
233 | elif (i == 1):
234 | # diceLoss = dice(input[:, C - 1] + input[:, i] + input[:, i + 1], target[:, i])
235 | diceLoss = wdice(input[:, i], target[:, i], weights)
236 | diceLoss2 = 1 - wdice(input[:, i], target[:, C - 1], weights)
237 | diceLoss3 = 1 - wdice(input[:, i], target[:, i + 1], weights)
238 | diceLoss = diceLoss - diceLoss2 - diceLoss3
239 |
240 | elif (i == C - 1):
241 | # diceLoss = dice(input[:, i - 1] + input[:, i] + input[:, 1], target[:, i])
242 | diceLoss = wdice(input[:, i], target[:, i], weights)
243 | diceLoss2 = 1 - wdice(input[:, i], target[:, i - 1], weights)
244 | diceLoss3 = 1 - wdice(input[:, i], target[:, 1], weights)
245 | diceLoss = diceLoss - diceLoss2 - diceLoss3
246 |
247 | else:
248 | # diceLoss = dice(input[:, i - 1] + input[:, i] + input[:, i + 1], target[:, i])
249 | diceLoss = wdice(input[:, i], target[:, i], weights)
250 | diceLoss2 = 1 - wdice(input[:, i], target[:, i - 1], weights)
251 | diceLoss3 = 1 - wdice(input[:, i], target[:, i + 1], weights)
252 | diceLoss = diceLoss - diceLoss2 - diceLoss3
253 |
254 | #if weights is not None:
255 | #diceLoss *= weights[i]
256 |
257 | totalLoss += diceLoss
258 | avgLoss = totalLoss/C
259 |
260 | return avgLoss
261 |
262 |
263 |
264 |
265 |
266 | class CenterLoss(nn.Module):
267 | """Center loss.
268 |
269 | Reference:
270 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
271 |
272 | Args:
273 | num_classes (int): number of classes.
274 | feat_dim (int): feature dimension.
275 | """
276 |
277 | def __init__(self, num_classes=3, feat_dim=3, use_gpu=True):
278 | super(CenterLoss, self).__init__()
279 | self.num_classes = num_classes
280 | self.feat_dim = feat_dim
281 | self.use_gpu = use_gpu
282 |
283 | if self.use_gpu:
284 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
285 | print(self.centers)
286 | else:
287 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))
288 |
289 | def forward(self, input_x, input_label):
290 | """
291 | Args:
292 | x: feature matrix with shape (batch_size, feat_dim).
293 | labels: ground truth labels with shape (batch_size).
294 | """
295 | labels = input_label
296 | batch_size = input_x.size(0)
297 | channels = input_x.size(1)
298 |
299 | distmat = torch.pow(input_x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
300 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
301 | distmat.addmm_(1, -2, input_x, self.centers.t()) # math:: out = beta * mat + alpha * (mat1_i @ mat2_i)
302 |
303 | classes = torch.arange(self.num_classes).long()
304 | if self.use_gpu: classes = classes.cuda()
305 | labels2 = input_label.unsqueeze(1).expand(batch_size, self.num_classes)
306 | mask = labels2.cuda().eq(classes.expand(batch_size, self.num_classes)) # eq() 想等返回1, 不相等返回0
307 |
308 | dist = distmat * mask.float()
309 |
310 |
311 | # torch.clamp(input, min, max, out=None) 将输入input张量每个元素的夹紧到区间 [min,max][min,max],并返回结果到一个新张量
312 |
313 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
314 |
315 | return loss
316 |
317 |
318 |
319 | import torch.nn as nn
320 | import torch.nn.functional as F
321 |
322 | def one_hot(label, n_classes, requires_grad=True):
323 | """Return One Hot Label"""
324 | device = label.device
325 | one_hot_label = torch.eye(n_classes, device=device, requires_grad=requires_grad)[label]
326 | one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3)
327 |
328 | return one_hot_label
329 |
330 |
331 | class BoundaryLoss(nn.Module):
332 | """Boundary Loss proposed in:
333 | Alexey Bokhovkin et al., Boundary Loss for Remote Sensing Imagery Semantic Segmentation
334 | https://arxiv.org/abs/1905.07852
335 | """
336 |
337 | def __init__(self, theta0=3, theta=5):
338 | super().__init__()
339 |
340 | self.theta0 = theta0
341 | self.theta = theta
342 |
343 | def forward(self, pred_output, gt):
344 | """
345 | Input:
346 | - pred_output: the output from model (before softmax)
347 | shape (N, C, H, W)
348 | - gt: ground truth map #这是原来的输入,最新输入为(N, C, H, W)
349 | shape (N, H, w)
350 | Return:
351 | - boundary loss, averaged over mini-bathc
352 | """
353 |
354 | n, c, _, _ = pred_output.shape
355 |
356 | # softmax so that predicted map can be distributed in [0, 1]
357 | pred = torch.softmax(pred_output, dim=1)
358 |
359 | # one-hot vector of ground truth
360 | #one_hot_gt = one_hot(gt.long(), c) # 这是原来的输入,最新输入为(N, C, H, W)
361 | one_hot_gt = gt
362 |
363 |
364 |
365 | # boundary map
366 | gt_b = F.max_pool2d(1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
367 | gt_b -= 1 - one_hot_gt
368 |
369 | pred_b = F.max_pool2d(1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2)
370 | pred_b -= 1 - pred
371 |
372 | # extended boundary map
373 | gt_b_ext = F.max_pool2d(gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)
374 |
375 | pred_b_ext = F.max_pool2d(pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2)
376 |
377 | # reshape
378 | gt_b = gt_b.view(n, c, -1)
379 | pred_b = pred_b.view(n, c, -1)
380 | gt_b_ext = gt_b_ext.view(n, c, -1)
381 | pred_b_ext = pred_b_ext.view(n, c, -1)
382 |
383 | # Precision, Recall
384 | P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7)
385 | R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7)
386 |
387 | # Boundary F1 Score
388 | BF1 = 2 * P * R / (P + R + 1e-7)
389 |
390 | # summing BF1 Score for each class and average over mini-batch
391 | loss = torch.mean(1 - BF1)
392 |
393 | return loss
394 |
395 |
396 |
397 |
398 |
399 | def dice_loss(input, target, eps=1e-7, if_sigmoid=True):
400 | if if_sigmoid:
401 | input = F.sigmoid(input)
402 | b = input.shape[0]
403 | iflat = input.contiguous().view(b, -1)
404 | tflat = target.float().contiguous().view(b, -1)
405 | intersection = (iflat * tflat).sum(dim=1)
406 | L = (1 - ((2. * intersection + eps) / (iflat.pow(2).sum(dim=1) + tflat.pow(2).sum(dim=1) + eps))).mean()
407 | return L
408 |
409 | def smooth_truncated_loss(p, t, ths=0.06, if_reduction=True, if_balance=True):
410 | n_log_pt = F.binary_cross_entropy_with_logits(p, t, reduction='none')
411 | pt = (-n_log_pt).exp()
412 | L = torch.where(pt>=ths, n_log_pt, -math.log(ths)+0.5*(1-pt.pow(2)/(ths**2)))
413 | if if_reduction:
414 | if if_balance:
415 | return 0.5*((L*t).sum()/t.sum().clamp(1) + (L*(1-t)).sum()/(1-t).sum().clamp(1))
416 | else:
417 | return L.mean()
418 | else:
419 | return L
420 |
421 | def balance_bce_loss(input, target):
422 | L0 = F.binary_cross_entropy_with_logits(input, target, reduction='none')
423 | return 0.5*((L0*target).sum()/target.sum().clamp(1)+(L0*(1-target)).sum()/(1-target).sum().clamp(1))
424 |
425 | def compute_loss_list(loss_func, pred=[], target=[], **kwargs):
426 | losses = []
427 | for ipred, itarget in zip(pred, target):
428 | losses.append(loss_func(ipred, itarget, **kwargs))
429 | return losses
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
--------------------------------------------------------------------------------
/models/FullNet.py:
--------------------------------------------------------------------------------
1 | """
2 | This script defines the structure of FullNet
3 |
4 | Author: Hui Qu
5 | """
6 |
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import math
12 |
13 |
14 | class ConvLayer(nn.Sequential):
15 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
16 | groups=1):
17 | super(ConvLayer, self).__init__()
18 | self.add_module('conv', nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
19 | padding=padding, dilation=dilation, bias=False, groups=groups))
20 | self.add_module('relu', nn.LeakyReLU(inplace=True))
21 | self.add_module('bn', nn.BatchNorm2d(out_channels))
22 |
23 |
24 | # --- different types of layers --- #
25 | class BasicLayer(nn.Sequential):
26 | def __init__(self, in_channels, growth_rate, drop_rate, dilation=1):
27 | super(BasicLayer, self).__init__()
28 | self.conv = ConvLayer(in_channels, growth_rate, kernel_size=3, stride=1, padding=dilation,
29 | dilation=dilation)
30 | self.drop_rate = drop_rate
31 |
32 | def forward(self, x):
33 | out = self.conv(x)
34 | if self.drop_rate > 0:
35 | out = F.dropout(out, p=self.drop_rate, training=self.training)
36 | return torch.cat([x, out], 1)
37 |
38 |
39 | class BottleneckLayer(nn.Sequential):
40 | def __init__(self, in_channels, growth_rate, drop_rate, dilation=1):
41 | super(BottleneckLayer, self).__init__()
42 |
43 | inter_planes = growth_rate * 4
44 | self.conv1 = ConvLayer(in_channels, inter_planes, kernel_size=1, padding=0)
45 | self.conv2 = ConvLayer(inter_planes, growth_rate, kernel_size=3, padding=dilation, dilation=dilation)
46 | self.drop_rate = drop_rate
47 |
48 | def forward(self, x):
49 | out = self.conv2(self.conv1(x))
50 | if self.drop_rate > 0:
51 | out = F.dropout(out, p=self.drop_rate, training=self.training)
52 | return torch.cat([x, out], 1)
53 |
54 |
55 | # --- dense block structure --- #
56 | class DenseBlock(nn.Sequential):
57 | def __init__(self, in_channels, growth_rate, drop_rate, layer_type, dilations):
58 | super(DenseBlock, self).__init__()
59 | for i in range(len(dilations)):
60 | layer = layer_type(in_channels+i*growth_rate, growth_rate, drop_rate, dilations[i])
61 | self.add_module('denselayer{:d}'.format(i+1), layer)
62 |
63 |
64 | def choose_hybrid_dilations(n_layers, dilation_schedule, is_hybrid):
65 | import numpy as np
66 | # key: (dilation, n_layers)
67 | HD_dict = {(1, 4): [1, 1, 1, 1],
68 | (2, 4): [1, 2, 3, 2],
69 | (4, 4): [1, 2, 5, 9],
70 | (8, 4): [3, 7, 10, 13],
71 | (16, 4): [13, 15, 17, 19],
72 | (1, 6): [1, 1, 1, 1, 1, 1],
73 | (2, 6): [1, 2, 3, 1, 2, 3],
74 | (4, 6): [1, 2, 3, 5, 6, 7],
75 | (8, 6): [2, 5, 7, 9, 11, 14],
76 | (16, 6): [10, 13, 16, 17, 19, 21]}
77 |
78 | dilation_list = np.zeros((len(dilation_schedule), n_layers), dtype=np.int32)
79 |
80 | for i in range(len(dilation_schedule)):
81 | dilation = dilation_schedule[i]
82 | if is_hybrid:
83 | dilation_list[i] = HD_dict[(dilation, n_layers)]
84 | else:
85 | dilation_list[i] = [dilation for k in range(n_layers)]
86 |
87 | return dilation_list
88 |
89 |
90 | class FullNet(nn.Module):
91 | def __init__(self, color_channels, output_channels=2, n_layers=6, growth_rate=24, compress_ratio=0.5,
92 | drop_rate=0.1, dilations=(1,2,4,8,16,4,1), is_hybrid=True, layer_type='basic'):
93 | super(FullNet, self).__init__()
94 | if layer_type == 'basic':
95 | layer_type = BasicLayer
96 | else:
97 | layer_type = BottleneckLayer
98 |
99 | # 1st conv before any dense block
100 | in_channels = 24
101 | self.conv1 = ConvLayer(color_channels, in_channels, kernel_size=3, padding=1)
102 |
103 | self.blocks = nn.Sequential()
104 | n_blocks = len(dilations)
105 |
106 | dilation_list = choose_hybrid_dilations(n_layers, dilations, is_hybrid)
107 |
108 | for i in range(n_blocks): # no trans in last block
109 | block = DenseBlock(in_channels, growth_rate, drop_rate, layer_type, dilation_list[i])
110 | self.blocks.add_module('block%d' % (i+1), block)
111 | num_trans_in = int(in_channels + n_layers * growth_rate)
112 | num_trans_out = int(math.floor(num_trans_in * compress_ratio))
113 | trans = ConvLayer(num_trans_in, num_trans_out, kernel_size=1, padding=0)
114 | self.blocks.add_module('trans%d' % (i+1), trans)
115 | in_channels = num_trans_out
116 | #print('block.size = ', block)
117 | #print('num_trans_in = ', num_trans_in, 'num_trans_out = ', num_trans_out)
118 | #print('trans.size = ', trans)
119 |
120 | # final conv
121 | self.conv2 = nn.Conv2d(in_channels, output_channels, kernel_size=3, stride=1,
122 | padding=1, bias=False)
123 | # initialization
124 | for m in self.modules():
125 | if isinstance(m, nn.Conv2d):
126 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
127 | m.weight.data.normal_(0, math.sqrt(2. / n))
128 | elif isinstance(m, nn.BatchNorm2d):
129 | m.weight.data.fill_(1)
130 | m.bias.data.zero_()
131 | elif isinstance(m, nn.Linear):
132 | m.bias.data.zero_()
133 |
134 | def forward(self, x):
135 | out = self.conv1(x)
136 | out = self.blocks(out)
137 | out = self.conv2(out)
138 | return out
139 |
140 |
141 | class FCN_pooling(nn.Module):
142 | """same structure with FullNet, except that there are pooling operations after block 1, 2, 3, 4
143 | and upsampling after block 5, 6
144 | """
145 | def __init__(self, color_channels, output_channels=2, n_layers=6, growth_rate=24, compress_ratio=0.5,
146 | drop_rate=0.1, dilations=(1,2,4,8,16,4,1), is_hybrid=True, layer_type='basic'):
147 | super(FCN_pooling, self).__init__()
148 | if layer_type == 'basic':
149 | layer_type = BasicLayer
150 | else:
151 | layer_type = BottleneckLayer
152 |
153 | # 1st conv before any dense block
154 | in_channels = 24
155 | self.conv1 = ConvLayer(color_channels, in_channels, kernel_size=3, padding=1)
156 |
157 | self.blocks = nn.Sequential()
158 | n_blocks = len(dilations)
159 |
160 | dilation_list = choose_hybrid_dilations(n_layers, dilations, is_hybrid)
161 |
162 | for i in range(7):
163 | block = DenseBlock(in_channels, growth_rate, drop_rate, layer_type, dilation_list[i])
164 | self.blocks.add_module('block{:d}'.format(i+1), block)
165 | num_trans_in = int(in_channels + n_layers * growth_rate)
166 | num_trans_out = int(math.floor(num_trans_in * compress_ratio))
167 | trans = ConvLayer(num_trans_in, num_trans_out, kernel_size=1, padding=0)
168 | self.blocks.add_module('trans{:d}'.format(i+1), trans)
169 | if i in range(0, 4):
170 | self.blocks.add_module('pool{:d}'.format(i+1), nn.MaxPool2d(kernel_size=2, stride=2))
171 | elif i in range(4, 6):
172 | self.blocks.add_module('upsample{:d}'.format(i + 1), nn.UpsamplingBilinear2d(scale_factor=4))
173 | in_channels = num_trans_out
174 |
175 | # final conv
176 | self.conv2 = nn.Conv2d(in_channels, output_channels, kernel_size=3, stride=1,
177 | padding=1, bias=False)
178 | # initialization
179 | for m in self.modules():
180 | if isinstance(m, nn.Conv2d):
181 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
182 | m.weight.data.normal_(0, math.sqrt(2. / n))
183 | elif isinstance(m, nn.BatchNorm2d):
184 | m.weight.data.fill_(1)
185 | m.bias.data.zero_()
186 | elif isinstance(m, nn.Linear):
187 | m.bias.data.zero_()
188 |
189 | def forward(self, x):
190 | out = self.conv1(x)
191 | out = self.blocks(out)
192 | out = self.conv2(out)
193 | return out
194 |
--------------------------------------------------------------------------------
/models/dam/model_unet_MandD.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models, datasets, transforms
4 | from torch.nn import functional as F
5 | import os
6 |
7 |
8 | class revAttention(nn.Module): #sSE
9 | def __init__(self, in_channels):
10 | super().__init__()
11 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
12 | self.norm = nn.Sigmoid()
13 |
14 | def forward(self, U, V):
15 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w]
16 | q = self.norm(q)
17 | return U * (1+q) #
18 |
19 |
20 |
21 |
22 | def get_backbone(name, pretrained=True):
23 |
24 | """ Loading backbone, defining names for skip-connections and encoder output. """
25 |
26 | # TODO: More backbones
27 |
28 | # loading backbone model
29 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}')
30 | if name == 'resnet18':
31 | backbone = models.resnet18(pretrained=pretrained)
32 | elif name == 'resnet34':
33 | backbone = models.resnet34(pretrained=pretrained)
34 | elif name == 'resnet50':
35 | backbone = models.resnet50(pretrained=pretrained)
36 | elif name == 'resnet101':
37 | backbone = models.resnet101(pretrained=pretrained)
38 | elif name == 'resnet152':
39 | backbone = models.resnet152(pretrained=pretrained)
40 | elif name == 'vgg16_bn':
41 | backbone = models.vgg16_bn(pretrained=pretrained).features
42 | elif name == 'vgg19_bn':
43 | backbone = models.vgg19_bn(pretrained=pretrained).features
44 | # elif name == 'inception_v3':
45 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False)
46 | elif name == 'densenet121':
47 | backbone = models.densenet121(pretrained=True).features
48 | elif name == 'densenet161':
49 | backbone = models.densenet161(pretrained=True).features
50 | elif name == 'densenet169':
51 | backbone = models.densenet169(pretrained=True).features
52 | elif name == 'densenet201':
53 | backbone = models.densenet201(pretrained=True).features
54 | elif name == 'unet_encoder':
55 | from unet_backbone import UnetEncoder
56 | backbone = UnetEncoder(3)
57 | else:
58 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
59 |
60 | # specifying skip feature and output names
61 | if name.startswith('resnet'):
62 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
63 | backbone_output = 'layer4'
64 | elif name == 'vgg16_bn':
65 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
66 | feature_names = ['5', '12', '22', '32', '42']
67 | backbone_output = '43'
68 | elif name == 'vgg19_bn':
69 | feature_names = ['5', '12', '25', '38', '51']
70 | backbone_output = '52'
71 | # elif name == 'inception_v3':
72 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e']
73 | # backbone_output = 'Mixed_7c'
74 | elif name.startswith('densenet'):
75 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3']
76 | backbone_output = 'denseblock4'
77 | elif name == 'unet_encoder':
78 | feature_names = ['module1', 'module2', 'module3', 'module4']
79 | backbone_output = 'module5'
80 | else:
81 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
82 |
83 | return backbone, feature_names, backbone_output
84 |
85 |
86 | class UpsampleBlock(nn.Module):
87 |
88 | # TODO: separate parametric and non-parametric classes?
89 | # TODO: skip connection concatenated OR added
90 |
91 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
92 | super(UpsampleBlock, self).__init__()
93 |
94 | self.parametric = parametric
95 | ch_out = ch_in/2 if ch_out is None else ch_out
96 |
97 | # first convolution: either transposed conv, or conv following the skip connection
98 | if parametric:
99 | # versions: kernel=4 padding=1, kernel=2 padding=0
100 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
101 | stride=2, padding=1, output_padding=0, bias=(not use_bn))
102 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
103 | else:
104 | self.up = None
105 | ch_in = ch_in + skip_in
106 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
107 | stride=1, padding=1, bias=(not use_bn))
108 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
109 |
110 | self.relu = nn.ReLU(inplace=True)
111 |
112 | # second convolution
113 | conv2_in = ch_out if not parametric else ch_out + skip_in
114 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
115 | stride=1, padding=1, bias=(not use_bn))
116 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None
117 |
118 | #def forward(self, x, skip_connection=None): #
119 | def forward(self, x, skip_connection=1): #
120 |
121 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
122 | align_corners=None)
123 | if self.parametric:
124 | x = self.bn1(x) if self.bn1 is not None else x
125 | x = self.relu(x)
126 |
127 | if skip_connection is not None:
128 | # Padding in case the incomping volumes are of different sizes #hhl20200413add
129 | diffY = skip_connection.size()[2] - x.size()[2]
130 | diffX = skip_connection.size()[3] - x.size()[3]
131 | x = F.pad(x, (diffX // 2, diffX - diffX // 2,
132 | diffY // 2, diffY - diffY // 2))
133 |
134 | x = torch.cat([x, skip_connection], dim=1)
135 |
136 | if not self.parametric:
137 | x = self.conv1(x)
138 | x = self.bn1(x) if self.bn1 is not None else x
139 | x = self.relu(x)
140 | x = self.conv2(x)
141 | x = self.bn2(x) if self.bn2 is not None else x
142 | x = self.relu(x)
143 |
144 | return x
145 |
146 |
147 |
148 | def conv3x3(in_channels, out_channels, stride=1):
149 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
150 |
151 | class ResidualUnit(nn.Module):
152 | def __init__(self, in_channels, out_channels):
153 | super(ResidualUnit, self).__init__()
154 | self.conv1 = conv3x3(in_channels, out_channels, stride=1)
155 | self.bn1 = nn.BatchNorm2d(out_channels)
156 | self.relu1 = nn.ReLU(inplace=True)
157 | self.conv2 = conv3x3(out_channels, out_channels, stride=1)
158 | self.bn2 = nn.BatchNorm2d(out_channels)
159 | self.relu2 = nn.ReLU(inplace=True)
160 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
161 |
162 | def forward(self, x):
163 | residual = self.conv_1x1(x)
164 | out = self.conv1(x)
165 | out = self.bn1(out)
166 | out = self.relu1(out)
167 | out = self.conv2(out)
168 | out = self.bn2(out)
169 | out += residual
170 | out = self.relu2(out)
171 | return out
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 | class Unet(nn.Module):
182 |
183 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
184 |
185 | def __init__(self,
186 | backbone_name='resnet50',
187 | pretrained=True,
188 | encoder_freeze=False,
189 | classes=21,
190 | decoder_filters=(256, 128, 64, 32, 16),
191 | parametric_upsampling=True,
192 | shortcut_features='default',
193 | decoder_use_batchnorm=True):
194 | super(Unet, self).__init__()
195 |
196 | self.backbone_name = backbone_name
197 |
198 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained)
199 | shortcut_chs, bb_out_chs = self.infer_skip_channels()
200 | if shortcut_features != 'default':
201 | self.shortcut_features = shortcut_features
202 |
203 | # build decoder part
204 | self.upsample_blocks = nn.ModuleList()
205 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections
206 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
207 | num_blocks = len(self.shortcut_features)
208 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
209 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out))
210 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
211 | skip_in=shortcut_chs[num_blocks-i-1],
212 | parametric=parametric_upsampling,
213 | use_bn=decoder_use_batchnorm))
214 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1))
215 |
216 | if encoder_freeze:
217 | self.freeze_encoder()
218 |
219 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later
220 |
221 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) #
222 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
223 |
224 |
225 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64)
226 | self.direction_feature = ResidualUnit(64, 64)
227 | self.point_feature = ResidualUnit(64, 64)
228 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1)
229 | self.directionAtt = revAttention(1)
230 | self.direction_conv = nn.Conv2d(64, 9, kernel_size=1)
231 | self.maskAtt = revAttention(9)
232 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1)
233 |
234 | self.residual = ResidualUnit(64, 64)
235 |
236 |
237 |
238 | def freeze_encoder(self):
239 |
240 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
241 |
242 | for param in self.backbone.parameters():
243 | param.requires_grad = False
244 |
245 | def forward(self, *input):
246 |
247 | """ Forward propagation in U-Net. """
248 |
249 | x, features = self.forward_backbone(*input)
250 |
251 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
252 | skip_features = features[skip_name]
253 | x = upsample_block(x, skip_features)
254 |
255 | x_F1 = self.mask_feature(x)
256 |
257 | x_F2 = self.direction_feature(x_F1)
258 |
259 | x_direction = self.direction_conv(x_F2)
260 |
261 | x_F1_mask = self.residual(x_F1)
262 | x_final_mask = self.mask_conv(x_F1_mask)
263 |
264 |
265 |
266 | return x_final_mask, x_direction
267 |
268 | def forward_backbone(self, x):
269 |
270 | """ Forward propagation in backbone encoder network. """
271 |
272 | features = {None: None} if None in self.shortcut_features else dict()
273 | for name, child in self.backbone.named_children():
274 |
275 | if(name == '0' and x.shape[1] !=3):
276 | x = self.child0(x)
277 | elif(name == 'conv1' and x.shape[1] !=3):
278 | x = self.child_conv1(x)
279 | else:
280 | x = child(x)
281 | #x = child(x)
282 | if name in self.shortcut_features:
283 | features[name] = x
284 | if name == self.bb_out_name:
285 | break
286 |
287 | return x, features
288 |
289 | def infer_skip_channels(self):
290 |
291 | """ Getting the number of channels at skip connections and at the output of the encoder. """
292 |
293 | x = torch.zeros(1, 3, 224, 224)
294 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder'
295 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution
296 |
297 | # forward run in backbone to count channels (dirty solution but works for *any* Module)
298 | for name, child in self.backbone.named_children():
299 | x = child(x)
300 | if name in self.shortcut_features:
301 | channels.append(x.shape[1])
302 | if name == self.bb_out_name:
303 | out_channels = x.shape[1]
304 | break
305 | return channels, out_channels
306 |
307 | def get_pretrained_parameters(self):
308 | for name, param in self.backbone.named_parameters():
309 | if not (self.replaced_conv1 and name == 'conv1.weight'):
310 | yield param
311 |
312 | def get_random_initialized_parameters(self):
313 | pretrained_param_names = set()
314 | for name, param in self.backbone.named_parameters():
315 | if not (self.replaced_conv1 and name == 'conv1.weight'):
316 | pretrained_param_names.add('backbone.{}'.format(name))
317 |
318 | for name, param in self.named_parameters():
319 | if name not in pretrained_param_names:
320 | yield param
321 |
322 |
323 | # if __name__ == "__main__":
324 |
325 | # # simple test run
326 | # net = Unet(backbone_name='resnet18')
327 |
328 | # criterion = nn.MSELoss()
329 | # optimizer = torch.optim.Adam(net.parameters())
330 | # print('Network initialized. Running a test batch.')
331 | # for _ in range(1):
332 | # with torch.set_grad_enabled(True):
333 | # batch = torch.empty(1, 3, 224, 224).normal_()
334 | # targets = torch.empty(1, 21, 224, 224).normal_()
335 |
336 | # out = net(batch)
337 | # loss = criterion(out, targets)
338 | # loss.backward()
339 | # optimizer.step()
340 | # print(out.shape)
341 |
342 | # print('fasza.')
343 |
--------------------------------------------------------------------------------
/models/dam/model_unet_MandD16.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models, datasets, transforms
4 | from torch.nn import functional as F
5 | import os
6 |
7 |
8 |
9 | class revAttention(nn.Module): #sSE
10 | def __init__(self, in_channels):
11 | super().__init__()
12 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
13 | self.norm = nn.Sigmoid()
14 |
15 | def forward(self, U, V):
16 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w]
17 | q = self.norm(q)
18 | return U * (1+q) #
19 |
20 |
21 |
22 |
23 | def get_backbone(name, pretrained=True):
24 |
25 | """ Loading backbone, defining names for skip-connections and encoder output. """
26 |
27 | # TODO: More backbones
28 |
29 | # loading backbone model
30 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}')
31 | if name == 'resnet18':
32 | backbone = models.resnet18(pretrained=pretrained)
33 | elif name == 'resnet34':
34 | backbone = models.resnet34(pretrained=pretrained)
35 | elif name == 'resnet50':
36 | backbone = models.resnet50(pretrained=pretrained)
37 | elif name == 'resnet101':
38 | backbone = models.resnet101(pretrained=pretrained)
39 | elif name == 'resnet152':
40 | backbone = models.resnet152(pretrained=pretrained)
41 | elif name == 'vgg16_bn':
42 | backbone = models.vgg16_bn(pretrained=pretrained).features
43 | elif name == 'vgg19_bn':
44 | backbone = models.vgg19_bn(pretrained=pretrained).features
45 | # elif name == 'inception_v3':
46 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False)
47 | elif name == 'densenet121':
48 | backbone = models.densenet121(pretrained=True).features
49 | elif name == 'densenet161':
50 | backbone = models.densenet161(pretrained=True).features
51 | elif name == 'densenet169':
52 | backbone = models.densenet169(pretrained=True).features
53 | elif name == 'densenet201':
54 | backbone = models.densenet201(pretrained=True).features
55 | elif name == 'unet_encoder':
56 | from unet_backbone import UnetEncoder
57 | backbone = UnetEncoder(3)
58 | else:
59 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
60 |
61 | # specifying skip feature and output names
62 | if name.startswith('resnet'):
63 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
64 | backbone_output = 'layer4'
65 | elif name == 'vgg16_bn':
66 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
67 | feature_names = ['5', '12', '22', '32', '42']
68 | backbone_output = '43'
69 | elif name == 'vgg19_bn':
70 | feature_names = ['5', '12', '25', '38', '51']
71 | backbone_output = '52'
72 | # elif name == 'inception_v3':
73 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e']
74 | # backbone_output = 'Mixed_7c'
75 | elif name.startswith('densenet'):
76 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3']
77 | backbone_output = 'denseblock4'
78 | elif name == 'unet_encoder':
79 | feature_names = ['module1', 'module2', 'module3', 'module4']
80 | backbone_output = 'module5'
81 | else:
82 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
83 |
84 | return backbone, feature_names, backbone_output
85 |
86 |
87 | class UpsampleBlock(nn.Module):
88 |
89 | # TODO: separate parametric and non-parametric classes?
90 | # TODO: skip connection concatenated OR added
91 |
92 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
93 | super(UpsampleBlock, self).__init__()
94 |
95 | self.parametric = parametric
96 | ch_out = ch_in/2 if ch_out is None else ch_out
97 |
98 | # first convolution: either transposed conv, or conv following the skip connection
99 | if parametric:
100 | # versions: kernel=4 padding=1, kernel=2 padding=0
101 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
102 | stride=2, padding=1, output_padding=0, bias=(not use_bn))
103 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
104 | else:
105 | self.up = None
106 | ch_in = ch_in + skip_in
107 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
108 | stride=1, padding=1, bias=(not use_bn))
109 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
110 |
111 | self.relu = nn.ReLU(inplace=True)
112 |
113 | # second convolution
114 | conv2_in = ch_out if not parametric else ch_out + skip_in
115 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
116 | stride=1, padding=1, bias=(not use_bn))
117 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None
118 |
119 | #def forward(self, x, skip_connection=None): #
120 | def forward(self, x, skip_connection=1): #
121 |
122 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
123 | align_corners=None)
124 | if self.parametric:
125 | x = self.bn1(x) if self.bn1 is not None else x
126 | x = self.relu(x)
127 |
128 | if skip_connection is not None:
129 | # Padding in case the incomping volumes are of different sizes
130 | diffY = skip_connection.size()[2] - x.size()[2]
131 | diffX = skip_connection.size()[3] - x.size()[3]
132 | x = F.pad(x, (diffX // 2, diffX - diffX // 2,
133 | diffY // 2, diffY - diffY // 2))
134 |
135 | x = torch.cat([x, skip_connection], dim=1)
136 |
137 | if not self.parametric:
138 | x = self.conv1(x)
139 | x = self.bn1(x) if self.bn1 is not None else x
140 | x = self.relu(x)
141 | x = self.conv2(x)
142 | x = self.bn2(x) if self.bn2 is not None else x
143 | x = self.relu(x)
144 |
145 | return x
146 |
147 |
148 |
149 | def conv3x3(in_channels, out_channels, stride=1):
150 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
151 |
152 | class ResidualUnit(nn.Module):
153 | def __init__(self, in_channels, out_channels):
154 | super(ResidualUnit, self).__init__()
155 | self.conv1 = conv3x3(in_channels, out_channels, stride=1)
156 | self.bn1 = nn.BatchNorm2d(out_channels)
157 | self.relu1 = nn.ReLU(inplace=True)
158 | self.conv2 = conv3x3(out_channels, out_channels, stride=1)
159 | self.bn2 = nn.BatchNorm2d(out_channels)
160 | self.relu2 = nn.ReLU(inplace=True)
161 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
162 |
163 | def forward(self, x):
164 | residual = self.conv_1x1(x)
165 | out = self.conv1(x)
166 | out = self.bn1(out)
167 | out = self.relu1(out)
168 | out = self.conv2(out)
169 | out = self.bn2(out)
170 | out += residual
171 | out = self.relu2(out)
172 | return out
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 | class Unet(nn.Module):
183 |
184 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
185 |
186 | def __init__(self,
187 | backbone_name='resnet50',
188 | pretrained=True,
189 | encoder_freeze=False,
190 | classes=21,
191 | decoder_filters=(256, 128, 64, 32, 16),
192 | parametric_upsampling=True,
193 | shortcut_features='default',
194 | decoder_use_batchnorm=True):
195 | super(Unet, self).__init__()
196 |
197 | self.backbone_name = backbone_name
198 |
199 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained)
200 | shortcut_chs, bb_out_chs = self.infer_skip_channels()
201 | if shortcut_features != 'default':
202 | self.shortcut_features = shortcut_features
203 |
204 | # build decoder part
205 | self.upsample_blocks = nn.ModuleList()
206 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections
207 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
208 | num_blocks = len(self.shortcut_features)
209 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
210 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out))
211 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
212 | skip_in=shortcut_chs[num_blocks-i-1],
213 | parametric=parametric_upsampling,
214 | use_bn=decoder_use_batchnorm))
215 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1))
216 |
217 | if encoder_freeze:
218 | self.freeze_encoder()
219 |
220 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later
221 | #hhl20210611add 用来替代1通道input在child=0时的卷积
222 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) #
223 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
224 |
225 |
226 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64)
227 | self.direction_feature = ResidualUnit(64, 64)
228 | self.point_feature = ResidualUnit(64, 64)
229 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1)
230 | self.directionAtt = revAttention(1)
231 | self.direction_conv = nn.Conv2d(64, 16+1, kernel_size=1)
232 | self.maskAtt = revAttention(16+1)
233 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1)
234 |
235 | self.residual = ResidualUnit(64, 64)
236 |
237 |
238 |
239 | def freeze_encoder(self):
240 |
241 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
242 |
243 | for param in self.backbone.parameters():
244 | param.requires_grad = False
245 |
246 | def forward(self, *input):
247 |
248 | """ Forward propagation in U-Net. """
249 |
250 | x, features = self.forward_backbone(*input)
251 |
252 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
253 | skip_features = features[skip_name]
254 | x = upsample_block(x, skip_features)
255 |
256 | x_F1 = self.mask_feature(x)
257 |
258 | x_F2 = self.direction_feature(x_F1)
259 |
260 | x_direction = self.direction_conv(x_F2)
261 |
262 | x_F1_mask = self.residual(x_F1)
263 | x_final_mask = self.mask_conv(x_F1_mask)
264 |
265 |
266 |
267 | return x_final_mask, x_direction
268 |
269 | def forward_backbone(self, x):
270 |
271 | """ Forward propagation in backbone encoder network. """
272 |
273 | features = {None: None} if None in self.shortcut_features else dict()
274 | for name, child in self.backbone.named_children():
275 | # hhl20210611add x.shape[1] = 1的情况
276 | if(name == '0' and x.shape[1] !=3):
277 | x = self.child0(x)
278 | elif(name == 'conv1' and x.shape[1] !=3):
279 | x = self.child_conv1(x)
280 | else:
281 | x = child(x)
282 | #x = child(x)
283 | if name in self.shortcut_features:
284 | features[name] = x
285 | if name == self.bb_out_name:
286 | break
287 |
288 | return x, features
289 |
290 | def infer_skip_channels(self):
291 |
292 | """ Getting the number of channels at skip connections and at the output of the encoder. """
293 |
294 | x = torch.zeros(1, 3, 224, 224)
295 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder'
296 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution
297 |
298 | # forward run in backbone to count channels (dirty solution but works for *any* Module)
299 | for name, child in self.backbone.named_children():
300 | x = child(x)
301 | if name in self.shortcut_features:
302 | channels.append(x.shape[1])
303 | if name == self.bb_out_name:
304 | out_channels = x.shape[1]
305 | break
306 | return channels, out_channels
307 |
308 | def get_pretrained_parameters(self):
309 | for name, param in self.backbone.named_parameters():
310 | if not (self.replaced_conv1 and name == 'conv1.weight'):
311 | yield param
312 |
313 | def get_random_initialized_parameters(self):
314 | pretrained_param_names = set()
315 | for name, param in self.backbone.named_parameters():
316 | if not (self.replaced_conv1 and name == 'conv1.weight'):
317 | pretrained_param_names.add('backbone.{}'.format(name))
318 |
319 | for name, param in self.named_parameters():
320 | if name not in pretrained_param_names:
321 | yield param
322 |
323 |
324 | # if __name__ == "__main__":
325 |
326 | # # simple test run
327 | # net = Unet(backbone_name='resnet18')
328 |
329 | # criterion = nn.MSELoss()
330 | # optimizer = torch.optim.Adam(net.parameters())
331 | # print('Network initialized. Running a test batch.')
332 | # for _ in range(1):
333 | # with torch.set_grad_enabled(True):
334 | # batch = torch.empty(1, 3, 224, 224).normal_()
335 | # targets = torch.empty(1, 21, 224, 224).normal_()
336 |
337 | # out = net(batch)
338 | # loss = criterion(out, targets)
339 | # loss.backward()
340 | # optimizer.step()
341 | # print(out.shape)
342 |
343 | # print('fasza.')
344 |
--------------------------------------------------------------------------------
/models/dam/model_unet_MandD4.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models, datasets, transforms
4 | from torch.nn import functional as F
5 | import os
6 |
7 |
8 | class revAttention(nn.Module): #sSE
9 | def __init__(self, in_channels):
10 | super().__init__()
11 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
12 | self.norm = nn.Sigmoid()
13 |
14 | def forward(self, U, V):
15 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w]
16 | q = self.norm(q)
17 | return U * (1+q)
18 |
19 |
20 |
21 |
22 | def get_backbone(name, pretrained=True):
23 |
24 | """ Loading backbone, defining names for skip-connections and encoder output. """
25 |
26 | # TODO: More backbones
27 |
28 | # loading backbone model
29 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}')
30 | if name == 'resnet18':
31 | backbone = models.resnet18(pretrained=pretrained)
32 | elif name == 'resnet34':
33 | backbone = models.resnet34(pretrained=pretrained)
34 | elif name == 'resnet50':
35 | backbone = models.resnet50(pretrained=pretrained)
36 | elif name == 'resnet101':
37 | backbone = models.resnet101(pretrained=pretrained)
38 | elif name == 'resnet152':
39 | backbone = models.resnet152(pretrained=pretrained)
40 | elif name == 'vgg16_bn':
41 | backbone = models.vgg16_bn(pretrained=pretrained).features
42 | elif name == 'vgg19_bn':
43 | backbone = models.vgg19_bn(pretrained=pretrained).features
44 | # elif name == 'inception_v3':
45 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False)
46 | elif name == 'densenet121':
47 | backbone = models.densenet121(pretrained=True).features
48 | elif name == 'densenet161':
49 | backbone = models.densenet161(pretrained=True).features
50 | elif name == 'densenet169':
51 | backbone = models.densenet169(pretrained=True).features
52 | elif name == 'densenet201':
53 | backbone = models.densenet201(pretrained=True).features
54 | elif name == 'unet_encoder':
55 | from unet_backbone import UnetEncoder
56 | backbone = UnetEncoder(3)
57 | else:
58 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
59 |
60 | # specifying skip feature and output names
61 | if name.startswith('resnet'):
62 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
63 | backbone_output = 'layer4'
64 | elif name == 'vgg16_bn':
65 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
66 | feature_names = ['5', '12', '22', '32', '42']
67 | backbone_output = '43'
68 | elif name == 'vgg19_bn':
69 | feature_names = ['5', '12', '25', '38', '51']
70 | backbone_output = '52'
71 | # elif name == 'inception_v3':
72 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e']
73 | # backbone_output = 'Mixed_7c'
74 | elif name.startswith('densenet'):
75 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3']
76 | backbone_output = 'denseblock4'
77 | elif name == 'unet_encoder':
78 | feature_names = ['module1', 'module2', 'module3', 'module4']
79 | backbone_output = 'module5'
80 | else:
81 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
82 |
83 | return backbone, feature_names, backbone_output
84 |
85 |
86 | class UpsampleBlock(nn.Module):
87 |
88 | # TODO: separate parametric and non-parametric classes?
89 | # TODO: skip connection concatenated OR added
90 |
91 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
92 | super(UpsampleBlock, self).__init__()
93 |
94 | self.parametric = parametric
95 | ch_out = ch_in/2 if ch_out is None else ch_out
96 |
97 | # first convolution: either transposed conv, or conv following the skip connection
98 | if parametric:
99 | # versions: kernel=4 padding=1, kernel=2 padding=0
100 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
101 | stride=2, padding=1, output_padding=0, bias=(not use_bn))
102 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
103 | else:
104 | self.up = None
105 | ch_in = ch_in + skip_in
106 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
107 | stride=1, padding=1, bias=(not use_bn))
108 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
109 |
110 | self.relu = nn.ReLU(inplace=True)
111 |
112 | # second convolution
113 | conv2_in = ch_out if not parametric else ch_out + skip_in
114 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
115 | stride=1, padding=1, bias=(not use_bn))
116 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None
117 |
118 | #def forward(self, x, skip_connection=None): #
119 | def forward(self, x, skip_connection=1): #
120 |
121 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
122 | align_corners=None)
123 | if self.parametric:
124 | x = self.bn1(x) if self.bn1 is not None else x
125 | x = self.relu(x)
126 |
127 | if skip_connection is not None:
128 | # Padding in case the incomping volumes are of different sizes #hhl20200413add
129 | diffY = skip_connection.size()[2] - x.size()[2]
130 | diffX = skip_connection.size()[3] - x.size()[3]
131 | x = F.pad(x, (diffX // 2, diffX - diffX // 2,
132 | diffY // 2, diffY - diffY // 2))
133 |
134 | x = torch.cat([x, skip_connection], dim=1)
135 |
136 | if not self.parametric:
137 | x = self.conv1(x)
138 | x = self.bn1(x) if self.bn1 is not None else x
139 | x = self.relu(x)
140 | x = self.conv2(x)
141 | x = self.bn2(x) if self.bn2 is not None else x
142 | x = self.relu(x)
143 |
144 | return x
145 |
146 |
147 |
148 | def conv3x3(in_channels, out_channels, stride=1):
149 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
150 |
151 | class ResidualUnit(nn.Module):
152 | def __init__(self, in_channels, out_channels):
153 | super(ResidualUnit, self).__init__()
154 | self.conv1 = conv3x3(in_channels, out_channels, stride=1)
155 | self.bn1 = nn.BatchNorm2d(out_channels)
156 | self.relu1 = nn.ReLU(inplace=True)
157 | self.conv2 = conv3x3(out_channels, out_channels, stride=1)
158 | self.bn2 = nn.BatchNorm2d(out_channels)
159 | self.relu2 = nn.ReLU(inplace=True)
160 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
161 |
162 | def forward(self, x):
163 | residual = self.conv_1x1(x)
164 | out = self.conv1(x)
165 | out = self.bn1(out)
166 | out = self.relu1(out)
167 | out = self.conv2(out)
168 | out = self.bn2(out)
169 | out += residual
170 | out = self.relu2(out)
171 | return out
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 | class Unet(nn.Module):
182 |
183 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
184 |
185 | def __init__(self,
186 | backbone_name='resnet50',
187 | pretrained=True,
188 | encoder_freeze=False,
189 | classes=21,
190 | decoder_filters=(256, 128, 64, 32, 16),
191 | parametric_upsampling=True,
192 | shortcut_features='default',
193 | decoder_use_batchnorm=True):
194 | super(Unet, self).__init__()
195 |
196 | self.backbone_name = backbone_name
197 |
198 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained)
199 | shortcut_chs, bb_out_chs = self.infer_skip_channels()
200 | if shortcut_features != 'default':
201 | self.shortcut_features = shortcut_features
202 |
203 | # build decoder part
204 | self.upsample_blocks = nn.ModuleList()
205 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections
206 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
207 | num_blocks = len(self.shortcut_features)
208 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
209 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out))
210 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
211 | skip_in=shortcut_chs[num_blocks-i-1],
212 | parametric=parametric_upsampling,
213 | use_bn=decoder_use_batchnorm))
214 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1))
215 |
216 | if encoder_freeze:
217 | self.freeze_encoder()
218 |
219 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later
220 |
221 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) #
222 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
223 |
224 |
225 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64)
226 | self.direction_feature = ResidualUnit(64, 64)
227 | self.point_feature = ResidualUnit(64, 64)
228 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1)
229 | self.directionAtt = revAttention(1)
230 | self.direction_conv = nn.Conv2d(64, 4+1, kernel_size=1)
231 | self.maskAtt = revAttention(4+1)
232 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1)
233 |
234 | self.residual = ResidualUnit(64, 64)
235 |
236 |
237 |
238 | def freeze_encoder(self):
239 |
240 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
241 |
242 | for param in self.backbone.parameters():
243 | param.requires_grad = False
244 |
245 | def forward(self, *input):
246 |
247 | """ Forward propagation in U-Net. """
248 |
249 | x, features = self.forward_backbone(*input)
250 |
251 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
252 | skip_features = features[skip_name]
253 | x = upsample_block(x, skip_features)
254 |
255 | x_F1 = self.mask_feature(x)
256 |
257 | x_F2 = self.direction_feature(x_F1)
258 |
259 | x_direction = self.direction_conv(x_F2)
260 |
261 | x_F1_mask = self.residual(x_F1)
262 | x_final_mask = self.mask_conv(x_F1_mask)
263 |
264 |
265 |
266 | return x_final_mask, x_direction
267 |
268 | def forward_backbone(self, x):
269 |
270 | """ Forward propagation in backbone encoder network. """
271 |
272 | features = {None: None} if None in self.shortcut_features else dict()
273 | for name, child in self.backbone.named_children():
274 |
275 | if(name == '0' and x.shape[1] !=3):
276 | x = self.child0(x)
277 | elif(name == 'conv1' and x.shape[1] !=3):
278 | x = self.child_conv1(x)
279 | else:
280 | x = child(x)
281 | #x = child(x)
282 | if name in self.shortcut_features:
283 | features[name] = x
284 | if name == self.bb_out_name:
285 | break
286 |
287 | return x, features
288 |
289 | def infer_skip_channels(self):
290 |
291 | """ Getting the number of channels at skip connections and at the output of the encoder. """
292 |
293 | x = torch.zeros(1, 3, 224, 224)
294 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder'
295 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution
296 |
297 | # forward run in backbone to count channels (dirty solution but works for *any* Module)
298 | for name, child in self.backbone.named_children():
299 | x = child(x)
300 | if name in self.shortcut_features:
301 | channels.append(x.shape[1])
302 | if name == self.bb_out_name:
303 | out_channels = x.shape[1]
304 | break
305 | return channels, out_channels
306 |
307 | def get_pretrained_parameters(self):
308 | for name, param in self.backbone.named_parameters():
309 | if not (self.replaced_conv1 and name == 'conv1.weight'):
310 | yield param
311 |
312 | def get_random_initialized_parameters(self):
313 | pretrained_param_names = set()
314 | for name, param in self.backbone.named_parameters():
315 | if not (self.replaced_conv1 and name == 'conv1.weight'):
316 | pretrained_param_names.add('backbone.{}'.format(name))
317 |
318 | for name, param in self.named_parameters():
319 | if name not in pretrained_param_names:
320 | yield param
321 |
322 |
323 | # if __name__ == "__main__":
324 |
325 | # # simple test run
326 | # net = Unet(backbone_name='resnet18')
327 |
328 | # criterion = nn.MSELoss()
329 | # optimizer = torch.optim.Adam(net.parameters())
330 | # print('Network initialized. Running a test batch.')
331 | # for _ in range(1):
332 | # with torch.set_grad_enabled(True):
333 | # batch = torch.empty(1, 3, 224, 224).normal_()
334 | # targets = torch.empty(1, 21, 224, 224).normal_()
335 |
336 | # out = net(batch)
337 | # loss = criterion(out, targets)
338 | # loss.backward()
339 | # optimizer.step()
340 | # print(out.shape)
341 |
342 | # print('fasza.')
343 |
--------------------------------------------------------------------------------
/models/dam/model_unet_MandDandP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models, datasets, transforms
4 | from torch.nn import functional as F
5 | import os
6 |
7 |
8 | class revAttention(nn.Module): #sSE
9 | def __init__(self, in_channels):
10 | super().__init__()
11 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
12 | self.norm = nn.Sigmoid()
13 |
14 | def forward(self, U, V):
15 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w]
16 | q = self.norm(q)
17 | return U * (1+q) #
18 |
19 |
20 |
21 |
22 | def get_backbone(name, pretrained=True):
23 |
24 | """ Loading backbone, defining names for skip-connections and encoder output. """
25 |
26 | # TODO: More backbones
27 |
28 | # loading backbone model
29 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}')
30 | if name == 'resnet18':
31 | backbone = models.resnet18(pretrained=pretrained)
32 | elif name == 'resnet34':
33 | backbone = models.resnet34(pretrained=pretrained)
34 | elif name == 'resnet50':
35 | backbone = models.resnet50(pretrained=pretrained)
36 | elif name == 'resnet101':
37 | backbone = models.resnet101(pretrained=pretrained)
38 | elif name == 'resnet152':
39 | backbone = models.resnet152(pretrained=pretrained)
40 | elif name == 'vgg16_bn':
41 | backbone = models.vgg16_bn(pretrained=pretrained).features
42 | elif name == 'vgg19_bn':
43 | backbone = models.vgg19_bn(pretrained=pretrained).features
44 | # elif name == 'inception_v3':
45 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False)
46 | elif name == 'densenet121':
47 | backbone = models.densenet121(pretrained=True).features
48 | elif name == 'densenet161':
49 | backbone = models.densenet161(pretrained=True).features
50 | elif name == 'densenet169':
51 | backbone = models.densenet169(pretrained=True).features
52 | elif name == 'densenet201':
53 | backbone = models.densenet201(pretrained=True).features
54 | elif name == 'unet_encoder':
55 | from unet_backbone import UnetEncoder
56 | backbone = UnetEncoder(3)
57 | else:
58 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
59 |
60 | # specifying skip feature and output names
61 | if name.startswith('resnet'):
62 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
63 | backbone_output = 'layer4'
64 | elif name == 'vgg16_bn':
65 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
66 | feature_names = ['5', '12', '22', '32', '42']
67 | backbone_output = '43'
68 | elif name == 'vgg19_bn':
69 | feature_names = ['5', '12', '25', '38', '51']
70 | backbone_output = '52'
71 | # elif name == 'inception_v3':
72 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e']
73 | # backbone_output = 'Mixed_7c'
74 | elif name.startswith('densenet'):
75 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3']
76 | backbone_output = 'denseblock4'
77 | elif name == 'unet_encoder':
78 | feature_names = ['module1', 'module2', 'module3', 'module4']
79 | backbone_output = 'module5'
80 | else:
81 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
82 |
83 | return backbone, feature_names, backbone_output
84 |
85 |
86 | class UpsampleBlock(nn.Module):
87 |
88 | # TODO: separate parametric and non-parametric classes?
89 | # TODO: skip connection concatenated OR added
90 |
91 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
92 | super(UpsampleBlock, self).__init__()
93 |
94 | self.parametric = parametric
95 | ch_out = ch_in/2 if ch_out is None else ch_out
96 |
97 | # first convolution: either transposed conv, or conv following the skip connection
98 | if parametric:
99 | # versions: kernel=4 padding=1, kernel=2 padding=0
100 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
101 | stride=2, padding=1, output_padding=0, bias=(not use_bn))
102 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
103 | else:
104 | self.up = None
105 | ch_in = ch_in + skip_in
106 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
107 | stride=1, padding=1, bias=(not use_bn))
108 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
109 |
110 | self.relu = nn.ReLU(inplace=True)
111 |
112 | # second convolution
113 | conv2_in = ch_out if not parametric else ch_out + skip_in
114 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
115 | stride=1, padding=1, bias=(not use_bn))
116 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None
117 |
118 | #def forward(self, x, skip_connection=None): #
119 | def forward(self, x, skip_connection=1):
120 |
121 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
122 | align_corners=None)
123 | if self.parametric:
124 | x = self.bn1(x) if self.bn1 is not None else x
125 | x = self.relu(x)
126 |
127 | if skip_connection is not None:
128 | # Padding in case the incomping volumes are of different sizes
129 | diffY = skip_connection.size()[2] - x.size()[2]
130 | diffX = skip_connection.size()[3] - x.size()[3]
131 | x = F.pad(x, (diffX // 2, diffX - diffX // 2,
132 | diffY // 2, diffY - diffY // 2))
133 |
134 | x = torch.cat([x, skip_connection], dim=1)
135 |
136 | if not self.parametric:
137 | x = self.conv1(x)
138 | x = self.bn1(x) if self.bn1 is not None else x
139 | x = self.relu(x)
140 | x = self.conv2(x)
141 | x = self.bn2(x) if self.bn2 is not None else x
142 | x = self.relu(x)
143 |
144 | return x
145 |
146 |
147 |
148 | def conv3x3(in_channels, out_channels, stride=1):
149 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
150 |
151 | class ResidualUnit(nn.Module):
152 | def __init__(self, in_channels, out_channels):
153 | super(ResidualUnit, self).__init__()
154 | self.conv1 = conv3x3(in_channels, out_channels, stride=1)
155 | self.bn1 = nn.BatchNorm2d(out_channels)
156 | self.relu1 = nn.ReLU(inplace=True)
157 | self.conv2 = conv3x3(out_channels, out_channels, stride=1)
158 | self.bn2 = nn.BatchNorm2d(out_channels)
159 | self.relu2 = nn.ReLU(inplace=True)
160 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
161 |
162 | def forward(self, x):
163 | residual = self.conv_1x1(x)
164 | out = self.conv1(x)
165 | out = self.bn1(out)
166 | out = self.relu1(out)
167 | out = self.conv2(out)
168 | out = self.bn2(out)
169 | out += residual
170 | out = self.relu2(out)
171 | return out
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 | class Unet(nn.Module):
182 |
183 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
184 |
185 | def __init__(self,
186 | backbone_name='resnet50',
187 | pretrained=True,
188 | encoder_freeze=False,
189 | classes=21,
190 | decoder_filters=(256, 128, 64, 32, 16),
191 | parametric_upsampling=True,
192 | shortcut_features='default',
193 | decoder_use_batchnorm=True):
194 | super(Unet, self).__init__()
195 |
196 | self.backbone_name = backbone_name
197 |
198 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained)
199 | shortcut_chs, bb_out_chs = self.infer_skip_channels()
200 | if shortcut_features != 'default':
201 | self.shortcut_features = shortcut_features
202 |
203 | # build decoder part
204 | self.upsample_blocks = nn.ModuleList()
205 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections
206 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
207 | num_blocks = len(self.shortcut_features)
208 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
209 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out))
210 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
211 | skip_in=shortcut_chs[num_blocks-i-1],
212 | parametric=parametric_upsampling,
213 | use_bn=decoder_use_batchnorm))
214 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1))
215 |
216 | if encoder_freeze:
217 | self.freeze_encoder()
218 |
219 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later
220 | #
221 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) #
222 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
223 |
224 |
225 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64)
226 | self.direction_feature = ResidualUnit(64, 64)
227 | self.point_feature = ResidualUnit(64, 64)
228 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1)
229 | self.directionAtt = revAttention(1)
230 | self.direction_conv = nn.Conv2d(64, 9, kernel_size=1)
231 | self.maskAtt = revAttention(9)
232 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1)
233 |
234 | self.residual = ResidualUnit(64, 64)
235 |
236 |
237 |
238 | def freeze_encoder(self):
239 |
240 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
241 |
242 | for param in self.backbone.parameters():
243 | param.requires_grad = False
244 |
245 | def forward(self, *input):
246 |
247 | """ Forward propagation in U-Net. """
248 |
249 | x, features = self.forward_backbone(*input)
250 |
251 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
252 | skip_features = features[skip_name]
253 | x = upsample_block(x, skip_features)
254 |
255 | x_F1 = self.mask_feature(x)
256 |
257 | x_F2 = self.direction_feature(x_F1)
258 | x_F3 = self.point_feature(x_F2)
259 |
260 | x_direction = self.direction_conv(x_F2)
261 | x_point = self.point_conv(x_F3)
262 |
263 | x_F1_mask = self.residual(x_F1)
264 | x_final_mask = self.mask_conv(x_F1_mask)
265 |
266 |
267 |
268 | return x_final_mask, x_point, x_direction
269 |
270 | def forward_backbone(self, x):
271 |
272 | """ Forward propagation in backbone encoder network. """
273 |
274 | features = {None: None} if None in self.shortcut_features else dict()
275 | for name, child in self.backbone.named_children():
276 | #
277 | if(name == '0' and x.shape[1] !=3):
278 | x = self.child0(x)
279 | elif(name == 'conv1' and x.shape[1] !=3):
280 | x = self.child_conv1(x)
281 | else:
282 | x = child(x)
283 | #x = child(x)
284 | if name in self.shortcut_features:
285 | features[name] = x
286 | if name == self.bb_out_name:
287 | break
288 |
289 | return x, features
290 |
291 | def infer_skip_channels(self):
292 |
293 | """ Getting the number of channels at skip connections and at the output of the encoder. """
294 |
295 | x = torch.zeros(1, 3, 224, 224)
296 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder'
297 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution
298 |
299 | # forward run in backbone to count channels (dirty solution but works for *any* Module)
300 | for name, child in self.backbone.named_children():
301 | x = child(x)
302 | if name in self.shortcut_features:
303 | channels.append(x.shape[1])
304 | if name == self.bb_out_name:
305 | out_channels = x.shape[1]
306 | break
307 | return channels, out_channels
308 |
309 | def get_pretrained_parameters(self):
310 | for name, param in self.backbone.named_parameters():
311 | if not (self.replaced_conv1 and name == 'conv1.weight'):
312 | yield param
313 |
314 | def get_random_initialized_parameters(self):
315 | pretrained_param_names = set()
316 | for name, param in self.backbone.named_parameters():
317 | if not (self.replaced_conv1 and name == 'conv1.weight'):
318 | pretrained_param_names.add('backbone.{}'.format(name))
319 |
320 | for name, param in self.named_parameters():
321 | if name not in pretrained_param_names:
322 | yield param
323 |
324 |
325 | # if __name__ == "__main__":
326 |
327 | # # simple test run
328 | # net = Unet(backbone_name='resnet18')
329 |
330 | # criterion = nn.MSELoss()
331 | # optimizer = torch.optim.Adam(net.parameters())
332 | # print('Network initialized. Running a test batch.')
333 | # for _ in range(1):
334 | # with torch.set_grad_enabled(True):
335 | # batch = torch.empty(1, 3, 224, 224).normal_()
336 | # targets = torch.empty(1, 21, 224, 224).normal_()
337 |
338 | # out = net(batch)
339 | # loss = criterion(out, targets)
340 | # loss.backward()
341 | # optimizer.step()
342 | # print(out.shape)
343 |
344 | # print('fasza.')
345 |
--------------------------------------------------------------------------------
/models/dam/model_unet_rev1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models, datasets, transforms
4 | from torch.nn import functional as F
5 | import os
6 |
7 |
8 | class revAttention(nn.Module): #sSE
9 | def __init__(self, in_channels):
10 | super().__init__()
11 | self.Conv1x1 = nn.Conv2d(in_channels, 1, kernel_size=1, bias=False)
12 | self.norm = nn.Sigmoid()
13 |
14 | def forward(self, U, V):
15 | q = self.Conv1x1(V) # U:[bs,c,h,w] to q:[bs,1,h,w]
16 | q = self.norm(q)
17 | return U * (1+q)
18 |
19 |
20 |
21 |
22 | def get_backbone(name, pretrained=True):
23 |
24 | """ Loading backbone, defining names for skip-connections and encoder output. """
25 |
26 | # TODO: More backbones
27 |
28 | # loading backbone model
29 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}')
30 | if name == 'resnet18':
31 | backbone = models.resnet18(pretrained=pretrained)
32 | elif name == 'resnet34':
33 | backbone = models.resnet34(pretrained=pretrained)
34 | elif name == 'resnet50':
35 | backbone = models.resnet50(pretrained=pretrained)
36 | elif name == 'resnet101':
37 | backbone = models.resnet101(pretrained=pretrained)
38 | elif name == 'resnet152':
39 | backbone = models.resnet152(pretrained=pretrained)
40 | elif name == 'vgg16_bn':
41 | backbone = models.vgg16_bn(pretrained=pretrained).features
42 | elif name == 'vgg19_bn':
43 | backbone = models.vgg19_bn(pretrained=pretrained).features
44 | # elif name == 'inception_v3':
45 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False)
46 | elif name == 'densenet121':
47 | backbone = models.densenet121(pretrained=True).features
48 | elif name == 'densenet161':
49 | backbone = models.densenet161(pretrained=True).features
50 | elif name == 'densenet169':
51 | backbone = models.densenet169(pretrained=True).features
52 | elif name == 'densenet201':
53 | backbone = models.densenet201(pretrained=True).features
54 | elif name == 'unet_encoder':
55 | from unet_backbone import UnetEncoder
56 | backbone = UnetEncoder(3)
57 | else:
58 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
59 |
60 | # specifying skip feature and output names
61 | if name.startswith('resnet'):
62 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
63 | backbone_output = 'layer4'
64 | elif name == 'vgg16_bn':
65 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
66 | feature_names = ['5', '12', '22', '32', '42']
67 | backbone_output = '43'
68 | elif name == 'vgg19_bn':
69 | feature_names = ['5', '12', '25', '38', '51']
70 | backbone_output = '52'
71 | # elif name == 'inception_v3':
72 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e']
73 | # backbone_output = 'Mixed_7c'
74 | elif name.startswith('densenet'):
75 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3']
76 | backbone_output = 'denseblock4'
77 | elif name == 'unet_encoder':
78 | feature_names = ['module1', 'module2', 'module3', 'module4']
79 | backbone_output = 'module5'
80 | else:
81 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
82 |
83 | return backbone, feature_names, backbone_output
84 |
85 |
86 | class UpsampleBlock(nn.Module):
87 |
88 | # TODO: separate parametric and non-parametric classes?
89 | # TODO: skip connection concatenated OR added
90 |
91 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
92 | super(UpsampleBlock, self).__init__()
93 |
94 | self.parametric = parametric
95 | ch_out = ch_in/2 if ch_out is None else ch_out
96 |
97 | # first convolution: either transposed conv, or conv following the skip connection
98 | if parametric:
99 | # versions: kernel=4 padding=1, kernel=2 padding=0
100 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
101 | stride=2, padding=1, output_padding=0, bias=(not use_bn))
102 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
103 | else:
104 | self.up = None
105 | ch_in = ch_in + skip_in
106 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
107 | stride=1, padding=1, bias=(not use_bn))
108 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
109 |
110 | self.relu = nn.ReLU(inplace=True)
111 |
112 | # second convolution
113 | conv2_in = ch_out if not parametric else ch_out + skip_in
114 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
115 | stride=1, padding=1, bias=(not use_bn))
116 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None
117 |
118 | #def forward(self, x, skip_connection=None): #original code
119 | def forward(self, x, skip_connection=1): # hhl revised
120 |
121 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
122 | align_corners=None)
123 | if self.parametric:
124 | x = self.bn1(x) if self.bn1 is not None else x
125 | x = self.relu(x)
126 |
127 | if skip_connection is not None:
128 | diffY = skip_connection.size()[2] - x.size()[2]
129 | diffX = skip_connection.size()[3] - x.size()[3]
130 | x = F.pad(x, (diffX // 2, diffX - diffX // 2,
131 | diffY // 2, diffY - diffY // 2))
132 |
133 | x = torch.cat([x, skip_connection], dim=1)
134 |
135 | if not self.parametric:
136 | x = self.conv1(x)
137 | x = self.bn1(x) if self.bn1 is not None else x
138 | x = self.relu(x)
139 | x = self.conv2(x)
140 | x = self.bn2(x) if self.bn2 is not None else x
141 | x = self.relu(x)
142 |
143 | return x
144 |
145 |
146 |
147 | def conv3x3(in_channels, out_channels, stride=1):
148 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
149 |
150 | class ResidualUnit(nn.Module):
151 | def __init__(self, in_channels, out_channels):
152 | super(ResidualUnit, self).__init__()
153 | self.conv1 = conv3x3(in_channels, out_channels, stride=1)
154 | self.bn1 = nn.BatchNorm2d(out_channels)
155 | self.relu1 = nn.ReLU(inplace=True)
156 | self.conv2 = conv3x3(out_channels, out_channels, stride=1)
157 | self.bn2 = nn.BatchNorm2d(out_channels)
158 | self.relu2 = nn.ReLU(inplace=True)
159 | self.conv_1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
160 |
161 | def forward(self, x):
162 | residual = self.conv_1x1(x)
163 | out = self.conv1(x)
164 | out = self.bn1(out)
165 | out = self.relu1(out)
166 | out = self.conv2(out)
167 | out = self.bn2(out)
168 | out += residual
169 | out = self.relu2(out)
170 | return out
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 | class Unet(nn.Module):
181 |
182 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
183 |
184 | def __init__(self,
185 | backbone_name='resnet50',
186 | pretrained=True,
187 | encoder_freeze=False,
188 | classes=21,
189 | decoder_filters=(256, 128, 64, 32, 16),
190 | parametric_upsampling=True,
191 | shortcut_features='default',
192 | decoder_use_batchnorm=True):
193 | super(Unet, self).__init__()
194 |
195 | self.backbone_name = backbone_name
196 |
197 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained)
198 | shortcut_chs, bb_out_chs = self.infer_skip_channels()
199 | if shortcut_features != 'default':
200 | self.shortcut_features = shortcut_features
201 |
202 | # build decoder part
203 | self.upsample_blocks = nn.ModuleList()
204 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections
205 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
206 | num_blocks = len(self.shortcut_features)
207 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
208 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out))
209 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
210 | skip_in=shortcut_chs[num_blocks-i-1],
211 | parametric=parametric_upsampling,
212 | use_bn=decoder_use_batchnorm))
213 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1))
214 |
215 | if encoder_freeze:
216 | self.freeze_encoder()
217 |
218 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later
219 |
220 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) #
221 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
222 |
223 |
224 | self.mask_feature = ResidualUnit(decoder_filters[-1], 64)
225 | self.direction_feature = ResidualUnit(64, 64)
226 | self.point_feature = ResidualUnit(64, 64)
227 | self.point_conv = nn.Conv2d(64, 1, kernel_size=1)
228 | self.directionAtt = revAttention(1)
229 | self.direction_conv = nn.Conv2d(64, 9, kernel_size=1)
230 | self.maskAtt = revAttention(9)
231 | self.mask_conv = nn.Conv2d(64, 3, kernel_size=1)
232 |
233 |
234 |
235 |
236 |
237 | def freeze_encoder(self):
238 |
239 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
240 |
241 | for param in self.backbone.parameters():
242 | param.requires_grad = False
243 |
244 | def forward(self, *input):
245 |
246 | """ Forward propagation in U-Net. """
247 |
248 | x, features = self.forward_backbone(*input)
249 |
250 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
251 | skip_features = features[skip_name]
252 | x = upsample_block(x, skip_features)
253 |
254 | x_F1 = self.mask_feature(x)
255 |
256 | x_F2 = self.direction_feature(x_F1)
257 | x_F3 = self.point_feature(x_F2)
258 | x_point = self.point_conv(x_F3)
259 | x_F2_direction = self.directionAtt(x_F2, x_point)
260 | x_direction = self.direction_conv(x_F2_direction)
261 |
262 | x_F1_mask = self.maskAtt(x_F1, x_direction)
263 | x_final_mask = self.mask_conv(x_F1_mask)
264 |
265 |
266 | return x_final_mask, x_point, x_direction
267 |
268 | def forward_backbone(self, x):
269 |
270 | """ Forward propagation in backbone encoder network. """
271 |
272 | features = {None: None} if None in self.shortcut_features else dict()
273 | for name, child in self.backbone.named_children():
274 |
275 | if(name == '0' and x.shape[1] !=3):
276 | x = self.child0(x)
277 | elif(name == 'conv1' and x.shape[1] !=3):
278 | x = self.child_conv1(x)
279 | else:
280 | x = child(x)
281 | #x = child(x)
282 | if name in self.shortcut_features:
283 | features[name] = x
284 | if name == self.bb_out_name:
285 | break
286 |
287 | return x, features
288 |
289 | def infer_skip_channels(self):
290 |
291 | """ Getting the number of channels at skip connections and at the output of the encoder. """
292 |
293 | x = torch.zeros(1, 3, 224, 224)
294 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder'
295 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution
296 |
297 | # forward run in backbone to count channels (dirty solution but works for *any* Module)
298 | for name, child in self.backbone.named_children():
299 | x = child(x)
300 | if name in self.shortcut_features:
301 | channels.append(x.shape[1])
302 | if name == self.bb_out_name:
303 | out_channels = x.shape[1]
304 | break
305 | return channels, out_channels
306 |
307 | def get_pretrained_parameters(self):
308 | for name, param in self.backbone.named_parameters():
309 | if not (self.replaced_conv1 and name == 'conv1.weight'):
310 | yield param
311 |
312 | def get_random_initialized_parameters(self):
313 | pretrained_param_names = set()
314 | for name, param in self.backbone.named_parameters():
315 | if not (self.replaced_conv1 and name == 'conv1.weight'):
316 | pretrained_param_names.add('backbone.{}'.format(name))
317 |
318 | for name, param in self.named_parameters():
319 | if name not in pretrained_param_names:
320 | yield param
321 |
322 |
323 | # if __name__ == "__main__":
324 |
325 | # # simple test run
326 | # net = Unet(backbone_name='resnet18')
327 |
328 | # criterion = nn.MSELoss()
329 | # optimizer = torch.optim.Adam(net.parameters())
330 | # print('Network initialized. Running a test batch.')
331 | # for _ in range(1):
332 | # with torch.set_grad_enabled(True):
333 | # batch = torch.empty(1, 3, 224, 224).normal_()
334 | # targets = torch.empty(1, 21, 224, 224).normal_()
335 |
336 | # out = net(batch)
337 | # loss = criterion(out, targets)
338 | # loss.backward()
339 | # optimizer.step()
340 | # print(out.shape)
341 |
342 | # print('fasza.')
343 |
--------------------------------------------------------------------------------
/models/deeplabv3_plus.py:
--------------------------------------------------------------------------------
1 | from base.base_model import BaseModel
2 | import torch
3 | import math
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torchvision import models
7 | import torch.utils.model_zoo as model_zoo
8 | from hhl_utils.helpers import initialize_weights
9 | from itertools import chain
10 |
11 | '''
12 | -> ResNet BackBone
13 | '''
14 |
15 | class ResNet(nn.Module):
16 | def __init__(self, in_channels=3, output_stride=16, backbone='resnet101', pretrained=True):
17 | super(ResNet, self).__init__()
18 | model = getattr(models, backbone)(pretrained)
19 | if not pretrained or in_channels != 3:
20 | self.layer0 = nn.Sequential(
21 | nn.Conv2d(in_channels, 64, 7, stride=2, padding=3, bias=False),
22 | nn.BatchNorm2d(64),
23 | nn.ReLU(inplace=True),
24 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
25 | )
26 | initialize_weights(self.layer0)
27 | else:
28 | self.layer0 = nn.Sequential(*list(model.children())[:4])
29 |
30 | self.layer1 = model.layer1
31 | self.layer2 = model.layer2
32 | self.layer3 = model.layer3
33 | self.layer4 = model.layer4
34 |
35 | if output_stride == 16: s3, s4, d3, d4 = (2, 1, 1, 2)
36 | elif output_stride == 8: s3, s4, d3, d4 = (1, 1, 2, 4)
37 |
38 | if output_stride == 8:
39 | for n, m in self.layer3.named_modules():
40 | if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'):
41 | m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3)
42 | elif 'conv2' in n:
43 | m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3)
44 | elif 'downsample.0' in n:
45 | m.stride = (s3, s3)
46 |
47 | for n, m in self.layer4.named_modules():
48 | if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'):
49 | m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4)
50 | elif 'conv2' in n:
51 | m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4)
52 | elif 'downsample.0' in n:
53 | m.stride = (s4, s4)
54 |
55 | def forward(self, x):
56 | x = self.layer0(x)
57 | x = self.layer1(x)
58 | low_level_features = x
59 | x = self.layer2(x)
60 | x = self.layer3(x)
61 | x = self.layer4(x)
62 |
63 | return x, low_level_features
64 |
65 | '''
66 | -> (Aligned) Xception BackBone
67 | Pretrained model from https://github.com/Cadene/pretrained-models.pytorch
68 | by Remi Cadene
69 | '''
70 | class SeparableConv2d(nn.Module):
71 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=nn.BatchNorm2d):
72 | super(SeparableConv2d, self).__init__()
73 |
74 | if dilation > kernel_size//2: padding = dilation
75 | else: padding = kernel_size//2
76 |
77 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding=padding,
78 | dilation=dilation, groups=in_channels, bias=bias)
79 | self.bn = nn.BatchNorm2d(in_channels)
80 | self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, bias=bias)
81 |
82 | def forward(self, x):
83 | x = self.conv1(x)
84 | x = self.bn(x)
85 | x = self.pointwise(x)
86 | return x
87 |
88 |
89 | class Block(nn.Module):
90 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, exit_flow=False, use_1st_relu=True):
91 | super(Block, self).__init__()
92 |
93 | if in_channels != out_channels or stride !=1:
94 | self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False)
95 | self.skipbn = nn.BatchNorm2d(out_channels)
96 | else: self.skip = None
97 |
98 | rep = []
99 | self.relu = nn.ReLU(inplace=True)
100 |
101 | rep.append(self.relu)
102 | rep.append(SeparableConv2d(in_channels, out_channels, 3, stride=1, dilation=dilation))
103 | rep.append(nn.BatchNorm2d(out_channels))
104 |
105 | rep.append(self.relu)
106 | rep.append(SeparableConv2d(out_channels, out_channels, 3, stride=1, dilation=dilation))
107 | rep.append(nn.BatchNorm2d(out_channels))
108 |
109 | rep.append(self.relu)
110 | rep.append(SeparableConv2d(out_channels, out_channels, 3, stride=stride, dilation=dilation))
111 | rep.append(nn.BatchNorm2d(out_channels))
112 |
113 | if exit_flow:
114 | rep[3:6] = rep[:3]
115 | rep[:3] = [
116 | self.relu,
117 | SeparableConv2d(in_channels, in_channels, 3, 1, dilation),
118 | nn.BatchNorm2d(in_channels)]
119 |
120 | if not use_1st_relu: rep = rep[1:]
121 | self.rep = nn.Sequential(*rep)
122 |
123 | def forward(self, x):
124 | output = self.rep(x)
125 | if self.skip is not None:
126 | skip = self.skip(x)
127 | skip = self.skipbn(skip)
128 | else:
129 | skip = x
130 |
131 | x = output + skip
132 | return x
133 |
134 | class Xception(nn.Module):
135 | def __init__(self, output_stride=16, in_channels=3, pretrained=True):
136 | super(Xception, self).__init__()
137 |
138 | # Stride for block 3 (entry flow), and the dilation rates for middle flow and exit flow
139 | if output_stride == 16: b3_s, mf_d, ef_d = 2, 1, (1, 2)
140 | if output_stride == 8: b3_s, mf_d, ef_d = 1, 2, (2, 4)
141 |
142 | # Entry Flow
143 | self.conv1 = nn.Conv2d(in_channels, 32, 3, 2, padding=1, bias=False)
144 | self.bn1 = nn.BatchNorm2d(32)
145 | self.relu = nn.ReLU(inplace=True)
146 | self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1, bias=False)
147 | self.bn2 = nn.BatchNorm2d(64)
148 |
149 | self.block1 = Block(64, 128, stride=2, dilation=1, use_1st_relu=False)
150 | self.block2 = Block(128, 256, stride=2, dilation=1)
151 | self.block3 = Block(256, 728, stride=b3_s, dilation=1)
152 |
153 | # Middle Flow
154 | for i in range(16):
155 | exec(f'self.block{i+4} = Block(728, 728, stride=1, dilation=mf_d)')
156 |
157 | # Exit flow
158 | self.block20 = Block(728, 1024, stride=1, dilation=ef_d[0], exit_flow=True)
159 |
160 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=ef_d[1])
161 | self.bn3 = nn.BatchNorm2d(1536)
162 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=ef_d[1])
163 | self.bn4 = nn.BatchNorm2d(1536)
164 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=ef_d[1])
165 | self.bn5 = nn.BatchNorm2d(2048)
166 |
167 | initialize_weights(self)
168 | if pretrained: self._load_pretrained_model()
169 |
170 |
171 | def _load_pretrained_model(self):
172 | url = 'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth'
173 | pretrained_weights = model_zoo.load_url(url)
174 | state_dict = self.state_dict()
175 | model_dict = {}
176 |
177 | for k, v in pretrained_weights.items():
178 | if k in state_dict:
179 | if 'pointwise' in k:
180 | v = v.unsqueeze(-1).unsqueeze(-1) # [C, C] -> [C, C, 1, 1]
181 | if k.startswith('block11'):
182 | # In Xception there is only 8 blocks in Middle flow
183 | model_dict[k] = v
184 | for i in range(8):
185 | model_dict[k.replace('block11', f'block{i+12}')] = v
186 | elif k.startswith('block12'):
187 | model_dict[k.replace('block12', 'block20')] = v
188 | elif k.startswith('bn3'):
189 | model_dict[k] = v
190 | model_dict[k.replace('bn3', 'bn4')] = v
191 | elif k.startswith('conv4'):
192 | model_dict[k.replace('conv4', 'conv5')] = v
193 | elif k.startswith('bn4'):
194 | model_dict[k.replace('bn4', 'bn5')] = v
195 | else:
196 | model_dict[k] = v
197 |
198 | state_dict.update(model_dict)
199 | self.load_state_dict(state_dict)
200 |
201 | def forward(self, x):
202 | # Entry flow
203 | x = self.conv1(x)
204 | x = self.bn1(x)
205 | x = self.relu(x)
206 | x = self.conv2(x)
207 | x = self.bn2(x)
208 | x = self.block1(x)
209 | low_level_features = x
210 | x = F.relu(x)
211 | x = self.block2(x)
212 | x = self.block3(x)
213 |
214 | # Middle flow
215 | x = self.block4(x)
216 | x = self.block5(x)
217 | x = self.block6(x)
218 | x = self.block7(x)
219 | x = self.block8(x)
220 | x = self.block9(x)
221 | x = self.block10(x)
222 | x = self.block11(x)
223 | x = self.block12(x)
224 | x = self.block13(x)
225 | x = self.block14(x)
226 | x = self.block15(x)
227 | x = self.block16(x)
228 | x = self.block17(x)
229 | x = self.block18(x)
230 | x = self.block19(x)
231 |
232 | # Exit flow
233 | x = self.block20(x)
234 | x = self.relu(x)
235 | x = self.conv3(x)
236 | x = self.bn3(x)
237 | x = self.relu(x)
238 |
239 | x = self.conv4(x)
240 | x = self.bn4(x)
241 | x = self.relu(x)
242 |
243 | x = self.conv5(x)
244 | x = self.bn5(x)
245 | x = self.relu(x)
246 |
247 | return x, low_level_features
248 |
249 | '''
250 | -> The Atrous Spatial Pyramid Pooling
251 | '''
252 |
253 | def assp_branch(in_channels, out_channles, kernel_size, dilation):
254 | padding = 0 if kernel_size == 1 else dilation
255 | return nn.Sequential(
256 | nn.Conv2d(in_channels, out_channles, kernel_size, padding=padding, dilation=dilation, bias=False),
257 | nn.BatchNorm2d(out_channles),
258 | nn.ReLU(inplace=True))
259 |
260 | class ASSP(nn.Module):
261 | def __init__(self, in_channels, output_stride):
262 | super(ASSP, self).__init__()
263 |
264 | assert output_stride in [8, 16], 'Only output strides of 8 or 16 are suported'
265 | if output_stride == 16: dilations = [1, 6, 12, 18]
266 | elif output_stride == 8: dilations = [1, 12, 24, 36]
267 |
268 | self.aspp1 = assp_branch(in_channels, 256, 1, dilation=dilations[0])
269 | self.aspp2 = assp_branch(in_channels, 256, 3, dilation=dilations[1])
270 | self.aspp3 = assp_branch(in_channels, 256, 3, dilation=dilations[2])
271 | self.aspp4 = assp_branch(in_channels, 256, 3, dilation=dilations[3])
272 |
273 | self.avg_pool = nn.Sequential(
274 | nn.AdaptiveAvgPool2d((1, 1)),
275 | nn.Conv2d(in_channels, 256, 1, bias=False),
276 | nn.BatchNorm2d(256),
277 | nn.ReLU(inplace=True))
278 |
279 | self.conv1 = nn.Conv2d(256*5, 256, 1, bias=False)
280 | self.bn1 = nn.BatchNorm2d(256)
281 | self.relu = nn.ReLU(inplace=True)
282 | self.dropout = nn.Dropout(0.5)
283 |
284 | initialize_weights(self)
285 |
286 | def forward(self, x):
287 | x1 = self.aspp1(x)
288 | x2 = self.aspp2(x)
289 | x3 = self.aspp3(x)
290 | x4 = self.aspp4(x)
291 | x5 = F.interpolate(self.avg_pool(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True)
292 |
293 | x = self.conv1(torch.cat((x1, x2, x3, x4, x5), dim=1))
294 | x = self.bn1(x)
295 | x = self.dropout(self.relu(x))
296 |
297 | return x
298 |
299 | '''
300 | -> Decoder
301 | '''
302 |
303 | class Decoder(nn.Module):
304 | def __init__(self, low_level_channels, num_classes):
305 | super(Decoder, self).__init__()
306 | self.conv1 = nn.Conv2d(low_level_channels, 48, 1, bias=False)
307 | self.bn1 = nn.BatchNorm2d(48)
308 | self.relu = nn.ReLU(inplace=True)
309 |
310 | # Table 2, best performance with two 3x3 convs
311 | self.output = nn.Sequential(
312 | nn.Conv2d(48+256, 256, 3, stride=1, padding=1, bias=False),
313 | nn.BatchNorm2d(256),
314 | nn.ReLU(inplace=True),
315 | nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),
316 | nn.BatchNorm2d(256),
317 | nn.ReLU(inplace=True),
318 | nn.Dropout(0.1),
319 | nn.Conv2d(256, num_classes, 1, stride=1),
320 | )
321 | initialize_weights(self)
322 |
323 | def forward(self, x, low_level_features):
324 | low_level_features = self.conv1(low_level_features)
325 | low_level_features = self.relu(self.bn1(low_level_features))
326 | H, W = low_level_features.size(2), low_level_features.size(3)
327 |
328 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
329 | x = self.output(torch.cat((low_level_features, x), dim=1))
330 | return x
331 |
332 | '''
333 | -> Deeplab V3 +
334 | '''
335 |
336 | class DeepLab(BaseModel):
337 | def __init__(self, num_classes, in_channels=3, backbone='xception', pretrained=False,#pretrained=True, hhl20191020gai
338 | output_stride=16, freeze_bn=False, **_):
339 |
340 | super(DeepLab, self).__init__()
341 | assert ('xception' or 'resnet' in backbone)
342 | if 'resnet' in backbone:
343 | self.backbone = ResNet(in_channels=in_channels, output_stride=output_stride, pretrained=pretrained)
344 | low_level_channels = 256
345 | else:
346 | self.backbone = Xception(output_stride=output_stride, pretrained=pretrained)
347 | low_level_channels = 128
348 |
349 | self.ASSP = ASSP(in_channels=2048, output_stride=output_stride)
350 | self.decoder = Decoder(low_level_channels, num_classes)
351 |
352 | if freeze_bn: self.freeze_bn()
353 |
354 | def forward(self, x):
355 | H, W = x.size(2), x.size(3)
356 | x, low_level_features = self.backbone(x)
357 | x = self.ASSP(x)
358 | x = self.decoder(x, low_level_features)
359 | x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)
360 | return x
361 |
362 | # Two functions to yield the parameters of the backbone
363 | # & Decoder / ASSP to use differentiable learning rates
364 | # FIXME: in xception, we use the parameters from xception and not aligned xception
365 | # better to have higher lr for this backbone
366 |
367 | def get_backbone_params(self):
368 | return self.backbone.parameters()
369 |
370 | def get_decoder_params(self):
371 | return chain(self.ASSP.parameters(), self.decoder.parameters())
372 |
373 | def freeze_bn(self):
374 | for module in self.modules():
375 | if isinstance(module, nn.BatchNorm2d): module.eval()
376 |
377 |
--------------------------------------------------------------------------------
/models/fcn8.py:
--------------------------------------------------------------------------------
1 | from base.base_model import BaseModel
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import models
5 | from hhl_utils.helpers import get_upsampling_weight
6 | import torch
7 | from itertools import chain
8 |
9 | class FCN8(BaseModel):
10 | def __init__(self, num_classes, pretrained=False, freeze_bn=False, **_):#pretrained=True hhl20191020gai
11 | super(FCN8, self).__init__()
12 | vgg = models.vgg16(pretrained)
13 | features = list(vgg.features.children())
14 | classifier = list(vgg.classifier.children())
15 |
16 | # Pad the input to enable small inputs and allow matching feature maps
17 | features[0].padding = (100, 100)
18 |
19 | # Enbale ceil in max pool, to avoid different sizes when upsampling
20 | for layer in features:
21 | if 'MaxPool' in layer.__class__.__name__:
22 | layer.ceil_mode = True
23 |
24 | # Extract pool3, pool4 and pool5 from the VGG net
25 | self.pool3 = nn.Sequential(*features[:17])
26 | self.pool4 = nn.Sequential(*features[17:24])
27 | self.pool5 = nn.Sequential(*features[24:])
28 |
29 | # Adjust the depth of pool3 and pool4 to num_classes
30 | self.adj_pool3 = nn.Conv2d(256, num_classes, kernel_size=1)
31 | self.adj_pool4 = nn.Conv2d(512, num_classes, kernel_size=1)
32 |
33 | # Replace the FC layer of VGG with conv layers
34 | conv6 = nn.Conv2d(512, 4096, kernel_size=7)
35 | conv7 = nn.Conv2d(4096, 4096, kernel_size=1)
36 | output = nn.Conv2d(4096, num_classes, kernel_size=1)
37 |
38 | # Copy the weights from VGG's FC pretrained layers
39 | conv6.weight.data.copy_(classifier[0].weight.data.view(
40 | conv6.weight.data.size()))
41 | conv6.bias.data.copy_(classifier[0].bias.data)
42 |
43 | conv7.weight.data.copy_(classifier[3].weight.data.view(
44 | conv7.weight.data.size()))
45 | conv7.bias.data.copy_(classifier[3].bias.data)
46 |
47 | # Get the outputs
48 | self.output = nn.Sequential(conv6, nn.ReLU(inplace=True), nn.Dropout(),
49 | conv7, nn.ReLU(inplace=True), nn.Dropout(),
50 | output)
51 |
52 | # We'll need three upsampling layers, upsampling (x2 +2) the ouputs
53 | # upsampling (x2 +2) addition of pool4 and upsampled output
54 | # upsampling (x8 +8) the final value (pool3 + added output and pool4)
55 | self.up_output = nn.ConvTranspose2d(num_classes, num_classes,
56 | kernel_size=4, stride=2, bias=False)
57 | self.up_pool4_out = nn.ConvTranspose2d(num_classes, num_classes,
58 | kernel_size=4, stride=2, bias=False)
59 | self.up_final = nn.ConvTranspose2d(num_classes, num_classes,
60 | kernel_size=16, stride=8, bias=False)
61 |
62 | # We'll use guassian kernels for the upsampling weights
63 | self.up_output.weight.data.copy_(
64 | get_upsampling_weight(num_classes, num_classes, 4))
65 | self.up_pool4_out.weight.data.copy_(
66 | get_upsampling_weight(num_classes, num_classes, 4))
67 | self.up_final.weight.data.copy_(
68 | get_upsampling_weight(num_classes, num_classes, 16))
69 |
70 | # We'll freeze the wights, this is a fixed upsampling and not deconv
71 | for m in self.modules():
72 | if isinstance(m, nn.ConvTranspose2d):
73 | m.weight.requires_grad = False
74 | if freeze_bn: self.freeze_bn()
75 |
76 | def forward(self, x):
77 | imh_H, img_W = x.size()[2], x.size()[3]
78 |
79 | # Forward the image
80 | pool3 = self.pool3(x)
81 | pool4 = self.pool4(pool3)
82 | pool5 = self.pool5(pool4)
83 |
84 | # Get the outputs and upsmaple them
85 | output = self.output(pool5)
86 | up_output = self.up_output(output)
87 |
88 | # Adjust pool4 and add the uped-outputs to pool4
89 | adjstd_pool4 = self.adj_pool4(0.01 * pool4)
90 | add_out_pool4 = self.up_pool4_out(adjstd_pool4[:, :, 5: (5 + up_output.size()[2]),
91 | 5: (5 + up_output.size()[3])]
92 | + up_output)
93 |
94 | # Adjust pool3 and add it to the uped last addition
95 | adjstd_pool3 = self.adj_pool3(0.0001 * pool3)
96 | final_value = self.up_final(adjstd_pool3[:, :, 9: (9 + add_out_pool4.size()[2]), 9: (9 + add_out_pool4.size()[3])]
97 | + add_out_pool4)
98 |
99 | # Remove the corresponding padded regions to the input img size
100 | final_value = final_value[:, :, 31: (31 + imh_H), 31: (31 + img_W)].contiguous()
101 | return final_value
102 |
103 | def get_backbone_params(self):
104 | return chain(self.pool3.parameters(), self.pool4.parameters(), self.pool5.parameters(), self.output.parameters())
105 |
106 | def get_decoder_params(self):
107 | return chain(self.up_output.parameters(), self.adj_pool4.parameters(), self.up_pool4_out.parameters(),
108 | self.adj_pool3.parameters(), self.up_final.parameters())
109 |
110 | def freeze_bn(self):
111 | for module in self.modules():
112 | if isinstance(module, nn.BatchNorm2d): module.eval()
113 |
114 |
--------------------------------------------------------------------------------
/models/model_unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models, datasets, transforms
4 | from torch.nn import functional as F
5 | import os
6 |
7 |
8 | def get_backbone(name, pretrained=True):
9 |
10 | """ Loading backbone, defining names for skip-connections and encoder output. """
11 |
12 | # TODO: More backbones
13 |
14 | # loading backbone model
15 | os.environ['TORCH_HOME'] = os.path.join(os.getcwd(), f'pretrained_models/{name}')
16 | if name == 'resnet18':
17 | backbone = models.resnet18(pretrained=pretrained)
18 | elif name == 'resnet34':
19 | backbone = models.resnet34(pretrained=pretrained)
20 | elif name == 'resnet50':
21 | backbone = models.resnet50(pretrained=pretrained)
22 | elif name == 'resnet101':
23 | backbone = models.resnet101(pretrained=pretrained)
24 | elif name == 'resnet152':
25 | backbone = models.resnet152(pretrained=pretrained)
26 | elif name == 'vgg16_bn':
27 | backbone = models.vgg16_bn(pretrained=pretrained).features
28 | elif name == 'vgg19_bn':
29 | backbone = models.vgg19_bn(pretrained=pretrained).features
30 | # elif name == 'inception_v3':
31 | # backbone = models.inception_v3(pretrained=pretrained, aux_logits=False)
32 | elif name == 'densenet121':
33 | backbone = models.densenet121(pretrained=True).features
34 | elif name == 'densenet161':
35 | backbone = models.densenet161(pretrained=True).features
36 | elif name == 'densenet169':
37 | backbone = models.densenet169(pretrained=True).features
38 | elif name == 'densenet201':
39 | backbone = models.densenet201(pretrained=True).features
40 | elif name == 'unet_encoder':
41 | from unet_backbone import UnetEncoder
42 | backbone = UnetEncoder(3)
43 | else:
44 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
45 |
46 | # specifying skip feature and output names
47 | if name.startswith('resnet'):
48 | feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
49 | backbone_output = 'layer4'
50 | elif name == 'vgg16_bn':
51 | # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
52 | feature_names = ['5', '12', '22', '32', '42']
53 | backbone_output = '43'
54 | elif name == 'vgg19_bn':
55 | feature_names = ['5', '12', '25', '38', '51']
56 | backbone_output = '52'
57 | # elif name == 'inception_v3':
58 | # feature_names = [None, 'Mixed_5d', 'Mixed_6e']
59 | # backbone_output = 'Mixed_7c'
60 | elif name.startswith('densenet'):
61 | feature_names = [None, 'relu0', 'denseblock1', 'denseblock2', 'denseblock3']
62 | backbone_output = 'denseblock4'
63 | elif name == 'unet_encoder':
64 | feature_names = ['module1', 'module2', 'module3', 'module4']
65 | backbone_output = 'module5'
66 | else:
67 | raise NotImplemented('{} backbone model is not implemented so far.'.format(name))
68 |
69 | return backbone, feature_names, backbone_output
70 |
71 |
72 | class UpsampleBlock(nn.Module):
73 |
74 | # TODO: separate parametric and non-parametric classes?
75 | # TODO: skip connection concatenated OR added
76 |
77 | def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
78 | super(UpsampleBlock, self).__init__()
79 |
80 | self.parametric = parametric
81 | ch_out = ch_in/2 if ch_out is None else ch_out
82 |
83 | # first convolution: either transposed conv, or conv following the skip connection
84 | if parametric:
85 | # versions: kernel=4 padding=1, kernel=2 padding=0
86 | self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
87 | stride=2, padding=1, output_padding=0, bias=(not use_bn))
88 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
89 | else:
90 | self.up = None
91 | ch_in = ch_in + skip_in
92 | self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
93 | stride=1, padding=1, bias=(not use_bn))
94 | self.bn1 = nn.BatchNorm2d(ch_out) if use_bn else None
95 |
96 | self.relu = nn.ReLU(inplace=True)
97 |
98 | # second convolution
99 | conv2_in = ch_out if not parametric else ch_out + skip_in
100 | self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
101 | stride=1, padding=1, bias=(not use_bn))
102 | self.bn2 = nn.BatchNorm2d(ch_out) if use_bn else None
103 |
104 | #def forward(self, x, skip_connection=None): #
105 | def forward(self, x, skip_connection=1): #
106 |
107 | x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
108 | align_corners=None)
109 | if self.parametric:
110 | x = self.bn1(x) if self.bn1 is not None else x
111 | x = self.relu(x)
112 |
113 | if skip_connection is not None:
114 | # Padding in case the incomping volumes are of different sizes #hhl20200413add
115 | diffY = skip_connection.size()[2] - x.size()[2]
116 | diffX = skip_connection.size()[3] - x.size()[3]
117 | x = F.pad(x, (diffX // 2, diffX - diffX // 2,
118 | diffY // 2, diffY - diffY // 2))
119 |
120 | x = torch.cat([x, skip_connection], dim=1)
121 |
122 | if not self.parametric:
123 | x = self.conv1(x)
124 | x = self.bn1(x) if self.bn1 is not None else x
125 | x = self.relu(x)
126 | x = self.conv2(x)
127 | x = self.bn2(x) if self.bn2 is not None else x
128 | x = self.relu(x)
129 |
130 | return x
131 |
132 |
133 | class Unet(nn.Module):
134 |
135 | """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
136 |
137 | def __init__(self,
138 | backbone_name='resnet50',
139 | pretrained=True,
140 | encoder_freeze=False,
141 | classes=21,
142 | decoder_filters=(256, 128, 64, 32, 16),
143 | parametric_upsampling=True,
144 | shortcut_features='default',
145 | decoder_use_batchnorm=True):
146 | super(Unet, self).__init__()
147 |
148 | self.backbone_name = backbone_name
149 |
150 | self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(backbone_name, pretrained=pretrained)
151 | shortcut_chs, bb_out_chs = self.infer_skip_channels()
152 | if shortcut_features != 'default':
153 | self.shortcut_features = shortcut_features
154 |
155 | # build decoder part
156 | self.upsample_blocks = nn.ModuleList()
157 | decoder_filters = decoder_filters[:len(self.shortcut_features)] # avoiding having more blocks than skip connections
158 | decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
159 | num_blocks = len(self.shortcut_features)
160 | for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
161 | print('upsample_blocks[{}] in: {} out: {}'.format(i, filters_in, filters_out))
162 | self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
163 | skip_in=shortcut_chs[num_blocks-i-1],
164 | parametric=parametric_upsampling,
165 | use_bn=decoder_use_batchnorm))
166 | self.final_conv = nn.Conv2d(decoder_filters[-1], classes, kernel_size=(1, 1))
167 |
168 | if encoder_freeze:
169 | self.freeze_encoder()
170 |
171 | self.replaced_conv1 = False # for accommodating inputs with different number of channels later
172 |
173 | #
174 | self.child0 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) #
175 | self.child_conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
176 |
177 | def freeze_encoder(self):
178 |
179 | """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
180 |
181 | for param in self.backbone.parameters():
182 | param.requires_grad = False
183 |
184 | def forward(self, *input):
185 |
186 | """ Forward propagation in U-Net. """
187 |
188 | x, features = self.forward_backbone(*input)
189 |
190 | for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
191 | skip_features = features[skip_name]
192 | x = upsample_block(x, skip_features)
193 |
194 | x = self.final_conv(x)
195 |
196 | return x
197 |
198 | def forward_backbone(self, x):
199 |
200 | """ Forward propagation in backbone encoder network. """
201 | #print('x.shape = ',x.shape)
202 | features = {None: None} if None in self.shortcut_features else dict()
203 | for name, child in self.backbone.named_children():
204 | #print(name,child)
205 | #
206 | if(name == '0' and x.shape[1] !=3):
207 | x = self.child0(x)
208 | elif(name == 'conv1' and x.shape[1] !=3):
209 | x = self.child_conv1(x)
210 | else:
211 | x = child(x)
212 |
213 | if name in self.shortcut_features:
214 | features[name] = x
215 | if name == self.bb_out_name:
216 | break
217 |
218 | return x, features
219 |
220 | def infer_skip_channels(self):
221 |
222 | """ Getting the number of channels at skip connections and at the output of the encoder. """
223 |
224 | x = torch.zeros(1, 3, 224, 224)
225 | has_fullres_features = self.backbone_name.startswith('vgg') or self.backbone_name == 'unet_encoder'
226 | channels = [] if has_fullres_features else [0] # only VGG has features at full resolution
227 |
228 | # forward run in backbone to count channels (dirty solution but works for *any* Module)
229 | for name, child in self.backbone.named_children():
230 | x = child(x)
231 | if name in self.shortcut_features:
232 | channels.append(x.shape[1])
233 | if name == self.bb_out_name:
234 | out_channels = x.shape[1]
235 | break
236 | return channels, out_channels
237 |
238 | def get_pretrained_parameters(self):
239 | for name, param in self.backbone.named_parameters():
240 | if not (self.replaced_conv1 and name == 'conv1.weight'):
241 | yield param
242 |
243 | def get_random_initialized_parameters(self):
244 | pretrained_param_names = set()
245 | for name, param in self.backbone.named_parameters():
246 | if not (self.replaced_conv1 and name == 'conv1.weight'):
247 | pretrained_param_names.add('backbone.{}'.format(name))
248 |
249 | for name, param in self.named_parameters():
250 | if name not in pretrained_param_names:
251 | yield param
252 |
253 |
254 | # if __name__ == "__main__":
255 |
256 | # # simple test run
257 | # net = Unet(backbone_name='resnet18')
258 |
259 | # criterion = nn.MSELoss()
260 | # optimizer = torch.optim.Adam(net.parameters())
261 | # print('Network initialized. Running a test batch.')
262 | # for _ in range(1):
263 | # with torch.set_grad_enabled(True):
264 | # batch = torch.empty(1, 3, 224, 224).normal_()
265 | # targets = torch.empty(1, 21, 224, 224).normal_()
266 |
267 | # out = net(batch)
268 | # loss = criterion(out, targets)
269 | # loss.backward()
270 | # optimizer.step()
271 | # print(out.shape)
272 |
273 | # print('fasza.')
274 |
--------------------------------------------------------------------------------
/models/pspnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 | from models import resnet
6 | from torchvision import models
7 | from base.base_model import BaseModel
8 | from hhl_utils.helpers import initialize_weights, set_trainable
9 | from itertools import chain
10 |
11 | class _PSPModule(nn.Module):
12 | def __init__(self, in_channels, bin_sizes, norm_layer):
13 | super(_PSPModule, self).__init__()
14 | out_channels = in_channels // len(bin_sizes)
15 | self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s, norm_layer)
16 | for b_s in bin_sizes])
17 | self.bottleneck = nn.Sequential(
18 | nn.Conv2d(in_channels+(out_channels * len(bin_sizes)), out_channels,
19 | kernel_size=3, padding=1, bias=False),
20 | norm_layer(out_channels),
21 | nn.ReLU(inplace=True),
22 | nn.Dropout2d(0.1)
23 | )
24 |
25 | def _make_stages(self, in_channels, out_channels, bin_sz, norm_layer):
26 | prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
27 | conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
28 | bn = norm_layer(out_channels)
29 | relu = nn.ReLU(inplace=True)
30 | return nn.Sequential(prior, conv, bn, relu)
31 |
32 | def forward(self, features):
33 | h, w = features.size()[2], features.size()[3]
34 | pyramids = [features]
35 | pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear',
36 | align_corners=True) for stage in self.stages])
37 | output = self.bottleneck(torch.cat(pyramids, dim=1))
38 | return output
39 |
40 |
41 | class PSPNet(BaseModel):
42 | def __init__(self, num_classes, in_channels=3, backbone='resnet152', pretrained=False, use_aux=True, freeze_bn=False, freeze_backbone=False):#pretrained=True hhl20191020gai
43 | super(PSPNet, self).__init__()
44 | # TODO: Use synch batchnorm
45 | norm_layer = nn.BatchNorm2d
46 | model = getattr(resnet, backbone)(pretrained, norm_layer=norm_layer, )
47 | m_out_sz = model.fc.in_features
48 | self.use_aux = use_aux
49 |
50 | self.initial = nn.Sequential(*list(model.children())[:4])
51 | if in_channels != 3:
52 | self.initial[0] = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
53 | self.initial = nn.Sequential(*self.initial)
54 |
55 | self.layer1 = model.layer1
56 | self.layer2 = model.layer2
57 | self.layer3 = model.layer3
58 | self.layer4 = model.layer4
59 |
60 | self.master_branch = nn.Sequential(
61 | _PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=norm_layer),
62 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)
63 | )
64 |
65 | self.auxiliary_branch = nn.Sequential(
66 | nn.Conv2d(m_out_sz//2, m_out_sz//4, kernel_size=3, padding=1, bias=False),
67 | norm_layer(m_out_sz//4),
68 | nn.ReLU(inplace=True),
69 | nn.Dropout2d(0.1),
70 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)
71 | )
72 |
73 | initialize_weights(self.master_branch, self.auxiliary_branch)
74 | if freeze_bn: self.freeze_bn()
75 | if freeze_backbone:
76 | set_trainable([self.initial, self.layer1, self.layer2, self.layer3, self.layer4], False)
77 |
78 | def forward(self, x):
79 | input_size = (x.size()[2], x.size()[3])
80 | x = self.initial(x)
81 | x = self.layer1(x)
82 | x = self.layer2(x)
83 | x_aux = self.layer3(x)
84 | x = self.layer4(x_aux)
85 |
86 | output = self.master_branch(x)
87 | output = F.interpolate(output, size=input_size, mode='bilinear', align_corners=False)# UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0.
88 | output = output[:, :, :input_size[0], :input_size[1]]
89 |
90 | if self.training and self.use_aux:
91 | aux = self.auxiliary_branch(x_aux)
92 | aux = F.interpolate(aux, size=input_size, mode='bilinear', align_corners=False)# UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0.
93 | aux = aux[:, :, :input_size[0], :input_size[1]]
94 | return output, aux
95 | return output
96 |
97 | def get_backbone_params(self):
98 | return chain(self.initial.parameters(), self.layer1.parameters(), self.layer2.parameters(),
99 | self.layer3.parameters(), self.layer4.parameters())
100 |
101 | def get_decoder_params(self):
102 | return chain(self.master_branch.parameters(), self.auxiliary_branch.parameters())
103 |
104 | def freeze_bn(self):
105 | for module in self.modules():
106 | if isinstance(module, nn.BatchNorm2d): module.eval()
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 | ## PSP with dense net as the backbone
117 |
118 | class PSPDenseNet(BaseModel):
119 | def __init__(self, num_classes, in_channels=3, backbone='densenet201', pretrained=False, use_aux=True, freeze_bn=False, **_):#pretrained=True hhl20191020gai
120 | super(PSPDenseNet, self).__init__()
121 | self.use_aux = use_aux
122 | model = getattr(models, backbone)(pretrained)
123 | m_out_sz = model.classifier.in_features
124 | aux_out_sz = model.features.transition3.conv.out_channels
125 |
126 | if not pretrained or in_channels != 3:
127 | # If we're training from scratch, better to use 3x3 convs
128 | block0 = [nn.Conv2d(in_channels, 64, 3, stride=2, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)]
129 | block0.extend(
130 | [nn.Conv2d(64, 64, 3, bias=False), nn.BatchNorm2d(64), nn.ReLU(inplace=True)] * 2
131 | )
132 | self.block0 = nn.Sequential(
133 | *block0,
134 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
135 | )
136 | initialize_weights(self.block0)
137 | else:
138 | self.block0 = nn.Sequential(*list(model.features.children())[:4])
139 |
140 | self.block1 = model.features.denseblock1
141 | self.block2 = model.features.denseblock2
142 | self.block3 = model.features.denseblock3
143 | self.block4 = model.features.denseblock4
144 |
145 | self.transition1 = model.features.transition1
146 | # No pooling
147 | self.transition2 = nn.Sequential(
148 | *list(model.features.transition2.children())[:-1])
149 | self.transition3 = nn.Sequential(
150 | *list(model.features.transition3.children())[:-1])
151 |
152 | for n, m in self.block3.named_modules():
153 | if 'conv2' in n:
154 | m.dilation, m.padding = (2,2), (2,2)
155 | for n, m in self.block4.named_modules():
156 | if 'conv2' in n:
157 | m.dilation, m.padding = (4,4), (4,4)
158 |
159 | self.master_branch = nn.Sequential(
160 | _PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=nn.BatchNorm2d),
161 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)
162 | )
163 |
164 | self.auxiliary_branch = nn.Sequential(
165 | nn.Conv2d(aux_out_sz, m_out_sz//4, kernel_size=3, padding=1, bias=False),
166 | nn.BatchNorm2d(m_out_sz//4),
167 | nn.ReLU(inplace=True),
168 | nn.Dropout2d(0.1),
169 | nn.Conv2d(m_out_sz//4, num_classes, kernel_size=1)
170 | )
171 |
172 | initialize_weights(self.master_branch, self.auxiliary_branch)
173 | if freeze_bn: self.freeze_bn()
174 |
175 | def forward(self, x):
176 | input_size = (x.size()[2], x.size()[3])
177 |
178 | x = self.block0(x)
179 | x = self.block1(x)
180 | x = self.transition1(x)
181 | x = self.block2(x)
182 | x = self.transition2(x)
183 | x = self.block3(x)
184 | x_aux = self.transition3(x)
185 | x = self.block4(x_aux)
186 |
187 | output = self.master_branch(x)
188 | output = F.interpolate(output, size=input_size, mode='bilinear', align_corners=False)#UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0.
189 |
190 | if self.training and self.use_aux:
191 | aux = self.auxiliary_branch(x_aux)
192 | aux = F.interpolate(aux, size=input_size, mode='bilinear', align_corners=False)#UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0.
193 | return output, aux
194 | return output
195 |
196 | def get_backbone_params(self):
197 | return chain(self.block0.parameters(), self.block1.parameters(), self.block2.parameters(),
198 | self.block3.parameters(), self.transition1.parameters(), self.transition2.parameters(),
199 | self.transition3.parameters())
200 |
201 | def get_decoder_params(self):
202 | return chain(self.master_branch.parameters(), self.auxiliary_branch.parameters())
203 |
204 | def freeze_bn(self):
205 | for module in self.modules():
206 | if isinstance(module, nn.BatchNorm2d): module.eval()
--------------------------------------------------------------------------------
/models/segnet.py:
--------------------------------------------------------------------------------
1 | from base.base_model import BaseModel
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import models
6 | from itertools import chain
7 | from math import ceil
8 |
9 | class SegNet(BaseModel):
10 | def __init__(self, num_classes, in_channels=3, pretrained=False, freeze_bn=False, **_):#pretrained=True hhl20191020gai
11 | super(SegNet, self).__init__()
12 | vgg_bn = models.vgg16_bn(pretrained= pretrained)
13 | encoder = list(vgg_bn.features.children())
14 |
15 | # Adjust the input size
16 | if in_channels != 3:
17 | encoder[0].in_channels = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
18 |
19 | # Encoder, VGG without any maxpooling
20 | self.stage1_encoder = nn.Sequential(*encoder[:6])
21 | self.stage2_encoder = nn.Sequential(*encoder[7:13])
22 | self.stage3_encoder = nn.Sequential(*encoder[14:23])
23 | self.stage4_encoder = nn.Sequential(*encoder[24:33])
24 | self.stage5_encoder = nn.Sequential(*encoder[34:-1])
25 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
26 |
27 | # Decoder, same as the encoder but reversed, maxpool will not be used
28 | decoder = encoder
29 | decoder = [i for i in list(reversed(decoder)) if not isinstance(i, nn.MaxPool2d)]
30 | # Replace the last conv layer
31 | decoder[-1] = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
32 | # When reversing, we also reversed conv->batchN->relu, correct it
33 | decoder = [item for i in range(0, len(decoder), 3) for item in decoder[i:i+3][::-1]]
34 | # Replace some conv layers & batchN after them
35 | for i, module in enumerate(decoder):
36 | if isinstance(module, nn.Conv2d):
37 | if module.in_channels != module.out_channels:
38 | decoder[i+1] = nn.BatchNorm2d(module.in_channels)
39 | decoder[i] = nn.Conv2d(module.out_channels, module.in_channels, kernel_size=3, stride=1, padding=1)
40 |
41 | self.stage1_decoder = nn.Sequential(*decoder[0:9])
42 | self.stage2_decoder = nn.Sequential(*decoder[9:18])
43 | self.stage3_decoder = nn.Sequential(*decoder[18:27])
44 | self.stage4_decoder = nn.Sequential(*decoder[27:33])
45 | self.stage5_decoder = nn.Sequential(*decoder[33:],
46 | nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1)
47 | )
48 | self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
49 |
50 | self._initialize_weights(self.stage1_decoder, self.stage2_decoder, self.stage3_decoder,
51 | self.stage4_decoder, self.stage5_decoder)
52 | if freeze_bn: self.freeze_bn()
53 |
54 | def _initialize_weights(self, *stages):
55 | for modules in stages:
56 | for module in modules.modules():
57 | if isinstance(module, nn.Conv2d):
58 | nn.init.kaiming_normal_(module.weight)
59 | if module.bias is not None:
60 | module.bias.data.zero_()
61 | elif isinstance(module, nn.BatchNorm2d):
62 | module.weight.data.fill_(1)
63 | module.bias.data.zero_()
64 |
65 | def forward(self, x):
66 | # Encoder
67 | x = self.stage1_encoder(x)
68 | x1_size = x.size()
69 | x, indices1 = self.pool(x)
70 |
71 | x = self.stage2_encoder(x)
72 | x2_size = x.size()
73 | x, indices2 = self.pool(x)
74 |
75 | x = self.stage3_encoder(x)
76 | x3_size = x.size()
77 | x, indices3 = self.pool(x)
78 |
79 | x = self.stage4_encoder(x)
80 | x4_size = x.size()
81 | x, indices4 = self.pool(x)
82 |
83 | x = self.stage5_encoder(x)
84 | x5_size = x.size()
85 | x, indices5 = self.pool(x)
86 |
87 | # Decoder
88 | x = self.unpool(x, indices=indices5, output_size=x5_size)
89 | x = self.stage1_decoder(x)
90 |
91 | x = self.unpool(x, indices=indices4, output_size=x4_size)
92 | x = self.stage2_decoder(x)
93 |
94 | x = self.unpool(x, indices=indices3, output_size=x3_size)
95 | x = self.stage3_decoder(x)
96 |
97 | x = self.unpool(x, indices=indices2, output_size=x2_size)
98 | x = self.stage4_decoder(x)
99 |
100 | x = self.unpool(x, indices=indices1, output_size=x1_size)
101 | x = self.stage5_decoder(x)
102 |
103 | return x
104 |
105 | def get_backbone_params(self):
106 | return []
107 |
108 | def get_decoder_params(self):
109 | return self.parameters()
110 |
111 | def freeze_bn(self):
112 | for module in self.modules():
113 | if isinstance(module, nn.BatchNorm2d): module.eval()
114 |
115 |
116 |
117 | class DecoderBottleneck(nn.Module):
118 | def __init__(self, inchannels):
119 | super(DecoderBottleneck, self).__init__()
120 | self.conv1 = nn.Conv2d(inchannels, inchannels//4, kernel_size=1, bias=False)
121 | self.bn1 = nn.BatchNorm2d(inchannels//4)
122 | self.conv2 = nn.ConvTranspose2d(inchannels//4, inchannels//4, kernel_size=2, stride=2, bias=False)
123 | self.bn2 = nn.BatchNorm2d(inchannels//4)
124 | self.conv3 = nn.Conv2d(inchannels//4, inchannels//2, 1, bias=False)
125 | self.bn3 = nn.BatchNorm2d(inchannels//2)
126 | self.relu = nn.ReLU(inplace=True)
127 | self.downsample = nn.Sequential(
128 | nn.ConvTranspose2d(inchannels, inchannels//2, kernel_size=2, stride=2, bias=False),
129 | nn.BatchNorm2d(inchannels//2))
130 |
131 | def forward(self, x):
132 | out = self.conv1(x)
133 | out = self.bn1(out)
134 | out = self.relu(out)
135 | out = self.conv2(out)
136 | out = self.bn2(out)
137 | out = self.relu(out)
138 | out = self.conv3(out)
139 | out = self.bn3(out)
140 |
141 | identity = self.downsample(x)
142 | out += identity
143 | out = self.relu(out)
144 | return out
145 |
146 | class LastBottleneck(nn.Module):
147 | def __init__(self, inchannels):
148 | super(LastBottleneck, self).__init__()
149 | self.conv1 = nn.Conv2d(inchannels, inchannels//4, kernel_size=1, bias=False)
150 | self.bn1 = nn.BatchNorm2d(inchannels//4)
151 | self.conv2 = nn.Conv2d(inchannels//4, inchannels//4, kernel_size=3, padding=1, bias=False)
152 | self.bn2 = nn.BatchNorm2d(inchannels//4)
153 | self.conv3 = nn.Conv2d(inchannels//4, inchannels//4, 1, bias=False)
154 | self.bn3 = nn.BatchNorm2d(inchannels//4)
155 | self.relu = nn.ReLU(inplace=True)
156 | self.downsample = nn.Sequential(
157 | nn.Conv2d(inchannels, inchannels//4, kernel_size=1, bias=False),
158 | nn.BatchNorm2d(inchannels//4))
159 |
160 | def forward(self, x):
161 | out = self.conv1(x)
162 | out = self.bn1(out)
163 | out = self.relu(out)
164 | out = self.conv2(out)
165 | out = self.bn2(out)
166 | out = self.relu(out)
167 | out = self.conv3(out)
168 | out = self.bn3(out)
169 |
170 | identity = self.downsample(x)
171 | out += identity
172 | out = self.relu(out)
173 | return out
174 |
175 | class SegResNet(BaseModel):
176 | def __init__(self, num_classes, in_channels=3, pretrained=True, freeze_bn=False, **_):
177 | super(SegResNet, self).__init__()
178 | resnet50 = models.resnet50(pretrained=pretrained)
179 | encoder = list(resnet50.children())
180 | if in_channels != 3:
181 | encoder[0].in_channels = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1)
182 | encoder[3].return_indices = True
183 |
184 | # Encoder
185 | self.first_conv = nn.Sequential(*encoder[:4])
186 | resnet50_blocks = list(resnet50.children())[4:-2]
187 | self.encoder = nn.Sequential(*resnet50_blocks)
188 |
189 | # Decoder
190 | resnet50_untrained = models.resnet50(pretrained=False)
191 | resnet50_blocks = list(resnet50_untrained.children())[4:-2][::-1]
192 | decoder = []
193 | channels = (2048, 1024, 512)
194 | for i, block in enumerate(resnet50_blocks[:-1]):
195 | new_block = list(block.children())[::-1][:-1]
196 | decoder.append(nn.Sequential(*new_block, DecoderBottleneck(channels[i])))
197 | new_block = list(resnet50_blocks[-1].children())[::-1][:-1]
198 | decoder.append(nn.Sequential(*new_block, LastBottleneck(256)))
199 |
200 | self.decoder = nn.Sequential(*decoder)
201 | self.last_conv = nn.Sequential(
202 | nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2, bias=False),
203 | nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1)
204 | )
205 | if freeze_bn: self.freeze_bn()
206 |
207 | def forward(self, x):
208 | inputsize = x.size()
209 |
210 | # Encoder
211 | x, indices = self.first_conv(x)
212 | x = self.encoder(x)
213 |
214 | # Decoder
215 | x = self.decoder(x)
216 | h_diff = ceil((x.size()[2] - indices.size()[2]) / 2)
217 | w_diff = ceil((x.size()[3] - indices.size()[3]) / 2)
218 | if indices.size()[2] % 2 == 1:
219 | x = x[:, :, h_diff:x.size()[2]-(h_diff-1), w_diff: x.size()[3]-(w_diff-1)]
220 | else:
221 | x = x[:, :, h_diff:x.size()[2]-h_diff, w_diff: x.size()[3]-w_diff]
222 |
223 | x = F.max_unpool2d(x, indices, kernel_size=2, stride=2)
224 | x = self.last_conv(x)
225 |
226 | if inputsize != x.size():
227 | h_diff = (x.size()[2] - inputsize[2]) // 2
228 | w_diff = (x.size()[3] - inputsize[3]) // 2
229 | x = x[:, :, h_diff:x.size()[2]-h_diff, w_diff: x.size()[3]-w_diff]
230 | if h_diff % 2 != 0: x = x[:, :, :-1, :]
231 | if w_diff % 2 != 0: x = x[:, :, :, :-1]
232 |
233 | return x
234 |
235 | def get_backbone_params(self):
236 | return chain(self.first_conv.parameters(), self.encoder.parameters())
237 |
238 | def get_decoder_params(self):
239 | return chain(self.decoder.parameters(), self.last_conv.parameters())
240 |
241 | def freeze_bn(self):
242 | for module in self.modules():
243 | if isinstance(module, nn.BatchNorm2d): module.eval()
244 |
245 |
246 |
--------------------------------------------------------------------------------
/models/unet.py:
--------------------------------------------------------------------------------
1 | #from base.base_model import BaseModel
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from itertools import chain
6 |
7 |
8 | class encoder(nn.Module):
9 | def __init__(self, in_channels, out_channels):
10 | super(encoder, self).__init__()
11 | self.down_conv = nn.Sequential(
12 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
13 | nn.BatchNorm2d(out_channels),
14 | nn.ReLU(inplace=True),
15 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
16 | nn.BatchNorm2d(out_channels),
17 | nn.ReLU(inplace=True),
18 | )
19 | self.pool = nn.MaxPool2d(kernel_size=2, ceil_mode=True)
20 |
21 | def forward(self, x):
22 | x = self.down_conv(x)
23 | x_pooled = self.pool(x)
24 | return x, x_pooled
25 |
26 | #nn.Upsample(scale_factor=2)
27 | class decoder(nn.Module):
28 | def __init__(self, in_channels, out_channels):
29 | super(decoder, self).__init__()
30 | self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
31 | self.up_conv = nn.Sequential(
32 | nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
33 | nn.BatchNorm2d(out_channels),
34 | nn.ReLU(inplace=True),
35 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
36 | nn.BatchNorm2d(out_channels),
37 | nn.ReLU(inplace=True),
38 | )
39 |
40 | def forward(self, x_copy, x):
41 | x = self.up(x)
42 | # Padding in case the incomping volumes are of different sizes
43 | diffY = x_copy.size()[2] - x.size()[2]
44 | diffX = x_copy.size()[3] - x.size()[3]
45 | x = F.pad(x, (diffX // 2, diffX - diffX // 2,
46 | diffY // 2, diffY - diffY // 2))
47 | # Concatenate
48 | x = torch.cat([x_copy, x], dim=1)
49 | x = self.up_conv(x)
50 | return x
51 |
52 |
53 | class UNet(nn.Module):#BaseModel
54 | def __init__(self, num_classes, in_channels=3, freeze_bn=False, **_):
55 | super(UNet, self).__init__()
56 | self.down1 = encoder(in_channels, 64)
57 | self.down2 = encoder(64, 128)
58 | self.down3 = encoder(128, 256)
59 | self.down4 = encoder(256, 512)
60 | self.middle_conv = nn.Sequential(
61 | nn.Conv2d(512, 1024, kernel_size=3, padding=1),
62 | nn.BatchNorm2d(1024),
63 | nn.ReLU(inplace=True),
64 | nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
65 | nn.BatchNorm2d(1024),
66 | nn.ReLU(inplace=True),
67 | )
68 | self.up1 = decoder(1024, 512)
69 | self.up2 = decoder(512, 256)
70 | self.up3 = decoder(256, 128)
71 | self.up4 = decoder(128, 64)
72 | self.up = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
73 | self.beforefinal2_conv = nn.Conv2d(128, num_classes, kernel_size=1) # 128
74 |
75 | self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
76 | self._initialize_weights()
77 | if freeze_bn:
78 | self.freeze_bn()
79 |
80 | def _initialize_weights(self):
81 | for module in self.modules():
82 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
83 | nn.init.kaiming_normal_(module.weight)
84 | if module.bias is not None:
85 | module.bias.data.zero_()
86 | elif isinstance(module, nn.BatchNorm2d):
87 | module.weight.data.fill_(1)
88 | module.bias.data.zero_()
89 |
90 | def forward(self, x):
91 | x1, x = self.down1(x)
92 | x2, x = self.down2(x)
93 | x3, x = self.down3(x)
94 | x4, x = self.down4(x)
95 | x = self.middle_conv(x)
96 | x = self.up1(x4, x)
97 | x = self.up2(x3, x)
98 | x = self.up3(x2, x)
99 | #x_beforefinal2_temp = self.up(x)
100 | #x_beforefinal2 = self.beforefinal2_conv(x_beforefinal2_temp)
101 |
102 | x = self.up4(x1, x)
103 | # x_beforefinal2 = self.beforefinal2_conv(x)
104 |
105 | x_final = self.final_conv(x)
106 | return x_final#, x_beforefinal2
107 |
108 | def get_backbone_params(self):
109 | # There is no backbone for unet, all the parameters are trained from scratch
110 | return []
111 |
112 | def get_decoder_params(self):
113 | return self.parameters()
114 |
115 | def freeze_bn(self):
116 | for module in self.modules():
117 | if isinstance(module, nn.BatchNorm2d): module.eval()
--------------------------------------------------------------------------------
/postproc_other.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import cv2
4 | import numpy as np
5 | from scipy.ndimage import filters, measurements
6 | from scipy.ndimage.morphology import (
7 | binary_erosion,
8 | binary_dilation,
9 | binary_fill_holes,
10 | distance_transform_cdt,
11 | distance_transform_edt)
12 | from skimage.morphology import remove_small_objects#, watershed
13 | from skimage.segmentation import watershed
14 |
15 | def process(pred, model_mode, min_size = 10, ws=True):
16 | def gen_inst_dst_map(ann):
17 | shape = ann.shape[:2] # HW
18 | nuc_list = list(np.unique(ann))
19 | nuc_list.remove(0) # 0 is background
20 |
21 | canvas = np.zeros(shape, dtype=np.uint8)
22 | for nuc_id in nuc_list:
23 | nuc_map = np.copy(ann == nuc_id)
24 | nuc_dst = distance_transform_edt(nuc_map)
25 | nuc_dst = 255 * (nuc_dst / np.amax(nuc_dst))
26 | canvas += nuc_dst.astype('uint8')
27 | return canvas
28 |
29 | if model_mode != 'dcan':
30 | assert len(pred.shape) == 2, 'Prediction shape is not HW'
31 | pred[pred > 0.5] = 1
32 | pred[pred <= 0.5] = 0
33 |
34 | # ! refactor these
35 | ws = False if model_mode == 'unet' or model_mode == 'micronet' else ws
36 | if ws:
37 | dist = measurements.label(pred)[0]
38 | dist = gen_inst_dst_map(dist)
39 | marker = np.copy(dist)
40 | marker[marker <= 125] = 0
41 | marker[marker > 125] = 1
42 | marker = binary_fill_holes(marker)
43 | marker = binary_erosion(marker, iterations=1)
44 | marker = measurements.label(marker)[0]
45 |
46 | marker = remove_small_objects(marker, min_size=min_size)
47 | pred = watershed(-dist, marker, mask=pred)
48 | pred = remove_small_objects(pred, min_size=min_size)
49 | #print('============================ ws = True ============================ ')
50 | else:
51 | pred = binary_fill_holes(pred)
52 | pred = measurements.label(pred)[0]
53 | pred = remove_small_objects(pred, min_size=min_size)
54 | print('binary_fill_holes(pred), measurements.label(pred)[0], remove_small_objects(pred, min_size=10)')
55 |
56 | if model_mode == 'micronet':
57 | # * dilate with same kernel size used for erosion during training
58 | kernel = np.array([[0, 1, 0],
59 | [1, 1, 1],
60 | [0, 1, 0]], np.uint8)
61 |
62 | canvas = np.zeros([pred.shape[0], pred.shape[1]])
63 | for inst_id in range(1, np.max(pred)+1):
64 | inst_map = np.array(pred == inst_id, dtype=np.uint8)
65 | inst_map = cv2.dilate(inst_map, kernel, iterations=1)
66 | inst_map = binary_fill_holes(inst_map)
67 | canvas[inst_map > 0] = inst_id
68 | pred = canvas
69 | else:
70 | assert (pred.shape[2]) == 2, 'Prediction should have contour and blb'
71 | blb = pred[...,0]
72 | blb = np.squeeze(blb)
73 | cnt = pred[...,1]
74 | cnt = np.squeeze(cnt)
75 |
76 | pred = blb - cnt # NOTE
77 | pred[pred > 0.3] = 1 # Kumar 0.3, UHCW 0.3
78 | pred[pred <= 0.3] = 0 # CPM2017 0.1
79 | pred = measurements.label(pred)[0]
80 | pred = remove_small_objects(pred, min_size=min_size) # 20
81 | canvas = np.zeros([pred.shape[0], pred.shape[1]])
82 |
83 | k_disk = np.array([
84 | [0, 0, 0, 1, 0, 0, 0],
85 | [0, 0, 1, 1, 1, 0, 0],
86 | [0, 1, 1, 1, 1, 1, 0],
87 | [1, 1, 1, 1, 1, 1, 1],
88 | [0, 1, 1, 1, 1, 1, 0],
89 | [0, 0, 1, 1, 1, 0, 0],
90 | [0, 0, 0, 1, 0, 0, 0],
91 | ], np.uint8)
92 | for inst_id in range(1, np.max(pred)+1):
93 | inst_map = np.array(pred == inst_id, dtype=np.uint8)
94 | inst_map = cv2.dilate(inst_map, k_disk, iterations=1)
95 | inst_map = binary_fill_holes(inst_map)
96 | canvas[inst_map > 0] = inst_id
97 | pred = canvas
98 |
99 | return pred
--------------------------------------------------------------------------------
/stats_utils.py:
--------------------------------------------------------------------------------
1 | # from HoverNet
2 | import warnings
3 | import numpy as np
4 | from scipy.optimize import linear_sum_assignment
5 |
6 |
7 | def get_fast_aji(true, pred):
8 | """
9 | AJI version distributed by MoNuSeg, has no permutation problem but suffered from
10 | over-penalisation similar to DICE2
11 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4]
12 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no
13 | effect on the result.
14 | """
15 | true = np.copy(true) # ? do we need this
16 | pred = np.copy(pred)
17 | true_id_list = list(np.unique(true))
18 | pred_id_list = list(np.unique(pred))
19 |
20 | true_masks = [None, ]
21 | for t in true_id_list[1:]:
22 | t_mask = np.array(true == t, np.uint8)
23 | true_masks.append(t_mask)
24 |
25 | pred_masks = [None, ]
26 | for p in pred_id_list[1:]:
27 | p_mask = np.array(pred == p, np.uint8)
28 | pred_masks.append(p_mask)
29 |
30 | # prefill with value
31 | pairwise_inter = np.zeros([len(true_id_list) - 1,
32 | len(pred_id_list) - 1], dtype=np.float64)
33 | pairwise_union = np.zeros([len(true_id_list) - 1,
34 | len(pred_id_list) - 1], dtype=np.float64)
35 | # 多检
36 | pairwise_FP = np.zeros([len(true_id_list) - 1,
37 | len(pred_id_list) - 1], dtype=np.float64)
38 | # 漏检
39 | pairwise_FN = np.zeros([len(true_id_list) - 1,
40 | len(pred_id_list) - 1], dtype=np.float64)
41 |
42 | # caching pairwise
43 | for true_id in true_id_list[1:]: # 0-th is background
44 | t_mask = true_masks[true_id]
45 | pred_true_overlap = pred[t_mask > 0]
46 | pred_true_overlap_id = np.unique(pred_true_overlap)
47 | pred_true_overlap_id = list(pred_true_overlap_id)
48 | for pred_id in pred_true_overlap_id:
49 | if pred_id == 0: # ignore
50 | continue # overlaping background
51 | p_mask = pred_masks[pred_id]
52 | total = (t_mask + p_mask).sum()
53 | inter = (t_mask * p_mask).sum()
54 | pairwise_inter[true_id - 1, pred_id - 1] = inter
55 | pairwise_union[true_id - 1, pred_id - 1] = total - inter
56 |
57 | pairwise_FP[true_id - 1, pred_id - 1] = p_mask.sum() - inter
58 | pairwise_FN[true_id - 1, pred_id - 1] = t_mask.sum() - inter
59 | #
60 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6)
61 | # pair of pred that give highest iou for each true, dont care
62 | # about reusing pred instance multiple times
63 | paired_pred = np.argmax(pairwise_iou, axis=1)
64 | pairwise_iou = np.max(pairwise_iou, axis=1)
65 | # exlude those dont have intersection
66 | paired_true = np.nonzero(pairwise_iou > 0.0)[0]
67 | paired_pred = paired_pred[paired_true]
68 | # print(paired_true.shape, paired_pred.shape)
69 |
70 | overall_inter = (pairwise_inter[paired_true, paired_pred]).sum()
71 | overall_union = (pairwise_union[paired_true, paired_pred]).sum()
72 |
73 | overall_FP = (pairwise_FP[paired_true, paired_pred]).sum()
74 | overall_FN = (pairwise_FN[paired_true, paired_pred]).sum()
75 |
76 |
77 | #
78 | paired_true = (list(paired_true + 1)) # index to instance ID
79 | paired_pred = (list(paired_pred + 1))
80 | # add all unpaired GT and Prediction into the union
81 | unpaired_true = np.array([idx for idx in true_id_list[1:] if idx not in paired_true])
82 | unpaired_pred = np.array([idx for idx in pred_id_list[1:] if idx not in paired_pred])
83 |
84 | less_pred = 0
85 | more_pred = 0
86 |
87 | for true_id in unpaired_true:
88 | less_pred += true_masks[true_id].sum()
89 | overall_union += true_masks[true_id].sum()
90 | for pred_id in unpaired_pred:
91 | more_pred += pred_masks[pred_id].sum()
92 | overall_union += pred_masks[pred_id].sum()
93 | #
94 | aji_score = overall_inter / overall_union
95 | fm = overall_union - overall_inter
96 | print('\t [ana_FP = {:.4f}, ana_FN = {:.4f}, ana_less = {:.4f}, ana_more = {:.4f}]'.format((overall_FP / fm),(overall_FN / fm),(less_pred / fm),(more_pred / fm)))
97 |
98 | return aji_score, overall_FP / fm, overall_FN / fm, less_pred / fm, more_pred / fm
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | #####
108 | def get_fast_aji_plus(true, pred):
109 | """
110 | AJI+, an AJI version with maximal unique pairing to obtain overall intersecion.
111 | Every prediction instance is paired with at most 1 GT instance (1 to 1) mapping, unlike AJI
112 | where a prediction instance can be paired against many GT instances (1 to many).
113 | Remaining unpaired GT and Prediction instances will be added to the overall union.
114 | The 1 to 1 mapping prevents AJI's over-penalisation from happening.
115 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4]
116 | not [2, 3, 6, 10]. Please call `remap_label` before hand and `by_size` flag has no
117 | effect on the result.
118 | """
119 | true = np.copy(true) # ? do we need this
120 | pred = np.copy(pred)
121 | true_id_list = list(np.unique(true))
122 | pred_id_list = list(np.unique(pred))
123 |
124 | true_masks = [None, ]
125 | for t in true_id_list[1:]:
126 | t_mask = np.array(true == t, np.uint8)
127 | true_masks.append(t_mask)
128 |
129 | pred_masks = [None, ]
130 | for p in pred_id_list[1:]:
131 | p_mask = np.array(pred == p, np.uint8)
132 | pred_masks.append(p_mask)
133 |
134 | # prefill with value
135 | pairwise_inter = np.zeros([len(true_id_list) - 1,
136 | len(pred_id_list) - 1], dtype=np.float64)
137 | pairwise_union = np.zeros([len(true_id_list) - 1,
138 | len(pred_id_list) - 1], dtype=np.float64)
139 |
140 | # caching pairwise
141 | for true_id in true_id_list[1:]: # 0-th is background
142 | t_mask = true_masks[true_id]
143 | pred_true_overlap = pred[t_mask > 0]
144 | pred_true_overlap_id = np.unique(pred_true_overlap)
145 | pred_true_overlap_id = list(pred_true_overlap_id)
146 | for pred_id in pred_true_overlap_id:
147 | if pred_id == 0: # ignore
148 | continue # overlaping background
149 | p_mask = pred_masks[pred_id]
150 | total = (t_mask + p_mask).sum()
151 | inter = (t_mask * p_mask).sum()
152 | pairwise_inter[true_id - 1, pred_id - 1] = inter
153 | pairwise_union[true_id - 1, pred_id - 1] = total - inter
154 | #
155 | pairwise_iou = pairwise_inter / (pairwise_union + 1.0e-6)
156 | #### Munkres pairing to find maximal unique pairing
157 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
158 | ### extract the paired cost and remove invalid pair
159 | paired_iou = pairwise_iou[paired_true, paired_pred]
160 | # now select all those paired with iou != 0.0 i.e have intersection
161 | paired_true = paired_true[paired_iou > 0.0]
162 | paired_pred = paired_pred[paired_iou > 0.0]
163 | paired_inter = pairwise_inter[paired_true, paired_pred]
164 | paired_union = pairwise_union[paired_true, paired_pred]
165 | paired_true = (list(paired_true + 1)) # index to instance ID
166 | paired_pred = (list(paired_pred + 1))
167 | overall_inter = paired_inter.sum()
168 | overall_union = paired_union.sum()
169 | # add all unpaired GT and Prediction into the union
170 | unpaired_true = np.array([idx for idx in true_id_list[1:] if idx not in paired_true])
171 | unpaired_pred = np.array([idx for idx in pred_id_list[1:] if idx not in paired_pred])
172 | for true_id in unpaired_true:
173 | overall_union += true_masks[true_id].sum()
174 | for pred_id in unpaired_pred:
175 | overall_union += pred_masks[pred_id].sum()
176 | #
177 | aji_score = overall_inter / overall_union
178 | return aji_score
179 |
180 |
181 | #####
182 | def get_fast_pq(true, pred, match_iou=0.5):
183 | """
184 | `match_iou` is the IoU threshold level to determine the pairing between
185 | GT instances `p` and prediction instances `g`. `p` and `g` is a pair
186 | if IoU > `match_iou`. However, pair of `p` and `g` must be unique
187 | (1 prediction instance to 1 GT instance mapping).
188 | If `match_iou` < 0.5, Munkres assignment (solving minimum weight matching
189 | in bipartite graphs) is caculated to find the maximal amount of unique pairing.
190 | If `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and
191 | the number of pairs is also maximal.
192 |
193 | Fast computation requires instance IDs are in contiguous orderding
194 | i.e [1, 2, 3, 4] not [2, 3, 6, 10]. Please call `remap_label` beforehand
195 | and `by_size` flag has no effect on the result.
196 | Returns:
197 | [dq, sq, pq]: measurement statistic
198 | [paired_true, paired_pred, unpaired_true, unpaired_pred]:
199 | pairing information to perform measurement
200 |
201 | """
202 | assert match_iou >= 0.0, "Cant' be negative"
203 |
204 | true = np.copy(true)
205 | pred = np.copy(pred)
206 | true_id_list = list(np.unique(true))
207 | pred_id_list = list(np.unique(pred))
208 |
209 | true_masks = [None, ]
210 | for t in true_id_list[1:]:
211 | t_mask = np.array(true == t, np.uint8)
212 | true_masks.append(t_mask)
213 |
214 | pred_masks = [None, ]
215 | for p in pred_id_list[1:]:
216 | p_mask = np.array(pred == p, np.uint8)
217 | pred_masks.append(p_mask)
218 |
219 | # prefill with value
220 | pairwise_iou = np.zeros([len(true_id_list) - 1,
221 | len(pred_id_list) - 1], dtype=np.float64)
222 |
223 | # caching pairwise iou
224 | for true_id in true_id_list[1:]: # 0-th is background
225 | t_mask = true_masks[true_id]
226 | pred_true_overlap = pred[t_mask > 0]
227 | pred_true_overlap_id = np.unique(pred_true_overlap)
228 | pred_true_overlap_id = list(pred_true_overlap_id)
229 | for pred_id in pred_true_overlap_id:
230 | if pred_id == 0: # ignore
231 | continue # overlaping background
232 | p_mask = pred_masks[pred_id]
233 | total = (t_mask + p_mask).sum()
234 | inter = (t_mask * p_mask).sum()
235 | iou = inter / (total - inter)
236 | pairwise_iou[true_id - 1, pred_id - 1] = iou
237 | #
238 | if match_iou >= 0.5:
239 | paired_iou = pairwise_iou[pairwise_iou > match_iou]
240 | pairwise_iou[pairwise_iou <= match_iou] = 0.0
241 | paired_true, paired_pred = np.nonzero(pairwise_iou)
242 | paired_iou = pairwise_iou[paired_true, paired_pred]
243 | paired_true += 1 # index is instance id - 1
244 | paired_pred += 1 # hence return back to original
245 | else: # * Exhaustive maximal unique pairing
246 | #### Munkres pairing with scipy library
247 | # the algorithm return (row indices, matched column indices)
248 | # if there is multiple same cost in a row, index of first occurence
249 | # is return, thus the unique pairing is ensure
250 | # inverse pair to get high IoU as minimum
251 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
252 | ### extract the paired cost and remove invalid pair
253 | paired_iou = pairwise_iou[paired_true, paired_pred]
254 |
255 | # now select those above threshold level
256 | # paired with iou = 0.0 i.e no intersection => FP or FN
257 | paired_true = list(paired_true[paired_iou > match_iou] + 1)
258 | paired_pred = list(paired_pred[paired_iou > match_iou] + 1)
259 | paired_iou = paired_iou[paired_iou > match_iou]
260 |
261 | # get the actual FP and FN
262 | unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true]
263 | unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred]
264 | # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred))
265 |
266 | #
267 | tp = len(paired_true)
268 | fp = len(unpaired_pred)
269 | fn = len(unpaired_true)
270 | # get the F1-score i.e DQ
271 | dq = tp / (tp + 0.5 * fp + 0.5 * fn)
272 | # get the SQ, no paired has 0 iou so not impact
273 | sq = paired_iou.sum() / (tp + 1.0e-6)
274 |
275 | return [dq, sq, dq * sq], [paired_true, paired_pred, unpaired_true, unpaired_pred]
276 |
277 |
278 | #####
279 | def get_fast_dice_2(true, pred):
280 | """
281 | Ensemble dice
282 | """
283 | true = np.copy(true)
284 | pred = np.copy(pred)
285 | true_id = list(np.unique(true))
286 | pred_id = list(np.unique(pred))
287 |
288 | overall_total = 0
289 | overall_inter = 0
290 |
291 | true_masks = [np.zeros(true.shape)]
292 | for t in true_id[1:]:
293 | t_mask = np.array(true == t, np.uint8)
294 | true_masks.append(t_mask)
295 |
296 | pred_masks = [np.zeros(true.shape)]
297 | for p in pred_id[1:]:
298 | p_mask = np.array(pred == p, np.uint8)
299 | pred_masks.append(p_mask)
300 |
301 | for true_idx in range(1, len(true_id)):
302 | t_mask = true_masks[true_idx]
303 | pred_true_overlap = pred[t_mask > 0]
304 | pred_true_overlap_id = np.unique(pred_true_overlap)
305 | pred_true_overlap_id = list(pred_true_overlap_id)
306 | try: # blinly remove background
307 | pred_true_overlap_id.remove(0)
308 | except ValueError:
309 | pass # just mean no background
310 | for pred_idx in pred_true_overlap_id:
311 | p_mask = pred_masks[pred_idx]
312 | total = (t_mask + p_mask).sum()
313 | inter = (t_mask * p_mask).sum()
314 | overall_total += total
315 | overall_inter += inter
316 |
317 | return 2 * overall_inter / overall_total
318 |
319 |
320 | #####
321 |
322 | #####--------------------------As pseudocode
323 | def get_dice_1(true, pred):
324 | """
325 | Traditional dice
326 | """
327 | # cast to binary 1st
328 | true = np.copy(true)
329 | pred = np.copy(pred)
330 | true[true > 0] = 1
331 | pred[pred > 0] = 1
332 | inter = true * pred
333 | denom = true + pred
334 | return 2.0 * np.sum(inter) / np.sum(denom)
335 |
336 |
337 | ####
338 | def get_dice_2(true, pred):
339 | true = np.copy(true)
340 | pred = np.copy(pred)
341 | true_id = list(np.unique(true))
342 | pred_id = list(np.unique(pred))
343 | # remove background aka id 0
344 | true_id.remove(0)
345 | pred_id.remove(0)
346 |
347 | total_markup = 0
348 | total_intersect = 0
349 | for t in true_id:
350 | t_mask = np.array(true == t, np.uint8)
351 | for p in pred_id:
352 | p_mask = np.array(pred == p, np.uint8)
353 | intersect = p_mask * t_mask
354 | if intersect.sum() > 0:
355 | total_intersect += intersect.sum()
356 | total_markup += (t_mask.sum() + p_mask.sum())
357 | return 2 * total_intersect / total_markup
358 |
359 |
360 | #####
361 | def remap_label(pred, by_size=False):
362 | """
363 | Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3]
364 | not [0, 2, 4, 6]. The ordering of instances (which one comes first)
365 | is preserved unless by_size=True, then the instances will be reordered
366 | so that bigger nucler has smaller ID
367 | Args:
368 | pred : the 2d array contain instances where each instances is marked
369 | by non-zero integer
370 | by_size : renaming with larger nuclei has smaller id (on-top)
371 | """
372 | pred_id = list(np.unique(pred))
373 | pred_id.remove(0)
374 | if len(pred_id) == 0:
375 | return pred # no label
376 | if by_size:
377 | pred_size = []
378 | for inst_id in pred_id:
379 | size = (pred == inst_id).sum()
380 | pred_size.append(size)
381 | # sort the id by size in descending order
382 | pair_list = zip(pred_id, pred_size)
383 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True)
384 | pred_id, pred_size = zip(*pair_list)
385 |
386 | new_pred = np.zeros(pred.shape, np.int32)
387 | for idx, inst_id in enumerate(pred_id):
388 | new_pred[pred == inst_id] = idx + 1
389 | return new_pred
390 |
391 |
392 | #####
393 | def pair_coordinates(setA, setB, radius):
394 | """
395 | Use the Munkres or Kuhn-Munkres algorithm to find the most optimal
396 | unique pairing (largest possible match) when pairing points in set B
397 | against points in set A, using distance as cost function
398 | Args:
399 | setA, setB: np.array (float32) of size Nx2 contains the of XY coordinate
400 | of N different points
401 | radius: valid area around a point in setA to consider
402 | a given coordinate in setB a candidate for match
403 | Return:
404 | pairing: pairing is an array of indices
405 | where point at index pairing[0] in set A paired with point
406 | in set B at index pairing[1]
407 | unparedA, unpairedB: remaining poitn in set A and set B unpaired
408 | """
409 |
410 | # * Euclidean distance as the cost matrix
411 | setA_tile = np.expand_dims(setA, axis=1)
412 | setB_tile = np.expand_dims(setB, axis=0)
413 | setA_tile = np.repeat(setA_tile, setB.shape[0], axis=1)
414 | setB_tile = np.repeat(setB_tile, setA.shape[0], axis=0)
415 | pair_distance = (setA_tile - setB_tile) ** 2
416 | # set A is row, and set B is paired against set A
417 | pair_distance = np.sqrt(np.sum(pair_distance, axis=-1))
418 |
419 | # * Munkres pairing with scipy library
420 | # the algorithm return (row indices, matched column indices)
421 | # if there is multiple same cost in a row, index of first occurence
422 | # is return, thus the unique pairing is ensured
423 | indicesA, paired_indicesB = linear_sum_assignment(pair_distance)
424 |
425 | # extract the paired cost and remove instances
426 | # outside of designated radius
427 | pair_cost = pair_distance[indicesA, paired_indicesB]
428 |
429 | pairedA = indicesA[pair_cost <= radius]
430 | pairedB = paired_indicesB[pair_cost <= radius]
431 |
432 | unpairedA = [idx for idx in range(setA.shape[0]) if idx not in list(pairedA)]
433 | unpairedB = [idx for idx in range(setB.shape[0]) if idx not in list(pairedB)]
434 |
435 | pairing = np.array(list(zip(pairedA, pairedB)))
436 | unpairedA = np.array(unpairedA, dtype=np.int64)
437 | unpairedB = np.array(unpairedB, dtype=np.int64)
438 |
439 | return pairing, unpairedA, unpairedB
--------------------------------------------------------------------------------