├── .gitignore ├── README.md ├── dataloaders ├── dataset.py ├── my_transforms.py └── prepare_data.py ├── figures └── overview.png ├── model ├── modelW.py ├── model_WNet.py ├── resnet.py └── resnet1.py ├── options.py ├── requirements.txt ├── test.py ├── train.py └── utils ├── accuracy.py ├── combine.py ├── divergence.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | */.idea 2 | */__pycache__ 3 | */.vscode 4 | # */figures 5 | */requirements.txt 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SC-Net 2 | This is the official code for our MedIA paper: 3 | 4 | > [Nuclei Segmentation with Point Annotations from Pathology Images via Self-Supervised Learning and Co-Training](https://arxiv.org/abs/2202.08195)
5 | > Yi Lin*, Zhiyong Qu*, Hao Chen, Zhongke Gao, Yuexiang Li, Lili Xia, Kai Ma, Yefeng Zheng, Kwang-Ting Cheng 6 | 7 | ## Highlights 8 |

9 | In this work, we propose a weakly-supervised learning method for nuclei segmentation that only requires point annotations for training. The proposed method achieves label propagation in a coarse-to-fine manner as follows. First, coarse pixel-level labels are derived from the point annotations based on the Voronoi diagram and the k-means clustering method to avoid overfitting. Second, a co-training strategy with an exponential moving average method is designed to refine the incomplete supervision of the coarse labels. Third, a self-supervised visual representation learning method is tailored for nuclei segmentation of pathology images that transforms the hematoxylin component images into the H&E stained images to gain better understanding of the relationship between the nuclei and cytoplasm. 10 | 11 | [comment]: <> () 12 | ![visualization](figures/overview.png) 13 |

14 | (a) The pipeline of the proposed method; (b) The framework of SC-Net; (c) The process of pseudo label generation. 15 |
16 | 17 | ### Using the code 18 | Please clone the following repositories: 19 | ``` 20 | git clone https://github.com/hust-linyi/SC-Net.git 21 | ``` 22 | 23 | ### Requirement 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ### Data preparation 29 | #### Download 30 | 1. **MoNuSeg** [Multi-Organ Nuclei Segmentation dataset](https://monuseg.grand-challenge.org) 31 | 2. **CPM** [Computational Precision Medicine dataset](https://drive.google.com/drive/folders/1sJ4nmkif6j4s2FOGj8j6i_Ye7z9w0TfA) 32 | 33 | #### Pre-processing 34 | Please refer to [dataloaders/prepare_data.py](https://github.com/hust-linyi/SC-Net/blob/main/dataloaders/prepare_data.py) for the pre-processing of the datasets. 35 | 36 | ### Training 37 | 1. Configure your own parameters in [opinions.py](https://github.com/hust-linyi/SC-Net/blob/main/options.py), including the dataset path, the number of GPUs, the number of epochs, the batch size, the learning rate, etc. 38 | 2. Run the following command to train the model: 39 | ``` 40 | python train.py 41 | ``` 42 | 43 | ### Testing 44 | Run the following command to test the model: 45 | ``` 46 | python test.py 47 | ``` 48 | 49 | ## Citation 50 | Please cite the paper if you use the code. 51 | ```bibtex 52 | @article{lin2023nuclei, 53 | title={Nuclei segmentation with point annotations from pathology images via self-supervised learning and co-training}, 54 | author={Lin, Yi and Qu, Zhiyong and Chen, Hao and Gao, Zhongke and Li, Yuexiang and Xia, Lili and Ma, Kai and Zheng, Yefeng and Cheng, Kwang-Ting}, 55 | journal={Medical Image Analysis}, 56 | pages={102933}, 57 | year={2023}, 58 | publisher={Elsevier} 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /dataloaders/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | 6 | IMG_EXTENSIONS = [ 7 | '.jpg', '.JPG', '.jpeg', '.JPEG', 8 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 9 | ] 10 | 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 14 | 15 | 16 | def img_loader(path, num_channels): 17 | if num_channels == 1: 18 | img = Image.open(path) 19 | else: 20 | img = Image.open(path).convert('RGB') 21 | 22 | return img 23 | 24 | 25 | # get the image list pairs 26 | def get_imgs_list(dir_list, post_fix=None): 27 | """ 28 | :param dir_list: [img1_dir, img2_dir, ...] 29 | :param post_fix: e.g. ['label_vor.png', 'label_cluster.png',...] 30 | :return: e.g. [(img1.png, img1_label_vor.png, img1_label_cluster.png), ...] 31 | """ 32 | img_list = [] 33 | if len(dir_list) == 0: 34 | return img_list 35 | 36 | img_filename_list = os.listdir(dir_list[0]) 37 | 38 | for img in img_filename_list: 39 | if not is_image_file(img): 40 | continue 41 | img1_name = os.path.splitext(img)[0] 42 | item = [os.path.join(dir_list[0], img), ] 43 | for i in range(1, len(dir_list)): 44 | img_name = '{:s}{:s}'.format(img1_name, post_fix[i - 1]) 45 | img_path = os.path.join(dir_list[i], img_name) 46 | item.append(img_path) 47 | 48 | if len(item) == len(dir_list): 49 | img_list.append(tuple(item)) 50 | 51 | return img_list 52 | 53 | # dataset that supports multiple images 54 | class DataFolder(data.Dataset): 55 | def __init__(self, dir_list, post_fix, num_channels, data_transform=None, loader=img_loader): 56 | """ 57 | :param dir_list: [img_dir, label_voronoi_dir, label_cluster_dir] 58 | :param post_fix: ['label_vor.png', 'label_cluster.png'] 59 | :param num_channels: [3, 3, 3] 60 | :param data_transform: data transformations 61 | :param loader: image loader 62 | """ 63 | super(DataFolder, self).__init__() 64 | # if len(dir_list) != len(post_fix) + 1: 65 | # raise (RuntimeError('Length of dir_list is different from length of post_fix + 1.')) 66 | if len(dir_list) != len(num_channels): 67 | raise (RuntimeError('Length of dir_list is different from length of num_channels.')) 68 | 69 | self.img_list = get_imgs_list(dir_list, post_fix) 70 | if len(self.img_list) == 0: 71 | raise(RuntimeError('Found 0 image pairs in given directories.')) 72 | 73 | self.data_transform = data_transform 74 | self.num_channels = num_channels 75 | self.loader = loader 76 | 77 | def __getitem__(self, index): 78 | img_paths = self.img_list[index] 79 | sample = [self.loader(img_paths[i], self.num_channels[i]) for i in range(len(img_paths))] 80 | 81 | if self.data_transform is not None: 82 | sample = self.data_transform(sample) 83 | 84 | return sample 85 | 86 | def __len__(self): 87 | return len(self.img_list) 88 | 89 | 90 | # dataset that supports multiple images 91 | class DataFolderTest(data.Dataset): 92 | def __init__(self, dir_list, post_fix, num_channels, data_transform=None, loader=img_loader): 93 | """ 94 | :param dir_list: [img_dir, label_voronoi_dir, label_cluster_dir] 95 | :param post_fix: ['label_vor.png', 'label_cluster.png'] 96 | :param num_channels: [3, 3, 3] 97 | :param data_transform: data transformations 98 | :param loader: image loader 99 | """ 100 | super(DataFolderTest, self).__init__() 101 | if len(dir_list) != len(num_channels): 102 | raise (RuntimeError('Length of dir_list is different from length of num_channels.')) 103 | 104 | self.img_list = get_imgs_list(dir_list, post_fix) 105 | if len(self.img_list) == 0: 106 | raise(RuntimeError('Found 0 image pairs in given directories.')) 107 | self.data_transform = data_transform 108 | self.num_channels = num_channels 109 | self.loader = loader 110 | 111 | 112 | def __getitem__(self, index): 113 | img_paths = self.img_list[index] 114 | sample = [self.loader(img_paths[i], self.num_channels[i]) for i in range(len(img_paths))] 115 | 116 | if self.data_transform is not None: 117 | sample = self.data_transform(sample) 118 | 119 | return sample 120 | 121 | def __len__(self): 122 | return len(self.img_list) 123 | -------------------------------------------------------------------------------- /dataloaders/my_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from PIL import Image, ImageOps 4 | import numpy as np 5 | import numbers 6 | 7 | def get_transforms(p): 8 | """ data transforms for train, validation and test 9 | p: transform dictionary 10 | """ 11 | t_list = list() 12 | if 'random_resize' in p: 13 | t_list.append(RandomResize(p['random_resize'][0], p['random_resize'][1])) 14 | 15 | if 'horizontal_flip' in p: 16 | t_list.append(RandomHorizontalFlip()) 17 | 18 | if 'vertical_flip' in p: 19 | t_list.append(RandomVerticalFlip()) 20 | 21 | if 'random_affine' in p: 22 | t_list.append(RandomAffine(p['random_affine'])) 23 | 24 | if 'random_rotation' in p: 25 | t_list.append(RandomRotation(p['random_rotation'])) 26 | 27 | if 'random_crop' in p: 28 | t_list.append(RandomCrop(p['random_crop'])) 29 | 30 | if 'label_encoding' in p: 31 | t_list.append(LabelEncoding(p['label_encoding'])) 32 | 33 | if 'to_tensor' in p: 34 | t_list.append(ToTensor(p['to_tensor'])) 35 | 36 | if 'normalize' in p: 37 | t_list.append(Normalize(mean=p['normalize'][0], std=p['normalize'][1])) 38 | 39 | return Compose(t_list) 40 | 41 | 42 | class Compose(object): 43 | """ Composes several transforms together. 44 | Args: 45 | transforms (list of ``Transform`` objects): list of transforms to compose. 46 | """ 47 | 48 | def __init__(self, transforms): 49 | self.transforms = transforms 50 | 51 | def __call__(self, imgs): 52 | for t in self.transforms: 53 | imgs = t(imgs) 54 | return imgs 55 | 56 | 57 | class ToTensor(object): 58 | """ Convert (img, label) of type ``PIL.Image`` or ``numpy.ndarray`` to tensors. 59 | Converts img of type PIL.Image or numpy.ndarray (H x W x C) in the range 60 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 61 | Converts label of type PIL.Image or numpy.ndarray (H x W) in the range [0, 255] 62 | to a torch.LongTensor of shape (H x W) in the range [0, 255]. 63 | """ 64 | def __init__(self, index=1): 65 | self.index = index # index to distinguish between images and labels 66 | 67 | def __call__(self, imgs): 68 | """ 69 | Args: 70 | imgs (PIL.Image or numpy.ndarray): Image to be converted to tensor. 71 | Returns: 72 | Tensor: Converted image. 73 | """ 74 | if len(imgs) < self.index: 75 | raise ValueError('The number of images is smaller than separation index!') 76 | 77 | pics = [] 78 | 79 | # process image 80 | for i in range(0, self.index): 81 | img = imgs[i] 82 | if isinstance(img, np.ndarray): 83 | # handle numpy array 84 | pic = torch.from_numpy(img.transpose((2, 0, 1))) 85 | # backward compatibility 86 | pics.append(pic.float().div(255)) 87 | 88 | # handle PIL Image 89 | if img.mode == 'I': 90 | pic = torch.from_numpy(np.array(img, np.int32, copy=False)) 91 | elif img.mode == 'I;16': 92 | pic = torch.from_numpy(np.array(img, np.int16, copy=False)) 93 | else: 94 | pic = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())) 95 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 96 | if img.mode == 'YCbCr': 97 | nchannel = 3 98 | elif img.mode == 'I;16': 99 | nchannel = 1 100 | else: 101 | nchannel = len(img.mode) 102 | pic = pic.view(img.size[1], img.size[0], nchannel) 103 | # put it from HWC to CHW format 104 | # yikes, this transpose takes 80% of the loading time/CPU 105 | pic = pic.transpose(0, 1).transpose(0, 2).contiguous() 106 | if isinstance(pic, torch.ByteTensor): 107 | pics.append(pic.float().div(255)) 108 | else: 109 | pics.append(pic) 110 | 111 | # process labels: 112 | for i in range(self.index, len(imgs)): 113 | # process label 114 | label = imgs[i] 115 | if isinstance(label, np.ndarray): 116 | # handle numpy array 117 | label_tensor = torch.from_numpy(label) 118 | # backward compatibility 119 | pics.append(label_tensor.long()) 120 | 121 | # handle PIL Image 122 | if label.mode == 'I': 123 | label_tensor = torch.from_numpy(np.array(label, np.int32, copy=False)) 124 | elif label.mode == 'I;16': 125 | label_tensor = torch.from_numpy(np.array(label, np.int16, copy=False)) 126 | else: 127 | label_tensor = torch.ByteTensor(torch.ByteStorage.from_buffer(label.tobytes())) 128 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 129 | if label.mode == 'YCbCr': 130 | nchannel = 3 131 | elif label.mode == 'I;16': 132 | nchannel = 1 133 | else: 134 | nchannel = len(label.mode) 135 | label_tensor = label_tensor.view(label.size[1], label.size[0], nchannel) 136 | # put it from HWC to CHW format 137 | # yikes, this transpose takes 80% of the loading time/CPU 138 | label_tensor = label_tensor.transpose(0, 1).transpose(0, 2).contiguous() 139 | # label_tensor = label_tensor.view(label.size[1], label.size[0]) 140 | pics.append(label_tensor.long()) 141 | 142 | return tuple(pics) 143 | 144 | 145 | class Normalize(object): 146 | """ Normalize an tensor image with mean and standard deviation. 147 | Given mean and std, will normalize each channel of the torch.*Tensor, 148 | i.e. channel = (channel - mean) / std 149 | Args: 150 | mean (sequence): Sequence of means for each channel. 151 | std (sequence): Sequence of standard deviations for each channel. 152 | ** only normalize the first image, keep the target image unchanged 153 | """ 154 | 155 | def __init__(self, mean, std): 156 | self.mean = mean 157 | self.std = std 158 | 159 | def __call__(self, tensors): 160 | """ 161 | Args: 162 | tensors (Tensor): Tensor images of size (C, H, W) to be normalized. 163 | Returns: 164 | Tensor: Normalized image. 165 | """ 166 | tensors = list(tensors) 167 | for t, m, s in zip(tensors[0], self.mean, self.std): 168 | t.sub_(m).div_(s) 169 | return tuple(tensors) 170 | 171 | 172 | class RandomCrop(object): 173 | """Crop the given PIL.Image at a random location. 174 | Args: 175 | size (sequence or int): Desired output size of the crop. If size is an 176 | int instead of sequence like (w, h), a square crop (size, size) is 177 | made. 178 | padding (int or sequence, optional): Optional padding on each border 179 | of the image. Default is 0, i.e no padding. If a sequence of length 180 | 4 is provided, it is used to pad left, top, right, bottom borders 181 | respectively. 182 | """ 183 | 184 | def __init__(self, size, padding=0, fill_val=(0,)): 185 | if isinstance(size, numbers.Number): 186 | self.size = (int(size), int(size)) 187 | else: 188 | self.size = size 189 | self.padding = padding 190 | self.fill_val = fill_val 191 | 192 | def __call__(self, imgs): 193 | """ 194 | Args: 195 | img (PIL.Image): Image to be cropped. 196 | Returns: 197 | PIL.Image: Cropped image. 198 | """ 199 | pics = [] 200 | 201 | w, h = imgs[0].size 202 | th, tw = self.size 203 | x1 = random.randint(0, w - tw) 204 | y1 = random.randint(0, h - th) 205 | 206 | for k in range(len(imgs)): 207 | img = imgs[k] 208 | if self.padding > 0: 209 | img = ImageOps.expand(img, border=self.padding, fill=self.fill_val[k]) 210 | 211 | if w == tw and h == th: 212 | pics.append(img) 213 | continue 214 | 215 | pics.append(img.crop((x1, y1, x1 + tw, y1 + th))) 216 | 217 | return tuple(pics) 218 | 219 | 220 | class RandomHorizontalFlip(object): 221 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 222 | 223 | def __call__(self, imgs): 224 | """ 225 | Args: 226 | img (PIL.Image): Image to be flipped. 227 | Returns: 228 | PIL.Image: Randomly flipped image. 229 | """ 230 | pics = [] 231 | if random.random() < 0.5: 232 | for img in imgs: 233 | pics.append(img.transpose(Image.FLIP_LEFT_RIGHT)) 234 | return tuple(pics) 235 | else: 236 | return imgs 237 | 238 | 239 | class RandomVerticalFlip(object): 240 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 241 | 242 | def __call__(self, imgs): 243 | """ 244 | Args: 245 | img (PIL.Image): Image to be flipped. 246 | Returns: 247 | PIL.Image: Randomly flipped image. 248 | """ 249 | pics = [] 250 | if random.random() < 0.5: 251 | for img in imgs: 252 | pics.append(img.transpose(Image.FLIP_TOP_BOTTOM)) 253 | return tuple(pics) 254 | else: 255 | return imgs 256 | 257 | 258 | class RandomRotation(object): 259 | """Rotate the image by angle. 260 | Args: 261 | degrees (sequence or float or int): Range of degrees to select from. 262 | If degrees is a number instead of sequence like (min, max), the range of degrees 263 | will be (-degrees, +degrees). 264 | resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): 265 | An optional resampling filter. 266 | See http://pillow.readthedocs.io/en/3.4.x/handbook/concepts.html#filters 267 | If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. 268 | expand (bool, optional): Optional expansion flag. 269 | If true, expands the output to make it large enough to hold the entire rotated image. 270 | If false or omitted, make the output image the same size as the input image. 271 | Note that the expand flag assumes rotation around the center and no translation. 272 | center (2-tuple, optional): Optional center of rotation. 273 | Origin is the upper left corner. 274 | Default is the center of the image. 275 | """ 276 | 277 | def __init__(self, degrees, resample=Image.BILINEAR, expand=False, center=None): 278 | if isinstance(degrees, numbers.Number): 279 | if degrees < 0: 280 | raise ValueError("If degrees is a single number, it must be positive.") 281 | self.degrees = (-degrees, degrees) 282 | else: 283 | if len(degrees) != 2: 284 | raise ValueError("If degrees is a sequence, it must be of len 2.") 285 | self.degrees = degrees 286 | 287 | self.resample = resample 288 | self.expand = expand 289 | self.center = center 290 | 291 | @staticmethod 292 | def get_params(degrees): 293 | """Get parameters for ``rotate`` for a random rotation. 294 | Returns: 295 | sequence: params to be passed to ``rotate`` for random rotation. 296 | """ 297 | angle = random.uniform(degrees[0], degrees[1]) 298 | 299 | return angle 300 | 301 | def __call__(self, imgs): 302 | """ 303 | imgs (PIL Image): Images to be rotated. 304 | Returns: 305 | PIL Image: Rotated image. 306 | """ 307 | 308 | angle = self.get_params(self.degrees) 309 | 310 | pics = [] 311 | for img in imgs: 312 | pics.append(img.rotate(angle, self.resample, self.expand, self.center)) 313 | 314 | # process the binary label 315 | # pics[1] = pics[1].point(lambda p: p > 127.5 and 255) 316 | 317 | return tuple(pics) 318 | 319 | 320 | class RandomResize(object): 321 | """Randomly Resize the input PIL Image using a scale of lb~ub. 322 | Args: 323 | size (sequence or int): Desired output size. If size is a sequence like 324 | (h, w), output size will be matched to this. If size is an int, 325 | smaller edge of the image will be matched to this number. 326 | i.e, if height > width, then image will be rescaled to 327 | (size * height / width, size) 328 | interpolation (int, optional): Desired interpolation. Default is 329 | ``PIL.Image.BILINEAR`` 330 | """ 331 | 332 | def __init__(self, lb=0.5, ub=1.5, interpolation=Image.BILINEAR): 333 | self.lb = lb 334 | self.ub = ub 335 | self.interpolation = interpolation 336 | 337 | def __call__(self, imgs): 338 | """ 339 | Args: 340 | imgs (PIL Images): Images to be scaled. 341 | Returns: 342 | PIL Images: Rescaled images. 343 | """ 344 | 345 | for img in imgs: 346 | if not isinstance(img, Image.Image): 347 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 348 | 349 | scale = random.uniform(self.lb, self.ub) 350 | # print scale 351 | 352 | w, h = imgs[0].size 353 | ow = int(w * scale) 354 | oh = int(h * scale) 355 | 356 | if scale < 1: 357 | padding_l = (w - ow)//2 358 | padding_t = (h - oh)//2 359 | padding_r = w - ow - padding_l 360 | padding_b = h - oh - padding_t 361 | padding = (padding_l, padding_t, padding_r, padding_b) 362 | 363 | pics = [] 364 | for i in range(len(imgs)): 365 | img = imgs[i] 366 | img = img.resize((ow, oh), self.interpolation) 367 | if scale < 1: 368 | img = ImageOps.expand(img, border=padding, fill=0) 369 | pics.append(img) 370 | 371 | return tuple(pics) 372 | 373 | 374 | class RandomAffine(object): 375 | """ Transform the input PIL Image using a random affine transformation 376 | The parameters of an affine transformation [a, b, c=0 377 | d, e, f=0] 378 | are generated randomly according to the bound, and there is no translation 379 | (c=f=0) 380 | Args: 381 | bound: the largest possible deviation of random parameters 382 | """ 383 | 384 | def __init__(self, bound): 385 | if bound < 0 or bound > 0.5: 386 | raise ValueError("Bound is invalid, should be in range [0, 0.5)") 387 | 388 | self.bound = bound 389 | 390 | def __call__(self, imgs): 391 | img = imgs[0] 392 | x, y = img.size 393 | 394 | a = 1 + 2 * self.bound * (random.random() - 0.5) 395 | b = 2 * self.bound * (random.random() - 0.5) 396 | d = 2 * self.bound * (random.random() - 0.5) 397 | e = 1 + 2 * self.bound * (random.random() - 0.5) 398 | 399 | # correct the transformation center to image center 400 | c = -a * x / 2 - b * y / 2 + x / 2 401 | f = -d * x / 2 - e * y / 2 + y / 2 402 | 403 | trans_matrix = [a, b, c, d, e, f] 404 | 405 | pics = [] 406 | for img in imgs: 407 | pics.append(img.transform((x, y), Image.AFFINE, trans_matrix)) 408 | 409 | return tuple(pics) 410 | 411 | 412 | class LabelEncoding(object): 413 | """ 414 | encode the 3-channel labels into one channel integer label map 415 | """ 416 | def __init__(self, num_labels=1): 417 | self.num_labels = num_labels 418 | 419 | def __call__(self, imgs): 420 | assert self.num_labels < len(imgs) 421 | 422 | out_imgs = list(imgs[:-self.num_labels]) 423 | image = imgs[0] # input image 424 | if not isinstance(image, np.ndarray): 425 | image = np.array(image) 426 | 427 | for i in range(-self.num_labels, 0): # labels 428 | label = imgs[i] 429 | if not isinstance(label, np.ndarray): 430 | label = np.array(label) 431 | # print(len(label.shape)) 432 | 433 | # ----- for estimated pixel-level label 434 | new_label = np.ones((label.shape[0], label.shape[1]), dtype=np.uint8) * 2 435 | new_label[label[:, :, 0] > 255 * 0.3] = 0 436 | new_label[label[:, :, 1] > 255 * 0.5] = 1 437 | new_label[(image[:, :, 0] == 0) * (image[:, :, 1] == 0) * (image[:, :, 2] == 0)] = 0 438 | new_label = Image.fromarray(new_label.astype(np.uint8)) 439 | out_imgs.append(new_label) 440 | 441 | return tuple(out_imgs) 442 | -------------------------------------------------------------------------------- /dataloaders/prepare_data.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import os 3 | import shutil 4 | import numpy as np 5 | from skimage import morphology, measure 6 | from sklearn.cluster import KMeans 7 | from scipy.ndimage.morphology import distance_transform_edt as dist_tranform 8 | import glob 9 | import json 10 | 11 | 12 | def main(): 13 | dataset = 'MO' # LC: Lung Cancer, MO: MultiOrgan 14 | data_dir = './data/{:s}'.format(dataset) 15 | img_dir = './data/{:s}/images'.format(dataset) 16 | label_instance_dir = './data/{:s}/labels_instance'.format(dataset) 17 | label_point_dir = './data/{:s}/labels_point'.format(dataset) 18 | label_vor_dir = './data/{:s}/labels_voronoi'.format(dataset) 19 | label_cluster_dir = './data/{:s}/labels_cluster'.format(dataset) 20 | label_Hcomponent_dir = './data/{:s}/labels_Hcomponent'.format(dataset) 21 | patch_folder = './data/{:s}/patches'.format(dataset) 22 | train_data_dir = './data_for_train/{:s}'.format(dataset) 23 | create_folder('./data_for_train') 24 | create_folder(label_point_dir) 25 | create_folder(label_vor_dir) 26 | create_folder(label_cluster_dir) 27 | create_folder(patch_folder) 28 | create_folder(train_data_dir) 29 | 30 | with open('{:s}/train_val_test.json'.format(data_dir), 'r') as file: 31 | data_list = json.load(file) 32 | train_list = data_list['train'] 33 | 34 | # ------ create H component image from original image 35 | create_h_component_image(img_dir, label_Hcomponent_dir) 36 | 37 | # ------ create point label from instance label 38 | create_point_label_from_instance(label_instance_dir, label_point_dir, train_list) 39 | 40 | # ------ create Voronoi label from point label 41 | create_Voronoi_label(label_point_dir, label_vor_dir, train_list) 42 | 43 | # ------ create cluster label from point label and image 44 | create_cluster_label(img_dir, label_point_dir, label_vor_dir, label_cluster_dir, train_list) 45 | 46 | # ------ split large images into 250x250 patches 47 | print("Spliting large images into small patches...") 48 | split_patches(img_dir, '{:s}/images'.format(patch_folder)) 49 | split_patches(label_vor_dir, '{:s}/labels_voronoi'.format(patch_folder), 'label_vor') 50 | split_patches(label_cluster_dir, '{:s}/labels_cluster'.format(patch_folder), 'label_cluster') 51 | split_patches(label_Hcomponent_dir, '{:s}/labels_Hcomponent'.format(patch_folder), 'label_Hcomponent') 52 | 53 | # ------ divide dataset into train, val and test sets 54 | organize_data_for_training(data_dir, train_data_dir) 55 | 56 | # ------ compute mean and std 57 | compute_mean_std(data_dir, train_data_dir) 58 | 59 | 60 | def create_h_component_image(img_dir, save_dir): 61 | image_list = os.listdir(img_dir) 62 | N_total = len(image_list) 63 | N_processed = 0 64 | 65 | # define stain_matrix 66 | H = np.array([0.650, 0.704, 0.286]) 67 | E = np.array([0.072, 0.990, 0.105]) 68 | R = np.array([0.268, 0.570, 0.776]) 69 | HDABtoRGB = [(H/np.linalg.norm(H)).tolist(), (E/np.linalg.norm(E)).tolist(), (R/np.linalg.norm(R)).tolist()] 70 | stain_matrix = HDABtoRGB 71 | im_inv = np.linalg.inv(stain_matrix) 72 | 73 | for image_name in image_list: 74 | N_processed += 1 75 | flag = '' if N_processed < N_total else '\n' 76 | print('\r\t{:d}/{:d}'.format(N_processed, N_total), end=flag) 77 | 78 | image_path = os.path.join(img_dir, image_name) 79 | image = imageio.imread(image_path) 80 | 81 | # transform 82 | im_temp = (-255)*np.log((np.float64(image)+1)/255)/np.log(255) 83 | image_out = np.reshape(np.dot(np.reshape(im_temp, [-1,3]), im_inv), np.shape(image)) 84 | image_out = np.exp((255-image_out)*np.log(255)/255) 85 | image_out[image_out > 255] = 255 86 | image_h = image_out[:, :, 0].astype(np.uint8) 87 | 88 | imageio.imsave('{:s}/{:s}_h.png'.format(save_dir, image_name[:-4]), image_h) 89 | 90 | 91 | def create_point_label_from_instance(data_dir, save_dir, train_list): 92 | def get_point(img): 93 | a = np.where(img != 0) 94 | rmin, rmax, cmin, cmax = np.min(a[0]), np.max(a[0]), np.min(a[1]), np.max(a[1]) 95 | return (rmin + rmax) // 2, (cmin + cmax) // 2 96 | 97 | print("Generating point label from instance label...") 98 | image_list = os.listdir(data_dir) 99 | N_total = len(train_list) 100 | N_processed = 0 101 | for image_name in image_list: 102 | name = image_name.split('.')[0] 103 | if '{:s}.png'.format(name[:-6]) not in train_list or name[-5:] != 'label': 104 | continue 105 | 106 | N_processed += 1 107 | flag = '' if N_processed < N_total else '\n' 108 | print('\r\t{:d}/{:d}'.format(N_processed, N_total), end=flag) 109 | 110 | image_path = os.path.join(data_dir, image_name) 111 | image = imageio.imread(image_path) 112 | h, w = image.shape 113 | 114 | # extract bbox 115 | id_max = np.max(image) 116 | label_point = np.zeros((h, w), dtype=np.uint8) 117 | 118 | for i in range(1, id_max + 1): 119 | nucleus = image == i 120 | if np.sum(nucleus) == 0: 121 | continue 122 | x, y = get_point(nucleus) 123 | label_point[x, y] = 255 124 | 125 | imageio.imsave('{:s}/{:s}_point.png'.format(save_dir, name), label_point.astype(np.uint8)) 126 | 127 | 128 | def create_Voronoi_label(data_dir, save_dir, train_list): 129 | from scipy.spatial import Voronoi 130 | from shapely.geometry import Polygon 131 | from utils.utils import voronoi_finite_polygons_2d, poly2mask 132 | img_list = os.listdir(data_dir) 133 | 134 | print("Generating Voronoi label from point label...") 135 | N_total = len(train_list) 136 | N_processed = 0 137 | for img_name in sorted(img_list): 138 | name = img_name.split('.')[0] 139 | if '{:s}.png'.format(name[:-12]) not in train_list: 140 | continue 141 | 142 | N_processed += 1 143 | flag = '' if N_processed < N_total else '\n' 144 | print('\r\t{:d}/{:d}'.format(N_processed, N_total), end=flag) 145 | 146 | img_path = '{:s}/{:s}'.format(data_dir, img_name) 147 | label_point = imageio.imread(img_path) 148 | h, w = label_point.shape 149 | 150 | points = np.argwhere(label_point > 0) 151 | vor = Voronoi(points) 152 | 153 | regions, vertices = voronoi_finite_polygons_2d(vor) 154 | box = Polygon([[0, 0], [0, w], [h, w], [h, 0]]) 155 | region_masks = np.zeros((h, w), dtype=np.int16) 156 | edges = np.zeros((h, w), dtype=np.bool) 157 | count = 1 158 | for region in regions: 159 | polygon = vertices[region] 160 | # Clipping polygon 161 | poly = Polygon(polygon) 162 | poly = poly.intersection(box) 163 | polygon = np.array([list(p) for p in poly.exterior.coords]) 164 | 165 | mask = poly2mask(polygon[:, 0], polygon[:, 1], (h, w)) 166 | edge = mask * (~morphology.erosion(mask, morphology.disk(1))) 167 | edges += edge 168 | region_masks[mask] = count 169 | count += 1 170 | 171 | # fuse Voronoi edge and dilated points 172 | label_point_dilated = morphology.dilation(label_point, morphology.disk(2)) 173 | label_vor = np.zeros((h, w, 3), dtype=np.uint8) 174 | label_vor[:, :, 0] = (edges > 0).astype(np.uint8) * 255 175 | label_vor[:, :, 1] = (label_point_dilated > 0).astype(np.uint8) * 255 176 | 177 | name = img_name.split('.')[0] 178 | imageio.imsave('{:s}/{:s}_label_vor.png'.format(save_dir, name[:-12]), label_vor) 179 | 180 | 181 | def create_cluster_label(data_dir, label_point_dir, label_vor_dir, save_dir, train_list): 182 | from scipy.ndimage import morphology as ndi_morph 183 | 184 | img_list = os.listdir(data_dir) 185 | print("Generating cluster label from point label...") 186 | N_total = len(train_list) 187 | N_processed = 0 188 | for img_name in sorted(img_list): 189 | if img_name not in train_list: 190 | continue 191 | 192 | N_processed += 1 193 | flag = '' if N_processed < N_total else '\n' 194 | print('\r\t{:d}/{:d}'.format(N_processed, N_total), end=flag) 195 | 196 | # print('\t[{:d}/{:d}] Processing image {:s} ...'.format(count, len(img_list), img_name)) 197 | ori_image = imageio.imread('{:s}/{:s}'.format(data_dir, img_name)) 198 | h, w, _ = ori_image.shape 199 | label_point = imageio.imread('{:s}/{:s}_label_point.png'.format(label_point_dir, img_name[:-4])) 200 | 201 | # k-means clustering 202 | dist_embeddings = dist_tranform(255 - label_point).reshape(-1, 1) 203 | clip_dist_embeddings = np.clip(dist_embeddings, a_min=0, a_max=20) 204 | color_embeddings = np.array(ori_image, dtype=np.float).reshape(-1, 3) / 10 205 | embeddings = np.concatenate((color_embeddings, clip_dist_embeddings), axis=1) 206 | 207 | # print("\t\tPerforming k-means clustering...") 208 | kmeans = KMeans(n_clusters=3, random_state=0).fit(embeddings) 209 | clusters = np.reshape(kmeans.labels_, (h, w)) 210 | 211 | # get nuclei and background clusters 212 | overlap_nums = [np.sum((clusters == i) * label_point) for i in range(3)] 213 | nuclei_idx = np.argmax(overlap_nums) 214 | remain_indices = np.delete(np.arange(3), nuclei_idx) 215 | dilated_label_point = morphology.binary_dilation(label_point, morphology.disk(5)) 216 | overlap_nums = [np.sum((clusters == i) * dilated_label_point) for i in remain_indices] 217 | background_idx = remain_indices[np.argmin(overlap_nums)] 218 | 219 | nuclei_cluster = clusters == nuclei_idx 220 | background_cluster = clusters == background_idx 221 | 222 | nuclei_labeled = measure.label(nuclei_cluster) 223 | initial_nuclei = morphology.remove_small_objects(nuclei_labeled, 30) 224 | refined_nuclei = np.zeros(initial_nuclei.shape, dtype=np.bool) 225 | 226 | label_vor = imageio.imread('{:s}/{:s}_label_vor.png'.format(label_vor_dir, img_name[:-4])) 227 | voronoi_cells = measure.label(label_vor[:, :, 0] == 0) 228 | voronoi_cells = morphology.dilation(voronoi_cells, morphology.disk(2)) 229 | 230 | unique_vals = np.unique(voronoi_cells) 231 | cell_indices = unique_vals[unique_vals != 0] 232 | N = len(cell_indices) 233 | for i in range(N): 234 | cell_i = voronoi_cells == cell_indices[i] 235 | nucleus_i = cell_i * initial_nuclei 236 | 237 | nucleus_i_dilated = morphology.binary_dilation(nucleus_i, morphology.disk(5)) 238 | nucleus_i_dilated_filled = ndi_morph.binary_fill_holes(nucleus_i_dilated) 239 | nucleus_i_final = morphology.binary_erosion(nucleus_i_dilated_filled, morphology.disk(7)) 240 | refined_nuclei += nucleus_i_final > 0 241 | 242 | refined_label = np.zeros((h, w, 3), dtype=np.uint8) 243 | label_point_dilated = morphology.dilation(label_point, morphology.disk(10)) 244 | refined_label[:, :, 0] = (background_cluster * (refined_nuclei == 0) * (label_point_dilated == 0)).astype(np.uint8) * 255 245 | refined_label[:, :, 1] = refined_nuclei.astype(np.uint8) * 255 246 | 247 | imageio.imsave('{:s}/{:s}_label_cluster.png'.format(save_dir, img_name[:-4]), refined_label) 248 | 249 | 250 | def split_patches(data_dir, save_dir, post_fix=None): 251 | import math 252 | """ split large image into small patches """ 253 | create_folder(save_dir) 254 | 255 | image_list = os.listdir(data_dir) 256 | for image_name in image_list: 257 | name = image_name.split('.')[0] 258 | if post_fix and name[-len(post_fix):] != post_fix: 259 | continue 260 | image_path = os.path.join(data_dir, image_name) 261 | image = imageio.imread(image_path) 262 | seg_imgs = [] 263 | 264 | # split into 16 patches of size 250x250 265 | h, w = image.shape[0], image.shape[1] 266 | patch_size = 250 267 | h_overlap = math.ceil((4 * patch_size - h) / 3) 268 | w_overlap = math.ceil((4 * patch_size - w) / 3) 269 | for i in range(0, h-patch_size+1, patch_size-h_overlap): 270 | for j in range(0, w-patch_size+1, patch_size-w_overlap): 271 | if len(image.shape) == 3: 272 | patch = image[i:i+patch_size, j:j+patch_size, :] 273 | else: 274 | patch = image[i:i + patch_size, j:j + patch_size] 275 | seg_imgs.append(patch) 276 | 277 | for k in range(len(seg_imgs)): 278 | if post_fix: 279 | imageio.imsave('{:s}/{:s}_{:d}_{:s}.png'.format(save_dir, name[:-len(post_fix)-1], k, post_fix), seg_imgs[k]) 280 | else: 281 | imageio.imsave('{:s}/{:s}_{:d}.png'.format(save_dir, name, k), seg_imgs[k]) 282 | 283 | 284 | def organize_data_for_training(data_dir, train_data_dir): 285 | # --- Step 1: create folders --- # 286 | create_folder(train_data_dir) 287 | create_folder('{:s}/images'.format(train_data_dir)) 288 | create_folder('{:s}/labels_voronoi'.format(train_data_dir)) 289 | create_folder('{:s}/labels_cluster'.format(train_data_dir)) 290 | create_folder('{:s}/labels_Hcomponent'.format(train_data_dir)) 291 | create_folder('{:s}/images/train'.format(train_data_dir)) 292 | create_folder('{:s}/images/val'.format(train_data_dir)) 293 | create_folder('{:s}/images/test'.format(train_data_dir)) 294 | create_folder('{:s}/labels_voronoi/train'.format(train_data_dir)) 295 | create_folder('{:s}/labels_cluster/train'.format(train_data_dir)) 296 | create_folder('{:s}/labels_Hcomponent/train'.format(train_data_dir)) 297 | 298 | # --- Step 2: move images and labels to each folder --- # 299 | print('Organizing data for training...') 300 | with open('{:s}/train_val_test.json'.format(data_dir), 'r') as file: 301 | data_list = json.load(file) 302 | train_list, val_list, test_list = data_list['train'], data_list['val'], data_list['test'] 303 | 304 | # train 305 | for img_name in train_list: 306 | name = img_name.split('.')[0] 307 | # images 308 | for file in glob.glob('{:s}/patches/images/{:s}*'.format(data_dir, name)): 309 | file_name = file.split('\\')[-1] 310 | dst = '{:s}/images/train/{:s}'.format(train_data_dir, file_name) 311 | shutil.copyfile(file, dst) 312 | # label_voronoi 313 | for file in glob.glob('{:s}/patches/labels_voronoi/{:s}*'.format(data_dir, name)): 314 | file_name = file.split('\\')[-1] 315 | dst = '{:s}/labels_voronoi/train/{:s}'.format(train_data_dir, file_name) 316 | shutil.copyfile(file, dst) 317 | # label_cluster 318 | for file in glob.glob('{:s}/patches/labels_cluster/{:s}*'.format(data_dir, name)): 319 | file_name = file.split('\\')[-1] 320 | dst = '{:s}/labels_cluster/train/{:s}'.format(train_data_dir, file_name) 321 | shutil.copyfile(file, dst) 322 | # label_Hcomponent 323 | for file in glob.glob('{:s}/patches/labels_Hcomponent/{:s}*'.format(data_dir, name)): 324 | file_name = file.split('\\')[-1] 325 | dst = '{:s}/labels_Hcomponent/train/{:s}'.format(train_data_dir, file_name) 326 | shutil.copyfile(file, dst) 327 | # val 328 | for img_name in val_list: 329 | name = img_name.split('.')[0] 330 | # images 331 | for file in glob.glob('{:s}/images/{:s}*'.format(data_dir, name)): 332 | file_name = file.split('\\')[-1] 333 | dst = '{:s}/images/val/{:s}'.format(train_data_dir, file_name) 334 | shutil.copyfile(file, dst) 335 | # test 336 | for img_name in test_list: 337 | name = img_name.split('.')[0] 338 | # images 339 | for file in glob.glob('{:s}/images/{:s}*'.format(data_dir, name)): 340 | file_name = file.split('\\')[-1] 341 | dst = '{:s}/images/test/{:s}'.format(train_data_dir, file_name) 342 | shutil.copyfile(file, dst) 343 | 344 | 345 | def compute_mean_std(data_dir, train_data_dir): 346 | """ compute mean and standarad deviation of training images """ 347 | total_sum = np.zeros(3) # total sum of all pixel values in each channel 348 | total_square_sum = np.zeros(3) 349 | num_pixel = 0 # total num of all pixels 350 | 351 | with open('{:s}/train_val_test.json'.format(data_dir), 'r') as file: 352 | data_list = json.load(file) 353 | train_list = data_list['train'] 354 | 355 | print('Computing the mean and standard deviation of training data...') 356 | 357 | for file_name in train_list: 358 | img_name = '{:s}/images/{:s}'.format(data_dir, file_name) 359 | img = imageio.imread(img_name) 360 | if len(img.shape) != 3 or img.shape[2] < 3: 361 | continue 362 | img = img[:, :, :3].astype(int) 363 | total_sum += img.sum(axis=(0, 1)) 364 | total_square_sum += (img ** 2).sum(axis=(0, 1)) 365 | num_pixel += img.shape[0] * img.shape[1] 366 | 367 | # compute the mean values of each channel 368 | mean_values = total_sum / num_pixel 369 | 370 | # compute the standard deviation 371 | std_values = np.sqrt(total_square_sum / num_pixel - mean_values ** 2) 372 | 373 | # normalization 374 | mean_values = mean_values / 255 375 | std_values = std_values / 255 376 | 377 | np.save('{:s}/mean_std.npy'.format(train_data_dir), np.array([mean_values, std_values])) 378 | np.savetxt('{:s}/mean_std.txt'.format(train_data_dir), np.array([mean_values, std_values]), '%.4f', '\t') 379 | 380 | 381 | def create_folder(folder): 382 | if not os.path.exists(folder): 383 | os.mkdir(folder) 384 | 385 | 386 | if __name__ == '__main__': 387 | main() 388 | -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hust-linyi/SC-Net/90f2a1cc62310c1799795845e239df74b9564bd4/figures/overview.png -------------------------------------------------------------------------------- /model/modelW.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from model.resnet import resnet34 6 | from model.resnet1 import resnet34 as Resnet34 7 | 8 | 9 | class dilated_conv(nn.Module): 10 | """ same as original conv if dilation equals to 1 """ 11 | def __init__(self, in_channel, out_channel, kernel_size=3, dropout_rate=0.0, activation=F.relu, dilation=1): 12 | super().__init__() 13 | self.conv = nn.Conv2d(in_channel, out_channel, kernel_size, padding=dilation, dilation=dilation) 14 | self.norm = nn.BatchNorm2d(out_channel) 15 | self.activation = activation 16 | if dropout_rate > 0: 17 | self.drop = nn.Dropout2d(p=dropout_rate) 18 | else: 19 | self.drop = lambda x: x # no-op 20 | 21 | def forward(self, x): 22 | # CAB: conv -> activation -> batch normal 23 | x = self.norm(self.activation(self.conv(x))) 24 | x = self.drop(x) 25 | return x 26 | 27 | 28 | class ConvDownBlock(nn.Module): 29 | def __init__(self, in_channel, out_channel, dropout_rate=0.0, dilation=1): 30 | super().__init__() 31 | self.conv1 = dilated_conv(in_channel, out_channel, dropout_rate=dropout_rate, dilation=dilation) 32 | self.conv2 = dilated_conv(out_channel, out_channel, dropout_rate=dropout_rate, dilation=dilation) 33 | self.pool = nn.MaxPool2d(kernel_size=2) 34 | 35 | def forward(self, x): 36 | x = self.conv1(x) 37 | x = self.conv2(x) 38 | return self.pool(x), x 39 | 40 | 41 | class ConvUpBlock(nn.Module): 42 | def __init__(self, in_channel, out_channel, dropout_rate=0.0, dilation=1): 43 | super().__init__() 44 | self.up = nn.ConvTranspose2d(in_channel, in_channel // 2, 2, stride=2) 45 | self.conv1 = dilated_conv(in_channel // 2 + out_channel, out_channel, dropout_rate=dropout_rate, dilation=dilation) 46 | self.conv2 = dilated_conv(out_channel, out_channel, dropout_rate=dropout_rate, dilation=dilation) 47 | 48 | def forward(self, x, x_skip): 49 | x = self.up(x) 50 | H_diff = x.shape[2] - x_skip.shape[2] 51 | W_diff = x.shape[3] - x_skip.shape[3] 52 | x_skip = F.pad(x_skip, (0, W_diff, 0, H_diff), mode='reflect') 53 | x = torch.cat([x, x_skip], 1) 54 | x = self.conv1(x) 55 | x = self.conv2(x) 56 | return x 57 | 58 | 59 | # Transfer Learning ResNet as Encoder part of UNet 60 | class ResWNet34(nn.Module): 61 | def __init__(self, seg_classes = 2, colour_classes = 3, fixed_feature=False): 62 | super().__init__() 63 | # load weight of pre-trained resnet 64 | self.resnet = resnet34(pretrained=False) 65 | self.resnet1 = Resnet34(pretrained=False) 66 | if fixed_feature: 67 | for param in self.resnet.parameters(): 68 | param.requires_grad = False 69 | # up conv 70 | l = [64, 64, 128, 256, 512] 71 | self.u5 = ConvUpBlock(l[4], l[3], dropout_rate=0.1) 72 | self.u6 = ConvUpBlock(l[3], l[2], dropout_rate=0.1) 73 | self.u7 = ConvUpBlock(l[2], l[1], dropout_rate=0.1) 74 | self.u8 = ConvUpBlock(l[1], l[0], dropout_rate=0.1) 75 | # final conv 76 | self.seg = nn.ConvTranspose2d(l[0], seg_classes, 2, stride=2) 77 | self.colour = nn.ConvTranspose2d(l[0], colour_classes, 2, stride=2) 78 | self.sigmoid = nn.Sigmoid() 79 | self.softmax = nn.Softmax(dim = 1) 80 | def forward(self, x): 81 | # refer https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 82 | x = self.resnet.conv1(x) 83 | x = self.resnet.bn1(x) 84 | x = s1 = self.resnet.relu(x) 85 | x = self.resnet.maxpool(x) 86 | x = s2 = self.resnet.layer1(x) 87 | x = s3 = self.resnet.layer2(x) 88 | x = s4 = self.resnet.layer3(x) 89 | x = self.resnet.layer4(x) 90 | 91 | x1 = self.u5(x, s4) 92 | x1 = self.u6(x1, s3) 93 | x1 = self.u7(x1, s2) 94 | x1 = self.u8(x1, s1) 95 | x1 = self.seg(x1) 96 | 97 | y_input = self.softmax(x1) 98 | y = self.resnet1.conv1(y_input) 99 | y = self.resnet1.bn1(y) 100 | y = c1 = self.resnet1.relu(y) 101 | y = self.resnet1.maxpool(y) 102 | y = c2 = self.resnet1.layer1(y) 103 | y = c3 = self.resnet1.layer2(y) 104 | y = c4 = self.resnet1.layer3(y) 105 | y = self.resnet1.layer4(y) 106 | 107 | y1 = self.u5(y, c4) 108 | y1 = self.u6(y1, c3) 109 | y1 = self.u7(y1, c2) 110 | y1 = self.u8(y1, c1) 111 | y1 = self.colour(y1) 112 | y1 = self.sigmoid(y1) 113 | 114 | return x1, y1 115 | -------------------------------------------------------------------------------- /model/model_WNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class DoubleConv(nn.Module): 6 | """(convolution => [BN] => ReLU) * 2""" 7 | 8 | def __init__(self, in_channels, out_channels, mid_channels=None): 9 | super().__init__() 10 | if not mid_channels: 11 | mid_channels = out_channels 12 | self.double_conv = nn.Sequential( 13 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 14 | nn.BatchNorm2d(mid_channels), 15 | nn.ReLU(inplace=True), 16 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(out_channels), 18 | nn.ReLU(inplace=True) 19 | ) 20 | 21 | def forward(self, x): 22 | return self.double_conv(x) 23 | 24 | 25 | class Down(nn.Module): 26 | """Downscaling with maxpool then double conv""" 27 | 28 | def __init__(self, in_channels, out_channels): 29 | super().__init__() 30 | self.maxpool_conv = nn.Sequential( 31 | nn.MaxPool2d(2), 32 | DoubleConv(in_channels, out_channels) 33 | ) 34 | 35 | def forward(self, x): 36 | return self.maxpool_conv(x) 37 | 38 | 39 | class Up(nn.Module): 40 | """Upscaling then double conv""" 41 | 42 | def __init__(self, in_channels, out_channels, bilinear=True): 43 | super().__init__() 44 | 45 | # if bilinear, use the normal convolutions to reduce the number of channels 46 | if bilinear: 47 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 48 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 49 | else: 50 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 51 | self.conv = DoubleConv(in_channels, out_channels) 52 | 53 | 54 | def forward(self, x1, x2): 55 | x1 = self.up(x1) 56 | # input is CHW 57 | diffY = x2.size()[2] - x1.size()[2] 58 | diffX = x2.size()[3] - x1.size()[3] 59 | 60 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 61 | diffY // 2, diffY - diffY // 2]) 62 | 63 | x = torch.cat([x2, x1], dim=1) 64 | return self.conv(x) 65 | 66 | class OutConv(nn.Module): 67 | def __init__(self, in_channels, out_channels): 68 | super(OutConv, self).__init__() 69 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 70 | def forward(self, x): 71 | return self.conv(x) 72 | 73 | class OutConvy(nn.Module): 74 | def __init__(self, in_channels, out_channels): 75 | super(OutConvy, self).__init__() 76 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 77 | self.sigmoid = nn.Sigmoid() 78 | def forward(self, x): 79 | x = self.conv(x) 80 | x = self.sigmoid(x) 81 | return x 82 | 83 | class WNet(nn.Module): 84 | def __init__(self, n_channels, seg_classes, colour_classes, bilinear=True): 85 | super(WNet, self).__init__() 86 | self.n_channels = n_channels 87 | self.seg_classes = seg_classes 88 | self.colour_classes = colour_classes 89 | self.bilinear = bilinear 90 | 91 | self.inc = DoubleConv(n_channels, 64) 92 | self.down1 = Down(64, 128) 93 | self.down2 = Down(128, 256) 94 | self.down3 = Down(256, 512) 95 | factor = 2 if bilinear else 1 96 | self.down4 = Down(512, 1024 // factor) 97 | self.up1 = Up(1024, 512 // factor, bilinear) 98 | self.up2 = Up(512, 256 // factor, bilinear) 99 | self.up3 = Up(256, 128 // factor, bilinear) 100 | self.up4 = Up(128, 64, bilinear) 101 | self.outc = OutConv(64, seg_classes) 102 | 103 | self.inc1 = DoubleConv(seg_classes, 64) 104 | self.down11 = Down(64, 128) 105 | self.down21 = Down(128, 256) 106 | self.down31 = Down(256, 512) 107 | factor = 2 if bilinear else 1 108 | self.down41 = Down(512, 1024 // factor) 109 | self.up11 = Up(1024, 512 // factor, bilinear) 110 | self.up21 = Up(512, 256 // factor, bilinear) 111 | self.up31 = Up(256, 128 // factor, bilinear) 112 | self.up41 = Up(128, 64, bilinear) 113 | self.outcy = OutConvy(64, colour_classes) 114 | self.softmax = nn.Softmax(dim = 1) 115 | 116 | def forward(self, x): 117 | x1 = self.inc(x) 118 | x2 = self.down1(x1) 119 | x3 = self.down2(x2) 120 | x4 = self.down3(x3) 121 | x5 = self.down4(x4) 122 | 123 | x = self.up1(x5, x4) 124 | x = self.up2(x, x3) 125 | x = self.up3(x, x2) 126 | x = self.up4(x, x1) 127 | x = self.outc(x) 128 | 129 | y_input = self.softmax(x) 130 | y1 = self.inc1(y_input) 131 | y2 = self.down11(y1) 132 | y3 = self.down21(y2) 133 | y4 = self.down31(y3) 134 | y5 = self.down41(y4) 135 | 136 | y = self.up11(y5, y4) 137 | y = self.up21(y, y3) 138 | y = self.up31(y, y2) 139 | y = self.up41(y, y1) 140 | y = self.outcy(y) 141 | 142 | return x, y 143 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=dilation, groups=groups, bias=False, dilation=dilation) 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1): 29 | """1x1 convolution""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 37 | base_width=64, dilation=1, norm_layer=None): 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 43 | if dilation > 1: 44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = norm_layer(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = norm_layer(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | identity = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | 67 | out += identity 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 75 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 76 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 77 | # This variant is also known as ResNet V1.5 and improves accuracy according to 78 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 79 | 80 | expansion = 4 81 | 82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 83 | base_width=64, dilation=1, norm_layer=None): 84 | super(Bottleneck, self).__init__() 85 | if norm_layer is None: 86 | norm_layer = nn.BatchNorm2d 87 | width = int(planes * (base_width / 64.)) * groups 88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 89 | self.conv1 = conv1x1(inplanes, width) 90 | self.bn1 = norm_layer(width) 91 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 92 | self.bn2 = norm_layer(width) 93 | self.conv3 = conv1x1(width, planes * self.expansion) 94 | self.bn3 = norm_layer(planes * self.expansion) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.downsample = downsample 97 | self.stride = stride 98 | 99 | def forward(self, x): 100 | identity = x 101 | 102 | out = self.conv1(x) 103 | out = self.bn1(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv3(out) 111 | out = self.bn3(out) 112 | 113 | if self.downsample is not None: 114 | identity = self.downsample(x) 115 | 116 | out += identity 117 | out = self.relu(out) 118 | 119 | return out 120 | 121 | 122 | class ResNet(nn.Module): 123 | 124 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 125 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 126 | norm_layer=None): 127 | super(ResNet, self).__init__() 128 | if norm_layer is None: 129 | norm_layer = nn.BatchNorm2d 130 | self._norm_layer = norm_layer 131 | 132 | self.inplanes = 64 133 | self.dilation = 1 134 | if replace_stride_with_dilation is None: 135 | # each element in the tuple indicates if we should replace 136 | # the 2x2 stride with a dilated convolution instead 137 | replace_stride_with_dilation = [False, False, False] 138 | if len(replace_stride_with_dilation) != 3: 139 | raise ValueError("replace_stride_with_dilation should be None " 140 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 141 | self.groups = groups 142 | self.base_width = width_per_group 143 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = norm_layer(self.inplanes) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 150 | dilate=replace_stride_with_dilation[0]) 151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 152 | dilate=replace_stride_with_dilation[1]) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 154 | dilate=replace_stride_with_dilation[2]) 155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 161 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 162 | nn.init.constant_(m.weight, 1) 163 | nn.init.constant_(m.bias, 0) 164 | 165 | # Zero-initialize the last BN in each residual branch, 166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 168 | if zero_init_residual: 169 | for m in self.modules(): 170 | if isinstance(m, Bottleneck): 171 | nn.init.constant_(m.bn3.weight, 0) 172 | elif isinstance(m, BasicBlock): 173 | nn.init.constant_(m.bn2.weight, 0) 174 | 175 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 176 | norm_layer = self._norm_layer 177 | downsample = None 178 | previous_dilation = self.dilation 179 | if dilate: 180 | self.dilation *= stride 181 | stride = 1 182 | if stride != 1 or self.inplanes != planes * block.expansion: 183 | downsample = nn.Sequential( 184 | conv1x1(self.inplanes, planes * block.expansion, stride), 185 | norm_layer(planes * block.expansion), 186 | ) 187 | 188 | layers = [] 189 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 190 | self.base_width, previous_dilation, norm_layer)) 191 | self.inplanes = planes * block.expansion 192 | for _ in range(1, blocks): 193 | layers.append(block(self.inplanes, planes, groups=self.groups, 194 | base_width=self.base_width, dilation=self.dilation, 195 | norm_layer=norm_layer)) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def _forward_impl(self, x): 200 | # See note [TorchScript super()] 201 | x = self.conv1(x) 202 | x = self.bn1(x) 203 | x = self.relu(x) 204 | x = self.maxpool(x) 205 | 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | 211 | x = self.avgpool(x) 212 | x = torch.flatten(x, 1) 213 | x = self.fc(x) 214 | 215 | return x 216 | 217 | def forward(self, x): 218 | return self._forward_impl(x) 219 | 220 | 221 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 222 | model = ResNet(block, layers, **kwargs) 223 | if pretrained: 224 | state_dict = load_state_dict_from_url(model_urls[arch], 225 | progress=progress) 226 | model.load_state_dict(state_dict) 227 | return model 228 | 229 | 230 | def resnet18(pretrained=False, progress=True, **kwargs): 231 | r"""ResNet-18 model from 232 | `"Deep Residual Learning for Image Recognition" `_ 233 | 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | progress (bool): If True, displays a progress bar of the download to stderr 237 | """ 238 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 239 | **kwargs) 240 | 241 | 242 | def resnet34(pretrained=False, progress=True, **kwargs): 243 | r"""ResNet-34 model from 244 | `"Deep Residual Learning for Image Recognition" `_ 245 | 246 | Args: 247 | pretrained (bool): If True, returns a model pre-trained on ImageNet 248 | progress (bool): If True, displays a progress bar of the download to stderr 249 | """ 250 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 251 | **kwargs) 252 | 253 | 254 | def resnet50(pretrained=False, progress=True, **kwargs): 255 | r"""ResNet-50 model from 256 | `"Deep Residual Learning for Image Recognition" `_ 257 | 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 263 | **kwargs) 264 | 265 | 266 | def resnet101(pretrained=False, progress=True, **kwargs): 267 | r"""ResNet-101 model from 268 | `"Deep Residual Learning for Image Recognition" `_ 269 | 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 275 | **kwargs) 276 | 277 | 278 | def resnet152(pretrained=False, progress=True, **kwargs): 279 | r"""ResNet-152 model from 280 | `"Deep Residual Learning for Image Recognition" `_ 281 | 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | progress (bool): If True, displays a progress bar of the download to stderr 285 | """ 286 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 287 | **kwargs) 288 | 289 | 290 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 291 | r"""ResNeXt-50 32x4d model from 292 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 293 | 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | progress (bool): If True, displays a progress bar of the download to stderr 297 | """ 298 | kwargs['groups'] = 32 299 | kwargs['width_per_group'] = 4 300 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 301 | pretrained, progress, **kwargs) 302 | 303 | 304 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 305 | r"""ResNeXt-101 32x8d model from 306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | progress (bool): If True, displays a progress bar of the download to stderr 311 | """ 312 | kwargs['groups'] = 32 313 | kwargs['width_per_group'] = 8 314 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 315 | pretrained, progress, **kwargs) 316 | 317 | 318 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 319 | r"""Wide ResNet-50-2 model from 320 | `"Wide Residual Networks" `_ 321 | 322 | The model is the same as ResNet except for the bottleneck number of channels 323 | which is twice larger in every block. The number of channels in outer 1x1 324 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 325 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 326 | 327 | Args: 328 | pretrained (bool): If True, returns a model pre-trained on ImageNet 329 | progress (bool): If True, displays a progress bar of the download to stderr 330 | """ 331 | kwargs['width_per_group'] = 64 * 2 332 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 333 | pretrained, progress, **kwargs) 334 | 335 | 336 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 337 | r"""Wide ResNet-101-2 model from 338 | `"Wide Residual Networks" `_ 339 | 340 | The model is the same as ResNet except for the bottleneck number of channels 341 | which is twice larger in every block. The number of channels in outer 1x1 342 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 343 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 344 | 345 | Args: 346 | pretrained (bool): If True, returns a model pre-trained on ImageNet 347 | progress (bool): If True, displays a progress bar of the download to stderr 348 | """ 349 | kwargs['width_per_group'] = 64 * 2 350 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 351 | pretrained, progress, **kwargs) 352 | -------------------------------------------------------------------------------- /model/resnet1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 16 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 17 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 18 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=dilation, groups=groups, bias=False, dilation=dilation) 26 | 27 | 28 | def conv1x1(in_planes, out_planes, stride=1): 29 | """1x1 convolution""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | expansion = 1 35 | 36 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 37 | base_width=64, dilation=1, norm_layer=None): 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 43 | if dilation > 1: 44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = norm_layer(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = norm_layer(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | identity = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | if self.downsample is not None: 65 | identity = self.downsample(x) 66 | 67 | out += identity 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class Bottleneck(nn.Module): 74 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 75 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 76 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 77 | # This variant is also known as ResNet V1.5 and improves accuracy according to 78 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 79 | 80 | expansion = 4 81 | 82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 83 | base_width=64, dilation=1, norm_layer=None): 84 | super(Bottleneck, self).__init__() 85 | if norm_layer is None: 86 | norm_layer = nn.BatchNorm2d 87 | width = int(planes * (base_width / 64.)) * groups 88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 89 | self.conv1 = conv1x1(inplanes, width) 90 | self.bn1 = norm_layer(width) 91 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 92 | self.bn2 = norm_layer(width) 93 | self.conv3 = conv1x1(width, planes * self.expansion) 94 | self.bn3 = norm_layer(planes * self.expansion) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.downsample = downsample 97 | self.stride = stride 98 | 99 | def forward(self, x): 100 | identity = x 101 | 102 | out = self.conv1(x) 103 | out = self.bn1(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv3(out) 111 | out = self.bn3(out) 112 | 113 | if self.downsample is not None: 114 | identity = self.downsample(x) 115 | 116 | out += identity 117 | out = self.relu(out) 118 | 119 | return out 120 | 121 | 122 | class ResNet(nn.Module): 123 | 124 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 125 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 126 | norm_layer=None): 127 | super(ResNet, self).__init__() 128 | if norm_layer is None: 129 | norm_layer = nn.BatchNorm2d 130 | self._norm_layer = norm_layer 131 | 132 | self.inplanes = 64 133 | self.dilation = 1 134 | if replace_stride_with_dilation is None: 135 | # each element in the tuple indicates if we should replace 136 | # the 2x2 stride with a dilated convolution instead 137 | replace_stride_with_dilation = [False, False, False] 138 | if len(replace_stride_with_dilation) != 3: 139 | raise ValueError("replace_stride_with_dilation should be None " 140 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 141 | self.groups = groups 142 | self.base_width = width_per_group 143 | self.conv1 = nn.Conv2d(2, self.inplanes, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = norm_layer(self.inplanes) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 150 | dilate=replace_stride_with_dilation[0]) 151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 152 | dilate=replace_stride_with_dilation[1]) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 154 | dilate=replace_stride_with_dilation[2]) 155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 161 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 162 | nn.init.constant_(m.weight, 1) 163 | nn.init.constant_(m.bias, 0) 164 | 165 | # Zero-initialize the last BN in each residual branch, 166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 168 | if zero_init_residual: 169 | for m in self.modules(): 170 | if isinstance(m, Bottleneck): 171 | nn.init.constant_(m.bn3.weight, 0) 172 | elif isinstance(m, BasicBlock): 173 | nn.init.constant_(m.bn2.weight, 0) 174 | 175 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 176 | norm_layer = self._norm_layer 177 | downsample = None 178 | previous_dilation = self.dilation 179 | if dilate: 180 | self.dilation *= stride 181 | stride = 1 182 | if stride != 1 or self.inplanes != planes * block.expansion: 183 | downsample = nn.Sequential( 184 | conv1x1(self.inplanes, planes * block.expansion, stride), 185 | norm_layer(planes * block.expansion), 186 | ) 187 | 188 | layers = [] 189 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 190 | self.base_width, previous_dilation, norm_layer)) 191 | self.inplanes = planes * block.expansion 192 | for _ in range(1, blocks): 193 | layers.append(block(self.inplanes, planes, groups=self.groups, 194 | base_width=self.base_width, dilation=self.dilation, 195 | norm_layer=norm_layer)) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def _forward_impl(self, x): 200 | # See note [TorchScript super()] 201 | x = self.conv1(x) 202 | x = self.bn1(x) 203 | x = self.relu(x) 204 | x = self.maxpool(x) 205 | 206 | x = self.layer1(x) 207 | x = self.layer2(x) 208 | x = self.layer3(x) 209 | x = self.layer4(x) 210 | 211 | x = self.avgpool(x) 212 | x = torch.flatten(x, 1) 213 | x = self.fc(x) 214 | 215 | return x 216 | 217 | def forward(self, x): 218 | return self._forward_impl(x) 219 | 220 | 221 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 222 | model = ResNet(block, layers, **kwargs) 223 | if pretrained: 224 | state_dict = load_state_dict_from_url(model_urls[arch], 225 | progress=progress) 226 | model.load_state_dict(state_dict) 227 | return model 228 | 229 | 230 | def resnet18(pretrained=False, progress=True, **kwargs): 231 | r"""ResNet-18 model from 232 | `"Deep Residual Learning for Image Recognition" `_ 233 | 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | progress (bool): If True, displays a progress bar of the download to stderr 237 | """ 238 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 239 | **kwargs) 240 | 241 | 242 | def resnet34(pretrained=False, progress=True, **kwargs): 243 | r"""ResNet-34 model from 244 | `"Deep Residual Learning for Image Recognition" `_ 245 | 246 | Args: 247 | pretrained (bool): If True, returns a model pre-trained on ImageNet 248 | progress (bool): If True, displays a progress bar of the download to stderr 249 | """ 250 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 251 | **kwargs) 252 | 253 | 254 | def resnet50(pretrained=False, progress=True, **kwargs): 255 | r"""ResNet-50 model from 256 | `"Deep Residual Learning for Image Recognition" `_ 257 | 258 | Args: 259 | pretrained (bool): If True, returns a model pre-trained on ImageNet 260 | progress (bool): If True, displays a progress bar of the download to stderr 261 | """ 262 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 263 | **kwargs) 264 | 265 | 266 | def resnet101(pretrained=False, progress=True, **kwargs): 267 | r"""ResNet-101 model from 268 | `"Deep Residual Learning for Image Recognition" `_ 269 | 270 | Args: 271 | pretrained (bool): If True, returns a model pre-trained on ImageNet 272 | progress (bool): If True, displays a progress bar of the download to stderr 273 | """ 274 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 275 | **kwargs) 276 | 277 | 278 | def resnet152(pretrained=False, progress=True, **kwargs): 279 | r"""ResNet-152 model from 280 | `"Deep Residual Learning for Image Recognition" `_ 281 | 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | progress (bool): If True, displays a progress bar of the download to stderr 285 | """ 286 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 287 | **kwargs) 288 | 289 | 290 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 291 | r"""ResNeXt-50 32x4d model from 292 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 293 | 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | progress (bool): If True, displays a progress bar of the download to stderr 297 | """ 298 | kwargs['groups'] = 32 299 | kwargs['width_per_group'] = 4 300 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 301 | pretrained, progress, **kwargs) 302 | 303 | 304 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 305 | r"""ResNeXt-101 32x8d model from 306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | progress (bool): If True, displays a progress bar of the download to stderr 311 | """ 312 | kwargs['groups'] = 32 313 | kwargs['width_per_group'] = 8 314 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 315 | pretrained, progress, **kwargs) 316 | 317 | 318 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 319 | r"""Wide ResNet-50-2 model from 320 | `"Wide Residual Networks" `_ 321 | 322 | The model is the same as ResNet except for the bottleneck number of channels 323 | which is twice larger in every block. The number of channels in outer 1x1 324 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 325 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 326 | 327 | Args: 328 | pretrained (bool): If True, returns a model pre-trained on ImageNet 329 | progress (bool): If True, displays a progress bar of the download to stderr 330 | """ 331 | kwargs['width_per_group'] = 64 * 2 332 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 333 | pretrained, progress, **kwargs) 334 | 335 | 336 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 337 | r"""Wide ResNet-101-2 model from 338 | `"Wide Residual Networks" `_ 339 | 340 | The model is the same as ResNet except for the bottleneck number of channels 341 | which is twice larger in every block. The number of channels in outer 1x1 342 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 343 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 344 | 345 | Args: 346 | pretrained (bool): If True, returns a model pre-trained on ImageNet 347 | progress (bool): If True, displays a progress bar of the download to stderr 348 | """ 349 | kwargs['width_per_group'] = 64 * 2 350 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 351 | pretrained, progress, **kwargs) 352 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | 5 | 6 | class Options: 7 | def __init__(self, isTrain): 8 | self.dataset = 'MO' # dataset: LC: Lung Cancer, MO: MultiOrgan 9 | self.isTrain = isTrain # train or test mode 10 | self.rootDir = '' # rootdir of the data and experiment 11 | 12 | # --- model hyper-parameters --- # 13 | self.model = dict() 14 | self.model['name'] = 'ResUNet34' 15 | self.model['pretrained'] = False 16 | self.model['fix_params'] = False 17 | self.model['in_c'] = 1 # input channel 18 | 19 | # --- training params --- # 20 | self.train = dict() 21 | self.train['data_dir'] = '{:s}/Data/{:s}'.format(self.rootDir, self.dataset) # path to data 22 | self.train['save_dir'] = '{:s}/Exp/{:s}'.format(self.rootDir, self.dataset) # path to save results 23 | self.train['input_size'] = 224 # input size of the image 24 | self.train['train_epochs'] = 100 # number of training epochs 25 | self.train['batch_size'] = 8 # batch size 26 | self.train['checkpoint_freq'] = 20 # epoch to save checkpoints 27 | self.train['lr'] = 1e-3 # initial learning rate 28 | self.train['weight_decay'] = 5e-4 # weight decay 29 | self.train['log_interval'] = 37 # iterations to print training results 30 | self.train['ema_interval'] = 5 # eps to save ema 31 | self.train['workers'] = 16 # number of workers to load images 32 | self.train['gpus'] = [0, ] # select gpu devices 33 | # --- resume training --- # 34 | self.train['start_epoch'] = 0 # start epoch 35 | self.train['checkpoint'] = '' 36 | 37 | # --- data transform --- # 38 | self.transform = dict() 39 | 40 | # --- test parameters --- # 41 | self.test = dict() 42 | self.test['test_epoch'] = 80 43 | self.test['gpus'] = [0, ] 44 | self.test['img_dir'] = '{:s}/Data/{:s}/images/test'.format(self.rootDir, self.dataset) 45 | self.test['imgh_dir'] = '{:s}/Data/{:s}/images/testh'.format(self.rootDir, self.dataset) 46 | self.test['label_dir'] = '{:s}/Data/{:s}/labels_instance'.format(self.rootDir, self.dataset) 47 | self.test['save_flag'] = True 48 | self.test['patch_size'] = 224 49 | self.test['overlap'] = 80 50 | self.test['save_dir'] = '{:s}/Exp/{:s}/test_results'.format(self.rootDir, self.dataset) 51 | self.test['checkpoint_dir'] = '{:s}/Exp/{:s}/checkpoints/'.format(self.rootDir, self.dataset) 52 | self.test['model_path1'] = '{:s}/checkpoint1_{:d}.pth.tar'.format(self.test['checkpoint_dir'], self.test['test_epoch']) 53 | self.test['model_path2'] = '{:s}/checkpoint2_{:d}.pth.tar'.format(self.test['checkpoint_dir'], self.test['test_epoch']) 54 | # --- post processing --- # 55 | self.post = dict() 56 | self.post['min_area'] = 20 # minimum area for an object 57 | 58 | def parse(self): 59 | """ Parse the options, replace the default value if there is a new input """ 60 | parser = argparse.ArgumentParser(description='') 61 | if self.isTrain: 62 | parser.add_argument('--batch-size', type=int, default=self.train['batch_size'], help='input batch size for training') 63 | parser.add_argument('--epochs', type=int, default=self.train['train_epochs'], help='number of epochs to train') 64 | parser.add_argument('--lr', type=float, default=self.train['lr'], help='learning rate') 65 | parser.add_argument('--log-interval', type=int, default=self.train['log_interval'], help='how many batches to wait before logging training status') 66 | parser.add_argument('--gpus', type=int, nargs='+', default=self.train['gpus'], help='GPUs for training') 67 | parser.add_argument('--data-dir', type=str, default=self.train['data_dir'], help='directory of training data') 68 | parser.add_argument('--save-dir', type=str, default=self.train['save_dir'], help='directory to save training results') 69 | args = parser.parse_args() 70 | 71 | self.train['batch_size'] = args.batch_size 72 | self.train['train_epochs'] = args.epochs 73 | self.train['lr'] = args.lr 74 | self.train['log_interval'] = args.log_interval 75 | self.train['gpus'] = args.gpus 76 | self.train['data_dir'] = args.data_dir 77 | self.train['img_dir'] = '{:s}/images'.format(self.train['data_dir']) 78 | self.train['label_vor_dir'] = '{:s}/labels_voronoi'.format(self.train['data_dir']) 79 | self.train['label_cluster_dir'] = '{:s}/labels_cluster'.format(self.train['data_dir']) 80 | 81 | 82 | self.train['save_dir'] = args.save_dir 83 | if not os.path.exists(self.train['save_dir']): 84 | os.makedirs(self.train['save_dir'], exist_ok=True) 85 | 86 | # define data transforms for training 87 | self.transform['train'] = { 88 | 'random_resize': [0.8, 1.25], 89 | 'horizontal_flip': True, 90 | 'vertical_flip': True, 91 | 'random_affine': 0.3, 92 | 'random_rotation': 90, 93 | 'random_crop': self.train['input_size'], 94 | 'label_encoding': 2, 95 | 'to_tensor': 3 96 | } 97 | 98 | self.transform['val'] = { 99 | 'random_crop': self.train['input_size'], 100 | 'label_encoding': 2, 101 | 'to_tensor': 2 102 | } 103 | 104 | self.transform['test'] = { 105 | 'to_tensor': 1 106 | } 107 | 108 | self.transform['val_lin'] = { 109 | 'to_tensor': 3 110 | } 111 | 112 | else: 113 | parser.add_argument('--save-flag', type=bool, default=self.test['save_flag'], help='flag to save the network outputs and predictions') 114 | parser.add_argument('--img-dir', type=str, default=self.test['img_dir'], help='directory of test images') 115 | parser.add_argument('--imgh-dir', type=str, default=self.test['imgh_dir'], help='directory of test images') 116 | parser.add_argument('--label-dir', type=str, default=self.test['label_dir'], help='directory of labels') 117 | parser.add_argument('--save-dir', type=str, default=self.test['save_dir'], help='directory to save test results') 118 | parser.add_argument('--gpus', type=int, nargs='+', default=self.train['gpus'], help='GPUs for training') 119 | parser.add_argument('--model-path1', type=str, default=self.test['model_path1'], help='train model to be evaluated') 120 | parser.add_argument('--model-path2', type=str, default=self.test['model_path2'],help='train model to be evaluated') 121 | args = parser.parse_args() 122 | self.test['gpus'] = args.gpus 123 | self.test['save_flag'] = args.save_flag 124 | self.test['img_dir'] = args.img_dir 125 | self.test['imgh_dir'] = args.imgh_dir 126 | self.test['label_dir'] = args.label_dir 127 | self.test['save_dir'] = args.save_dir 128 | self.test['model_path1'] = args.model_path1 129 | self.test['model_path2'] = args.model_path2 130 | 131 | if not os.path.exists(self.test['save_dir']): 132 | os.makedirs(self.test['save_dir'], exist_ok=True) 133 | 134 | self.transform['test'] = { 135 | 'to_tensor': 1 136 | } 137 | 138 | def save_options(self): 139 | if self.isTrain: 140 | filename = '{:s}/train_options.txt'.format(self.train['save_dir']) 141 | else: 142 | filename = '{:s}/test_options.txt'.format(self.test['save_dir']) 143 | file = open(filename, 'w') 144 | groups = ['model', 'train', 'transform'] if self.isTrain else ['model', 'test', 'post', 'transform'] 145 | 146 | file.write("# ---------- Options ---------- #") 147 | file.write('\ndataset: {:s}\n'.format(self.dataset)) 148 | file.write('isTrain: {}\n'.format(self.isTrain)) 149 | for group, options in self.__dict__.items(): 150 | if group not in groups: 151 | continue 152 | file.write('\n\n-------- {:s} --------\n'.format(group)) 153 | if group == 'transform': 154 | for name, val in options.items(): 155 | if (self.isTrain and name != 'test') or (not self.isTrain and name == 'test'): 156 | file.write("{:s}:\n".format(name)) 157 | for t_name, t_val in val.items(): 158 | file.write("\t{:s}: {:s}\n".format(t_name, repr(t_val))) 159 | else: 160 | for name, val in options.items(): 161 | file.write("{:s} = {:s}\n".format(name, repr(val))) 162 | file.close() 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | astor==0.8.1 3 | backcall==0.2.0 4 | cached-property==1.5.2 5 | cachetools==4.2.4 6 | certifi==2020.6.20 7 | chardet==5.0.0 8 | click==8.0.4 9 | commonmark==0.9.1 10 | configparser==5.2.0 11 | contextlib2==21.6.0 12 | cycler==0.11.0 13 | dataclasses==0.8 14 | decorator==4.4.2 15 | gast==0.5.3 16 | google-pasta==0.2.0 17 | grad-cam==1.4.8 18 | grpcio==1.43.0 19 | h5py==3.1.0 20 | imageio==2.13.5 21 | importlib-metadata==4.8.3 22 | importlib-resources==5.4.0 23 | ipdb==0.13.13 24 | ipython==7.16.3 25 | ipython-genutils==0.2.0 26 | jedi==0.17.2 27 | joblib==1.1.1 28 | Keras==2.3.1 29 | Keras-Applications==1.0.8 30 | Keras-Preprocessing==1.1.2 31 | kiwisolver==1.3.1 32 | lxml==4.9.3 33 | Markdown==3.3.6 34 | matplotlib==3.3.4 35 | networkx==2.5.1 36 | nltk==3.6.7 37 | numpy==1.19.5 38 | nvidia-ml-py==11.525.131 39 | nvitop==1.0.0 40 | opencv-python==4.5.5.62 41 | packaging==21.3 42 | pandas==1.1.5 43 | parso==0.7.1 44 | pexpect==4.8.0 45 | pickleshare==0.7.5 46 | Pillow==10.0.1 47 | plotly==5.15.0 48 | prompt-toolkit==3.0.36 49 | protobuf==3.19.3 50 | psutil==5.9.5 51 | ptyprocess==0.7.0 52 | Pygments==2.14.0 53 | pyparsing==3.0.6 54 | python-dateutil==2.8.2 55 | pytz==2023.3 56 | PyWavelets==1.1.1 57 | PyYAML==6.0 58 | regex==2023.8.8 59 | rich==12.6.0 60 | scikit-image==0.17.2 61 | scikit-learn==0.24.2 62 | scipy==1.5.4 63 | Shapely==1.8.5.post1 64 | SimpleITK==2.0.2 65 | six==1.16.0 66 | summary==0.2.0 67 | tenacity==8.2.2 68 | tensorboard==1.14.0 69 | tensorboardX==2.6.2 70 | tensorflow-estimator==1.14.0 71 | tensorflow-gpu==1.14.0 72 | termcolor==1.1.0 73 | threadpoolctl==3.1.0 74 | tifffile==2020.9.3 75 | tomli==1.2.3 76 | torch==1.9.1+cu111 77 | torchaudio==0.9.1 78 | torchcam==0.3.2 79 | torchsummary==1.5.1 80 | torchvision==0.10.1+cu111 81 | tqdm==4.64.1 82 | traitlets==4.3.3 83 | ttach==0.0.3 84 | typing_extensions==4.0.1 85 | wcwidth==0.2.6 86 | Werkzeug==2.0.2 87 | wrapt==1.13.3 88 | xgboost==1.5.2 89 | zipp==3.6.0 90 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from PIL import Image 7 | import skimage.morphology as morph 8 | import scipy.ndimage.morphology as ndi_morph 9 | from skimage import measure 10 | from scipy import misc 11 | from model.modelW import ResWNet34 12 | from model.model_WNet import WNet 13 | import utils.utils as utils 14 | from utils.accuracy import compute_metrics 15 | import time 16 | import imageio 17 | from options import Options 18 | from dataloaders.my_transforms import get_transforms 19 | from tqdm import tqdm 20 | from rich.table import Column, Table 21 | from rich import print 22 | import time 23 | 24 | 25 | def main(): 26 | opt = Options(isTrain=False) 27 | opt.parse() 28 | opt.save_options() 29 | 30 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(x) for x in opt.test['gpus']) 31 | 32 | img_dir = opt.test['img_dir'] 33 | imgh_dir = opt.test['imgh_dir'] 34 | label_dir = opt.test['label_dir'] 35 | save_dir = opt.test['save_dir'] 36 | model_path1 = opt.test['model_path1'] 37 | model_path2 = opt.test['model_path2'] 38 | save_flag = opt.test['save_flag'] 39 | 40 | # data transforms 41 | test_transform = get_transforms(opt.transform['test']) 42 | 43 | model1 = ResWNet34(seg_classes = 2, colour_classes = 3) 44 | model1 = torch.nn.DataParallel(model1) 45 | model1 = model1.cuda() 46 | model2 = WNet(n_channels = 1, seg_classes = 2, colour_classes = 3) 47 | # model2 = ResWNet34(seg_classes = 2, colour_classes = 3) 48 | model2 = torch.nn.DataParallel(model2) 49 | model2 = model2.cuda() 50 | cudnn.benchmark = True 51 | 52 | # ----- load trained model ----- # 53 | print("=> loading trained model") 54 | checkpoint1 = torch.load(model_path1) 55 | checkpoint2 = torch.load(model_path2) 56 | model1.load_state_dict(checkpoint1['state_dict']) 57 | model2.load_state_dict(checkpoint2['state_dict']) 58 | print("=> loaded model1 at epoch {}".format(checkpoint1['epoch'])) 59 | print("=> loaded model2 at epoch {}".format(checkpoint2['epoch'])) 60 | model1 = model1.module 61 | model2 = model2.module 62 | 63 | # switch to evaluate mode 64 | model1.eval() 65 | model2.eval() 66 | counter = 0 67 | print("=> Test begins:") 68 | 69 | img_names = os.listdir(img_dir) 70 | 71 | if save_flag: 72 | if not os.path.exists(save_dir): 73 | os.mkdir(save_dir) 74 | strs = img_dir.split('/') 75 | prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1]) 76 | seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1]) 77 | if not os.path.exists(prob_maps_folder): 78 | os.mkdir(prob_maps_folder) 79 | if not os.path.exists(seg_folder): 80 | os.mkdir(seg_folder) 81 | 82 | metric_names = ['acc', 'p_F1', 'dice', 'aji', 'dq', 'sq', 'pq'] 83 | test_results = dict() 84 | all_result = utils.AverageMeter(len(metric_names)) 85 | all_result1 = utils.AverageMeter(len(metric_names)) 86 | all_result2 = utils.AverageMeter(len(metric_names)) 87 | 88 | # calculte inference time 89 | time_list = [] 90 | 91 | img_process = tqdm(img_names) 92 | for img_name in img_process: 93 | # load test image 94 | img_process.set_description('=> Processing image {:s}'.format(img_name)) 95 | img_path = '{:s}/{:s}'.format(img_dir, img_name) 96 | imgh_path = '{:s}/{:s}'.format(imgh_dir, img_name) 97 | img = Image.open(img_path) 98 | imgh = Image.open(imgh_path) 99 | import ipdb; ipdb.set_trace() 100 | ori_h = img.size[1] 101 | ori_w = img.size[0] 102 | name = os.path.splitext(img_name)[0] 103 | label_path = '{:s}/{:s}_label.png'.format(label_dir, name) 104 | gt = imageio.imread(label_path) 105 | 106 | input = test_transform((img,))[0].unsqueeze(0) 107 | inputh = test_transform((imgh,))[0].unsqueeze(0) 108 | 109 | img_process.set_description('Computing output probability maps...') 110 | begin = time.time() 111 | prob_maps, prob_maps1, prob_maps2 = get_probmaps(input, inputh, model1, model2, opt) 112 | end = time.time() 113 | time_list.append(end - begin) 114 | pred = np.argmax(prob_maps, axis=0) # prediction 115 | 116 | pred_labeled = measure.label(pred) 117 | pred_labeled = morph.remove_small_objects(pred_labeled, opt.post['min_area']) 118 | pred_labeled = ndi_morph.binary_fill_holes(pred_labeled > 0) 119 | pred_labeled = measure.label(pred_labeled) 120 | 121 | img_process.set_description('Computing metrics...') 122 | metrics = compute_metrics(pred_labeled, gt, metric_names) 123 | print(metrics) 124 | 125 | test_results[name] = [metrics[m] for m in metric_names] 126 | all_result.update([metrics[m] for m in metric_names]) 127 | 128 | pred1 = np.argmax(prob_maps1, axis=0) # prediction 129 | pred_labeled1 = measure.label(pred1) 130 | pred_labeled1 = morph.remove_small_objects(pred_labeled1, opt.post['min_area']) 131 | pred_labeled1 = ndi_morph.binary_fill_holes(pred_labeled1 > 0) 132 | pred_labeled1 = measure.label(pred_labeled1) 133 | metrics = compute_metrics(pred_labeled1, gt, metric_names) 134 | all_result1.update([metrics[m] for m in metric_names]) 135 | 136 | pred = np.argmax(prob_maps2, axis=0) # prediction 137 | pred_labeled2 = measure.label(pred) 138 | pred_labeled2 = morph.remove_small_objects(pred_labeled2, opt.post['min_area']) 139 | pred_labeled2 = ndi_morph.binary_fill_holes(pred_labeled2 > 0) 140 | pred_labeled2 = measure.label(pred_labeled2) 141 | metrics = compute_metrics(pred_labeled2, gt, metric_names) 142 | all_result2.update([metrics[m] for m in metric_names]) 143 | 144 | # save image 145 | if save_flag: 146 | img_process.set_description('Saving image results...') 147 | imageio.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name), pred.astype(np.uint8) * 255) 148 | imageio.imsave('{:s}/{:s}_prob.png'.format(prob_maps_folder, name), (prob_maps[1] * 255).astype(np.uint8)) 149 | final_pred = Image.fromarray(pred_labeled.astype(np.uint16)) 150 | final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name)) 151 | 152 | # save colored objects 153 | pred_colored_instance = np.zeros((ori_h, ori_w, 3)) 154 | for k in range(1, pred_labeled.max() + 1): 155 | pred_colored_instance[pred_labeled == k, :] = np.array(utils.get_random_color()) 156 | filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name) 157 | imageio.imsave(filename, (pred_colored_instance * 255).astype(np.uint8)) 158 | 159 | counter += 1 160 | 161 | print('=> Processed all {:d} images'.format(counter)) 162 | print('=> Average time per image: {:.4f} s'.format(np.mean(time_list))) 163 | print('=> Std time per image: {:.4f} s'.format(np.std(time_list))) 164 | 165 | table = Table(show_header=True, header_style="bold magenta") 166 | table.add_column("Model", style="dim", width=12) 167 | table.add_column("Acc") 168 | table.add_column("F1") 169 | table.add_column("Dice") 170 | table.add_column("AJI") 171 | table.add_column("DQ") 172 | table.add_column("SQ") 173 | table.add_column("PQ") 174 | a, a1, a2 = all_result.avg, all_result1.avg, all_result2.avg 175 | table.add_row('ens', f'{a[0]: .4f}', f'{a[1]: .4f}', f'{a[2]: .4f}', f'{a[3]: .4f}', f'{a[4]: .4f}', f'{a[5]: .4f}', f'{a[6]: .4f}'.format(a=all_result.avg)) 176 | table.add_row('m1', f'{a1[0]: .4f}', f'{a1[1]: .4f}', f'{a1[2]: .4f}', f'{a1[3]: .4f}', f'{a1[4]: .4f}', f'{a1[5]: .4f}', f'{a1[6]: .4f}'.format(a1=all_result1.avg)) 177 | table.add_row('m2', f'{a2[0]: .4f}', f'{a2[1]: .4f}', f'{a2[2]: .4f}', f'{a2[3]: .4f}', f'{a2[4]: .4f}', f'{a2[5]: .4f}', f'{a2[6]: .4f}'.format(a2=all_result2.avg)) 178 | print(table) 179 | 180 | header = metric_names 181 | utils.save_results(header, all_result.avg, test_results, f'{save_dir}/test_results_{checkpoint1["epoch"]}_{checkpoint2["epoch"]}_{all_result.avg[5]:.4f}_{all_result1.avg[5]:.4f}_{all_result2.avg[5]:.4f}.txt') 182 | 183 | 184 | def get_probmaps(input, inputh, model1, model2, opt): 185 | size = opt.test['patch_size'] 186 | overlap = opt.test['overlap'] 187 | output1 = split_forward1(model1, inputh, size, overlap) 188 | output1 = output1.squeeze(0) 189 | prob_maps1 = F.softmax(output1, dim=0).cpu().numpy() 190 | 191 | output2 = split_forward1(model2, inputh, size, overlap) 192 | output2 = output2.squeeze(0) 193 | prob_maps2 = F.softmax(output2, dim=0).cpu().numpy() 194 | 195 | prob_maps = (prob_maps1 + prob_maps2) / 2.0 196 | 197 | return prob_maps, prob_maps1, prob_maps2 198 | 199 | 200 | def split_forward1(model, input, size, overlap, outchannel = 2): 201 | ''' 202 | split the input image for forward passes 203 | motification: if single image, split it into patches and concat, forward once. 204 | ''' 205 | 206 | b, c, h0, w0 = input.size() 207 | 208 | # zero pad for border patches 209 | pad_h = 0 210 | if h0 - size > 0 and (h0 - size) % (size - overlap) > 0: 211 | pad_h = (size - overlap) - (h0 - size) % (size - overlap) 212 | tmp = torch.zeros((b, c, pad_h, w0)) 213 | input = torch.cat((input, tmp), dim=2) 214 | 215 | if w0 - size > 0 and (w0 - size) % (size - overlap) > 0: 216 | pad_w = (size - overlap) - (w0 - size) % (size - overlap) 217 | tmp = torch.zeros((b, c, h0 + pad_h, pad_w)) 218 | input = torch.cat((input, tmp), dim=3) 219 | 220 | _, c, h, w = input.size() 221 | 222 | output = torch.zeros((input.size(0), outchannel, h, w)) 223 | input_vars = [] 224 | for i in range(0, h-overlap, size-overlap): 225 | r_end = i + size if i + size < h else h 226 | ind1_s = i + overlap // 2 if i > 0 else 0 227 | ind1_e = i + size - overlap // 2 if i + size < h else h 228 | for j in range(0, w-overlap, size-overlap): 229 | c_end = j+size if j+size < w else w 230 | 231 | input_patch = input[:,:,i:r_end,j:c_end] 232 | # input_var = input_patch.cuda() 233 | input_var = input_patch.numpy() 234 | input_vars.append(input_var) 235 | input_vars = torch.as_tensor(input_vars) 236 | input_vars = input_vars.squeeze(1).cuda() 237 | with torch.no_grad(): 238 | output_patches, _ = model(input_vars) 239 | idx = 0 240 | for i in range(0, h-overlap, size-overlap): 241 | r_end = i + size if i + size < h else h 242 | ind1_s = i + overlap // 2 if i > 0 else 0 243 | ind1_e = i + size - overlap // 2 if i + size < h else h 244 | for j in range(0, w-overlap, size-overlap): 245 | c_end = j+size if j+size < w else w 246 | output_patch = output_patches[idx] 247 | idx += 1 248 | output_patch = output_patch.unsqueeze(0) 249 | ind2_s = j+overlap//2 if j>0 else 0 250 | ind2_e = j+size-overlap//2 if j+size Initial learning rate: {:g}".format(opt.train['lr'])) 122 | logger.info("=> Batch size: {:d}".format(opt.train['batch_size'])) 123 | logger.info("=> Number of training iterations: {:d}".format(num_iter)) 124 | logger.info("=> Training epochs: {:d}".format(opt.train['train_epochs'])) 125 | min_loss1 = 100 126 | min_loss2 = 100 127 | for epoch in range(opt.train['start_epoch'], num_epoch): 128 | # train for one epoch or len(train_loader) iterations 129 | logger.info('Epoch: [{:d}/{:d}]'.format(epoch+1, num_epoch)) 130 | if epoch % opt.train['ema_interval'] == 0: 131 | if epoch == 0: 132 | if os.path.exists(os.path.join(opt.train['save_dir'], 'weight12')): 133 | shutil.rmtree(os.path.join(opt.train['save_dir'], 'weight12')) 134 | shutil.rmtree(os.path.join(opt.train['save_dir'], 'weight21')) 135 | ensemble_prediction(imgh2_dir, os.path.join(opt.train['save_dir'], 'weight21'), model1) 136 | ensemble_prediction(imgh1_dir, os.path.join(opt.train['save_dir'], 'weight12'), model2) 137 | 138 | test_dice1 = val_lin(model1, test_loader) 139 | test_dice2 = val_lin(model2, test_loader) 140 | print(f'test_dice1: {test_dice1: .4f}\ttest_dice2: {test_dice2: .4f}') 141 | 142 | train_results = train(train_loader1, train_loader2, model1, model2, optimizer1, optimizer2, criterion, criterion_H, epoch, num_epoch) 143 | train_loss, train_loss_vor, train_loss_cluster, train_loss_ct, train_loss_self = train_results 144 | state1 = {'epoch': epoch + 1, 'state_dict': model1.state_dict(), 'optimizer': optimizer1.state_dict()} 145 | state2 = {'epoch': epoch + 1, 'state_dict': model2.state_dict(), 'optimizer': optimizer2.state_dict()} 146 | if (epoch + 1) < (num_epoch - 4): 147 | cp_flag = (epoch + 1) % opt.train['checkpoint_freq'] == 0 148 | else: 149 | cp_flag = True 150 | save_checkpoint1(state1, epoch, opt.train['save_dir'], cp_flag) 151 | save_checkpoint2(state2, epoch, opt.train['save_dir'], cp_flag) 152 | 153 | # save the training results to txt files 154 | logger_results.info('{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}' 155 | .format(epoch+1, train_loss, train_loss_vor, train_loss_cluster, train_loss_ct, train_loss_self, test_dice1, test_dice2)) 156 | scheduler1.step() 157 | scheduler2.step() 158 | 159 | val_loss1 = val(val_loader, model1, criterion) 160 | val_loss2 = val(val_loader, model2, criterion) 161 | 162 | if val_loss1 < min_loss1: 163 | print(f'val_loss1: {val_loss1:.4f}', f'min_loss1: {min_loss1:.4f}') 164 | min_loss1 = val_loss1 165 | save_bestcheckpoint1(state1, opt.train['save_dir']) 166 | if val_loss2 < min_loss2: 167 | print(f'val_loss2: {val_loss2:.4f}', f'min_loss2: {min_loss2:.4f}') 168 | min_loss2 = val_loss2 169 | save_bestcheckpoint2(state2, opt.train['save_dir']) 170 | for i in list(logger.handlers): 171 | logger.removeHandler(i) 172 | i.flush() 173 | i.close() 174 | for i in list(logger_results.handlers): 175 | logger_results.removeHandler(i) 176 | i.flush() 177 | i.close() 178 | 179 | 180 | def train(train_loader1, train_loader2, model1, model2, optimizer1, optimizer2, criterion, criterion_H, epoch, num_epoch): 181 | # list to store the average loss for this epoch 182 | results = utils.AverageMeter(5) 183 | # switch to train mode 184 | model1.train() 185 | model2.train() 186 | weight12, weight21 = None, None 187 | for i, (sample1, sample2) in enumerate(zip(train_loader1, train_loader2)): 188 | input1, inputh1, weight12, vor1, cluster1 = sample1 189 | input2, inputh2, weight21, vor2, cluster2 = sample2 190 | if vor1.dim() == 4: 191 | vor1 = vor1.squeeze(1) 192 | if vor2.dim() == 4: 193 | vor2 = vor2.squeeze(1) 194 | if cluster1.dim() == 4: 195 | cluster1 = cluster1.squeeze(1) 196 | if cluster2.dim() == 4: 197 | cluster2 = cluster2.squeeze(1) 198 | inputh_var1 = inputh1.float().cuda() 199 | inputh_var2 = inputh2.float().cuda() 200 | 201 | # compute output 202 | output11, output11l = model1(inputh_var1) 203 | output22, output22l = model2(inputh_var2) 204 | 205 | log_prob_maps1 = F.log_softmax(output11, dim=1) 206 | loss_vor1 = criterion(log_prob_maps1, vor1.cuda()) 207 | loss_cluster1 = criterion(log_prob_maps1, cluster1.cuda()) 208 | 209 | log_prob_maps2 = F.log_softmax(output22, dim=1) 210 | loss_vor2 = criterion(log_prob_maps2, vor2.cuda()) 211 | loss_cluster2 = criterion(log_prob_maps2, cluster2.cuda()) 212 | 213 | loss_vor = loss_vor1 + loss_vor2 214 | loss_cluster = loss_cluster1 + loss_cluster2 215 | 216 | pseudo12 = Combine(weight12.float().cuda(), cluster1) 217 | Pseudo12 = Variable(pseudo12, requires_grad=False) 218 | pseudo21 = Combine(weight21.float().cuda(), cluster2) 219 | Pseudo21 = Variable(pseudo21, requires_grad=False) 220 | 221 | loss_ct1 = loss_CT(output11, Pseudo12) 222 | loss_ct2 = loss_CT(output22, Pseudo21) 223 | loss_ct = loss_ct1 + loss_ct2 224 | 225 | loss_self1 = criterion_H(output11l, input1.cuda()) 226 | loss_self2 = criterion_H(output22l, input2.cuda()) 227 | loss_self = loss_self1 + loss_self2 228 | 229 | loss = loss_vor + loss_cluster + loss_ct * (epoch/num_epoch)**2 + loss_self * (1 - (epoch/num_epoch)**2) * 0.1 230 | 231 | result = [loss.item(), loss_vor.item(), loss_cluster.item(), loss_ct.item(), loss_self.item()] 232 | 233 | results.update(result, input1.size(0)) 234 | 235 | # compute gradient and do SGD step 236 | optimizer1.zero_grad() 237 | optimizer2.zero_grad() 238 | loss.backward() 239 | optimizer1.step() 240 | optimizer2.step() 241 | 242 | if i % opt.train['log_interval'] == 0: 243 | logger.info('Iteration: [{:d}/{:d}]' 244 | '\tLoss {r[0]:.4f}' 245 | '\tLoss_vor {r[1]:.4f}' 246 | '\tLoss_cluster {r[2]:.4f}' 247 | '\tLoss_ct {r[3]:.4f}' 248 | '\tLoss_self {r[4]:.4f}'.format(i, len(train_loader1), r=results.avg)) 249 | 250 | logger.info('===> Train Avg: Loss {r[0]:.4f}' 251 | '\tloss_vor {r[1]:.4f}' 252 | '\tloss_cluster {r[2]:.4f}' 253 | '\tloss_ct {r[3]:.4f}' 254 | '\tloss_self {r[4]:.4f}'.format(r=results.avg)) 255 | 256 | return results.avg 257 | 258 | def val(val_loader, model, criterion): 259 | model.eval() 260 | results = 0 261 | for i, sample in enumerate(val_loader): 262 | input, inputh, target1, target2 = sample 263 | if target1.dim() == 4: 264 | target1 = target1.squeeze(1) 265 | if target2.dim() == 4: 266 | target2 = target2.squeeze(1) 267 | 268 | input_var = inputh.cuda() 269 | 270 | # compute output 271 | output, _ = model(input_var) 272 | log_prob_maps = F.log_softmax(output, dim=1) 273 | loss_vor = criterion(log_prob_maps, target1.cuda()) 274 | loss_cluster = criterion(log_prob_maps, target2.cuda()) 275 | result = loss_vor.item() + loss_cluster.item() 276 | results += result 277 | val_loss = results / (opt.train['batch_size'] * len(val_loader)) 278 | return val_loss 279 | 280 | 281 | def dice_coeff(pred, gt): 282 | target = torch.zeros_like(gt) 283 | target[gt > 0.5] = 1 284 | target = gt 285 | smooth = 1. 286 | num = pred.size(0) 287 | m1 = pred.view(num, -1) # Flatten 288 | m2 = target.view(num, -1) # Flatten 289 | intersection = (m1 * m2) 290 | dice = (2. * intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth) 291 | avg_dice = dice.sum() / num 292 | return avg_dice 293 | 294 | def val_lin(model, test_loader): 295 | model.eval() 296 | metric_names = ['dice'] 297 | all_result = utils.AverageMeter(len(metric_names)) 298 | 299 | size = opt.test['patch_size'] 300 | overlap = opt.test['overlap'] 301 | 302 | for _, inputh, gt in test_loader: 303 | if gt.dim() == 4: 304 | gt = gt.squeeze(1).cuda() 305 | output = utils.split_forward(model, inputh, size, overlap) 306 | log_prob_maps = F.softmax(output, dim=1) 307 | pred_labeled = torch.argmax(log_prob_maps, axis=1) # prediction 308 | dice = dice_coeff(pred_labeled, gt).cpu().numpy() 309 | all_result.update([dice]) 310 | dice_avg = all_result.avg[0] 311 | return dice_avg 312 | 313 | 314 | def ensemble_prediction(img_dir, save_dir, model, alpha=0.1): 315 | ## save pred into the experiment fold 316 | ## load all the training images 317 | img_names = os.listdir(img_dir) 318 | img_process = tqdm(img_names) 319 | test_transform = get_transforms(opt.transform['test']) 320 | print('[bold magenta]saving EMA weights ...[/bold magenta]') 321 | for img_name in img_process: 322 | img = Image.open(os.path.join(img_dir, img_name)) 323 | input = test_transform((img,))[0].unsqueeze(0).cuda() 324 | output, _ = model(input) 325 | log_prob_maps = F.softmax(output, dim=1) 326 | pred = log_prob_maps.squeeze(0).cpu().detach().numpy()[1] 327 | 328 | try: 329 | weight = imageio.imread(os.path.join(save_dir, img_name)) 330 | weight = np.array(weight) 331 | weight = alpha * pred + (1 - alpha) * weight 332 | except: 333 | weight = pred 334 | os.makedirs(save_dir, exist_ok=True) 335 | imageio.imsave(os.path.join(save_dir, img_name), (weight * 255).astype(np.uint8)) 336 | 337 | 338 | def save_checkpoint1(state, epoch, save_dir, cp_flag): 339 | cp_dir = '{:s}/checkpoints'.format(save_dir) 340 | if not os.path.exists(cp_dir): 341 | os.mkdir(cp_dir) 342 | filename = '{:s}/checkpoint1.pth.tar'.format(cp_dir) 343 | torch.save(state, filename) 344 | if cp_flag: 345 | shutil.copyfile(filename, '{:s}/checkpoint1_{:d}.pth.tar'.format(cp_dir, epoch+1)) 346 | def save_checkpoint2(state, epoch, save_dir, cp_flag): 347 | cp_dir = '{:s}/checkpoints'.format(save_dir) 348 | if not os.path.exists(cp_dir): 349 | os.mkdir(cp_dir) 350 | filename = '{:s}/checkpoint2.pth.tar'.format(cp_dir) 351 | torch.save(state, filename) 352 | if cp_flag: 353 | shutil.copyfile(filename, '{:s}/checkpoint2_{:d}.pth.tar'.format(cp_dir, epoch+1)) 354 | 355 | def save_bestcheckpoint1(state, save_dir): 356 | cp_dir = '{:s}/checkpoints'.format(save_dir) 357 | torch.save(state, '{:s}/checkpoint1_0.pth.tar'.format(cp_dir)) 358 | 359 | def save_bestcheckpoint2(state, save_dir): 360 | cp_dir = '{:s}/checkpoints'.format(save_dir) 361 | torch.save(state, '{:s}/checkpoint2_0.pth.tar'.format(cp_dir)) 362 | 363 | def setup_logging(opt): 364 | mode = 'a' if opt.train['checkpoint'] else 'w' 365 | 366 | # create logger for training information 367 | logger = logging.getLogger('train_logger') 368 | logger.setLevel(logging.DEBUG) 369 | # create console handler and file handler 370 | console_handler = RichHandler(show_level=False, show_time=False, show_path=False) 371 | console_handler.setLevel(logging.INFO) 372 | file_handler = logging.FileHandler('{:s}/train_log.txt'.format(opt.train['save_dir']), mode=mode) 373 | file_handler.setLevel(logging.DEBUG) 374 | # create formatter 375 | formatter = logging.Formatter('%(message)s') 376 | # add formatter to handlers 377 | console_handler.setFormatter(formatter) 378 | file_handler.setFormatter(formatter) 379 | # add handlers to logger 380 | logger.addHandler(console_handler) 381 | logger.addHandler(file_handler) 382 | 383 | # create logger for epoch results 384 | logger_results = logging.getLogger('results') 385 | logger_results.setLevel(logging.DEBUG) 386 | file_handler2 = logging.FileHandler('{:s}/epoch_results.txt'.format(opt.train['save_dir']), mode=mode) 387 | file_handler2.setFormatter(logging.Formatter('%(message)s')) 388 | logger_results.addHandler(file_handler2) 389 | 390 | logger.info('***** Training starts *****') 391 | logger.info('save directory: {:s}'.format(opt.train['save_dir'])) 392 | if mode == 'w': 393 | logger_results.info('epoch\ttrain_loss\ttrain_loss_vor\ttrain_loss_cluster\ttrain_loss_ct\ttrain_loss_self\ttest_dice1\ttest_dice2') 394 | 395 | return logger, logger_results 396 | 397 | 398 | if __name__ == '__main__': 399 | main() 400 | -------------------------------------------------------------------------------- /utils/accuracy.py: -------------------------------------------------------------------------------- 1 | 2 | from skimage.measure import label 3 | from sklearn.metrics import accuracy_score, roc_auc_score 4 | from sklearn.metrics import f1_score 5 | from sklearn.metrics import recall_score, precision_score 6 | from scipy.spatial.distance import directed_hausdorff as hausdorff 7 | from scipy.ndimage.measurements import center_of_mass 8 | import numpy as np 9 | from scipy.optimize import linear_sum_assignment 10 | 11 | def compute_metrics(pred, gt, names): 12 | """ 13 | Computes metrics specified by names between predicted label and groundtruth label. 14 | """ 15 | 16 | gt_labeled = label(gt) 17 | pred_labeled = label(pred) 18 | 19 | gt_binary = gt_labeled.copy() 20 | pred_binary = pred_labeled.copy() 21 | gt_binary[gt_binary > 0] = 1 22 | pred_binary[pred_binary > 0] = 1 23 | gt_binary, pred_binary = gt_binary.flatten(), pred_binary.flatten() 24 | 25 | results = {} 26 | 27 | # pixel-level metrics 28 | if 'acc' in names: 29 | results['acc'] = accuracy_score(gt_binary, pred_binary) 30 | if 'roc' in names: 31 | results['roc'] = roc_auc_score(gt_binary, pred_binary) 32 | if 'p_F1' in names: # pixel-level F1 33 | results['p_F1'] = f1_score(gt_binary, pred_binary) 34 | if 'p_recall' in names: # pixel-level F1 35 | results['p_recall'] = recall_score(gt_binary, pred_binary) 36 | if 'p_precision' in names: # pixel-level F1 37 | results['p_precision'] = precision_score(gt_binary, pred_binary) 38 | 39 | # object-level metrics 40 | if 'aji' in names: 41 | results['aji'] = AJI_fast(gt_labeled, pred_labeled) 42 | if 'haus' in names: 43 | results['dice'], results['iou'], results['haus'] = accuracy_object_level(pred_labeled, gt_labeled, True) 44 | if 'dice' in names or 'iou' in names: 45 | results['dice'], results['iou'], _ = accuracy_object_level(pred_labeled, gt_labeled, False) 46 | if 'pq' in names: 47 | results['dq'], results['sq'], results['pq'] = get_pq(gt_labeled, pred_labeled) 48 | 49 | return results 50 | 51 | 52 | def accuracy_object_level(pred, gt, hausdorff_flag=True): 53 | """ Compute the object-level metrics between predicted and 54 | groundtruth: dice, iou, hausdorff """ 55 | if not isinstance(pred, np.ndarray): 56 | pred = np.array(pred) 57 | if not isinstance(gt, np.ndarray): 58 | gt = np.array(gt) 59 | 60 | # get connected components 61 | pred_labeled = label(pred, connectivity=2) 62 | Ns = len(np.unique(pred_labeled)) - 1 63 | gt_labeled = label(gt, connectivity=2) 64 | Ng = len(np.unique(gt_labeled)) - 1 65 | 66 | # --- compute dice, iou, hausdorff --- # 67 | pred_objs_area = np.sum(pred_labeled>0) # total area of objects in image 68 | gt_objs_area = np.sum(gt_labeled>0) # total area of objects in groundtruth gt 69 | 70 | # compute how well groundtruth object overlaps its segmented object 71 | dice_g = 0.0 72 | iou_g = 0.0 73 | hausdorff_g = 0.0 74 | for i in range(1, Ng + 1): 75 | gt_i = np.where(gt_labeled == i, 1, 0) 76 | overlap_parts = gt_i * pred_labeled 77 | 78 | # get intersection objects numbers in image 79 | obj_no = np.unique(overlap_parts) 80 | obj_no = obj_no[obj_no != 0] 81 | 82 | gamma_i = float(np.sum(gt_i)) / gt_objs_area 83 | 84 | if obj_no.size == 0: # no intersection object 85 | dice_i = 0 86 | iou_i = 0 87 | 88 | # find nearest segmented object in hausdorff distance 89 | if hausdorff_flag: 90 | min_haus = 1e3 91 | 92 | # find overlap object in a window [-50, 50] 93 | pred_cand_indices = find_candidates(gt_i, pred_labeled) 94 | 95 | for j in pred_cand_indices: 96 | pred_j = np.where(pred_labeled == j, 1, 0) 97 | seg_ind = np.argwhere(pred_j) 98 | gt_ind = np.argwhere(gt_i) 99 | haus_tmp = max(hausdorff(seg_ind, gt_ind)[0], hausdorff(gt_ind, seg_ind)[0]) 100 | 101 | if haus_tmp < min_haus: 102 | min_haus = haus_tmp 103 | haus_i = min_haus 104 | else: 105 | # find max overlap object 106 | obj_areas = [np.sum(overlap_parts == k) for k in obj_no] 107 | seg_obj = obj_no[np.argmax(obj_areas)] # segmented object number 108 | pred_i = np.where(pred_labeled == seg_obj, 1, 0) # segmented object 109 | 110 | overlap_area = np.max(obj_areas) # overlap area 111 | 112 | dice_i = 2 * float(overlap_area) / (np.sum(pred_i) + np.sum(gt_i)) 113 | iou_i = float(overlap_area) / (np.sum(pred_i) + np.sum(gt_i) - overlap_area) 114 | 115 | # compute hausdorff distance 116 | if hausdorff_flag: 117 | seg_ind = np.argwhere(pred_i) 118 | gt_ind = np.argwhere(gt_i) 119 | haus_i = max(hausdorff(seg_ind, gt_ind)[0], hausdorff(gt_ind, seg_ind)[0]) 120 | 121 | dice_g += gamma_i * dice_i 122 | iou_g += gamma_i * iou_i 123 | if hausdorff_flag: 124 | hausdorff_g += gamma_i * haus_i 125 | 126 | # compute how well segmented object overlaps its groundtruth object 127 | dice_s = 0.0 128 | iou_s = 0.0 129 | hausdorff_s = 0.0 130 | for j in range(1, Ns + 1): 131 | pred_j = np.where(pred_labeled == j, 1, 0) 132 | overlap_parts = pred_j * gt_labeled 133 | 134 | # get intersection objects number in gt 135 | obj_no = np.unique(overlap_parts) 136 | obj_no = obj_no[obj_no != 0] 137 | 138 | # show_figures((pred_j, gt_labeled, overlap_parts)) 139 | 140 | sigma_j = float(np.sum(pred_j)) / pred_objs_area 141 | # no intersection object 142 | if obj_no.size == 0: 143 | dice_j = 0 144 | iou_j = 0 145 | 146 | # find nearest groundtruth object in hausdorff distance 147 | if hausdorff_flag: 148 | min_haus = 1e3 149 | 150 | # find overlap object in a window [-50, 50] 151 | gt_cand_indices = find_candidates(pred_j, gt_labeled) 152 | 153 | for i in gt_cand_indices: 154 | gt_i = np.where(gt_labeled == i, 1, 0) 155 | seg_ind = np.argwhere(pred_j) 156 | gt_ind = np.argwhere(gt_i) 157 | haus_tmp = max(hausdorff(seg_ind, gt_ind)[0], hausdorff(gt_ind, seg_ind)[0]) 158 | 159 | if haus_tmp < min_haus: 160 | min_haus = haus_tmp 161 | haus_j = min_haus 162 | else: 163 | # find max overlap gt 164 | gt_areas = [np.sum(overlap_parts == k) for k in obj_no] 165 | gt_obj = obj_no[np.argmax(gt_areas)] # groundtruth object number 166 | gt_j = np.where(gt_labeled == gt_obj, 1, 0) # groundtruth object 167 | 168 | overlap_area = np.max(gt_areas) # overlap area 169 | 170 | dice_j = 2 * float(overlap_area) / (np.sum(pred_j) + np.sum(gt_j)) 171 | iou_j = float(overlap_area) / (np.sum(pred_j) + np.sum(gt_j) - overlap_area) 172 | 173 | # compute hausdorff distance 174 | if hausdorff_flag: 175 | seg_ind = np.argwhere(pred_j) 176 | gt_ind = np.argwhere(gt_j) 177 | haus_j = max(hausdorff(seg_ind, gt_ind)[0], hausdorff(gt_ind, seg_ind)[0]) 178 | 179 | dice_s += sigma_j * dice_j 180 | iou_s += sigma_j * iou_j 181 | if hausdorff_flag: 182 | hausdorff_s += sigma_j * haus_j 183 | 184 | return (dice_g + dice_s) / 2, (iou_g + iou_s) / 2, (hausdorff_g + hausdorff_s) / 2 185 | 186 | 187 | def find_candidates(obj_i, objects_labeled, radius=50): 188 | """ 189 | find object indices in objects_labeled in a window centered at obj_i 190 | when computing object-level hausdorff distance 191 | 192 | """ 193 | if radius > 400: 194 | return np.array([]) 195 | 196 | h, w = objects_labeled.shape 197 | x, y = center_of_mass(obj_i) 198 | x, y = int(x), int(y) 199 | r1 = x-radius if x-radius >= 0 else 0 200 | r2 = x+radius if x+radius <= h else h 201 | c1 = y-radius if y-radius >= 0 else 0 202 | c2 = y+radius if y+radius < w else w 203 | indices = np.unique(objects_labeled[r1:r2, c1:c2]) 204 | indices = indices[indices != 0] 205 | 206 | if indices.size == 0: 207 | indices = find_candidates(obj_i, objects_labeled, 2*radius) 208 | 209 | return indices 210 | 211 | 212 | def AJI_fast(gt, pred_arr): 213 | gs, g_areas = np.unique(gt, return_counts=True) 214 | assert np.all(gs == np.arange(len(gs))) 215 | ss, s_areas = np.unique(pred_arr, return_counts=True) 216 | assert np.all(ss == np.arange(len(ss))) 217 | 218 | i_idx, i_cnt = np.unique(np.concatenate([gt.reshape(1, -1), pred_arr.reshape(1, -1)]), 219 | return_counts=True, axis=1) 220 | i_arr = np.zeros(shape=(len(gs), len(ss)), dtype=np.int) 221 | i_arr[i_idx[0], i_idx[1]] += i_cnt 222 | u_arr = g_areas.reshape(-1, 1) + s_areas.reshape(1, -1) - i_arr 223 | iou_arr = 1.0 * i_arr / u_arr 224 | 225 | i_arr = i_arr[1:, 1:] 226 | u_arr = u_arr[1:, 1:] 227 | iou_arr = iou_arr[1:, 1:] 228 | 229 | j = np.argmax(iou_arr, axis=1) 230 | 231 | c = np.sum(i_arr[np.arange(len(gs) - 1), j]) 232 | u = np.sum(u_arr[np.arange(len(gs) - 1), j]) 233 | used = np.zeros(shape=(len(ss) - 1), dtype=np.int) 234 | used[j] = 1 235 | u += (np.sum(s_areas[1:] * (1 - used))) 236 | return 1.0 * c / u 237 | 238 | 239 | def remap_label(pred, by_size=False): 240 | """Rename all instance id so that the id is contiguous i.e [0, 1, 2, 3] 241 | not [0, 2, 4, 6]. The ordering of instances (which one comes first) 242 | is preserved unless by_size=True, then the instances will be reordered 243 | so that bigger nucler has smaller ID. 244 | 245 | Args: 246 | pred (ndarray): the 2d array contain instances where each instances is marked 247 | by non-zero integer. 248 | by_size (bool): renaming such that larger nuclei have a smaller id (on-top). 249 | 250 | Returns: 251 | new_pred (ndarray): Array with continguous ordering of instances. 252 | 253 | """ 254 | pred_id = list(np.unique(pred)) 255 | pred_id.remove(0) 256 | if len(pred_id) == 0: 257 | return pred # no label 258 | if by_size: 259 | pred_size = [] 260 | for inst_id in pred_id: 261 | size = (pred == inst_id).sum() 262 | pred_size.append(size) 263 | # sort the id by size in descending order 264 | pair_list = zip(pred_id, pred_size) 265 | pair_list = sorted(pair_list, key=lambda x: x[1], reverse=True) 266 | pred_id, pred_size = zip(*pair_list) 267 | 268 | new_pred = np.zeros(pred.shape, np.int32) 269 | for idx, inst_id in enumerate(pred_id): 270 | new_pred[pred == inst_id] = idx + 1 271 | return new_pred 272 | 273 | 274 | def get_bounding_box(img): 275 | """Get the bounding box coordinates of a binary input- assumes a single object. 276 | 277 | Args: 278 | img: input binary image. 279 | 280 | Returns: 281 | bounding box coordinates 282 | 283 | """ 284 | rows = np.any(img, axis=1) 285 | cols = np.any(img, axis=0) 286 | rmin, rmax = np.where(rows)[0][[0, -1]] 287 | cmin, cmax = np.where(cols)[0][[0, -1]] 288 | # due to python indexing, need to add 1 to max 289 | # else accessing will be 1px in the box, not out 290 | rmax += 1 291 | cmax += 1 292 | return [rmin, rmax, cmin, cmax] 293 | 294 | 295 | def get_pq(true, pred, match_iou=0.5, remap=True): 296 | """Get the panoptic quality result. 297 | 298 | Fast computation requires instance IDs are in contiguous orderding i.e [1, 2, 3, 4] 299 | not [2, 3, 6, 10]. Please call `remap_label` beforehand. Here, the `by_size` flag 300 | has no effect on the result. 301 | 302 | Args: 303 | true (ndarray): HxW ground truth instance segmentation map 304 | pred (ndarray): HxW predicted instance segmentation map 305 | match_iou (float): IoU threshold level to determine the pairing between 306 | GT instances `p` and prediction instances `g`. `p` and `g` is a pair 307 | if IoU > `match_iou`. However, pair of `p` and `g` must be unique 308 | (1 prediction instance to 1 GT instance mapping). If `match_iou` < 0.5, 309 | Munkres assignment (solving minimum weight matching in bipartite graphs) 310 | is caculated to find the maximal amount of unique pairing. If 311 | `match_iou` >= 0.5, all IoU(p,g) > 0.5 pairing is proven to be unique and 312 | the number of pairs is also maximal. 313 | remap (bool): whether to ensure contiguous ordering of instances. 314 | 315 | Returns: 316 | [dq, sq, pq]: measurement statistic 317 | 318 | [paired_true, paired_pred, unpaired_true, unpaired_pred]: 319 | pairing information to perform measurement 320 | 321 | paired_iou.sum(): sum of IoU within true positive predictions 322 | 323 | """ 324 | assert match_iou >= 0.0, "Cant' be negative" 325 | # ensure instance maps are contiguous 326 | if remap: 327 | pred = remap_label(pred) 328 | true = remap_label(true) 329 | 330 | true = np.copy(true) 331 | pred = np.copy(pred) 332 | true = true.astype("int32") 333 | pred = pred.astype("int32") 334 | true_id_list = list(np.unique(true)) 335 | pred_id_list = list(np.unique(pred)) 336 | # prefill with value 337 | pairwise_iou = np.zeros([len(true_id_list), len(pred_id_list)], dtype=np.float64) 338 | 339 | # caching pairwise iou 340 | for true_id in true_id_list[1:]: # 0-th is background 341 | t_mask_lab = true == true_id 342 | rmin1, rmax1, cmin1, cmax1 = get_bounding_box(t_mask_lab) 343 | t_mask_crop = t_mask_lab[rmin1:rmax1, cmin1:cmax1] 344 | t_mask_crop = t_mask_crop.astype("int") 345 | p_mask_crop = pred[rmin1:rmax1, cmin1:cmax1] 346 | pred_true_overlap = p_mask_crop[t_mask_crop > 0] 347 | pred_true_overlap_id = np.unique(pred_true_overlap) 348 | pred_true_overlap_id = list(pred_true_overlap_id) 349 | for pred_id in pred_true_overlap_id: 350 | if pred_id == 0: # ignore 351 | continue # overlaping background 352 | p_mask_lab = pred == pred_id 353 | p_mask_lab = p_mask_lab.astype("int") 354 | 355 | # crop region to speed up computation 356 | rmin2, rmax2, cmin2, cmax2 = get_bounding_box(p_mask_lab) 357 | rmin = min(rmin1, rmin2) 358 | rmax = max(rmax1, rmax2) 359 | cmin = min(cmin1, cmin2) 360 | cmax = max(cmax1, cmax2) 361 | t_mask_crop2 = t_mask_lab[rmin:rmax, cmin:cmax] 362 | p_mask_crop2 = p_mask_lab[rmin:rmax, cmin:cmax] 363 | 364 | total = (t_mask_crop2 + p_mask_crop2).sum() 365 | inter = (t_mask_crop2 * p_mask_crop2).sum() 366 | iou = inter / (total - inter) 367 | pairwise_iou[true_id - 1, pred_id - 1] = iou 368 | 369 | if match_iou >= 0.5: 370 | paired_iou = pairwise_iou[pairwise_iou > match_iou] 371 | pairwise_iou[pairwise_iou <= match_iou] = 0.0 372 | paired_true, paired_pred = np.nonzero(pairwise_iou) 373 | paired_iou = pairwise_iou[paired_true, paired_pred] 374 | paired_true += 1 # index is instance id - 1 375 | paired_pred += 1 # hence return back to original 376 | else: # * Exhaustive maximal unique pairing 377 | #### Munkres pairing with scipy library 378 | # the algorithm return (row indices, matched column indices) 379 | # if there is multiple same cost in a row, index of first occurence 380 | # is return, thus the unique pairing is ensure 381 | # inverse pair to get high IoU as minimum 382 | paired_true, paired_pred = linear_sum_assignment(-pairwise_iou) 383 | ### extract the paired cost and remove invalid pair 384 | paired_iou = pairwise_iou[paired_true, paired_pred] 385 | 386 | # now select those above threshold level 387 | # paired with iou = 0.0 i.e no intersection => FP or FN 388 | paired_true = list(paired_true[paired_iou > match_iou] + 1) 389 | paired_pred = list(paired_pred[paired_iou > match_iou] + 1) 390 | paired_iou = paired_iou[paired_iou > match_iou] 391 | 392 | # get the actual FP and FN 393 | unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true] 394 | unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred] 395 | # print(paired_iou.shape, paired_true.shape, len(unpaired_true), len(unpaired_pred)) 396 | 397 | # 398 | tp = len(paired_true) 399 | fp = len(unpaired_pred) 400 | fn = len(unpaired_true) 401 | # get the F1-score i.e DQ 402 | dq = tp / ((tp + 0.5 * fp + 0.5 * fn) + 1.0e-6) 403 | # get the SQ, no paired has 0 iou so not impact 404 | sq = paired_iou.sum() / (tp + 1.0e-6) 405 | 406 | # return ( 407 | # [dq, sq, dq * sq], 408 | # [tp, fp, fn], 409 | # paired_iou.sum(), 410 | return dq, sq, dq * sq -------------------------------------------------------------------------------- /utils/combine.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | 4 | 5 | def Combine(prediction, target): 6 | predict = prediction 7 | pseudos = torch.zeros([prediction.shape[0], 2, prediction.shape[2], prediction.shape[3]]).cuda() 8 | for i in range(target.shape[0]): 9 | pseudo_label = predict[i].squeeze(0) 10 | pseudo_label[target[i] == 1] = 1.0 # nuclei 11 | pseudo_label[target[i] == 0] = 0.0 # background 12 | pseudo_0 = torch.ones((target[i].shape[0], target[i].shape[1])).cuda() - pseudo_label 13 | pseudo1 = pseudo_label.unsqueeze(0) 14 | pseudo0 = pseudo_0.unsqueeze(0) 15 | pseudo = torch.cat([pseudo0, pseudo1.float()], dim=0) 16 | pseudos[i] = pseudo 17 | return pseudos -------------------------------------------------------------------------------- /utils/divergence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def loss_CT(M1, M2): 5 | criterion = torch.nn.KLDivLoss(reduction='mean').cuda() 6 | loss = criterion(F.log_softmax(M1, dim=1), M2) 7 | return loss 8 | 9 | def Epsilon(M1, M2): 10 | results = 0 11 | M1_1 = F.softmax(M1, dim=1) 12 | for i in range(M2.shape[0]): 13 | m1 = M1_1[i] - M2[i] 14 | m2 = M1_1[i] + M2[i] 15 | result = 2 * (torch.abs(m1)).sum() / m2.sum() 16 | result = result.item() 17 | results += result 18 | return results -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import random 4 | import torch 5 | from scipy.spatial import Voronoi 6 | from skimage import draw 7 | 8 | 9 | def poly2mask(vertex_row_coords, vertex_col_coords, shape): 10 | fill_row_coords, fill_col_coords = draw.polygon(vertex_row_coords, vertex_col_coords, shape) 11 | mask = np.zeros(shape, dtype=np.bool) 12 | mask[fill_row_coords, fill_col_coords] = True 13 | return mask 14 | 15 | 16 | # borrowed from https://gist.github.com/pv/8036995 17 | def voronoi_finite_polygons_2d(vor, radius=None): 18 | """ 19 | Reconstruct infinite voronoi regions in a 2D diagram to finite 20 | regions. 21 | 22 | Parameters 23 | ---------- 24 | vor : Voronoi 25 | Input diagram 26 | radius : float, optional 27 | Distance to 'points at infinity'. 28 | 29 | Returns 30 | ------- 31 | regions : list of tuples 32 | Indices of vertices in each revised Voronoi regions. 33 | vertices : list of tuples 34 | Coordinates for revised Voronoi vertices. Same as coordinates 35 | of input vertices, with 'points at infinity' appended to the 36 | end. 37 | 38 | """ 39 | 40 | if vor.points.shape[1] != 2: 41 | raise ValueError("Requires 2D input") 42 | 43 | new_regions = [] 44 | new_vertices = vor.vertices.tolist() 45 | 46 | center = vor.points.mean(axis=0) 47 | if radius is None: 48 | radius = vor.points.ptp().max() 49 | 50 | # Construct a map containing all ridges for a given point 51 | all_ridges = {} 52 | for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices): 53 | all_ridges.setdefault(p1, []).append((p2, v1, v2)) 54 | all_ridges.setdefault(p2, []).append((p1, v1, v2)) 55 | 56 | # Reconstruct infinite regions 57 | for p1, region in enumerate(vor.point_region): 58 | vertices = vor.regions[region] 59 | 60 | if all(v >= 0 for v in vertices): 61 | # finite region 62 | new_regions.append(vertices) 63 | continue 64 | 65 | # reconstruct a non-finite region 66 | ridges = all_ridges[p1] 67 | new_region = [v for v in vertices if v >= 0] 68 | 69 | for p2, v1, v2 in ridges: 70 | if v2 < 0: 71 | v1, v2 = v2, v1 72 | if v1 >= 0: 73 | # finite ridge: already in the region 74 | continue 75 | 76 | # Compute the missing endpoint of an infinite ridge 77 | 78 | t = vor.points[p2] - vor.points[p1] # tangent 79 | t /= np.linalg.norm(t) 80 | n = np.array([-t[1], t[0]]) # normal 81 | 82 | midpoint = vor.points[[p1, p2]].mean(axis=0) 83 | direction = np.sign(np.dot(midpoint - center, n)) * n 84 | far_point = vor.vertices[v2] + direction * radius 85 | 86 | new_region.append(len(new_vertices)) 87 | new_vertices.append(far_point.tolist()) 88 | 89 | # sort region counterclockwise 90 | vs = np.asarray([new_vertices[v] for v in new_region]) 91 | c = vs.mean(axis=0) 92 | angles = np.arctan2(vs[:,1] - c[1], vs[:,0] - c[0]) 93 | new_region = np.array(new_region)[np.argsort(angles)] 94 | 95 | # finish 96 | new_regions.append(new_region.tolist()) 97 | 98 | return new_regions, np.asarray(new_vertices) 99 | 100 | 101 | def split_forward(model, input, size, overlap, outchannel = 2): 102 | ''' 103 | split the input image for forward passes 104 | ''' 105 | 106 | b, c, h0, w0 = input.size() 107 | 108 | # zero pad for border patches 109 | pad_h = 0 110 | if h0 - size > 0 and (h0 - size) % (size - overlap) > 0: 111 | pad_h = (size - overlap) - (h0 - size) % (size - overlap) 112 | tmp = torch.zeros((b, c, pad_h, w0)) 113 | input = torch.cat((input, tmp), dim=2) 114 | 115 | if w0 - size > 0 and (w0 - size) % (size - overlap) > 0: 116 | pad_w = (size - overlap) - (w0 - size) % (size - overlap) 117 | tmp = torch.zeros((b, c, h0 + pad_h, pad_w)) 118 | input = torch.cat((input, tmp), dim=3) 119 | 120 | _, c, h, w = input.size() 121 | 122 | output = torch.zeros((input.size(0), outchannel, h, w)) 123 | for i in range(0, h-overlap, size-overlap): 124 | r_end = i + size if i + size < h else h 125 | ind1_s = i + overlap // 2 if i > 0 else 0 126 | ind1_e = i + size - overlap // 2 if i + size < h else h 127 | for j in range(0, w-overlap, size-overlap): 128 | c_end = j+size if j+size < w else w 129 | 130 | input_patch = input[:,:,i:r_end,j:c_end] 131 | input_var = input_patch.cuda() 132 | with torch.no_grad(): 133 | output_patch, _ = model(input_var) 134 | #_, output_patch = model(input_var) 135 | 136 | ind2_s = j+overlap//2 if j>0 else 0 137 | ind2_e = j+size-overlap//2 if j+size