├── README.md ├── dataloaders ├── __init__.py ├── custom_transforms.py ├── fundus_dataloader.py ├── label_to_colormap.py ├── mypath.py ├── net.py └── utils.py ├── generate_pseudo_bound.py ├── generate_pseudo_label.py ├── metrics.py ├── mypath.py ├── networks ├── GAN.py ├── __init__.py ├── __pycache__ │ ├── GAN.cpython-37.pyc │ ├── GAN.cpython-38.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── aspp.cpython-37.pyc │ ├── aspp.cpython-38.pyc │ ├── aspp_eval.cpython-37.pyc │ ├── aspp_eval.cpython-38.pyc │ ├── decoder.cpython-37.pyc │ ├── decoder.cpython-38.pyc │ ├── deeplabv3.cpython-37.pyc │ ├── deeplabv3.cpython-38.pyc │ ├── deeplabv3_eval.cpython-37.pyc │ ├── deeplabv3_eval.cpython-38.pyc │ └── utils.cpython-38.pyc ├── aspp.py ├── aspp_eval.py ├── backbone │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── drn.cpython-37.pyc │ │ ├── drn.cpython-38.pyc │ │ ├── mobilenet.cpython-37.pyc │ │ ├── mobilenet.cpython-38.pyc │ │ ├── resnet.cpython-37.pyc │ │ ├── resnet.cpython-38.pyc │ │ ├── xception.cpython-37.pyc │ │ └── xception.cpython-38.pyc │ ├── drn.py │ ├── mobilenet.py │ ├── resnet.py │ └── xception.py ├── decoder.py ├── deeplabv3.py ├── deeplabv3_eval.py ├── layers.py ├── models.py ├── sync_batchnorm │ ├── __pycache__ │ │ ├── batchnorm.cpython-37.pyc │ │ ├── batchnorm.cpython-38.pyc │ │ ├── comm.cpython-37.pyc │ │ └── comm.cpython-38.pyc │ ├── batchnorm.py │ └── comm.py └── utils.py ├── train_process ├── Trainer.py ├── __init__.py └── __pycache__ │ ├── Trainer.cpython-37.pyc │ ├── Trainer.cpython-38.pyc │ ├── Trainer_fgsm.cpython-38.pyc │ ├── __init__.cpython-37.pyc │ └── __init__.cpython-38.pyc ├── train_source.py ├── train_target.py └── utils ├── Utils.py ├── __init__.py ├── __pycache__ ├── Utils.cpython-37.pyc ├── Utils.cpython-38.pyc ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc └── metrics.cpython-38.pyc └── losses.py /README.md: -------------------------------------------------------------------------------- 1 | # Robust Source-Free Domain Adaptation for Fundus Image Segmentation 2 | Datasets & Code for the WACV 2024 paper 'Robust Source-Free Domain Adaptation for Fundus Image Segmentation' [Paper](https://arxiv.org/abs/2310.16665). 3 | 4 | In this study, we propose a two-stage training strategy for robust domain adaptation. In the source training stage, we utilize adversarial sample augmentation to enhance the robustness and generalization capability of the source model. And in the target training stage, we propose a novel robust pseudo-label and pseudo-boundary (PLPB) method, which effectively utilizes unlabeled target data to generate pseudo labels and pseudo boundaries that enable model self-adaptation without requiring source data. Extensive experimental results on cross-domain fundus image segmentation confirm the effectiveness and versatility of our method. 5 | 6 | ## Paper 7 | [Robust Source-Free Domain Adaptation for Fundus Image Segmentation](https://arxiv.org/abs/2310.16665) WACV 2024 8 | ![image](https://github.com/LinGrayy/PLPB/assets/49065934/84cfe4bd-d584-4742-8f4d-311bd2929928) 9 | 10 | ## Pytorch implementation of our method PLPB. 11 | 12 | ## Installation 13 | * Install Pytorch 0.4.1 and CUDA 9.0 14 | * Clone this repo 15 | ``` 16 | git clone https://github.com/LinGrayy/PLPB 17 | cd PLPB 18 | ``` 19 | 20 | ## Train 21 | * Download datasets from [here](https://drive.google.com/file/d/1B7ArHRBjt2Dx29a3A6X_lGhD0vDVr3sy/view). 22 | * Download the source domain model from `./logs/source/robust-checkpoint.pth.tar` as the robust model 23 | or specify the data path in `./train_source.py` and then train `./train_source.py`. 24 | * Save the source domain model into folder `./logs/source`. 25 | 26 | * specify the model path and data path in `./generate_pseudo_label.py` and then train `./generate_pseudo_label.py`, obtaining standard pseudo label. 27 | * Save generated pseudo labels into folder `./generate_pseudo/mask`. 28 | * specify the model path and data path in `./generate_pseudo_bound.py` and then train `./generate_pseudo_bound.py`, obtaining standard pseudo boundary. 29 | * Save generated pseudo labels into the folder `./generate_pseudo/bound`. 30 | 31 | * Run `./train_target.py` to start the target domain training process. 32 | 33 | ## Acknowledgement 34 | The code for source domain training is modified from [BEAL](https://github.com/emma-sjwang/BEAL) and [DPL](https://github.com/cchen-cc/SFDA-DPL). 35 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataloaders/fundus_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from pathlib import Path 7 | from glob import glob 8 | import random 9 | import torch 10 | import torchvision.transforms as transforms 11 | import torchvision.transforms.functional as Ft 12 | import imgaug.augmenters as iaa 13 | import dataloaders.net as net 14 | from dataloaders.utils import * 15 | from dataloaders.mypath import MYPath 16 | import cv2 17 | from torch.utils.data import DataLoader 18 | import torch.nn as nn 19 | 20 | class FundusSegmentation(Dataset): 21 | """ 22 | Fundus segmentation dataset 23 | including 5 domain dataset 24 | one for test others for training 25 | """ 26 | 27 | def __init__(self, 28 | base_dir=MYPath.db_root_dir('fundus'), 29 | dataset='refuge', 30 | split='train', 31 | testid=None, 32 | transform=None 33 | ): 34 | """ 35 | :param base_dir: path to VOC dataset directory 36 | :param split: train/val 37 | :param transform: transform to apply 38 | """ 39 | # super().__init__() 40 | self._base_dir = base_dir 41 | self.image_list = [] 42 | self.split = split 43 | 44 | self.image_pool = [] 45 | self.label_pool = [] 46 | self.img_name_pool = [] 47 | 48 | self._image_dir = os.path.join(self._base_dir, dataset, split, 'image') 49 | print(self._image_dir) 50 | imagelist = glob(self._image_dir + "/*.png") #png 51 | for image_path in imagelist: 52 | gt_path = image_path.replace('image/', 'mask/') # 53 | #gt_path = gt_path.replace('.tif', '-1.tif') #RIGA 54 | self.image_list.append({'image': image_path, 'label': gt_path, 'id': testid}) 55 | 56 | self.transform = transform 57 | # self._read_img_into_memory() 58 | # Display stats 59 | print('Number of images in {}: {:d}'.format(split, len(self.image_list))) 60 | 61 | def __len__(self): 62 | return len(self.image_list) 63 | 64 | def __getitem__(self, index): 65 | 66 | _img = Image.open(self.image_list[index]['image']).convert('RGB') 67 | _target = Image.open(self.image_list[index]['label']) 68 | if _target.mode is 'RGB': 69 | _target = _target.convert('L') 70 | _img_name = self.image_list[index]['image'].split('/')[-1] 71 | 72 | # _img = self.image_pool[index] 73 | # _target = self.label_pool[index] 74 | # _img_name = self.img_name_pool[index] 75 | anco_sample = {'image': _img, 'label': _target, 'img_name': _img_name, 'image1': _img} 76 | 77 | if self.transform is not None: 78 | anco_sample = self.transform(anco_sample) 79 | 80 | return anco_sample 81 | 82 | def _read_img_into_memory(self): 83 | 84 | img_num = len(self.image_list) 85 | for index in range(img_num): 86 | self.image_pool.append(Image.open(self.image_list[index]['image']).convert('RGB')) 87 | _target = Image.open(self.image_list[index]['label']) 88 | if _target.mode is 'RGB': 89 | _target = _target.convert('L') 90 | self.label_pool.append(_target) 91 | _img_name = self.image_list[index]['image'].split('/')[-1] 92 | self.img_name_pool.append(_img_name) 93 | 94 | 95 | def __str__(self): 96 | return 'Fundus(split=' + str(self.split) + ')' 97 | 98 | # load adversarial samples 99 | class FundusSegmentation_pgdtest(Dataset): 100 | """ 101 | Fundus segmentation dataset 102 | including 5 domain dataset 103 | one for test others for training 104 | """ 105 | 106 | def __init__(self, 107 | base_dir=MYPath.db_root_dir('fundus'), 108 | dataset='refuge', 109 | split='train', 110 | testid=None, 111 | transform=None 112 | ): 113 | """ 114 | :param base_dir: path to VOC dataset directory 115 | :param split: train/val 116 | :param transform: transform to apply 117 | """ 118 | # super().__init__() 119 | self._base_dir = base_dir 120 | self.image_list = [] 121 | self.split = split 122 | 123 | self.image_pool = [] 124 | self.label_pool = [] 125 | self.img_name_pool = [] 126 | 127 | self._image_dir = os.path.join(self._base_dir, dataset, split, 'image') 128 | print(self._image_dir+'/pgd') 129 | imagelist = glob(self._image_dir + "/*.png") 130 | 131 | for image_path in imagelist: 132 | gt_path = image_path.replace('image', 'mask') 133 | p1_path = image_path.replace('Domain1/test/ROIs/image/', 'PGD/DPL/Domain1/') 134 | # p1_path = image_path.replace('Domain1/test/ROIs/', 'PGD/OURS/Domain1/test/') 135 | #print(p1_path) 136 | self.image_list.append({'image': p1_path, 'label': gt_path, 'id': testid}) 137 | 138 | 139 | self.transform = transform 140 | 141 | print('Number of images in {}: {:d}'.format(split, len(self.image_list))) 142 | 143 | def __len__(self): 144 | return len(self.image_list) 145 | 146 | def __getitem__(self, index): 147 | _img = Image.open(self.image_list[index]['image']).convert('RGB') 148 | _target = Image.open(self.image_list[index]['label']) 149 | if _target.mode is 'RGB': 150 | _target = _target.convert('L') 151 | _img_name = self.image_list[index]['image'].split('/')[-1] 152 | 153 | 154 | anco_sample = {'image': _img, 'label': _target, 'img_name': _img_name} 155 | 156 | if self.transform is not None: 157 | anco_sample = self.transform(anco_sample) 158 | 159 | return anco_sample 160 | 161 | 162 | def __str__(self): 163 | return 'Fundus(split=' + str(self.split) + ')' 164 | 165 | 166 | 167 | if __name__ == '__main__': 168 | data_dir = '/mnt/data1/llr_data/Cell' 169 | dataset = 'Domain3' 170 | cell_dataset = CellSegmentation(base_dir=data_dir, dataset=dataset, split='test/') 171 | 172 | domain_loader = DataLoader(cell_dataset, batch_size=4, shuffle=False, num_workers=0, pin_memory=True) 173 | for batch_idx, (sample) in enumerate(domain_loader): 174 | data, img_name = sample['image'], sample['img_name'] 175 | print(img_name) 176 | 177 | 178 | -------------------------------------------------------------------------------- /dataloaders/label_to_colormap.py: -------------------------------------------------------------------------------- 1 | # Lint as: python2, python3 2 | # Copyright 2018 The TensorFlow Authors All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Visualizes the segmentation results via specified color map. 17 | 18 | Visualizes the semantic segmentation results by the color map 19 | defined by the different datasets. Supported colormaps are: 20 | 21 | * ADE20K (http://groups.csail.mit.edu/vision/datasets/ADE20K/). 22 | 23 | * Cityscapes dataset (https://www.cityscapes-dataset.com). 24 | 25 | * Mapillary Vistas (https://research.mapillary.com). 26 | 27 | * PASCAL VOC 2012 (http://host.robots.ox.ac.uk/pascal/VOC/). 28 | """ 29 | 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | import numpy as np 34 | from six.moves import range 35 | 36 | # Dataset names. 37 | _ADE20K = 'ade20k' 38 | _CITYSCAPES = 'cityscapes' 39 | _MAPILLARY_VISTAS = 'mapillary_vistas' 40 | _PASCAL = 'pascal' 41 | 42 | # Max number of entries in the colormap for each dataset. 43 | _DATASET_MAX_ENTRIES = { 44 | _ADE20K: 151, 45 | _CITYSCAPES: 256, 46 | _MAPILLARY_VISTAS: 66, 47 | _PASCAL: 512, 48 | } 49 | 50 | def create_cityscapes_label_colormap(): 51 | """Creates a label colormap used in CITYSCAPES segmentation benchmark. 52 | 53 | Returns: 54 | A colormap for visualizing segmentation results. 55 | """ 56 | colormap = np.zeros((20, 3), dtype=np.uint8) 57 | colormap[0] = [128, 64, 128] 58 | colormap[1] = [244, 35, 232] 59 | colormap[2] = [70, 70, 70] 60 | colormap[3] = [102, 102, 156] 61 | colormap[4] = [190, 153, 153] 62 | colormap[5] = [153, 153, 153] 63 | colormap[6] = [250, 170, 30] 64 | colormap[7] = [220, 220, 0] 65 | colormap[8] = [107, 142, 35] 66 | colormap[9] = [152, 251, 152] 67 | colormap[10] = [70, 130, 180] 68 | colormap[11] = [220, 20, 60] 69 | colormap[12] = [255, 0, 0] 70 | colormap[13] = [0, 0, 142] 71 | colormap[14] = [0, 0, 70] 72 | colormap[15] = [0, 60, 100] 73 | colormap[16] = [0, 80, 100] 74 | colormap[17] = [0, 0, 230] 75 | colormap[18] = [119, 11, 32] 76 | colormap[19] = [0,0,0] # void class 77 | return colormap 78 | 79 | 80 | def get_cityscapes_name(): 81 | return _CITYSCAPES 82 | 83 | 84 | def bit_get(val, idx): 85 | """Gets the bit value. 86 | 87 | Args: 88 | val: Input value, int or numpy int array. 89 | idx: Which bit of the input val. 90 | 91 | Returns: 92 | The "idx"-th bit of input val. 93 | """ 94 | return (val >> idx) & 1 95 | 96 | 97 | def create_label_colormap(dataset=_PASCAL): 98 | return create_cityscapes_label_colormap() 99 | 100 | 101 | def label_to_color_image(label, dataset=_PASCAL): 102 | """Adds color defined by the dataset colormap to the label. 103 | 104 | Args: 105 | label: A 2D array with integer type, storing the segmentation label. 106 | dataset: The colormap used in the dataset. 107 | 108 | Returns: 109 | result: A 2D array with floating type. The element of the array 110 | is the color indexed by the corresponding element in the input label 111 | to the dataset color map. 112 | 113 | Raises: 114 | ValueError: If label is not of rank 2 or its value is larger than color 115 | map maximum entry. 116 | """ 117 | if label.ndim != 2: 118 | raise ValueError('Expect 2-D input label. Got {}'.format(label.shape)) 119 | 120 | if np.max(label) >= _DATASET_MAX_ENTRIES[dataset]: 121 | raise ValueError( 122 | 'label value too large: {} >= {}.'.format( 123 | np.max(label), _DATASET_MAX_ENTRIES[dataset])) 124 | 125 | colormap = create_label_colormap(dataset) 126 | return colormap[label] 127 | 128 | 129 | def get_dataset_colormap_max_entries(dataset): 130 | return _DATASET_MAX_ENTRIES[dataset] 131 | -------------------------------------------------------------------------------- /dataloaders/mypath.py: -------------------------------------------------------------------------------- 1 | class MYPath(object): 2 | @staticmethod 3 | def db_root_dir(database): 4 | if database == 'fundus': 5 | return '../../../../data/disc_cup_split/' # foler that contains leftImg8bit/ 6 | if database =='cell': 7 | return '../../../../data/' 8 | else: 9 | print('Database {} not available.'.format(database)) 10 | raise NotImplementedError 11 | -------------------------------------------------------------------------------- /dataloaders/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from dataloaders.utils import adaptive_instance_normalization as adain# 4 | from dataloaders.utils import calc_mean_std 5 | 6 | decoder = nn.Sequential( 7 | nn.ReflectionPad2d((1, 1, 1, 1)), 8 | nn.Conv2d(512, 256, (3, 3)), 9 | nn.ReLU(), 10 | nn.Upsample(scale_factor=2, mode='nearest'), 11 | nn.ReflectionPad2d((1, 1, 1, 1)), 12 | nn.Conv2d(256, 256, (3, 3)), 13 | nn.ReLU(), 14 | nn.ReflectionPad2d((1, 1, 1, 1)), 15 | nn.Conv2d(256, 256, (3, 3)), 16 | nn.ReLU(), 17 | nn.ReflectionPad2d((1, 1, 1, 1)), 18 | nn.Conv2d(256, 256, (3, 3)), 19 | nn.ReLU(), 20 | nn.ReflectionPad2d((1, 1, 1, 1)), 21 | nn.Conv2d(256, 128, (3, 3)), 22 | nn.ReLU(), 23 | nn.Upsample(scale_factor=2, mode='nearest'), 24 | nn.ReflectionPad2d((1, 1, 1, 1)), 25 | nn.Conv2d(128, 128, (3, 3)), 26 | nn.ReLU(), 27 | nn.ReflectionPad2d((1, 1, 1, 1)), 28 | nn.Conv2d(128, 64, (3, 3)), 29 | nn.ReLU(), 30 | nn.Upsample(scale_factor=2, mode='nearest'), 31 | nn.ReflectionPad2d((1, 1, 1, 1)), 32 | nn.Conv2d(64, 64, (3, 3)), 33 | nn.ReLU(), 34 | nn.ReflectionPad2d((1, 1, 1, 1)), 35 | nn.Conv2d(64, 3, (3, 3)), 36 | ) 37 | 38 | vgg = nn.Sequential( 39 | nn.Conv2d(3, 3, (1, 1)), 40 | nn.ReflectionPad2d((1, 1, 1, 1)), 41 | nn.Conv2d(3, 64, (3, 3)), 42 | nn.ReLU(), # relu1-1 43 | nn.ReflectionPad2d((1, 1, 1, 1)), 44 | nn.Conv2d(64, 64, (3, 3)), 45 | nn.ReLU(), # relu1-2 46 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 47 | nn.ReflectionPad2d((1, 1, 1, 1)), 48 | nn.Conv2d(64, 128, (3, 3)), 49 | nn.ReLU(), # relu2-1 50 | nn.ReflectionPad2d((1, 1, 1, 1)), 51 | nn.Conv2d(128, 128, (3, 3)), 52 | nn.ReLU(), # relu2-2 53 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 54 | nn.ReflectionPad2d((1, 1, 1, 1)), 55 | nn.Conv2d(128, 256, (3, 3)), 56 | nn.ReLU(), # relu3-1 57 | nn.ReflectionPad2d((1, 1, 1, 1)), 58 | nn.Conv2d(256, 256, (3, 3)), 59 | nn.ReLU(), # relu3-2 60 | nn.ReflectionPad2d((1, 1, 1, 1)), 61 | nn.Conv2d(256, 256, (3, 3)), 62 | nn.ReLU(), # relu3-3 63 | nn.ReflectionPad2d((1, 1, 1, 1)), 64 | nn.Conv2d(256, 256, (3, 3)), 65 | nn.ReLU(), # relu3-4 66 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 67 | nn.ReflectionPad2d((1, 1, 1, 1)), 68 | nn.Conv2d(256, 512, (3, 3)), 69 | nn.ReLU(), # relu4-1, this is the last layer used 70 | nn.ReflectionPad2d((1, 1, 1, 1)), 71 | nn.Conv2d(512, 512, (3, 3)), 72 | nn.ReLU(), # relu4-2 73 | nn.ReflectionPad2d((1, 1, 1, 1)), 74 | nn.Conv2d(512, 512, (3, 3)), 75 | nn.ReLU(), # relu4-3 76 | nn.ReflectionPad2d((1, 1, 1, 1)), 77 | nn.Conv2d(512, 512, (3, 3)), 78 | nn.ReLU(), # relu4-4 79 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 80 | nn.ReflectionPad2d((1, 1, 1, 1)), 81 | nn.Conv2d(512, 512, (3, 3)), 82 | nn.ReLU(), # relu5-1 83 | nn.ReflectionPad2d((1, 1, 1, 1)), 84 | nn.Conv2d(512, 512, (3, 3)), 85 | nn.ReLU(), # relu5-2 86 | nn.ReflectionPad2d((1, 1, 1, 1)), 87 | nn.Conv2d(512, 512, (3, 3)), 88 | nn.ReLU(), # relu5-3 89 | nn.ReflectionPad2d((1, 1, 1, 1)), 90 | nn.Conv2d(512, 512, (3, 3)), 91 | nn.ReLU() # relu5-4 92 | ) 93 | 94 | 95 | class Net(nn.Module): 96 | def __init__(self, encoder, decoder): 97 | super(Net, self).__init__() 98 | enc_layers = list(encoder.children()) 99 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 100 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 101 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 102 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 103 | self.decoder = decoder 104 | self.mse_loss = nn.MSELoss() 105 | 106 | # fix the encoder 107 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']: 108 | for param in getattr(self, name).parameters(): 109 | param.requires_grad = False 110 | 111 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 112 | def encode_with_intermediate(self, input): 113 | results = [input] 114 | for i in range(4): 115 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 116 | results.append(func(results[-1])) 117 | return results[1:] 118 | 119 | # extract relu4_1 from input image 120 | def encode(self, input): 121 | for i in range(4): 122 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input) 123 | return input 124 | 125 | def calc_content_loss(self, input, target): 126 | assert (input.size() == target.size()) 127 | assert (target.requires_grad is False) 128 | return self.mse_loss(input, target) 129 | 130 | def calc_style_loss(self, input, target): 131 | assert (input.size() == target.size()) 132 | assert (target.requires_grad is False) 133 | input_mean, input_std = calc_mean_std(input) 134 | target_mean, target_std = calc_mean_std(target) 135 | return self.mse_loss(input_mean, target_mean) + \ 136 | self.mse_loss(input_std, target_std) 137 | 138 | def forward(self, content, style, alpha=1.0): 139 | assert 0 <= alpha <= 1 140 | style_feats = self.encode_with_intermediate(style) 141 | content_feat = self.encode(content) 142 | t = adain(content_feat, style_feats[-1]) 143 | t = alpha * t + (1 - alpha) * content_feat 144 | 145 | g_t = self.decoder(t) 146 | g_t_feats = self.encode_with_intermediate(g_t) 147 | 148 | loss_c = self.calc_content_loss(g_t_feats[-1], t) 149 | loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) 150 | for i in range(1, 4): 151 | loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) 152 | return loss_c, loss_s 153 | -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import matplotlib.pyplot as plt 4 | from torch.autograd import Variable 5 | from mpl_toolkits.axes_grid1 import ImageGrid 6 | from torchvision.transforms import Compose, ToTensor 7 | from torchvision import transforms 8 | import torch.nn.functional as F 9 | import sys 10 | import numpy as np 11 | import torchvision 12 | from dataloaders.label_to_colormap import create_cityscapes_label_colormap 13 | from sklearn.metrics import confusion_matrix 14 | from PIL import Image 15 | 16 | def bytescale(data, cmin=None, cmax=None, high=255, low=0): 17 | """ 18 | Byte scales an array (image). 19 | Byte scaling means converting the input image to uint8 dtype and scaling 20 | the range to ``(low, high)`` (default 0-255). 21 | If the input image already has dtype uint8, no scaling is done. 22 | This function is only available if Python Imaging Library (PIL) is installed. 23 | Parameters 24 | ---------- 25 | data : ndarray 26 | PIL image data array. 27 | cmin : scalar, optional 28 | Bias scaling of small values. Default is ``data.min()``. 29 | cmax : scalar, optional 30 | Bias scaling of large values. Default is ``data.max()``. 31 | high : scalar, optional 32 | Scale max value to `high`. Default is 255. 33 | low : scalar, optional 34 | Scale min value to `low`. Default is 0. 35 | Returns 36 | ------- 37 | img_array : uint8 ndarray 38 | The byte-scaled array. 39 | Examples 40 | -------- 41 | """ 42 | if data.dtype == np.uint8: 43 | return data 44 | 45 | if high > 255: 46 | raise ValueError("`high` should be less than or equal to 255.") 47 | if low < 0: 48 | raise ValueError("`low` should be greater than or equal to 0.") 49 | if high < low: 50 | raise ValueError("`high` should be greater than or equal to `low`.") 51 | 52 | if cmin is None: 53 | cmin = data.min() 54 | if cmax is None: 55 | cmax = data.max() 56 | 57 | cscale = cmax - cmin 58 | if cscale < 0: 59 | raise ValueError("`cmax` should be larger than `cmin`.") 60 | elif cscale == 0: 61 | cscale = 1 62 | 63 | scale = float(high - low) / cscale 64 | bytedata = (data - cmin) * scale + low 65 | return (bytedata.clip(low, high) + 0.5).astype(np.uint8) 66 | 67 | 68 | def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None, 69 | mode=None, channel_axis=None): 70 | """Takes a numpy array and returns a PIL image. 71 | This function is only available if Python Imaging Library (PIL) is installed. 72 | The mode of the PIL image depends on the array shape and the `pal` and 73 | `mode` keywords. 74 | For 2-D arrays, if `pal` is a valid (N,3) byte-array giving the RGB values 75 | (from 0 to 255) then ``mode='P'``, otherwise ``mode='L'``, unless mode 76 | is given as 'F' or 'I' in which case a float and/or integer array is made. 77 | .. warning:: 78 | This function uses `bytescale` under the hood to rescale images to use 79 | the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``. 80 | It will also cast data for 2-D images to ``uint32`` for ``mode=None`` 81 | (which is the default). 82 | Notes 83 | ----- 84 | For 3-D arrays, the `channel_axis` argument tells which dimension of the 85 | array holds the channel data. 86 | For 3-D arrays if one of the dimensions is 3, the mode is 'RGB' 87 | by default or 'YCbCr' if selected. 88 | The numpy array must be either 2 dimensional or 3 dimensional. 89 | """ 90 | data = np.asarray(arr) 91 | if np.iscomplexobj(data): 92 | raise ValueError("Cannot convert a complex-valued array.") 93 | shape = list(data.shape) 94 | valid = len(shape) == 2 or ((len(shape) == 3) and 95 | ((3 in shape) or (4 in shape))) 96 | if not valid: 97 | raise ValueError("'arr' does not have a suitable array shape for " 98 | "any mode.") 99 | if len(shape) == 2: 100 | shape = (shape[1], shape[0]) # columns show up first 101 | if mode == 'F': 102 | data32 = data.astype(np.float32) 103 | image = Image.frombytes(mode, shape, data32.tostring()) 104 | return image 105 | if mode in [None, 'L', 'P']: 106 | bytedata = bytescale(data, high=high, low=low, 107 | cmin=cmin, cmax=cmax) 108 | image = Image.frombytes('L', shape, bytedata.tostring()) 109 | if pal is not None: 110 | image.putpalette(np.asarray(pal, dtype=np.uint8).tostring()) 111 | # Becomes a mode='P' automagically. 112 | elif mode == 'P': # default gray-scale 113 | pal = (np.arange(0, 256, 1, dtype=np.uint8)[:, np.newaxis] * 114 | np.ones((3,), dtype=np.uint8)[np.newaxis, :]) 115 | image.putpalette(np.asarray(pal, dtype=np.uint8).tostring()) 116 | return image 117 | if mode == '1': # high input gives threshold for 1 118 | bytedata = (data > high) 119 | image = Image.frombytes('1', shape, bytedata.tostring()) 120 | return image 121 | if cmin is None: 122 | cmin = np.amin(np.ravel(data)) 123 | if cmax is None: 124 | cmax = np.amax(np.ravel(data)) 125 | data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low 126 | if mode == 'I': 127 | data32 = data.astype(np.uint32) 128 | image = Image.frombytes(mode, shape, data32.tostring()) 129 | else: 130 | raise ValueError(_errstr) 131 | return image 132 | 133 | # if here then 3-d array with a 3 or a 4 in the shape length. 134 | # Check for 3 in datacube shape --- 'RGB' or 'YCbCr' 135 | if channel_axis is None: 136 | if (3 in shape): 137 | ca = np.flatnonzero(np.asarray(shape) == 3)[0] 138 | else: 139 | ca = np.flatnonzero(np.asarray(shape) == 4) 140 | if len(ca): 141 | ca = ca[0] 142 | else: 143 | raise ValueError("Could not find channel dimension.") 144 | else: 145 | ca = channel_axis 146 | 147 | numch = shape[ca] 148 | if numch not in [3, 4]: 149 | raise ValueError("Channel axis dimension is not valid.") 150 | 151 | bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax) 152 | if ca == 2: 153 | strdata = bytedata.tostring() 154 | shape = (shape[1], shape[0]) 155 | elif ca == 1: 156 | strdata = np.transpose(bytedata, (0, 2, 1)).tostring() 157 | shape = (shape[2], shape[0]) 158 | elif ca == 0: 159 | strdata = np.transpose(bytedata, (1, 2, 0)).tostring() 160 | shape = (shape[2], shape[1]) 161 | if mode is None: 162 | if numch == 3: 163 | mode = 'RGB' 164 | else: 165 | mode = 'RGBA' 166 | 167 | if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']: 168 | raise ValueError(_errstr) 169 | 170 | if mode in ['RGB', 'YCbCr']: 171 | if numch != 3: 172 | raise ValueError("Invalid array shape for mode.") 173 | if mode in ['RGBA', 'CMYK']: 174 | if numch != 4: 175 | raise ValueError("Invalid array shape for mode.") 176 | 177 | # Here we know data and mode is correct 178 | image = Image.frombytes(mode, shape, strdata) 179 | return image 180 | 181 | def label2Color(label): 182 | label = np.asarray(label, dtype=np.uint8) 183 | colormap = create_cityscapes_label_colormap() 184 | image = np.zeros((label.shape[0],label.shape[1],3), dtype=np.uint8) 185 | for i in range(label.shape[0]): 186 | for j in range(label.shape[1]): 187 | if(label[i,j] > 19): 188 | label[i,j] = 19 189 | image[i,j] = colormap[label[i,j]] 190 | return image 191 | 192 | def segMap3(rgb, img_label, pred): 193 | # plotting for 0th batch only 194 | rgb, img_label, pred = rgb[0], img_label[0], pred[0] 195 | rgb = rgb.permute(1,2,0) 196 | 197 | pred = F.softmax(pred, dim=0) 198 | pred = torch.argmax(pred, dim=0) 199 | 200 | img_label = img_label.cpu() 201 | rgb = rgb.cpu() 202 | pred = pred.cpu() 203 | 204 | img_label = label2Color(img_label) 205 | pred = label2Color(pred) 206 | rgb = np.asarray(rgb, dtype=np.uint8) 207 | IMG_MEAN = np.array((104, 116, 122), dtype=np.uint8) 208 | rgb += IMG_MEAN 209 | rgb = rgb[:, :, : : -1] 210 | 211 | grid = torch.from_numpy(np.asarray([rgb, img_label, pred])) 212 | grid = grid.permute(0,3,1,2) 213 | grid = torchvision.utils.make_grid(grid) 214 | 215 | return grid 216 | 217 | 218 | ##### AdaIN helper functions 219 | def calc_mean_std(feat, eps=1e-5): 220 | # eps is a small value added to the variance to avoid divide-by-zero. 221 | size = feat.size() 222 | assert (len(size) == 4) 223 | N, C = size[:2] 224 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 225 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 226 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 227 | return feat_mean, feat_std 228 | 229 | 230 | def adaptive_instance_normalization(content_feat, style_feat): 231 | assert (content_feat.size()[:2] == style_feat.size()[:2]) 232 | size = content_feat.size() 233 | style_mean, style_std = calc_mean_std(style_feat) 234 | content_mean, content_std = calc_mean_std(content_feat) 235 | 236 | normalized_feat = (content_feat - content_mean.expand( 237 | size)) / content_std.expand(size) 238 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 239 | 240 | 241 | def _calc_feat_flatten_mean_std(feat): 242 | # takes 3D feat (C, H, W), return mean and std of array within channels 243 | assert (feat.size()[0] == 3) 244 | assert (isinstance(feat, torch.FloatTensor)) 245 | feat_flatten = feat.view(3, -1) 246 | mean = feat_flatten.mean(dim=-1, keepdim=True) 247 | std = feat_flatten.std(dim=-1, keepdim=True) 248 | return feat_flatten, mean, std 249 | 250 | 251 | def _mat_sqrt(x): 252 | U, D, V = torch.svd(x) 253 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 254 | 255 | 256 | def test_transform(size, crop): 257 | transform_list = [] 258 | if size != 0: 259 | transform_list.append(transforms.Resize(size)) 260 | if crop: 261 | transform_list.append(transforms.CenterCrop(size)) 262 | transform_list.append(transforms.ToTensor()) 263 | transform = transforms.Compose(transform_list) 264 | return transform 265 | 266 | # only for single style image and single content image 267 | def style_transfer(vgg, decoder, content, style, alpha=1.0): 268 | assert (0.0 <= alpha <= 1.0) 269 | content_f = vgg(content) 270 | style_f = vgg(style) 271 | feat = adaptive_instance_normalization(content_f, style_f) 272 | feat = feat * alpha + content_f * (1 - alpha) 273 | return decoder(feat) 274 | -------------------------------------------------------------------------------- /generate_pseudo_bound.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | 4 | import argparse 5 | import os 6 | import os.path as osp 7 | import torch.nn.functional as F 8 | 9 | import matplotlib 10 | matplotlib.use('TkAgg') 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | from torch.autograd import Variable 15 | import tqdm 16 | from dataloaders import fundus_dataloader as DL 17 | from torch.utils.data import DataLoader 18 | from dataloaders import custom_transforms as tr 19 | from torchvision import transforms 20 | 21 | from matplotlib.pyplot import imsave 22 | from utils.Utils import * 23 | from metrics import * 24 | from datetime import datetime 25 | import pytz 26 | from networks.deeplabv3 import * 27 | import cv2 28 | import torch.backends.cudnn as cudnn 29 | import random 30 | 31 | bceloss = torch.nn.BCELoss() 32 | seed = 3377 33 | savefig = False 34 | get_hd = False 35 | if True: 36 | cudnn.benchmark = False 37 | cudnn.deterministic = True 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed(seed) 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--model-file', type=str, default='./logs/source/source_model.pth.tar') 47 | parser.add_argument('--dataset', type=str, default='Domain2') 48 | parser.add_argument('--batchsize', type=int, default=8) 49 | parser.add_argument('--source', type=str, default='Domain3') 50 | parser.add_argument('-g', '--gpu', type=int, default=0) 51 | parser.add_argument('--data-dir', default='/mnt/data1/llr_data/Fundus/') 52 | parser.add_argument('--out-stride',type=int,default=16) 53 | parser.add_argument('--save-root-ent',type=str,default='./results/ent/') 54 | parser.add_argument('--save-root-mask',type=str,default='./results/mask/') 55 | parser.add_argument('--sync-bn',type=bool,default=True) 56 | parser.add_argument('--freeze-bn',type=bool,default=False) 57 | parser.add_argument('--test-prediction-save-path', type=str,default='./results/baseline/') 58 | 59 | args = parser.parse_args() 60 | 61 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 62 | model_file = args.model_file 63 | 64 | # 1. dataset 65 | composed_transforms_test = transforms.Compose([ 66 | tr.Resize(512), 67 | tr.Normalize_tf1(), 68 | tr.ToTensor() 69 | ]) 70 | db_train = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='train/ROIs', transform=composed_transforms_test) 71 | db_test = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='test/ROIs', transform=composed_transforms_test) 72 | db_source = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.source, split='train/ROIs', transform=composed_transforms_test) 73 | 74 | train_loader = DataLoader(db_train, batch_size=args.batchsize, shuffle=False, num_workers=1) 75 | test_loader = DataLoader(db_test, batch_size=args.batchsize, shuffle=False, num_workers=1) 76 | source_loader = DataLoader(db_source, batch_size=args.batchsize, shuffle=False, num_workers=1) 77 | 78 | # 2. model 79 | model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) 80 | 81 | if torch.cuda.is_available(): 82 | model = model.cuda() 83 | print('==> Loading %s model file: %s' % 84 | (model.__class__.__name__, model_file)) 85 | checkpoint = torch.load(model_file) 86 | 87 | model.load_state_dict(checkpoint['model_state_dict']) 88 | #model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(model_file).items()}) 89 | 90 | model.train() 91 | 92 | pseudo_bound_dic = {} 93 | 94 | with torch.no_grad(): 95 | for batch_idx, (sample) in tqdm.tqdm(enumerate(train_loader), 96 | total=len(train_loader), 97 | ncols=80, leave=False): 98 | data, target, img_name = sample['image'], sample['map'], sample['img_name'] 99 | if torch.cuda.is_available(): 100 | data, target = data.cuda(), target.cuda() 101 | data, target = Variable(data), Variable(target) 102 | 103 | preds = torch.zeros([10, data.shape[0], 2, data.shape[2], data.shape[3]]).cuda() 104 | features = torch.zeros([10, data.shape[0], 305, 128, 128]).cuda() 105 | boundary = torch.zeros([10, data.shape[0], 1, data.shape[2], data.shape[3]]).cuda() 106 | for i in range(10): 107 | with torch.no_grad(): 108 | preds[i,...], boundary[i,...], features[i,...] = model(data) 109 | preds1 = torch.sigmoid(preds) 110 | preds = torch.sigmoid(preds/2.0) 111 | prediction=torch.mean(preds1,dim=0) 112 | pseudo_label = prediction.clone() 113 | prediction=torch.mean(preds1,dim=0) 114 | 115 | b1 = torch.sigmoid(boundary) 116 | b = torch.mean(b1,dim=0) 117 | pseudo_bound = b.clone() 118 | 119 | feature = torch.mean(features,dim=0) 120 | pseudo_bound = pseudo_bound.detach().cpu().numpy() 121 | for i in range(prediction.shape[0]): 122 | pseudo_bound_dic[img_name[i]] = pseudo_bound[i] 123 | 124 | 125 | if args.dataset=="Domain1":#pseudolabel_D1 126 | np.savez('./results/bound/r-bound_D1', pseudo_bound_dic) 127 | 128 | elif args.dataset=="Domain2": 129 | np.savez('./results/bound/r-bound_D2', pseudo_bound_dic) 130 | elif args.dataset=="RIGA": 131 | np.savez('/mnt/data1/llr_data/results/bound/test-bound_D6', pseudo_bound_dic) 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /generate_pseudo_label.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | 4 | import argparse 5 | import os 6 | import os.path as osp 7 | import torch.nn.functional as F 8 | 9 | import matplotlib 10 | matplotlib.use('TkAgg') 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | from torch.autograd import Variable 15 | import tqdm 16 | from dataloaders import fundus_dataloader as DL 17 | from torch.utils.data import DataLoader 18 | from dataloaders import custom_transforms as tr 19 | from torchvision import transforms 20 | 21 | from matplotlib.pyplot import imsave 22 | from utils.Utils import * 23 | from metrics import * 24 | from datetime import datetime 25 | import pytz 26 | from networks.deeplabv3 import * 27 | import cv2 28 | import torch.backends.cudnn as cudnn 29 | import random 30 | 31 | bceloss = torch.nn.BCELoss() 32 | seed = 3377 33 | savefig = False 34 | get_hd = False 35 | if True: 36 | cudnn.benchmark = False 37 | cudnn.deterministic = True 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | torch.manual_seed(seed) 41 | torch.cuda.manual_seed(seed) 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--model-file', type=str, default='./logs/Domain2/1/checkpoint_200.pth.tar') 47 | parser.add_argument('--dataset', type=str, default='Domain2') 48 | parser.add_argument('--batchsize', type=int, default=8) 49 | parser.add_argument('--source', type=str, default='Domain3') 50 | parser.add_argument('-g', '--gpu', type=int, default=0) 51 | parser.add_argument('--data-dir', default='/mnt/data1/llr_data/Fundus/') 52 | parser.add_argument('--out-stride',type=int,default=16) 53 | parser.add_argument('--save-root-ent',type=str,default='./results/ent/') 54 | parser.add_argument('--save-root-mask',type=str,default='./results/mask/') 55 | parser.add_argument('--sync-bn',type=bool,default=True) 56 | parser.add_argument('--freeze-bn',type=bool,default=False) 57 | parser.add_argument('--test-prediction-save-path', type=str,default='./results/baseline/') 58 | 59 | args = parser.parse_args() 60 | 61 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 62 | model_file = args.model_file 63 | 64 | # 1. dataset 65 | composed_transforms_test = transforms.Compose([ 66 | tr.Resize(512), 67 | tr.Normalize_tf1(), 68 | tr.ToTensor() 69 | ]) 70 | db_train = DL.FundusSegmentationRIGA(base_dir=args.data_dir, dataset=args.dataset, split='train/', transform=composed_transforms_test) 71 | db_test = DL.FundusSegmentationRIGA(base_dir=args.data_dir, dataset=args.dataset, split='test/', transform=composed_transforms_test) 72 | db_source = DL.FundusSegmentationRIGA(base_dir=args.data_dir, dataset=args.source, split='train/ROIs', transform=composed_transforms_test) 73 | 74 | train_loader = DataLoader(db_train, batch_size=args.batchsize, shuffle=False, num_workers=1,drop_last=True) 75 | test_loader = DataLoader(db_test, batch_size=args.batchsize, shuffle=False, num_workers=1) 76 | source_loader = DataLoader(db_source, batch_size=args.batchsize, shuffle=False, num_workers=1) 77 | 78 | # 2. model 79 | model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) 80 | 81 | if torch.cuda.is_available(): 82 | model = model.cuda() 83 | print('==> Loading %s model file: %s' % 84 | (model.__class__.__name__, model_file)) 85 | checkpoint = torch.load(model_file) 86 | 87 | model.load_state_dict(checkpoint['model_state_dict']) 88 | #model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(model_file).items()}) 89 | 90 | model.train() 91 | 92 | pseudo_label_dic = {} 93 | uncertain_dic = {} 94 | proto_pseudo_dic = {} 95 | distance_0_obj_dic = {} 96 | distance_0_bck_dic = {} 97 | distance_1_bck_dic = {} 98 | distance_1_obj_dic = {} 99 | centroid_0_obj_dic = {} 100 | centroid_0_bck_dic = {} 101 | centroid_1_obj_dic = {} 102 | centroid_1_bck_dic = {} 103 | # fundus 104 | with torch.no_grad(): 105 | for batch_idx, (sample) in tqdm.tqdm(enumerate(test_loader), 106 | total=len(train_loader), 107 | ncols=80, leave=False): 108 | data, target, img_name = sample['image'], sample['map'], sample['img_name'] 109 | if torch.cuda.is_available(): 110 | data, target = data.cuda(), target.cuda() 111 | data, target = Variable(data), Variable(target) 112 | 113 | preds = torch.zeros([10, data.shape[0], 2, data.shape[2], data.shape[3]]).cuda() 114 | features = torch.zeros([10, data.shape[0], 305, 128, 128]).cuda() 115 | for i in range(10): 116 | with torch.no_grad(): 117 | preds[i,...], _, features[i,...] = model(data) 118 | preds1 = torch.sigmoid(preds) 119 | preds = torch.sigmoid(preds/2.0) 120 | std_map = torch.std(preds,dim=0) 121 | 122 | prediction=torch.mean(preds1,dim=0) 123 | 124 | pseudo_label = prediction.clone() 125 | pseudo_label[pseudo_label > 0.75] = 1.0; pseudo_label[pseudo_label <= 0.75] = 0.0 126 | 127 | feature = torch.mean(features,dim=0) 128 | 129 | target_0_obj = F.interpolate(pseudo_label[:,0:1,...], size=feature.size()[2:], mode='nearest') 130 | target_1_obj = F.interpolate(pseudo_label[:, 1:, ...], size=feature.size()[2:], mode='nearest') 131 | prediction_small = F.interpolate(prediction, size=feature.size()[2:], mode='bilinear', align_corners=True) 132 | std_map_small = F.interpolate(std_map, size=feature.size()[2:], mode='bilinear', align_corners=True) 133 | target_0_bck = 1.0 - target_0_obj;target_1_bck = 1.0 - target_1_obj 134 | 135 | mask_0_obj = torch.zeros([std_map_small.shape[0], 1, std_map_small.shape[2], std_map_small.shape[3]]).cuda() 136 | mask_0_bck = torch.zeros([std_map_small.shape[0], 1, std_map_small.shape[2], std_map_small.shape[3]]).cuda() 137 | mask_1_obj = torch.zeros([std_map_small.shape[0], 1, std_map_small.shape[2], std_map_small.shape[3]]).cuda() 138 | mask_1_bck = torch.zeros([std_map_small.shape[0], 1, std_map_small.shape[2], std_map_small.shape[3]]).cuda() 139 | mask_0_obj[std_map_small[:, 0:1, ...] < 0.05] = 1.0 140 | mask_0_bck[std_map_small[:, 0:1, ...] < 0.05] = 1.0 141 | mask_1_obj[std_map_small[:, 1:, ...] < 0.05] = 1.0 142 | mask_1_bck[std_map_small[:, 1:, ...] < 0.05] = 1.0 143 | mask_0 = mask_0_obj + mask_0_bck 144 | mask_1 = mask_1_obj + mask_1_bck 145 | mask = torch.cat((mask_0, mask_1), dim=1) 146 | 147 | feature_0_obj = feature * target_0_obj*mask_0_obj;feature_1_obj = feature * target_1_obj*mask_1_obj 148 | feature_0_bck = feature * target_0_bck*mask_0_bck;feature_1_bck = feature * target_1_bck*mask_1_bck 149 | 150 | centroid_0_obj = torch.sum(feature_0_obj*prediction_small[:,0:1,...], dim=[0,2,3], keepdim=True) 151 | centroid_1_obj = torch.sum(feature_1_obj*prediction_small[:,1:,...], dim=[0,2,3], keepdim=True) 152 | centroid_0_bck = torch.sum(feature_0_bck*(1.0-prediction_small[:,0:1,...]), dim=[0,2,3], keepdim=True) 153 | centroid_1_bck = torch.sum(feature_1_bck*(1.0-prediction_small[:,1:,...]), dim=[0,2,3], keepdim=True) 154 | target_0_obj_cnt = torch.sum(mask_0_obj*target_0_obj*prediction_small[:,0:1,...], dim=[0,2,3], keepdim=True) 155 | target_1_obj_cnt = torch.sum(mask_1_obj*target_1_obj*prediction_small[:,1:,...], dim=[0,2,3], keepdim=True) 156 | target_0_bck_cnt = torch.sum(mask_0_bck*target_0_bck*(1.0-prediction_small[:,0:1,...]), dim=[0,2,3], keepdim=True) 157 | target_1_bck_cnt = torch.sum(mask_1_bck*target_1_bck*(1.0-prediction_small[:,1:,...]), dim=[0,2,3], keepdim=True) 158 | 159 | centroid_0_obj /= target_0_obj_cnt; centroid_1_obj /= target_1_obj_cnt 160 | centroid_0_bck /= target_0_bck_cnt; centroid_1_bck /= target_1_bck_cnt 161 | 162 | distance_0_obj = torch.sum(torch.pow(feature - centroid_0_obj, 2), dim=1, keepdim=True) 163 | distance_0_bck = torch.sum(torch.pow(feature - centroid_0_bck, 2), dim=1, keepdim=True) 164 | distance_1_obj = torch.sum(torch.pow(feature - centroid_1_obj, 2), dim=1, keepdim=True) 165 | distance_1_bck = torch.sum(torch.pow(feature - centroid_1_bck, 2), dim=1, keepdim=True) 166 | 167 | proto_pseudo_0 = torch.zeros([data.shape[0], 1, feature.shape[2], feature.shape[3]]).cuda() 168 | proto_pseudo_1 = torch.zeros([data.shape[0], 1, feature.shape[2], feature.shape[3]]).cuda() 169 | 170 | proto_pseudo_0[distance_0_obj < distance_0_bck] = 1.0 171 | proto_pseudo_1[distance_1_obj < distance_1_bck] = 1.0 172 | proto_pseudo = torch.cat((proto_pseudo_0, proto_pseudo_1), dim=1) 173 | proto_pseudo = F.interpolate(proto_pseudo, size=data.size()[2:], mode='nearest') 174 | 175 | debugc = 1 176 | 177 | pseudo_label = pseudo_label.detach().cpu().numpy() 178 | 179 | # save pseudo label image 180 | # mask_oc = construct_color_img(pseudo_label[0][0,:,:]) 181 | # mask_oc = cv2.cvtColor(mask_oc, cv2.COLOR_BGR2GRAY) 182 | # print(pseudo_label.shape) 183 | # mask_od = construct_color_img(pseudo_label[0][1,:,:]) 184 | # mask_od = cv2.cvtColor(mask_od, cv2.COLOR_BGR2GRAY) 185 | # cv2.imwrite(os.path.join('/mnt/data1/llr_data/results/pseudolabel/boundary', args.dataset, 'oc',sample['img_name'][0]), mask_oc) 186 | # cv2.imwrite(os.path.join('/mnt/data1/llr_data/results/pseudolabel/boundary', args.dataset, 'od',sample['img_name'][0]), mask_od) 187 | 188 | # save pseudo boundary image 189 | # mask_oc = construct_color_img(pseudo_label[0][0,:,:]) 190 | # mask_oc = cv2.cvtColor(mask_oc, cv2.COLOR_BGR2GRAY) 191 | # img = np.zeros((512,512,3),np.uint8) 192 | # #img = Image.fromarray(img) 193 | # ret, thresh = cv2.threshold(mask_oc, 255, 255, 0) 194 | # contours, im = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) #第一个参数是轮廓 195 | # result_oc = cv2.drawContours(image=img, contours=contours, contourIdx=-1, color=(255, 255, 255), thickness=5) 196 | 197 | # mask_od = cv2.imread(os.path.join('/mnt/data1/llr_data/results/pseudolabel/Domain2/oc',sample['img_name'][0])) 198 | # mask_od = cv2.cvtColor(mask_od, cv2.COLOR_BGR2GRAY) 199 | # ret, thresh = cv2.threshold(mask_od, 127, 255, 0) 200 | # contours, im = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) #第一个参数是轮廓 201 | # result_od = cv2.drawContours(image=img, contours=contours, contourIdx=-1, color=(255, 255, 255), thickness=5) 202 | # cv2.imwrite(os.path.join('/mnt/data1/llr_data/results/pseudolabel/Domain2/boundary/oc',sample['img_name'][0]), result_oc) 203 | # cv2.imwrite(os.path.join('/mnt/data1/llr_data/results/pseudolabel/Domain2/boundary/od',sample['img_name'][0]), result_od) 204 | 205 | 206 | std_map = std_map.detach().cpu().numpy() 207 | proto_pseudo = proto_pseudo.detach().cpu().numpy() 208 | distance_0_obj = distance_0_obj.detach().cpu().numpy() 209 | distance_0_bck = distance_0_bck.detach().cpu().numpy() 210 | distance_1_obj = distance_1_obj.detach().cpu().numpy() 211 | distance_1_bck = distance_1_bck.detach().cpu().numpy() 212 | centroid_0_obj = centroid_0_obj.detach().cpu().numpy() 213 | centroid_0_bck = centroid_0_bck.detach().cpu().numpy() 214 | centroid_1_obj = centroid_1_obj.detach().cpu().numpy() 215 | centroid_1_bck = centroid_1_bck.detach().cpu().numpy() 216 | for i in range(prediction.shape[0]): 217 | pseudo_label_dic[img_name[i]] = pseudo_label[i] 218 | uncertain_dic[img_name[i]] = std_map[i] 219 | proto_pseudo_dic[img_name[i]] = proto_pseudo[i] 220 | distance_0_obj_dic[img_name[i]] = distance_0_obj[i] 221 | distance_0_bck_dic[img_name[i]] = distance_0_bck[i] 222 | distance_1_obj_dic[img_name[i]] = distance_1_obj[i] 223 | distance_1_bck_dic[img_name[i]] = distance_1_bck[i] 224 | centroid_0_obj_dic[img_name[i]] = centroid_0_obj 225 | centroid_0_bck_dic[img_name[i]] = centroid_0_bck 226 | centroid_1_obj_dic[img_name[i]] = centroid_1_obj 227 | centroid_1_bck_dic[img_name[i]] = centroid_1_bck 228 | 229 | if args.dataset=="Domain1":#pseudolabel_D1 230 | np.savez('/mnt/data1/llr_data/results/cell/pseudolabel_D1', pseudo_label_dic, uncertain_dic, proto_pseudo_dic, 231 | distance_0_obj_dic, distance_0_bck_dic, distance_1_obj_dic, distance_1_bck_dic, 232 | centroid_0_obj_dic, centroid_0_bck_dic, centroid_1_obj_dic, centroid_1_bck_dic 233 | ) 234 | 235 | elif args.dataset=="Domain2": 236 | np.savez('/mnt/data1/llr_data/results/pseudolabel_D2', pseudo_label_dic, uncertain_dic, proto_pseudo_dic, 237 | distance_0_obj_dic, distance_0_bck_dic, distance_1_obj_dic, distance_1_bck_dic, 238 | centroid_0_obj_dic, centroid_0_bck_dic, centroid_1_obj_dic, centroid_1_bck_dic 239 | ) 240 | elif args.dataset=="RIGA": 241 | np.savez('/mnt/data1/llr_data/results/prototype/test_D6', pseudo_label_dic, uncertain_dic, proto_pseudo_dic, 242 | distance_0_obj_dic, distance_0_bck_dic, distance_1_obj_dic, distance_1_bck_dic, 243 | centroid_0_obj_dic, centroid_0_bck_dic, centroid_1_obj_dic, centroid_1_bck_dic 244 | ) 245 | 246 | 247 | 248 | 249 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import medpy.metric.binary as medmetric 4 | 5 | bce = torch.nn.BCEWithLogitsLoss(reduction='none') 6 | 7 | def _upscan(f): 8 | for i, fi in enumerate(f): 9 | if fi == np.inf: continue 10 | for j in range(1,i+1): 11 | x = fi+j*j 12 | if f[i-j] < x: break 13 | f[i-j] = x 14 | 15 | 16 | def dice_coefficient_numpy(binary_segmentation, binary_gt_label): 17 | ''' 18 | Compute the Dice coefficient between two binary segmentation. 19 | Dice coefficient is defined as here: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient 20 | Input: 21 | binary_segmentation: binary 2D numpy array representing the region of interest as segmented by the algorithm 22 | binary_gt_label: binary 2D numpy array representing the region of interest as provided in the database 23 | Output: 24 | dice_value: Dice coefficient between the segmentation and the ground truth 25 | ''' 26 | 27 | # turn all variables to booleans, just in case 28 | binary_segmentation = np.asarray(binary_segmentation, dtype=np.bool) 29 | binary_gt_label = np.asarray(binary_gt_label, dtype=np.bool) 30 | 31 | # compute the intersection 32 | intersection = np.logical_and(binary_segmentation, binary_gt_label) 33 | 34 | # count the number of True pixels in the binary segmentation 35 | # segmentation_pixels = float(np.sum(binary_segmentation.flatten())) 36 | segmentation_pixels = np.sum(binary_segmentation.astype(float), axis=(1,2)) 37 | # same for the ground truth 38 | # gt_label_pixels = float(np.sum(binary_gt_label.flatten())) 39 | gt_label_pixels = np.sum(binary_gt_label.astype(float), axis=(1,2)) 40 | # same for the intersection 41 | intersection = np.sum(intersection.astype(float), axis=(1,2)) 42 | 43 | # compute the Dice coefficient 44 | dice_value = (2 * intersection + 1.0) / (1.0 + segmentation_pixels + gt_label_pixels) 45 | 46 | # return it 47 | return dice_value 48 | 49 | 50 | def dice_numpy_medpy(binary_segmentation, binary_gt_label): 51 | 52 | # turn all variables to booleans, just in case 53 | binary_segmentation = np.asarray(binary_segmentation) 54 | binary_gt_label = np.asarray(binary_gt_label) 55 | 56 | return medmetric.dc(binary_segmentation, binary_gt_label) 57 | 58 | 59 | # if get_hd: 60 | # if np.sum(binary_segmentation) > 0 and np.sum(binary_gt_label) > 0: 61 | # return medmetric.assd(binary_segmentation, binary_gt_label) 62 | # # return medmetric.hd(binary_segmentation, binary_gt_label) 63 | # else: 64 | # return np.nan 65 | # else: 66 | # return 0.0 67 | 68 | 69 | def hd_numpy(binary_segmentation, binary_gt_label, get_hd): 70 | 71 | # turn all variables to booleans, just in case 72 | binary_segmentation = np.asarray(binary_segmentation) 73 | binary_gt_label = np.asarray(binary_gt_label) 74 | 75 | if get_hd: 76 | if np.sum(binary_segmentation) > 0 and np.sum(binary_gt_label) > 0: 77 | return medmetric.assd(binary_segmentation, binary_gt_label) 78 | # return medmetric.hd(binary_segmentation, binary_gt_label) 79 | else: 80 | return np.nan 81 | else: 82 | return 0.0 83 | 84 | 85 | def dice_coeff(pred, target): 86 | """This definition generalize to real valued pred and target vector. 87 | This should be differentiable. 88 | pred: tensor with first dimension as batch 89 | target: tensor with first dimension as batch 90 | """ 91 | 92 | target = target.data.cpu() 93 | pred = torch.sigmoid(pred) 94 | pred = pred.data.cpu() 95 | pred[pred > 0.5] = 1 96 | pred[pred <= 0.5] = 0 97 | 98 | return dice_coefficient_numpy(pred, target) 99 | 100 | def dice_coeff_2label(pred, target): 101 | """This definition generalize to real valued pred and target vector. 102 | This should be differentiable. 103 | pred: tensor with first dimension as batch 104 | target: tensor with first dimension as batch 105 | """ 106 | 107 | target = target.data.cpu() 108 | pred = torch.sigmoid(pred) 109 | pred = pred.data.cpu() 110 | pred[pred > 0.75] = 1 111 | pred[pred <= 0.75] = 0 112 | # print target.shape 113 | # print pred.shape 114 | # return dice_coefficient_numpy(pred[:, 0, ...], target[:, 0, ...]), dice_coefficient_numpy(pred[:, 1, ...], target[:, 1, ...]) 115 | return dice_coefficient_numpy(pred[:, 0, ...], target[:, 0, ...]) 116 | 117 | 118 | def DiceLoss(input, target): 119 | ''' 120 | in tensor fomate 121 | :param input: 122 | :param target: 123 | :return: 124 | ''' 125 | smooth = 1. 126 | iflat = input.contiguous().view(-1) 127 | tflat = target.contiguous().view(-1) 128 | intersection = (iflat * tflat).sum() 129 | 130 | return 1 - ((2. * intersection + smooth) / 131 | (iflat.sum() + tflat.sum() + smooth)) 132 | -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | class Path(object): 2 | @staticmethod 3 | def db_root_dir(database): 4 | if database == 'fundus': 5 | return '../../../../data/disc_cup_split/' # foler that contains leftImg8bit/ 6 | else: 7 | print('Database {} not available.'.format(database)) 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /networks/GAN.py: -------------------------------------------------------------------------------- 1 | # camera-ready 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | 8 | class Discriminator(nn.Module): 9 | def __init__(self, ): 10 | super(Discriminator, self).__init__() 11 | 12 | filter_num_list = [4096, 2048, 1024, 1] 13 | 14 | self.fc1 = nn.Linear(24576, filter_num_list[0]) 15 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 16 | self.fc2 = nn.Linear(filter_num_list[0], filter_num_list[1]) 17 | self.fc3 = nn.Linear(filter_num_list[1], filter_num_list[2]) 18 | self.fc4 = nn.Linear(filter_num_list[2], filter_num_list[3]) 19 | 20 | # self.sigmoid = nn.Sigmoid() 21 | self._initialize_weights() 22 | 23 | 24 | def _initialize_weights(self): 25 | 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | m.weight.data.normal_(0.0, 0.02) 29 | if m.bias is not None: 30 | m.bias.data.zero_() 31 | 32 | if isinstance(m, nn.ConvTranspose2d): 33 | m.weight.data.normal_(0.0, 0.02) 34 | if m.bias is not None: 35 | m.bias.data.zero_() 36 | 37 | if isinstance(m, nn.Linear): 38 | m.weight.data.normal_(0.0, 0.02) 39 | if m.bias is not None: 40 | # m.bias.data.copy_(1.0) 41 | m.bias.data.zero_() 42 | 43 | 44 | def forward(self, x): 45 | 46 | x = self.leakyrelu(self.fc1(x)) 47 | x = self.leakyrelu(self.fc2(x)) 48 | x = self.leakyrelu(self.fc3(x)) 49 | x = self.fc4(x) 50 | return x 51 | 52 | 53 | class OutputDiscriminator(nn.Module): 54 | def __init__(self, ): 55 | super(OutputDiscriminator, self).__init__() 56 | 57 | filter_num_list = [64, 128, 256, 512, 1] 58 | 59 | self.conv1 = nn.Conv2d(2, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 60 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 61 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 62 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 63 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 64 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 65 | # self.sigmoid = nn.Sigmoid() 66 | self._initialize_weights() 67 | 68 | 69 | def _initialize_weights(self): 70 | for m in self.modules(): 71 | if isinstance(m, nn.Conv2d): 72 | m.weight.data.normal_(0.0, 0.02) 73 | if m.bias is not None: 74 | m.bias.data.zero_() 75 | 76 | 77 | def forward(self, x): 78 | x = self.leakyrelu(self.conv1(x)) 79 | x = self.leakyrelu(self.conv2(x)) 80 | x = self.leakyrelu(self.conv3(x)) 81 | x = self.leakyrelu(self.conv4(x)) 82 | x = self.conv5(x) 83 | return x 84 | 85 | 86 | class UncertaintyDiscriminator(nn.Module): 87 | def __init__(self, ): 88 | super(UncertaintyDiscriminator, self).__init__() 89 | 90 | filter_num_list = [64, 128, 256, 512, 1] 91 | 92 | self.conv1 = nn.Conv2d(2, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 93 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 94 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 95 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 96 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 97 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 98 | # self.sigmoid = nn.Sigmoid() 99 | self._initialize_weights() 100 | 101 | 102 | def _initialize_weights(self): 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | m.weight.data.normal_(0.0, 0.02) 106 | if m.bias is not None: 107 | m.bias.data.zero_() 108 | 109 | 110 | def forward(self, x): 111 | x = self.leakyrelu(self.conv1(x)) 112 | x = self.leakyrelu(self.conv2(x)) 113 | x = self.leakyrelu(self.conv3(x)) 114 | x = self.leakyrelu(self.conv4(x)) 115 | x = self.conv5(x) 116 | return x 117 | 118 | class BoundaryDiscriminator(nn.Module): 119 | def __init__(self, ): 120 | super(BoundaryDiscriminator, self).__init__() 121 | 122 | filter_num_list = [64, 128, 256, 512, 1] 123 | 124 | self.conv1 = nn.Conv2d(1, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 125 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 126 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 127 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 128 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 129 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 130 | # self.sigmoid = nn.Sigmoid() 131 | self._initialize_weights() 132 | 133 | 134 | def _initialize_weights(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | m.weight.data.normal_(0.0, 0.02) 138 | if m.bias is not None: 139 | m.bias.data.zero_() 140 | 141 | 142 | def forward(self, x): 143 | x = self.leakyrelu(self.conv1(x)) 144 | x = self.leakyrelu(self.conv2(x)) 145 | x = self.leakyrelu(self.conv3(x)) 146 | x = self.leakyrelu(self.conv4(x)) 147 | x = self.conv5(x) 148 | return x 149 | 150 | class BoundaryEntDiscriminator(nn.Module): 151 | def __init__(self, ): 152 | super(BoundaryEntDiscriminator, self).__init__() 153 | 154 | filter_num_list = [64, 128, 256, 512, 1] 155 | 156 | self.conv1 = nn.Conv2d(3, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 157 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 158 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 159 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 160 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 161 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 162 | # self.sigmoid = nn.Sigmoid() 163 | self._initialize_weights() 164 | 165 | 166 | def _initialize_weights(self): 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | m.weight.data.normal_(0.0, 0.02) 170 | if m.bias is not None: 171 | m.bias.data.zero_() 172 | 173 | 174 | def forward(self, x): 175 | x = self.leakyrelu(self.conv1(x)) 176 | x = self.leakyrelu(self.conv2(x)) 177 | x = self.leakyrelu(self.conv3(x)) 178 | x = self.leakyrelu(self.conv4(x)) 179 | x = self.conv5(x) 180 | return x 181 | 182 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /networks/__pycache__/GAN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/GAN.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/GAN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/GAN.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/aspp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/aspp.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/aspp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/aspp.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/aspp_eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/aspp_eval.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/aspp_eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/aspp_eval.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/decoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/decoder.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/deeplabv3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/deeplabv3.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/deeplabv3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/deeplabv3.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/deeplabv3_eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/deeplabv3_eval.cpython-37.pyc -------------------------------------------------------------------------------- /networks/__pycache__/deeplabv3_eval.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/deeplabv3_eval.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /networks/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | # self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 60 | # nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 61 | # nn.ReLU()) 62 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 63 | self.bn1 = BatchNorm(256) 64 | self.relu = nn.ReLU() 65 | self.dropout = nn.Dropout(0.5) 66 | self._init_weight() 67 | 68 | def forward(self, x): 69 | x1 = self.aspp1(x) 70 | x2 = self.aspp2(x) 71 | x3 = self.aspp3(x) 72 | x4 = self.aspp4(x) 73 | x5 = self.global_avg_pool(x) 74 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 75 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 76 | 77 | x = self.conv1(x) 78 | x = self.bn1(x) 79 | x = self.relu(x) 80 | 81 | return self.dropout(x) 82 | 83 | def _init_weight(self): 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 87 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 88 | torch.nn.init.kaiming_normal_(m.weight) 89 | elif isinstance(m, SynchronizedBatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | elif isinstance(m, nn.BatchNorm2d): 93 | m.weight.data.fill_(1) 94 | m.bias.data.zero_() 95 | 96 | 97 | def build_aspp(backbone, output_stride, BatchNorm): 98 | return ASPP(backbone, output_stride, BatchNorm) 99 | -------------------------------------------------------------------------------- /networks/aspp_eval.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | # self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | # nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | # BatchNorm(256), 58 | # nn.ReLU()) 59 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 60 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 61 | nn.ReLU()) 62 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 63 | self.bn1 = BatchNorm(256) 64 | self.relu = nn.ReLU() 65 | self.dropout = nn.Dropout(0.5) 66 | self._init_weight() 67 | 68 | def forward(self, x): 69 | x1 = self.aspp1(x) 70 | x2 = self.aspp2(x) 71 | x3 = self.aspp3(x) 72 | x4 = self.aspp4(x) 73 | x5 = self.global_avg_pool(x) 74 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 75 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 76 | 77 | x = self.conv1(x) 78 | x = self.bn1(x) 79 | x = self.relu(x) 80 | 81 | return self.dropout(x) 82 | 83 | def _init_weight(self): 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 87 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 88 | torch.nn.init.kaiming_normal_(m.weight) 89 | elif isinstance(m, SynchronizedBatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | elif isinstance(m, nn.BatchNorm2d): 93 | m.weight.data.fill_(1) 94 | m.bias.data.zero_() 95 | 96 | 97 | def build_aspp(backbone, output_stride, BatchNorm): 98 | return ASPP(backbone, output_stride, BatchNorm) 99 | -------------------------------------------------------------------------------- /networks/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /networks/backbone/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /networks/backbone/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/backbone/__pycache__/drn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/drn.cpython-37.pyc -------------------------------------------------------------------------------- /networks/backbone/__pycache__/drn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/drn.cpython-38.pyc -------------------------------------------------------------------------------- /networks/backbone/__pycache__/mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /networks/backbone/__pycache__/mobilenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/mobilenet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/backbone/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /networks/backbone/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/backbone/__pycache__/xception.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/xception.cpython-37.pyc -------------------------------------------------------------------------------- /networks/backbone/__pycache__/xception.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/backbone/__pycache__/xception.cpython-38.pyc -------------------------------------------------------------------------------- /networks/backbone/drn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | webroot = 'https://tigress-web.princeton.edu/~fy/drn/models/' 7 | 8 | model_urls = { 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth', 11 | 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth', 12 | 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth', 13 | 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth', 14 | 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth', 15 | 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth', 16 | 'drn-d-105': webroot + 'drn_d_105-12b40979.pth' 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=padding, bias=False, dilation=dilation) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, 29 | dilation=(1, 1), residual=True, BatchNorm=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride, 32 | padding=dilation[0], dilation=dilation[0]) 33 | self.bn1 = BatchNorm(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes, 36 | padding=dilation[1], dilation=dilation[1]) 37 | self.bn2 = BatchNorm(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | self.residual = residual 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | if self.residual: 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, 65 | dilation=(1, 1), residual=True, BatchNorm=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = BatchNorm(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=dilation[1], bias=False, 71 | dilation=dilation[1]) 72 | self.bn2 = BatchNorm(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = BatchNorm(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class DRN(nn.Module): 103 | 104 | def __init__(self, block, layers, arch='D', 105 | channels=(16, 32, 64, 128, 256, 512, 512, 512), 106 | BatchNorm=None): 107 | super(DRN, self).__init__() 108 | self.inplanes = channels[0] 109 | self.out_dim = channels[-1] 110 | self.arch = arch 111 | 112 | if arch == 'C': 113 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 114 | padding=3, bias=False) 115 | self.bn1 = BatchNorm(channels[0]) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | self.layer1 = self._make_layer( 119 | BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 120 | self.layer2 = self._make_layer( 121 | BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 122 | 123 | elif arch == 'D': 124 | self.layer0 = nn.Sequential( 125 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, 126 | bias=False), 127 | BatchNorm(channels[0]), 128 | nn.ReLU(inplace=True) 129 | ) 130 | 131 | self.layer1 = self._make_conv_layers( 132 | channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 133 | self.layer2 = self._make_conv_layers( 134 | channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 135 | 136 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm) 137 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm) 138 | self.layer5 = self._make_layer(block, channels[4], layers[4], 139 | dilation=2, new_level=False, BatchNorm=BatchNorm) 140 | self.layer6 = None if layers[5] == 0 else \ 141 | self._make_layer(block, channels[5], layers[5], dilation=4, 142 | new_level=False, BatchNorm=BatchNorm) 143 | 144 | if arch == 'C': 145 | self.layer7 = None if layers[6] == 0 else \ 146 | self._make_layer(BasicBlock, channels[6], layers[6], dilation=2, 147 | new_level=False, residual=False, BatchNorm=BatchNorm) 148 | self.layer8 = None if layers[7] == 0 else \ 149 | self._make_layer(BasicBlock, channels[7], layers[7], dilation=1, 150 | new_level=False, residual=False, BatchNorm=BatchNorm) 151 | elif arch == 'D': 152 | self.layer7 = None if layers[6] == 0 else \ 153 | self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm) 154 | self.layer8 = None if layers[7] == 0 else \ 155 | self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm) 156 | 157 | self._init_weight() 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, SynchronizedBatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 173 | new_level=True, residual=True, BatchNorm=None): 174 | assert dilation == 1 or dilation % 2 == 0 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | nn.Conv2d(self.inplanes, planes * block.expansion, 179 | kernel_size=1, stride=stride, bias=False), 180 | BatchNorm(planes * block.expansion), 181 | ) 182 | 183 | layers = list() 184 | layers.append(block( 185 | self.inplanes, planes, stride, downsample, 186 | dilation=(1, 1) if dilation == 1 else ( 187 | dilation // 2 if new_level else dilation, dilation), 188 | residual=residual, BatchNorm=BatchNorm)) 189 | self.inplanes = planes * block.expansion 190 | for i in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, residual=residual, 192 | dilation=(dilation, dilation), BatchNorm=BatchNorm)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None): 197 | modules = [] 198 | for i in range(convs): 199 | modules.extend([ 200 | nn.Conv2d(self.inplanes, channels, kernel_size=3, 201 | stride=stride if i == 0 else 1, 202 | padding=dilation, bias=False, dilation=dilation), 203 | BatchNorm(channels), 204 | nn.ReLU(inplace=True)]) 205 | self.inplanes = channels 206 | return nn.Sequential(*modules) 207 | 208 | def forward(self, x): 209 | if self.arch == 'C': 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | elif self.arch == 'D': 214 | x = self.layer0(x) 215 | 216 | x = self.layer1(x) 217 | x = self.layer2(x) 218 | 219 | x = self.layer3(x) 220 | low_level_feat = x 221 | 222 | x = self.layer4(x) 223 | x = self.layer5(x) 224 | 225 | if self.layer6 is not None: 226 | x = self.layer6(x) 227 | 228 | if self.layer7 is not None: 229 | x = self.layer7(x) 230 | 231 | if self.layer8 is not None: 232 | x = self.layer8(x) 233 | 234 | return x, low_level_feat 235 | 236 | 237 | class DRN_A(nn.Module): 238 | 239 | def __init__(self, block, layers, BatchNorm=None): 240 | self.inplanes = 64 241 | super(DRN_A, self).__init__() 242 | self.out_dim = 512 * block.expansion 243 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 244 | bias=False) 245 | self.bn1 = BatchNorm(64) 246 | self.relu = nn.ReLU(inplace=True) 247 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 248 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm) 249 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm) 250 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 251 | dilation=2, BatchNorm=BatchNorm) 252 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 253 | dilation=4, BatchNorm=BatchNorm) 254 | 255 | self._init_weight() 256 | 257 | def _init_weight(self): 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 261 | m.weight.data.normal_(0, math.sqrt(2. / n)) 262 | elif isinstance(m, SynchronizedBatchNorm2d): 263 | m.weight.data.fill_(1) 264 | m.bias.data.zero_() 265 | elif isinstance(m, nn.BatchNorm2d): 266 | m.weight.data.fill_(1) 267 | m.bias.data.zero_() 268 | 269 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 270 | downsample = None 271 | if stride != 1 or self.inplanes != planes * block.expansion: 272 | downsample = nn.Sequential( 273 | nn.Conv2d(self.inplanes, planes * block.expansion, 274 | kernel_size=1, stride=stride, bias=False), 275 | BatchNorm(planes * block.expansion), 276 | ) 277 | 278 | layers = [] 279 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm)) 280 | self.inplanes = planes * block.expansion 281 | for i in range(1, blocks): 282 | layers.append(block(self.inplanes, planes, 283 | dilation=(dilation, dilation, ), BatchNorm=BatchNorm)) 284 | 285 | return nn.Sequential(*layers) 286 | 287 | def forward(self, x): 288 | x = self.conv1(x) 289 | x = self.bn1(x) 290 | x = self.relu(x) 291 | x = self.maxpool(x) 292 | 293 | x = self.layer1(x) 294 | x = self.layer2(x) 295 | x = self.layer3(x) 296 | x = self.layer4(x) 297 | 298 | return x 299 | 300 | def drn_a_50(BatchNorm, pretrained=True): 301 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 304 | return model 305 | 306 | 307 | def drn_c_26(BatchNorm, pretrained=True): 308 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm) 309 | if pretrained: 310 | pretrained = model_zoo.load_url(model_urls['drn-c-26']) 311 | del pretrained['fc.weight'] 312 | del pretrained['fc.bias'] 313 | model.load_state_dict(pretrained) 314 | return model 315 | 316 | 317 | def drn_c_42(BatchNorm, pretrained=True): 318 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 319 | if pretrained: 320 | pretrained = model_zoo.load_url(model_urls['drn-c-42']) 321 | del pretrained['fc.weight'] 322 | del pretrained['fc.bias'] 323 | model.load_state_dict(pretrained) 324 | return model 325 | 326 | 327 | def drn_c_58(BatchNorm, pretrained=True): 328 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 329 | if pretrained: 330 | pretrained = model_zoo.load_url(model_urls['drn-c-58']) 331 | del pretrained['fc.weight'] 332 | del pretrained['fc.bias'] 333 | model.load_state_dict(pretrained) 334 | return model 335 | 336 | 337 | def drn_d_22(BatchNorm, pretrained=True): 338 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm) 339 | if pretrained: 340 | pretrained = model_zoo.load_url(model_urls['drn-d-22']) 341 | del pretrained['fc.weight'] 342 | del pretrained['fc.bias'] 343 | model.load_state_dict(pretrained) 344 | return model 345 | 346 | 347 | def drn_d_24(BatchNorm, pretrained=True): 348 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm) 349 | if pretrained: 350 | pretrained = model_zoo.load_url(model_urls['drn-d-24']) 351 | del pretrained['fc.weight'] 352 | del pretrained['fc.bias'] 353 | model.load_state_dict(pretrained) 354 | return model 355 | 356 | 357 | def drn_d_38(BatchNorm, pretrained=True): 358 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 359 | if pretrained: 360 | pretrained = model_zoo.load_url(model_urls['drn-d-38']) 361 | del pretrained['fc.weight'] 362 | del pretrained['fc.bias'] 363 | model.load_state_dict(pretrained) 364 | return model 365 | 366 | 367 | def drn_d_40(BatchNorm, pretrained=True): 368 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm) 369 | if pretrained: 370 | pretrained = model_zoo.load_url(model_urls['drn-d-40']) 371 | del pretrained['fc.weight'] 372 | del pretrained['fc.bias'] 373 | model.load_state_dict(pretrained) 374 | return model 375 | 376 | 377 | def drn_d_54(BatchNorm, pretrained=True): 378 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 379 | if pretrained: 380 | pretrained = model_zoo.load_url(model_urls['drn-d-54']) 381 | del pretrained['fc.weight'] 382 | del pretrained['fc.bias'] 383 | model.load_state_dict(pretrained) 384 | return model 385 | 386 | 387 | def drn_d_105(BatchNorm, pretrained=True): 388 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 389 | if pretrained: 390 | pretrained = model_zoo.load_url(model_urls['drn-d-105']) 391 | del pretrained['fc.weight'] 392 | del pretrained['fc.bias'] 393 | model.load_state_dict(pretrained) 394 | return model 395 | 396 | if __name__ == "__main__": 397 | import torch 398 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True) 399 | input = torch.rand(1, 3, 512, 512) 400 | output, low_level_feat = model(input) 401 | print(output.size()) 402 | print(low_level_feat.size()) 403 | -------------------------------------------------------------------------------- /networks/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /networks/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class ResNet(nn.Module): 46 | 47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 48 | self.inplanes = 64 49 | super(ResNet, self).__init__() 50 | blocks = [1, 2, 4] 51 | if output_stride == 16: 52 | strides = [1, 2, 2, 1] 53 | dilations = [1, 1, 1, 2] 54 | elif output_stride == 8: 55 | strides = [1, 2, 1, 1] 56 | dilations = [1, 1, 2, 4] 57 | else: 58 | raise NotImplementedError 59 | 60 | # Modules 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = BatchNorm(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 72 | self._init_weight() 73 | 74 | if pretrained: 75 | self._load_pretrained_model() 76 | 77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 78 | downsample = None 79 | if stride != 1 or self.inplanes != planes * block.expansion: 80 | downsample = nn.Sequential( 81 | nn.Conv2d(self.inplanes, planes * block.expansion, 82 | kernel_size=1, stride=stride, bias=False), 83 | BatchNorm(planes * block.expansion), 84 | ) 85 | 86 | layers = [] 87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 88 | self.inplanes = planes * block.expansion 89 | for i in range(1, blocks): 90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 95 | downsample = None 96 | if stride != 1 or self.inplanes != planes * block.expansion: 97 | downsample = nn.Sequential( 98 | nn.Conv2d(self.inplanes, planes * block.expansion, 99 | kernel_size=1, stride=stride, bias=False), 100 | BatchNorm(planes * block.expansion), 101 | ) 102 | 103 | layers = [] 104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 105 | downsample=downsample, BatchNorm=BatchNorm)) 106 | self.inplanes = planes * block.expansion 107 | for i in range(1, len(blocks)): 108 | layers.append(block(self.inplanes, planes, stride=1, 109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, input): 114 | x = self.conv1(input) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | low_level_feat = x 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | return x, low_level_feat 125 | 126 | def _init_weight(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, SynchronizedBatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | 138 | def _load_pretrained_model(self): 139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 140 | model_dict = {} 141 | state_dict = self.state_dict() 142 | for k, v in pretrain_dict.items(): 143 | if k in state_dict: 144 | model_dict[k] = v 145 | state_dict.update(model_dict) 146 | self.load_state_dict(state_dict) 147 | 148 | def ResNet101(output_stride, BatchNorm, pretrained=True): 149 | """Constructs a ResNet-101 model. 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 154 | return model 155 | 156 | if __name__ == "__main__": 157 | import torch 158 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 159 | input = torch.rand(1, 3, 512, 512) 160 | output, low_level_feat = model(input) 161 | print(output.size()) 162 | print(low_level_feat.size()) 163 | -------------------------------------------------------------------------------- /networks/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return x, low_level_feat 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in model_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) 289 | -------------------------------------------------------------------------------- /networks/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential( 23 | # nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 24 | # BatchNorm(256), 25 | # nn.ReLU(), 26 | # nn.Dropout(0.5), 27 | # nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 28 | BatchNorm(305), 29 | nn.ReLU(), 30 | nn.Dropout(0.1), 31 | nn.Conv2d(305, num_classes, kernel_size=1, stride=1)) 32 | self.last_conv_boundary = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 33 | BatchNorm(256), 34 | nn.ReLU(), 35 | nn.Dropout(0.5), 36 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 37 | BatchNorm(256), 38 | nn.ReLU(), 39 | nn.Dropout(0.1), 40 | nn.Conv2d(256, 1, kernel_size=1, stride=1)) 41 | self._init_weight() 42 | 43 | 44 | def forward(self, x, low_level_feat): 45 | low_level_feat = self.conv1(low_level_feat) 46 | low_level_feat = self.bn1(low_level_feat) 47 | low_level_feat = self.relu(low_level_feat) 48 | 49 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 50 | x = torch.cat((x, low_level_feat), dim=1) 51 | boundary = self.last_conv_boundary(x) 52 | x = torch.cat([x, boundary], 1) 53 | x1 = self.last_conv(x) 54 | 55 | return x1, boundary, x 56 | 57 | def _init_weight(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | torch.nn.init.kaiming_normal_(m.weight) 61 | elif isinstance(m, SynchronizedBatchNorm2d): 62 | m.weight.data.fill_(1) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.BatchNorm2d): 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | 68 | def build_decoder(num_classes, backbone, BatchNorm): 69 | return Decoder(num_classes, backbone, BatchNorm) 70 | -------------------------------------------------------------------------------- /networks/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from networks.aspp import build_aspp 6 | from networks.decoder import build_decoder 7 | from networks.backbone import build_backbone 8 | 9 | 10 | class DeepLab(nn.Module): 11 | def __init__(self, backbone='resnet', output_stride=16, num_classes=2, 12 | sync_bn=True, freeze_bn=False): 13 | super(DeepLab, self).__init__() 14 | if backbone == 'drn': 15 | output_stride = 8 16 | 17 | if sync_bn == True: 18 | BatchNorm = SynchronizedBatchNorm2d 19 | else: 20 | BatchNorm = nn.BatchNorm2d 21 | 22 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 24 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 25 | 26 | if freeze_bn: 27 | self.freeze_bn() 28 | 29 | def forward(self, input): 30 | x, low_level_feat = self.backbone(input) 31 | x = self.aspp(x) 32 | feature = x 33 | x1, x2, feature_last = self.decoder(x, low_level_feat) 34 | 35 | x2 = F.interpolate(x2, size=input.size()[2:], mode='bilinear', align_corners=True) 36 | x1 = F.interpolate(x1, size=input.size()[2:], mode='bilinear', align_corners=True) 37 | 38 | return x1, x2, feature_last 39 | 40 | def freeze_bn(self): 41 | for m in self.modules(): 42 | if isinstance(m, SynchronizedBatchNorm2d): 43 | m.eval() 44 | elif isinstance(m, nn.BatchNorm2d): 45 | m.eval() 46 | 47 | def get_1x_lr_params(self): 48 | modules = [self.backbone] 49 | for i in range(len(modules)): 50 | for m in modules[i].named_modules(): 51 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 52 | or isinstance(m[1], nn.BatchNorm2d): 53 | for p in m[1].parameters(): 54 | if p.requires_grad: 55 | yield p 56 | 57 | def get_10x_lr_params(self): 58 | modules = [self.aspp, self.decoder] 59 | for i in range(len(modules)): 60 | for m in modules[i].named_modules(): 61 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 62 | or isinstance(m[1], nn.BatchNorm2d): 63 | for p in m[1].parameters(): 64 | if p.requires_grad: 65 | yield p 66 | 67 | 68 | if __name__ == "__main__": 69 | model = DeepLab(backbone='mobilenet', output_stride=16) 70 | model.eval() 71 | input = torch.rand(1, 3, 513, 513) 72 | output = model(input) 73 | print(output.size()) 74 | 75 | 76 | -------------------------------------------------------------------------------- /networks/deeplabv3_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from networks.aspp_eval import build_aspp 6 | from networks.decoder import build_decoder 7 | from networks.backbone import build_backbone 8 | 9 | 10 | class DeepLab(nn.Module): 11 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 12 | sync_bn=True, freeze_bn=False): 13 | super(DeepLab, self).__init__() 14 | if backbone == 'drn': 15 | output_stride = 8 16 | 17 | if sync_bn == True: 18 | BatchNorm = SynchronizedBatchNorm2d 19 | else: 20 | BatchNorm = nn.BatchNorm2d 21 | 22 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 24 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 25 | 26 | if freeze_bn: 27 | self.freeze_bn() 28 | 29 | def forward(self, input): 30 | x, low_level_feat = self.backbone(input) 31 | x = self.aspp(x) 32 | feature = x 33 | x1, x2, feature_last = self.decoder(x, low_level_feat) 34 | 35 | x2 = F.interpolate(x2, size=input.size()[2:], mode='bilinear', align_corners=True) 36 | x1 = F.interpolate(x1, size=input.size()[2:], mode='bilinear', align_corners=True) 37 | return x1, x2, feature_last 38 | 39 | def freeze_bn(self): 40 | for m in self.modules(): 41 | if isinstance(m, SynchronizedBatchNorm2d): 42 | m.eval() 43 | elif isinstance(m, nn.BatchNorm2d): 44 | m.eval() 45 | 46 | def get_1x_lr_params(self): 47 | modules = [self.backbone] 48 | for i in range(len(modules)): 49 | for m in modules[i].named_modules(): 50 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 51 | or isinstance(m[1], nn.BatchNorm2d): 52 | for p in m[1].parameters(): 53 | if p.requires_grad: 54 | yield p 55 | 56 | def get_10x_lr_params(self): 57 | modules = [self.aspp, self.decoder] 58 | for i in range(len(modules)): 59 | for m in modules[i].named_modules(): 60 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 61 | or isinstance(m[1], nn.BatchNorm2d): 62 | for p in m[1].parameters(): 63 | if p.requires_grad: 64 | yield p 65 | 66 | 67 | if __name__ == "__main__": 68 | model = DeepLab(backbone='mobilenet', output_stride=16) 69 | model.eval() 70 | input = torch.rand(1, 3, 513, 513) 71 | output = model(input) 72 | print(output.size()) 73 | 74 | 75 | -------------------------------------------------------------------------------- /networks/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrappers for the operations to take the meta-learning gradient 3 | updates into account. 4 | """ 5 | import torch.autograd as autograd 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | def linear(inputs, weight, bias, meta_step_size=0.001, meta_loss=None, stop_gradient=False): 11 | inputs = inputs.cuda() 12 | weight = weight.cuda() 13 | bias = bias.cuda() 14 | 15 | if meta_loss is not None: 16 | 17 | if not stop_gradient: 18 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True)[0] 19 | 20 | if bias is not None: 21 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True)[0] 22 | bias_adapt = bias - grad_bias * meta_step_size 23 | else: 24 | bias_adapt = bias 25 | 26 | else: 27 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True)[0].data, requires_grad=False) 28 | 29 | if bias is not None: 30 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True)[0].data, requires_grad=False) 31 | bias_adapt = bias - grad_bias * meta_step_size 32 | else: 33 | bias_adapt = bias 34 | 35 | return F.linear(inputs, 36 | weight - grad_weight * meta_step_size, 37 | bias_adapt) 38 | else: 39 | return F.linear(inputs, weight, bias) 40 | 41 | def conv2d(inputs, weight, bias, stride=1, padding=1, dilation=1, groups=1, kernel_size=3): 42 | 43 | inputs = inputs.cuda() 44 | weight = weight.cuda() 45 | bias = bias.cuda() 46 | 47 | return F.conv2d(inputs, weight, bias, stride, padding, dilation, groups) 48 | 49 | 50 | def deconv2d(inputs, weight, bias, stride=2, padding=0, dilation=0, groups=1, kernel_size=None): 51 | 52 | inputs = inputs.cuda() 53 | weight = weight.cuda() 54 | bias = bias.cuda() 55 | 56 | return F.conv_transpose2d(inputs, weight, bias, stride, padding, dilation, groups) 57 | 58 | def relu(inputs): 59 | return F.relu(inputs, inplace=True) 60 | 61 | 62 | def maxpool(inputs, kernel_size, stride=None, padding=0): 63 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding) 64 | 65 | 66 | def dropout(inputs): 67 | return F.dropout(inputs, p=0.5, training=False, inplace=False) 68 | 69 | def batchnorm(inputs, running_mean, running_var): 70 | return F.batch_norm(inputs, running_mean, running_var) 71 | 72 | 73 | """ 74 | The following are the new methods for 2D-Unet: 75 | Conv2d, batchnorm2d, GroupNorm, InstanceNorm2d, MaxPool2d, UpSample 76 | """ 77 | #as per the 2D Unet: kernel_size, stride, padding 78 | 79 | def instancenorm(input): 80 | return F.instance_norm(input) 81 | 82 | def groupnorm(input): 83 | return F.group_norm(input) 84 | 85 | def dropout2D(inputs): 86 | return F.dropout2d(inputs, p=0.5, training=False, inplace=False) 87 | 88 | def maxpool2D(inputs, kernel_size, stride=None, padding=0): 89 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding) 90 | 91 | def upsample(input): 92 | return F.upsample(input, scale_factor=2, mode='bilinear', align_corners=False) 93 | -------------------------------------------------------------------------------- /networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/sync_batchnorm/__pycache__/batchnorm.cpython-37.pyc -------------------------------------------------------------------------------- /networks/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc -------------------------------------------------------------------------------- /networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/sync_batchnorm/__pycache__/comm.cpython-37.pyc -------------------------------------------------------------------------------- /networks/sync_batchnorm/__pycache__/comm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/networks/sync_batchnorm/__pycache__/comm.cpython-38.pyc -------------------------------------------------------------------------------- /networks/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 142 | as the built-in PyTorch implementation. 143 | The mean and standard-deviation are calculated per-dimension over 144 | the mini-batches and gamma and beta are learnable parameter vectors 145 | of size C (where C is the input size). 146 | During training, this layer keeps a running estimate of its computed mean 147 | and variance. The running sum is kept with a default momentum of 0.1. 148 | During evaluation, this running mean/variance is used for normalization. 149 | Because the BatchNorm is done over the `C` dimension, computing statistics 150 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 151 | Args: 152 | num_features: num_features from an expected input of size 153 | `batch_size x num_features [x width]` 154 | eps: a value added to the denominator for numerical stability. 155 | Default: 1e-5 156 | momentum: the value used for the running_mean and running_var 157 | computation. Default: 0.1 158 | affine: a boolean value that when set to ``True``, gives the layer learnable 159 | affine parameters. Default: ``True`` 160 | Shape: 161 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 162 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 163 | Examples: 164 | >>> # With Learnable Parameters 165 | >>> m = SynchronizedBatchNorm1d(100) 166 | >>> # Without Learnable Parameters 167 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 168 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 169 | >>> output = m(input) 170 | """ 171 | 172 | def _check_input_dim(self, input): 173 | if input.dim() != 2 and input.dim() != 3: 174 | raise ValueError('expected 2D or 3D input (got {}D input)' 175 | .format(input.dim())) 176 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 177 | 178 | 179 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 180 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 181 | of 3d inputs 182 | .. math:: 183 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 184 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 185 | standard-deviation are reduced across all devices during training. 186 | For example, when one uses `nn.DataParallel` to wrap the network during 187 | training, PyTorch's implementation normalize the tensor on each device using 188 | the statistics only on that device, which accelerated the computation and 189 | is also easy to implement, but the statistics might be inaccurate. 190 | Instead, in this synchronized version, the statistics will be computed 191 | over all training samples distributed on multiple devices. 192 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 193 | as the built-in PyTorch implementation. 194 | The mean and standard-deviation are calculated per-dimension over 195 | the mini-batches and gamma and beta are learnable parameter vectors 196 | of size C (where C is the input size). 197 | During training, this layer keeps a running estimate of its computed mean 198 | and variance. The running sum is kept with a default momentum of 0.1. 199 | During evaluation, this running mean/variance is used for normalization. 200 | Because the BatchNorm is done over the `C` dimension, computing statistics 201 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 202 | Args: 203 | num_features: num_features from an expected input of 204 | size batch_size x num_features x height x width 205 | eps: a value added to the denominator for numerical stability. 206 | Default: 1e-5 207 | momentum: the value used for the running_mean and running_var 208 | computation. Default: 0.1 209 | affine: a boolean value that when set to ``True``, gives the layer learnable 210 | affine parameters. Default: ``True`` 211 | Shape: 212 | - Input: :math:`(N, C, H, W)` 213 | - Output: :math:`(N, C, H, W)` (same shape as input) 214 | Examples: 215 | >>> # With Learnable Parameters 216 | >>> m = SynchronizedBatchNorm2d(100) 217 | >>> # Without Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 219 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 220 | >>> output = m(input) 221 | """ 222 | 223 | def _check_input_dim(self, input): 224 | if input.dim() != 4: 225 | raise ValueError('expected 4D input (got {}D input)' 226 | .format(input.dim())) 227 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 228 | 229 | 230 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 231 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 232 | of 4d inputs 233 | .. math:: 234 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 235 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 236 | standard-deviation are reduced across all devices during training. 237 | For example, when one uses `nn.DataParallel` to wrap the network during 238 | training, PyTorch's implementation normalize the tensor on each device using 239 | the statistics only on that device, which accelerated the computation and 240 | is also easy to implement, but the statistics might be inaccurate. 241 | Instead, in this synchronized version, the statistics will be computed 242 | over all training samples distributed on multiple devices. 243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 244 | as the built-in PyTorch implementation. 245 | The mean and standard-deviation are calculated per-dimension over 246 | the mini-batches and gamma and beta are learnable parameter vectors 247 | of size C (where C is the input size). 248 | During training, this layer keeps a running estimate of its computed mean 249 | and variance. The running sum is kept with a default momentum of 0.1. 250 | During evaluation, this running mean/variance is used for normalization. 251 | Because the BatchNorm is done over the `C` dimension, computing statistics 252 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 253 | or Spatio-temporal BatchNorm 254 | Args: 255 | num_features: num_features from an expected input of 256 | size batch_size x num_features x depth x height x width 257 | eps: a value added to the denominator for numerical stability. 258 | Default: 1e-5 259 | momentum: the value used for the running_mean and running_var 260 | computation. Default: 0.1 261 | affine: a boolean value that when set to ``True``, gives the layer learnable 262 | affine parameters. Default: ``True`` 263 | Shape: 264 | - Input: :math:`(N, C, D, H, W)` 265 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 266 | Examples: 267 | >>> # With Learnable Parameters 268 | >>> m = SynchronizedBatchNorm3d(100) 269 | >>> # Without Learnable Parameters 270 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 271 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 272 | >>> output = m(input) 273 | """ 274 | 275 | def _check_input_dim(self, input): 276 | if input.dim() != 5: 277 | raise ValueError('expected 5D input (got {}D input)' 278 | .format(input.dim())) 279 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /networks/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.nn as nn 3 | 4 | class qkv_transform(nn.Conv1d): 5 | """Conv1d for qkv_transform""" 6 | 7 | def str2bool(v): 8 | if v.lower() in ['true', 1]: 9 | return True 10 | elif v.lower() in ['false', 0]: 11 | return False 12 | else: 13 | raise argparse.ArgumentTypeError('Boolean value expected.') 14 | 15 | 16 | def count_params(model): 17 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 18 | 19 | 20 | class AverageMeter(object): 21 | """Computes and stores the average and current value""" 22 | 23 | def __init__(self): 24 | self.reset() 25 | 26 | def reset(self): 27 | self.val = 0 28 | self.avg = 0 29 | self.sum = 0 30 | self.count = 0 31 | 32 | def update(self, val, n=1): 33 | self.val = val 34 | self.sum += val * n 35 | self.count += n 36 | self.avg = self.sum / self.count 37 | -------------------------------------------------------------------------------- /train_process/Trainer.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import os.path as osp 4 | import timeit 5 | from torchvision.utils import make_grid 6 | import time 7 | 8 | import numpy as np 9 | import pytz 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | 14 | from tensorboardX import SummaryWriter 15 | 16 | import tqdm 17 | import socket 18 | from metrics import * 19 | from utils.Utils import * 20 | import scipy.optimize as sopt 21 | 22 | bceloss = torch.nn.BCELoss() 23 | mseloss = torch.nn.MSELoss() 24 | 25 | def get_lr(optimizer): 26 | for param_group in optimizer.param_groups: 27 | return param_group['lr'] 28 | 29 | 30 | class Trainer(object): 31 | 32 | def __init__(self, cuda, model_gen, model_dis, model_uncertainty_dis, optimizer_gen, optimizer_dis, optimizer_uncertainty_dis, 33 | val_loader, domain_loaderS, domain_loaderT, out, max_epoch, stop_epoch=None, 34 | lr_gen=1e-3, lr_dis=1e-3, lr_decrease_rate=0.1, interval_validate=None, batch_size=8, warmup_epoch=10): 35 | self.cuda = cuda 36 | self.warmup_epoch = warmup_epoch 37 | self.model_gen = model_gen 38 | self.model_dis2 = model_uncertainty_dis 39 | self.model_dis = model_dis 40 | self.optim_gen = optimizer_gen 41 | self.optim_dis = optimizer_dis 42 | self.optim_dis2 = optimizer_uncertainty_dis 43 | self.lr_gen = lr_gen 44 | self.lr_dis = lr_dis 45 | self.lr_decrease_rate = lr_decrease_rate 46 | self.batch_size = batch_size 47 | 48 | self.val_loader = val_loader 49 | self.domain_loaderS = domain_loaderS 50 | self.domain_loaderT = domain_loaderT 51 | self.time_zone = 'Asia/Hong_Kong' 52 | self.timestamp_start = \ 53 | datetime.now(pytz.timezone(self.time_zone)) 54 | 55 | if interval_validate is None: 56 | self.interval_validate = int(10) 57 | else: 58 | self.interval_validate = interval_validate 59 | 60 | self.out = out 61 | if not osp.exists(self.out): 62 | os.makedirs(self.out) 63 | 64 | self.log_headers = [ 65 | 'epoch', 66 | 'iteration', 67 | 'train/loss_seg', 68 | 'train/cup_dice', 69 | 'train/disc_dice', 70 | 'train/loss_adv', 71 | 'train/loss_D_same', 72 | 'train/loss_D_diff', 73 | 'valid/loss_CE', 74 | 'valid/cup_dice', 75 | 'valid/disc_dice', 76 | 'elapsed_time', 77 | ] 78 | if not osp.exists(osp.join(self.out, 'log.csv')): 79 | with open(osp.join(self.out, 'log.csv'), 'w') as f: 80 | f.write(','.join(self.log_headers) + '\n') 81 | 82 | log_dir = os.path.join(self.out, 'tensorboard', 83 | datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 84 | self.writer = SummaryWriter(log_dir=log_dir) 85 | 86 | self.epoch = 0 87 | self.iteration = 0 88 | self.max_epoch = max_epoch 89 | self.stop_epoch = stop_epoch if stop_epoch is not None else max_epoch 90 | self.best_disc_dice = 0.0 91 | self.running_loss_tr = 0.0 92 | self.running_adv_diff_loss = 0.0 93 | self.running_adv_same_loss = 0.0 94 | self.best_mean_dice = 0.0 95 | self.best_epoch = -1 96 | 97 | 98 | def validate(self): 99 | training = self.model_gen.training 100 | self.model_gen.eval() 101 | 102 | val_loss = 0.0 103 | val_cup_dice = 0.0 104 | val_disc_dice = 0.0 105 | datanum_cnt = 0.0 106 | metrics = [] 107 | with torch.no_grad(): 108 | 109 | for batch_idx, sample in tqdm.tqdm( 110 | enumerate(self.val_loader), total=len(self.val_loader), 111 | desc='Valid iteration=%d' % self.iteration, ncols=80, 112 | leave=False): 113 | data = sample['image'] 114 | target_map = sample['map'] 115 | target_boundary = sample['boundary'] 116 | 117 | 118 | if self.cuda: 119 | data, target_map, target_boundary = data.cuda(), target_map.cuda(), target_boundary.cuda() 120 | predictions, boundary, _ = self.model_gen(data) 121 | with torch.no_grad(): 122 | predictions, boundary, _ = self.model_gen(data) 123 | 124 | loss = F.binary_cross_entropy_with_logits(predictions, target_map) 125 | 126 | loss_data = loss.data.item() 127 | if np.isnan(loss_data): 128 | raise ValueError('loss is nan while validating') 129 | val_loss += loss_data 130 | 131 | dice_cup = dice_coeff_2label(predictions, target_map) 132 | dice_disc = dice_coeff_2label(predictions, target_map)# 133 | val_cup_dice += np.sum(dice_cup) 134 | val_disc_dice += np.sum(dice_disc) 135 | datanum_cnt += float(dice_cup.shape[0]) 136 | val_loss /= datanum_cnt 137 | val_cup_dice /= datanum_cnt 138 | val_disc_dice /= datanum_cnt 139 | metrics.append((val_loss, val_cup_dice)) 140 | self.writer.add_scalar('val_data/loss_CE', val_loss, self.epoch * (len(self.domain_loaderS))) 141 | self.writer.add_scalar('val_data/val_CUP_dice', val_cup_dice, self.epoch * (len(self.domain_loaderS))) 142 | self.writer.add_scalar('val_data/val_DISC_dice', val_disc_dice, self.epoch * (len(self.domain_loaderS))) 143 | 144 | mean_dice = val_cup_dice #+ val_disc_dice 145 | is_best = mean_dice > self.best_mean_dice 146 | if is_best: 147 | self.best_epoch = self.epoch + 1 148 | self.best_mean_dice = mean_dice 149 | 150 | torch.save({ 151 | 'epoch': self.epoch, 152 | 'iteration': self.iteration, 153 | 'arch': self.model_gen.__class__.__name__, 154 | 'optim_state_dict': self.optim_gen.state_dict(), 155 | 'optim_dis_state_dict': self.optim_dis.state_dict(), 156 | 'optim_dis2_state_dict': self.optim_dis2.state_dict(), 157 | 'model_state_dict': self.model_gen.state_dict(), 158 | 'model_dis_state_dict': self.model_dis.state_dict(), 159 | 'model_dis2_state_dict': self.model_dis2.state_dict(), 160 | 'learning_rate_gen': get_lr(self.optim_gen), 161 | 'learning_rate_dis': get_lr(self.optim_dis), 162 | 'learning_rate_dis2': get_lr(self.optim_dis2), 163 | 'best_mean_dice': self.best_mean_dice, 164 | }, osp.join(self.out, 'checkpoint_%d.pth.tar' % self.best_epoch)) 165 | else: 166 | if (self.epoch + 1) >150 :#% 10 == 0: 167 | torch.save({ 168 | 'epoch': self.epoch, 169 | 'iteration': self.iteration, 170 | 'arch': self.model_gen.__class__.__name__, 171 | 'optim_state_dict': self.optim_gen.state_dict(), 172 | 'optim_dis_state_dict': self.optim_dis.state_dict(), 173 | 'optim_dis2_state_dict': self.optim_dis2.state_dict(), 174 | 'model_state_dict': self.model_gen.state_dict(), 175 | 'model_dis_state_dict': self.model_dis.state_dict(), 176 | 'model_dis2_state_dict': self.model_dis2.state_dict(), 177 | 'learning_rate_gen': get_lr(self.optim_gen), 178 | 'learning_rate_dis': get_lr(self.optim_dis), 179 | 'learning_rate_dis2': get_lr(self.optim_dis2), 180 | 'best_mean_dice': self.best_mean_dice, 181 | }, osp.join(self.out, 'checkpoint_%d.pth.tar' % (self.epoch + 1))) 182 | 183 | 184 | with open(osp.join(self.out, 'log.csv'), 'a') as f: 185 | elapsed_time = ( 186 | datetime.now(pytz.timezone(self.time_zone)) - 187 | self.timestamp_start).total_seconds() 188 | log = [self.epoch, self.iteration] + [''] * 5 + \ 189 | list(metrics) + [elapsed_time] + ['best model epoch: %d' % self.best_epoch] 190 | log = map(str, log) 191 | f.write(','.join(log) + '\n') 192 | self.writer.add_scalar('best_model_epoch', self.best_epoch, self.epoch * (len(self.domain_loaderS))) 193 | if training: 194 | self.model_gen.train() 195 | self.model_dis.train() 196 | self.model_dis2.train() 197 | 198 | 199 | 200 | def train_epoch(self): 201 | source_domain_label = 1 202 | target_domain_label = 0 203 | smooth = 1e-7 204 | self.model_gen.train() 205 | self.model_dis.train() 206 | self.model_dis2.train() 207 | self.running_seg_loss = 0.0 208 | self.running_adv_loss = 0.0 209 | self.running_dis_diff_loss = 0.0 210 | self.running_dis_same_loss = 0.0 211 | self.running_total_loss = 0.0 212 | self.running_cup_dice_tr = 0.0 213 | self.running_disc_dice_tr = 0.0 214 | loss_adv_diff_data = 0 215 | loss_D_same_data = 0 216 | loss_D_diff_data = 0 217 | 218 | domain_t_loader = enumerate(self.domain_loaderT) 219 | start_time = timeit.default_timer() 220 | for batch_idx, sampleS in tqdm.tqdm( 221 | enumerate(self.domain_loaderS), total=len(self.domain_loaderS), 222 | desc='Train epoch=%d' % self.epoch, ncols=80, leave=False): 223 | 224 | metrics = [] 225 | 226 | iteration = batch_idx + self.epoch * len(self.domain_loaderS) 227 | self.iteration = iteration 228 | 229 | assert self.model_gen.training 230 | assert self.model_dis.training 231 | assert self.model_dis2.training 232 | 233 | self.optim_gen.zero_grad() 234 | self.optim_dis.zero_grad() 235 | self.optim_dis2.zero_grad() 236 | 237 | # 1. train generator with random images 238 | for param in self.model_dis.parameters(): 239 | param.requires_grad = False 240 | for param in self.model_dis2.parameters(): 241 | param.requires_grad = False 242 | for param in self.model_gen.parameters(): 243 | param.requires_grad = True 244 | 245 | imageS = sampleS['image'].cuda() 246 | target_map = sampleS['map'].cuda() 247 | target_boundary = sampleS['boundary'].cuda() 248 | 249 | oS, boundaryS, feature = self.model_gen(imageS) 250 | 251 | loss_seg1 = bceloss(torch.sigmoid(oS), target_map) 252 | loss_seg2 = mseloss(torch.sigmoid(boundaryS), target_boundary) 253 | 254 | loss_seg = loss_seg1 +loss_seg2#+ global_loss 255 | 256 | self.running_seg_loss += loss_seg.item() 257 | loss_seg_data = loss_seg.data.item() 258 | if np.isnan(loss_seg_data): 259 | raise ValueError('loss is nan while training') 260 | 261 | loss_seg.backward() 262 | self.optim_gen.step() 263 | 264 | # write image log 265 | if iteration % 30 == 0: 266 | grid_image = make_grid( 267 | imageS[0, ...].clone().cpu().data, 1, normalize=True) 268 | self.writer.add_image('DomainS/image', grid_image, iteration) 269 | grid_image = make_grid( 270 | target_map[0, 0, ...].clone().cpu().data, 1, normalize=True) 271 | self.writer.add_image('DomainS/target_cup', grid_image, iteration) 272 | # grid_image = make_grid( 273 | # target_map[0, 1, ...].clone().cpu().data, 1, normalize=True) 274 | self.writer.add_image('DomainS/target_disc', grid_image, iteration) 275 | grid_image = make_grid( 276 | target_boundary[0, 0, ...].clone().cpu().data, 1, normalize=True) 277 | self.writer.add_image('DomainS/target_boundary', grid_image, iteration) 278 | grid_image = make_grid(torch.sigmoid(oS)[0, 0, ...].clone().cpu().data, 1, normalize=True) 279 | self.writer.add_image('DomainS/prediction_cup', grid_image, iteration) 280 | # grid_image = make_grid(torch.sigmoid(oS)[0, 1, ...].clone().cpu().data, 1, normalize=True) 281 | # self.writer.add_image('DomainS/prediction_disc', grid_image, iteration) 282 | grid_image = make_grid(torch.sigmoid(boundaryS)[0, 0, ...].clone().cpu().data, 1, normalize=True) 283 | self.writer.add_image('DomainS/prediction_boundary', grid_image, iteration) 284 | 285 | self.writer.add_scalar('train_gen/loss_seg', loss_seg_data, iteration) 286 | 287 | metrics.append((loss_seg_data, loss_adv_diff_data, loss_D_same_data, loss_D_diff_data)) 288 | metrics = np.mean(metrics, axis=0) 289 | 290 | with open(osp.join(self.out, 'log.csv'), 'a') as f: 291 | elapsed_time = ( 292 | datetime.now(pytz.timezone(self.time_zone)) - 293 | self.timestamp_start).total_seconds() 294 | log = [self.epoch, self.iteration] + \ 295 | metrics.tolist() + [''] * 5 + [elapsed_time] 296 | log = map(str, log) 297 | f.write(','.join(log) + '\n') 298 | 299 | self.running_seg_loss /= len(self.domain_loaderS) 300 | self.running_adv_diff_loss /= len(self.domain_loaderS) 301 | self.running_dis_same_loss /= len(self.domain_loaderS) 302 | self.running_dis_diff_loss /= len(self.domain_loaderS) 303 | 304 | stop_time = timeit.default_timer() 305 | 306 | print('\n[Epoch: %d] lr:%f, Average segLoss: %f, ' 307 | ' Average advLoss: %f, Average dis_same_Loss: %f, ' 308 | 'Average dis_diff_Lyoss: %f,' 309 | 'Execution time: %.5f' % 310 | (self.epoch, get_lr(self.optim_gen), self.running_seg_loss, 311 | self.running_adv_diff_loss, 312 | self.running_dis_same_loss, self.running_dis_diff_loss, stop_time - start_time)) 313 | 314 | 315 | def train(self): 316 | for epoch in tqdm.trange(self.epoch, self.max_epoch, 317 | desc='Train', ncols=80): 318 | self.epoch = epoch 319 | self.train_epoch() 320 | if self.stop_epoch == self.epoch: 321 | print('Stop epoch at %d' % self.stop_epoch) 322 | break 323 | 324 | if (epoch+1) % 100 == 0: 325 | _lr_gen = self.lr_gen * 0.2 326 | for param_group in self.optim_gen.param_groups: 327 | param_group['lr'] = _lr_gen 328 | self.writer.add_scalar('lr_gen', get_lr(self.optim_gen), self.epoch * (len(self.domain_loaderS))) 329 | # if (self.epoch+1) % self.interval_validate == 0: 330 | if (self.epoch + 1) % 5 == 0: 331 | self.validate() 332 | self.writer.close() 333 | 334 | 335 | 336 | -------------------------------------------------------------------------------- /train_process/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /train_process/__pycache__/Trainer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/train_process/__pycache__/Trainer.cpython-37.pyc -------------------------------------------------------------------------------- /train_process/__pycache__/Trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/train_process/__pycache__/Trainer.cpython-38.pyc -------------------------------------------------------------------------------- /train_process/__pycache__/Trainer_fgsm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/train_process/__pycache__/Trainer_fgsm.cpython-38.pyc -------------------------------------------------------------------------------- /train_process/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/train_process/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /train_process/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/train_process/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /train_source.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import os.path as osp 4 | 5 | # PyTorch includes 6 | import torch 7 | from torchvision import transforms 8 | from torch.utils.data import DataLoader 9 | import argparse 10 | import yaml 11 | from train_process import Trainer 12 | 13 | # Custom includes 14 | from dataloaders import fundus_dataloader as DL #fundus_dataloader 15 | from dataloaders import custom_transforms as tr 16 | from networks.deeplabv3 import * 17 | from networks.GAN import BoundaryDiscriminator, UncertaintyDiscriminator 18 | 19 | 20 | here = osp.dirname(osp.abspath(__file__)) 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser( 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 25 | ) 26 | parser.add_argument('-g', '--gpu', type=int, default=0, help='gpu id') 27 | parser.add_argument('--resume', default=None, help='checkpoint path') 28 | 29 | # configurations (same configuration as original work) 30 | # https://github.com/shelhamer/fcn.berkeleyvision.org 31 | parser.add_argument( 32 | '--datasetS', type=str, default='Domain3', help='test folder id contain images ROIs to test' 33 | ) 34 | parser.add_argument( 35 | '--datasetT', type=str, default='Domain1', help='refuge / Drishti-GS/ RIM-ONE_r3' 36 | ) 37 | parser.add_argument( 38 | '--batch-size', type=int, default=8, help='batch size for training the model' 39 | ) 40 | parser.add_argument( 41 | '--group-num', type=int, default=1, help='group number for group normalization' 42 | ) 43 | parser.add_argument( 44 | '--max-epoch', type=int, default=200, help='max epoch' 45 | ) 46 | parser.add_argument( 47 | '--stop-epoch', type=int, default=200, help='stop epoch' 48 | ) 49 | parser.add_argument( 50 | '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN' 51 | ) 52 | 53 | parser.add_argument( 54 | '--interval-validate', type=int, default=10, help='interval epoch number to valide the model' 55 | ) 56 | parser.add_argument( 57 | '--lr-gen', type=float, default=1e-3, help='learning rate', 58 | ) 59 | parser.add_argument( 60 | '--lr-dis', type=float, default=2.5e-5, help='learning rate', 61 | ) 62 | parser.add_argument( 63 | '--lr-decrease-rate', type=float, default=0.1, help='ratio multiplied to initial lr', 64 | ) 65 | parser.add_argument( 66 | '--weight-decay', type=float, default=0.0005, help='weight decay', 67 | ) 68 | parser.add_argument( 69 | '--momentum', type=float, default=0.99, help='momentum', 70 | ) 71 | parser.add_argument( 72 | '--data-dir', 73 | default='/mnt/data1/llr_data/Fundus', 74 | help='data root path' 75 | ) 76 | parser.add_argument( 77 | '--out-stride', 78 | type=int, 79 | default=16, 80 | help='out-stride of deeplabv3+', 81 | ) 82 | parser.add_argument( 83 | '--sync-bn', 84 | type=bool, 85 | default=True, 86 | help='sync-bn in deeplabv3+', 87 | ) 88 | parser.add_argument( 89 | '--freeze-bn', 90 | type=bool, 91 | default=False, 92 | help='freeze batch normalization of deeplabv3+', 93 | ) 94 | 95 | args = parser.parse_args() 96 | 97 | args.model = 'FCN8s' 98 | 99 | now = datetime.now() 100 | args.out = osp.join(here, 'logs/', args.datasetT, now.strftime('%Y%m%d_%H%M%S.%f')) 101 | 102 | os.makedirs(args.out) 103 | with open(osp.join(args.out, 'config.yaml'), 'w') as f: 104 | yaml.safe_dump(args.__dict__, f, default_flow_style=False) 105 | 106 | 107 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 108 | cuda = torch.cuda.is_available() 109 | 110 | torch.manual_seed(1337) 111 | if cuda: 112 | torch.cuda.manual_seed(1337) 113 | 114 | # 1. dataset 115 | # cartoon_synth_dataset = DL.FundusSegmentation_aug( 116 | # base_dir=args.data_dir, dataset=args.datasetS, split='train/ROIs', 117 | # mirror=True, 118 | # #crop=crop, 119 | # imgaug='cartoon' 120 | # ) 121 | 122 | 123 | # domain_loader = DataLoader(cartoon_synth_dataset, batch_size=4, shuffle=False, num_workers=0, pin_memory=True) 124 | # for batch_idx, (sample) in enumerate(domain_loader): 125 | # data, img_name = sample['image'], sample['img_name'] 126 | composed_transforms_tr = transforms.Compose([ 127 | tr.RandomScaleCrop(512), 128 | #tr.RandomRotate(), 129 | #tr.RandomFlip(), 130 | tr.elastic_transform(), 131 | tr.add_salt_pepper_noise(), 132 | tr.adjust_light(), 133 | tr.eraser(), 134 | tr.Normalize_tf(), 135 | tr.ToTensor() 136 | ]) 137 | 138 | composed_transforms_ts = transforms.Compose([ 139 | # tr.RandomCrop(512), 140 | tr.Resize(512), 141 | tr.Normalize_tf(), 142 | tr.ToTensor() 143 | ]) 144 | 145 | domain = DL.FundusSegmentation_pgd(base_dir=args.data_dir, dataset=args.datasetS, split='train/ROIs/', transform=composed_transforms_tr) 146 | domain_loaderS = DataLoader(domain, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True) 147 | 148 | domain_T = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='train/ROIs/', transform=composed_transforms_tr) 149 | domain_loaderT = DataLoader(domain_T, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True) 150 | 151 | domain_val = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.datasetT, split='test/ROIs/', transform=composed_transforms_ts) 152 | domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True) 153 | #domain = DL.sourceDataSet(base_dir=args.data_dir, dataset=args.datasetS, split='train/', transform=composed_transforms_tr) 154 | # domain = DL.sourceDataSet(root_img='/mnt/data1/llr_data/EM/VNC3/training/', 155 | # root_label='/mnt/data1/llr_data/EM/VNC3/training_groundtruth/', 156 | # list_path='/mnt/data1/llr_data/EM/VNC3/train.txt', 157 | # crop_size=(512, 512), 158 | # stride=1) 159 | # domain_loaderS = DataLoader(domain, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True) 160 | 161 | # #domain_T = DL.sourceDataSet(base_dir=args.data_dir, dataset=args.datasetT, split='train/', transform=composed_transforms_tr) 162 | # domain_T = DL.sourceDataSet(root_img='/mnt/data1/llr_data/EM/Lucchi/training/', 163 | # root_label='/mnt/data1/llr_data/EM/Lucchi/training_groundtruth/', 164 | # list_path='/mnt/data1/llr_data/EM/Lucchi/train.txt', 165 | # crop_size=(512, 512), 166 | # stride=1) 167 | # domain_loaderT = DataLoader(domain_T, batch_size=args.batch_size, shuffle=False, num_workers=8, pin_memory=True) 168 | 169 | # #domain_val = DL.sourceDataSet(base_dir=args.data_dir, dataset=args.datasetT, split='test/', transform=composed_transforms_ts) 170 | # domain_val = DL.sourceDataSet(root_img='/mnt/data1/llr_data/EM/Lucchi/testing/', 171 | # root_label='/mnt/data1/llr_data/EM/Lucchi/testing_groundtruth/', 172 | # list_path='/mnt/data1/llr_data/EM/Lucchi/testing.txt', 173 | # crop_size=(512, 512), 174 | # stride=1) 175 | #domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) 176 | 177 | # 2. model 178 | model_gen = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, 179 | sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() 180 | 181 | model_dis = BoundaryDiscriminator().cuda() 182 | model_dis2 = UncertaintyDiscriminator().cuda() 183 | 184 | start_epoch = 0 185 | start_iteration = 0 186 | 187 | # 3. optimizer 188 | 189 | optim_gen = torch.optim.Adam( 190 | model_gen.parameters(), 191 | lr=args.lr_gen, 192 | betas=(0.9, 0.99) 193 | ) 194 | optim_dis = torch.optim.SGD( 195 | model_dis.parameters(), 196 | lr=args.lr_dis, 197 | momentum=args.momentum, 198 | weight_decay=args.weight_decay 199 | ) 200 | optim_dis2 = torch.optim.SGD( 201 | model_dis2.parameters(), 202 | lr=args.lr_dis, 203 | momentum=args.momentum, 204 | weight_decay=args.weight_decay 205 | ) 206 | 207 | if args.resume: 208 | checkpoint = torch.load(args.resume) 209 | pretrained_dict = checkpoint['model_state_dict'] 210 | model_dict = model_gen.state_dict() 211 | # 1. filter out unnecessary keys 212 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 213 | # 2. overwrite entries in the existing state dict 214 | model_dict.update(pretrained_dict) 215 | # 3. load the new state dict 216 | model_gen.load_state_dict(model_dict) 217 | 218 | pretrained_dict = checkpoint['model_dis_state_dict'] 219 | model_dict = model_dis.state_dict() 220 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 221 | model_dict.update(pretrained_dict) 222 | model_dis.load_state_dict(model_dict) 223 | 224 | pretrained_dict = checkpoint['model_dis2_state_dict'] 225 | model_dict = model_dis2.state_dict() 226 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 227 | model_dict.update(pretrained_dict) 228 | model_dis2.load_state_dict(model_dict) 229 | 230 | 231 | start_epoch = checkpoint['epoch'] + 1 232 | start_iteration = checkpoint['iteration'] + 1 233 | optim_gen.load_state_dict(checkpoint['optim_state_dict']) 234 | optim_dis.load_state_dict(checkpoint['optim_dis_state_dict']) 235 | optim_dis2.load_state_dict(checkpoint['optim_dis2_state_dict']) 236 | 237 | trainer = Trainer.Trainer( 238 | cuda=cuda, 239 | model_gen=model_gen, 240 | model_dis=model_dis, 241 | model_uncertainty_dis=model_dis2, 242 | optimizer_gen=optim_gen, 243 | optimizer_dis=optim_dis, 244 | optimizer_uncertainty_dis=optim_dis2, 245 | lr_gen=args.lr_gen, 246 | lr_dis=args.lr_dis, 247 | lr_decrease_rate=args.lr_decrease_rate, 248 | val_loader=domain_loader_val, 249 | domain_loaderS=domain_loaderS, 250 | domain_loaderT=domain_loaderT, 251 | out=args.out, 252 | max_epoch=args.max_epoch, 253 | stop_epoch=args.stop_epoch, 254 | interval_validate=args.interval_validate, 255 | batch_size=args.batch_size, 256 | warmup_epoch=args.warmup_epoch, 257 | ) 258 | trainer.epoch = start_epoch 259 | trainer.iteration = start_iteration 260 | trainer.train() 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /train_target.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | import argparse 4 | import os 5 | import os.path as osp 6 | import torch.nn.functional as F 7 | 8 | import matplotlib.pyplot as plt 9 | from PIL import Image 10 | import torch 11 | from torch.autograd import Variable 12 | import tqdm 13 | from advent import dataset 14 | from dataloaders import fundus_dataloader as DL 15 | from torch.utils.data import DataLoader 16 | from dataloaders import custom_transforms as tr 17 | from torchvision import transforms 18 | from matplotlib.pyplot import imsave 19 | from utils.Utils import * 20 | from metrics import * 21 | from datetime import datetime 22 | import pytz 23 | import networks.deeplabv3 as netd 24 | import networks.deeplabv3_eval as netd_eval 25 | import cv2 26 | import torch.backends.cudnn as cudnn 27 | import random 28 | from tensorboardX import SummaryWriter 29 | import torch.nn as nn 30 | from Lovaszloss import lovasz_hinge 31 | import imgaug.augmenters as iaa 32 | import torchattacks 33 | from scipy.ndimage import distance_transform_edt 34 | bceloss = torch.nn.BCELoss(reduction='none') 35 | 36 | def entropy_loss(v): 37 | """ 38 | Entropy loss for probabilistic prediction vectors 39 | input: batch_size x channels x h x w 40 | output: batch_size x 1 x h x w 41 | """ 42 | assert v.dim() == 4 43 | n, c, h, w = v.size() 44 | return -torch.sum(torch.mul(v, torch.log2(v + 1e-30))) / (n * h * w * np.log2(c)) 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--model-file', type=str, default='./logs/Domain1/1/checkpoint_180.pth.tar') 50 | parser.add_argument('--dataset', type=str, default='Domain1') 51 | parser.add_argument('--dataset_open', type=str, default='Domain4') #open domain 52 | parser.add_argument('--source', type=str, default='Domain3') 53 | parser.add_argument('-g', '--gpu', type=int, default=0) 54 | parser.add_argument('--data-dir', default='/mnt/data1/llr_data/Fundus/') 55 | parser.add_argument('--out-stride',type=int,default=16) 56 | parser.add_argument('--sync-bn',type=bool,default=True) 57 | parser.add_argument('--freeze-bn',type=bool,default=False) 58 | parser.add_argument( 59 | '--save-root-ent', 60 | type=str, 61 | default='./results/ent/', 62 | help='path to save ent', 63 | ) 64 | parser.add_argument( 65 | '--save-root-mask', 66 | type=str, 67 | default='./results/mask/', 68 | help='path to save mask', 69 | ) 70 | args = parser.parse_args() 71 | 72 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 73 | model_file = args.model_file 74 | 75 | # 1. dataset 76 | composed_transforms_train = transforms.Compose([ 77 | tr.Resize(512), 78 | #tr.RandomFlip(), 79 | tr.add_salt_pepper_noise(), 80 | tr.adjust_light(), 81 | tr.eraser(), 82 | tr.Normalize_tf(), 83 | tr.ToTensor() 84 | ]) 85 | composed_transforms_test = transforms.Compose([ 86 | tr.Resize(512), 87 | #tr.RandomFlip(), 88 | tr.Normalize_tf(), 89 | tr.ToTensor() 90 | ]) 91 | composed_transforms_test1 = transforms.Compose([ 92 | tr.Resize(512), 93 | #tr.RandomFlip(), 94 | tr.Normalize_tf(), 95 | tr.ToTensor() 96 | ]) 97 | db_train = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='train/ROIs', transform=composed_transforms_train) 98 | db_test = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='test/ROIs', transform=composed_transforms_test) 99 | db_source = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.source, split='train/ROIs', transform=composed_transforms_test) 100 | db_open = DL.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset_open, split='test/ROIs', transform=composed_transforms_test) 101 | 102 | train_loader = DataLoader(db_train, batch_size=4, shuffle=False, num_workers=1) 103 | test_loader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) 104 | open_loader = DataLoader(db_open, batch_size=1, shuffle=False, num_workers=1) 105 | source_loader = DataLoader(db_source, batch_size=1, shuffle=False, num_workers=1) 106 | 107 | # 2. model 108 | model = netd.DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) 109 | model_eval = netd_eval.DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).cuda() 110 | 111 | if torch.cuda.is_available(): 112 | model = model.cuda() 113 | print('==> Loading %s model file: %s' % 114 | (model.__class__.__name__, model_file)) 115 | checkpoint = torch.load(model_file) 116 | model.load_state_dict(checkpoint['model_state_dict']) 117 | model.train() 118 | 119 | if args.dataset=="Domain2": 120 | npfilename = './results/mask/D2.npz' 121 | npfilename1 = './results/bound/bound_D2.npz' 122 | elif args.dataset=="Domain1": 123 | npfilename = './results/mask/D1.npz' 124 | npfilename1 = './results/bound/r-bound_D1.npz' 125 | elif args.dataset=="Domain4": # open domain 126 | npfilename = './results/mask/D1.npz' 127 | npfilename1 = './results/bound/bound_D1.npz' 128 | 129 | npdata = np.load(npfilename, allow_pickle=True) 130 | pseudo_label_dic = npdata['arr_0'].item() 131 | uncertain_dic = npdata['arr_1'].item() 132 | proto_pseudo_dic = npdata['arr_2'].item() 133 | 134 | npdata1 = np.load(npfilename1, allow_pickle=True) 135 | pseudo_bound_dic = npdata1['arr_0'].item() 136 | 137 | var_list = model.named_parameters() 138 | 139 | optim_gen = torch.optim.Adam(model.parameters(), lr=0.002, betas=(0.9, 0.99)) 140 | best_val_cup_dice = 0.0; 141 | best_val_disc_dice = 0.0; 142 | best_avg = 0.0 143 | 144 | iter_num = 0 145 | for epoch_num in tqdm.tqdm(range(2), ncols=70): 146 | model.train() 147 | for batch_idx, (sample) in enumerate(train_loader): 148 | data, target, img_name = sample['image'], sample['map'], sample['img_name'] 149 | target_boundary = sample['boundary'] 150 | if torch.cuda.is_available(): 151 | data, target, target_boundary = data.cuda(), target.cuda(), target_boundary.cuda() 152 | data, target, target_boundary = Variable(data), Variable(target), Variable(target_boundary) 153 | prediction, boundaryS, feature = model(data) 154 | 155 | num_classes = 2 156 | pred_s = prediction.permute(0, 2, 3, 1).contiguous().view(-1, num_classes) 157 | pred_s_softmax = F.softmax(pred_s, -1) 158 | prediction = torch.sigmoid(prediction) 159 | 160 | pseudo_label = [pseudo_label_dic.get(key) for key in img_name] 161 | uncertain_map = [uncertain_dic.get(key) for key in img_name] 162 | proto_pseudo = [proto_pseudo_dic.get(key) for key in img_name] 163 | pseudo_bound = [pseudo_bound_dic.get(key) for key in img_name] 164 | 165 | pseudo_label = torch.from_numpy(np.asarray(pseudo_label)).float().cuda() 166 | uncertain_map = torch.from_numpy(np.asarray(uncertain_map)).float().cuda() 167 | proto_pseudo = torch.from_numpy(np.asarray(proto_pseudo)).float().cuda() 168 | pseudo_bound = torch.from_numpy(np.asarray(pseudo_bound)).float().cuda() 169 | 170 | # generate adversarial samples 171 | # atk = torchattacks.PGD(model, eps=4/255, alpha=2/255, steps=4) 172 | # adv_untargeted = atk(data, target) 173 | # #cv2.imwrite(os.path.join('/mnt/data1/llr_data/Fundus/',args.dataset,'/test/image/'+ sample['img_name'][0]), cv2.resize(255*adv_untargeted.data.cpu().numpy()[0,0,:,:],(800,800))) 174 | # cv2.imwrite('/mnt/data1/llr_data/Fundus/PGD/TENT/Domain6/'+ sample['img_name'][0], cv2.resize(adv_untargeted.data.cpu().numpy()[0,0,:,:],(800,800))) 175 | # cv2.imwrite('/mnt/data1/llr_data/Fundus/PGD/TENT/Domain6/'+ sample['img_name'][1], cv2.resize(adv_untargeted.data.cpu().numpy()[1,0,:,:],(800,800))) 176 | 177 | for param in model.parameters(): 178 | param.requires_grad = True 179 | optim_gen.zero_grad() 180 | 181 | target_0_obj = F.interpolate(pseudo_label[:,0:1,...], size=feature.size()[2:], mode='nearest') 182 | target_1_obj = F.interpolate(pseudo_label[:, 1:, ...], size=feature.size()[2:], mode='nearest') 183 | target_0_bck = 1.0 - target_0_obj;target_1_bck = 1.0 - target_1_obj 184 | 185 | mask_0_obj = torch.zeros([pseudo_label.shape[0], 1, pseudo_label.shape[2], pseudo_label.shape[3]]).cuda() 186 | mask_0_bck = torch.zeros([pseudo_label.shape[0], 1, pseudo_label.shape[2], pseudo_label.shape[3]]).cuda() 187 | mask_1_obj = torch.zeros([pseudo_label.shape[0], 1, pseudo_label.shape[2], pseudo_label.shape[3]]).cuda() 188 | mask_1_bck = torch.zeros([pseudo_label.shape[0], 1, pseudo_label.shape[2], pseudo_label.shape[3]]).cuda() 189 | mask_0_obj[uncertain_map[:, 0:1, ...] < 0.05] = 1.0 190 | mask_0_bck[uncertain_map[:, 0:1, ...] < 0.05] = 1.0 191 | mask_1_obj[uncertain_map[:, 1:, ...] < 0.05] = 1.0 192 | mask_1_bck[uncertain_map[:, 1:, ...] < 0.05] = 1.0 193 | mask = torch.cat((mask_0_obj*pseudo_label[:,0:1,...] + mask_0_bck*(1.0-pseudo_label[:,0:1,...]), mask_1_obj*pseudo_label[:,1:,...] + mask_1_bck*(1.0-pseudo_label[:,1:,...])), dim=1) 194 | 195 | mask_proto = torch.zeros([data.shape[0], 2, data.shape[2], data.shape[3]]).cuda() 196 | mask_proto[pseudo_label==proto_pseudo] = 1.0 197 | mask = mask*mask_proto 198 | 199 | # mask for pseudo boundary 200 | target_0_obj = F.interpolate(pseudo_bound[:,0:1,...], size=feature.size()[2:], mode='nearest') 201 | target_0_bck = 1.0 - target_0_obj;target_1_bck = 1.0 - target_1_obj 202 | mask_0_obj = torch.zeros([pseudo_bound.shape[0], 1, pseudo_bound.shape[2], pseudo_bound.shape[3]]).cuda() 203 | mask_0_bck = torch.zeros([pseudo_bound.shape[0], 1, pseudo_bound.shape[2], pseudo_bound.shape[3]]).cuda() 204 | mask_0_obj[uncertain_map[:, 0:1, ...] < 0.05] = 1.0 205 | mask_0_bck[uncertain_map[:, 0:1, ...] < 0.05] = 1.0 206 | mask1 = mask_0_obj*pseudo_bound[:,0:1,...] + mask_0_bck*(1.0-pseudo_bound[:,0:1,...]) 207 | mask_proto1 = torch.zeros([data.shape[0], 1, data.shape[2], data.shape[3]]).cuda() 208 | mask_proto1[pseudo_label[:,0:1,...]==proto_pseudo[:,0:1,...]] = 1.0 209 | mask1 = mask1*mask_proto1 210 | 211 | sceloss = SCELoss() 212 | mseloss = torch.nn.MSELoss() 213 | 214 | loss_seg_pixel = bceloss(prediction, pseudo_label) 215 | loss_seg = torch.sum(loss_seg_pixel) / torch.sum(mask) 216 | loss_seg_bound = mseloss(torch.sigmoid(boundaryS), pseudo_bound) #Lb for pseudo boundary 217 | loss_ent = entropy_loss(prediction) 218 | loss = loss_seg + loss_seg_bound + 0.4*loss_ent 219 | 220 | loss.backward() 221 | optim_gen.step() 222 | iter_num = iter_num + 1 223 | 224 | #test 225 | model_eval.train() 226 | pretrained_dict = model.state_dict() 227 | model_dict = model_eval.state_dict() 228 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 229 | model_eval.load_state_dict(pretrained_dict) 230 | 231 | val_cup_dice = 0.0;val_disc_dice = 0.0;datanum_cnt = 0.0 232 | cup_hd = 0.0; disc_hd = 0.0;datanum_cnt_cup = 0.0;datanum_cnt_disc = 0.0 233 | with torch.no_grad(): 234 | for batch_idx, (sample) in enumerate(test_loader): #test_loader 235 | data, target, img_name = sample['image'], sample['map'], sample['img_name'] 236 | 237 | if torch.cuda.is_available(): 238 | data, target = data.cuda(), target.cuda() 239 | data, target = Variable(data), Variable(target) 240 | prediction, boundary, _ = model_eval(data) 241 | prediction = torch.sigmoid(prediction) 242 | 243 | target_numpy = target.data.cpu() 244 | prediction = prediction.data.cpu() 245 | prediction[prediction>0.75] = 1;prediction[prediction <= 0.75] = 0 246 | im = np.array(target_numpy[:,0, ...]).transpose(1,2,0) *255 247 | im1 = np.array(prediction[:,0, ...]).transpose(1,2,0) *255 248 | 249 | cup_dice = dice_coefficient_numpy(prediction[:,0, ...], target_numpy[:, 0, ...]) 250 | disc_dice = dice_coefficient_numpy(prediction[:,1, ...], target_numpy[:, 1, ...]) 251 | 252 | for i in range(prediction.shape[0]): 253 | hd_tmp = hd_numpy(prediction[i, 0, ...], target_numpy[i, 0, ...], get_hd) 254 | if np.isnan(hd_tmp): 255 | datanum_cnt_cup -= 1.0 256 | else: 257 | cup_hd += hd_tmp 258 | 259 | hd_tmp = hd_numpy(prediction[i, 1, ...], target_numpy[i, 1, ...], get_hd) 260 | if np.isnan(hd_tmp): 261 | datanum_cnt_disc -= 1.0 262 | else: 263 | disc_hd += hd_tmp 264 | 265 | val_cup_dice += np.sum(cup_dice) 266 | val_disc_dice += np.sum(disc_dice) 267 | 268 | datanum_cnt += float(prediction.shape[0]) 269 | datanum_cnt_cup += float(prediction.shape[0]) 270 | datanum_cnt_disc += float(prediction.shape[0]) 271 | 272 | val_cup_dice /= datanum_cnt 273 | val_disc_dice /= datanum_cnt 274 | cup_hd /= datanum_cnt_cup 275 | disc_hd /= datanum_cnt_disc 276 | if (val_cup_dice+val_disc_dice)/2.0>best_avg: 277 | best_val_cup_dice = val_cup_dice; best_val_disc_dice = val_disc_dice; best_avg = (val_cup_dice+val_disc_dice)/2.0 278 | best_cup_hd = cup_hd; best_disc_hd = disc_hd; best_avg_hd = (best_cup_hd+best_disc_hd)/2.0 279 | 280 | if not os.path.exists('./logs/train_target'): 281 | os.mkdir('./logs/train_target') 282 | if args.dataset == 'Domain1': 283 | savefile = './logs/train_target/' + 'D1_' + 'checkpoint_%d.pth.tar' % epoch_num 284 | elif args.dataset == 'Domain2': 285 | savefile = './logs/train_target/' + 'D2_' + 'checkpoint_%d.pth.tar' % epoch_num 286 | elif args.dataset == 'Domain4': 287 | savefile = './logs/train_target/' + 'D4_' + 'checkpoint_%d.pth.tar' % epoch_num 288 | if model_save: 289 | torch.save({ 290 | 'model_state_dict': model.state_dict(), 291 | 'best_mean_dice': best_avg, 292 | 'best_cup_dice': best_val_cup_dice, 293 | 'best_disc_dice': best_val_disc_dice, 294 | }, savefile) 295 | 296 | print("cup: %.4f disc: %.4f cup: %.4f disc: %.4f " % 297 | (val_cup_dice, val_disc_dice, cup_hd, disc_hd)) 298 | print("best disc: %.4f best disc d: %.4f best cup: %.4f best cup d: %.4f " % 299 | (best_val_disc_dice, best_disc_hd, best_val_cup_dice, best_cup_hd, )) 300 | model.train() 301 | 302 | 303 | 304 | -------------------------------------------------------------------------------- /utils/Utils.py: -------------------------------------------------------------------------------- 1 | 2 | # from scipy.misc import imsave 3 | import os.path as osp 4 | import numpy as np 5 | import os 6 | import cv2 7 | from skimage import morphology 8 | import scipy 9 | from PIL import Image 10 | from matplotlib.pyplot import imsave 11 | # from keras.preprocessing import image 12 | from skimage.measure import label, regionprops 13 | from skimage.transform import rotate, resize 14 | from skimage import measure, draw 15 | 16 | import matplotlib.pyplot as plt 17 | plt.switch_backend('agg') 18 | 19 | # from scipy.misc import imsave 20 | from metrics import * 21 | import cv2 22 | 23 | 24 | def construct_color_img(prob_per_slice): 25 | shape = prob_per_slice.shape 26 | img = np.zeros((shape[0], shape[1], 3), dtype=np.uint8) 27 | img[:, :, 0] = prob_per_slice * 255 28 | img[:, :, 1] = prob_per_slice * 255 29 | img[:, :, 2] = prob_per_slice * 255 30 | 31 | #im_color = cv2.applyColorMap(img, cv2.COLORMAP_JET) 32 | return img#im_color 33 | 34 | 35 | def normalize_ent(ent): 36 | ''' 37 | Normalizate ent to 0 - 1 38 | :param ent: 39 | :return: 40 | ''' 41 | min = np.amin(ent) 42 | return (ent - min) / 0.4 43 | 44 | 45 | def draw_ent(prediction, save_root, name): 46 | ''' 47 | Draw the entropy information for each img and save them to the save path 48 | :param prediction: [2, h, w] numpy 49 | :param save_path: string including img name 50 | :return: None 51 | ''' 52 | if not os.path.exists(os.path.join(save_root, 'disc')): 53 | os.makedirs(os.path.join(save_root, 'disc')) 54 | if not os.path.exists(os.path.join(save_root, 'cup')): 55 | os.makedirs(os.path.join(save_root, 'cup')) 56 | smooth = 1e-8 57 | cup = prediction[0] 58 | disc = prediction[1] 59 | cup_ent = - cup * np.log(cup + smooth) 60 | disc_ent = - disc * np.log(disc + smooth) 61 | cup_ent = normalize_ent(cup_ent) 62 | disc_ent = normalize_ent(disc_ent) 63 | disc = construct_color_img(disc_ent) 64 | cv2.imwrite(os.path.join(save_root, 'disc', name.split('.')[0]) + '.png', disc) 65 | cup = construct_color_img(cup_ent) 66 | cv2.imwrite(os.path.join(save_root, 'cup', name.split('.')[0]) + '.png', cup) 67 | 68 | 69 | def draw_mask(prediction, save_root, name): 70 | ''' 71 | Draw the mask probability for each img and save them to the save path 72 | :param prediction: [2, h, w] numpy 73 | :param save_path: string including img name 74 | :return: None 75 | ''' 76 | if not os.path.exists(os.path.join(save_root, 'disc')): 77 | os.makedirs(os.path.join(save_root, 'disc')) 78 | if not os.path.exists(os.path.join(save_root, 'cup')): 79 | os.makedirs(os.path.join(save_root, 'cup')) 80 | cup = prediction[0] 81 | disc = prediction[1] 82 | 83 | disc = construct_color_img(disc) 84 | cv2.imwrite(os.path.join(save_root, 'disc', name.split('.')[0]) + '.png', disc) 85 | cup = construct_color_img(cup) 86 | cv2.imwrite(os.path.join(save_root, 'cup', name.split('.')[0]) + '.png', cup) 87 | 88 | def draw_boundary(prediction, save_root, name): 89 | ''' 90 | Draw the mask probability for each img and save them to the save path 91 | :param prediction: [2, h, w] numpy 92 | :param save_path: string including img name 93 | :return: None 94 | ''' 95 | if not os.path.exists(os.path.join(save_root, 'boundary')): 96 | os.makedirs(os.path.join(save_root, 'boundary')) 97 | boundary = prediction[0] 98 | boundary = construct_color_img(boundary) 99 | cv2.imwrite(os.path.join(save_root, 'boundary', name.split('.')[0]) + '.png', boundary) 100 | 101 | 102 | def get_largest_fillhole(binary): 103 | label_image = label(binary) 104 | regions = regionprops(label_image) 105 | area_list = [] 106 | for region in regions: 107 | area_list.append(region.area) 108 | if area_list: 109 | idx_max = np.argmax(area_list) 110 | binary[label_image != idx_max + 1] = 0 111 | return scipy.ndimage.binary_fill_holes(np.asarray(binary).astype(int)) 112 | 113 | def postprocessing(prediction, threshold=0.75, dataset='G'): 114 | if dataset[0] == 'D': 115 | prediction = prediction.numpy() 116 | prediction_copy = np.copy(prediction) 117 | disc_mask = prediction[1] 118 | cup_mask = prediction[0] 119 | disc_mask = (disc_mask > 0.5) # return binary mask 120 | cup_mask = (cup_mask > 0.1) # return binary mask 121 | disc_mask = disc_mask.astype(np.uint8) 122 | cup_mask = cup_mask.astype(np.uint8) 123 | for i in range(5): 124 | disc_mask = scipy.signal.medfilt2d(disc_mask, 7) 125 | cup_mask = scipy.signal.medfilt2d(cup_mask, 7) 126 | disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 127 | cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 128 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1 129 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8) 130 | prediction_copy[0] = cup_mask 131 | prediction_copy[1] = disc_mask 132 | return prediction_copy 133 | else: 134 | prediction = prediction.numpy() 135 | prediction = (prediction > threshold) # return binary mask 136 | prediction = prediction.astype(np.uint8) 137 | prediction_copy = np.copy(prediction) 138 | disc_mask = prediction[1] 139 | cup_mask = prediction[0] 140 | # for i in range(5): 141 | # disc_mask = scipy.signal.medfilt2d(disc_mask, 7) 142 | # cup_mask = scipy.signal.medfilt2d(cup_mask, 7) 143 | # disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 144 | # cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 145 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1 146 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8) 147 | prediction_copy[0] = cup_mask 148 | prediction_copy[1] = disc_mask 149 | return prediction_copy 150 | 151 | 152 | def joint_val_image(image, prediction, mask): 153 | ratio = 0.5 154 | _pred_cup = np.zeros([mask.shape[-2], mask.shape[-1], 3]) 155 | _pred_disc = np.zeros([mask.shape[-2], mask.shape[-1], 3]) 156 | _mask = np.zeros([mask.shape[-2], mask.shape[-1], 3]) 157 | image = np.transpose(image, (1, 2, 0)) 158 | 159 | _pred_cup[:, :, 0] = prediction[0] 160 | _pred_cup[:, :, 1] = prediction[0] 161 | _pred_cup[:, :, 2] = prediction[0] 162 | _pred_disc[:, :, 0] = prediction[1] 163 | _pred_disc[:, :, 1] = prediction[1] 164 | _pred_disc[:, :, 2] = prediction[1] 165 | _mask[:,:,0] = mask[0] 166 | _mask[:,:,1] = mask[1] 167 | 168 | pred_cup = np.add(ratio * image, (1 - ratio) * _pred_cup) 169 | pred_disc = np.add(ratio * image, (1 - ratio) * _pred_disc) 170 | mask_img = np.add(ratio * image, (1 - ratio) * _mask) 171 | 172 | joint_img = np.concatenate([image, mask_img, pred_cup, pred_disc], axis=1) 173 | return joint_img 174 | 175 | 176 | def save_val_img(path, epoch, img): 177 | name = osp.join(path, "visualization", "epoch_%d.png" % epoch) 178 | out = osp.join(path, "visualization") 179 | if not osp.exists(out): 180 | os.makedirs(out) 181 | img_shape = img[0].shape 182 | stack_image = np.zeros([len(img) * img_shape[0], img_shape[1], img_shape[2]]) 183 | for i in range(len(img)): 184 | stack_image[i * img_shape[0] : (i + 1) * img_shape[0], :, : ] = img[i] 185 | imsave(name, stack_image) 186 | 187 | 188 | 189 | 190 | def save_per_img(patch_image, data_save_path, img_name, prob_map, mask_path=None, ext="bmp"): 191 | path1 = os.path.join(data_save_path, 'overlay', img_name.split('.')[0]+'.png') 192 | path0 = os.path.join(data_save_path, 'original_image', img_name.split('.')[0]+'.png') 193 | if not os.path.exists(os.path.dirname(path0)): 194 | os.makedirs(os.path.dirname(path0)) 195 | if not os.path.exists(os.path.dirname(path1)): 196 | os.makedirs(os.path.dirname(path1)) 197 | 198 | disc_map = prob_map[0] 199 | cup_map = prob_map[1] 200 | size = disc_map.shape 201 | disc_map[:, 0] = np.zeros(size[0]) 202 | disc_map[:, size[1] - 1] = np.zeros(size[0]) 203 | disc_map[0, :] = np.zeros(size[1]) 204 | disc_map[size[0] - 1, :] = np.zeros(size[1]) 205 | size = cup_map.shape 206 | cup_map[:, 0] = np.zeros(size[0]) 207 | cup_map[:, size[1] - 1] = np.zeros(size[0]) 208 | cup_map[0, :] = np.zeros(size[1]) 209 | cup_map[size[0] - 1, :] = np.zeros(size[1]) 210 | 211 | disc_mask = (disc_map > 0.75) # return binary mask 212 | cup_mask = (cup_map > 0.75) 213 | disc_mask = disc_mask.astype(np.uint8) 214 | cup_mask = cup_mask.astype(np.uint8) 215 | 216 | for i in range(5): 217 | disc_mask = scipy.signal.medfilt2d(disc_mask, 7) 218 | cup_mask = scipy.signal.medfilt2d(cup_mask, 7) 219 | disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 220 | cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 221 | disc_mask = get_largest_fillhole(disc_mask) 222 | cup_mask = get_largest_fillhole(cup_mask) 223 | 224 | disc_mask = morphology.binary_dilation(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 225 | cup_mask = morphology.binary_dilation(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 226 | 227 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1 228 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8) 229 | 230 | 231 | contours_disc = measure.find_contours(disc_mask, 0.5) 232 | contours_cup = measure.find_contours(cup_mask, 0.5) 233 | 234 | patch_image2 = patch_image.astype(np.uint8) 235 | patch_image2 = Image.fromarray(patch_image2) 236 | 237 | patch_image2.save(path0) 238 | 239 | for n, contour in enumerate(contours_cup): 240 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0] 241 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0] 242 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 255, 0] 243 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 255, 0] 244 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0] 245 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 255, 0] 246 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 255, 0] 247 | 248 | for n, contour in enumerate(contours_disc): 249 | patch_image[contour[:, 0].astype(int), contour[:, 1].astype(int), :] = [0, 0, 255] 250 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 0, 255] 251 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 0, 255] 252 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 0, 255] 253 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 0, 255] 254 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 0, 255] 255 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 0, 255] 256 | 257 | patch_image = patch_image.astype(np.uint8) 258 | patch_image = Image.fromarray(patch_image) 259 | 260 | patch_image.save(path1) 261 | 262 | def untransform(img, lt): 263 | img = (img + 1) * 127.5 264 | lt = lt * 128 265 | return img, lt -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/__pycache__/Utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/utils/__pycache__/Utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/Utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/utils/__pycache__/Utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinGrayy/PLPB/d61d61fe33dd21c5c67d12edec2a47fb22aa6806/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from Lovaszloss import lovasz_hinge 6 | 7 | def entropy_loss(p, C=2): 8 | y1 = -1.0*torch.sum(p*torch.log(p+1e-6), dim=1)/torch.tensor(np.log(C)).cuda() 9 | ent = torch.mean(y1) 10 | 11 | return ent 12 | 13 | class LovaszHingeLoss(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, input, target): 18 | input = input.squeeze(1) 19 | target = target.squeeze(1) 20 | loss = lovasz_hinge(input, target, per_image=True) 21 | 22 | return loss 23 | 24 | 25 | class CrossEntropyLoss(nn.CrossEntropyLoss): 26 | def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): 27 | super().__init__(weight, size_average, ignore_index, reduce, reduction) 28 | 29 | def forward(self, logits: torch.tensor, target: torch.tensor, **kwargs): 30 | return super().forward(logits, target) 31 | 32 | 33 | class StochasticSegmentationNetworkLossMCIntegral(nn.Module): 34 | def __init__(self, num_mc_samples: int = 1): 35 | super().__init__() 36 | self.num_mc_samples = num_mc_samples 37 | 38 | @staticmethod 39 | def fixed_re_parametrization_trick(dist, num_samples): 40 | assert num_samples % 2 == 0 41 | samples = dist.rsample((num_samples // 2,)) 42 | mean = dist.mean.unsqueeze(0) 43 | samples = samples - mean 44 | return torch.cat([samples, -samples]) + mean 45 | 46 | def forward(self, logits, target, distribution, **kwargs): 47 | batch_size = logits.shape[0] 48 | num_classes = logits.shape[1] 49 | assert num_classes >= 2 # not implemented for binary case with implied background 50 | # logit_sample = distribution.rsample((self.num_mc_samples,)) 51 | logit_sample = self.fixed_re_parametrization_trick(distribution, self.num_mc_samples) 52 | target = target.unsqueeze(1) 53 | target = target.expand((self.num_mc_samples,) + target.shape) 54 | 55 | flat_size = self.num_mc_samples * batch_size 56 | logit_sample = logit_sample.view((flat_size, num_classes, -1)) 57 | target = target.reshape((flat_size, -1)) 58 | 59 | # log_prob = -F.cross_entropy(logit_sample, target, reduction='none').view((self.num_mc_samples, batch_size, -1)) 60 | log_prob = -F.binary_cross_entropy(F.sigmoid(logit_sample), target, reduction='none').view((self.num_mc_samples, batch_size, -1)) 61 | loglikelihood = torch.mean(torch.logsumexp(torch.sum(log_prob, dim=-1), dim=0) - math.log(self.num_mc_samples)) 62 | loss = -loglikelihood 63 | return loss --------------------------------------------------------------------------------