├── LICENSE ├── README.md ├── dataloaders ├── __init__.py ├── custom_transforms.py └── fundus_dataloader.py ├── mypath.py ├── networks ├── GAN.py ├── __init__.py ├── aspp.py ├── backbone │ ├── __init__.py │ ├── drn.py │ ├── mobilenet.py │ ├── resnet.py │ └── xception.py ├── decoder.py ├── deeplabv3.py └── sync_batchnorm │ ├── batchnorm.py │ └── comm.py ├── test.py ├── train.py ├── train_process ├── Trainer.py └── __init__.py └── utils ├── Utils.py ├── __init__.py └── metrics.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Shujun WANG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-BEAL 2 | 3 | Code for paper 'Boundary and Entropy-driven Adversarial 4 | Learning for Fundus Image Segmentation' early accepted by MICCAI 2019. 5 | 6 | ### Introduction 7 | This is a PyTorch(1.0.1.post2) implementation of [BEAL](https://github.com/emma-sjwang/BEAL). 8 | The code was tested with Anaconda and Python 3.7.1. 9 | ```Shell 10 | conda install pytorch torchvision cudatoolkit=9.0 -c pytorch 11 | ``` 12 | 13 | ### Installation 14 | 15 | After installing the dependency: 16 | ``` Shell 17 | pip install pyyaml 18 | pip install pytz 19 | pip install tensorboardX==1.4 matplotlib pillow 20 | pip install tqdm 21 | conda install scipy==1.1.0 22 | conda install -c conda-forge opencv 23 | ``` 24 | 25 | 0. Clone the repo: 26 | ```Shell 27 | git clone https://github.com/emma-sjwang/BEAL.git 28 | cd BEAL 29 | ``` 30 | 31 | 1. Install dependencies: 32 | For PyTorch dependency, see [pytorch.org](https://pytorch.org/) for more details. 33 | 34 | For custom dependencies: 35 | ```Shell 36 | 37 | ``` 38 | 39 | 2. Configure your dataset path in [train.py](https://github.com/emma-sjwang/BEAL/blob/master/train.py) with parameter '--data-dir'. 40 | Dataset download link: 41 | [DGS](http://cvit.iiit.ac.in/projects/mip/drishti-gs/mip-dataset2/enter.php) 42 | [RIM-ONE](http://medimrg.webs.ull.es/research/downloads/) 43 | [Refuge](https://refuge.grand-challenge.org) 44 | 45 | OR you can download an already preprocessed data from this [link](https://drive.google.com/file/d/1B7ArHRBjt2Dx29a3A6X_lGhD0vDVr3sy/view?usp=sharing). 46 | 47 | 48 | 3. You can train deeplab v3+ using mobilenetv2 or others as backbone. 49 | 50 | To train it, please do: 51 | ```Shell 52 | python train.py -g 0 --data-dir /data/ssd/public/sjwang/fundus_data/domain_adaptation --batch-size 8 --datasetT RIM-ONE_r3 53 | ``` 54 | To test it, please do: 55 | Download the weights can put them into the log folder from [link](https://drive.google.com/open?id=1ZPLX937VT31KOZLtIOZjc2IBpnZIYawU). 56 | ```Shell 57 | python test.py --model-file ./logs/DGS_weights.tar --dataset Drishti-GS 58 | ``` 59 | 60 | 61 | ### Citation 62 | ``` 63 | @inproceedings{wang2019boundary, 64 | title={Boundary and Entropy-driven Adversarial Learning for Fundus Image Segmentation}, 65 | author={Wang, Shujun and Yu, Lequan and Li, Kang and Yang, Xin and Fu, Chi-Wing and Heng, Pheng-Ann}, 66 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 67 | pages={102--110}, 68 | year={2019}, 69 | organization={Springer} 70 | } 71 | ``` 72 | 73 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emma-sjwang/BEAL/945cad38a354605b8bca5bc01ae1b65848d605e1/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numbers 4 | import random 5 | import numpy as np 6 | 7 | from PIL import Image, ImageOps 8 | from scipy.ndimage.filters import gaussian_filter 9 | from matplotlib.pyplot import imshow, imsave 10 | from scipy.ndimage.interpolation import map_coordinates 11 | import cv2 12 | from scipy import ndimage 13 | 14 | 15 | def to_multilabel(pre_mask, classes = 2): 16 | mask = np.zeros((pre_mask.shape[0], pre_mask.shape[1], classes)) 17 | mask[pre_mask == 1] = [0, 1] 18 | mask[pre_mask == 2] = [1, 1] 19 | return mask 20 | 21 | 22 | class add_salt_pepper_noise(): 23 | def __call__(self, sample): 24 | image = sample['image'] 25 | X_imgs_copy = image.copy() 26 | # row = image.shape[0] 27 | # col = image.shape[1] 28 | salt_vs_pepper = 0.2 29 | amount = 0.004 30 | 31 | num_salt = np.ceil(amount * X_imgs_copy.size * salt_vs_pepper) 32 | num_pepper = np.ceil(amount * X_imgs_copy.size * (1.0 - salt_vs_pepper)) 33 | 34 | seed = random.random() 35 | if seed > 0.75: 36 | # Add Salt noise 37 | coords = [np.random.randint(0, i - 1, int(num_salt)) for i in X_imgs_copy.shape] 38 | X_imgs_copy[coords[0], coords[1], :] = 1 39 | elif seed > 0.5: 40 | # Add Pepper noise 41 | coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in X_imgs_copy.shape] 42 | X_imgs_copy[coords[0], coords[1], :] = 0 43 | 44 | return {'image': X_imgs_copy, 45 | 'label': sample['label'], 46 | 'img_name': sample['img_name']} 47 | 48 | class adjust_light(): 49 | def __call__(self, sample): 50 | image = sample['image'] 51 | seed = random.random() 52 | if seed > 0.5: 53 | gamma = random.random() * 3 + 0.5 54 | invGamma = 1.0 / gamma 55 | table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype(np.uint8) 56 | image = cv2.LUT(np.array(image).astype(np.uint8), table).astype(np.uint8) 57 | return {'image': image, 58 | 'label': sample['label'], 59 | 'img_name': sample['img_name']} 60 | else: 61 | return sample 62 | 63 | 64 | class eraser(): 65 | def __call__(self, sample, s_l=0.02, s_h=0.06, r_1=0.3, r_2=0.6, v_l=0, v_h=255, pixel_level=False): 66 | image = sample['image'] 67 | img_h, img_w, img_c = image.shape 68 | 69 | 70 | if random.random() > 0.5: 71 | return sample 72 | 73 | while True: 74 | s = np.random.uniform(s_l, s_h) * img_h * img_w 75 | r = np.random.uniform(r_1, r_2) 76 | w = int(np.sqrt(s / r)) 77 | h = int(np.sqrt(s * r)) 78 | left = np.random.randint(0, img_w) 79 | top = np.random.randint(0, img_h) 80 | 81 | if left + w <= img_w and top + h <= img_h: 82 | break 83 | 84 | if pixel_level: 85 | c = np.random.uniform(v_l, v_h, (h, w, img_c)) 86 | else: 87 | c = np.random.uniform(v_l, v_h) 88 | 89 | image[top:top + h, left:left + w, :] = c 90 | 91 | return {'image': image, 92 | 'label': sample['label'], 93 | 'img_name': sample['img_name']} 94 | 95 | class elastic_transform(): 96 | """Elastic deformation of images as described in [Simard2003]_. 97 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 98 | Convolutional Neural Networks applied to Visual Document Analysis", in 99 | Proc. of the International Conference on Document Analysis and 100 | Recognition, 2003. 101 | """ 102 | 103 | # def __init__(self): 104 | 105 | def __call__(self, sample): 106 | image, label = sample['image'], sample['label'] 107 | alpha = image.size[1] * 2 108 | sigma = image.size[1] * 0.08 109 | random_state = None 110 | seed = random.random() 111 | if seed > 0.5: 112 | # print(image.size) 113 | assert len(image.size) == 2 114 | 115 | if random_state is None: 116 | random_state = np.random.RandomState(None) 117 | 118 | shape = image.size[0:2] 119 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 120 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 121 | 122 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 123 | indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1)) 124 | 125 | transformed_image = np.zeros([image.size[0], image.size[1], 3]) 126 | transformed_label = np.zeros([image.size[0], image.size[1]]) 127 | 128 | for i in range(3): 129 | # print(i) 130 | transformed_image[:, :, i] = map_coordinates(np.array(image)[:, :, i], indices, order=1).reshape(shape) 131 | # break 132 | if label is not None: 133 | transformed_label[:, :] = map_coordinates(np.array(label)[:, :], indices, order=1, mode='nearest').reshape(shape) 134 | else: 135 | transformed_label = None 136 | transformed_image = transformed_image.astype(np.uint8) 137 | 138 | if label is not None: 139 | transformed_label = transformed_label.astype(np.uint8) 140 | 141 | return {'image': transformed_image, 142 | 'label': transformed_label, 143 | 'img_name': sample['img_name']} 144 | else: 145 | return {'image': np.array(sample['image']), 146 | 'label': np.array(sample['label']), 147 | 'img_name': sample['img_name']} 148 | 149 | 150 | 151 | 152 | class RandomCrop(object): 153 | def __init__(self, size, padding=0): 154 | if isinstance(size, numbers.Number): 155 | self.size = (int(size), int(size)) 156 | else: 157 | self.size = size # h, w 158 | self.padding = padding 159 | 160 | def __call__(self, sample): 161 | img, mask = sample['image'], sample['label'] 162 | w, h = img.size 163 | if self.padding > 0 or w < self.size[0] or h < self.size[1]: 164 | padding = np.maximum(self.padding,np.maximum((self.size[0]-w)//2+5,(self.size[1]-h)//2+5)) 165 | img = ImageOps.expand(img, border=padding, fill=0) 166 | mask = ImageOps.expand(mask, border=padding, fill=255) 167 | 168 | assert img.width == mask.width 169 | assert img.height == mask.height 170 | w, h = img.size 171 | th, tw = self.size # target size 172 | if w == tw and h == th: 173 | return {'image': img, 174 | 'label': mask, 175 | 'img_name': sample['img_name']} 176 | x1 = random.randint(0, w - tw) 177 | y1 = random.randint(0, h - th) 178 | img = img.crop((x1, y1, x1 + tw, y1 + th)) 179 | mask = mask.crop((x1, y1, x1 + tw, y1 + th)) 180 | return {'image': img, 181 | 'label': mask, 182 | 'img_name': sample['img_name']} 183 | 184 | 185 | class CenterCrop(object): 186 | def __init__(self, size): 187 | if isinstance(size, numbers.Number): 188 | self.size = (int(size), int(size)) 189 | else: 190 | self.size = size 191 | 192 | def __call__(self, sample): 193 | img = sample['image'] 194 | mask = sample['label'] 195 | 196 | w, h = img.size 197 | th, tw = self.size 198 | x1 = int(round((w - tw) / 2.)) 199 | y1 = int(round((h - th) / 2.)) 200 | img = img.crop((x1, y1, x1 + tw, y1 + th)) 201 | mask = mask.crop((x1, y1, x1 + tw, y1 + th)) 202 | 203 | return {'image': img, 204 | 'label': mask, 205 | 'img_name': sample['img_name']} 206 | 207 | 208 | class RandomFlip(object): 209 | def __call__(self, sample): 210 | img = sample['image'] 211 | mask = sample['label'] 212 | name = sample['img_name'] 213 | if random.random() < 0.5: 214 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 215 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 216 | if random.random() < 0.5: 217 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 218 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM) 219 | 220 | return {'image': img, 221 | 'label': mask, 222 | 'img_name': name 223 | } 224 | 225 | 226 | class FixedResize(object): 227 | def __init__(self, size): 228 | self.size = tuple(reversed(size)) # size: (h, w) 229 | 230 | def __call__(self, sample): 231 | img = sample['image'] 232 | mask = sample['label'] 233 | name = sample['img_name'] 234 | 235 | assert img.width == mask.width 236 | assert img.height == mask.height 237 | img = img.resize(self.size, Image.BILINEAR) 238 | mask = mask.resize(self.size, Image.NEAREST) 239 | 240 | return {'image': img, 241 | 'label': mask, 242 | 'img_name': name} 243 | 244 | 245 | class Scale(object): 246 | def __init__(self, size): 247 | if isinstance(size, numbers.Number): 248 | self.size = (int(size), int(size)) 249 | else: 250 | self.size = size 251 | 252 | def __call__(self, sample): 253 | img = sample['image'] 254 | mask = sample['label'] 255 | assert img.width == mask.width 256 | assert img.height == mask.height 257 | w, h = img.size 258 | 259 | if (w >= h and w == self.size[1]) or (h >= w and h == self.size[0]): 260 | return {'image': img, 261 | 'label': mask, 262 | 'img_name': sample['img_name']} 263 | oh, ow = self.size 264 | img = img.resize((ow, oh), Image.BILINEAR) 265 | mask = mask.resize((ow, oh), Image.NEAREST) 266 | 267 | return {'image': img, 268 | 'label': mask, 269 | 'img_name': sample['img_name']} 270 | 271 | 272 | class RandomSizedCrop(object): 273 | def __init__(self, size): 274 | self.size = size 275 | 276 | def __call__(self, sample): 277 | img = sample['image'] 278 | mask = sample['label'] 279 | name = sample['img_name'] 280 | assert img.width == mask.width 281 | assert img.height == mask.height 282 | for attempt in range(10): 283 | area = img.size[0] * img.size[1] 284 | target_area = random.uniform(0.45, 1.0) * area 285 | aspect_ratio = random.uniform(0.5, 2) 286 | 287 | w = int(round(math.sqrt(target_area * aspect_ratio))) 288 | h = int(round(math.sqrt(target_area / aspect_ratio))) 289 | 290 | if random.random() < 0.5: 291 | w, h = h, w 292 | 293 | if w <= img.size[0] and h <= img.size[1]: 294 | x1 = random.randint(0, img.size[0] - w) 295 | y1 = random.randint(0, img.size[1] - h) 296 | 297 | img = img.crop((x1, y1, x1 + w, y1 + h)) 298 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 299 | assert (img.size == (w, h)) 300 | 301 | img = img.resize((self.size, self.size), Image.BILINEAR) 302 | mask = mask.resize((self.size, self.size), Image.NEAREST) 303 | 304 | return {'image': img, 305 | 'label': mask, 306 | 'img_name': name} 307 | 308 | # Fallback 309 | scale = Scale(self.size) 310 | crop = CenterCrop(self.size) 311 | sample = crop(scale(sample)) 312 | return sample 313 | 314 | 315 | class RandomRotate(object): 316 | def __init__(self, size=512): 317 | self.degree = random.randint(1, 4) * 90 318 | self.size = size 319 | 320 | def __call__(self, sample): 321 | img = sample['image'] 322 | mask = sample['label'] 323 | 324 | seed = random.random() 325 | if seed > 0.5: 326 | rotate_degree = self.degree 327 | img = img.rotate(rotate_degree, Image.BILINEAR, expand=0) 328 | mask = mask.rotate(rotate_degree, Image.NEAREST, expand=255) 329 | 330 | sample = {'image': img, 'label': mask, 'img_name': sample['img_name']} 331 | return sample 332 | 333 | 334 | class RandomScaleCrop(object): 335 | def __init__(self, size): 336 | self.size = size 337 | self.crop = RandomCrop(self.size) 338 | 339 | def __call__(self, sample): 340 | img = sample['image'] 341 | mask = sample['label'] 342 | name = sample['img_name'] 343 | # print(img.size) 344 | assert img.width == mask.width 345 | assert img.height == mask.height 346 | 347 | seed = random.random() 348 | if seed > 0.5: 349 | w = int(random.uniform(0.5, 1.5) * img.size[0]) 350 | h = int(random.uniform(0.5, 1.5) * img.size[1]) 351 | 352 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 353 | sample = {'image': img, 'label': mask, 'img_name': name} 354 | 355 | return self.crop(sample) 356 | 357 | 358 | class ResizeImg(object): 359 | def __init__(self, size): 360 | self.size = size 361 | 362 | def __call__(self, sample): 363 | img = sample['image'] 364 | mask = sample['label'] 365 | name = sample['img_name'] 366 | assert img.width == mask.width 367 | assert img.height == mask.height 368 | 369 | img = img.resize((self.size, self.size)) 370 | 371 | sample = {'image': img, 'label': mask, 'img_name': name} 372 | return sample 373 | 374 | 375 | class Resize(object): 376 | def __init__(self, size): 377 | self.size = size 378 | 379 | def __call__(self, sample): 380 | img = sample['image'] 381 | mask = sample['label'] 382 | name = sample['img_name'] 383 | assert img.width == mask.width 384 | assert img.height == mask.height 385 | 386 | img = img.resize((self.size, self.size)) 387 | mask = mask.resize((self.size, self.size)) 388 | 389 | sample = {'image': img, 'label': mask, 'img_name': name} 390 | return sample 391 | 392 | class Normalize(object): 393 | """Normalize a tensor image with mean and standard deviation. 394 | Args: 395 | mean (tuple): means for each channel. 396 | std (tuple): standard deviations for each channel. 397 | """ 398 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 399 | self.mean = mean 400 | self.std = std 401 | 402 | def __call__(self, sample): 403 | img = np.array(sample['image']).astype(np.float32) 404 | mask = np.array(sample['label']).astype(np.float32) 405 | img /= 255.0 406 | img -= self.mean 407 | img /= self.std 408 | 409 | return {'image': img, 410 | 'label': mask, 411 | 'img_name': sample['img_name']} 412 | 413 | 414 | class GetBoundary(object): 415 | def __init__(self, width = 5): 416 | self.width = width 417 | def __call__(self, mask): 418 | cup = mask[:, :, 0] 419 | disc = mask[:, :, 1] 420 | dila_cup = ndimage.binary_dilation(cup, iterations=self.width).astype(cup.dtype) 421 | eros_cup = ndimage.binary_erosion(cup, iterations=self.width).astype(cup.dtype) 422 | dila_disc= ndimage.binary_dilation(disc, iterations=self.width).astype(disc.dtype) 423 | eros_disc= ndimage.binary_erosion(disc, iterations=self.width).astype(disc.dtype) 424 | cup = dila_cup + eros_cup 425 | disc = dila_disc + eros_disc 426 | cup[cup==2]=0 427 | disc[disc==2]=0 428 | boundary = (cup + disc) > 0 429 | return boundary.astype(np.uint8) 430 | 431 | 432 | class Normalize_tf(object): 433 | """Normalize a tensor image with mean and standard deviation. 434 | Args: 435 | mean (tuple): means for each channel. 436 | std (tuple): standard deviations for each channel. 437 | """ 438 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 439 | self.mean = mean 440 | self.std = std 441 | self.get_boundary = GetBoundary() 442 | 443 | def __call__(self, sample): 444 | img = np.array(sample['image']).astype(np.float32) 445 | __mask = np.array(sample['label']).astype(np.uint8) 446 | name = sample['img_name'] 447 | img /= 127.5 448 | img -= 1.0 449 | _mask = np.zeros([__mask.shape[0], __mask.shape[1]]) 450 | _mask[__mask > 200] = 255 451 | _mask[(__mask > 50) & (__mask < 201)] = 128 452 | 453 | __mask[_mask == 0] = 2 454 | __mask[_mask == 255] = 0 455 | __mask[_mask == 128] = 1 456 | 457 | mask = to_multilabel(__mask) 458 | boundary = self.get_boundary(mask) * 255 459 | boundary = ndimage.gaussian_filter(boundary, sigma=3) / 255.0 460 | boundary = np.expand_dims(boundary, -1) 461 | 462 | return {'image': img, 463 | 'map': mask, 464 | 'boundary': boundary, 465 | 'img_name': name 466 | } 467 | 468 | 469 | class Normalize_cityscapes(object): 470 | """Normalize a tensor image with mean and standard deviation. 471 | Args: 472 | mean (tuple): means for each channel. 473 | std (tuple): standard deviations for each channel. 474 | """ 475 | def __init__(self, mean=(0., 0., 0.)): 476 | self.mean = mean 477 | 478 | def __call__(self, sample): 479 | img = np.array(sample['image']).astype(np.float32) 480 | mask = np.array(sample['label']).astype(np.float32) 481 | img -= self.mean 482 | img /= 255.0 483 | 484 | return {'image': img, 485 | 'label': mask, 486 | 'img_name': sample['img_name']} 487 | 488 | 489 | class ToTensor(object): 490 | """Convert ndarrays in sample to Tensors.""" 491 | 492 | def __call__(self, sample): 493 | # swap color axis because 494 | # numpy image: H x W x C 495 | # torch image: C X H X W 496 | img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1)) 497 | map = np.array(sample['map']).astype(np.uint8).transpose((2, 0, 1)) 498 | boundary = np.array(sample['boundary']).astype(np.float).transpose((2, 0, 1)) 499 | name = sample['img_name'] 500 | img = torch.from_numpy(img).float() 501 | map = torch.from_numpy(map).float() 502 | boundary = torch.from_numpy(boundary).float() 503 | 504 | return {'image': img, 505 | 'map': map, 506 | 'boundary': boundary, 507 | 'img_name': name} -------------------------------------------------------------------------------- /dataloaders/fundus_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from mypath import Path 7 | from glob import glob 8 | import random 9 | 10 | 11 | class FundusSegmentation(Dataset): 12 | """ 13 | Fundus segmentation dataset 14 | including 5 domain dataset 15 | one for test others for training 16 | """ 17 | 18 | def __init__(self, 19 | base_dir=Path.db_root_dir('fundus'), 20 | dataset='refuge', 21 | split='train', 22 | testid=None, 23 | transform=None 24 | ): 25 | """ 26 | :param base_dir: path to VOC dataset directory 27 | :param split: train/val 28 | :param transform: transform to apply 29 | """ 30 | # super().__init__() 31 | self._base_dir = base_dir 32 | self.image_list = [] 33 | self.split = split 34 | 35 | self.image_pool = [] 36 | self.label_pool = [] 37 | self.img_name_pool = [] 38 | SEED = 1212 39 | random.seed(SEED) 40 | 41 | self._image_dir = os.path.join(self._base_dir, dataset, split, 'image') 42 | print(self._image_dir) 43 | imagelist = glob(self._image_dir + "/*.png") 44 | for image_path in imagelist: 45 | gt_path = image_path.replace('image', 'mask') 46 | self.image_list.append({'image': image_path, 'label': gt_path, 'id': testid}) 47 | 48 | self.transform = transform 49 | self._read_img_into_memory() 50 | # Display stats 51 | print('Number of images in {}: {:d}'.format(split, len(self.image_list))) 52 | 53 | def __len__(self): 54 | return len(self.image_list) 55 | 56 | def __getitem__(self, index): 57 | _img = self.image_pool[index] 58 | _target = self.label_pool[index] 59 | _img_name = self.img_name_pool[index] 60 | anco_sample = {'image': _img, 'label': _target, 'img_name': _img_name} 61 | 62 | if self.transform is not None: 63 | anco_sample = self.transform(anco_sample) 64 | 65 | return anco_sample 66 | 67 | def _read_img_into_memory(self): 68 | 69 | img_num = len(self.image_list) 70 | for index in range(img_num): 71 | self.image_pool.append(Image.open(self.image_list[index]['image']).convert('RGB')) 72 | _target = Image.open(self.image_list[index]['label']) 73 | if _target.mode is 'RGB': 74 | _target = _target.convert('L') 75 | self.label_pool.append(_target) 76 | _img_name = self.image_list[index]['image'].split('/')[-1] 77 | self.img_name_pool.append(_img_name) 78 | 79 | 80 | def __str__(self): 81 | return 'Fundus(split=' + str(self.split) + ')' 82 | 83 | 84 | -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | class Path(object): 2 | @staticmethod 3 | def db_root_dir(database): 4 | if database == 'fundus': 5 | return '../../../../data/disc_cup_split/' # foler that contains leftImg8bit/ 6 | else: 7 | print('Database {} not available.'.format(database)) 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /networks/GAN.py: -------------------------------------------------------------------------------- 1 | # camera-ready 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | 8 | class Discriminator(nn.Module): 9 | def __init__(self, ): 10 | super(Discriminator, self).__init__() 11 | 12 | filter_num_list = [4096, 2048, 1024, 1] 13 | 14 | self.fc1 = nn.Linear(24576, filter_num_list[0]) 15 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 16 | self.fc2 = nn.Linear(filter_num_list[0], filter_num_list[1]) 17 | self.fc3 = nn.Linear(filter_num_list[1], filter_num_list[2]) 18 | self.fc4 = nn.Linear(filter_num_list[2], filter_num_list[3]) 19 | 20 | # self.sigmoid = nn.Sigmoid() 21 | self._initialize_weights() 22 | 23 | 24 | def _initialize_weights(self): 25 | 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | m.weight.data.normal_(0.0, 0.02) 29 | if m.bias is not None: 30 | m.bias.data.zero_() 31 | 32 | if isinstance(m, nn.ConvTranspose2d): 33 | m.weight.data.normal_(0.0, 0.02) 34 | if m.bias is not None: 35 | m.bias.data.zero_() 36 | 37 | if isinstance(m, nn.Linear): 38 | m.weight.data.normal_(0.0, 0.02) 39 | if m.bias is not None: 40 | # m.bias.data.copy_(1.0) 41 | m.bias.data.zero_() 42 | 43 | 44 | def forward(self, x): 45 | 46 | x = self.leakyrelu(self.fc1(x)) 47 | x = self.leakyrelu(self.fc2(x)) 48 | x = self.leakyrelu(self.fc3(x)) 49 | x = self.fc4(x) 50 | return x 51 | 52 | 53 | class OutputDiscriminator(nn.Module): 54 | def __init__(self, ): 55 | super(OutputDiscriminator, self).__init__() 56 | 57 | filter_num_list = [64, 128, 256, 512, 1] 58 | 59 | self.conv1 = nn.Conv2d(2, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 60 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 61 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 62 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 63 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 64 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 65 | # self.sigmoid = nn.Sigmoid() 66 | self._initialize_weights() 67 | 68 | 69 | def _initialize_weights(self): 70 | for m in self.modules(): 71 | if isinstance(m, nn.Conv2d): 72 | m.weight.data.normal_(0.0, 0.02) 73 | if m.bias is not None: 74 | m.bias.data.zero_() 75 | 76 | 77 | def forward(self, x): 78 | x = self.leakyrelu(self.conv1(x)) 79 | x = self.leakyrelu(self.conv2(x)) 80 | x = self.leakyrelu(self.conv3(x)) 81 | x = self.leakyrelu(self.conv4(x)) 82 | x = self.conv5(x) 83 | return x 84 | 85 | 86 | class UncertaintyDiscriminator(nn.Module): 87 | def __init__(self, ): 88 | super(UncertaintyDiscriminator, self).__init__() 89 | 90 | filter_num_list = [64, 128, 256, 512, 1] 91 | 92 | self.conv1 = nn.Conv2d(2, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 93 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 94 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 95 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 96 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 97 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 98 | # self.sigmoid = nn.Sigmoid() 99 | self._initialize_weights() 100 | 101 | 102 | def _initialize_weights(self): 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | m.weight.data.normal_(0.0, 0.02) 106 | if m.bias is not None: 107 | m.bias.data.zero_() 108 | 109 | 110 | def forward(self, x): 111 | x = self.leakyrelu(self.conv1(x)) 112 | x = self.leakyrelu(self.conv2(x)) 113 | x = self.leakyrelu(self.conv3(x)) 114 | x = self.leakyrelu(self.conv4(x)) 115 | x = self.conv5(x) 116 | return x 117 | 118 | class BoundaryDiscriminator(nn.Module): 119 | def __init__(self, ): 120 | super(BoundaryDiscriminator, self).__init__() 121 | 122 | filter_num_list = [64, 128, 256, 512, 1] 123 | 124 | self.conv1 = nn.Conv2d(1, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 125 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 126 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 127 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 128 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 129 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 130 | # self.sigmoid = nn.Sigmoid() 131 | self._initialize_weights() 132 | 133 | 134 | def _initialize_weights(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | m.weight.data.normal_(0.0, 0.02) 138 | if m.bias is not None: 139 | m.bias.data.zero_() 140 | 141 | 142 | def forward(self, x): 143 | x = self.leakyrelu(self.conv1(x)) 144 | x = self.leakyrelu(self.conv2(x)) 145 | x = self.leakyrelu(self.conv3(x)) 146 | x = self.leakyrelu(self.conv4(x)) 147 | x = self.conv5(x) 148 | return x 149 | 150 | class BoundaryEntDiscriminator(nn.Module): 151 | def __init__(self, ): 152 | super(BoundaryEntDiscriminator, self).__init__() 153 | 154 | filter_num_list = [64, 128, 256, 512, 1] 155 | 156 | self.conv1 = nn.Conv2d(3, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 157 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 158 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 159 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 160 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 161 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 162 | # self.sigmoid = nn.Sigmoid() 163 | self._initialize_weights() 164 | 165 | 166 | def _initialize_weights(self): 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | m.weight.data.normal_(0.0, 0.02) 170 | if m.bias is not None: 171 | m.bias.data.zero_() 172 | 173 | 174 | def forward(self, x): 175 | x = self.leakyrelu(self.conv1(x)) 176 | x = self.leakyrelu(self.conv2(x)) 177 | x = self.leakyrelu(self.conv3(x)) 178 | x = self.leakyrelu(self.conv4(x)) 179 | x = self.conv5(x) 180 | return x 181 | 182 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emma-sjwang/BEAL/945cad38a354605b8bca5bc01ae1b65848d605e1/networks/__init__.py -------------------------------------------------------------------------------- /networks/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 60 | self.bn1 = BatchNorm(256) 61 | self.relu = nn.ReLU() 62 | self.dropout = nn.Dropout(0.5) 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x1 = self.aspp1(x) 67 | x2 = self.aspp2(x) 68 | x3 = self.aspp3(x) 69 | x4 = self.aspp4(x) 70 | x5 = self.global_avg_pool(x) 71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | return self.dropout(x) 79 | 80 | def _init_weight(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) 96 | -------------------------------------------------------------------------------- /networks/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /networks/backbone/drn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | webroot = 'https://tigress-web.princeton.edu/~fy/drn/models/' 7 | 8 | model_urls = { 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth', 11 | 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth', 12 | 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth', 13 | 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth', 14 | 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth', 15 | 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth', 16 | 'drn-d-105': webroot + 'drn_d_105-12b40979.pth' 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=padding, bias=False, dilation=dilation) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, 29 | dilation=(1, 1), residual=True, BatchNorm=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride, 32 | padding=dilation[0], dilation=dilation[0]) 33 | self.bn1 = BatchNorm(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes, 36 | padding=dilation[1], dilation=dilation[1]) 37 | self.bn2 = BatchNorm(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | self.residual = residual 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | if self.residual: 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, 65 | dilation=(1, 1), residual=True, BatchNorm=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = BatchNorm(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=dilation[1], bias=False, 71 | dilation=dilation[1]) 72 | self.bn2 = BatchNorm(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = BatchNorm(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class DRN(nn.Module): 103 | 104 | def __init__(self, block, layers, arch='D', 105 | channels=(16, 32, 64, 128, 256, 512, 512, 512), 106 | BatchNorm=None): 107 | super(DRN, self).__init__() 108 | self.inplanes = channels[0] 109 | self.out_dim = channels[-1] 110 | self.arch = arch 111 | 112 | if arch == 'C': 113 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 114 | padding=3, bias=False) 115 | self.bn1 = BatchNorm(channels[0]) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | self.layer1 = self._make_layer( 119 | BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 120 | self.layer2 = self._make_layer( 121 | BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 122 | 123 | elif arch == 'D': 124 | self.layer0 = nn.Sequential( 125 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, 126 | bias=False), 127 | BatchNorm(channels[0]), 128 | nn.ReLU(inplace=True) 129 | ) 130 | 131 | self.layer1 = self._make_conv_layers( 132 | channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 133 | self.layer2 = self._make_conv_layers( 134 | channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 135 | 136 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm) 137 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm) 138 | self.layer5 = self._make_layer(block, channels[4], layers[4], 139 | dilation=2, new_level=False, BatchNorm=BatchNorm) 140 | self.layer6 = None if layers[5] == 0 else \ 141 | self._make_layer(block, channels[5], layers[5], dilation=4, 142 | new_level=False, BatchNorm=BatchNorm) 143 | 144 | if arch == 'C': 145 | self.layer7 = None if layers[6] == 0 else \ 146 | self._make_layer(BasicBlock, channels[6], layers[6], dilation=2, 147 | new_level=False, residual=False, BatchNorm=BatchNorm) 148 | self.layer8 = None if layers[7] == 0 else \ 149 | self._make_layer(BasicBlock, channels[7], layers[7], dilation=1, 150 | new_level=False, residual=False, BatchNorm=BatchNorm) 151 | elif arch == 'D': 152 | self.layer7 = None if layers[6] == 0 else \ 153 | self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm) 154 | self.layer8 = None if layers[7] == 0 else \ 155 | self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm) 156 | 157 | self._init_weight() 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, SynchronizedBatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 173 | new_level=True, residual=True, BatchNorm=None): 174 | assert dilation == 1 or dilation % 2 == 0 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | nn.Conv2d(self.inplanes, planes * block.expansion, 179 | kernel_size=1, stride=stride, bias=False), 180 | BatchNorm(planes * block.expansion), 181 | ) 182 | 183 | layers = list() 184 | layers.append(block( 185 | self.inplanes, planes, stride, downsample, 186 | dilation=(1, 1) if dilation == 1 else ( 187 | dilation // 2 if new_level else dilation, dilation), 188 | residual=residual, BatchNorm=BatchNorm)) 189 | self.inplanes = planes * block.expansion 190 | for i in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, residual=residual, 192 | dilation=(dilation, dilation), BatchNorm=BatchNorm)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None): 197 | modules = [] 198 | for i in range(convs): 199 | modules.extend([ 200 | nn.Conv2d(self.inplanes, channels, kernel_size=3, 201 | stride=stride if i == 0 else 1, 202 | padding=dilation, bias=False, dilation=dilation), 203 | BatchNorm(channels), 204 | nn.ReLU(inplace=True)]) 205 | self.inplanes = channels 206 | return nn.Sequential(*modules) 207 | 208 | def forward(self, x): 209 | if self.arch == 'C': 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | elif self.arch == 'D': 214 | x = self.layer0(x) 215 | 216 | x = self.layer1(x) 217 | x = self.layer2(x) 218 | 219 | x = self.layer3(x) 220 | low_level_feat = x 221 | 222 | x = self.layer4(x) 223 | x = self.layer5(x) 224 | 225 | if self.layer6 is not None: 226 | x = self.layer6(x) 227 | 228 | if self.layer7 is not None: 229 | x = self.layer7(x) 230 | 231 | if self.layer8 is not None: 232 | x = self.layer8(x) 233 | 234 | return x, low_level_feat 235 | 236 | 237 | class DRN_A(nn.Module): 238 | 239 | def __init__(self, block, layers, BatchNorm=None): 240 | self.inplanes = 64 241 | super(DRN_A, self).__init__() 242 | self.out_dim = 512 * block.expansion 243 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 244 | bias=False) 245 | self.bn1 = BatchNorm(64) 246 | self.relu = nn.ReLU(inplace=True) 247 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 248 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm) 249 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm) 250 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 251 | dilation=2, BatchNorm=BatchNorm) 252 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 253 | dilation=4, BatchNorm=BatchNorm) 254 | 255 | self._init_weight() 256 | 257 | def _init_weight(self): 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 261 | m.weight.data.normal_(0, math.sqrt(2. / n)) 262 | elif isinstance(m, SynchronizedBatchNorm2d): 263 | m.weight.data.fill_(1) 264 | m.bias.data.zero_() 265 | elif isinstance(m, nn.BatchNorm2d): 266 | m.weight.data.fill_(1) 267 | m.bias.data.zero_() 268 | 269 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 270 | downsample = None 271 | if stride != 1 or self.inplanes != planes * block.expansion: 272 | downsample = nn.Sequential( 273 | nn.Conv2d(self.inplanes, planes * block.expansion, 274 | kernel_size=1, stride=stride, bias=False), 275 | BatchNorm(planes * block.expansion), 276 | ) 277 | 278 | layers = [] 279 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm)) 280 | self.inplanes = planes * block.expansion 281 | for i in range(1, blocks): 282 | layers.append(block(self.inplanes, planes, 283 | dilation=(dilation, dilation, ), BatchNorm=BatchNorm)) 284 | 285 | return nn.Sequential(*layers) 286 | 287 | def forward(self, x): 288 | x = self.conv1(x) 289 | x = self.bn1(x) 290 | x = self.relu(x) 291 | x = self.maxpool(x) 292 | 293 | x = self.layer1(x) 294 | x = self.layer2(x) 295 | x = self.layer3(x) 296 | x = self.layer4(x) 297 | 298 | return x 299 | 300 | def drn_a_50(BatchNorm, pretrained=True): 301 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 304 | return model 305 | 306 | 307 | def drn_c_26(BatchNorm, pretrained=True): 308 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm) 309 | if pretrained: 310 | pretrained = model_zoo.load_url(model_urls['drn-c-26']) 311 | del pretrained['fc.weight'] 312 | del pretrained['fc.bias'] 313 | model.load_state_dict(pretrained) 314 | return model 315 | 316 | 317 | def drn_c_42(BatchNorm, pretrained=True): 318 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 319 | if pretrained: 320 | pretrained = model_zoo.load_url(model_urls['drn-c-42']) 321 | del pretrained['fc.weight'] 322 | del pretrained['fc.bias'] 323 | model.load_state_dict(pretrained) 324 | return model 325 | 326 | 327 | def drn_c_58(BatchNorm, pretrained=True): 328 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 329 | if pretrained: 330 | pretrained = model_zoo.load_url(model_urls['drn-c-58']) 331 | del pretrained['fc.weight'] 332 | del pretrained['fc.bias'] 333 | model.load_state_dict(pretrained) 334 | return model 335 | 336 | 337 | def drn_d_22(BatchNorm, pretrained=True): 338 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm) 339 | if pretrained: 340 | pretrained = model_zoo.load_url(model_urls['drn-d-22']) 341 | del pretrained['fc.weight'] 342 | del pretrained['fc.bias'] 343 | model.load_state_dict(pretrained) 344 | return model 345 | 346 | 347 | def drn_d_24(BatchNorm, pretrained=True): 348 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm) 349 | if pretrained: 350 | pretrained = model_zoo.load_url(model_urls['drn-d-24']) 351 | del pretrained['fc.weight'] 352 | del pretrained['fc.bias'] 353 | model.load_state_dict(pretrained) 354 | return model 355 | 356 | 357 | def drn_d_38(BatchNorm, pretrained=True): 358 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 359 | if pretrained: 360 | pretrained = model_zoo.load_url(model_urls['drn-d-38']) 361 | del pretrained['fc.weight'] 362 | del pretrained['fc.bias'] 363 | model.load_state_dict(pretrained) 364 | return model 365 | 366 | 367 | def drn_d_40(BatchNorm, pretrained=True): 368 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm) 369 | if pretrained: 370 | pretrained = model_zoo.load_url(model_urls['drn-d-40']) 371 | del pretrained['fc.weight'] 372 | del pretrained['fc.bias'] 373 | model.load_state_dict(pretrained) 374 | return model 375 | 376 | 377 | def drn_d_54(BatchNorm, pretrained=True): 378 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 379 | if pretrained: 380 | pretrained = model_zoo.load_url(model_urls['drn-d-54']) 381 | del pretrained['fc.weight'] 382 | del pretrained['fc.bias'] 383 | model.load_state_dict(pretrained) 384 | return model 385 | 386 | 387 | def drn_d_105(BatchNorm, pretrained=True): 388 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 389 | if pretrained: 390 | pretrained = model_zoo.load_url(model_urls['drn-d-105']) 391 | del pretrained['fc.weight'] 392 | del pretrained['fc.bias'] 393 | model.load_state_dict(pretrained) 394 | return model 395 | 396 | if __name__ == "__main__": 397 | import torch 398 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True) 399 | input = torch.rand(1, 3, 512, 512) 400 | output, low_level_feat = model(input) 401 | print(output.size()) 402 | print(low_level_feat.size()) 403 | -------------------------------------------------------------------------------- /networks/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /networks/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class ResNet(nn.Module): 46 | 47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 48 | self.inplanes = 64 49 | super(ResNet, self).__init__() 50 | blocks = [1, 2, 4] 51 | if output_stride == 16: 52 | strides = [1, 2, 2, 1] 53 | dilations = [1, 1, 1, 2] 54 | elif output_stride == 8: 55 | strides = [1, 2, 1, 1] 56 | dilations = [1, 1, 2, 4] 57 | else: 58 | raise NotImplementedError 59 | 60 | # Modules 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = BatchNorm(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 72 | self._init_weight() 73 | 74 | if pretrained: 75 | self._load_pretrained_model() 76 | 77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 78 | downsample = None 79 | if stride != 1 or self.inplanes != planes * block.expansion: 80 | downsample = nn.Sequential( 81 | nn.Conv2d(self.inplanes, planes * block.expansion, 82 | kernel_size=1, stride=stride, bias=False), 83 | BatchNorm(planes * block.expansion), 84 | ) 85 | 86 | layers = [] 87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 88 | self.inplanes = planes * block.expansion 89 | for i in range(1, blocks): 90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 95 | downsample = None 96 | if stride != 1 or self.inplanes != planes * block.expansion: 97 | downsample = nn.Sequential( 98 | nn.Conv2d(self.inplanes, planes * block.expansion, 99 | kernel_size=1, stride=stride, bias=False), 100 | BatchNorm(planes * block.expansion), 101 | ) 102 | 103 | layers = [] 104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 105 | downsample=downsample, BatchNorm=BatchNorm)) 106 | self.inplanes = planes * block.expansion 107 | for i in range(1, len(blocks)): 108 | layers.append(block(self.inplanes, planes, stride=1, 109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, input): 114 | x = self.conv1(input) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | low_level_feat = x 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | return x, low_level_feat 125 | 126 | def _init_weight(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, SynchronizedBatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | 138 | def _load_pretrained_model(self): 139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 140 | model_dict = {} 141 | state_dict = self.state_dict() 142 | for k, v in pretrain_dict.items(): 143 | if k in state_dict: 144 | model_dict[k] = v 145 | state_dict.update(model_dict) 146 | self.load_state_dict(state_dict) 147 | 148 | def ResNet101(output_stride, BatchNorm, pretrained=True): 149 | """Constructs a ResNet-101 model. 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 154 | return model 155 | 156 | if __name__ == "__main__": 157 | import torch 158 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 159 | input = torch.rand(1, 3, 512, 512) 160 | output, low_level_feat = model(input) 161 | print(output.size()) 162 | print(low_level_feat.size()) 163 | -------------------------------------------------------------------------------- /networks/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return x, low_level_feat 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in model_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) 289 | -------------------------------------------------------------------------------- /networks/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential( 23 | # nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 24 | # BatchNorm(256), 25 | # nn.ReLU(), 26 | # nn.Dropout(0.5), 27 | # nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 28 | BatchNorm(305), 29 | nn.ReLU(), 30 | nn.Dropout(0.1), 31 | nn.Conv2d(305, num_classes, kernel_size=1, stride=1)) 32 | self.last_conv_boundary = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 33 | BatchNorm(256), 34 | nn.ReLU(), 35 | nn.Dropout(0.5), 36 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 37 | BatchNorm(256), 38 | nn.ReLU(), 39 | nn.Dropout(0.1), 40 | nn.Conv2d(256, 1, kernel_size=1, stride=1)) 41 | self._init_weight() 42 | 43 | 44 | def forward(self, x, low_level_feat): 45 | low_level_feat = self.conv1(low_level_feat) 46 | low_level_feat = self.bn1(low_level_feat) 47 | low_level_feat = self.relu(low_level_feat) 48 | 49 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 50 | x = torch.cat((x, low_level_feat), dim=1) 51 | boundary = self.last_conv_boundary(x) 52 | x = torch.cat([x, boundary], 1) 53 | x1 = self.last_conv(x) 54 | 55 | return x1, boundary 56 | 57 | def _init_weight(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | torch.nn.init.kaiming_normal_(m.weight) 61 | elif isinstance(m, SynchronizedBatchNorm2d): 62 | m.weight.data.fill_(1) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.BatchNorm2d): 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | 68 | def build_decoder(num_classes, backbone, BatchNorm): 69 | return Decoder(num_classes, backbone, BatchNorm) 70 | -------------------------------------------------------------------------------- /networks/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from networks.aspp import build_aspp 6 | from networks.decoder import build_decoder 7 | from networks.backbone import build_backbone 8 | 9 | 10 | class DeepLab(nn.Module): 11 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 12 | sync_bn=True, freeze_bn=False): 13 | super(DeepLab, self).__init__() 14 | if backbone == 'drn': 15 | output_stride = 8 16 | 17 | if sync_bn == True: 18 | BatchNorm = SynchronizedBatchNorm2d 19 | else: 20 | BatchNorm = nn.BatchNorm2d 21 | 22 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 24 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 25 | 26 | if freeze_bn: 27 | self.freeze_bn() 28 | 29 | def forward(self, input): 30 | x, low_level_feat = self.backbone(input) 31 | x = self.aspp(x) 32 | feature = x 33 | x1, x2 = self.decoder(x, low_level_feat) 34 | 35 | x2 = F.interpolate(x2, size=input.size()[2:], mode='bilinear', align_corners=True) 36 | x1 = F.interpolate(x1, size=input.size()[2:], mode='bilinear', align_corners=True) 37 | return x1, x2, feature 38 | 39 | def freeze_bn(self): 40 | for m in self.modules(): 41 | if isinstance(m, SynchronizedBatchNorm2d): 42 | m.eval() 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.eval() 45 | 46 | def get_1x_lr_params(self): 47 | modules = [self.backbone] 48 | for i in range(len(modules)): 49 | for m in modules[i].named_modules(): 50 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 51 | or isinstance(m[1], nn.BatchNorm2d): 52 | for p in m[1].parameters(): 53 | if p.requires_grad: 54 | yield p 55 | 56 | def get_10x_lr_params(self): 57 | modules = [self.aspp, self.decoder] 58 | for i in range(len(modules)): 59 | for m in modules[i].named_modules(): 60 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 61 | or isinstance(m[1], nn.BatchNorm2d): 62 | for p in m[1].parameters(): 63 | if p.requires_grad: 64 | yield p 65 | 66 | 67 | if __name__ == "__main__": 68 | model = DeepLab(backbone='mobilenet', output_stride=16) 69 | model.eval() 70 | input = torch.rand(1, 3, 513, 513) 71 | output = model(input) 72 | print(output.size()) 73 | 74 | 75 | -------------------------------------------------------------------------------- /networks/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 142 | as the built-in PyTorch implementation. 143 | The mean and standard-deviation are calculated per-dimension over 144 | the mini-batches and gamma and beta are learnable parameter vectors 145 | of size C (where C is the input size). 146 | During training, this layer keeps a running estimate of its computed mean 147 | and variance. The running sum is kept with a default momentum of 0.1. 148 | During evaluation, this running mean/variance is used for normalization. 149 | Because the BatchNorm is done over the `C` dimension, computing statistics 150 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 151 | Args: 152 | num_features: num_features from an expected input of size 153 | `batch_size x num_features [x width]` 154 | eps: a value added to the denominator for numerical stability. 155 | Default: 1e-5 156 | momentum: the value used for the running_mean and running_var 157 | computation. Default: 0.1 158 | affine: a boolean value that when set to ``True``, gives the layer learnable 159 | affine parameters. Default: ``True`` 160 | Shape: 161 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 162 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 163 | Examples: 164 | >>> # With Learnable Parameters 165 | >>> m = SynchronizedBatchNorm1d(100) 166 | >>> # Without Learnable Parameters 167 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 168 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 169 | >>> output = m(input) 170 | """ 171 | 172 | def _check_input_dim(self, input): 173 | if input.dim() != 2 and input.dim() != 3: 174 | raise ValueError('expected 2D or 3D input (got {}D input)' 175 | .format(input.dim())) 176 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 177 | 178 | 179 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 180 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 181 | of 3d inputs 182 | .. math:: 183 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 184 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 185 | standard-deviation are reduced across all devices during training. 186 | For example, when one uses `nn.DataParallel` to wrap the network during 187 | training, PyTorch's implementation normalize the tensor on each device using 188 | the statistics only on that device, which accelerated the computation and 189 | is also easy to implement, but the statistics might be inaccurate. 190 | Instead, in this synchronized version, the statistics will be computed 191 | over all training samples distributed on multiple devices. 192 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 193 | as the built-in PyTorch implementation. 194 | The mean and standard-deviation are calculated per-dimension over 195 | the mini-batches and gamma and beta are learnable parameter vectors 196 | of size C (where C is the input size). 197 | During training, this layer keeps a running estimate of its computed mean 198 | and variance. The running sum is kept with a default momentum of 0.1. 199 | During evaluation, this running mean/variance is used for normalization. 200 | Because the BatchNorm is done over the `C` dimension, computing statistics 201 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 202 | Args: 203 | num_features: num_features from an expected input of 204 | size batch_size x num_features x height x width 205 | eps: a value added to the denominator for numerical stability. 206 | Default: 1e-5 207 | momentum: the value used for the running_mean and running_var 208 | computation. Default: 0.1 209 | affine: a boolean value that when set to ``True``, gives the layer learnable 210 | affine parameters. Default: ``True`` 211 | Shape: 212 | - Input: :math:`(N, C, H, W)` 213 | - Output: :math:`(N, C, H, W)` (same shape as input) 214 | Examples: 215 | >>> # With Learnable Parameters 216 | >>> m = SynchronizedBatchNorm2d(100) 217 | >>> # Without Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 219 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 220 | >>> output = m(input) 221 | """ 222 | 223 | def _check_input_dim(self, input): 224 | if input.dim() != 4: 225 | raise ValueError('expected 4D input (got {}D input)' 226 | .format(input.dim())) 227 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 228 | 229 | 230 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 231 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 232 | of 4d inputs 233 | .. math:: 234 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 235 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 236 | standard-deviation are reduced across all devices during training. 237 | For example, when one uses `nn.DataParallel` to wrap the network during 238 | training, PyTorch's implementation normalize the tensor on each device using 239 | the statistics only on that device, which accelerated the computation and 240 | is also easy to implement, but the statistics might be inaccurate. 241 | Instead, in this synchronized version, the statistics will be computed 242 | over all training samples distributed on multiple devices. 243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 244 | as the built-in PyTorch implementation. 245 | The mean and standard-deviation are calculated per-dimension over 246 | the mini-batches and gamma and beta are learnable parameter vectors 247 | of size C (where C is the input size). 248 | During training, this layer keeps a running estimate of its computed mean 249 | and variance. The running sum is kept with a default momentum of 0.1. 250 | During evaluation, this running mean/variance is used for normalization. 251 | Because the BatchNorm is done over the `C` dimension, computing statistics 252 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 253 | or Spatio-temporal BatchNorm 254 | Args: 255 | num_features: num_features from an expected input of 256 | size batch_size x num_features x depth x height x width 257 | eps: a value added to the denominator for numerical stability. 258 | Default: 1e-5 259 | momentum: the value used for the running_mean and running_var 260 | computation. Default: 0.1 261 | affine: a boolean value that when set to ``True``, gives the layer learnable 262 | affine parameters. Default: ``True`` 263 | Shape: 264 | - Input: :math:`(N, C, D, H, W)` 265 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 266 | Examples: 267 | >>> # With Learnable Parameters 268 | >>> m = SynchronizedBatchNorm3d(100) 269 | >>> # Without Learnable Parameters 270 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 271 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 272 | >>> output = m(input) 273 | """ 274 | 275 | def _check_input_dim(self, input): 276 | if input.dim() != 5: 277 | raise ValueError('expected 5D input (got {}D input)' 278 | .format(input.dim())) 279 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import os.path as osp 6 | import torch.nn.functional as F 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import tqdm 11 | from dataloaders import fundus_dataloader as DL 12 | from torch.utils.data import DataLoader 13 | from dataloaders import custom_transforms as tr 14 | from torchvision import transforms 15 | from scipy.misc import imsave 16 | from utils.Utils import * 17 | from utils.metrics import * 18 | from datetime import datetime 19 | import pytz 20 | from networks.deeplabv3 import * 21 | import cv2 22 | 23 | 24 | def main(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--model-file', type=str, default='./logs/train2/20181202_160326.365442/checkpoint_9.pth.tar', 27 | help='Model path') 28 | parser.add_argument( 29 | '--dataset', type=str, default='Drishti-GS', help='test folder id contain images ROIs to test' 30 | ) 31 | parser.add_argument('-g', '--gpu', type=int, default=0) 32 | 33 | parser.add_argument( 34 | '--data-dir', 35 | default='/home/sjwang/ssd1T/fundus/domain_adaptation/', 36 | help='data root path' 37 | ) 38 | parser.add_argument( 39 | '--out-stride', 40 | type=int, 41 | default=16, 42 | help='out-stride of deeplabv3+', 43 | ) 44 | parser.add_argument( 45 | '--save-root-ent', 46 | type=str, 47 | default='./results/ent/', 48 | help='path to save ent', 49 | ) 50 | parser.add_argument( 51 | '--save-root-mask', 52 | type=str, 53 | default='./results/mask/', 54 | help='path to save mask', 55 | ) 56 | parser.add_argument( 57 | '--sync-bn', 58 | type=bool, 59 | default=True, 60 | help='sync-bn in deeplabv3+', 61 | ) 62 | parser.add_argument( 63 | '--freeze-bn', 64 | type=bool, 65 | default=False, 66 | help='freeze batch normalization of deeplabv3+', 67 | ) 68 | parser.add_argument('--test-prediction-save-path', type=str, 69 | default='./results/baseline/', 70 | help='Path root for test image and mask') 71 | args = parser.parse_args() 72 | 73 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 74 | model_file = args.model_file 75 | 76 | # 1. dataset 77 | composed_transforms_test = transforms.Compose([ 78 | tr.Normalize_tf(), 79 | tr.ToTensor() 80 | ]) 81 | db_test = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='test', 82 | transform=composed_transforms_test) 83 | 84 | test_loader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 85 | 86 | # 2. model 87 | model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, 88 | sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() 89 | 90 | if torch.cuda.is_available(): 91 | model = model.cuda() 92 | print('==> Loading %s model file: %s' % 93 | (model.__class__.__name__, model_file)) 94 | checkpoint = torch.load(model_file) 95 | try: 96 | model.load_state_dict(model_data) 97 | pretrained_dict = checkpoint['model_state_dict'] 98 | model_dict = model_gen.state_dict() 99 | # 1. filter out unnecessary keys 100 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 101 | # 2. overwrite entries in the existing state dict 102 | model_dict.update(pretrained_dict) 103 | # 3. load the new state dict 104 | model_gen.load_state_dict(model_dict) 105 | 106 | except Exception: 107 | model.load_state_dict(checkpoint['model_state_dict']) 108 | model.eval() 109 | print('==> Evaluating with %s' % (args.dataset)) 110 | 111 | val_cup_dice = 0.0 112 | val_disc_dice = 0.0 113 | timestamp_start = \ 114 | datetime.now(pytz.timezone('Asia/Hong_Kong')) 115 | 116 | for batch_idx, (sample) in tqdm.tqdm(enumerate(test_loader), 117 | total=len(test_loader), 118 | ncols=80, leave=False): 119 | data = sample['image'] 120 | target = sample['map'] 121 | img_name = sample['img_name'] 122 | if torch.cuda.is_available(): 123 | data, target = data.cuda(), target.cuda() 124 | data, target = Variable(data), Variable(target) 125 | prediction, boundary = model(data) 126 | prediction = torch.nn.functional.interpolate(prediction, size=(target.size()[2], target.size()[3]), 127 | mode="bilinear") 128 | boundary = torch.nn.functional.interpolate(boundary, size=(target.size()[2], target.size()[3]), 129 | mode="bilinear") 130 | data = torch.nn.functional.interpolate(data, size=(target.size()[2], target.size()[3]), mode="bilinear") 131 | prediction = torch.sigmoid(prediction) 132 | boundary = torch.sigmoid(boundary) 133 | draw_ent(prediction.data.cpu()[0].numpy(), os.path.join(args.save_root_ent, args.dataset), img_name[0]) 134 | draw_mask(prediction.data.cpu()[0].numpy(), os.path.join(args.save_root_mask, args.dataset), img_name[0]) 135 | draw_boundary(boundary.data.cpu()[0].numpy(), os.path.join(args.save_root_mask, args.dataset), img_name[0]) 136 | 137 | prediction = postprocessing(prediction.data.cpu()[0], dataset=args.dataset) 138 | target_numpy = target.data.cpu() 139 | cup_dice = dice_coefficient_numpy(prediction[0, ...], target_numpy[0, 0, ...]) 140 | disc_dice = dice_coefficient_numpy(prediction[1, ...], target_numpy[0, 1, ...]) 141 | 142 | val_cup_dice += cup_dice 143 | val_disc_dice += disc_dice 144 | 145 | imgs = data.data.cpu() 146 | 147 | for img, lt, lp in zip(imgs, target_numpy, [prediction]): 148 | img, lt = untransform(img, lt) 149 | save_per_img(img.numpy().transpose(1, 2, 0), os.path.join(args.test_prediction_save_path, args.dataset), 150 | img_name[0], 151 | lp, mask_path=None, ext="bmp") 152 | 153 | val_cup_dice /= len(test_loader) 154 | val_disc_dice /= len(test_loader) 155 | 156 | print('''\n==>val_cup_dice : {0}'''.format(val_cup_dice)) 157 | print('''\n==>val_disc_dice : {0}'''.format(val_disc_dice)) 158 | with open(osp.join(args.test_prediction_save_path, 'test_log.csv'), 'a') as f: 159 | elapsed_time = ( 160 | datetime.now(pytz.timezone('Asia/Hong_Kong')) - 161 | timestamp_start).total_seconds() 162 | log = [[args.model_file] + ['cup dice coefficence: '] + \ 163 | [val_cup_dice] + ['disc dice coefficence: '] + \ 164 | [val_disc_dice] + [elapsed_time]] 165 | log = map(str, log) 166 | f.write(','.join(log) + '\n') 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import os.path as osp 4 | 5 | # PyTorch includes 6 | import torch 7 | from torchvision import transforms 8 | from torch.utils.data import DataLoader 9 | import argparse 10 | import yaml 11 | from train_process import Trainer 12 | 13 | # Custom includes 14 | from dataloaders import fundus_dataloader as DL 15 | from dataloaders import custom_transforms as tr 16 | from networks.deeplabv3 import * 17 | from networks.GAN import BoundaryDiscriminator, UncertaintyDiscriminator 18 | 19 | 20 | here = osp.dirname(osp.abspath(__file__)) 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser( 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 25 | ) 26 | parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id') 27 | parser.add_argument('--resume', default=None, help='checkpoint path') 28 | 29 | # configurations (same configuration as original work) 30 | # https://github.com/shelhamer/fcn.berkeleyvision.org 31 | parser.add_argument( 32 | '--datasetS', type=str, default='refuge', help='test folder id contain images ROIs to test' 33 | ) 34 | parser.add_argument( 35 | '--datasetT', type=str, default='Drishti-GS', help='refuge / Drishti-GS/ RIM-ONE_r3' 36 | ) 37 | parser.add_argument( 38 | '--batch-size', type=int, default=8, help='batch size for training the model' 39 | ) 40 | parser.add_argument( 41 | '--group-num', type=int, default=1, help='group number for group normalization' 42 | ) 43 | parser.add_argument( 44 | '--max-epoch', type=int, default=200, help='max epoch' 45 | ) 46 | parser.add_argument( 47 | '--stop-epoch', type=int, default=200, help='stop epoch' 48 | ) 49 | parser.add_argument( 50 | '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN' 51 | ) 52 | 53 | parser.add_argument( 54 | '--interval-validate', type=int, default=10, help='interval epoch number to valide the model' 55 | ) 56 | parser.add_argument( 57 | '--lr-gen', type=float, default=1e-3, help='learning rate', 58 | ) 59 | parser.add_argument( 60 | '--lr-dis', type=float, default=2.5e-5, help='learning rate', 61 | ) 62 | parser.add_argument( 63 | '--lr-decrease-rate', type=float, default=0.1, help='ratio multiplied to initial lr', 64 | ) 65 | parser.add_argument( 66 | '--weight-decay', type=float, default=0.0005, help='weight decay', 67 | ) 68 | parser.add_argument( 69 | '--momentum', type=float, default=0.99, help='momentum', 70 | ) 71 | parser.add_argument( 72 | '--data-dir', 73 | default='/home/sjwang/ssd1T/fundus/domain_adaptation/', 74 | help='data root path' 75 | ) 76 | parser.add_argument( 77 | '--pretrained-model', 78 | default='../../../models/pytorch/fcn16s_from_caffe.pth', 79 | help='pretrained model of FCN16s', 80 | ) 81 | parser.add_argument( 82 | '--out-stride', 83 | type=int, 84 | default=16, 85 | help='out-stride of deeplabv3+', 86 | ) 87 | parser.add_argument( 88 | '--sync-bn', 89 | type=bool, 90 | default=True, 91 | help='sync-bn in deeplabv3+', 92 | ) 93 | parser.add_argument( 94 | '--freeze-bn', 95 | type=bool, 96 | default=False, 97 | help='freeze batch normalization of deeplabv3+', 98 | ) 99 | 100 | args = parser.parse_args() 101 | 102 | args.model = 'FCN8s' 103 | 104 | now = datetime.now() 105 | args.out = osp.join(here, 'logs', args.datasetT, now.strftime('%Y%m%d_%H%M%S.%f')) 106 | 107 | os.makedirs(args.out) 108 | with open(osp.join(args.out, 'config.yaml'), 'w') as f: 109 | yaml.safe_dump(args.__dict__, f, default_flow_style=False) 110 | 111 | 112 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 113 | cuda = torch.cuda.is_available() 114 | 115 | torch.manual_seed(1337) 116 | if cuda: 117 | torch.cuda.manual_seed(1337) 118 | 119 | # 1. dataset 120 | composed_transforms_tr = transforms.Compose([ 121 | tr.RandomScaleCrop(512), 122 | tr.RandomRotate(), 123 | tr.RandomFlip(), 124 | tr.elastic_transform(), 125 | tr.add_salt_pepper_noise(), 126 | tr.adjust_light(), 127 | tr.eraser(), 128 | tr.Normalize_tf(), 129 | tr.ToTensor() 130 | ]) 131 | 132 | composed_transforms_ts = transforms.Compose([ 133 | tr.RandomCrop(512), 134 | tr.Normalize_tf(), 135 | tr.ToTensor() 136 | ]) 137 | 138 | domain = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetS, split='train', 139 | transform=composed_transforms_tr) 140 | domain_loaderS = DataLoader(domain, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) 141 | domain_T = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train', 142 | transform=composed_transforms_tr) 143 | domain_loaderT = DataLoader(domain_T, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) 144 | domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train', 145 | transform=composed_transforms_ts) 146 | domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) 147 | 148 | # 2. model 149 | model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, 150 | sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() 151 | 152 | model_dis = BoundaryDiscriminator().cuda() 153 | model_dis2 = UncertaintyDiscriminator().cuda() 154 | 155 | start_epoch = 0 156 | start_iteration = 0 157 | 158 | # 3. optimizer 159 | 160 | optim_gen = torch.optim.Adam( 161 | model_gen.parameters(), 162 | lr=args.lr_gen, 163 | betas=(0.9, 0.99) 164 | ) 165 | optim_dis = torch.optim.SGD( 166 | model_dis.parameters(), 167 | lr=args.lr_dis, 168 | momentum=args.momentum, 169 | weight_decay=args.weight_decay 170 | ) 171 | optim_dis2 = torch.optim.SGD( 172 | model_dis2.parameters(), 173 | lr=args.lr_dis, 174 | momentum=args.momentum, 175 | weight_decay=args.weight_decay 176 | ) 177 | 178 | if args.resume: 179 | checkpoint = torch.load(args.resume) 180 | pretrained_dict = checkpoint['model_state_dict'] 181 | model_dict = model_gen.state_dict() 182 | # 1. filter out unnecessary keys 183 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 184 | # 2. overwrite entries in the existing state dict 185 | model_dict.update(pretrained_dict) 186 | # 3. load the new state dict 187 | model_gen.load_state_dict(model_dict) 188 | 189 | pretrained_dict = checkpoint['model_dis_state_dict'] 190 | model_dict = model_dis.state_dict() 191 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 192 | model_dict.update(pretrained_dict) 193 | model_dis.load_state_dict(model_dict) 194 | 195 | pretrained_dict = checkpoint['model_dis2_state_dict'] 196 | model_dict = model_dis2.state_dict() 197 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 198 | model_dict.update(pretrained_dict) 199 | model_dis2.load_state_dict(model_dict) 200 | 201 | 202 | start_epoch = checkpoint['epoch'] + 1 203 | start_iteration = checkpoint['iteration'] + 1 204 | optim_gen.load_state_dict(checkpoint['optim_state_dict']) 205 | optim_dis.load_state_dict(checkpoint['optim_dis_state_dict']) 206 | optim_dis2.load_state_dict(checkpoint['optim_dis2_state_dict']) 207 | 208 | trainer = Trainer.Trainer( 209 | cuda=cuda, 210 | model_gen=model_gen, 211 | model_dis=model_dis, 212 | model_uncertainty_dis=model_dis2, 213 | optimizer_gen=optim_gen, 214 | optimizer_dis=optim_dis, 215 | optimizer_uncertainty_dis=optim_dis2, 216 | lr_gen=args.lr_gen, 217 | lr_dis=args.lr_dis, 218 | lr_decrease_rate=args.lr_decrease_rate, 219 | val_loader=domain_loader_val, 220 | domain_loaderS=domain_loaderS, 221 | domain_loaderT=domain_loaderT, 222 | out=args.out, 223 | max_epoch=args.max_epoch, 224 | stop_epoch=args.stop_epoch, 225 | interval_validate=args.interval_validate, 226 | batch_size=args.batch_size, 227 | warmup_epoch=args.warmup_epoch, 228 | ) 229 | trainer.epoch = start_epoch 230 | trainer.iteration = start_iteration 231 | trainer.train() 232 | 233 | if __name__ == '__main__': 234 | main() 235 | -------------------------------------------------------------------------------- /train_process/Trainer.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import os.path as osp 4 | import timeit 5 | from torchvision.utils import make_grid 6 | import time 7 | 8 | import numpy as np 9 | import pytz 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from tensorboardX import SummaryWriter 14 | 15 | import tqdm 16 | import socket 17 | from utils.metrics import * 18 | from utils.Utils import * 19 | 20 | bceloss = torch.nn.BCELoss() 21 | mseloss = torch.nn.MSELoss() 22 | 23 | def get_lr(optimizer): 24 | for param_group in optimizer.param_groups: 25 | return param_group['lr'] 26 | 27 | class Trainer(object): 28 | 29 | def __init__(self, cuda, model_gen, model_dis, model_uncertainty_dis, optimizer_gen, optimizer_dis, optimizer_uncertainty_dis, 30 | val_loader, domain_loaderS, domain_loaderT, out, max_epoch, stop_epoch=None, 31 | lr_gen=1e-3, lr_dis=1e-3, lr_decrease_rate=0.1, interval_validate=None, batch_size=8, warmup_epoch=10): 32 | self.cuda = cuda 33 | self.warmup_epoch = warmup_epoch 34 | self.model_gen = model_gen 35 | self.model_dis2 = model_uncertainty_dis 36 | self.model_dis = model_dis 37 | self.optim_gen = optimizer_gen 38 | self.optim_dis = optimizer_dis 39 | self.optim_dis2 = optimizer_uncertainty_dis 40 | self.lr_gen = lr_gen 41 | self.lr_dis = lr_dis 42 | self.lr_decrease_rate = lr_decrease_rate 43 | self.batch_size = batch_size 44 | 45 | self.val_loader = val_loader 46 | self.domain_loaderS = domain_loaderS 47 | self.domain_loaderT = domain_loaderT 48 | self.time_zone = 'Asia/Hong_Kong' 49 | self.timestamp_start = \ 50 | datetime.now(pytz.timezone(self.time_zone)) 51 | 52 | if interval_validate is None: 53 | self.interval_validate = int(10) 54 | else: 55 | self.interval_validate = interval_validate 56 | 57 | self.out = out 58 | if not osp.exists(self.out): 59 | os.makedirs(self.out) 60 | 61 | self.log_headers = [ 62 | 'epoch', 63 | 'iteration', 64 | 'train/loss_seg', 65 | 'train/cup_dice', 66 | 'train/disc_dice', 67 | 'train/loss_adv', 68 | 'train/loss_D_same', 69 | 'train/loss_D_diff', 70 | 'valid/loss_CE', 71 | 'valid/cup_dice', 72 | 'valid/disc_dice', 73 | 'elapsed_time', 74 | ] 75 | if not osp.exists(osp.join(self.out, 'log.csv')): 76 | with open(osp.join(self.out, 'log.csv'), 'w') as f: 77 | f.write(','.join(self.log_headers) + '\n') 78 | 79 | log_dir = os.path.join(self.out, 'tensorboard', 80 | datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 81 | self.writer = SummaryWriter(log_dir=log_dir) 82 | 83 | self.epoch = 0 84 | self.iteration = 0 85 | self.max_epoch = max_epoch 86 | self.stop_epoch = stop_epoch if stop_epoch is not None else max_epoch 87 | self.best_disc_dice = 0.0 88 | self.running_loss_tr = 0.0 89 | self.running_adv_diff_loss = 0.0 90 | self.running_adv_same_loss = 0.0 91 | self.best_mean_dice = 0.0 92 | self.best_epoch = -1 93 | 94 | 95 | def validate(self): 96 | training = self.model_gen.training 97 | self.model_gen.eval() 98 | 99 | val_loss = 0 100 | val_cup_dice = 0 101 | val_disc_dice = 0 102 | metrics = [] 103 | with torch.no_grad(): 104 | 105 | for batch_idx, sample in tqdm.tqdm( 106 | enumerate(self.val_loader), total=len(self.val_loader), 107 | desc='Valid iteration=%d' % self.iteration, ncols=80, 108 | leave=False): 109 | data = sample['image'] 110 | target_map = sample['map'] 111 | target_boundary = sample['boundary'] 112 | if self.cuda: 113 | data, target_map, target_boundary = data.cuda(), target_map.cuda(), target_boundary.cuda() 114 | with torch.no_grad(): 115 | predictions, boundary = self.model_gen(data) 116 | 117 | loss = F.binary_cross_entropy_with_logits(predictions, target_map) 118 | loss_data = loss.data.item() 119 | if np.isnan(loss_data): 120 | raise ValueError('loss is nan while validating') 121 | val_loss += loss_data 122 | 123 | dice_cup, dice_disc = dice_coeff_2label(predictions, target_map) 124 | val_cup_dice += dice_cup 125 | val_disc_dice += dice_disc 126 | val_loss /= len(self.val_loader) 127 | val_cup_dice /= len(self.val_loader) 128 | val_disc_dice /= len(self.val_loader) 129 | metrics.append((val_loss, val_cup_dice, val_disc_dice)) 130 | self.writer.add_scalar('val_data/loss_CE', val_loss, self.epoch * (len(self.domain_loaderS))) 131 | self.writer.add_scalar('val_data/val_CUP_dice', val_cup_dice, self.epoch * (len(self.domain_loaderS))) 132 | self.writer.add_scalar('val_data/val_DISC_dice', val_disc_dice, self.epoch * (len(self.domain_loaderS))) 133 | 134 | mean_dice = val_cup_dice + val_disc_dice 135 | is_best = mean_dice > self.best_mean_dice 136 | if is_best: 137 | self.best_epoch = self.epoch + 1 138 | self.best_mean_dice = mean_dice 139 | 140 | torch.save({ 141 | 'epoch': self.epoch, 142 | 'iteration': self.iteration, 143 | 'arch': self.model_gen.__class__.__name__, 144 | 'optim_state_dict': self.optim_gen.state_dict(), 145 | 'optim_dis_state_dict': self.optim_dis.state_dict(), 146 | 'optim_dis2_state_dict': self.optim_dis2.state_dict(), 147 | 'model_state_dict': self.model_gen.state_dict(), 148 | 'model_dis_state_dict': self.model_dis.state_dict(), 149 | 'model_dis2_state_dict': self.model_dis2.state_dict(), 150 | 'learning_rate_gen': get_lr(self.optim_gen), 151 | 'learning_rate_dis': get_lr(self.optim_dis), 152 | 'learning_rate_dis2': get_lr(self.optim_dis2), 153 | 'best_mean_dice': self.best_mean_dice, 154 | }, osp.join(self.out, 'checkpoint_%d.pth.tar' % self.best_epoch)) 155 | else: 156 | if (self.epoch + 1) % 50 == 0: 157 | torch.save({ 158 | 'epoch': self.epoch, 159 | 'iteration': self.iteration, 160 | 'arch': self.model_gen.__class__.__name__, 161 | 'optim_state_dict': self.optim_gen.state_dict(), 162 | 'optim_dis_state_dict': self.optim_dis.state_dict(), 163 | 'optim_dis2_state_dict': self.optim_dis2.state_dict(), 164 | 'model_state_dict': self.model_gen.state_dict(), 165 | 'model_dis_state_dict': self.model_dis.state_dict(), 166 | 'model_dis2_state_dict': self.model_dis2.state_dict(), 167 | 'learning_rate_gen': get_lr(self.optim_gen), 168 | 'learning_rate_dis': get_lr(self.optim_dis), 169 | 'learning_rate_dis2': get_lr(self.optim_dis2), 170 | 'best_mean_dice': self.best_mean_dice, 171 | }, osp.join(self.out, 'checkpoint_%d.pth.tar' % (self.epoch + 1))) 172 | 173 | 174 | with open(osp.join(self.out, 'log.csv'), 'a') as f: 175 | elapsed_time = ( 176 | datetime.now(pytz.timezone(self.time_zone)) - 177 | self.timestamp_start).total_seconds() 178 | log = [self.epoch, self.iteration] + [''] * 5 + \ 179 | list(metrics) + [elapsed_time] + ['best model epoch: %d' % self.best_epoch] 180 | log = map(str, log) 181 | f.write(','.join(log) + '\n') 182 | self.writer.add_scalar('best_model_epoch', self.best_epoch, self.epoch * (len(self.domain_loaderS))) 183 | if training: 184 | self.model_gen.train() 185 | self.model_dis.train() 186 | self.model_dis2.train() 187 | 188 | 189 | def train_epoch(self): 190 | source_domain_label = 1 191 | target_domain_label = 0 192 | smooth = 1e-7 193 | self.model_gen.train() 194 | self.model_dis.train() 195 | self.model_dis2.train() 196 | self.running_seg_loss = 0.0 197 | self.running_adv_loss = 0.0 198 | self.running_dis_diff_loss = 0.0 199 | self.running_dis_same_loss = 0.0 200 | self.running_total_loss = 0.0 201 | self.running_cup_dice_tr = 0.0 202 | self.running_disc_dice_tr = 0.0 203 | loss_adv_diff_data = 0 204 | loss_D_same_data = 0 205 | loss_D_diff_data = 0 206 | 207 | domain_t_loader = enumerate(self.domain_loaderT) 208 | start_time = timeit.default_timer() 209 | for batch_idx, sampleS in tqdm.tqdm( 210 | enumerate(self.domain_loaderS), total=len(self.domain_loaderS), 211 | desc='Train epoch=%d' % self.epoch, ncols=80, leave=False): 212 | 213 | metrics = [] 214 | 215 | iteration = batch_idx + self.epoch * len(self.domain_loaderS) 216 | self.iteration = iteration 217 | 218 | assert self.model_gen.training 219 | assert self.model_dis.training 220 | assert self.model_dis2.training 221 | 222 | self.optim_gen.zero_grad() 223 | self.optim_dis.zero_grad() 224 | self.optim_dis2.zero_grad() 225 | 226 | # 1. train generator with random images 227 | for param in self.model_dis.parameters(): 228 | param.requires_grad = False 229 | for param in self.model_dis2.parameters(): 230 | param.requires_grad = False 231 | for param in self.model_gen.parameters(): 232 | param.requires_grad = True 233 | 234 | imageS = sampleS['image'].cuda() 235 | target_map = sampleS['map'].cuda() 236 | target_boundary = sampleS['boundary'].cuda() 237 | 238 | oS, boundaryS = self.model_gen(imageS) 239 | 240 | loss_seg1 = bceloss(torch.sigmoid(oS), target_map) 241 | loss_seg2 = mseloss(torch.sigmoid(boundaryS), target_boundary) 242 | loss_seg = loss_seg1 + loss_seg2 243 | 244 | self.running_seg_loss += loss_seg.item() 245 | loss_seg_data = loss_seg.data.item() 246 | if np.isnan(loss_seg_data): 247 | raise ValueError('loss is nan while training') 248 | 249 | # cup_dice, disc_dice = dice_coeff_2label(oS, target_map) 250 | 251 | loss_seg.backward() 252 | # self.optim_gen.step() 253 | 254 | # write image log 255 | if iteration % 30 == 0: 256 | grid_image = make_grid( 257 | imageS[0, ...].clone().cpu().data, 1, normalize=True) 258 | self.writer.add_image('DomainS/image', grid_image, iteration) 259 | grid_image = make_grid( 260 | target_map[0, 0, ...].clone().cpu().data, 1, normalize=True) 261 | self.writer.add_image('DomainS/target_cup', grid_image, iteration) 262 | grid_image = make_grid( 263 | target_map[0, 1, ...].clone().cpu().data, 1, normalize=True) 264 | self.writer.add_image('DomainS/target_disc', grid_image, iteration) 265 | grid_image = make_grid( 266 | target_boundary[0, 0, ...].clone().cpu().data, 1, normalize=True) 267 | self.writer.add_image('DomainS/target_boundary', grid_image, iteration) 268 | grid_image = make_grid(torch.sigmoid(oS)[0, 0, ...].clone().cpu().data, 1, normalize=True) 269 | self.writer.add_image('DomainS/prediction_cup', grid_image, iteration) 270 | grid_image = make_grid(torch.sigmoid(oS)[0, 1, ...].clone().cpu().data, 1, normalize=True) 271 | self.writer.add_image('DomainS/prediction_disc', grid_image, iteration) 272 | grid_image = make_grid(torch.sigmoid(boundaryS)[0, 0, ...].clone().cpu().data, 1, normalize=True) 273 | self.writer.add_image('DomainS/prediction_boundary', grid_image, iteration) 274 | 275 | if self.epoch > self.warmup_epoch: 276 | # # 2. train generator with images from different domain 277 | try: 278 | id_, sampleT = next(domain_t_loader) 279 | except: 280 | domain_t_loader = enumerate(self.domain_loaderT) 281 | id_, sampleT = next(domain_t_loader) 282 | 283 | imageT = sampleT['image'].cuda() 284 | 285 | oT, boundaryT = self.model_gen(imageT) 286 | uncertainty_mapT = -1.0 * torch.sigmoid(oT) * torch.log(torch.sigmoid(oT) + smooth) 287 | D_out2 = self.model_dis(torch.sigmoid(boundaryT)) 288 | D_out1 = self.model_dis2(uncertainty_mapT) 289 | 290 | loss_adv_diff1 = F.binary_cross_entropy_with_logits(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(source_domain_label).cuda()) 291 | loss_adv_diff2 = F.binary_cross_entropy_with_logits(D_out2, torch.FloatTensor(D_out2.data.size()).fill_(source_domain_label).cuda()) 292 | loss_adv_diff = 0.01 * (loss_adv_diff1 + loss_adv_diff2) 293 | self.running_adv_diff_loss += loss_adv_diff.item() 294 | loss_adv_diff_data = loss_adv_diff.data.item() 295 | if np.isnan(loss_adv_diff_data): 296 | raise ValueError('loss_adv_diff_data is nan while training') 297 | 298 | loss_adv_diff.backward() 299 | self.optim_gen.step() 300 | 301 | # 3. train discriminator with images from same domain 302 | for param in self.model_dis.parameters(): 303 | param.requires_grad = True 304 | for param in self.model_dis2.parameters(): 305 | param.requires_grad = True 306 | for param in self.model_gen.parameters(): 307 | param.requires_grad = False 308 | 309 | boundaryS = boundaryS.detach() 310 | oS = oS.detach() 311 | uncertainty_mapS = -1.0 * torch.sigmoid(oS) * torch.log(torch.sigmoid(oS) + smooth) 312 | D_out2 = self.model_dis(torch.sigmoid(boundaryS)) 313 | D_out1 = self.model_dis2(uncertainty_mapS) 314 | 315 | loss_D_same1 = F.binary_cross_entropy_with_logits(D_out1, torch.FloatTensor(D_out1.data.size()).fill_( 316 | source_domain_label).cuda()) 317 | loss_D_same2 = F.binary_cross_entropy_with_logits(D_out2, torch.FloatTensor(D_out2.data.size()).fill_( 318 | source_domain_label).cuda()) 319 | loss_D_same = loss_D_same1+loss_D_same2 320 | 321 | self.running_dis_same_loss += loss_D_same.item() 322 | loss_D_same_data = loss_D_same.data.item() 323 | if np.isnan(loss_D_same_data): 324 | raise ValueError('loss is nan while training') 325 | loss_D_same.backward() 326 | 327 | # 4. train discriminator with images from different domain 328 | 329 | boundaryT = boundaryT.detach() 330 | oT = oT.detach() 331 | uncertainty_mapT = -1.0 * torch.sigmoid(oT) * torch.log(torch.sigmoid(oT) + smooth) 332 | D_out2 = self.model_dis(torch.sigmoid(boundaryT)) 333 | D_out1 = self.model_dis2(uncertainty_mapT) 334 | 335 | loss_D_diff1 = F.binary_cross_entropy_with_logits(D_out1, torch.FloatTensor(D_out1.data.size()).fill_( 336 | target_domain_label).cuda()) 337 | loss_D_diff2 = F.binary_cross_entropy_with_logits(D_out2, torch.FloatTensor(D_out2.data.size()).fill_( 338 | target_domain_label).cuda()) 339 | loss_D_diff = loss_D_diff1 + loss_D_diff2 340 | self.running_dis_diff_loss += loss_D_diff.item() 341 | loss_D_diff_data = loss_D_diff.data.item() 342 | if np.isnan(loss_D_diff_data): 343 | raise ValueError('loss is nan while training') 344 | loss_D_diff.backward() 345 | 346 | # 5. update parameters 347 | self.optim_dis.step() 348 | self.optim_dis2.step() 349 | 350 | if iteration % 30 == 0: 351 | grid_image = make_grid( 352 | imageT[0, ...].clone().cpu().data, 1, normalize=True) 353 | self.writer.add_image('DomainT/image', grid_image, iteration) 354 | grid_image = make_grid( 355 | sampleT['map'][0, 0, ...].clone().cpu().data, 1, normalize=True) 356 | self.writer.add_image('DomainT/target_cup', grid_image, iteration) 357 | grid_image = make_grid( 358 | sampleT['map'][0, 1, ...].clone().cpu().data, 1, normalize=True) 359 | self.writer.add_image('DomainT/target_disc', grid_image, iteration) 360 | grid_image = make_grid(torch.sigmoid(oT)[0, 0, ...].clone().cpu().data, 1, normalize=True) 361 | self.writer.add_image('DomainT/prediction_cup', grid_image, iteration) 362 | grid_image = make_grid(torch.sigmoid(oT)[0, 1, ...].clone().cpu().data, 1, normalize=True) 363 | self.writer.add_image('DomainT/prediction_disc', grid_image, iteration) 364 | grid_image = make_grid(boundaryS[0, 0, ...].clone().cpu().data, 1, normalize=True) 365 | self.writer.add_image('DomainS/boundaryS', grid_image, iteration) 366 | grid_image = make_grid(boundaryT[0, 0, ...].clone().cpu().data, 1, 367 | normalize=True) 368 | self.writer.add_image('DomainT/boundaryT', grid_image, iteration) 369 | 370 | self.writer.add_scalar('train_adv/loss_adv_diff', loss_adv_diff_data, iteration) 371 | self.writer.add_scalar('train_dis/loss_D_same', loss_D_same_data, iteration) 372 | self.writer.add_scalar('train_dis/loss_D_diff', loss_D_diff_data, iteration) 373 | self.writer.add_scalar('train_gen/loss_seg', loss_seg_data, iteration) 374 | 375 | metrics.append((loss_seg_data, loss_adv_diff_data, loss_D_same_data, loss_D_diff_data)) 376 | metrics = np.mean(metrics, axis=0) 377 | 378 | with open(osp.join(self.out, 'log.csv'), 'a') as f: 379 | elapsed_time = ( 380 | datetime.now(pytz.timezone(self.time_zone)) - 381 | self.timestamp_start).total_seconds() 382 | log = [self.epoch, self.iteration] + \ 383 | metrics.tolist() + [''] * 5 + [elapsed_time] 384 | log = map(str, log) 385 | f.write(','.join(log) + '\n') 386 | 387 | self.running_seg_loss /= len(self.domain_loaderS) 388 | self.running_adv_diff_loss /= len(self.domain_loaderS) 389 | self.running_dis_same_loss /= len(self.domain_loaderS) 390 | self.running_dis_diff_loss /= len(self.domain_loaderS) 391 | 392 | stop_time = timeit.default_timer() 393 | 394 | print('\n[Epoch: %d] lr:%f, Average segLoss: %f, ' 395 | ' Average advLoss: %f, Average dis_same_Loss: %f, ' 396 | 'Average dis_diff_Lyoss: %f,' 397 | 'Execution time: %.5f' % 398 | (self.epoch, get_lr(self.optim_gen), self.running_seg_loss, 399 | self.running_adv_diff_loss, 400 | self.running_dis_same_loss, self.running_dis_diff_loss, stop_time - start_time)) 401 | 402 | 403 | def train(self): 404 | for epoch in tqdm.trange(self.epoch, self.max_epoch, 405 | desc='Train', ncols=80): 406 | self.epoch = epoch 407 | self.train_epoch() 408 | if self.stop_epoch == self.epoch: 409 | print('Stop epoch at %d' % self.stop_epoch) 410 | break 411 | 412 | if (epoch+1) % 100 == 0: 413 | _lr_gen = self.lr_gen * 0.2 414 | for param_group in self.optim_gen.param_groups: 415 | param_group['lr'] = _lr_gen 416 | self.writer.add_scalar('lr_gen', get_lr(self.optim_gen), self.epoch * (len(self.domain_loaderS))) 417 | if (self.epoch+1) % self.interval_validate == 0: 418 | self.validate() 419 | self.writer.close() 420 | 421 | 422 | 423 | -------------------------------------------------------------------------------- /train_process/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emma-sjwang/BEAL/945cad38a354605b8bca5bc01ae1b65848d605e1/train_process/__init__.py -------------------------------------------------------------------------------- /utils/Utils.py: -------------------------------------------------------------------------------- 1 | 2 | from scipy.misc import imsave 3 | import os.path as osp 4 | import numpy as np 5 | import os 6 | import cv2 7 | from skimage import morphology 8 | import scipy 9 | from PIL import Image 10 | from matplotlib.pyplot import imsave 11 | # from keras.preprocessing import image 12 | from skimage.measure import label, regionprops 13 | from skimage.transform import rotate, resize 14 | from skimage import measure, draw 15 | 16 | import matplotlib.pyplot as plt 17 | plt.switch_backend('agg') 18 | 19 | from scipy.misc import imsave 20 | from utils.metrics import * 21 | import cv2 22 | 23 | 24 | def construct_color_img(prob_per_slice): 25 | shape = prob_per_slice.shape 26 | img = np.zeros((shape[0], shape[1], 3), dtype=np.uint8) 27 | img[:, :, 0] = prob_per_slice * 255 28 | img[:, :, 1] = prob_per_slice * 255 29 | img[:, :, 2] = prob_per_slice * 255 30 | 31 | im_color = cv2.applyColorMap(img, cv2.COLORMAP_JET) 32 | return im_color 33 | 34 | 35 | def normalize_ent(ent): 36 | ''' 37 | Normalizate ent to 0 - 1 38 | :param ent: 39 | :return: 40 | ''' 41 | min = np.amin(ent) 42 | return (ent - min) / 0.4 43 | 44 | 45 | def draw_ent(prediction, save_root, name): 46 | ''' 47 | Draw the entropy information for each img and save them to the save path 48 | :param prediction: [2, h, w] numpy 49 | :param save_path: string including img name 50 | :return: None 51 | ''' 52 | if not os.path.exists(os.path.join(save_root, 'disc')): 53 | os.makedirs(os.path.join(save_root, 'disc')) 54 | if not os.path.exists(os.path.join(save_root, 'cup')): 55 | os.makedirs(os.path.join(save_root, 'cup')) 56 | smooth = 1e-8 57 | cup = prediction[0] 58 | disc = prediction[1] 59 | cup_ent = - cup * np.log(cup + smooth) 60 | disc_ent = - disc * np.log(disc + smooth) 61 | cup_ent = normalize_ent(cup_ent) 62 | disc_ent = normalize_ent(disc_ent) 63 | disc = construct_color_img(disc_ent) 64 | cv2.imwrite(os.path.join(save_root, 'disc', name.split('.')[0]) + '.png', disc) 65 | cup = construct_color_img(cup_ent) 66 | cv2.imwrite(os.path.join(save_root, 'cup', name.split('.')[0]) + '.png', cup) 67 | 68 | 69 | def draw_mask(prediction, save_root, name): 70 | ''' 71 | Draw the mask probability for each img and save them to the save path 72 | :param prediction: [2, h, w] numpy 73 | :param save_path: string including img name 74 | :return: None 75 | ''' 76 | if not os.path.exists(os.path.join(save_root, 'disc')): 77 | os.makedirs(os.path.join(save_root, 'disc')) 78 | if not os.path.exists(os.path.join(save_root, 'cup')): 79 | os.makedirs(os.path.join(save_root, 'cup')) 80 | cup = prediction[0] 81 | disc = prediction[1] 82 | 83 | disc = construct_color_img(disc) 84 | cv2.imwrite(os.path.join(save_root, 'disc', name.split('.')[0]) + '.png', disc) 85 | cup = construct_color_img(cup) 86 | cv2.imwrite(os.path.join(save_root, 'cup', name.split('.')[0]) + '.png', cup) 87 | 88 | def draw_boundary(prediction, save_root, name): 89 | ''' 90 | Draw the mask probability for each img and save them to the save path 91 | :param prediction: [2, h, w] numpy 92 | :param save_path: string including img name 93 | :return: None 94 | ''' 95 | if not os.path.exists(os.path.join(save_root, 'boundary')): 96 | os.makedirs(os.path.join(save_root, 'boundary')) 97 | boundary = prediction[0] 98 | boundary = construct_color_img(boundary) 99 | cv2.imwrite(os.path.join(save_root, 'boundary', name.split('.')[0]) + '.png', boundary) 100 | 101 | 102 | def get_largest_fillhole(binary): 103 | label_image = label(binary) 104 | regions = regionprops(label_image) 105 | area_list = [] 106 | for region in regions: 107 | area_list.append(region.area) 108 | if area_list: 109 | idx_max = np.argmax(area_list) 110 | binary[label_image != idx_max + 1] = 0 111 | return scipy.ndimage.binary_fill_holes(np.asarray(binary).astype(int)) 112 | 113 | def postprocessing(prediction, threshold=0.75, dataset='G'): 114 | if dataset[0] == 'D': 115 | prediction = prediction.numpy() 116 | prediction_copy = np.copy(prediction) 117 | disc_mask = prediction[1] 118 | cup_mask = prediction[0] 119 | disc_mask = (disc_mask > 0.5) # return binary mask 120 | cup_mask = (cup_mask > 0.1) # return binary mask 121 | disc_mask = disc_mask.astype(np.uint8) 122 | cup_mask = cup_mask.astype(np.uint8) 123 | for i in range(5): 124 | disc_mask = scipy.signal.medfilt2d(disc_mask, 7) 125 | cup_mask = scipy.signal.medfilt2d(cup_mask, 7) 126 | disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 127 | cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 128 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1 129 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8) 130 | prediction_copy[0] = cup_mask 131 | prediction_copy[1] = disc_mask 132 | return prediction_copy 133 | else: 134 | prediction = prediction.numpy() 135 | prediction = (prediction > threshold) # return binary mask 136 | prediction = prediction.astype(np.uint8) 137 | prediction_copy = np.copy(prediction) 138 | disc_mask = prediction[1] 139 | cup_mask = prediction[0] 140 | for i in range(5): 141 | disc_mask = scipy.signal.medfilt2d(disc_mask, 7) 142 | cup_mask = scipy.signal.medfilt2d(cup_mask, 7) 143 | disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 144 | cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 145 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1 146 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8) 147 | prediction_copy[0] = cup_mask 148 | prediction_copy[1] = disc_mask 149 | return prediction_copy 150 | 151 | 152 | def joint_val_image(image, prediction, mask): 153 | ratio = 0.5 154 | _pred_cup = np.zeros([mask.shape[-2], mask.shape[-1], 3]) 155 | _pred_disc = np.zeros([mask.shape[-2], mask.shape[-1], 3]) 156 | _mask = np.zeros([mask.shape[-2], mask.shape[-1], 3]) 157 | image = np.transpose(image, (1, 2, 0)) 158 | 159 | _pred_cup[:, :, 0] = prediction[0] 160 | _pred_cup[:, :, 1] = prediction[0] 161 | _pred_cup[:, :, 2] = prediction[0] 162 | _pred_disc[:, :, 0] = prediction[1] 163 | _pred_disc[:, :, 1] = prediction[1] 164 | _pred_disc[:, :, 2] = prediction[1] 165 | _mask[:,:,0] = mask[0] 166 | _mask[:,:,1] = mask[1] 167 | 168 | pred_cup = np.add(ratio * image, (1 - ratio) * _pred_cup) 169 | pred_disc = np.add(ratio * image, (1 - ratio) * _pred_disc) 170 | mask_img = np.add(ratio * image, (1 - ratio) * _mask) 171 | 172 | joint_img = np.concatenate([image, mask_img, pred_cup, pred_disc], axis=1) 173 | return joint_img 174 | 175 | 176 | def save_val_img(path, epoch, img): 177 | name = osp.join(path, "visualization", "epoch_%d.png" % epoch) 178 | out = osp.join(path, "visualization") 179 | if not osp.exists(out): 180 | os.makedirs(out) 181 | img_shape = img[0].shape 182 | stack_image = np.zeros([len(img) * img_shape[0], img_shape[1], img_shape[2]]) 183 | for i in range(len(img)): 184 | stack_image[i * img_shape[0] : (i + 1) * img_shape[0], :, : ] = img[i] 185 | imsave(name, stack_image) 186 | 187 | 188 | 189 | 190 | def save_per_img(patch_image, data_save_path, img_name, prob_map, mask_path=None, ext="bmp"): 191 | path1 = os.path.join(data_save_path, 'overlay', img_name.split('.')[0]+'.png') 192 | path0 = os.path.join(data_save_path, 'original_image', img_name.split('.')[0]+'.png') 193 | if not os.path.exists(os.path.dirname(path0)): 194 | os.makedirs(os.path.dirname(path0)) 195 | if not os.path.exists(os.path.dirname(path1)): 196 | os.makedirs(os.path.dirname(path1)) 197 | 198 | disc_map = prob_map[0] 199 | cup_map = prob_map[1] 200 | size = disc_map.shape 201 | disc_map[:, 0] = np.zeros(size[0]) 202 | disc_map[:, size[1] - 1] = np.zeros(size[0]) 203 | disc_map[0, :] = np.zeros(size[1]) 204 | disc_map[size[0] - 1, :] = np.zeros(size[1]) 205 | size = cup_map.shape 206 | cup_map[:, 0] = np.zeros(size[0]) 207 | cup_map[:, size[1] - 1] = np.zeros(size[0]) 208 | cup_map[0, :] = np.zeros(size[1]) 209 | cup_map[size[0] - 1, :] = np.zeros(size[1]) 210 | 211 | disc_mask = (disc_map > 0.75) # return binary mask 212 | cup_mask = (cup_map > 0.75) 213 | disc_mask = disc_mask.astype(np.uint8) 214 | cup_mask = cup_mask.astype(np.uint8) 215 | 216 | for i in range(5): 217 | disc_mask = scipy.signal.medfilt2d(disc_mask, 7) 218 | cup_mask = scipy.signal.medfilt2d(cup_mask, 7) 219 | disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 220 | cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 221 | disc_mask = get_largest_fillhole(disc_mask) 222 | cup_mask = get_largest_fillhole(cup_mask) 223 | 224 | disc_mask = morphology.binary_dilation(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 225 | cup_mask = morphology.binary_dilation(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 226 | 227 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1 228 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8) 229 | 230 | 231 | contours_disc = measure.find_contours(disc_mask, 0.5) 232 | contours_cup = measure.find_contours(cup_mask, 0.5) 233 | 234 | patch_image2 = patch_image.astype(np.uint8) 235 | patch_image2 = Image.fromarray(patch_image2) 236 | 237 | patch_image2.save(path0) 238 | 239 | for n, contour in enumerate(contours_cup): 240 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0] 241 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0] 242 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 255, 0] 243 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 255, 0] 244 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0] 245 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 255, 0] 246 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 255, 0] 247 | 248 | for n, contour in enumerate(contours_disc): 249 | patch_image[contour[:, 0].astype(int), contour[:, 1].astype(int), :] = [0, 0, 255] 250 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 0, 255] 251 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 0, 255] 252 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 0, 255] 253 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 0, 255] 254 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 0, 255] 255 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 0, 255] 256 | 257 | patch_image = patch_image.astype(np.uint8) 258 | patch_image = Image.fromarray(patch_image) 259 | 260 | patch_image.save(path1) 261 | 262 | def untransform(img, lt): 263 | img = (img + 1) * 127.5 264 | lt = lt * 128 265 | return img, lt -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/emma-sjwang/BEAL/945cad38a354605b8bca5bc01ae1b65848d605e1/utils/__init__.py -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | bce = torch.nn.BCEWithLogitsLoss(reduction='none') 5 | 6 | def _upscan(f): 7 | for i, fi in enumerate(f): 8 | if fi == np.inf: continue 9 | for j in range(1,i+1): 10 | x = fi+j*j 11 | if f[i-j] < x: break 12 | f[i-j] = x 13 | 14 | 15 | def dice_coefficient_numpy(binary_segmentation, binary_gt_label): 16 | ''' 17 | Compute the Dice coefficient between two binary segmentation. 18 | Dice coefficient is defined as here: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient 19 | Input: 20 | binary_segmentation: binary 2D numpy array representing the region of interest as segmented by the algorithm 21 | binary_gt_label: binary 2D numpy array representing the region of interest as provided in the database 22 | Output: 23 | dice_value: Dice coefficient between the segmentation and the ground truth 24 | ''' 25 | 26 | # turn all variables to booleans, just in case 27 | binary_segmentation = np.asarray(binary_segmentation, dtype=np.bool) 28 | binary_gt_label = np.asarray(binary_gt_label, dtype=np.bool) 29 | 30 | # compute the intersection 31 | intersection = np.logical_and(binary_segmentation, binary_gt_label) 32 | 33 | # count the number of True pixels in the binary segmentation 34 | segmentation_pixels = float(np.sum(binary_segmentation.flatten())) 35 | # same for the ground truth 36 | gt_label_pixels = float(np.sum(binary_gt_label.flatten())) 37 | # same for the intersection 38 | intersection = float(np.sum(intersection.flatten())) 39 | 40 | # compute the Dice coefficient 41 | dice_value = (2 * intersection + 1.0) / (1.0 + segmentation_pixels + gt_label_pixels) 42 | 43 | # return it 44 | return dice_value 45 | 46 | 47 | def dice_coeff(pred, target): 48 | """This definition generalize to real valued pred and target vector. 49 | This should be differentiable. 50 | pred: tensor with first dimension as batch 51 | target: tensor with first dimension as batch 52 | """ 53 | 54 | target = target.data.cpu() 55 | pred = torch.sigmoid(pred) 56 | pred = pred.data.cpu() 57 | pred[pred > 0.5] = 1 58 | pred[pred <= 0.5] = 0 59 | 60 | return dice_coefficient_numpy(pred, target) 61 | 62 | def dice_coeff_2label(pred, target): 63 | """This definition generalize to real valued pred and target vector. 64 | This should be differentiable. 65 | pred: tensor with first dimension as batch 66 | target: tensor with first dimension as batch 67 | """ 68 | 69 | target = target.data.cpu() 70 | pred = torch.sigmoid(pred) 71 | pred = pred.data.cpu() 72 | pred[pred > 0.75] = 1 73 | pred[pred <= 0.75] = 0 74 | # print target.shape 75 | # print pred.shape 76 | return dice_coefficient_numpy(pred[:, 0, ...], target[:, 0, ...]), dice_coefficient_numpy(pred[:, 1, ...], target[:, 1, ...]) 77 | 78 | 79 | def DiceLoss(input, target): 80 | ''' 81 | in tensor fomate 82 | :param input: 83 | :param target: 84 | :return: 85 | ''' 86 | smooth = 1. 87 | iflat = input.contiguous().view(-1) 88 | tflat = target.contiguous().view(-1) 89 | intersection = (iflat * tflat).sum() 90 | 91 | return 1 - ((2. * intersection + smooth) / 92 | (iflat.sum() + tflat.sum() + smooth)) 93 | --------------------------------------------------------------------------------