├── .gitignore
├── README.md
├── dataloaders
├── dataset.py
├── my_transforms.py
└── prepare_data.py
├── figures
└── overview.png
├── model
├── modelW.py
├── model_WNet.py
├── resnet.py
└── resnet1.py
├── options.py
├── requirements.txt
├── test.py
├── train.py
└── utils
├── accuracy.py
├── combine.py
├── divergence.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | */.idea
2 | */__pycache__
3 | */.vscode
4 | # */figures
5 | */requirements.txt
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SC-Net
2 | This is the official code for our MedIA paper:
3 |
4 | > [Nuclei Segmentation with Point Annotations from Pathology Images via Self-Supervised Learning and Co-Training](https://arxiv.org/abs/2202.08195)
5 | > Yi Lin*, Zhiyong Qu*, Hao Chen, Zhongke Gao, Yuexiang Li, Lili Xia, Kai Ma, Yefeng Zheng, Kwang-Ting Cheng
6 |
7 | ## Highlights
8 |
9 | In this work, we propose a weakly-supervised learning method for nuclei segmentation that only requires point annotations for training. The proposed method achieves label propagation in a coarse-to-fine manner as follows. First, coarse pixel-level labels are derived from the point annotations based on the Voronoi diagram and the k-means clustering method to avoid overfitting. Second, a co-training strategy with an exponential moving average method is designed to refine the incomplete supervision of the coarse labels. Third, a self-supervised visual representation learning method is tailored for nuclei segmentation of pathology images that transforms the hematoxylin component images into the H&E stained images to gain better understanding of the relationship between the nuclei and cytoplasm.
10 |
11 | [comment]: <> ()
12 | 
13 |
14 | (a) The pipeline of the proposed method; (b) The framework of SC-Net; (c) The process of pseudo label generation.
15 |
16 |
17 | ### Using the code
18 | Please clone the following repositories:
19 | ```
20 | git clone https://github.com/hust-linyi/SC-Net.git
21 | ```
22 |
23 | ### Requirement
24 | ```
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ### Data preparation
29 | #### Download
30 | 1. **MoNuSeg** [Multi-Organ Nuclei Segmentation dataset](https://monuseg.grand-challenge.org)
31 | 2. **CPM** [Computational Precision Medicine dataset](https://drive.google.com/drive/folders/1sJ4nmkif6j4s2FOGj8j6i_Ye7z9w0TfA)
32 |
33 | #### Pre-processing
34 | Please refer to [dataloaders/prepare_data.py](https://github.com/hust-linyi/SC-Net/blob/main/dataloaders/prepare_data.py) for the pre-processing of the datasets.
35 |
36 | ### Training
37 | 1. Configure your own parameters in [opinions.py](https://github.com/hust-linyi/SC-Net/blob/main/options.py), including the dataset path, the number of GPUs, the number of epochs, the batch size, the learning rate, etc.
38 | 2. Run the following command to train the model:
39 | ```
40 | python train.py
41 | ```
42 |
43 | ### Testing
44 | Run the following command to test the model:
45 | ```
46 | python test.py
47 | ```
48 |
49 | ## Citation
50 | Please cite the paper if you use the code.
51 | ```bibtex
52 | @article{lin2023nuclei,
53 | title={Nuclei segmentation with point annotations from pathology images via self-supervised learning and co-training},
54 | author={Lin, Yi and Qu, Zhiyong and Chen, Hao and Gao, Zhongke and Li, Yuexiang and Xia, Lili and Ma, Kai and Zheng, Yefeng and Cheng, Kwang-Ting},
55 | journal={Medical Image Analysis},
56 | pages={102933},
57 | year={2023},
58 | publisher={Elsevier}
59 | }
60 | ```
61 |
--------------------------------------------------------------------------------
/dataloaders/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import os
3 | from PIL import Image
4 | import numpy as np
5 |
6 | IMG_EXTENSIONS = [
7 | '.jpg', '.JPG', '.jpeg', '.JPEG',
8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
9 | ]
10 |
11 |
12 | def is_image_file(filename):
13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
14 |
15 |
16 | def img_loader(path, num_channels):
17 | if num_channels == 1:
18 | img = Image.open(path)
19 | else:
20 | img = Image.open(path).convert('RGB')
21 |
22 | return img
23 |
24 |
25 | # get the image list pairs
26 | def get_imgs_list(dir_list, post_fix=None):
27 | """
28 | :param dir_list: [img1_dir, img2_dir, ...]
29 | :param post_fix: e.g. ['label_vor.png', 'label_cluster.png',...]
30 | :return: e.g. [(img1.png, img1_label_vor.png, img1_label_cluster.png), ...]
31 | """
32 | img_list = []
33 | if len(dir_list) == 0:
34 | return img_list
35 |
36 | img_filename_list = os.listdir(dir_list[0])
37 |
38 | for img in img_filename_list:
39 | if not is_image_file(img):
40 | continue
41 | img1_name = os.path.splitext(img)[0]
42 | item = [os.path.join(dir_list[0], img), ]
43 | for i in range(1, len(dir_list)):
44 | img_name = '{:s}{:s}'.format(img1_name, post_fix[i - 1])
45 | img_path = os.path.join(dir_list[i], img_name)
46 | item.append(img_path)
47 |
48 | if len(item) == len(dir_list):
49 | img_list.append(tuple(item))
50 |
51 | return img_list
52 |
53 | # dataset that supports multiple images
54 | class DataFolder(data.Dataset):
55 | def __init__(self, dir_list, post_fix, num_channels, data_transform=None, loader=img_loader):
56 | """
57 | :param dir_list: [img_dir, label_voronoi_dir, label_cluster_dir]
58 | :param post_fix: ['label_vor.png', 'label_cluster.png']
59 | :param num_channels: [3, 3, 3]
60 | :param data_transform: data transformations
61 | :param loader: image loader
62 | """
63 | super(DataFolder, self).__init__()
64 | # if len(dir_list) != len(post_fix) + 1:
65 | # raise (RuntimeError('Length of dir_list is different from length of post_fix + 1.'))
66 | if len(dir_list) != len(num_channels):
67 | raise (RuntimeError('Length of dir_list is different from length of num_channels.'))
68 |
69 | self.img_list = get_imgs_list(dir_list, post_fix)
70 | if len(self.img_list) == 0:
71 | raise(RuntimeError('Found 0 image pairs in given directories.'))
72 |
73 | self.data_transform = data_transform
74 | self.num_channels = num_channels
75 | self.loader = loader
76 |
77 | def __getitem__(self, index):
78 | img_paths = self.img_list[index]
79 | sample = [self.loader(img_paths[i], self.num_channels[i]) for i in range(len(img_paths))]
80 |
81 | if self.data_transform is not None:
82 | sample = self.data_transform(sample)
83 |
84 | return sample
85 |
86 | def __len__(self):
87 | return len(self.img_list)
88 |
89 |
90 | # dataset that supports multiple images
91 | class DataFolderTest(data.Dataset):
92 | def __init__(self, dir_list, post_fix, num_channels, data_transform=None, loader=img_loader):
93 | """
94 | :param dir_list: [img_dir, label_voronoi_dir, label_cluster_dir]
95 | :param post_fix: ['label_vor.png', 'label_cluster.png']
96 | :param num_channels: [3, 3, 3]
97 | :param data_transform: data transformations
98 | :param loader: image loader
99 | """
100 | super(DataFolderTest, self).__init__()
101 | if len(dir_list) != len(num_channels):
102 | raise (RuntimeError('Length of dir_list is different from length of num_channels.'))
103 |
104 | self.img_list = get_imgs_list(dir_list, post_fix)
105 | if len(self.img_list) == 0:
106 | raise(RuntimeError('Found 0 image pairs in given directories.'))
107 | self.data_transform = data_transform
108 | self.num_channels = num_channels
109 | self.loader = loader
110 |
111 |
112 | def __getitem__(self, index):
113 | img_paths = self.img_list[index]
114 | sample = [self.loader(img_paths[i], self.num_channels[i]) for i in range(len(img_paths))]
115 |
116 | if self.data_transform is not None:
117 | sample = self.data_transform(sample)
118 |
119 | return sample
120 |
121 | def __len__(self):
122 | return len(self.img_list)
123 |
--------------------------------------------------------------------------------
/dataloaders/my_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | from PIL import Image, ImageOps
4 | import numpy as np
5 | import numbers
6 |
7 | def get_transforms(p):
8 | """ data transforms for train, validation and test
9 | p: transform dictionary
10 | """
11 | t_list = list()
12 | if 'random_resize' in p:
13 | t_list.append(RandomResize(p['random_resize'][0], p['random_resize'][1]))
14 |
15 | if 'horizontal_flip' in p:
16 | t_list.append(RandomHorizontalFlip())
17 |
18 | if 'vertical_flip' in p:
19 | t_list.append(RandomVerticalFlip())
20 |
21 | if 'random_affine' in p:
22 | t_list.append(RandomAffine(p['random_affine']))
23 |
24 | if 'random_rotation' in p:
25 | t_list.append(RandomRotation(p['random_rotation']))
26 |
27 | if 'random_crop' in p:
28 | t_list.append(RandomCrop(p['random_crop']))
29 |
30 | if 'label_encoding' in p:
31 | t_list.append(LabelEncoding(p['label_encoding']))
32 |
33 | if 'to_tensor' in p:
34 | t_list.append(ToTensor(p['to_tensor']))
35 |
36 | if 'normalize' in p:
37 | t_list.append(Normalize(mean=p['normalize'][0], std=p['normalize'][1]))
38 |
39 | return Compose(t_list)
40 |
41 |
42 | class Compose(object):
43 | """ Composes several transforms together.
44 | Args:
45 | transforms (list of ``Transform`` objects): list of transforms to compose.
46 | """
47 |
48 | def __init__(self, transforms):
49 | self.transforms = transforms
50 |
51 | def __call__(self, imgs):
52 | for t in self.transforms:
53 | imgs = t(imgs)
54 | return imgs
55 |
56 |
57 | class ToTensor(object):
58 | """ Convert (img, label) of type ``PIL.Image`` or ``numpy.ndarray`` to tensors.
59 | Converts img of type PIL.Image or numpy.ndarray (H x W x C) in the range
60 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
61 | Converts label of type PIL.Image or numpy.ndarray (H x W) in the range [0, 255]
62 | to a torch.LongTensor of shape (H x W) in the range [0, 255].
63 | """
64 | def __init__(self, index=1):
65 | self.index = index # index to distinguish between images and labels
66 |
67 | def __call__(self, imgs):
68 | """
69 | Args:
70 | imgs (PIL.Image or numpy.ndarray): Image to be converted to tensor.
71 | Returns:
72 | Tensor: Converted image.
73 | """
74 | if len(imgs) < self.index:
75 | raise ValueError('The number of images is smaller than separation index!')
76 |
77 | pics = []
78 |
79 | # process image
80 | for i in range(0, self.index):
81 | img = imgs[i]
82 | if isinstance(img, np.ndarray):
83 | # handle numpy array
84 | pic = torch.from_numpy(img.transpose((2, 0, 1)))
85 | # backward compatibility
86 | pics.append(pic.float().div(255))
87 |
88 | # handle PIL Image
89 | if img.mode == 'I':
90 | pic = torch.from_numpy(np.array(img, np.int32, copy=False))
91 | elif img.mode == 'I;16':
92 | pic = torch.from_numpy(np.array(img, np.int16, copy=False))
93 | else:
94 | pic = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes()))
95 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
96 | if img.mode == 'YCbCr':
97 | nchannel = 3
98 | elif img.mode == 'I;16':
99 | nchannel = 1
100 | else:
101 | nchannel = len(img.mode)
102 | pic = pic.view(img.size[1], img.size[0], nchannel)
103 | # put it from HWC to CHW format
104 | # yikes, this transpose takes 80% of the loading time/CPU
105 | pic = pic.transpose(0, 1).transpose(0, 2).contiguous()
106 | if isinstance(pic, torch.ByteTensor):
107 | pics.append(pic.float().div(255))
108 | else:
109 | pics.append(pic)
110 |
111 | # process labels:
112 | for i in range(self.index, len(imgs)):
113 | # process label
114 | label = imgs[i]
115 | if isinstance(label, np.ndarray):
116 | # handle numpy array
117 | label_tensor = torch.from_numpy(label)
118 | # backward compatibility
119 | pics.append(label_tensor.long())
120 |
121 | # handle PIL Image
122 | if label.mode == 'I':
123 | label_tensor = torch.from_numpy(np.array(label, np.int32, copy=False))
124 | elif label.mode == 'I;16':
125 | label_tensor = torch.from_numpy(np.array(label, np.int16, copy=False))
126 | else:
127 | label_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(label.tobytes()))
128 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
129 | if label.mode == 'YCbCr':
130 | nchannel = 3
131 | elif label.mode == 'I;16':
132 | nchannel = 1
133 | else:
134 | nchannel = len(label.mode)
135 | label_tensor = label_tensor.view(label.size[1], label.size[0], nchannel)
136 | # put it from HWC to CHW format
137 | # yikes, this transpose takes 80% of the loading time/CPU
138 | label_tensor = label_tensor.transpose(0, 1).transpose(0, 2).contiguous()
139 | # label_tensor = label_tensor.view(label.size[1], label.size[0])
140 | pics.append(label_tensor.long())
141 |
142 | return tuple(pics)
143 |
144 |
145 | class Normalize(object):
146 | """ Normalize an tensor image with mean and standard deviation.
147 | Given mean and std, will normalize each channel of the torch.*Tensor,
148 | i.e. channel = (channel - mean) / std
149 | Args:
150 | mean (sequence): Sequence of means for each channel.
151 | std (sequence): Sequence of standard deviations for each channel.
152 | ** only normalize the first image, keep the target image unchanged
153 | """
154 |
155 | def __init__(self, mean, std):
156 | self.mean = mean
157 | self.std = std
158 |
159 | def __call__(self, tensors):
160 | """
161 | Args:
162 | tensors (Tensor): Tensor images of size (C, H, W) to be normalized.
163 | Returns:
164 | Tensor: Normalized image.
165 | """
166 | tensors = list(tensors)
167 | for t, m, s in zip(tensors[0], self.mean, self.std):
168 | t.sub_(m).div_(s)
169 | return tuple(tensors)
170 |
171 |
172 | class RandomCrop(object):
173 | """Crop the given PIL.Image at a random location.
174 | Args:
175 | size (sequence or int): Desired output size of the crop. If size is an
176 | int instead of sequence like (w, h), a square crop (size, size) is
177 | made.
178 | padding (int or sequence, optional): Optional padding on each border
179 | of the image. Default is 0, i.e no padding. If a sequence of length
180 | 4 is provided, it is used to pad left, top, right, bottom borders
181 | respectively.
182 | """
183 |
184 | def __init__(self, size, padding=0, fill_val=(0,)):
185 | if isinstance(size, numbers.Number):
186 | self.size = (int(size), int(size))
187 | else:
188 | self.size = size
189 | self.padding = padding
190 | self.fill_val = fill_val
191 |
192 | def __call__(self, imgs):
193 | """
194 | Args:
195 | img (PIL.Image): Image to be cropped.
196 | Returns:
197 | PIL.Image: Cropped image.
198 | """
199 | pics = []
200 |
201 | w, h = imgs[0].size
202 | th, tw = self.size
203 | x1 = random.randint(0, w - tw)
204 | y1 = random.randint(0, h - th)
205 |
206 | for k in range(len(imgs)):
207 | img = imgs[k]
208 | if self.padding > 0:
209 | img = ImageOps.expand(img, border=self.padding, fill=self.fill_val[k])
210 |
211 | if w == tw and h == th:
212 | pics.append(img)
213 | continue
214 |
215 | pics.append(img.crop((x1, y1, x1 + tw, y1 + th)))
216 |
217 | return tuple(pics)
218 |
219 |
220 | class RandomHorizontalFlip(object):
221 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
222 |
223 | def __call__(self, imgs):
224 | """
225 | Args:
226 | img (PIL.Image): Image to be flipped.
227 | Returns:
228 | PIL.Image: Randomly flipped image.
229 | """
230 | pics = []
231 | if random.random() < 0.5:
232 | for img in imgs:
233 | pics.append(img.transpose(Image.FLIP_LEFT_RIGHT))
234 | return tuple(pics)
235 | else:
236 | return imgs
237 |
238 |
239 | class RandomVerticalFlip(object):
240 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""
241 |
242 | def __call__(self, imgs):
243 | """
244 | Args:
245 | img (PIL.Image): Image to be flipped.
246 | Returns:
247 | PIL.Image: Randomly flipped image.
248 | """
249 | pics = []
250 | if random.random() < 0.5:
251 | for img in imgs:
252 | pics.append(img.transpose(Image.FLIP_TOP_BOTTOM))
253 | return tuple(pics)
254 | else:
255 | return imgs
256 |
257 |
258 | class RandomRotation(object):
259 | """Rotate the image by angle.
260 | Args:
261 | degrees (sequence or float or int): Range of degrees to select from.
262 | If degrees is a number instead of sequence like (min, max), the range of degrees
263 | will be (-degrees, +degrees).
264 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
265 | An optional resampling filter.
266 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters
267 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
268 | expand (bool, optional): Optional expansion flag.
269 | If true, expands the output to make it large enough to hold the entire rotated image.
270 | If false or omitted, make the output image the same size as the input image.
271 | Note that the expand flag assumes rotation around the center and no translation.
272 | center (2-tuple, optional): Optional center of rotation.
273 | Origin is the upper left corner.
274 | Default is the center of the image.
275 | """
276 |
277 | def __init__(self, degrees, resample=Image.BILINEAR, expand=False, center=None):
278 | if isinstance(degrees, numbers.Number):
279 | if degrees < 0:
280 | raise ValueError("If degrees is a single number, it must be positive.")
281 | self.degrees = (-degrees, degrees)
282 | else:
283 | if len(degrees) != 2:
284 | raise ValueError("If degrees is a sequence, it must be of len 2.")
285 | self.degrees = degrees
286 |
287 | self.resample = resample
288 | self.expand = expand
289 | self.center = center
290 |
291 | @staticmethod
292 | def get_params(degrees):
293 | """Get parameters for ``rotate`` for a random rotation.
294 | Returns:
295 | sequence: params to be passed to ``rotate`` for random rotation.
296 | """
297 | angle = random.uniform(degrees[0], degrees[1])
298 |
299 | return angle
300 |
301 | def __call__(self, imgs):
302 | """
303 | imgs (PIL Image): Images to be rotated.
304 | Returns:
305 | PIL Image: Rotated image.
306 | """
307 |
308 | angle = self.get_params(self.degrees)
309 |
310 | pics = []
311 | for img in imgs:
312 | pics.append(img.rotate(angle, self.resample, self.expand, self.center))
313 |
314 | # process the binary label
315 | # pics[1] = pics[1].point(lambda p: p > 127.5 and 255)
316 |
317 | return tuple(pics)
318 |
319 |
320 | class RandomResize(object):
321 | """Randomly Resize the input PIL Image using a scale of lb~ub.
322 | Args:
323 | size (sequence or int): Desired output size. If size is a sequence like
324 | (h, w), output size will be matched to this. If size is an int,
325 | smaller edge of the image will be matched to this number.
326 | i.e, if height > width, then image will be rescaled to
327 | (size * height / width, size)
328 | interpolation (int, optional): Desired interpolation. Default is
329 | ``PIL.Image.BILINEAR``
330 | """
331 |
332 | def __init__(self, lb=0.5, ub=1.5, interpolation=Image.BILINEAR):
333 | self.lb = lb
334 | self.ub = ub
335 | self.interpolation = interpolation
336 |
337 | def __call__(self, imgs):
338 | """
339 | Args:
340 | imgs (PIL Images): Images to be scaled.
341 | Returns:
342 | PIL Images: Rescaled images.
343 | """
344 |
345 | for img in imgs:
346 | if not isinstance(img, Image.Image):
347 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
348 |
349 | scale = random.uniform(self.lb, self.ub)
350 | # print scale
351 |
352 | w, h = imgs[0].size
353 | ow = int(w * scale)
354 | oh = int(h * scale)
355 |
356 | if scale < 1:
357 | padding_l = (w - ow)//2
358 | padding_t = (h - oh)//2
359 | padding_r = w - ow - padding_l
360 | padding_b = h - oh - padding_t
361 | padding = (padding_l, padding_t, padding_r, padding_b)
362 |
363 | pics = []
364 | for i in range(len(imgs)):
365 | img = imgs[i]
366 | img = img.resize((ow, oh), self.interpolation)
367 | if scale < 1:
368 | img = ImageOps.expand(img, border=padding, fill=0)
369 | pics.append(img)
370 |
371 | return tuple(pics)
372 |
373 |
374 | class RandomAffine(object):
375 | """ Transform the input PIL Image using a random affine transformation
376 | The parameters of an affine transformation [a, b, c=0
377 | d, e, f=0]
378 | are generated randomly according to the bound, and there is no translation
379 | (c=f=0)
380 | Args:
381 | bound: the largest possible deviation of random parameters
382 | """
383 |
384 | def __init__(self, bound):
385 | if bound < 0 or bound > 0.5:
386 | raise ValueError("Bound is invalid, should be in range [0, 0.5)")
387 |
388 | self.bound = bound
389 |
390 | def __call__(self, imgs):
391 | img = imgs[0]
392 | x, y = img.size
393 |
394 | a = 1 + 2 * self.bound * (random.random() - 0.5)
395 | b = 2 * self.bound * (random.random() - 0.5)
396 | d = 2 * self.bound * (random.random() - 0.5)
397 | e = 1 + 2 * self.bound * (random.random() - 0.5)
398 |
399 | # correct the transformation center to image center
400 | c = -a * x / 2 - b * y / 2 + x / 2
401 | f = -d * x / 2 - e * y / 2 + y / 2
402 |
403 | trans_matrix = [a, b, c, d, e, f]
404 |
405 | pics = []
406 | for img in imgs:
407 | pics.append(img.transform((x, y), Image.AFFINE, trans_matrix))
408 |
409 | return tuple(pics)
410 |
411 |
412 | class LabelEncoding(object):
413 | """
414 | encode the 3-channel labels into one channel integer label map
415 | """
416 | def __init__(self, num_labels=1):
417 | self.num_labels = num_labels
418 |
419 | def __call__(self, imgs):
420 | assert self.num_labels < len(imgs)
421 |
422 | out_imgs = list(imgs[:-self.num_labels])
423 | image = imgs[0] # input image
424 | if not isinstance(image, np.ndarray):
425 | image = np.array(image)
426 |
427 | for i in range(-self.num_labels, 0): # labels
428 | label = imgs[i]
429 | if not isinstance(label, np.ndarray):
430 | label = np.array(label)
431 | # print(len(label.shape))
432 |
433 | # ----- for estimated pixel-level label
434 | new_label = np.ones((label.shape[0], label.shape[1]), dtype=np.uint8) * 2
435 | new_label[label[:, :, 0] > 255 * 0.3] = 0
436 | new_label[label[:, :, 1] > 255 * 0.5] = 1
437 | new_label[(image[:, :, 0] == 0) * (image[:, :, 1] == 0) * (image[:, :, 2] == 0)] = 0
438 | new_label = Image.fromarray(new_label.astype(np.uint8))
439 | out_imgs.append(new_label)
440 |
441 | return tuple(out_imgs)
442 |
--------------------------------------------------------------------------------
/dataloaders/prepare_data.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import os
3 | import shutil
4 | import numpy as np
5 | from skimage import morphology, measure
6 | from sklearn.cluster import KMeans
7 | from scipy.ndimage.morphology import distance_transform_edt as dist_tranform
8 | import glob
9 | import json
10 |
11 |
12 | def main():
13 | dataset = 'MO' # LC: Lung Cancer, MO: MultiOrgan
14 | data_dir = './data/{:s}'.format(dataset)
15 | img_dir = './data/{:s}/images'.format(dataset)
16 | label_instance_dir = './data/{:s}/labels_instance'.format(dataset)
17 | label_point_dir = './data/{:s}/labels_point'.format(dataset)
18 | label_vor_dir = './data/{:s}/labels_voronoi'.format(dataset)
19 | label_cluster_dir = './data/{:s}/labels_cluster'.format(dataset)
20 | label_Hcomponent_dir = './data/{:s}/labels_Hcomponent'.format(dataset)
21 | patch_folder = './data/{:s}/patches'.format(dataset)
22 | train_data_dir = './data_for_train/{:s}'.format(dataset)
23 | create_folder('./data_for_train')
24 | create_folder(label_point_dir)
25 | create_folder(label_vor_dir)
26 | create_folder(label_cluster_dir)
27 | create_folder(patch_folder)
28 | create_folder(train_data_dir)
29 |
30 | with open('{:s}/train_val_test.json'.format(data_dir), 'r') as file:
31 | data_list = json.load(file)
32 | train_list = data_list['train']
33 |
34 | # ------ create H component image from original image
35 | create_h_component_image(img_dir, label_Hcomponent_dir)
36 |
37 | # ------ create point label from instance label
38 | create_point_label_from_instance(label_instance_dir, label_point_dir, train_list)
39 |
40 | # ------ create Voronoi label from point label
41 | create_Voronoi_label(label_point_dir, label_vor_dir, train_list)
42 |
43 | # ------ create cluster label from point label and image
44 | create_cluster_label(img_dir, label_point_dir, label_vor_dir, label_cluster_dir, train_list)
45 |
46 | # ------ split large images into 250x250 patches
47 | print("Spliting large images into small patches...")
48 | split_patches(img_dir, '{:s}/images'.format(patch_folder))
49 | split_patches(label_vor_dir, '{:s}/labels_voronoi'.format(patch_folder), 'label_vor')
50 | split_patches(label_cluster_dir, '{:s}/labels_cluster'.format(patch_folder), 'label_cluster')
51 | split_patches(label_Hcomponent_dir, '{:s}/labels_Hcomponent'.format(patch_folder), 'label_Hcomponent')
52 |
53 | # ------ divide dataset into train, val and test sets
54 | organize_data_for_training(data_dir, train_data_dir)
55 |
56 | # ------ compute mean and std
57 | compute_mean_std(data_dir, train_data_dir)
58 |
59 |
60 | def create_h_component_image(img_dir, save_dir):
61 | image_list = os.listdir(img_dir)
62 | N_total = len(image_list)
63 | N_processed = 0
64 |
65 | # define stain_matrix
66 | H = np.array([0.650, 0.704, 0.286])
67 | E = np.array([0.072, 0.990, 0.105])
68 | R = np.array([0.268, 0.570, 0.776])
69 | HDABtoRGB = [(H/np.linalg.norm(H)).tolist(), (E/np.linalg.norm(E)).tolist(), (R/np.linalg.norm(R)).tolist()]
70 | stain_matrix = HDABtoRGB
71 | im_inv = np.linalg.inv(stain_matrix)
72 |
73 | for image_name in image_list:
74 | N_processed += 1
75 | flag = '' if N_processed < N_total else '\n'
76 | print('\r\t{:d}/{:d}'.format(N_processed, N_total), end=flag)
77 |
78 | image_path = os.path.join(img_dir, image_name)
79 | image = imageio.imread(image_path)
80 |
81 | # transform
82 | im_temp = (-255)*np.log((np.float64(image)+1)/255)/np.log(255)
83 | image_out = np.reshape(np.dot(np.reshape(im_temp, [-1,3]), im_inv), np.shape(image))
84 | image_out = np.exp((255-image_out)*np.log(255)/255)
85 | image_out[image_out > 255] = 255
86 | image_h = image_out[:, :, 0].astype(np.uint8)
87 |
88 | imageio.imsave('{:s}/{:s}_h.png'.format(save_dir, image_name[:-4]), image_h)
89 |
90 |
91 | def create_point_label_from_instance(data_dir, save_dir, train_list):
92 | def get_point(img):
93 | a = np.where(img != 0)
94 | rmin, rmax, cmin, cmax = np.min(a[0]), np.max(a[0]), np.min(a[1]), np.max(a[1])
95 | return (rmin + rmax) // 2, (cmin + cmax) // 2
96 |
97 | print("Generating point label from instance label...")
98 | image_list = os.listdir(data_dir)
99 | N_total = len(train_list)
100 | N_processed = 0
101 | for image_name in image_list:
102 | name = image_name.split('.')[0]
103 | if '{:s}.png'.format(name[:-6]) not in train_list or name[-5:] != 'label':
104 | continue
105 |
106 | N_processed += 1
107 | flag = '' if N_processed < N_total else '\n'
108 | print('\r\t{:d}/{:d}'.format(N_processed, N_total), end=flag)
109 |
110 | image_path = os.path.join(data_dir, image_name)
111 | image = imageio.imread(image_path)
112 | h, w = image.shape
113 |
114 | # extract bbox
115 | id_max = np.max(image)
116 | label_point = np.zeros((h, w), dtype=np.uint8)
117 |
118 | for i in range(1, id_max + 1):
119 | nucleus = image == i
120 | if np.sum(nucleus) == 0:
121 | continue
122 | x, y = get_point(nucleus)
123 | label_point[x, y] = 255
124 |
125 | imageio.imsave('{:s}/{:s}_point.png'.format(save_dir, name), label_point.astype(np.uint8))
126 |
127 |
128 | def create_Voronoi_label(data_dir, save_dir, train_list):
129 | from scipy.spatial import Voronoi
130 | from shapely.geometry import Polygon
131 | from utils.utils import voronoi_finite_polygons_2d, poly2mask
132 | img_list = os.listdir(data_dir)
133 |
134 | print("Generating Voronoi label from point label...")
135 | N_total = len(train_list)
136 | N_processed = 0
137 | for img_name in sorted(img_list):
138 | name = img_name.split('.')[0]
139 | if '{:s}.png'.format(name[:-12]) not in train_list:
140 | continue
141 |
142 | N_processed += 1
143 | flag = '' if N_processed < N_total else '\n'
144 | print('\r\t{:d}/{:d}'.format(N_processed, N_total), end=flag)
145 |
146 | img_path = '{:s}/{:s}'.format(data_dir, img_name)
147 | label_point = imageio.imread(img_path)
148 | h, w = label_point.shape
149 |
150 | points = np.argwhere(label_point > 0)
151 | vor = Voronoi(points)
152 |
153 | regions, vertices = voronoi_finite_polygons_2d(vor)
154 | box = Polygon([[0, 0], [0, w], [h, w], [h, 0]])
155 | region_masks = np.zeros((h, w), dtype=np.int16)
156 | edges = np.zeros((h, w), dtype=np.bool)
157 | count = 1
158 | for region in regions:
159 | polygon = vertices[region]
160 | # Clipping polygon
161 | poly = Polygon(polygon)
162 | poly = poly.intersection(box)
163 | polygon = np.array([list(p) for p in poly.exterior.coords])
164 |
165 | mask = poly2mask(polygon[:, 0], polygon[:, 1], (h, w))
166 | edge = mask * (~morphology.erosion(mask, morphology.disk(1)))
167 | edges += edge
168 | region_masks[mask] = count
169 | count += 1
170 |
171 | # fuse Voronoi edge and dilated points
172 | label_point_dilated = morphology.dilation(label_point, morphology.disk(2))
173 | label_vor = np.zeros((h, w, 3), dtype=np.uint8)
174 | label_vor[:, :, 0] = (edges > 0).astype(np.uint8) * 255
175 | label_vor[:, :, 1] = (label_point_dilated > 0).astype(np.uint8) * 255
176 |
177 | name = img_name.split('.')[0]
178 | imageio.imsave('{:s}/{:s}_label_vor.png'.format(save_dir, name[:-12]), label_vor)
179 |
180 |
181 | def create_cluster_label(data_dir, label_point_dir, label_vor_dir, save_dir, train_list):
182 | from scipy.ndimage import morphology as ndi_morph
183 |
184 | img_list = os.listdir(data_dir)
185 | print("Generating cluster label from point label...")
186 | N_total = len(train_list)
187 | N_processed = 0
188 | for img_name in sorted(img_list):
189 | if img_name not in train_list:
190 | continue
191 |
192 | N_processed += 1
193 | flag = '' if N_processed < N_total else '\n'
194 | print('\r\t{:d}/{:d}'.format(N_processed, N_total), end=flag)
195 |
196 | # print('\t[{:d}/{:d}] Processing image {:s} ...'.format(count, len(img_list), img_name))
197 | ori_image = imageio.imread('{:s}/{:s}'.format(data_dir, img_name))
198 | h, w, _ = ori_image.shape
199 | label_point = imageio.imread('{:s}/{:s}_label_point.png'.format(label_point_dir, img_name[:-4]))
200 |
201 | # k-means clustering
202 | dist_embeddings = dist_tranform(255 - label_point).reshape(-1, 1)
203 | clip_dist_embeddings = np.clip(dist_embeddings, a_min=0, a_max=20)
204 | color_embeddings = np.array(ori_image, dtype=np.float).reshape(-1, 3) / 10
205 | embeddings = np.concatenate((color_embeddings, clip_dist_embeddings), axis=1)
206 |
207 | # print("\t\tPerforming k-means clustering...")
208 | kmeans = KMeans(n_clusters=3, random_state=0).fit(embeddings)
209 | clusters = np.reshape(kmeans.labels_, (h, w))
210 |
211 | # get nuclei and background clusters
212 | overlap_nums = [np.sum((clusters == i) * label_point) for i in range(3)]
213 | nuclei_idx = np.argmax(overlap_nums)
214 | remain_indices = np.delete(np.arange(3), nuclei_idx)
215 | dilated_label_point = morphology.binary_dilation(label_point, morphology.disk(5))
216 | overlap_nums = [np.sum((clusters == i) * dilated_label_point) for i in remain_indices]
217 | background_idx = remain_indices[np.argmin(overlap_nums)]
218 |
219 | nuclei_cluster = clusters == nuclei_idx
220 | background_cluster = clusters == background_idx
221 |
222 | nuclei_labeled = measure.label(nuclei_cluster)
223 | initial_nuclei = morphology.remove_small_objects(nuclei_labeled, 30)
224 | refined_nuclei = np.zeros(initial_nuclei.shape, dtype=np.bool)
225 |
226 | label_vor = imageio.imread('{:s}/{:s}_label_vor.png'.format(label_vor_dir, img_name[:-4]))
227 | voronoi_cells = measure.label(label_vor[:, :, 0] == 0)
228 | voronoi_cells = morphology.dilation(voronoi_cells, morphology.disk(2))
229 |
230 | unique_vals = np.unique(voronoi_cells)
231 | cell_indices = unique_vals[unique_vals != 0]
232 | N = len(cell_indices)
233 | for i in range(N):
234 | cell_i = voronoi_cells == cell_indices[i]
235 | nucleus_i = cell_i * initial_nuclei
236 |
237 | nucleus_i_dilated = morphology.binary_dilation(nucleus_i, morphology.disk(5))
238 | nucleus_i_dilated_filled = ndi_morph.binary_fill_holes(nucleus_i_dilated)
239 | nucleus_i_final = morphology.binary_erosion(nucleus_i_dilated_filled, morphology.disk(7))
240 | refined_nuclei += nucleus_i_final > 0
241 |
242 | refined_label = np.zeros((h, w, 3), dtype=np.uint8)
243 | label_point_dilated = morphology.dilation(label_point, morphology.disk(10))
244 | refined_label[:, :, 0] = (background_cluster * (refined_nuclei == 0) * (label_point_dilated == 0)).astype(np.uint8) * 255
245 | refined_label[:, :, 1] = refined_nuclei.astype(np.uint8) * 255
246 |
247 | imageio.imsave('{:s}/{:s}_label_cluster.png'.format(save_dir, img_name[:-4]), refined_label)
248 |
249 |
250 | def split_patches(data_dir, save_dir, post_fix=None):
251 | import math
252 | """ split large image into small patches """
253 | create_folder(save_dir)
254 |
255 | image_list = os.listdir(data_dir)
256 | for image_name in image_list:
257 | name = image_name.split('.')[0]
258 | if post_fix and name[-len(post_fix):] != post_fix:
259 | continue
260 | image_path = os.path.join(data_dir, image_name)
261 | image = imageio.imread(image_path)
262 | seg_imgs = []
263 |
264 | # split into 16 patches of size 250x250
265 | h, w = image.shape[0], image.shape[1]
266 | patch_size = 250
267 | h_overlap = math.ceil((4 * patch_size - h) / 3)
268 | w_overlap = math.ceil((4 * patch_size - w) / 3)
269 | for i in range(0, h-patch_size+1, patch_size-h_overlap):
270 | for j in range(0, w-patch_size+1, patch_size-w_overlap):
271 | if len(image.shape) == 3:
272 | patch = image[i:i+patch_size, j:j+patch_size, :]
273 | else:
274 | patch = image[i:i + patch_size, j:j + patch_size]
275 | seg_imgs.append(patch)
276 |
277 | for k in range(len(seg_imgs)):
278 | if post_fix:
279 | imageio.imsave('{:s}/{:s}_{:d}_{:s}.png'.format(save_dir, name[:-len(post_fix)-1], k, post_fix), seg_imgs[k])
280 | else:
281 | imageio.imsave('{:s}/{:s}_{:d}.png'.format(save_dir, name, k), seg_imgs[k])
282 |
283 |
284 | def organize_data_for_training(data_dir, train_data_dir):
285 | # --- Step 1: create folders --- #
286 | create_folder(train_data_dir)
287 | create_folder('{:s}/images'.format(train_data_dir))
288 | create_folder('{:s}/labels_voronoi'.format(train_data_dir))
289 | create_folder('{:s}/labels_cluster'.format(train_data_dir))
290 | create_folder('{:s}/labels_Hcomponent'.format(train_data_dir))
291 | create_folder('{:s}/images/train'.format(train_data_dir))
292 | create_folder('{:s}/images/val'.format(train_data_dir))
293 | create_folder('{:s}/images/test'.format(train_data_dir))
294 | create_folder('{:s}/labels_voronoi/train'.format(train_data_dir))
295 | create_folder('{:s}/labels_cluster/train'.format(train_data_dir))
296 | create_folder('{:s}/labels_Hcomponent/train'.format(train_data_dir))
297 |
298 | # --- Step 2: move images and labels to each folder --- #
299 | print('Organizing data for training...')
300 | with open('{:s}/train_val_test.json'.format(data_dir), 'r') as file:
301 | data_list = json.load(file)
302 | train_list, val_list, test_list = data_list['train'], data_list['val'], data_list['test']
303 |
304 | # train
305 | for img_name in train_list:
306 | name = img_name.split('.')[0]
307 | # images
308 | for file in glob.glob('{:s}/patches/images/{:s}*'.format(data_dir, name)):
309 | file_name = file.split('\\')[-1]
310 | dst = '{:s}/images/train/{:s}'.format(train_data_dir, file_name)
311 | shutil.copyfile(file, dst)
312 | # label_voronoi
313 | for file in glob.glob('{:s}/patches/labels_voronoi/{:s}*'.format(data_dir, name)):
314 | file_name = file.split('\\')[-1]
315 | dst = '{:s}/labels_voronoi/train/{:s}'.format(train_data_dir, file_name)
316 | shutil.copyfile(file, dst)
317 | # label_cluster
318 | for file in glob.glob('{:s}/patches/labels_cluster/{:s}*'.format(data_dir, name)):
319 | file_name = file.split('\\')[-1]
320 | dst = '{:s}/labels_cluster/train/{:s}'.format(train_data_dir, file_name)
321 | shutil.copyfile(file, dst)
322 | # label_Hcomponent
323 | for file in glob.glob('{:s}/patches/labels_Hcomponent/{:s}*'.format(data_dir, name)):
324 | file_name = file.split('\\')[-1]
325 | dst = '{:s}/labels_Hcomponent/train/{:s}'.format(train_data_dir, file_name)
326 | shutil.copyfile(file, dst)
327 | # val
328 | for img_name in val_list:
329 | name = img_name.split('.')[0]
330 | # images
331 | for file in glob.glob('{:s}/images/{:s}*'.format(data_dir, name)):
332 | file_name = file.split('\\')[-1]
333 | dst = '{:s}/images/val/{:s}'.format(train_data_dir, file_name)
334 | shutil.copyfile(file, dst)
335 | # test
336 | for img_name in test_list:
337 | name = img_name.split('.')[0]
338 | # images
339 | for file in glob.glob('{:s}/images/{:s}*'.format(data_dir, name)):
340 | file_name = file.split('\\')[-1]
341 | dst = '{:s}/images/test/{:s}'.format(train_data_dir, file_name)
342 | shutil.copyfile(file, dst)
343 |
344 |
345 | def compute_mean_std(data_dir, train_data_dir):
346 | """ compute mean and standarad deviation of training images """
347 | total_sum = np.zeros(3) # total sum of all pixel values in each channel
348 | total_square_sum = np.zeros(3)
349 | num_pixel = 0 # total num of all pixels
350 |
351 | with open('{:s}/train_val_test.json'.format(data_dir), 'r') as file:
352 | data_list = json.load(file)
353 | train_list = data_list['train']
354 |
355 | print('Computing the mean and standard deviation of training data...')
356 |
357 | for file_name in train_list:
358 | img_name = '{:s}/images/{:s}'.format(data_dir, file_name)
359 | img = imageio.imread(img_name)
360 | if len(img.shape) != 3 or img.shape[2] < 3:
361 | continue
362 | img = img[:, :, :3].astype(int)
363 | total_sum += img.sum(axis=(0, 1))
364 | total_square_sum += (img ** 2).sum(axis=(0, 1))
365 | num_pixel += img.shape[0] * img.shape[1]
366 |
367 | # compute the mean values of each channel
368 | mean_values = total_sum / num_pixel
369 |
370 | # compute the standard deviation
371 | std_values = np.sqrt(total_square_sum / num_pixel - mean_values ** 2)
372 |
373 | # normalization
374 | mean_values = mean_values / 255
375 | std_values = std_values / 255
376 |
377 | np.save('{:s}/mean_std.npy'.format(train_data_dir), np.array([mean_values, std_values]))
378 | np.savetxt('{:s}/mean_std.txt'.format(train_data_dir), np.array([mean_values, std_values]), '%.4f', '\t')
379 |
380 |
381 | def create_folder(folder):
382 | if not os.path.exists(folder):
383 | os.mkdir(folder)
384 |
385 |
386 | if __name__ == '__main__':
387 | main()
388 |
--------------------------------------------------------------------------------
/figures/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hust-linyi/SC-Net/90f2a1cc62310c1799795845e239df74b9564bd4/figures/overview.png
--------------------------------------------------------------------------------
/model/modelW.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from model.resnet import resnet34
6 | from model.resnet1 import resnet34 as Resnet34
7 |
8 |
9 | class dilated_conv(nn.Module):
10 | """ same as original conv if dilation equals to 1 """
11 | def __init__(self, in_channel, out_channel, kernel_size=3, dropout_rate=0.0, activation=F.relu, dilation=1):
12 | super().__init__()
13 | self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, padding=dilation, dilation=dilation)
14 | self.norm = nn.BatchNorm2d(out_channel)
15 | self.activation = activation
16 | if dropout_rate > 0:
17 | self.drop = nn.Dropout2d(p=dropout_rate)
18 | else:
19 | self.drop = lambda x: x # no-op
20 |
21 | def forward(self, x):
22 | # CAB: conv -> activation -> batch normal
23 | x = self.norm(self.activation(self.conv(x)))
24 | x = self.drop(x)
25 | return x
26 |
27 |
28 | class ConvDownBlock(nn.Module):
29 | def __init__(self, in_channel, out_channel, dropout_rate=0.0, dilation=1):
30 | super().__init__()
31 | self.conv1 = dilated_conv(in_channel, out_channel, dropout_rate=dropout_rate, dilation=dilation)
32 | self.conv2 = dilated_conv(out_channel, out_channel, dropout_rate=dropout_rate, dilation=dilation)
33 | self.pool = nn.MaxPool2d(kernel_size=2)
34 |
35 | def forward(self, x):
36 | x = self.conv1(x)
37 | x = self.conv2(x)
38 | return self.pool(x), x
39 |
40 |
41 | class ConvUpBlock(nn.Module):
42 | def __init__(self, in_channel, out_channel, dropout_rate=0.0, dilation=1):
43 | super().__init__()
44 | self.up = nn.ConvTranspose2d(in_channel, in_channel // 2, 2, stride=2)
45 | self.conv1 = dilated_conv(in_channel // 2 + out_channel, out_channel, dropout_rate=dropout_rate, dilation=dilation)
46 | self.conv2 = dilated_conv(out_channel, out_channel, dropout_rate=dropout_rate, dilation=dilation)
47 |
48 | def forward(self, x, x_skip):
49 | x = self.up(x)
50 | H_diff = x.shape[2] - x_skip.shape[2]
51 | W_diff = x.shape[3] - x_skip.shape[3]
52 | x_skip = F.pad(x_skip, (0, W_diff, 0, H_diff), mode='reflect')
53 | x = torch.cat([x, x_skip], 1)
54 | x = self.conv1(x)
55 | x = self.conv2(x)
56 | return x
57 |
58 |
59 | # Transfer Learning ResNet as Encoder part of UNet
60 | class ResWNet34(nn.Module):
61 | def __init__(self, seg_classes = 2, colour_classes = 3, fixed_feature=False):
62 | super().__init__()
63 | # load weight of pre-trained resnet
64 | self.resnet = resnet34(pretrained=False)
65 | self.resnet1 = Resnet34(pretrained=False)
66 | if fixed_feature:
67 | for param in self.resnet.parameters():
68 | param.requires_grad = False
69 | # up conv
70 | l = [64, 64, 128, 256, 512]
71 | self.u5 = ConvUpBlock(l[4], l[3], dropout_rate=0.1)
72 | self.u6 = ConvUpBlock(l[3], l[2], dropout_rate=0.1)
73 | self.u7 = ConvUpBlock(l[2], l[1], dropout_rate=0.1)
74 | self.u8 = ConvUpBlock(l[1], l[0], dropout_rate=0.1)
75 | # final conv
76 | self.seg = nn.ConvTranspose2d(l[0], seg_classes, 2, stride=2)
77 | self.colour = nn.ConvTranspose2d(l[0], colour_classes, 2, stride=2)
78 | self.sigmoid = nn.Sigmoid()
79 | self.softmax = nn.Softmax(dim = 1)
80 | def forward(self, x):
81 | # refer https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
82 | x = self.resnet.conv1(x)
83 | x = self.resnet.bn1(x)
84 | x = s1 = self.resnet.relu(x)
85 | x = self.resnet.maxpool(x)
86 | x = s2 = self.resnet.layer1(x)
87 | x = s3 = self.resnet.layer2(x)
88 | x = s4 = self.resnet.layer3(x)
89 | x = self.resnet.layer4(x)
90 |
91 | x1 = self.u5(x, s4)
92 | x1 = self.u6(x1, s3)
93 | x1 = self.u7(x1, s2)
94 | x1 = self.u8(x1, s1)
95 | x1 = self.seg(x1)
96 |
97 | y_input = self.softmax(x1)
98 | y = self.resnet1.conv1(y_input)
99 | y = self.resnet1.bn1(y)
100 | y = c1 = self.resnet1.relu(y)
101 | y = self.resnet1.maxpool(y)
102 | y = c2 = self.resnet1.layer1(y)
103 | y = c3 = self.resnet1.layer2(y)
104 | y = c4 = self.resnet1.layer3(y)
105 | y = self.resnet1.layer4(y)
106 |
107 | y1 = self.u5(y, c4)
108 | y1 = self.u6(y1, c3)
109 | y1 = self.u7(y1, c2)
110 | y1 = self.u8(y1, c1)
111 | y1 = self.colour(y1)
112 | y1 = self.sigmoid(y1)
113 |
114 | return x1, y1
115 |
--------------------------------------------------------------------------------
/model/model_WNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class DoubleConv(nn.Module):
6 | """(convolution => [BN] => ReLU) * 2"""
7 |
8 | def __init__(self, in_channels, out_channels, mid_channels=None):
9 | super().__init__()
10 | if not mid_channels:
11 | mid_channels = out_channels
12 | self.double_conv = nn.Sequential(
13 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
14 | nn.BatchNorm2d(mid_channels),
15 | nn.ReLU(inplace=True),
16 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
17 | nn.BatchNorm2d(out_channels),
18 | nn.ReLU(inplace=True)
19 | )
20 |
21 | def forward(self, x):
22 | return self.double_conv(x)
23 |
24 |
25 | class Down(nn.Module):
26 | """Downscaling with maxpool then double conv"""
27 |
28 | def __init__(self, in_channels, out_channels):
29 | super().__init__()
30 | self.maxpool_conv = nn.Sequential(
31 | nn.MaxPool2d(2),
32 | DoubleConv(in_channels, out_channels)
33 | )
34 |
35 | def forward(self, x):
36 | return self.maxpool_conv(x)
37 |
38 |
39 | class Up(nn.Module):
40 | """Upscaling then double conv"""
41 |
42 | def __init__(self, in_channels, out_channels, bilinear=True):
43 | super().__init__()
44 |
45 | # if bilinear, use the normal convolutions to reduce the number of channels
46 | if bilinear:
47 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
48 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
49 | else:
50 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
51 | self.conv = DoubleConv(in_channels, out_channels)
52 |
53 |
54 | def forward(self, x1, x2):
55 | x1 = self.up(x1)
56 | # input is CHW
57 | diffY = x2.size()[2] - x1.size()[2]
58 | diffX = x2.size()[3] - x1.size()[3]
59 |
60 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
61 | diffY // 2, diffY - diffY // 2])
62 |
63 | x = torch.cat([x2, x1], dim=1)
64 | return self.conv(x)
65 |
66 | class OutConv(nn.Module):
67 | def __init__(self, in_channels, out_channels):
68 | super(OutConv, self).__init__()
69 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
70 | def forward(self, x):
71 | return self.conv(x)
72 |
73 | class OutConvy(nn.Module):
74 | def __init__(self, in_channels, out_channels):
75 | super(OutConvy, self).__init__()
76 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
77 | self.sigmoid = nn.Sigmoid()
78 | def forward(self, x):
79 | x = self.conv(x)
80 | x = self.sigmoid(x)
81 | return x
82 |
83 | class WNet(nn.Module):
84 | def __init__(self, n_channels, seg_classes, colour_classes, bilinear=True):
85 | super(WNet, self).__init__()
86 | self.n_channels = n_channels
87 | self.seg_classes = seg_classes
88 | self.colour_classes = colour_classes
89 | self.bilinear = bilinear
90 |
91 | self.inc = DoubleConv(n_channels, 64)
92 | self.down1 = Down(64, 128)
93 | self.down2 = Down(128, 256)
94 | self.down3 = Down(256, 512)
95 | factor = 2 if bilinear else 1
96 | self.down4 = Down(512, 1024 // factor)
97 | self.up1 = Up(1024, 512 // factor, bilinear)
98 | self.up2 = Up(512, 256 // factor, bilinear)
99 | self.up3 = Up(256, 128 // factor, bilinear)
100 | self.up4 = Up(128, 64, bilinear)
101 | self.outc = OutConv(64, seg_classes)
102 |
103 | self.inc1 = DoubleConv(seg_classes, 64)
104 | self.down11 = Down(64, 128)
105 | self.down21 = Down(128, 256)
106 | self.down31 = Down(256, 512)
107 | factor = 2 if bilinear else 1
108 | self.down41 = Down(512, 1024 // factor)
109 | self.up11 = Up(1024, 512 // factor, bilinear)
110 | self.up21 = Up(512, 256 // factor, bilinear)
111 | self.up31 = Up(256, 128 // factor, bilinear)
112 | self.up41 = Up(128, 64, bilinear)
113 | self.outcy = OutConvy(64, colour_classes)
114 | self.softmax = nn.Softmax(dim = 1)
115 |
116 | def forward(self, x):
117 | x1 = self.inc(x)
118 | x2 = self.down1(x1)
119 | x3 = self.down2(x2)
120 | x4 = self.down3(x3)
121 | x5 = self.down4(x4)
122 |
123 | x = self.up1(x5, x4)
124 | x = self.up2(x, x3)
125 | x = self.up3(x, x2)
126 | x = self.up4(x, x1)
127 | x = self.outc(x)
128 |
129 | y_input = self.softmax(x)
130 | y1 = self.inc1(y_input)
131 | y2 = self.down11(y1)
132 | y3 = self.down21(y2)
133 | y4 = self.down31(y3)
134 | y5 = self.down41(y4)
135 |
136 | y = self.up11(y5, y4)
137 | y = self.up21(y, y3)
138 | y = self.up31(y, y2)
139 | y = self.up41(y, y1)
140 | y = self.outcy(y)
141 |
142 | return x, y
143 |
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.hub import load_state_dict_from_url
4 |
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
8 | 'wide_resnet50_2', 'wide_resnet101_2']
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
23 | """3x3 convolution with padding"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=dilation, groups=groups, bias=False, dilation=dilation)
26 |
27 |
28 | def conv1x1(in_planes, out_planes, stride=1):
29 | """1x1 convolution"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
31 |
32 |
33 | class BasicBlock(nn.Module):
34 | expansion = 1
35 |
36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
37 | base_width=64, dilation=1, norm_layer=None):
38 | super(BasicBlock, self).__init__()
39 | if norm_layer is None:
40 | norm_layer = nn.BatchNorm2d
41 | if groups != 1 or base_width != 64:
42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
43 | if dilation > 1:
44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
46 | self.conv1 = conv3x3(inplanes, planes, stride)
47 | self.bn1 = norm_layer(planes)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.conv2 = conv3x3(planes, planes)
50 | self.bn2 = norm_layer(planes)
51 | self.downsample = downsample
52 | self.stride = stride
53 |
54 | def forward(self, x):
55 | identity = x
56 |
57 | out = self.conv1(x)
58 | out = self.bn1(out)
59 | out = self.relu(out)
60 |
61 | out = self.conv2(out)
62 | out = self.bn2(out)
63 |
64 | if self.downsample is not None:
65 | identity = self.downsample(x)
66 |
67 | out += identity
68 | out = self.relu(out)
69 |
70 | return out
71 |
72 |
73 | class Bottleneck(nn.Module):
74 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
75 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
76 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
77 | # This variant is also known as ResNet V1.5 and improves accuracy according to
78 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
79 |
80 | expansion = 4
81 |
82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
83 | base_width=64, dilation=1, norm_layer=None):
84 | super(Bottleneck, self).__init__()
85 | if norm_layer is None:
86 | norm_layer = nn.BatchNorm2d
87 | width = int(planes * (base_width / 64.)) * groups
88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
89 | self.conv1 = conv1x1(inplanes, width)
90 | self.bn1 = norm_layer(width)
91 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
92 | self.bn2 = norm_layer(width)
93 | self.conv3 = conv1x1(width, planes * self.expansion)
94 | self.bn3 = norm_layer(planes * self.expansion)
95 | self.relu = nn.ReLU(inplace=True)
96 | self.downsample = downsample
97 | self.stride = stride
98 |
99 | def forward(self, x):
100 | identity = x
101 |
102 | out = self.conv1(x)
103 | out = self.bn1(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv2(out)
107 | out = self.bn2(out)
108 | out = self.relu(out)
109 |
110 | out = self.conv3(out)
111 | out = self.bn3(out)
112 |
113 | if self.downsample is not None:
114 | identity = self.downsample(x)
115 |
116 | out += identity
117 | out = self.relu(out)
118 |
119 | return out
120 |
121 |
122 | class ResNet(nn.Module):
123 |
124 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
125 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
126 | norm_layer=None):
127 | super(ResNet, self).__init__()
128 | if norm_layer is None:
129 | norm_layer = nn.BatchNorm2d
130 | self._norm_layer = norm_layer
131 |
132 | self.inplanes = 64
133 | self.dilation = 1
134 | if replace_stride_with_dilation is None:
135 | # each element in the tuple indicates if we should replace
136 | # the 2x2 stride with a dilated convolution instead
137 | replace_stride_with_dilation = [False, False, False]
138 | if len(replace_stride_with_dilation) != 3:
139 | raise ValueError("replace_stride_with_dilation should be None "
140 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
141 | self.groups = groups
142 | self.base_width = width_per_group
143 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
144 | bias=False)
145 | self.bn1 = norm_layer(self.inplanes)
146 | self.relu = nn.ReLU(inplace=True)
147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
148 | self.layer1 = self._make_layer(block, 64, layers[0])
149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
150 | dilate=replace_stride_with_dilation[0])
151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
152 | dilate=replace_stride_with_dilation[1])
153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
154 | dilate=replace_stride_with_dilation[2])
155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
156 | self.fc = nn.Linear(512 * block.expansion, num_classes)
157 |
158 | for m in self.modules():
159 | if isinstance(m, nn.Conv2d):
160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
161 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
162 | nn.init.constant_(m.weight, 1)
163 | nn.init.constant_(m.bias, 0)
164 |
165 | # Zero-initialize the last BN in each residual branch,
166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
168 | if zero_init_residual:
169 | for m in self.modules():
170 | if isinstance(m, Bottleneck):
171 | nn.init.constant_(m.bn3.weight, 0)
172 | elif isinstance(m, BasicBlock):
173 | nn.init.constant_(m.bn2.weight, 0)
174 |
175 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
176 | norm_layer = self._norm_layer
177 | downsample = None
178 | previous_dilation = self.dilation
179 | if dilate:
180 | self.dilation *= stride
181 | stride = 1
182 | if stride != 1 or self.inplanes != planes * block.expansion:
183 | downsample = nn.Sequential(
184 | conv1x1(self.inplanes, planes * block.expansion, stride),
185 | norm_layer(planes * block.expansion),
186 | )
187 |
188 | layers = []
189 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
190 | self.base_width, previous_dilation, norm_layer))
191 | self.inplanes = planes * block.expansion
192 | for _ in range(1, blocks):
193 | layers.append(block(self.inplanes, planes, groups=self.groups,
194 | base_width=self.base_width, dilation=self.dilation,
195 | norm_layer=norm_layer))
196 |
197 | return nn.Sequential(*layers)
198 |
199 | def _forward_impl(self, x):
200 | # See note [TorchScript super()]
201 | x = self.conv1(x)
202 | x = self.bn1(x)
203 | x = self.relu(x)
204 | x = self.maxpool(x)
205 |
206 | x = self.layer1(x)
207 | x = self.layer2(x)
208 | x = self.layer3(x)
209 | x = self.layer4(x)
210 |
211 | x = self.avgpool(x)
212 | x = torch.flatten(x, 1)
213 | x = self.fc(x)
214 |
215 | return x
216 |
217 | def forward(self, x):
218 | return self._forward_impl(x)
219 |
220 |
221 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
222 | model = ResNet(block, layers, **kwargs)
223 | if pretrained:
224 | state_dict = load_state_dict_from_url(model_urls[arch],
225 | progress=progress)
226 | model.load_state_dict(state_dict)
227 | return model
228 |
229 |
230 | def resnet18(pretrained=False, progress=True, **kwargs):
231 | r"""ResNet-18 model from
232 | `"Deep Residual Learning for Image Recognition" `_
233 |
234 | Args:
235 | pretrained (bool): If True, returns a model pre-trained on ImageNet
236 | progress (bool): If True, displays a progress bar of the download to stderr
237 | """
238 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
239 | **kwargs)
240 |
241 |
242 | def resnet34(pretrained=False, progress=True, **kwargs):
243 | r"""ResNet-34 model from
244 | `"Deep Residual Learning for Image Recognition" `_
245 |
246 | Args:
247 | pretrained (bool): If True, returns a model pre-trained on ImageNet
248 | progress (bool): If True, displays a progress bar of the download to stderr
249 | """
250 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
251 | **kwargs)
252 |
253 |
254 | def resnet50(pretrained=False, progress=True, **kwargs):
255 | r"""ResNet-50 model from
256 | `"Deep Residual Learning for Image Recognition" `_
257 |
258 | Args:
259 | pretrained (bool): If True, returns a model pre-trained on ImageNet
260 | progress (bool): If True, displays a progress bar of the download to stderr
261 | """
262 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
263 | **kwargs)
264 |
265 |
266 | def resnet101(pretrained=False, progress=True, **kwargs):
267 | r"""ResNet-101 model from
268 | `"Deep Residual Learning for Image Recognition" `_
269 |
270 | Args:
271 | pretrained (bool): If True, returns a model pre-trained on ImageNet
272 | progress (bool): If True, displays a progress bar of the download to stderr
273 | """
274 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
275 | **kwargs)
276 |
277 |
278 | def resnet152(pretrained=False, progress=True, **kwargs):
279 | r"""ResNet-152 model from
280 | `"Deep Residual Learning for Image Recognition" `_
281 |
282 | Args:
283 | pretrained (bool): If True, returns a model pre-trained on ImageNet
284 | progress (bool): If True, displays a progress bar of the download to stderr
285 | """
286 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
287 | **kwargs)
288 |
289 |
290 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
291 | r"""ResNeXt-50 32x4d model from
292 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
293 |
294 | Args:
295 | pretrained (bool): If True, returns a model pre-trained on ImageNet
296 | progress (bool): If True, displays a progress bar of the download to stderr
297 | """
298 | kwargs['groups'] = 32
299 | kwargs['width_per_group'] = 4
300 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
301 | pretrained, progress, **kwargs)
302 |
303 |
304 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
305 | r"""ResNeXt-101 32x8d model from
306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
307 |
308 | Args:
309 | pretrained (bool): If True, returns a model pre-trained on ImageNet
310 | progress (bool): If True, displays a progress bar of the download to stderr
311 | """
312 | kwargs['groups'] = 32
313 | kwargs['width_per_group'] = 8
314 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
315 | pretrained, progress, **kwargs)
316 |
317 |
318 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
319 | r"""Wide ResNet-50-2 model from
320 | `"Wide Residual Networks" `_
321 |
322 | The model is the same as ResNet except for the bottleneck number of channels
323 | which is twice larger in every block. The number of channels in outer 1x1
324 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
325 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
326 |
327 | Args:
328 | pretrained (bool): If True, returns a model pre-trained on ImageNet
329 | progress (bool): If True, displays a progress bar of the download to stderr
330 | """
331 | kwargs['width_per_group'] = 64 * 2
332 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
333 | pretrained, progress, **kwargs)
334 |
335 |
336 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
337 | r"""Wide ResNet-101-2 model from
338 | `"Wide Residual Networks" `_
339 |
340 | The model is the same as ResNet except for the bottleneck number of channels
341 | which is twice larger in every block. The number of channels in outer 1x1
342 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
343 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
344 |
345 | Args:
346 | pretrained (bool): If True, returns a model pre-trained on ImageNet
347 | progress (bool): If True, displays a progress bar of the download to stderr
348 | """
349 | kwargs['width_per_group'] = 64 * 2
350 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
351 | pretrained, progress, **kwargs)
352 |
--------------------------------------------------------------------------------
/model/resnet1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.hub import load_state_dict_from_url
4 |
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
8 | 'wide_resnet50_2', 'wide_resnet101_2']
9 | model_urls = {
10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
23 | """3x3 convolution with padding"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=dilation, groups=groups, bias=False, dilation=dilation)
26 |
27 |
28 | def conv1x1(in_planes, out_planes, stride=1):
29 | """1x1 convolution"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
31 |
32 |
33 | class BasicBlock(nn.Module):
34 | expansion = 1
35 |
36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
37 | base_width=64, dilation=1, norm_layer=None):
38 | super(BasicBlock, self).__init__()
39 | if norm_layer is None:
40 | norm_layer = nn.BatchNorm2d
41 | if groups != 1 or base_width != 64:
42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
43 | if dilation > 1:
44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
46 | self.conv1 = conv3x3(inplanes, planes, stride)
47 | self.bn1 = norm_layer(planes)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.conv2 = conv3x3(planes, planes)
50 | self.bn2 = norm_layer(planes)
51 | self.downsample = downsample
52 | self.stride = stride
53 |
54 | def forward(self, x):
55 | identity = x
56 |
57 | out = self.conv1(x)
58 | out = self.bn1(out)
59 | out = self.relu(out)
60 |
61 | out = self.conv2(out)
62 | out = self.bn2(out)
63 |
64 | if self.downsample is not None:
65 | identity = self.downsample(x)
66 |
67 | out += identity
68 | out = self.relu(out)
69 |
70 | return out
71 |
72 |
73 | class Bottleneck(nn.Module):
74 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
75 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
76 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
77 | # This variant is also known as ResNet V1.5 and improves accuracy according to
78 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
79 |
80 | expansion = 4
81 |
82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
83 | base_width=64, dilation=1, norm_layer=None):
84 | super(Bottleneck, self).__init__()
85 | if norm_layer is None:
86 | norm_layer = nn.BatchNorm2d
87 | width = int(planes * (base_width / 64.)) * groups
88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
89 | self.conv1 = conv1x1(inplanes, width)
90 | self.bn1 = norm_layer(width)
91 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
92 | self.bn2 = norm_layer(width)
93 | self.conv3 = conv1x1(width, planes * self.expansion)
94 | self.bn3 = norm_layer(planes * self.expansion)
95 | self.relu = nn.ReLU(inplace=True)
96 | self.downsample = downsample
97 | self.stride = stride
98 |
99 | def forward(self, x):
100 | identity = x
101 |
102 | out = self.conv1(x)
103 | out = self.bn1(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv2(out)
107 | out = self.bn2(out)
108 | out = self.relu(out)
109 |
110 | out = self.conv3(out)
111 | out = self.bn3(out)
112 |
113 | if self.downsample is not None:
114 | identity = self.downsample(x)
115 |
116 | out += identity
117 | out = self.relu(out)
118 |
119 | return out
120 |
121 |
122 | class ResNet(nn.Module):
123 |
124 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
125 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
126 | norm_layer=None):
127 | super(ResNet, self).__init__()
128 | if norm_layer is None:
129 | norm_layer = nn.BatchNorm2d
130 | self._norm_layer = norm_layer
131 |
132 | self.inplanes = 64
133 | self.dilation = 1
134 | if replace_stride_with_dilation is None:
135 | # each element in the tuple indicates if we should replace
136 | # the 2x2 stride with a dilated convolution instead
137 | replace_stride_with_dilation = [False, False, False]
138 | if len(replace_stride_with_dilation) != 3:
139 | raise ValueError("replace_stride_with_dilation should be None "
140 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
141 | self.groups = groups
142 | self.base_width = width_per_group
143 | self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=7, stride=2, padding=3,
144 | bias=False)
145 | self.bn1 = norm_layer(self.inplanes)
146 | self.relu = nn.ReLU(inplace=True)
147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
148 | self.layer1 = self._make_layer(block, 64, layers[0])
149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
150 | dilate=replace_stride_with_dilation[0])
151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
152 | dilate=replace_stride_with_dilation[1])
153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
154 | dilate=replace_stride_with_dilation[2])
155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
156 | self.fc = nn.Linear(512 * block.expansion, num_classes)
157 |
158 | for m in self.modules():
159 | if isinstance(m, nn.Conv2d):
160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
161 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
162 | nn.init.constant_(m.weight, 1)
163 | nn.init.constant_(m.bias, 0)
164 |
165 | # Zero-initialize the last BN in each residual branch,
166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
168 | if zero_init_residual:
169 | for m in self.modules():
170 | if isinstance(m, Bottleneck):
171 | nn.init.constant_(m.bn3.weight, 0)
172 | elif isinstance(m, BasicBlock):
173 | nn.init.constant_(m.bn2.weight, 0)
174 |
175 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
176 | norm_layer = self._norm_layer
177 | downsample = None
178 | previous_dilation = self.dilation
179 | if dilate:
180 | self.dilation *= stride
181 | stride = 1
182 | if stride != 1 or self.inplanes != planes * block.expansion:
183 | downsample = nn.Sequential(
184 | conv1x1(self.inplanes, planes * block.expansion, stride),
185 | norm_layer(planes * block.expansion),
186 | )
187 |
188 | layers = []
189 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
190 | self.base_width, previous_dilation, norm_layer))
191 | self.inplanes = planes * block.expansion
192 | for _ in range(1, blocks):
193 | layers.append(block(self.inplanes, planes, groups=self.groups,
194 | base_width=self.base_width, dilation=self.dilation,
195 | norm_layer=norm_layer))
196 |
197 | return nn.Sequential(*layers)
198 |
199 | def _forward_impl(self, x):
200 | # See note [TorchScript super()]
201 | x = self.conv1(x)
202 | x = self.bn1(x)
203 | x = self.relu(x)
204 | x = self.maxpool(x)
205 |
206 | x = self.layer1(x)
207 | x = self.layer2(x)
208 | x = self.layer3(x)
209 | x = self.layer4(x)
210 |
211 | x = self.avgpool(x)
212 | x = torch.flatten(x, 1)
213 | x = self.fc(x)
214 |
215 | return x
216 |
217 | def forward(self, x):
218 | return self._forward_impl(x)
219 |
220 |
221 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
222 | model = ResNet(block, layers, **kwargs)
223 | if pretrained:
224 | state_dict = load_state_dict_from_url(model_urls[arch],
225 | progress=progress)
226 | model.load_state_dict(state_dict)
227 | return model
228 |
229 |
230 | def resnet18(pretrained=False, progress=True, **kwargs):
231 | r"""ResNet-18 model from
232 | `"Deep Residual Learning for Image Recognition" `_
233 |
234 | Args:
235 | pretrained (bool): If True, returns a model pre-trained on ImageNet
236 | progress (bool): If True, displays a progress bar of the download to stderr
237 | """
238 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
239 | **kwargs)
240 |
241 |
242 | def resnet34(pretrained=False, progress=True, **kwargs):
243 | r"""ResNet-34 model from
244 | `"Deep Residual Learning for Image Recognition" `_
245 |
246 | Args:
247 | pretrained (bool): If True, returns a model pre-trained on ImageNet
248 | progress (bool): If True, displays a progress bar of the download to stderr
249 | """
250 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
251 | **kwargs)
252 |
253 |
254 | def resnet50(pretrained=False, progress=True, **kwargs):
255 | r"""ResNet-50 model from
256 | `"Deep Residual Learning for Image Recognition" `_
257 |
258 | Args:
259 | pretrained (bool): If True, returns a model pre-trained on ImageNet
260 | progress (bool): If True, displays a progress bar of the download to stderr
261 | """
262 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
263 | **kwargs)
264 |
265 |
266 | def resnet101(pretrained=False, progress=True, **kwargs):
267 | r"""ResNet-101 model from
268 | `"Deep Residual Learning for Image Recognition" `_
269 |
270 | Args:
271 | pretrained (bool): If True, returns a model pre-trained on ImageNet
272 | progress (bool): If True, displays a progress bar of the download to stderr
273 | """
274 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
275 | **kwargs)
276 |
277 |
278 | def resnet152(pretrained=False, progress=True, **kwargs):
279 | r"""ResNet-152 model from
280 | `"Deep Residual Learning for Image Recognition" `_
281 |
282 | Args:
283 | pretrained (bool): If True, returns a model pre-trained on ImageNet
284 | progress (bool): If True, displays a progress bar of the download to stderr
285 | """
286 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
287 | **kwargs)
288 |
289 |
290 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
291 | r"""ResNeXt-50 32x4d model from
292 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
293 |
294 | Args:
295 | pretrained (bool): If True, returns a model pre-trained on ImageNet
296 | progress (bool): If True, displays a progress bar of the download to stderr
297 | """
298 | kwargs['groups'] = 32
299 | kwargs['width_per_group'] = 4
300 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
301 | pretrained, progress, **kwargs)
302 |
303 |
304 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
305 | r"""ResNeXt-101 32x8d model from
306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
307 |
308 | Args:
309 | pretrained (bool): If True, returns a model pre-trained on ImageNet
310 | progress (bool): If True, displays a progress bar of the download to stderr
311 | """
312 | kwargs['groups'] = 32
313 | kwargs['width_per_group'] = 8
314 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
315 | pretrained, progress, **kwargs)
316 |
317 |
318 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
319 | r"""Wide ResNet-50-2 model from
320 | `"Wide Residual Networks" `_
321 |
322 | The model is the same as ResNet except for the bottleneck number of channels
323 | which is twice larger in every block. The number of channels in outer 1x1
324 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
325 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
326 |
327 | Args:
328 | pretrained (bool): If True, returns a model pre-trained on ImageNet
329 | progress (bool): If True, displays a progress bar of the download to stderr
330 | """
331 | kwargs['width_per_group'] = 64 * 2
332 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
333 | pretrained, progress, **kwargs)
334 |
335 |
336 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
337 | r"""Wide ResNet-101-2 model from
338 | `"Wide Residual Networks" `_
339 |
340 | The model is the same as ResNet except for the bottleneck number of channels
341 | which is twice larger in every block. The number of channels in outer 1x1
342 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
343 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
344 |
345 | Args:
346 | pretrained (bool): If True, returns a model pre-trained on ImageNet
347 | progress (bool): If True, displays a progress bar of the download to stderr
348 | """
349 | kwargs['width_per_group'] = 64 * 2
350 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
351 | pretrained, progress, **kwargs)
352 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 |
5 |
6 | class Options:
7 | def __init__(self, isTrain):
8 | self.dataset = 'MO' # dataset: LC: Lung Cancer, MO: MultiOrgan
9 | self.isTrain = isTrain # train or test mode
10 | self.rootDir = '' # rootdir of the data and experiment
11 |
12 | # --- model hyper-parameters --- #
13 | self.model = dict()
14 | self.model['name'] = 'ResUNet34'
15 | self.model['pretrained'] = False
16 | self.model['fix_params'] = False
17 | self.model['in_c'] = 1 # input channel
18 |
19 | # --- training params --- #
20 | self.train = dict()
21 | self.train['data_dir'] = '{:s}/Data/{:s}'.format(self.rootDir, self.dataset) # path to data
22 | self.train['save_dir'] = '{:s}/Exp/{:s}'.format(self.rootDir, self.dataset) # path to save results
23 | self.train['input_size'] = 224 # input size of the image
24 | self.train['train_epochs'] = 100 # number of training epochs
25 | self.train['batch_size'] = 8 # batch size
26 | self.train['checkpoint_freq'] = 20 # epoch to save checkpoints
27 | self.train['lr'] = 1e-3 # initial learning rate
28 | self.train['weight_decay'] = 5e-4 # weight decay
29 | self.train['log_interval'] = 37 # iterations to print training results
30 | self.train['ema_interval'] = 5 # eps to save ema
31 | self.train['workers'] = 16 # number of workers to load images
32 | self.train['gpus'] = [0, ] # select gpu devices
33 | # --- resume training --- #
34 | self.train['start_epoch'] = 0 # start epoch
35 | self.train['checkpoint'] = ''
36 |
37 | # --- data transform --- #
38 | self.transform = dict()
39 |
40 | # --- test parameters --- #
41 | self.test = dict()
42 | self.test['test_epoch'] = 80
43 | self.test['gpus'] = [0, ]
44 | self.test['img_dir'] = '{:s}/Data/{:s}/images/test'.format(self.rootDir, self.dataset)
45 | self.test['imgh_dir'] = '{:s}/Data/{:s}/images/testh'.format(self.rootDir, self.dataset)
46 | self.test['label_dir'] = '{:s}/Data/{:s}/labels_instance'.format(self.rootDir, self.dataset)
47 | self.test['save_flag'] = True
48 | self.test['patch_size'] = 224
49 | self.test['overlap'] = 80
50 | self.test['save_dir'] = '{:s}/Exp/{:s}/test_results'.format(self.rootDir, self.dataset)
51 | self.test['checkpoint_dir'] = '{:s}/Exp/{:s}/checkpoints/'.format(self.rootDir, self.dataset)
52 | self.test['model_path1'] = '{:s}/checkpoint1_{:d}.pth.tar'.format(self.test['checkpoint_dir'], self.test['test_epoch'])
53 | self.test['model_path2'] = '{:s}/checkpoint2_{:d}.pth.tar'.format(self.test['checkpoint_dir'], self.test['test_epoch'])
54 | # --- post processing --- #
55 | self.post = dict()
56 | self.post['min_area'] = 20 # minimum area for an object
57 |
58 | def parse(self):
59 | """ Parse the options, replace the default value if there is a new input """
60 | parser = argparse.ArgumentParser(description='')
61 | if self.isTrain:
62 | parser.add_argument('--batch-size', type=int, default=self.train['batch_size'], help='input batch size for training')
63 | parser.add_argument('--epochs', type=int, default=self.train['train_epochs'], help='number of epochs to train')
64 | parser.add_argument('--lr', type=float, default=self.train['lr'], help='learning rate')
65 | parser.add_argument('--log-interval', type=int, default=self.train['log_interval'], help='how many batches to wait before logging training status')
66 | parser.add_argument('--gpus', type=int, nargs='+', default=self.train['gpus'], help='GPUs for training')
67 | parser.add_argument('--data-dir', type=str, default=self.train['data_dir'], help='directory of training data')
68 | parser.add_argument('--save-dir', type=str, default=self.train['save_dir'], help='directory to save training results')
69 | args = parser.parse_args()
70 |
71 | self.train['batch_size'] = args.batch_size
72 | self.train['train_epochs'] = args.epochs
73 | self.train['lr'] = args.lr
74 | self.train['log_interval'] = args.log_interval
75 | self.train['gpus'] = args.gpus
76 | self.train['data_dir'] = args.data_dir
77 | self.train['img_dir'] = '{:s}/images'.format(self.train['data_dir'])
78 | self.train['label_vor_dir'] = '{:s}/labels_voronoi'.format(self.train['data_dir'])
79 | self.train['label_cluster_dir'] = '{:s}/labels_cluster'.format(self.train['data_dir'])
80 |
81 |
82 | self.train['save_dir'] = args.save_dir
83 | if not os.path.exists(self.train['save_dir']):
84 | os.makedirs(self.train['save_dir'], exist_ok=True)
85 |
86 | # define data transforms for training
87 | self.transform['train'] = {
88 | 'random_resize': [0.8, 1.25],
89 | 'horizontal_flip': True,
90 | 'vertical_flip': True,
91 | 'random_affine': 0.3,
92 | 'random_rotation': 90,
93 | 'random_crop': self.train['input_size'],
94 | 'label_encoding': 2,
95 | 'to_tensor': 3
96 | }
97 |
98 | self.transform['val'] = {
99 | 'random_crop': self.train['input_size'],
100 | 'label_encoding': 2,
101 | 'to_tensor': 2
102 | }
103 |
104 | self.transform['test'] = {
105 | 'to_tensor': 1
106 | }
107 |
108 | self.transform['val_lin'] = {
109 | 'to_tensor': 3
110 | }
111 |
112 | else:
113 | parser.add_argument('--save-flag', type=bool, default=self.test['save_flag'], help='flag to save the network outputs and predictions')
114 | parser.add_argument('--img-dir', type=str, default=self.test['img_dir'], help='directory of test images')
115 | parser.add_argument('--imgh-dir', type=str, default=self.test['imgh_dir'], help='directory of test images')
116 | parser.add_argument('--label-dir', type=str, default=self.test['label_dir'], help='directory of labels')
117 | parser.add_argument('--save-dir', type=str, default=self.test['save_dir'], help='directory to save test results')
118 | parser.add_argument('--gpus', type=int, nargs='+', default=self.train['gpus'], help='GPUs for training')
119 | parser.add_argument('--model-path1', type=str, default=self.test['model_path1'], help='train model to be evaluated')
120 | parser.add_argument('--model-path2', type=str, default=self.test['model_path2'],help='train model to be evaluated')
121 | args = parser.parse_args()
122 | self.test['gpus'] = args.gpus
123 | self.test['save_flag'] = args.save_flag
124 | self.test['img_dir'] = args.img_dir
125 | self.test['imgh_dir'] = args.imgh_dir
126 | self.test['label_dir'] = args.label_dir
127 | self.test['save_dir'] = args.save_dir
128 | self.test['model_path1'] = args.model_path1
129 | self.test['model_path2'] = args.model_path2
130 |
131 | if not os.path.exists(self.test['save_dir']):
132 | os.makedirs(self.test['save_dir'], exist_ok=True)
133 |
134 | self.transform['test'] = {
135 | 'to_tensor': 1
136 | }
137 |
138 | def save_options(self):
139 | if self.isTrain:
140 | filename = '{:s}/train_options.txt'.format(self.train['save_dir'])
141 | else:
142 | filename = '{:s}/test_options.txt'.format(self.test['save_dir'])
143 | file = open(filename, 'w')
144 | groups = ['model', 'train', 'transform'] if self.isTrain else ['model', 'test', 'post', 'transform']
145 |
146 | file.write("# ---------- Options ---------- #")
147 | file.write('\ndataset: {:s}\n'.format(self.dataset))
148 | file.write('isTrain: {}\n'.format(self.isTrain))
149 | for group, options in self.__dict__.items():
150 | if group not in groups:
151 | continue
152 | file.write('\n\n-------- {:s} --------\n'.format(group))
153 | if group == 'transform':
154 | for name, val in options.items():
155 | if (self.isTrain and name != 'test') or (not self.isTrain and name == 'test'):
156 | file.write("{:s}:\n".format(name))
157 | for t_name, t_val in val.items():
158 | file.write("\t{:s}: {:s}\n".format(t_name, repr(t_val)))
159 | else:
160 | for name, val in options.items():
161 | file.write("{:s} = {:s}\n".format(name, repr(val)))
162 | file.close()
163 |
164 |
165 |
166 |
167 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | astor==0.8.1
3 | backcall==0.2.0
4 | cached-property==1.5.2
5 | cachetools==4.2.4
6 | certifi==2020.6.20
7 | chardet==5.0.0
8 | click==8.0.4
9 | commonmark==0.9.1
10 | configparser==5.2.0
11 | contextlib2==21.6.0
12 | cycler==0.11.0
13 | dataclasses==0.8
14 | decorator==4.4.2
15 | gast==0.5.3
16 | google-pasta==0.2.0
17 | grad-cam==1.4.8
18 | grpcio==1.43.0
19 | h5py==3.1.0
20 | imageio==2.13.5
21 | importlib-metadata==4.8.3
22 | importlib-resources==5.4.0
23 | ipdb==0.13.13
24 | ipython==7.16.3
25 | ipython-genutils==0.2.0
26 | jedi==0.17.2
27 | joblib==1.1.1
28 | Keras==2.3.1
29 | Keras-Applications==1.0.8
30 | Keras-Preprocessing==1.1.2
31 | kiwisolver==1.3.1
32 | lxml==4.9.3
33 | Markdown==3.3.6
34 | matplotlib==3.3.4
35 | networkx==2.5.1
36 | nltk==3.6.7
37 | numpy==1.19.5
38 | nvidia-ml-py==11.525.131
39 | nvitop==1.0.0
40 | opencv-python==4.5.5.62
41 | packaging==21.3
42 | pandas==1.1.5
43 | parso==0.7.1
44 | pexpect==4.8.0
45 | pickleshare==0.7.5
46 | Pillow==10.0.1
47 | plotly==5.15.0
48 | prompt-toolkit==3.0.36
49 | protobuf==3.19.3
50 | psutil==5.9.5
51 | ptyprocess==0.7.0
52 | Pygments==2.14.0
53 | pyparsing==3.0.6
54 | python-dateutil==2.8.2
55 | pytz==2023.3
56 | PyWavelets==1.1.1
57 | PyYAML==6.0
58 | regex==2023.8.8
59 | rich==12.6.0
60 | scikit-image==0.17.2
61 | scikit-learn==0.24.2
62 | scipy==1.5.4
63 | Shapely==1.8.5.post1
64 | SimpleITK==2.0.2
65 | six==1.16.0
66 | summary==0.2.0
67 | tenacity==8.2.2
68 | tensorboard==1.14.0
69 | tensorboardX==2.6.2
70 | tensorflow-estimator==1.14.0
71 | tensorflow-gpu==1.14.0
72 | termcolor==1.1.0
73 | threadpoolctl==3.1.0
74 | tifffile==2020.9.3
75 | tomli==1.2.3
76 | torch==1.9.1+cu111
77 | torchaudio==0.9.1
78 | torchcam==0.3.2
79 | torchsummary==1.5.1
80 | torchvision==0.10.1+cu111
81 | tqdm==4.64.1
82 | traitlets==4.3.3
83 | ttach==0.0.3
84 | typing_extensions==4.0.1
85 | wcwidth==0.2.6
86 | Werkzeug==2.0.2
87 | wrapt==1.13.3
88 | xgboost==1.5.2
89 | zipp==3.6.0
90 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.backends.cudnn as cudnn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from PIL import Image
7 | import skimage.morphology as morph
8 | import scipy.ndimage.morphology as ndi_morph
9 | from skimage import measure
10 | from scipy import misc
11 | from model.modelW import ResWNet34
12 | from model.model_WNet import WNet
13 | import utils.utils as utils
14 | from utils.accuracy import compute_metrics
15 | import time
16 | import imageio
17 | from options import Options
18 | from dataloaders.my_transforms import get_transforms
19 | from tqdm import tqdm
20 | from rich.table import Column, Table
21 | from rich import print
22 | import time
23 |
24 |
25 | def main():
26 | opt = Options(isTrain=False)
27 | opt.parse()
28 | opt.save_options()
29 |
30 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(x) for x in opt.test['gpus'])
31 |
32 | img_dir = opt.test['img_dir']
33 | imgh_dir = opt.test['imgh_dir']
34 | label_dir = opt.test['label_dir']
35 | save_dir = opt.test['save_dir']
36 | model_path1 = opt.test['model_path1']
37 | model_path2 = opt.test['model_path2']
38 | save_flag = opt.test['save_flag']
39 |
40 | # data transforms
41 | test_transform = get_transforms(opt.transform['test'])
42 |
43 | model1 = ResWNet34(seg_classes = 2, colour_classes = 3)
44 | model1 = torch.nn.DataParallel(model1)
45 | model1 = model1.cuda()
46 | model2 = WNet(n_channels = 1, seg_classes = 2, colour_classes = 3)
47 | # model2 = ResWNet34(seg_classes = 2, colour_classes = 3)
48 | model2 = torch.nn.DataParallel(model2)
49 | model2 = model2.cuda()
50 | cudnn.benchmark = True
51 |
52 | # ----- load trained model ----- #
53 | print("=> loading trained model")
54 | checkpoint1 = torch.load(model_path1)
55 | checkpoint2 = torch.load(model_path2)
56 | model1.load_state_dict(checkpoint1['state_dict'])
57 | model2.load_state_dict(checkpoint2['state_dict'])
58 | print("=> loaded model1 at epoch {}".format(checkpoint1['epoch']))
59 | print("=> loaded model2 at epoch {}".format(checkpoint2['epoch']))
60 | model1 = model1.module
61 | model2 = model2.module
62 |
63 | # switch to evaluate mode
64 | model1.eval()
65 | model2.eval()
66 | counter = 0
67 | print("=> Test begins:")
68 |
69 | img_names = os.listdir(img_dir)
70 |
71 | if save_flag:
72 | if not os.path.exists(save_dir):
73 | os.mkdir(save_dir)
74 | strs = img_dir.split('/')
75 | prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
76 | seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1])
77 | if not os.path.exists(prob_maps_folder):
78 | os.mkdir(prob_maps_folder)
79 | if not os.path.exists(seg_folder):
80 | os.mkdir(seg_folder)
81 |
82 | metric_names = ['acc', 'p_F1', 'dice', 'aji', 'dq', 'sq', 'pq']
83 | test_results = dict()
84 | all_result = utils.AverageMeter(len(metric_names))
85 | all_result1 = utils.AverageMeter(len(metric_names))
86 | all_result2 = utils.AverageMeter(len(metric_names))
87 |
88 | # calculte inference time
89 | time_list = []
90 |
91 | img_process = tqdm(img_names)
92 | for img_name in img_process:
93 | # load test image
94 | img_process.set_description('=> Processing image {:s}'.format(img_name))
95 | img_path = '{:s}/{:s}'.format(img_dir, img_name)
96 | imgh_path = '{:s}/{:s}'.format(imgh_dir, img_name)
97 | img = Image.open(img_path)
98 | imgh = Image.open(imgh_path)
99 | import ipdb; ipdb.set_trace()
100 | ori_h = img.size[1]
101 | ori_w = img.size[0]
102 | name = os.path.splitext(img_name)[0]
103 | label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
104 | gt = imageio.imread(label_path)
105 |
106 | input = test_transform((img,))[0].unsqueeze(0)
107 | inputh = test_transform((imgh,))[0].unsqueeze(0)
108 |
109 | img_process.set_description('Computing output probability maps...')
110 | begin = time.time()
111 | prob_maps, prob_maps1, prob_maps2 = get_probmaps(input, inputh, model1, model2, opt)
112 | end = time.time()
113 | time_list.append(end - begin)
114 | pred = np.argmax(prob_maps, axis=0) # prediction
115 |
116 | pred_labeled = measure.label(pred)
117 | pred_labeled = morph.remove_small_objects(pred_labeled, opt.post['min_area'])
118 | pred_labeled = ndi_morph.binary_fill_holes(pred_labeled > 0)
119 | pred_labeled = measure.label(pred_labeled)
120 |
121 | img_process.set_description('Computing metrics...')
122 | metrics = compute_metrics(pred_labeled, gt, metric_names)
123 | print(metrics)
124 |
125 | test_results[name] = [metrics[m] for m in metric_names]
126 | all_result.update([metrics[m] for m in metric_names])
127 |
128 | pred1 = np.argmax(prob_maps1, axis=0) # prediction
129 | pred_labeled1 = measure.label(pred1)
130 | pred_labeled1 = morph.remove_small_objects(pred_labeled1, opt.post['min_area'])
131 | pred_labeled1 = ndi_morph.binary_fill_holes(pred_labeled1 > 0)
132 | pred_labeled1 = measure.label(pred_labeled1)
133 | metrics = compute_metrics(pred_labeled1, gt, metric_names)
134 | all_result1.update([metrics[m] for m in metric_names])
135 |
136 | pred = np.argmax(prob_maps2, axis=0) # prediction
137 | pred_labeled2 = measure.label(pred)
138 | pred_labeled2 = morph.remove_small_objects(pred_labeled2, opt.post['min_area'])
139 | pred_labeled2 = ndi_morph.binary_fill_holes(pred_labeled2 > 0)
140 | pred_labeled2 = measure.label(pred_labeled2)
141 | metrics = compute_metrics(pred_labeled2, gt, metric_names)
142 | all_result2.update([metrics[m] for m in metric_names])
143 |
144 | # save image
145 | if save_flag:
146 | img_process.set_description('Saving image results...')
147 | imageio.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name), pred.astype(np.uint8) * 255)
148 | imageio.imsave('{:s}/{:s}_prob.png'.format(prob_maps_folder, name), (prob_maps[1] * 255).astype(np.uint8))
149 | final_pred = Image.fromarray(pred_labeled.astype(np.uint16))
150 | final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name))
151 |
152 | # save colored objects
153 | pred_colored_instance = np.zeros((ori_h, ori_w, 3))
154 | for k in range(1, pred_labeled.max() + 1):
155 | pred_colored_instance[pred_labeled == k, :] = np.array(utils.get_random_color())
156 | filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name)
157 | imageio.imsave(filename, (pred_colored_instance * 255).astype(np.uint8))
158 |
159 | counter += 1
160 |
161 | print('=> Processed all {:d} images'.format(counter))
162 | print('=> Average time per image: {:.4f} s'.format(np.mean(time_list)))
163 | print('=> Std time per image: {:.4f} s'.format(np.std(time_list)))
164 |
165 | table = Table(show_header=True, header_style="bold magenta")
166 | table.add_column("Model", style="dim", width=12)
167 | table.add_column("Acc")
168 | table.add_column("F1")
169 | table.add_column("Dice")
170 | table.add_column("AJI")
171 | table.add_column("DQ")
172 | table.add_column("SQ")
173 | table.add_column("PQ")
174 | a, a1, a2 = all_result.avg, all_result1.avg, all_result2.avg
175 | table.add_row('ens', f'{a[0]: .4f}', f'{a[1]: .4f}', f'{a[2]: .4f}', f'{a[3]: .4f}', f'{a[4]: .4f}', f'{a[5]: .4f}', f'{a[6]: .4f}'.format(a=all_result.avg))
176 | table.add_row('m1', f'{a1[0]: .4f}', f'{a1[1]: .4f}', f'{a1[2]: .4f}', f'{a1[3]: .4f}', f'{a1[4]: .4f}', f'{a1[5]: .4f}', f'{a1[6]: .4f}'.format(a1=all_result1.avg))
177 | table.add_row('m2', f'{a2[0]: .4f}', f'{a2[1]: .4f}', f'{a2[2]: .4f}', f'{a2[3]: .4f}', f'{a2[4]: .4f}', f'{a2[5]: .4f}', f'{a2[6]: .4f}'.format(a2=all_result2.avg))
178 | print(table)
179 |
180 | header = metric_names
181 | utils.save_results(header, all_result.avg, test_results, f'{save_dir}/test_results_{checkpoint1["epoch"]}_{checkpoint2["epoch"]}_{all_result.avg[5]:.4f}_{all_result1.avg[5]:.4f}_{all_result2.avg[5]:.4f}.txt')
182 |
183 |
184 | def get_probmaps(input, inputh, model1, model2, opt):
185 | size = opt.test['patch_size']
186 | overlap = opt.test['overlap']
187 | output1 = split_forward1(model1, inputh, size, overlap)
188 | output1 = output1.squeeze(0)
189 | prob_maps1 = F.softmax(output1, dim=0).cpu().numpy()
190 |
191 | output2 = split_forward1(model2, inputh, size, overlap)
192 | output2 = output2.squeeze(0)
193 | prob_maps2 = F.softmax(output2, dim=0).cpu().numpy()
194 |
195 | prob_maps = (prob_maps1 + prob_maps2) / 2.0
196 |
197 | return prob_maps, prob_maps1, prob_maps2
198 |
199 |
200 | def split_forward1(model, input, size, overlap, outchannel = 2):
201 | '''
202 | split the input image for forward passes
203 | motification: if single image, split it into patches and concat, forward once.
204 | '''
205 |
206 | b, c, h0, w0 = input.size()
207 |
208 | # zero pad for border patches
209 | pad_h = 0
210 | if h0 - size > 0 and (h0 - size) % (size - overlap) > 0:
211 | pad_h = (size - overlap) - (h0 - size) % (size - overlap)
212 | tmp = torch.zeros((b, c, pad_h, w0))
213 | input = torch.cat((input, tmp), dim=2)
214 |
215 | if w0 - size > 0 and (w0 - size) % (size - overlap) > 0:
216 | pad_w = (size - overlap) - (w0 - size) % (size - overlap)
217 | tmp = torch.zeros((b, c, h0 + pad_h, pad_w))
218 | input = torch.cat((input, tmp), dim=3)
219 |
220 | _, c, h, w = input.size()
221 |
222 | output = torch.zeros((input.size(0), outchannel, h, w))
223 | input_vars = []
224 | for i in range(0, h-overlap, size-overlap):
225 | r_end = i + size if i + size < h else h
226 | ind1_s = i + overlap // 2 if i > 0 else 0
227 | ind1_e = i + size - overlap // 2 if i + size < h else h
228 | for j in range(0, w-overlap, size-overlap):
229 | c_end = j+size if j+size < w else w
230 |
231 | input_patch = input[:,:,i:r_end,j:c_end]
232 | # input_var = input_patch.cuda()
233 | input_var = input_patch.numpy()
234 | input_vars.append(input_var)
235 | input_vars = torch.as_tensor(input_vars)
236 | input_vars = input_vars.squeeze(1).cuda()
237 | with torch.no_grad():
238 | output_patches, _ = model(input_vars)
239 | idx = 0
240 | for i in range(0, h-overlap, size-overlap):
241 | r_end = i + size if i + size < h else h
242 | ind1_s = i + overlap // 2 if i > 0 else 0
243 | ind1_e = i + size - overlap // 2 if i + size < h else h
244 | for j in range(0, w-overlap, size-overlap):
245 | c_end = j+size if j+size < w else w
246 | output_patch = output_patches[idx]
247 | idx += 1
248 | output_patch = output_patch.unsqueeze(0)
249 | ind2_s = j+overlap//2 if j>0 else 0
250 | ind2_e = j+size-overlap//2 if j+size Initial learning rate: {:g}".format(opt.train['lr']))
122 | logger.info("=> Batch size: {:d}".format(opt.train['batch_size']))
123 | logger.info("=> Number of training iterations: {:d}".format(num_iter))
124 | logger.info("=> Training epochs: {:d}".format(opt.train['train_epochs']))
125 | min_loss1 = 100
126 | min_loss2 = 100
127 | for epoch in range(opt.train['start_epoch'], num_epoch):
128 | # train for one epoch or len(train_loader) iterations
129 | logger.info('Epoch: [{:d}/{:d}]'.format(epoch+1, num_epoch))
130 | if epoch % opt.train['ema_interval'] == 0:
131 | if epoch == 0:
132 | if os.path.exists(os.path.join(opt.train['save_dir'], 'weight12')):
133 | shutil.rmtree(os.path.join(opt.train['save_dir'], 'weight12'))
134 | shutil.rmtree(os.path.join(opt.train['save_dir'], 'weight21'))
135 | ensemble_prediction(imgh2_dir, os.path.join(opt.train['save_dir'], 'weight21'), model1)
136 | ensemble_prediction(imgh1_dir, os.path.join(opt.train['save_dir'], 'weight12'), model2)
137 |
138 | test_dice1 = val_lin(model1, test_loader)
139 | test_dice2 = val_lin(model2, test_loader)
140 | print(f'test_dice1: {test_dice1: .4f}\ttest_dice2: {test_dice2: .4f}')
141 |
142 | train_results = train(train_loader1, train_loader2, model1, model2, optimizer1, optimizer2, criterion, criterion_H, epoch, num_epoch)
143 | train_loss, train_loss_vor, train_loss_cluster, train_loss_ct, train_loss_self = train_results
144 | state1 = {'epoch': epoch + 1, 'state_dict': model1.state_dict(), 'optimizer': optimizer1.state_dict()}
145 | state2 = {'epoch': epoch + 1, 'state_dict': model2.state_dict(), 'optimizer': optimizer2.state_dict()}
146 | if (epoch + 1) < (num_epoch - 4):
147 | cp_flag = (epoch + 1) % opt.train['checkpoint_freq'] == 0
148 | else:
149 | cp_flag = True
150 | save_checkpoint1(state1, epoch, opt.train['save_dir'], cp_flag)
151 | save_checkpoint2(state2, epoch, opt.train['save_dir'], cp_flag)
152 |
153 | # save the training results to txt files
154 | logger_results.info('{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'
155 | .format(epoch+1, train_loss, train_loss_vor, train_loss_cluster, train_loss_ct, train_loss_self, test_dice1, test_dice2))
156 | scheduler1.step()
157 | scheduler2.step()
158 |
159 | val_loss1 = val(val_loader, model1, criterion)
160 | val_loss2 = val(val_loader, model2, criterion)
161 |
162 | if val_loss1 < min_loss1:
163 | print(f'val_loss1: {val_loss1:.4f}', f'min_loss1: {min_loss1:.4f}')
164 | min_loss1 = val_loss1
165 | save_bestcheckpoint1(state1, opt.train['save_dir'])
166 | if val_loss2 < min_loss2:
167 | print(f'val_loss2: {val_loss2:.4f}', f'min_loss2: {min_loss2:.4f}')
168 | min_loss2 = val_loss2
169 | save_bestcheckpoint2(state2, opt.train['save_dir'])
170 | for i in list(logger.handlers):
171 | logger.removeHandler(i)
172 | i.flush()
173 | i.close()
174 | for i in list(logger_results.handlers):
175 | logger_results.removeHandler(i)
176 | i.flush()
177 | i.close()
178 |
179 |
180 | def train(train_loader1, train_loader2, model1, model2, optimizer1, optimizer2, criterion, criterion_H, epoch, num_epoch):
181 | # list to store the average loss for this epoch
182 | results = utils.AverageMeter(5)
183 | # switch to train mode
184 | model1.train()
185 | model2.train()
186 | weight12, weight21 = None, None
187 | for i, (sample1, sample2) in enumerate(zip(train_loader1, train_loader2)):
188 | input1, inputh1, weight12, vor1, cluster1 = sample1
189 | input2, inputh2, weight21, vor2, cluster2 = sample2
190 | if vor1.dim() == 4:
191 | vor1 = vor1.squeeze(1)
192 | if vor2.dim() == 4:
193 | vor2 = vor2.squeeze(1)
194 | if cluster1.dim() == 4:
195 | cluster1 = cluster1.squeeze(1)
196 | if cluster2.dim() == 4:
197 | cluster2 = cluster2.squeeze(1)
198 | inputh_var1 = inputh1.float().cuda()
199 | inputh_var2 = inputh2.float().cuda()
200 |
201 | # compute output
202 | output11, output11l = model1(inputh_var1)
203 | output22, output22l = model2(inputh_var2)
204 |
205 | log_prob_maps1 = F.log_softmax(output11, dim=1)
206 | loss_vor1 = criterion(log_prob_maps1, vor1.cuda())
207 | loss_cluster1 = criterion(log_prob_maps1, cluster1.cuda())
208 |
209 | log_prob_maps2 = F.log_softmax(output22, dim=1)
210 | loss_vor2 = criterion(log_prob_maps2, vor2.cuda())
211 | loss_cluster2 = criterion(log_prob_maps2, cluster2.cuda())
212 |
213 | loss_vor = loss_vor1 + loss_vor2
214 | loss_cluster = loss_cluster1 + loss_cluster2
215 |
216 | pseudo12 = Combine(weight12.float().cuda(), cluster1)
217 | Pseudo12 = Variable(pseudo12, requires_grad=False)
218 | pseudo21 = Combine(weight21.float().cuda(), cluster2)
219 | Pseudo21 = Variable(pseudo21, requires_grad=False)
220 |
221 | loss_ct1 = loss_CT(output11, Pseudo12)
222 | loss_ct2 = loss_CT(output22, Pseudo21)
223 | loss_ct = loss_ct1 + loss_ct2
224 |
225 | loss_self1 = criterion_H(output11l, input1.cuda())
226 | loss_self2 = criterion_H(output22l, input2.cuda())
227 | loss_self = loss_self1 + loss_self2
228 |
229 | loss = loss_vor + loss_cluster + loss_ct * (epoch/num_epoch)**2 + loss_self * (1 - (epoch/num_epoch)**2) * 0.1
230 |
231 | result = [loss.item(), loss_vor.item(), loss_cluster.item(), loss_ct.item(), loss_self.item()]
232 |
233 | results.update(result, input1.size(0))
234 |
235 | # compute gradient and do SGD step
236 | optimizer1.zero_grad()
237 | optimizer2.zero_grad()
238 | loss.backward()
239 | optimizer1.step()
240 | optimizer2.step()
241 |
242 | if i % opt.train['log_interval'] == 0:
243 | logger.info('Iteration: [{:d}/{:d}]'
244 | '\tLoss {r[0]:.4f}'
245 | '\tLoss_vor {r[1]:.4f}'
246 | '\tLoss_cluster {r[2]:.4f}'
247 | '\tLoss_ct {r[3]:.4f}'
248 | '\tLoss_self {r[4]:.4f}'.format(i, len(train_loader1), r=results.avg))
249 |
250 | logger.info('===> Train Avg: Loss {r[0]:.4f}'
251 | '\tloss_vor {r[1]:.4f}'
252 | '\tloss_cluster {r[2]:.4f}'
253 | '\tloss_ct {r[3]:.4f}'
254 | '\tloss_self {r[4]:.4f}'.format(r=results.avg))
255 |
256 | return results.avg
257 |
258 | def val(val_loader, model, criterion):
259 | model.eval()
260 | results = 0
261 | for i, sample in enumerate(val_loader):
262 | input, inputh, target1, target2 = sample
263 | if target1.dim() == 4:
264 | target1 = target1.squeeze(1)
265 | if target2.dim() == 4:
266 | target2 = target2.squeeze(1)
267 |
268 | input_var = inputh.cuda()
269 |
270 | # compute output
271 | output, _ = model(input_var)
272 | log_prob_maps = F.log_softmax(output, dim=1)
273 | loss_vor = criterion(log_prob_maps, target1.cuda())
274 | loss_cluster = criterion(log_prob_maps, target2.cuda())
275 | result = loss_vor.item() + loss_cluster.item()
276 | results += result
277 | val_loss = results / (opt.train['batch_size'] * len(val_loader))
278 | return val_loss
279 |
280 |
281 | def dice_coeff(pred, gt):
282 | target = torch.zeros_like(gt)
283 | target[gt > 0.5] = 1
284 | target = gt
285 | smooth = 1.
286 | num = pred.size(0)
287 | m1 = pred.view(num, -1) # Flatten
288 | m2 = target.view(num, -1) # Flatten
289 | intersection = (m1 * m2)
290 | dice = (2. * intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth)
291 | avg_dice = dice.sum() / num
292 | return avg_dice
293 |
294 | def val_lin(model, test_loader):
295 | model.eval()
296 | metric_names = ['dice']
297 | all_result = utils.AverageMeter(len(metric_names))
298 |
299 | size = opt.test['patch_size']
300 | overlap = opt.test['overlap']
301 |
302 | for _, inputh, gt in test_loader:
303 | if gt.dim() == 4:
304 | gt = gt.squeeze(1).cuda()
305 | output = utils.split_forward(model, inputh, size, overlap)
306 | log_prob_maps = F.softmax(output, dim=1)
307 | pred_labeled = torch.argmax(log_prob_maps, axis=1) # prediction
308 | dice = dice_coeff(pred_labeled, gt).cpu().numpy()
309 | all_result.update([dice])
310 | dice_avg = all_result.avg[0]
311 | return dice_avg
312 |
313 |
314 | def ensemble_prediction(img_dir, save_dir, model, alpha=0.1):
315 | ## save pred into the experiment fold
316 | ## load all the training images
317 | img_names = os.listdir(img_dir)
318 | img_process = tqdm(img_names)
319 | test_transform = get_transforms(opt.transform['test'])
320 | print('[bold magenta]saving EMA weights ...[/bold magenta]')
321 | for img_name in img_process:
322 | img = Image.open(os.path.join(img_dir, img_name))
323 | input = test_transform((img,))[0].unsqueeze(0).cuda()
324 | output, _ = model(input)
325 | log_prob_maps = F.softmax(output, dim=1)
326 | pred = log_prob_maps.squeeze(0).cpu().detach().numpy()[1]
327 |
328 | try:
329 | weight = imageio.imread(os.path.join(save_dir, img_name))
330 | weight = np.array(weight)
331 | weight = alpha * pred + (1 - alpha) * weight
332 | except:
333 | weight = pred
334 | os.makedirs(save_dir, exist_ok=True)
335 | imageio.imsave(os.path.join(save_dir, img_name), (weight * 255).astype(np.uint8))
336 |
337 |
338 | def save_checkpoint1(state, epoch, save_dir, cp_flag):
339 | cp_dir = '{:s}/checkpoints'.format(save_dir)
340 | if not os.path.exists(cp_dir):
341 | os.mkdir(cp_dir)
342 | filename = '{:s}/checkpoint1.pth.tar'.format(cp_dir)
343 | torch.save(state, filename)
344 | if cp_flag:
345 | shutil.copyfile(filename, '{:s}/checkpoint1_{:d}.pth.tar'.format(cp_dir, epoch+1))
346 | def save_checkpoint2(state, epoch, save_dir, cp_flag):
347 | cp_dir = '{:s}/checkpoints'.format(save_dir)
348 | if not os.path.exists(cp_dir):
349 | os.mkdir(cp_dir)
350 | filename = '{:s}/checkpoint2.pth.tar'.format(cp_dir)
351 | torch.save(state, filename)
352 | if cp_flag:
353 | shutil.copyfile(filename, '{:s}/checkpoint2_{:d}.pth.tar'.format(cp_dir, epoch+1))
354 |
355 | def save_bestcheckpoint1(state, save_dir):
356 | cp_dir = '{:s}/checkpoints'.format(save_dir)
357 | torch.save(state, '{:s}/checkpoint1_0.pth.tar'.format(cp_dir))
358 |
359 | def save_bestcheckpoint2(state, save_dir):
360 | cp_dir = '{:s}/checkpoints'.format(save_dir)
361 | torch.save(state, '{:s}/checkpoint2_0.pth.tar'.format(cp_dir))
362 |
363 | def setup_logging(opt):
364 | mode = 'a' if opt.train['checkpoint'] else 'w'
365 |
366 | # create logger for training information
367 | logger = logging.getLogger('train_logger')
368 | logger.setLevel(logging.DEBUG)
369 | # create console handler and file handler
370 | console_handler = RichHandler(show_level=False, show_time=False, show_path=False)
371 | console_handler.setLevel(logging.INFO)
372 | file_handler = logging.FileHandler('{:s}/train_log.txt'.format(opt.train['save_dir']), mode=mode)
373 | file_handler.setLevel(logging.DEBUG)
374 | # create formatter
375 | formatter = logging.Formatter('%(message)s')
376 | # add formatter to handlers
377 | console_handler.setFormatter(formatter)
378 | file_handler.setFormatter(formatter)
379 | # add handlers to logger
380 | logger.addHandler(console_handler)
381 | logger.addHandler(file_handler)
382 |
383 | # create logger for epoch results
384 | logger_results = logging.getLogger('results')
385 | logger_results.setLevel(logging.DEBUG)
386 | file_handler2 = logging.FileHandler('{:s}/epoch_results.txt'.format(opt.train['save_dir']), mode=mode)
387 | file_handler2.setFormatter(logging.Formatter('%(message)s'))
388 | logger_results.addHandler(file_handler2)
389 |
390 | logger.info('***** Training starts *****')
391 | logger.info('save directory: {:s}'.format(opt.train['save_dir']))
392 | if mode == 'w':
393 | logger_results.info('epoch\ttrain_loss\ttrain_loss_vor\ttrain_loss_cluster\ttrain_loss_ct\ttrain_loss_self\ttest_dice1\ttest_dice2')
394 |
395 | return logger, logger_results
396 |
397 |
398 | if __name__ == '__main__':
399 | main()
400 |
--------------------------------------------------------------------------------
/utils/accuracy.py:
--------------------------------------------------------------------------------
1 |
2 | from skimage.measure import label
3 | from sklearn.metrics import accuracy_score, roc_auc_score
4 | from sklearn.metrics import f1_score
5 | from sklearn.metrics import recall_score, precision_score
6 | from scipy.spatial.distance import directed_hausdorff as hausdorff
7 | from scipy.ndimage.measurements import center_of_mass
8 | import numpy as np
9 | from scipy.optimize import linear_sum_assignment
10 |
11 | def compute_metrics(pred, gt, names):
12 | """
13 | Computes metrics specified by names between predicted label and groundtruth label.
14 | """
15 |
16 | gt_labeled = label(gt)
17 | pred_labeled = label(pred)
18 |
19 | gt_binary = gt_labeled.copy()
20 | pred_binary = pred_labeled.copy()
21 | gt_binary[gt_binary > 0] = 1
22 | pred_binary[pred_binary > 0] = 1
23 | gt_binary, pred_binary = gt_binary.flatten(), pred_binary.flatten()
24 |
25 | results = {}
26 |
27 | # pixel-level metrics
28 | if 'acc' in names:
29 | results['acc'] = accuracy_score(gt_binary, pred_binary)
30 | if 'roc' in names:
31 | results['roc'] = roc_auc_score(gt_binary, pred_binary)
32 | if 'p_F1' in names: # pixel-level F1
33 | results['p_F1'] = f1_score(gt_binary, pred_binary)
34 | if 'p_recall' in names: # pixel-level F1
35 | results['p_recall'] = recall_score(gt_binary, pred_binary)
36 | if 'p_precision' in names: # pixel-level F1
37 | results['p_precision'] = precision_score(gt_binary, pred_binary)
38 |
39 | # object-level metrics
40 | if 'aji' in names:
41 | results['aji'] = AJI_fast(gt_labeled, pred_labeled)
42 | if 'haus' in names:
43 | results['dice'], results['iou'], results['haus'] = accuracy_object_level(pred_labeled, gt_labeled, True)
44 | if 'dice' in names or 'iou' in names:
45 | results['dice'], results['iou'], _ = accuracy_object_level(pred_labeled, gt_labeled, False)
46 | if 'pq' in names:
47 | results['dq'], results['sq'], results['pq'] = get_pq(gt_labeled, pred_labeled)
48 |
49 | return results
50 |
51 |
52 | def accuracy_object_level(pred, gt, hausdorff_flag=True):
53 | """ Compute the object-level metrics between predicted and
54 | groundtruth: dice, iou, hausdorff """
55 | if not isinstance(pred, np.ndarray):
56 | pred = np.array(pred)
57 | if not isinstance(gt, np.ndarray):
58 | gt = np.array(gt)
59 |
60 | # get connected components
61 | pred_labeled = label(pred, connectivity=2)
62 | Ns = len(np.unique(pred_labeled)) - 1
63 | gt_labeled = label(gt, connectivity=2)
64 | Ng = len(np.unique(gt_labeled)) - 1
65 |
66 | # --- compute dice, iou, hausdorff --- #
67 | pred_objs_area = np.sum(pred_labeled>0) # total area of objects in image
68 | gt_objs_area = np.sum(gt_labeled>0) # total area of objects in groundtruth gt
69 |
70 | # compute how well groundtruth object overlaps its segmented object
71 | dice_g = 0.0
72 | iou_g = 0.0
73 | hausdorff_g = 0.0
74 | for i in range(1, Ng + 1):
75 | gt_i = np.where(gt_labeled == i, 1, 0)
76 | overlap_parts = gt_i * pred_labeled
77 |
78 | # get intersection objects numbers in image
79 | obj_no = np.unique(overlap_parts)
80 | obj_no = obj_no[obj_no != 0]
81 |
82 | gamma_i = float(np.sum(gt_i)) / gt_objs_area
83 |
84 | if obj_no.size == 0: # no intersection object
85 | dice_i = 0
86 | iou_i = 0
87 |
88 | # find nearest segmented object in hausdorff distance
89 | if hausdorff_flag:
90 | min_haus = 1e3
91 |
92 | # find overlap object in a window [-50, 50]
93 | pred_cand_indices = find_candidates(gt_i, pred_labeled)
94 |
95 | for j in pred_cand_indices:
96 | pred_j = np.where(pred_labeled == j, 1, 0)
97 | seg_ind = np.argwhere(pred_j)
98 | gt_ind = np.argwhere(gt_i)
99 | haus_tmp = max(hausdorff(seg_ind, gt_ind)[0], hausdorff(gt_ind, seg_ind)[0])
100 |
101 | if haus_tmp < min_haus:
102 | min_haus = haus_tmp
103 | haus_i = min_haus
104 | else:
105 | # find max overlap object
106 | obj_areas = [np.sum(overlap_parts == k) for k in obj_no]
107 | seg_obj = obj_no[np.argmax(obj_areas)] # segmented object number
108 | pred_i = np.where(pred_labeled == seg_obj, 1, 0) # segmented object
109 |
110 | overlap_area = np.max(obj_areas) # overlap area
111 |
112 | dice_i = 2 * float(overlap_area) / (np.sum(pred_i) + np.sum(gt_i))
113 | iou_i = float(overlap_area) / (np.sum(pred_i) + np.sum(gt_i) - overlap_area)
114 |
115 | # compute hausdorff distance
116 | if hausdorff_flag:
117 | seg_ind = np.argwhere(pred_i)
118 | gt_ind = np.argwhere(gt_i)
119 | haus_i = max(hausdorff(seg_ind, gt_ind)[0], hausdorff(gt_ind, seg_ind)[0])
120 |
121 | dice_g += gamma_i * dice_i
122 | iou_g += gamma_i * iou_i
123 | if hausdorff_flag:
124 | hausdorff_g += gamma_i * haus_i
125 |
126 | # compute how well segmented object overlaps its groundtruth object
127 | dice_s = 0.0
128 | iou_s = 0.0
129 | hausdorff_s = 0.0
130 | for j in range(1, Ns + 1):
131 | pred_j = np.where(pred_labeled == j, 1, 0)
132 | overlap_parts = pred_j * gt_labeled
133 |
134 | # get intersection objects number in gt
135 | obj_no = np.unique(overlap_parts)
136 | obj_no = obj_no[obj_no != 0]
137 |
138 | # show_figures((pred_j, gt_labeled, overlap_parts))
139 |
140 | sigma_j = float(np.sum(pred_j)) / pred_objs_area
141 | # no intersection object
142 | if obj_no.size == 0:
143 | dice_j = 0
144 | iou_j = 0
145 |
146 | # find nearest groundtruth object in hausdorff distance
147 | if hausdorff_flag:
148 | min_haus = 1e3
149 |
150 | # find overlap object in a window [-50, 50]
151 | gt_cand_indices = find_candidates(pred_j, gt_labeled)
152 |
153 | for i in gt_cand_indices:
154 | gt_i = np.where(gt_labeled == i, 1, 0)
155 | seg_ind = np.argwhere(pred_j)
156 | gt_ind = np.argwhere(gt_i)
157 | haus_tmp = max(hausdorff(seg_ind, gt_ind)[0], hausdorff(gt_ind, seg_ind)[0])
158 |
159 | if haus_tmp < min_haus:
160 | min_haus = haus_tmp
161 | haus_j = min_haus
162 | else:
163 | # find max overlap gt
164 | gt_areas = [np.sum(overlap_parts == k) for k in obj_no]
165 | gt_obj = obj_no[np.argmax(gt_areas)] # groundtruth object number
166 | gt_j = np.where(gt_labeled == gt_obj, 1, 0) # groundtruth object
167 |
168 | overlap_area = np.max(gt_areas) # overlap area
169 |
170 | dice_j = 2 * float(overlap_area) / (np.sum(pred_j) + np.sum(gt_j))
171 | iou_j = float(overlap_area) / (np.sum(pred_j) + np.sum(gt_j) - overlap_area)
172 |
173 | # compute hausdorff distance
174 | if hausdorff_flag:
175 | seg_ind = np.argwhere(pred_j)
176 | gt_ind = np.argwhere(gt_j)
177 | haus_j = max(hausdorff(seg_ind, gt_ind)[0], hausdorff(gt_ind, seg_ind)[0])
178 |
179 | dice_s += sigma_j * dice_j
180 | iou_s += sigma_j * iou_j
181 | if hausdorff_flag:
182 | hausdorff_s += sigma_j * haus_j
183 |
184 | return (dice_g + dice_s) / 2, (iou_g + iou_s) / 2, (hausdorff_g + hausdorff_s) / 2
185 |
186 |
187 | def find_candidates(obj_i, objects_labeled, radius=50):
188 | """
189 | find object indices in objects_labeled in a window centered at obj_i
190 | when computing object-level hausdorff distance
191 |
192 | """
193 | if radius > 400:
194 | return np.array([])
195 |
196 | h, w = objects_labeled.shape
197 | x, y = center_of_mass(obj_i)
198 | x, y = int(x), int(y)
199 | r1 = x-radius if x-radius >= 0 else 0
200 | r2 = x+radius if x+radius <= h else h
201 | c1 = y-radius if y-radius >= 0 else 0
202 | c2 = y+radius if y+radius < w else w
203 | indices = np.unique(objects_labeled[r1:r2, c1:c2])
204 | indices = indices[indices != 0]
205 |
206 | if indices.size == 0:
207 | indices = find_candidates(obj_i, objects_labeled, 2*radius)
208 |
209 | return indices
210 |
211 |
212 | def AJI_fast(gt, pred_arr):
213 | gs, g_areas = np.unique(gt, return_counts=True)
214 | assert np.all(gs == np.arange(len(gs)))
215 | ss, s_areas = np.unique(pred_arr, return_counts=True)
216 | assert np.all(ss == np.arange(len(ss)))
217 |
218 | i_idx, i_cnt = np.unique(np.concatenate([gt.reshape(1, -1), pred_arr.reshape(1, -1)]),
219 | return_counts=True, axis=1)
220 | i_arr = np.zeros(shape=(len(gs), len(ss)), dtype=np.int)
221 | i_arr[i_idx[0], i_idx[1]] += i_cnt
222 | u_arr = g_areas.reshape(-1, 1) + s_areas.reshape(1, -1) - i_arr
223 | iou_arr = 1.0 * i_arr / u_arr
224 |
225 | i_arr = i_arr[1:, 1:]
226 | u_arr = u_arr[1:, 1:]
227 | iou_arr = iou_arr[1:, 1:]
228 |
229 | j = np.argmax(iou_arr, axis=1)
230 |
231 | c = np.sum(i_arr[np.arange(len(gs) - 1), j])
232 | u = np.sum(u_arr[np.arange(len(gs) - 1), j])
233 | used = np.zeros(shape=(len(ss) - 1), dtype=np.int)
234 | used[j] = 1
235 | u += (np.sum(s_areas[1:] * (1 - used)))
236 | return 1.0 * c / u
237 |
238 |
239 | def remap_label(pred, by_size=False):
240 | """Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3]
241 | not [0, 2, 4, 6]. The ordering of instances (which one comes first)
242 | is preserved unless by_size=True, then the instances will be reordered
243 | so that bigger nucler has smaller ID.
244 |
245 | Args:
246 | pred (ndarray): the 2d array contain instances where each instances is marked
247 | by non-zero integer.
248 | by_size (bool): renaming such that larger nuclei have a smaller id (on-top).
249 |
250 | Returns:
251 | new_pred (ndarray): Array with continguous ordering of instances.
252 |
253 | """
254 | pred_id = list(np.unique(pred))
255 | pred_id.remove(0)
256 | if len(pred_id) == 0:
257 | return pred # no label
258 | if by_size:
259 | pred_size = []
260 | for inst_id in pred_id:
261 | size = (pred == inst_id).sum()
262 | pred_size.append(size)
263 | # sort the id by size in descending order
264 | pair_list = zip(pred_id, pred_size)
265 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True)
266 | pred_id, pred_size = zip(*pair_list)
267 |
268 | new_pred = np.zeros(pred.shape, np.int32)
269 | for idx, inst_id in enumerate(pred_id):
270 | new_pred[pred == inst_id] = idx + 1
271 | return new_pred
272 |
273 |
274 | def get_bounding_box(img):
275 | """Get the bounding box coordinates of a binary input- assumes a single object.
276 |
277 | Args:
278 | img: input binary image.
279 |
280 | Returns:
281 | bounding box coordinates
282 |
283 | """
284 | rows = np.any(img, axis=1)
285 | cols = np.any(img, axis=0)
286 | rmin, rmax = np.where(rows)[0][[0, -1]]
287 | cmin, cmax = np.where(cols)[0][[0, -1]]
288 | # due to python indexing, need to add 1 to max
289 | # else accessing will be 1px in the box, not out
290 | rmax += 1
291 | cmax += 1
292 | return [rmin, rmax, cmin, cmax]
293 |
294 |
295 | def get_pq(true, pred, match_iou=0.5, remap=True):
296 | """Get the panoptic quality result.
297 |
298 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4]
299 | not [2, 3, 6, 10]. Please call `remap_label` beforehand. Here, the `by_size` flag
300 | has no effect on the result.
301 |
302 | Args:
303 | true (ndarray): HxW ground truth instance segmentation map
304 | pred (ndarray): HxW predicted instance segmentation map
305 | match_iou (float): IoU threshold level to determine the pairing between
306 | GT instances `p` and prediction instances `g`. `p` and `g` is a pair
307 | if IoU > `match_iou`. However, pair of `p` and `g` must be unique
308 | (1 prediction instance to 1 GT instance mapping). If `match_iou` < 0.5,
309 | Munkres assignment (solving minimum weight matching in bipartite graphs)
310 | is caculated to find the maximal amount of unique pairing. If
311 | `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and
312 | the number of pairs is also maximal.
313 | remap (bool): whether to ensure contiguous ordering of instances.
314 |
315 | Returns:
316 | [dq, sq, pq]: measurement statistic
317 |
318 | [paired_true, paired_pred, unpaired_true, unpaired_pred]:
319 | pairing information to perform measurement
320 |
321 | paired_iou.sum(): sum of IoU within true positive predictions
322 |
323 | """
324 | assert match_iou >= 0.0, "Cant' be negative"
325 | # ensure instance maps are contiguous
326 | if remap:
327 | pred = remap_label(pred)
328 | true = remap_label(true)
329 |
330 | true = np.copy(true)
331 | pred = np.copy(pred)
332 | true = true.astype("int32")
333 | pred = pred.astype("int32")
334 | true_id_list = list(np.unique(true))
335 | pred_id_list = list(np.unique(pred))
336 | # prefill with value
337 | pairwise_iou = np.zeros([len(true_id_list), len(pred_id_list)], dtype=np.float64)
338 |
339 | # caching pairwise iou
340 | for true_id in true_id_list[1:]: # 0-th is background
341 | t_mask_lab = true == true_id
342 | rmin1, rmax1, cmin1, cmax1 = get_bounding_box(t_mask_lab)
343 | t_mask_crop = t_mask_lab[rmin1:rmax1, cmin1:cmax1]
344 | t_mask_crop = t_mask_crop.astype("int")
345 | p_mask_crop = pred[rmin1:rmax1, cmin1:cmax1]
346 | pred_true_overlap = p_mask_crop[t_mask_crop > 0]
347 | pred_true_overlap_id = np.unique(pred_true_overlap)
348 | pred_true_overlap_id = list(pred_true_overlap_id)
349 | for pred_id in pred_true_overlap_id:
350 | if pred_id == 0: # ignore
351 | continue # overlaping background
352 | p_mask_lab = pred == pred_id
353 | p_mask_lab = p_mask_lab.astype("int")
354 |
355 | # crop region to speed up computation
356 | rmin2, rmax2, cmin2, cmax2 = get_bounding_box(p_mask_lab)
357 | rmin = min(rmin1, rmin2)
358 | rmax = max(rmax1, rmax2)
359 | cmin = min(cmin1, cmin2)
360 | cmax = max(cmax1, cmax2)
361 | t_mask_crop2 = t_mask_lab[rmin:rmax, cmin:cmax]
362 | p_mask_crop2 = p_mask_lab[rmin:rmax, cmin:cmax]
363 |
364 | total = (t_mask_crop2 + p_mask_crop2).sum()
365 | inter = (t_mask_crop2 * p_mask_crop2).sum()
366 | iou = inter / (total - inter)
367 | pairwise_iou[true_id - 1, pred_id - 1] = iou
368 |
369 | if match_iou >= 0.5:
370 | paired_iou = pairwise_iou[pairwise_iou > match_iou]
371 | pairwise_iou[pairwise_iou <= match_iou] = 0.0
372 | paired_true, paired_pred = np.nonzero(pairwise_iou)
373 | paired_iou = pairwise_iou[paired_true, paired_pred]
374 | paired_true += 1 # index is instance id - 1
375 | paired_pred += 1 # hence return back to original
376 | else: # * Exhaustive maximal unique pairing
377 | #### Munkres pairing with scipy library
378 | # the algorithm return (row indices, matched column indices)
379 | # if there is multiple same cost in a row, index of first occurence
380 | # is return, thus the unique pairing is ensure
381 | # inverse pair to get high IoU as minimum
382 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
383 | ### extract the paired cost and remove invalid pair
384 | paired_iou = pairwise_iou[paired_true, paired_pred]
385 |
386 | # now select those above threshold level
387 | # paired with iou = 0.0 i.e no intersection => FP or FN
388 | paired_true = list(paired_true[paired_iou > match_iou] + 1)
389 | paired_pred = list(paired_pred[paired_iou > match_iou] + 1)
390 | paired_iou = paired_iou[paired_iou > match_iou]
391 |
392 | # get the actual FP and FN
393 | unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true]
394 | unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred]
395 | # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred))
396 |
397 | #
398 | tp = len(paired_true)
399 | fp = len(unpaired_pred)
400 | fn = len(unpaired_true)
401 | # get the F1-score i.e DQ
402 | dq = tp / ((tp + 0.5 * fp + 0.5 * fn) + 1.0e-6)
403 | # get the SQ, no paired has 0 iou so not impact
404 | sq = paired_iou.sum() / (tp + 1.0e-6)
405 |
406 | # return (
407 | # [dq, sq, dq * sq],
408 | # [tp, fp, fn],
409 | # paired_iou.sum(),
410 | return dq, sq, dq * sq
--------------------------------------------------------------------------------
/utils/combine.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch
3 |
4 |
5 | def Combine(prediction, target):
6 | predict = prediction
7 | pseudos = torch.zeros([prediction.shape[0], 2, prediction.shape[2], prediction.shape[3]]).cuda()
8 | for i in range(target.shape[0]):
9 | pseudo_label = predict[i].squeeze(0)
10 | pseudo_label[target[i] == 1] = 1.0 # nuclei
11 | pseudo_label[target[i] == 0] = 0.0 # background
12 | pseudo_0 = torch.ones((target[i].shape[0], target[i].shape[1])).cuda() - pseudo_label
13 | pseudo1 = pseudo_label.unsqueeze(0)
14 | pseudo0 = pseudo_0.unsqueeze(0)
15 | pseudo = torch.cat([pseudo0, pseudo1.float()], dim=0)
16 | pseudos[i] = pseudo
17 | return pseudos
--------------------------------------------------------------------------------
/utils/divergence.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def loss_CT(M1, M2):
5 | criterion = torch.nn.KLDivLoss(reduction='mean').cuda()
6 | loss = criterion(F.log_softmax(M1, dim=1), M2)
7 | return loss
8 |
9 | def Epsilon(M1, M2):
10 | results = 0
11 | M1_1 = F.softmax(M1, dim=1)
12 | for i in range(M2.shape[0]):
13 | m1 = M1_1[i] - M2[i]
14 | m2 = M1_1[i] + M2[i]
15 | result = 2 * (torch.abs(m1)).sum() / m2.sum()
16 | result = result.item()
17 | results += result
18 | return results
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import random
4 | import torch
5 | from scipy.spatial import Voronoi
6 | from skimage import draw
7 |
8 |
9 | def poly2mask(vertex_row_coords, vertex_col_coords, shape):
10 | fill_row_coords, fill_col_coords = draw.polygon(vertex_row_coords, vertex_col_coords, shape)
11 | mask = np.zeros(shape, dtype=np.bool)
12 | mask[fill_row_coords, fill_col_coords] = True
13 | return mask
14 |
15 |
16 | # borrowed from https://gist.github.com/pv/8036995
17 | def voronoi_finite_polygons_2d(vor, radius=None):
18 | """
19 | Reconstruct infinite voronoi regions in a 2D diagram to finite
20 | regions.
21 |
22 | Parameters
23 | ----------
24 | vor : Voronoi
25 | Input diagram
26 | radius : float, optional
27 | Distance to 'points at infinity'.
28 |
29 | Returns
30 | -------
31 | regions : list of tuples
32 | Indices of vertices in each revised Voronoi regions.
33 | vertices : list of tuples
34 | Coordinates for revised Voronoi vertices. Same as coordinates
35 | of input vertices, with 'points at infinity' appended to the
36 | end.
37 |
38 | """
39 |
40 | if vor.points.shape[1] != 2:
41 | raise ValueError("Requires 2D input")
42 |
43 | new_regions = []
44 | new_vertices = vor.vertices.tolist()
45 |
46 | center = vor.points.mean(axis=0)
47 | if radius is None:
48 | radius = vor.points.ptp().max()
49 |
50 | # Construct a map containing all ridges for a given point
51 | all_ridges = {}
52 | for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
53 | all_ridges.setdefault(p1, []).append((p2, v1, v2))
54 | all_ridges.setdefault(p2, []).append((p1, v1, v2))
55 |
56 | # Reconstruct infinite regions
57 | for p1, region in enumerate(vor.point_region):
58 | vertices = vor.regions[region]
59 |
60 | if all(v >= 0 for v in vertices):
61 | # finite region
62 | new_regions.append(vertices)
63 | continue
64 |
65 | # reconstruct a non-finite region
66 | ridges = all_ridges[p1]
67 | new_region = [v for v in vertices if v >= 0]
68 |
69 | for p2, v1, v2 in ridges:
70 | if v2 < 0:
71 | v1, v2 = v2, v1
72 | if v1 >= 0:
73 | # finite ridge: already in the region
74 | continue
75 |
76 | # Compute the missing endpoint of an infinite ridge
77 |
78 | t = vor.points[p2] - vor.points[p1] # tangent
79 | t /= np.linalg.norm(t)
80 | n = np.array([-t[1], t[0]]) # normal
81 |
82 | midpoint = vor.points[[p1, p2]].mean(axis=0)
83 | direction = np.sign(np.dot(midpoint - center, n)) * n
84 | far_point = vor.vertices[v2] + direction * radius
85 |
86 | new_region.append(len(new_vertices))
87 | new_vertices.append(far_point.tolist())
88 |
89 | # sort region counterclockwise
90 | vs = np.asarray([new_vertices[v] for v in new_region])
91 | c = vs.mean(axis=0)
92 | angles = np.arctan2(vs[:,1] - c[1], vs[:,0] - c[0])
93 | new_region = np.array(new_region)[np.argsort(angles)]
94 |
95 | # finish
96 | new_regions.append(new_region.tolist())
97 |
98 | return new_regions, np.asarray(new_vertices)
99 |
100 |
101 | def split_forward(model, input, size, overlap, outchannel = 2):
102 | '''
103 | split the input image for forward passes
104 | '''
105 |
106 | b, c, h0, w0 = input.size()
107 |
108 | # zero pad for border patches
109 | pad_h = 0
110 | if h0 - size > 0 and (h0 - size) % (size - overlap) > 0:
111 | pad_h = (size - overlap) - (h0 - size) % (size - overlap)
112 | tmp = torch.zeros((b, c, pad_h, w0))
113 | input = torch.cat((input, tmp), dim=2)
114 |
115 | if w0 - size > 0 and (w0 - size) % (size - overlap) > 0:
116 | pad_w = (size - overlap) - (w0 - size) % (size - overlap)
117 | tmp = torch.zeros((b, c, h0 + pad_h, pad_w))
118 | input = torch.cat((input, tmp), dim=3)
119 |
120 | _, c, h, w = input.size()
121 |
122 | output = torch.zeros((input.size(0), outchannel, h, w))
123 | for i in range(0, h-overlap, size-overlap):
124 | r_end = i + size if i + size < h else h
125 | ind1_s = i + overlap // 2 if i > 0 else 0
126 | ind1_e = i + size - overlap // 2 if i + size < h else h
127 | for j in range(0, w-overlap, size-overlap):
128 | c_end = j+size if j+size < w else w
129 |
130 | input_patch = input[:,:,i:r_end,j:c_end]
131 | input_var = input_patch.cuda()
132 | with torch.no_grad():
133 | output_patch, _ = model(input_var)
134 | #_, output_patch = model(input_var)
135 |
136 | ind2_s = j+overlap//2 if j>0 else 0
137 | ind2_e = j+size-overlap//2 if j+size