├── utils ├── __init__.py ├── losses.py ├── metrics.py └── Utils.py ├── dataloaders ├── __init__.py ├── fundus_dataloader.py └── custom_transforms.py ├── networks ├── __init__.py ├── backbone │ ├── __init__.py │ ├── mobilenet.py │ ├── resnet.py │ ├── xception.py │ └── drn.py ├── decoder.py ├── deeplabv3.py ├── layers.py ├── decoder_old.py ├── aspp.py ├── aspp_eval.py ├── sync_batchnorm │ ├── comm.py │ └── batchnorm.py ├── GAN.py └── models.py ├── train_process ├── __init__.py └── Trainer.py ├── figure └── framework.png ├── README.md ├── .gitignore ├── train_source.py └── train_target.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /train_process/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /figure/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lloongx/SFDA-CBMT/HEAD/figure/framework.png -------------------------------------------------------------------------------- /networks/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | else: 13 | raise NotImplementedError 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Source-Free Domain Adaptive Fundus Image Segmentation with Class-Balanced Mean Teacher 2 | 3 | Pytorch implementation of MICCAI'23 paper *Source-Free Domain Adaptive Fundus Image Segmentation with Class-Balanced Mean Teacher*. 4 | 5 |

6 | 7 |

8 | 9 | ## Installation 10 | * Install python 3.10.5, pytorch 1.12.0, CUDA 11.6 and other essential packages (Note that using other versions of packages may affect performance.) 11 | * Clone this repo 12 | ``` 13 | git clone https://github.com/lloongx/SFDA-CBMT.git 14 | cd SFDA-CBMT 15 | ``` 16 | 17 | ## Training 18 | * Download datasets from [here](https://drive.google.com/file/d/1B7ArHRBjt2Dx29a3A6X_lGhD0vDVr3sy/view). 19 | * Download source domain model from [here](https://drive.google.com/drive/folders/1L23mCg8prsdu1imEQI5ouuwvVL_FSiLY) or specify the `--data-dir` in `./train_source.py` and then run it. 20 | * Save source domain model into folder `./logs_train/`. 21 | * Run `./train_target.py` with specified `--model-file` and `--data-dir` to start the SFDA training process. 22 | 23 | 24 | ## Acknowledgement 25 | This repo benefits from [BEAL](https://github.com/emma-sjwang/BEAL) and [SFDA-DPL](https://github.com/cchen-cc/SFDA-DPL). Thanks for their wonderful works. 26 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | 6 | def entropy_loss(p, C=2): 7 | y1 = -1.0*torch.sum(p*torch.log(p+1e-6), dim=1)/torch.tensor(np.log(C)).cuda() 8 | ent = torch.mean(y1) 9 | 10 | return ent 11 | 12 | class CrossEntropyLoss(nn.CrossEntropyLoss): 13 | def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): 14 | super().__init__(weight, size_average, ignore_index, reduce, reduction) 15 | 16 | def forward(self, logits: torch.tensor, target: torch.tensor, **kwargs): 17 | return super().forward(logits, target) 18 | 19 | 20 | class StochasticSegmentationNetworkLossMCIntegral(nn.Module): 21 | def __init__(self, num_mc_samples: int = 1): 22 | super().__init__() 23 | self.num_mc_samples = num_mc_samples 24 | 25 | @staticmethod 26 | def fixed_re_parametrization_trick(dist, num_samples): 27 | assert num_samples % 2 == 0 28 | samples = dist.rsample((num_samples // 2,)) 29 | mean = dist.mean.unsqueeze(0) 30 | samples = samples - mean 31 | return torch.cat([samples, -samples]) + mean 32 | 33 | def forward(self, logits, target, distribution, **kwargs): 34 | batch_size = logits.shape[0] 35 | num_classes = logits.shape[1] 36 | assert num_classes >= 2 # not implemented for binary case with implied background 37 | # logit_sample = distribution.rsample((self.num_mc_samples,)) 38 | logit_sample = self.fixed_re_parametrization_trick(distribution, self.num_mc_samples) 39 | target = target.unsqueeze(1) 40 | target = target.expand((self.num_mc_samples,) + target.shape) 41 | 42 | flat_size = self.num_mc_samples * batch_size 43 | logit_sample = logit_sample.view((flat_size, num_classes, -1)) 44 | target = target.reshape((flat_size, -1)) 45 | 46 | # log_prob = -F.cross_entropy(logit_sample, target, reduction='none').view((self.num_mc_samples, batch_size, -1)) 47 | log_prob = -F.binary_cross_entropy(F.sigmoid(logit_sample), target, reduction='none').view((self.num_mc_samples, batch_size, -1)) 48 | loglikelihood = torch.mean(torch.logsumexp(torch.sum(log_prob, dim=-1), dim=0) - math.log(self.num_mc_samples)) 49 | loss = -loglikelihood 50 | return loss -------------------------------------------------------------------------------- /networks/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv_1 = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 23 | BatchNorm(256), 24 | nn.ReLU(), 25 | nn.Dropout(0.5), 26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 27 | BatchNorm(256), 28 | nn.ReLU()) 29 | self.last_conv_2 = nn.Sequential(nn.Dropout(0.1), 30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 31 | self._init_weight() 32 | 33 | 34 | def forward(self, x, low_level_feat): 35 | low_level_feat = self.conv1(low_level_feat) 36 | low_level_feat = self.bn1(low_level_feat) 37 | low_level_feat = self.relu(low_level_feat) 38 | 39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 40 | x = torch.cat((x, low_level_feat), dim=1) 41 | x = self.last_conv_1(x) 42 | out = self.last_conv_2(x) 43 | 44 | return out, x 45 | 46 | def _init_weight(self): 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | torch.nn.init.kaiming_normal_(m.weight) 50 | elif isinstance(m, SynchronizedBatchNorm2d): 51 | m.weight.data.fill_(1) 52 | m.bias.data.zero_() 53 | elif isinstance(m, nn.BatchNorm2d): 54 | m.weight.data.fill_(1) 55 | m.bias.data.zero_() 56 | 57 | def build_decoder(num_classes, backbone, BatchNorm): 58 | return Decoder(num_classes, backbone, BatchNorm) 59 | -------------------------------------------------------------------------------- /networks/deeplabv3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from networks.aspp import build_aspp 6 | from networks.decoder import build_decoder 7 | from networks.backbone import build_backbone 8 | 9 | 10 | class DeepLab(nn.Module): 11 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 12 | sync_bn=True, freeze_bn=False): 13 | super(DeepLab, self).__init__() 14 | if backbone == 'drn': 15 | output_stride = 8 16 | 17 | if sync_bn == True: 18 | BatchNorm = SynchronizedBatchNorm2d 19 | else: 20 | BatchNorm = nn.BatchNorm2d 21 | 22 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 24 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 25 | 26 | if freeze_bn: 27 | self.freeze_bn() 28 | 29 | def forward(self, input): 30 | x, low_level_feat = self.backbone(input) 31 | x = self.aspp(x) 32 | x, features = self.decoder(x, low_level_feat) 33 | 34 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 35 | return x, features 36 | 37 | def freeze_bn(self): 38 | for m in self.modules(): 39 | if isinstance(m, SynchronizedBatchNorm2d): 40 | m.eval() 41 | elif isinstance(m, nn.BatchNorm2d): 42 | m.eval() 43 | 44 | def get_1x_lr_params(self): 45 | modules = [self.backbone] 46 | for i in range(len(modules)): 47 | for m in modules[i].named_modules(): 48 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 49 | or isinstance(m[1], nn.BatchNorm2d): 50 | for p in m[1].parameters(): 51 | if p.requires_grad: 52 | yield p 53 | 54 | def get_10x_lr_params(self): 55 | modules = [self.aspp, self.decoder] 56 | for i in range(len(modules)): 57 | for m in modules[i].named_modules(): 58 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 59 | or isinstance(m[1], nn.BatchNorm2d): 60 | for p in m[1].parameters(): 61 | if p.requires_grad: 62 | yield p 63 | 64 | 65 | if __name__ == "__main__": 66 | model = DeepLab(backbone='mobilenet', output_stride=16) 67 | model.eval() 68 | input = torch.rand(1, 3, 513, 513) 69 | output = model(input) 70 | print(output.size()) 71 | 72 | 73 | -------------------------------------------------------------------------------- /networks/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Wrappers for the operations to take the meta-learning gradient 3 | updates into account. 4 | """ 5 | import torch.autograd as autograd 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | def linear(inputs, weight, bias, meta_step_size=0.001, meta_loss=None, stop_gradient=False): 11 | inputs = inputs.cuda() 12 | weight = weight.cuda() 13 | bias = bias.cuda() 14 | 15 | if meta_loss is not None: 16 | 17 | if not stop_gradient: 18 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True)[0] 19 | 20 | if bias is not None: 21 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True)[0] 22 | bias_adapt = bias - grad_bias * meta_step_size 23 | else: 24 | bias_adapt = bias 25 | 26 | else: 27 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True)[0].data, requires_grad=False) 28 | 29 | if bias is not None: 30 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True)[0].data, requires_grad=False) 31 | bias_adapt = bias - grad_bias * meta_step_size 32 | else: 33 | bias_adapt = bias 34 | 35 | return F.linear(inputs, 36 | weight - grad_weight * meta_step_size, 37 | bias_adapt) 38 | else: 39 | return F.linear(inputs, weight, bias) 40 | 41 | def conv2d(inputs, weight, bias, stride=1, padding=1, dilation=1, groups=1, kernel_size=3): 42 | 43 | inputs = inputs.cuda() 44 | weight = weight.cuda() 45 | bias = bias.cuda() 46 | 47 | return F.conv2d(inputs, weight, bias, stride, padding, dilation, groups) 48 | 49 | 50 | def deconv2d(inputs, weight, bias, stride=2, padding=0, dilation=0, groups=1, kernel_size=None): 51 | 52 | inputs = inputs.cuda() 53 | weight = weight.cuda() 54 | bias = bias.cuda() 55 | 56 | return F.conv_transpose2d(inputs, weight, bias, stride, padding, dilation, groups) 57 | 58 | def relu(inputs): 59 | return F.relu(inputs, inplace=True) 60 | 61 | 62 | def maxpool(inputs, kernel_size, stride=None, padding=0): 63 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding) 64 | 65 | 66 | def dropout(inputs): 67 | return F.dropout(inputs, p=0.5, training=False, inplace=False) 68 | 69 | def batchnorm(inputs, running_mean, running_var): 70 | return F.batch_norm(inputs, running_mean, running_var) 71 | 72 | 73 | """ 74 | The following are the new methods for 2D-Unet: 75 | Conv2d, batchnorm2d, GroupNorm, InstanceNorm2d, MaxPool2d, UpSample 76 | """ 77 | #as per the 2D Unet: kernel_size, stride, padding 78 | 79 | def instancenorm(input): 80 | return F.instance_norm(input) 81 | 82 | def groupnorm(input): 83 | return F.group_norm(input) 84 | 85 | def dropout2D(inputs): 86 | return F.dropout2d(inputs, p=0.5, training=False, inplace=False) 87 | 88 | def maxpool2D(inputs, kernel_size, stride=None, padding=0): 89 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding) 90 | 91 | def upsample(input): 92 | return F.upsample(input, scale_factor=2, mode='bilinear', align_corners=False) 93 | -------------------------------------------------------------------------------- /networks/decoder_old.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | else: 17 | raise NotImplementedError 18 | 19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 20 | self.bn1 = BatchNorm(48) 21 | self.relu = nn.ReLU() 22 | self.last_conv = nn.Sequential( 23 | # nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 24 | # BatchNorm(256), 25 | # nn.ReLU(), 26 | # nn.Dropout(0.5), 27 | # nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 28 | BatchNorm(305), 29 | nn.ReLU(), 30 | nn.Dropout(0.1), 31 | nn.Conv2d(305, num_classes, kernel_size=1, stride=1)) 32 | self.last_conv_boundary = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 33 | BatchNorm(256), 34 | nn.ReLU(), 35 | nn.Dropout(0.5), 36 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 37 | BatchNorm(256), 38 | nn.ReLU(), 39 | nn.Dropout(0.1), 40 | nn.Conv2d(256, 1, kernel_size=1, stride=1)) 41 | self._init_weight() 42 | 43 | 44 | def forward(self, x, low_level_feat): 45 | low_level_feat = self.conv1(low_level_feat) 46 | low_level_feat = self.bn1(low_level_feat) 47 | low_level_feat = self.relu(low_level_feat) 48 | 49 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 50 | x = torch.cat((x, low_level_feat), dim=1) 51 | boundary = self.last_conv_boundary(x) 52 | x = torch.cat([x, boundary], 1) 53 | x1 = self.last_conv(x) 54 | 55 | return x1, boundary, x 56 | 57 | def _init_weight(self): 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | torch.nn.init.kaiming_normal_(m.weight) 61 | elif isinstance(m, SynchronizedBatchNorm2d): 62 | m.weight.data.fill_(1) 63 | m.bias.data.zero_() 64 | elif isinstance(m, nn.BatchNorm2d): 65 | m.weight.data.fill_(1) 66 | m.bias.data.zero_() 67 | 68 | def build_decoder(num_classes, backbone, BatchNorm): 69 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /networks/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | BatchNorm(256), 58 | nn.ReLU()) 59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 60 | self.bn1 = BatchNorm(256) 61 | self.relu = nn.ReLU() 62 | self.dropout = nn.Dropout(0.5) 63 | self._init_weight() 64 | 65 | def forward(self, x): 66 | x1 = self.aspp1(x) 67 | x2 = self.aspp2(x) 68 | x3 = self.aspp3(x) 69 | x4 = self.aspp4(x) 70 | x5 = self.global_avg_pool(x) 71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | return self.dropout(x) 79 | 80 | def _init_weight(self): 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | torch.nn.init.kaiming_normal_(m.weight) 86 | elif isinstance(m, SynchronizedBatchNorm2d): 87 | m.weight.data.fill_(1) 88 | m.bias.data.zero_() 89 | elif isinstance(m, nn.BatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | 93 | 94 | def build_aspp(backbone, output_stride, BatchNorm): 95 | return ASPP(backbone, output_stride, BatchNorm) 96 | -------------------------------------------------------------------------------- /networks/aspp_eval.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | else: 42 | inplanes = 2048 43 | if output_stride == 16: 44 | dilations = [1, 6, 12, 18] 45 | elif output_stride == 8: 46 | dilations = [1, 12, 24, 36] 47 | else: 48 | raise NotImplementedError 49 | 50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 54 | 55 | # self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 56 | # nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 57 | # BatchNorm(256), 58 | # nn.ReLU()) 59 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 60 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 61 | nn.ReLU()) 62 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 63 | self.bn1 = BatchNorm(256) 64 | self.relu = nn.ReLU() 65 | self.dropout = nn.Dropout(0.5) 66 | self._init_weight() 67 | 68 | def forward(self, x): 69 | x1 = self.aspp1(x) 70 | x2 = self.aspp2(x) 71 | x3 = self.aspp3(x) 72 | x4 = self.aspp4(x) 73 | x5 = self.global_avg_pool(x) 74 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 75 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 76 | 77 | x = self.conv1(x) 78 | x = self.bn1(x) 79 | x = self.relu(x) 80 | 81 | return self.dropout(x) 82 | 83 | def _init_weight(self): 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv2d): 86 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 87 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 88 | torch.nn.init.kaiming_normal_(m.weight) 89 | elif isinstance(m, SynchronizedBatchNorm2d): 90 | m.weight.data.fill_(1) 91 | m.bias.data.zero_() 92 | elif isinstance(m, nn.BatchNorm2d): 93 | m.weight.data.fill_(1) 94 | m.bias.data.zero_() 95 | 96 | 97 | def build_aspp(backbone, output_stride, BatchNorm): 98 | return ASPP(backbone, output_stride, BatchNorm) 99 | -------------------------------------------------------------------------------- /networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /networks/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=True) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=True), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=True), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=True), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def _load_pretrained_model(self): 124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 125 | model_dict = {} 126 | state_dict = self.state_dict() 127 | for k, v in pretrain_dict.items(): 128 | if k in state_dict: 129 | model_dict[k] = v 130 | state_dict.update(model_dict) 131 | self.load_state_dict(state_dict) 132 | 133 | def _initialize_weights(self): 134 | for m in self.modules(): 135 | if isinstance(m, nn.Conv2d): 136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 137 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 138 | torch.nn.init.kaiming_normal_(m.weight) 139 | elif isinstance(m, SynchronizedBatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | elif isinstance(m, nn.BatchNorm2d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | 146 | if __name__ == "__main__": 147 | input = torch.rand(1, 3, 512, 512) 148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 149 | output, low_level_feat = model(input) 150 | print(output.size()) 151 | print(low_level_feat.size()) 152 | -------------------------------------------------------------------------------- /dataloaders/fundus_dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from glob import glob 7 | import random 8 | 9 | 10 | class FundusSegmentation(Dataset): 11 | """ 12 | Fundus segmentation dataset 13 | including 5 domain dataset 14 | one for test others for training 15 | """ 16 | 17 | def __init__(self, 18 | base_dir, 19 | dataset='refuge', 20 | split='train', 21 | testid=None, 22 | transform=None 23 | ): 24 | """ 25 | :param base_dir: path to VOC dataset directory 26 | :param split: train/val 27 | :param transform: transform to apply 28 | """ 29 | # super().__init__() 30 | self._base_dir = base_dir 31 | self.image_list = [] 32 | self.split = split 33 | 34 | self.image_pool = [] 35 | self.label_pool = [] 36 | self.img_name_pool = [] 37 | 38 | self._image_dir = os.path.join(self._base_dir, dataset, split, 'image') 39 | print(self._image_dir) 40 | imagelist = glob(self._image_dir + "/*.png") 41 | for image_path in imagelist: 42 | gt_path = image_path.replace('image', 'mask') 43 | self.image_list.append({'image': image_path, 'label': gt_path, 'id': testid}) 44 | 45 | self.transform = transform 46 | # self._read_img_into_memory() 47 | # Display stats 48 | print('Number of images in {}: {:d}'.format(split, len(self.image_list))) 49 | 50 | def __len__(self): 51 | return len(self.image_list) 52 | 53 | def __getitem__(self, index): 54 | 55 | _img = Image.open(self.image_list[index]['image']).convert('RGB') 56 | _target = Image.open(self.image_list[index]['label']) 57 | if _target.mode is 'RGB': 58 | _target = _target.convert('L') 59 | _img_name = self.image_list[index]['image'].split('/')[-1] 60 | 61 | # _img = self.image_pool[index] 62 | # _target = self.label_pool[index] 63 | # _img_name = self.img_name_pool[index] 64 | anco_sample = {'image': _img, 'label': _target, 'img_name': _img_name} 65 | 66 | if self.transform is not None: 67 | anco_sample = self.transform(anco_sample) 68 | 69 | return anco_sample 70 | 71 | def _read_img_into_memory(self): 72 | 73 | img_num = len(self.image_list) 74 | for index in range(img_num): 75 | self.image_pool.append(Image.open(self.image_list[index]['image']).convert('RGB')) 76 | _target = Image.open(self.image_list[index]['label']) 77 | if _target.mode is 'RGB': 78 | _target = _target.convert('L') 79 | self.label_pool.append(_target) 80 | _img_name = self.image_list[index]['image'].split('/')[-1] 81 | self.img_name_pool.append(_img_name) 82 | 83 | 84 | def __str__(self): 85 | return 'Fundus(split=' + str(self.split) + ')' 86 | 87 | 88 | class FundusSegmentation_2transform(Dataset): 89 | """ 90 | Fundus segmentation dataset 91 | including 5 domain dataset 92 | one for test others for training 93 | """ 94 | 95 | def __init__(self, 96 | base_dir, 97 | dataset='refuge', 98 | split='train', 99 | testid=None, 100 | transform_weak=None, 101 | transform_strong=None 102 | ): 103 | """ 104 | :param base_dir: path to VOC dataset directory 105 | :param split: train/val 106 | :param transform: transform to apply 107 | """ 108 | # super().__init__() 109 | self._base_dir = base_dir 110 | self.image_list = [] 111 | self.split = split 112 | 113 | self.image_pool = [] 114 | self.label_pool = [] 115 | self.img_name_pool = [] 116 | 117 | self._image_dir = os.path.join(self._base_dir, dataset, split, 'image') 118 | print(self._image_dir) 119 | imagelist = glob(self._image_dir + "/*.png") 120 | for image_path in imagelist: 121 | gt_path = image_path.replace('image', 'mask') 122 | self.image_list.append({'image': image_path, 'label': gt_path, 'id': testid}) 123 | 124 | self.transform_weak = transform_weak 125 | self.transform_strong = transform_strong 126 | # self._read_img_into_memory() 127 | # Display stats 128 | print('Number of images in {}: {:d}'.format(split, len(self.image_list))) 129 | 130 | def __len__(self): 131 | return len(self.image_list) 132 | 133 | def __getitem__(self, index): 134 | 135 | _img = Image.open(self.image_list[index]['image']).convert('RGB') 136 | _target = Image.open(self.image_list[index]['label']) 137 | if _target.mode is 'RGB': 138 | _target = _target.convert('L') 139 | _img_name = self.image_list[index]['image'].split('/')[-1] 140 | 141 | # _img = self.image_pool[index] 142 | # _target = self.label_pool[index] 143 | # _img_name = self.img_name_pool[index] 144 | anco_sample = {'image': _img, 'label': _target, 'img_name': _img_name} 145 | 146 | anco_sample_weak_aug = self.transform_weak(anco_sample) 147 | 148 | anco_sample_strong_aug = self.transform_strong(anco_sample) 149 | 150 | return anco_sample_weak_aug, anco_sample_strong_aug 151 | 152 | def _read_img_into_memory(self): 153 | 154 | img_num = len(self.image_list) 155 | for index in range(img_num): 156 | self.image_pool.append(Image.open(self.image_list[index]['image']).convert('RGB')) 157 | _target = Image.open(self.image_list[index]['label']) 158 | if _target.mode is 'RGB': 159 | _target = _target.convert('L') 160 | self.label_pool.append(_target) 161 | _img_name = self.image_list[index]['image'].split('/')[-1] 162 | self.img_name_pool.append(_img_name) 163 | 164 | 165 | def __str__(self): 166 | return 'Fundus(split=' + str(self.split) + ')' 167 | 168 | 169 | -------------------------------------------------------------------------------- /networks/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = BatchNorm(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | self.dilation = dilation 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | 34 | out = self.conv3(out) 35 | out = self.bn3(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | class ResNet(nn.Module): 46 | 47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 48 | self.inplanes = 64 49 | super(ResNet, self).__init__() 50 | blocks = [1, 2, 4] 51 | if output_stride == 16: 52 | strides = [1, 2, 2, 1] 53 | dilations = [1, 1, 1, 2] 54 | elif output_stride == 8: 55 | strides = [1, 2, 1, 1] 56 | dilations = [1, 1, 2, 4] 57 | else: 58 | raise NotImplementedError 59 | 60 | # Modules 61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 62 | bias=False) 63 | self.bn1 = BatchNorm(64) 64 | self.relu = nn.ReLU(inplace=True) 65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 66 | 67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 72 | self._init_weight() 73 | 74 | if pretrained: 75 | self._load_pretrained_model() 76 | 77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 78 | downsample = None 79 | if stride != 1 or self.inplanes != planes * block.expansion: 80 | downsample = nn.Sequential( 81 | nn.Conv2d(self.inplanes, planes * block.expansion, 82 | kernel_size=1, stride=stride, bias=False), 83 | BatchNorm(planes * block.expansion), 84 | ) 85 | 86 | layers = [] 87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 88 | self.inplanes = planes * block.expansion 89 | for i in range(1, blocks): 90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 91 | 92 | return nn.Sequential(*layers) 93 | 94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 95 | downsample = None 96 | if stride != 1 or self.inplanes != planes * block.expansion: 97 | downsample = nn.Sequential( 98 | nn.Conv2d(self.inplanes, planes * block.expansion, 99 | kernel_size=1, stride=stride, bias=False), 100 | BatchNorm(planes * block.expansion), 101 | ) 102 | 103 | layers = [] 104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 105 | downsample=downsample, BatchNorm=BatchNorm)) 106 | self.inplanes = planes * block.expansion 107 | for i in range(1, len(blocks)): 108 | layers.append(block(self.inplanes, planes, stride=1, 109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 110 | 111 | return nn.Sequential(*layers) 112 | 113 | def forward(self, input): 114 | x = self.conv1(input) 115 | x = self.bn1(x) 116 | x = self.relu(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | low_level_feat = x 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | return x, low_level_feat 125 | 126 | def _init_weight(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | elif isinstance(m, SynchronizedBatchNorm2d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | 138 | def _load_pretrained_model(self): 139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 140 | model_dict = {} 141 | state_dict = self.state_dict() 142 | for k, v in pretrain_dict.items(): 143 | if k in state_dict: 144 | model_dict[k] = v 145 | state_dict.update(model_dict) 146 | self.load_state_dict(state_dict) 147 | 148 | def ResNet101(output_stride, BatchNorm, pretrained=True): 149 | """Constructs a ResNet-101 model. 150 | Args: 151 | pretrained (bool): If True, returns a model pre-trained on ImageNet 152 | """ 153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 154 | return model 155 | 156 | if __name__ == "__main__": 157 | import torch 158 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 159 | input = torch.rand(1, 3, 512, 512) 160 | output, low_level_feat = model(input) 161 | print(output.size()) 162 | print(low_level_feat.size()) 163 | -------------------------------------------------------------------------------- /train_source.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser( 3 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 4 | ) 5 | parser.add_argument('-g', '--gpu', type=str, default='7', help='gpu id') 6 | parser.add_argument('--resume', default=None, help='checkpoint path') 7 | parser.add_argument( 8 | '--dataset', type=str, default='Domain3', help='test folder id contain images ROIs to test' 9 | ) 10 | parser.add_argument( 11 | '--model', type=str, default='Deeplab', help='Deeplab' 12 | ) 13 | parser.add_argument( 14 | '--batch-size', type=int, default=8, help='batch size for training the model' 15 | ) 16 | parser.add_argument( 17 | '--group-num', type=int, default=1, help='group number for group normalization' 18 | ) 19 | parser.add_argument( 20 | '--max-epoch', type=int, default=200, help='max epoch' 21 | ) 22 | parser.add_argument( 23 | '--stop-epoch', type=int, default=200, help='stop epoch' 24 | ) 25 | parser.add_argument( 26 | '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN' 27 | ) 28 | 29 | parser.add_argument( 30 | '--interval-validate', type=int, default=5, help='interval epoch number to validate the model' 31 | ) 32 | parser.add_argument( 33 | '--interval-save', type=int, default=10, help='interval epoch number to save the model' 34 | ) 35 | parser.add_argument( 36 | '--lr', type=float, default=1e-3, help='learning rate', 37 | ) 38 | parser.add_argument( 39 | '--lr-decrease-rate', type=float, default=0.98, help='ratio multiplied to initial lr', 40 | ) 41 | parser.add_argument( 42 | '--lr-decrease-epoch', type=int, default=1, help='interval epoch number for lr decrease', 43 | ) 44 | parser.add_argument( 45 | '--weight-decay', type=float, default=0, help='weight decay', 46 | ) 47 | parser.add_argument( 48 | '--momentum', type=float, default=0.99, help='momentum', 49 | ) 50 | parser.add_argument( 51 | '--data-dir', 52 | default='../Datasets/Fundus', 53 | help='data root path' 54 | ) 55 | parser.add_argument( 56 | '--out-stride', 57 | type=int, 58 | default=16, 59 | help='out-stride of deeplabv3+', 60 | ) 61 | parser.add_argument( 62 | '--sync-bn', 63 | type=bool, 64 | default=True, 65 | help='sync-bn in deeplabv3+', 66 | ) 67 | parser.add_argument( 68 | '--freeze-bn', 69 | type=bool, 70 | default=False, 71 | help='freeze batch normalization of deeplabv3+', 72 | ) 73 | parser.add_argument('--no-augmentation', action='store_true') 74 | 75 | args = parser.parse_args() 76 | 77 | from datetime import datetime 78 | import os 79 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 80 | import os.path as osp 81 | 82 | # PyTorch includes 83 | import torch 84 | from torchvision import transforms 85 | from torch.utils.data import DataLoader 86 | import yaml 87 | from train_process import Trainer 88 | 89 | # Custom includes 90 | from dataloaders import fundus_dataloader 91 | from dataloaders import custom_transforms as trans 92 | from networks.deeplabv3 import * 93 | 94 | 95 | here = osp.dirname(osp.abspath(__file__)) 96 | 97 | def main(): 98 | now = datetime.now() 99 | args.out = osp.join(here, 'logs_train', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f')) 100 | 101 | os.makedirs(args.out) 102 | with open(osp.join(args.out, 'config.yaml'), 'w') as f: 103 | yaml.safe_dump(args.__dict__, f, default_flow_style=False) 104 | 105 | multiply_gpu = False 106 | if (args.gpu).find(',') != -1: 107 | multiply_gpu = True 108 | cuda = torch.cuda.is_available() 109 | 110 | torch.manual_seed(42) 111 | if cuda: 112 | torch.cuda.manual_seed(42) 113 | 114 | # 1. dataset 115 | composed_transforms_tr = transforms.Compose([ 116 | trans.RandomScaleCrop(512), 117 | trans.RandomRotate(), 118 | trans.RandomFlip(), 119 | trans.elastic_transform(), 120 | trans.add_salt_pepper_noise(), 121 | trans.adjust_light(), 122 | trans.eraser(), 123 | trans.Normalize_tf(), 124 | trans.ToTensor() 125 | ]) 126 | 127 | composed_transforms_ts = transforms.Compose([ 128 | # fundus_trans.RandomCrop(512), 129 | trans.Resize(512), 130 | trans.Normalize_tf(), # this function separates (raw ground truth mask) into (2 masks) 131 | trans.ToTensor() 132 | ]) 133 | 134 | if args.no_augmentation: 135 | domain = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, 136 | split='train/ROIs', transform=composed_transforms_ts) 137 | else: 138 | domain = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, 139 | split='train/ROIs', transform=composed_transforms_tr) 140 | domain_loader = DataLoader(domain, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) 141 | 142 | domain_val = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, 143 | split='test/ROIs', transform=composed_transforms_ts) 144 | domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2, 145 | pin_memory=True) 146 | 147 | # 2. model 148 | model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, 149 | sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) 150 | if cuda: 151 | model = model.cuda() 152 | 153 | start_epoch = 0 154 | start_iteration = 0 155 | 156 | # 3. optimizer 157 | if multiply_gpu: 158 | model = torch.nn.DataParallel(model, device_ids=[0, 1]) 159 | 160 | optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay) 161 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=args.lr_decrease_epoch, gamma=args.lr_decrease_rate) 162 | 163 | if args.resume: 164 | checkpoint = torch.load(args.resume) 165 | pretrained_dict = checkpoint['model_state_dict'] 166 | model_dict = model.state_dict() 167 | # 1. filter out unnecessary keys 168 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 169 | # 2. overwrite entries in the existing state dict 170 | model_dict.update(pretrained_dict) 171 | # 3. load the new state dict 172 | model.load_state_dict(model_dict) 173 | 174 | 175 | start_epoch = checkpoint['epoch'] + 1 176 | start_iteration = checkpoint['iteration'] + 1 177 | optim.load_state_dict(checkpoint['optim_state_dict']) 178 | 179 | trainer = Trainer.Trainer( 180 | cuda=cuda, 181 | multiply_gpu=multiply_gpu, 182 | model=model, 183 | optimizer=optim, 184 | scheduler=scheduler, 185 | lr=args.lr, 186 | val_loader=domain_loader_val, 187 | domain_loader=domain_loader, 188 | out=args.out, 189 | max_epoch=args.max_epoch, 190 | stop_epoch=args.stop_epoch, 191 | interval_validate=args.interval_validate, 192 | interval_save=args.interval_save, 193 | batch_size=args.batch_size, 194 | warmup_epoch=args.warmup_epoch 195 | ) 196 | trainer.epoch = start_epoch 197 | trainer.iteration = start_iteration 198 | trainer.train() 199 | 200 | if __name__ == '__main__': 201 | main() 202 | -------------------------------------------------------------------------------- /networks/GAN.py: -------------------------------------------------------------------------------- 1 | # camera-ready 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | 8 | class Discriminator(nn.Module): 9 | def __init__(self, ): 10 | super(Discriminator, self).__init__() 11 | 12 | filter_num_list = [4096, 2048, 1024, 1] 13 | 14 | self.fc1 = nn.Linear(24576, filter_num_list[0]) 15 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 16 | self.fc2 = nn.Linear(filter_num_list[0], filter_num_list[1]) 17 | self.fc3 = nn.Linear(filter_num_list[1], filter_num_list[2]) 18 | self.fc4 = nn.Linear(filter_num_list[2], filter_num_list[3]) 19 | 20 | # self.sigmoid = nn.Sigmoid() 21 | self._initialize_weights() 22 | 23 | 24 | def _initialize_weights(self): 25 | 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | m.weight.data.normal_(0.0, 0.02) 29 | if m.bias is not None: 30 | m.bias.data.zero_() 31 | 32 | if isinstance(m, nn.ConvTranspose2d): 33 | m.weight.data.normal_(0.0, 0.02) 34 | if m.bias is not None: 35 | m.bias.data.zero_() 36 | 37 | if isinstance(m, nn.Linear): 38 | m.weight.data.normal_(0.0, 0.02) 39 | if m.bias is not None: 40 | # m.bias.data.copy_(1.0) 41 | m.bias.data.zero_() 42 | 43 | 44 | def forward(self, x): 45 | 46 | x = self.leakyrelu(self.fc1(x)) 47 | x = self.leakyrelu(self.fc2(x)) 48 | x = self.leakyrelu(self.fc3(x)) 49 | x = self.fc4(x) 50 | return x 51 | 52 | 53 | class OutputDiscriminator(nn.Module): 54 | def __init__(self, ): 55 | super(OutputDiscriminator, self).__init__() 56 | 57 | filter_num_list = [64, 128, 256, 512, 1] 58 | 59 | self.conv1 = nn.Conv2d(2, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 60 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 61 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 62 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 63 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 64 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 65 | # self.sigmoid = nn.Sigmoid() 66 | self._initialize_weights() 67 | 68 | 69 | def _initialize_weights(self): 70 | for m in self.modules(): 71 | if isinstance(m, nn.Conv2d): 72 | m.weight.data.normal_(0.0, 0.02) 73 | if m.bias is not None: 74 | m.bias.data.zero_() 75 | 76 | 77 | def forward(self, x): 78 | x = self.leakyrelu(self.conv1(x)) 79 | x = self.leakyrelu(self.conv2(x)) 80 | x = self.leakyrelu(self.conv3(x)) 81 | x = self.leakyrelu(self.conv4(x)) 82 | x = self.conv5(x) 83 | return x 84 | 85 | 86 | class UncertaintyDiscriminator(nn.Module): 87 | def __init__(self, ): 88 | super(UncertaintyDiscriminator, self).__init__() 89 | 90 | filter_num_list = [64, 128, 256, 512, 1] 91 | 92 | self.conv1 = nn.Conv2d(2, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 93 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 94 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 95 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 96 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 97 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 98 | # self.sigmoid = nn.Sigmoid() 99 | self._initialize_weights() 100 | 101 | 102 | def _initialize_weights(self): 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | m.weight.data.normal_(0.0, 0.02) 106 | if m.bias is not None: 107 | m.bias.data.zero_() 108 | 109 | 110 | def forward(self, x): 111 | x = self.leakyrelu(self.conv1(x)) 112 | x = self.leakyrelu(self.conv2(x)) 113 | x = self.leakyrelu(self.conv3(x)) 114 | x = self.leakyrelu(self.conv4(x)) 115 | x = self.conv5(x) 116 | return x 117 | 118 | class BoundaryDiscriminator(nn.Module): 119 | def __init__(self, ): 120 | super(BoundaryDiscriminator, self).__init__() 121 | 122 | filter_num_list = [64, 128, 256, 512, 1] 123 | 124 | self.conv1 = nn.Conv2d(1, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 125 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 126 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 127 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 128 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 129 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 130 | # self.sigmoid = nn.Sigmoid() 131 | self._initialize_weights() 132 | 133 | 134 | def _initialize_weights(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | m.weight.data.normal_(0.0, 0.02) 138 | if m.bias is not None: 139 | m.bias.data.zero_() 140 | 141 | 142 | def forward(self, x): 143 | x = self.leakyrelu(self.conv1(x)) 144 | x = self.leakyrelu(self.conv2(x)) 145 | x = self.leakyrelu(self.conv3(x)) 146 | x = self.leakyrelu(self.conv4(x)) 147 | x = self.conv5(x) 148 | return x 149 | 150 | class BoundaryEntDiscriminator(nn.Module): 151 | def __init__(self, ): 152 | super(BoundaryEntDiscriminator, self).__init__() 153 | 154 | filter_num_list = [64, 128, 256, 512, 1] 155 | 156 | self.conv1 = nn.Conv2d(3, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False) 157 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False) 158 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False) 159 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False) 160 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False) 161 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2) 162 | # self.sigmoid = nn.Sigmoid() 163 | self._initialize_weights() 164 | 165 | 166 | def _initialize_weights(self): 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | m.weight.data.normal_(0.0, 0.02) 170 | if m.bias is not None: 171 | m.bias.data.zero_() 172 | 173 | 174 | def forward(self, x): 175 | x = self.leakyrelu(self.conv1(x)) 176 | x = self.leakyrelu(self.conv2(x)) 177 | x = self.leakyrelu(self.conv3(x)) 178 | x = self.leakyrelu(self.conv4(x)) 179 | x = self.conv5(x) 180 | return x 181 | 182 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import medpy.metric.binary as medmetric 4 | 5 | bce = torch.nn.BCEWithLogitsLoss(reduction='none') 6 | 7 | def _upscan(f): 8 | for i, fi in enumerate(f): 9 | if fi == np.inf: continue 10 | for j in range(1,i+1): 11 | x = fi+j*j 12 | if f[i-j] < x: break 13 | f[i-j] = x 14 | 15 | 16 | def dice_coefficient_numpy(binary_segmentation, binary_gt_label): 17 | ''' 18 | Compute the Dice coefficient between two binary segmentation. 19 | Dice coefficient is defined as here: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient 20 | Input: 21 | binary_segmentation: binary 2D numpy array representing the region of interest as segmented by the algorithm 22 | binary_gt_label: binary 2D numpy array representing the region of interest as provided in the database 23 | Output: 24 | dice_value: Dice coefficient between the segmentation and the ground truth 25 | ''' 26 | 27 | # turn all variables to booleans, just in case 28 | binary_segmentation = np.asarray(binary_segmentation, dtype=np.bool) 29 | binary_gt_label = np.asarray(binary_gt_label, dtype=np.bool) 30 | 31 | # compute the intersection 32 | intersection = np.logical_and(binary_segmentation, binary_gt_label) 33 | 34 | # count the number of True pixels in the binary segmentation 35 | # segmentation_pixels = float(np.sum(binary_segmentation.flatten())) 36 | segmentation_pixels = np.sum(binary_segmentation.astype(float), axis=(1,2)) 37 | # same for the ground truth 38 | # gt_label_pixels = float(np.sum(binary_gt_label.flatten())) 39 | gt_label_pixels = np.sum(binary_gt_label.astype(float), axis=(1,2)) 40 | # same for the intersection 41 | intersection = np.sum(intersection.astype(float), axis=(1,2)) 42 | 43 | # compute the Dice coefficient 44 | dice_value = (2 * intersection + 1.0) / (1.0 + segmentation_pixels + gt_label_pixels) 45 | 46 | # return it 47 | return dice_value 48 | 49 | 50 | def dice_coefficient_numpy_3D(binary_segmentation, binary_gt_label): 51 | ''' 52 | Compute the Dice coefficient between two binary segmentation. 53 | Dice coefficient is defined as here: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient 54 | Input: 55 | binary_segmentation: binary 3D numpy array representing the region of interest as segmented by the algorithm 56 | binary_gt_label: binary 3D numpy array representing the region of interest as provided in the database 57 | Output: 58 | dice_value: Dice coefficient between the segmentation and the ground truth 59 | ''' 60 | 61 | # turn all variables to booleans, just in case 62 | binary_segmentation = np.asarray(binary_segmentation, dtype=np.bool) 63 | binary_gt_label = np.asarray(binary_gt_label, dtype=np.bool) 64 | 65 | # compute the intersection 66 | intersection = np.logical_and(binary_segmentation, binary_gt_label) 67 | 68 | # count the number of True pixels in the binary segmentation 69 | # segmentation_pixels = float(np.sum(binary_segmentation.flatten())) 70 | segmentation_pixels = np.sum(binary_segmentation.astype(float), axis=(0,1,2)) 71 | # same for the ground truth 72 | # gt_label_pixels = float(np.sum(binary_gt_label.flatten())) 73 | gt_label_pixels = np.sum(binary_gt_label.astype(float), axis=(0,1,2)) 74 | # same for the intersection 75 | intersection = np.sum(intersection.astype(float), axis=(0,1,2)) 76 | 77 | # compute the Dice coefficient 78 | dice_value = (2 * intersection + 1.0) / (1.0 + segmentation_pixels + gt_label_pixels) 79 | 80 | # return it 81 | return dice_value 82 | 83 | 84 | def dice_numpy_medpy(binary_segmentation, binary_gt_label): 85 | 86 | # turn all variables to booleans, just in case 87 | binary_segmentation = np.asarray(binary_segmentation) 88 | binary_gt_label = np.asarray(binary_gt_label) 89 | 90 | return medmetric.dc(binary_segmentation, binary_gt_label) 91 | 92 | 93 | # if get_hd: 94 | # if np.sum(binary_segmentation) > 0 and np.sum(binary_gt_label) > 0: 95 | # return medmetric.assd(binary_segmentation, binary_gt_label) 96 | # # return medmetric.hd(binary_segmentation, binary_gt_label) 97 | # else: 98 | # return np.nan 99 | # else: 100 | # return 0.0 101 | 102 | 103 | def assd_numpy(binary_segmentation, binary_gt_label): 104 | 105 | # turn all variables to booleans, just in case 106 | binary_segmentation = np.asarray(binary_segmentation) 107 | binary_gt_label = np.asarray(binary_gt_label) 108 | 109 | if np.sum(binary_segmentation) > 0 and np.sum(binary_gt_label) > 0: 110 | return medmetric.assd(binary_segmentation, binary_gt_label) 111 | else: 112 | return -1 113 | 114 | 115 | def hd_numpy(binary_segmentation, binary_gt_label): 116 | 117 | # turn all variables to booleans, just in case 118 | binary_segmentation = np.asarray(binary_segmentation) 119 | binary_gt_label = np.asarray(binary_gt_label) 120 | 121 | if np.sum(binary_segmentation) > 0 and np.sum(binary_gt_label) > 0: 122 | return medmetric.hd(binary_segmentation, binary_gt_label) 123 | else: 124 | return -1 125 | 126 | 127 | def dice_coeff(pred, target): 128 | """This definition generalize to real valued pred and target vector. 129 | This should be differentiable. 130 | pred: tensor with first dimension as batch 131 | target: tensor with first dimension as batch 132 | """ 133 | 134 | target = target.data.cpu() 135 | pred = torch.sigmoid(pred) 136 | pred = pred.data.cpu() 137 | pred[pred > 0.5] = 1 138 | pred[pred <= 0.5] = 0 139 | 140 | return dice_coefficient_numpy(pred, target) 141 | 142 | def dice_coeff_2label(pred, target): 143 | """This definition generalize to real valued pred and target vector. 144 | This should be differentiable. 145 | pred: tensor with first dimension as batch 146 | target: tensor with first dimension as batch 147 | """ 148 | target = target.data.cpu() 149 | pred = torch.sigmoid(pred) 150 | pred = pred.data.cpu() 151 | pred[pred > 0.75] = 1 152 | pred[pred <= 0.75] = 0 153 | return dice_coefficient_numpy(pred[:, 0, ...], target[:, 0, ...]), dice_coefficient_numpy(pred[:, 1, ...], 154 | target[:, 1, ...]) 155 | 156 | 157 | def dice_coeff_4label(pred, target): 158 | """This definition generalize to real valued pred and target vector. 159 | This should be differentiable. 160 | pred: tensor with first dimension as batch 161 | target: tensor with first dimension as batch 162 | """ 163 | y_hat = torch.argmax(pred, dim=1, keepdim=True) 164 | pred_label = torch.zeros(pred.size()) # bs*4*W*H 165 | if torch.cuda.is_available(): 166 | pred_label = pred_label.cuda() 167 | pred_label = pred_label.scatter_(1, y_hat, 1) # one-hot label 168 | target = target.data.cpu() 169 | pred_label = pred_label.data.cpu() 170 | return (dice_coefficient_numpy(pred_label[:, 0, ...], target[:, 0, ...]), 171 | dice_coefficient_numpy(pred_label[:, 1, ...], target[:, 1, ...]), 172 | dice_coefficient_numpy(pred_label[:, 2, ...], target[:, 2, ...]), 173 | dice_coefficient_numpy(pred_label[:, 3, ...], target[:, 3, ...])) 174 | 175 | 176 | def DiceLoss(input, target): 177 | ''' 178 | in tensor fomate 179 | :param input: 180 | :param target: 181 | :return: 182 | ''' 183 | smooth = 1. 184 | iflat = input.contiguous().view(-1) 185 | tflat = target.contiguous().view(-1) 186 | intersection = (iflat * tflat).sum() 187 | 188 | return 1 - ((2. * intersection + smooth) / 189 | (iflat.sum() + tflat.sum() + smooth)) 190 | 191 | 192 | def assd_compute(pred, target): 193 | target = target.data.cpu() 194 | pred = torch.sigmoid(pred) 195 | pred = pred.data.cpu() 196 | pred[pred > 0.75] = 1 197 | pred[pred <= 0.75] = 0 198 | 199 | assd = np.zeros([pred.shape[0], pred.shape[1]]) 200 | for i in range(pred.shape[0]): 201 | for j in range(pred.shape[1]): 202 | assd[i][j] = assd_numpy(pred[i, j, ...], target[i, j, ...]) 203 | 204 | return assd 205 | -------------------------------------------------------------------------------- /train_process/Trainer.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import os.path as osp 4 | import timeit 5 | from torchvision.utils import make_grid 6 | import time 7 | 8 | import numpy as np 9 | import pytz 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | from tensorboardX import SummaryWriter 14 | 15 | import tqdm 16 | import socket 17 | from utils.metrics import * 18 | from utils.Utils import * 19 | 20 | bceloss = torch.nn.BCELoss() 21 | mseloss = torch.nn.MSELoss() 22 | 23 | def get_lr(optimizer): 24 | for param_group in optimizer.param_groups: 25 | return param_group['lr'] 26 | 27 | class Trainer(object): 28 | 29 | def __init__(self, cuda, multiply_gpu, model, optimizer, scheduler, val_loader, domain_loader, out, max_epoch, stop_epoch=None, 30 | lr=1e-3, interval_validate=10, interval_save=10, batch_size=8, warmup_epoch=10): 31 | self.cuda = cuda 32 | self.multiply_gpu = multiply_gpu 33 | self.warmup_epoch = warmup_epoch 34 | self.model = model 35 | self.optim = optimizer 36 | self.scheduler = scheduler 37 | self.lr = lr 38 | # self.lr_decrease_rate = lr_decrease_rate 39 | # self.lr_decrease_epoch = lr_decrease_epoch 40 | self.batch_size = batch_size 41 | 42 | self.val_loader = val_loader 43 | self.domain_loader = domain_loader 44 | self.time_zone = 'Asia/Shanghai' 45 | self.timestamp_start = datetime.now(pytz.timezone(self.time_zone)) 46 | 47 | self.interval_validate = interval_validate 48 | self.interval_save = interval_save 49 | 50 | self.out = out 51 | if not osp.exists(self.out): 52 | os.makedirs(self.out) 53 | 54 | self.log_headers = [ 55 | 'epoch', 56 | 'iteration', 57 | 'train/loss_seg', 58 | 'valid/loss_CE', 59 | 'valid/cup_dice', 60 | 'valid/disc_dice', 61 | 'elapsed_time', 62 | 'best_epoch' 63 | ] 64 | if not osp.exists(osp.join(self.out, 'log.csv')): 65 | with open(osp.join(self.out, 'log.csv'), 'w') as f: 66 | f.write(','.join(self.log_headers) + '\n') 67 | 68 | log_dir = os.path.join(self.out, 'tensorboard', 69 | datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) 70 | self.writer = SummaryWriter(log_dir=log_dir) 71 | 72 | self.epoch = 0 73 | self.iteration = 0 74 | self.max_epoch = max_epoch 75 | self.stop_epoch = stop_epoch if stop_epoch is not None else max_epoch 76 | self.best_mean_dice = 0.0 77 | self.best_epoch = -1 78 | 79 | 80 | def validate_fundus(self): 81 | training = self.model.training 82 | self.model.eval() 83 | 84 | val_loss = 0.0 85 | val_fundus_dice = {'cup': 0.0, 'disc': 0.0} 86 | data_num_cnt = 0.0 87 | metrics = [] 88 | with torch.no_grad(): 89 | for batch_idx, sample in tqdm.tqdm( 90 | enumerate(self.val_loader), total=len(self.val_loader), 91 | desc='Valid iteration=%d' % self.iteration, ncols=80, 92 | leave=False): 93 | data = sample['image'] 94 | target_map = sample['label'] 95 | if self.cuda: 96 | data, target_map = data.cuda(), target_map.cuda() 97 | with torch.no_grad(): 98 | predictions, _ = self.model(data) 99 | 100 | loss = F.binary_cross_entropy_with_logits(predictions, target_map) 101 | loss_data = loss.data.item() 102 | if np.isnan(loss_data): 103 | raise ValueError('loss is nan while validating') 104 | val_loss += loss_data 105 | 106 | dice_cup, dice_disc = dice_coeff_2label(predictions, target_map) 107 | val_fundus_dice['cup'] += np.sum(dice_cup) 108 | val_fundus_dice['disc'] += np.sum(dice_disc) 109 | data_num_cnt += float(dice_cup.shape[0]) 110 | 111 | val_loss /= data_num_cnt 112 | val_fundus_dice['cup'] /= data_num_cnt 113 | val_fundus_dice['disc'] /= data_num_cnt 114 | metrics.append((val_loss, val_fundus_dice['cup'], val_fundus_dice['disc'])) 115 | 116 | self.writer.add_scalar('val_data/loss_CE', val_loss, self.epoch * (len(self.domain_loader))) 117 | self.writer.add_scalar('val_data/val_CUP_dice', val_fundus_dice['cup'], self.epoch * (len(self.domain_loader))) 118 | self.writer.add_scalar('val_data/val_DISC_dice', val_fundus_dice['disc'], self.epoch * (len(self.domain_loader))) 119 | 120 | mean_dice = (val_fundus_dice['cup'] + val_fundus_dice['disc'])/2 121 | is_best = mean_dice > self.best_mean_dice 122 | if is_best: 123 | self.best_epoch = self.epoch + 1 124 | self.best_mean_dice = mean_dice 125 | 126 | torch.save({ 127 | 'epoch': self.epoch, 128 | 'iteration': self.iteration, 129 | 'arch': self.model.__class__.__name__, 130 | 'optim_state_dict': self.optim.state_dict(), 131 | 'model_state_dict': self.model.module.state_dict() if self.multiply_gpu else self.model.state_dict(), 132 | 'learning_rate_gen': get_lr(self.optim), 133 | 'best_mean_dice': self.best_mean_dice, 134 | }, osp.join(self.out, 'checkpoint_%d.pth.tar' % self.best_epoch)) 135 | else: 136 | if (self.epoch + 1) % self.interval_save == 0: 137 | torch.save({ 138 | 'epoch': self.epoch, 139 | 'iteration': self.iteration, 140 | 'arch': self.model.__class__.__name__, 141 | 'optim_state_dict': self.optim.state_dict(), 142 | 'model_state_dict': self.model.module.state_dict() if self.multiply_gpu else self.model.state_dict(), 143 | 'learning_rate_gen': get_lr(self.optim), 144 | 'best_mean_dice': self.best_mean_dice, 145 | }, osp.join(self.out, 'checkpoint_%d.pth.tar' % (self.epoch + 1))) 146 | 147 | with open(osp.join(self.out, 'log.csv'), 'a') as f: 148 | elapsed_time = ( 149 | datetime.now(pytz.timezone(self.time_zone)) - 150 | self.timestamp_start).total_seconds() 151 | log = [self.epoch, self.iteration] + [''] + list(metrics) + [elapsed_time] + [self.best_epoch] 152 | log = map(str, log) 153 | f.write(','.join(log) + '\n') 154 | self.writer.add_scalar('best_model_epoch', self.best_epoch, self.epoch * (len(self.domain_loader))) 155 | if training: 156 | self.model.train() 157 | 158 | 159 | def train_epoch(self): 160 | self.model.train() 161 | self.running_seg_loss = 0.0 162 | 163 | start_time = timeit.default_timer() 164 | for batch_idx, sample in tqdm.tqdm( 165 | enumerate(self.domain_loader), total=len(self.domain_loader), 166 | desc='Train epoch=%d' % self.epoch, ncols=80, leave=False): 167 | 168 | iteration = batch_idx + self.epoch * len(self.domain_loader) 169 | self.iteration = iteration 170 | 171 | assert self.model.training 172 | 173 | self.optim.zero_grad() 174 | 175 | # train 176 | for param in self.model.parameters(): 177 | param.requires_grad = True 178 | 179 | image = sample['image'].cuda() 180 | target_map = sample['label'].cuda() 181 | 182 | pred, _ = self.model(image) 183 | pred = torch.sigmoid(pred) 184 | loss_seg = bceloss(pred, target_map) 185 | 186 | self.running_seg_loss += loss_seg.item() 187 | self.running_seg_loss /= len(self.domain_loader) 188 | 189 | loss_seg_data = loss_seg.data.item() 190 | if np.isnan(loss_seg_data): 191 | raise ValueError('loss is nan while training') 192 | 193 | loss_seg.backward() 194 | self.optim.step() 195 | 196 | # write image log 197 | if iteration % 30 == 0: 198 | grid_image = make_grid(image[0, ...].clone().cpu().data, 1, normalize=True) 199 | self.writer.add_image('Domain/image', grid_image, iteration) 200 | grid_image = make_grid(target_map[0, 0, ...].clone().cpu().data, 1, normalize=True) 201 | self.writer.add_image('Domain/target_cup', grid_image, iteration) 202 | grid_image = make_grid(target_map[0, 1, ...].clone().cpu().data, 1, normalize=True) 203 | self.writer.add_image('Domain/target_disc', grid_image, iteration) 204 | grid_image = make_grid(pred[0, 0, ...].clone().cpu().data, 1, normalize=True) 205 | self.writer.add_image('Domain/prediction_cup', grid_image, iteration) 206 | grid_image = make_grid(pred[0, 1, ...].clone().cpu().data, 1, normalize=True) 207 | self.writer.add_image('Domain/prediction_disc', grid_image, iteration) 208 | 209 | 210 | self.writer.add_scalar('train/loss_seg', loss_seg_data, iteration) 211 | 212 | with open(osp.join(self.out, 'log.csv'), 'a') as f: 213 | elapsed_time = ( 214 | datetime.now(pytz.timezone(self.time_zone)) - 215 | self.timestamp_start).total_seconds() 216 | log = [self.epoch, self.iteration] + [loss_seg_data] + [''] * 3 + [elapsed_time] + [''] 217 | log = map(str, log) 218 | f.write(','.join(log) + '\n') 219 | 220 | stop_time = timeit.default_timer() 221 | 222 | print('\n[Epoch: %d] lr:%f, Average segLoss: %f, Execution time: %.5f\n' % 223 | (self.epoch, get_lr(self.optim), self.running_seg_loss, stop_time - start_time)) 224 | 225 | 226 | def train(self): 227 | for epoch in tqdm.trange(self.epoch, self.max_epoch, 228 | desc='Train', ncols=80): 229 | self.epoch = epoch 230 | self.train_epoch() 231 | if self.stop_epoch == self.epoch: 232 | print('Stop epoch at %d' % self.stop_epoch) 233 | break 234 | 235 | self.scheduler.step() 236 | self.writer.add_scalar('lr', get_lr(self.optim), self.epoch * (len(self.domain_loader))) 237 | 238 | if (self.epoch+1) % self.interval_validate == 0: 239 | self.validate_fundus() 240 | self.writer.close() 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /utils/Utils.py: -------------------------------------------------------------------------------- 1 | 2 | # from scipy.misc import imsave 3 | import os.path as osp 4 | import numpy as np 5 | import os 6 | import cv2 7 | from skimage import morphology 8 | import scipy 9 | from PIL import Image 10 | from matplotlib.pyplot import imsave 11 | # from keras.preprocessing import image 12 | from skimage.measure import label, regionprops 13 | from skimage.transform import rotate, resize 14 | from skimage import measure, draw 15 | 16 | import matplotlib.pyplot as plt 17 | plt.switch_backend('agg') 18 | 19 | # from scipy.misc import imsave 20 | from utils.metrics import * 21 | import cv2 22 | 23 | 24 | def construct_color_img(prob_per_slice): 25 | shape = prob_per_slice.shape 26 | img = np.zeros((shape[0], shape[1], 3), dtype=np.uint8) 27 | img[:, :, 0] = prob_per_slice * 255 28 | img[:, :, 1] = prob_per_slice * 255 29 | img[:, :, 2] = prob_per_slice * 255 30 | 31 | im_color = cv2.applyColorMap(img, cv2.COLORMAP_JET) 32 | return im_color 33 | 34 | 35 | def normalize_ent(ent): 36 | ''' 37 | Normalizate ent to 0 - 1 38 | :param ent: 39 | :return: 40 | ''' 41 | min = np.amin(ent) 42 | return (ent - min) / 0.4 43 | 44 | 45 | def draw_ent(prediction, save_root, name): 46 | ''' 47 | Draw the entropy information for each img and save them to the save path 48 | :param prediction: [2, h, w] numpy 49 | :param save_path: string including img name 50 | :return: None 51 | ''' 52 | if not os.path.exists(os.path.join(save_root, 'disc')): 53 | os.makedirs(os.path.join(save_root, 'disc')) 54 | if not os.path.exists(os.path.join(save_root, 'cup')): 55 | os.makedirs(os.path.join(save_root, 'cup')) 56 | smooth = 1e-8 57 | cup = prediction[0] 58 | disc = prediction[1] 59 | cup_ent = - cup * np.log(cup + smooth) 60 | disc_ent = - disc * np.log(disc + smooth) 61 | cup_ent = normalize_ent(cup_ent) 62 | disc_ent = normalize_ent(disc_ent) 63 | disc = construct_color_img(disc_ent) 64 | cv2.imwrite(os.path.join(save_root, 'disc', name.split('.')[0]) + '.png', disc) 65 | cup = construct_color_img(cup_ent) 66 | cv2.imwrite(os.path.join(save_root, 'cup', name.split('.')[0]) + '.png', cup) 67 | 68 | 69 | def draw_mask(prediction, save_root, name): 70 | ''' 71 | Draw the mask probability for each img and save them to the save path 72 | :param prediction: [2, h, w] numpy 73 | :param save_path: string including img name 74 | :return: None 75 | ''' 76 | if not os.path.exists(os.path.join(save_root, 'disc')): 77 | os.makedirs(os.path.join(save_root, 'disc')) 78 | if not os.path.exists(os.path.join(save_root, 'cup')): 79 | os.makedirs(os.path.join(save_root, 'cup')) 80 | cup = prediction[0] 81 | disc = prediction[1] 82 | 83 | disc = construct_color_img(disc) 84 | cv2.imwrite(os.path.join(save_root, 'disc', name.split('.')[0]) + '.png', disc) 85 | cup = construct_color_img(cup) 86 | cv2.imwrite(os.path.join(save_root, 'cup', name.split('.')[0]) + '.png', cup) 87 | 88 | def draw_boundary(prediction, save_root, name): 89 | ''' 90 | Draw the mask probability for each img and save them to the save path 91 | :param prediction: [2, h, w] numpy 92 | :param save_path: string including img name 93 | :return: None 94 | ''' 95 | if not os.path.exists(os.path.join(save_root, 'boundary')): 96 | os.makedirs(os.path.join(save_root, 'boundary')) 97 | boundary = prediction[0] 98 | boundary = construct_color_img(boundary) 99 | cv2.imwrite(os.path.join(save_root, 'boundary', name.split('.')[0]) + '.png', boundary) 100 | 101 | 102 | def get_largest_fillhole(binary): 103 | label_image = label(binary) 104 | regions = regionprops(label_image) 105 | area_list = [] 106 | for region in regions: 107 | area_list.append(region.area) 108 | if area_list: 109 | idx_max = np.argmax(area_list) 110 | binary[label_image != idx_max + 1] = 0 111 | return scipy.ndimage.binary_fill_holes(np.asarray(binary).astype(int)) 112 | 113 | def postprocessing(prediction, threshold=0.75, dataset='G'): 114 | if dataset[0] == 'D': 115 | prediction = prediction.numpy() 116 | prediction_copy = np.copy(prediction) 117 | disc_mask = prediction[1] 118 | cup_mask = prediction[0] 119 | disc_mask = (disc_mask > 0.5) # return binary mask 120 | cup_mask = (cup_mask > 0.1) # return binary mask 121 | disc_mask = disc_mask.astype(np.uint8) 122 | cup_mask = cup_mask.astype(np.uint8) 123 | for i in range(5): 124 | disc_mask = scipy.signal.medfilt2d(disc_mask, 7) 125 | cup_mask = scipy.signal.medfilt2d(cup_mask, 7) 126 | disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 127 | cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 128 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1 129 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8) 130 | prediction_copy[0] = cup_mask 131 | prediction_copy[1] = disc_mask 132 | return prediction_copy 133 | else: 134 | prediction = prediction.numpy() 135 | prediction = (prediction > threshold) # return binary mask 136 | prediction = prediction.astype(np.uint8) 137 | prediction_copy = np.copy(prediction) 138 | disc_mask = prediction[1] 139 | cup_mask = prediction[0] 140 | # for i in range(5): 141 | # disc_mask = scipy.signal.medfilt2d(disc_mask, 7) 142 | # cup_mask = scipy.signal.medfilt2d(cup_mask, 7) 143 | # disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 144 | # cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 145 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1 146 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8) 147 | prediction_copy[0] = cup_mask 148 | prediction_copy[1] = disc_mask 149 | return prediction_copy 150 | 151 | 152 | def joint_val_image(image, prediction, mask): 153 | ratio = 0.5 154 | _pred_cup = np.zeros([mask.shape[-2], mask.shape[-1], 3]) 155 | _pred_disc = np.zeros([mask.shape[-2], mask.shape[-1], 3]) 156 | _mask = np.zeros([mask.shape[-2], mask.shape[-1], 3]) 157 | image = np.transpose(image, (1, 2, 0)) 158 | 159 | _pred_cup[:, :, 0] = prediction[0] 160 | _pred_cup[:, :, 1] = prediction[0] 161 | _pred_cup[:, :, 2] = prediction[0] 162 | _pred_disc[:, :, 0] = prediction[1] 163 | _pred_disc[:, :, 1] = prediction[1] 164 | _pred_disc[:, :, 2] = prediction[1] 165 | _mask[:,:,0] = mask[0] 166 | _mask[:,:,1] = mask[1] 167 | 168 | pred_cup = np.add(ratio * image, (1 - ratio) * _pred_cup) 169 | pred_disc = np.add(ratio * image, (1 - ratio) * _pred_disc) 170 | mask_img = np.add(ratio * image, (1 - ratio) * _mask) 171 | 172 | joint_img = np.concatenate([image, mask_img, pred_cup, pred_disc], axis=1) 173 | return joint_img 174 | 175 | 176 | def save_val_img(path, epoch, img): 177 | name = osp.join(path, "visualization", "epoch_%d.png" % epoch) 178 | out = osp.join(path, "visualization") 179 | if not osp.exists(out): 180 | os.makedirs(out) 181 | img_shape = img[0].shape 182 | stack_image = np.zeros([len(img) * img_shape[0], img_shape[1], img_shape[2]]) 183 | for i in range(len(img)): 184 | stack_image[i * img_shape[0] : (i + 1) * img_shape[0], :, : ] = img[i] 185 | imsave(name, stack_image) 186 | 187 | 188 | 189 | 190 | def save_per_img(patch_image, data_save_path, img_name, prob_map, mask_path=None, ext="bmp"): 191 | path1 = os.path.join(data_save_path, 'overlay', img_name.split('.')[0]+'.png') 192 | path0 = os.path.join(data_save_path, 'original_image', img_name.split('.')[0]+'.png') 193 | if not os.path.exists(os.path.dirname(path0)): 194 | os.makedirs(os.path.dirname(path0)) 195 | if not os.path.exists(os.path.dirname(path1)): 196 | os.makedirs(os.path.dirname(path1)) 197 | 198 | disc_map = prob_map[0] 199 | cup_map = prob_map[1] 200 | size = disc_map.shape 201 | disc_map[:, 0] = np.zeros(size[0]) 202 | disc_map[:, size[1] - 1] = np.zeros(size[0]) 203 | disc_map[0, :] = np.zeros(size[1]) 204 | disc_map[size[0] - 1, :] = np.zeros(size[1]) 205 | size = cup_map.shape 206 | cup_map[:, 0] = np.zeros(size[0]) 207 | cup_map[:, size[1] - 1] = np.zeros(size[0]) 208 | cup_map[0, :] = np.zeros(size[1]) 209 | cup_map[size[0] - 1, :] = np.zeros(size[1]) 210 | 211 | disc_mask = (disc_map > 0.75) # return binary mask 212 | cup_mask = (cup_map > 0.75) 213 | disc_mask = disc_mask.astype(np.uint8) 214 | cup_mask = cup_mask.astype(np.uint8) 215 | 216 | for i in range(5): 217 | disc_mask = scipy.signal.medfilt2d(disc_mask, 7) 218 | cup_mask = scipy.signal.medfilt2d(cup_mask, 7) 219 | disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 220 | cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 221 | disc_mask = get_largest_fillhole(disc_mask) 222 | cup_mask = get_largest_fillhole(cup_mask) 223 | 224 | disc_mask = morphology.binary_dilation(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 225 | cup_mask = morphology.binary_dilation(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1 226 | 227 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1 228 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8) 229 | 230 | 231 | contours_disc = measure.find_contours(disc_mask, 0.5) 232 | contours_cup = measure.find_contours(cup_mask, 0.5) 233 | 234 | patch_image2 = patch_image.astype(np.uint8) 235 | patch_image2 = Image.fromarray(patch_image2) 236 | 237 | patch_image2.save(path0) 238 | 239 | for n, contour in enumerate(contours_cup): 240 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0] 241 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0] 242 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 255, 0] 243 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 255, 0] 244 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0] 245 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 255, 0] 246 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 255, 0] 247 | 248 | for n, contour in enumerate(contours_disc): 249 | patch_image[contour[:, 0].astype(int), contour[:, 1].astype(int), :] = [0, 0, 255] 250 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 0, 255] 251 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 0, 255] 252 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 0, 255] 253 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 0, 255] 254 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 0, 255] 255 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 0, 255] 256 | 257 | patch_image = patch_image.astype(np.uint8) 258 | patch_image = Image.fromarray(patch_image) 259 | 260 | patch_image.save(path1) 261 | 262 | def untransform(img, lt): 263 | img = (img + 1) * 127.5 264 | lt = lt * 128 265 | return img, lt -------------------------------------------------------------------------------- /networks/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return x, low_level_feat 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in model_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) 289 | -------------------------------------------------------------------------------- /networks/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 142 | as the built-in PyTorch implementation. 143 | The mean and standard-deviation are calculated per-dimension over 144 | the mini-batches and gamma and beta are learnable parameter vectors 145 | of size C (where C is the input size). 146 | During training, this layer keeps a running estimate of its computed mean 147 | and variance. The running sum is kept with a default momentum of 0.1. 148 | During evaluation, this running mean/variance is used for normalization. 149 | Because the BatchNorm is done over the `C` dimension, computing statistics 150 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 151 | Args: 152 | num_features: num_features from an expected input of size 153 | `batch_size x num_features [x width]` 154 | eps: a value added to the denominator for numerical stability. 155 | Default: 1e-5 156 | momentum: the value used for the running_mean and running_var 157 | computation. Default: 0.1 158 | affine: a boolean value that when set to ``True``, gives the layer learnable 159 | affine parameters. Default: ``True`` 160 | Shape: 161 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 162 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 163 | Examples: 164 | >>> # With Learnable Parameters 165 | >>> m = SynchronizedBatchNorm1d(100) 166 | >>> # Without Learnable Parameters 167 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 168 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 169 | >>> output = m(input) 170 | """ 171 | 172 | def _check_input_dim(self, input): 173 | if input.dim() != 2 and input.dim() != 3: 174 | raise ValueError('expected 2D or 3D input (got {}D input)' 175 | .format(input.dim())) 176 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 177 | 178 | 179 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 180 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 181 | of 3d inputs 182 | .. math:: 183 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 184 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 185 | standard-deviation are reduced across all devices during training. 186 | For example, when one uses `nn.DataParallel` to wrap the network during 187 | training, PyTorch's implementation normalize the tensor on each device using 188 | the statistics only on that device, which accelerated the computation and 189 | is also easy to implement, but the statistics might be inaccurate. 190 | Instead, in this synchronized version, the statistics will be computed 191 | over all training samples distributed on multiple devices. 192 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 193 | as the built-in PyTorch implementation. 194 | The mean and standard-deviation are calculated per-dimension over 195 | the mini-batches and gamma and beta are learnable parameter vectors 196 | of size C (where C is the input size). 197 | During training, this layer keeps a running estimate of its computed mean 198 | and variance. The running sum is kept with a default momentum of 0.1. 199 | During evaluation, this running mean/variance is used for normalization. 200 | Because the BatchNorm is done over the `C` dimension, computing statistics 201 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 202 | Args: 203 | num_features: num_features from an expected input of 204 | size batch_size x num_features x height x width 205 | eps: a value added to the denominator for numerical stability. 206 | Default: 1e-5 207 | momentum: the value used for the running_mean and running_var 208 | computation. Default: 0.1 209 | affine: a boolean value that when set to ``True``, gives the layer learnable 210 | affine parameters. Default: ``True`` 211 | Shape: 212 | - Input: :math:`(N, C, H, W)` 213 | - Output: :math:`(N, C, H, W)` (same shape as input) 214 | Examples: 215 | >>> # With Learnable Parameters 216 | >>> m = SynchronizedBatchNorm2d(100) 217 | >>> # Without Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 219 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 220 | >>> output = m(input) 221 | """ 222 | 223 | def _check_input_dim(self, input): 224 | if input.dim() != 4: 225 | raise ValueError('expected 4D input (got {}D input)' 226 | .format(input.dim())) 227 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 228 | 229 | 230 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 231 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 232 | of 4d inputs 233 | .. math:: 234 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 235 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 236 | standard-deviation are reduced across all devices during training. 237 | For example, when one uses `nn.DataParallel` to wrap the network during 238 | training, PyTorch's implementation normalize the tensor on each device using 239 | the statistics only on that device, which accelerated the computation and 240 | is also easy to implement, but the statistics might be inaccurate. 241 | Instead, in this synchronized version, the statistics will be computed 242 | over all training samples distributed on multiple devices. 243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 244 | as the built-in PyTorch implementation. 245 | The mean and standard-deviation are calculated per-dimension over 246 | the mini-batches and gamma and beta are learnable parameter vectors 247 | of size C (where C is the input size). 248 | During training, this layer keeps a running estimate of its computed mean 249 | and variance. The running sum is kept with a default momentum of 0.1. 250 | During evaluation, this running mean/variance is used for normalization. 251 | Because the BatchNorm is done over the `C` dimension, computing statistics 252 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 253 | or Spatio-temporal BatchNorm 254 | Args: 255 | num_features: num_features from an expected input of 256 | size batch_size x num_features x depth x height x width 257 | eps: a value added to the denominator for numerical stability. 258 | Default: 1e-5 259 | momentum: the value used for the running_mean and running_var 260 | computation. Default: 0.1 261 | affine: a boolean value that when set to ``True``, gives the layer learnable 262 | affine parameters. Default: ``True`` 263 | Shape: 264 | - Input: :math:`(N, C, D, H, W)` 265 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 266 | Examples: 267 | >>> # With Learnable Parameters 268 | >>> m = SynchronizedBatchNorm3d(100) 269 | >>> # Without Learnable Parameters 270 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 271 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 272 | >>> output = m(input) 273 | """ 274 | 275 | def _check_input_dim(self, input): 276 | if input.dim() != 5: 277 | raise ValueError('expected 5D input (got {}D input)' 278 | .format(input.dim())) 279 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /train_target.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | parser.add_argument('-g', '--gpu', type=str, default='0') 4 | parser.add_argument('--model-file', type=str, default='./logs_train/Domain3/source_model.pth.tar') 5 | parser.add_argument('--model', type=str, default='Deeplab', help='Deeplab') 6 | parser.add_argument('--out-stride', type=int, default=16) 7 | parser.add_argument('--sync-bn', type=bool, default=True) 8 | parser.add_argument('--freeze-bn', type=bool, default=False) 9 | parser.add_argument('--epoch', type=int, default=20) 10 | parser.add_argument('--lr', type=float, default=5e-4) 11 | parser.add_argument('--lr-decrease-rate', type=float, default=0.9, help='ratio multiplied to initial lr') 12 | parser.add_argument('--lr-decrease-epoch', type=int, default=1, help='interval epoch number for lr decrease') 13 | 14 | parser.add_argument('--data-dir', default='../Datasets/Fundus') 15 | parser.add_argument('--dataset', type=str, default='Domain2') 16 | parser.add_argument('--model-source', type=str, default='Domain3') 17 | parser.add_argument('--batch-size', type=int, default=8) 18 | 19 | parser.add_argument('--model-ema-rate', type=float, default=0.98) 20 | parser.add_argument('--pseudo-label-threshold', type=float, default=0.75) 21 | parser.add_argument('--mean-loss-calc-bound-ratio', type=float, default=0.2) 22 | 23 | args = parser.parse_args() 24 | 25 | import os 26 | 27 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 28 | 29 | import os.path as osp 30 | 31 | import numpy as np 32 | import torch.nn.functional as F 33 | 34 | import torch 35 | from torch.autograd import Variable 36 | import tqdm 37 | from torch.utils.data import DataLoader 38 | from dataloaders import fundus_dataloader 39 | from dataloaders import custom_transforms as trans 40 | from torchvision import transforms 41 | # from scipy.misc import imsave 42 | from matplotlib.pyplot import imsave 43 | from utils.Utils import * 44 | from utils.metrics import * 45 | from datetime import datetime 46 | import pytz 47 | import networks.deeplabv3 as netd 48 | import cv2 49 | import torch.backends.cudnn as cudnn 50 | import random 51 | import glob 52 | import sys 53 | 54 | seed = 42 55 | savefig = False 56 | get_hd = True 57 | model_save = True 58 | cudnn.benchmark = False 59 | cudnn.deterministic = True 60 | random.seed(seed) 61 | np.random.seed(seed) 62 | torch.manual_seed(seed) 63 | torch.cuda.manual_seed(seed) 64 | 65 | def print_args(args): 66 | s = "==========================================\n" 67 | for arg, content in args.__dict__.items(): 68 | s += "{}:{}\n".format(arg, content) 69 | return s 70 | 71 | 72 | def soft_label_to_hard(soft_pls, pseudo_label_threshold): 73 | pseudo_labels = torch.zeros(soft_pls.size()) 74 | if torch.cuda.is_available(): 75 | pseudo_labels = pseudo_labels.cuda() 76 | pseudo_labels[soft_pls > pseudo_label_threshold] = 1 77 | pseudo_labels[soft_pls <= pseudo_label_threshold] = 0 78 | 79 | return pseudo_labels 80 | 81 | 82 | def init_feature_pred_bank(model, loader): 83 | feature_bank = {} 84 | pred_bank = {} 85 | 86 | model.eval() 87 | 88 | with torch.no_grad(): 89 | for sample in loader: 90 | data = sample['image'] 91 | img_name = sample['img_name'] 92 | data = data.cuda() 93 | 94 | pred, feat = model(data) 95 | pred = torch.sigmoid(pred) 96 | 97 | for i in range(data.size(0)): 98 | feature_bank[img_name[i]] = feat[i].detach().clone() 99 | pred_bank[img_name[i]] = pred[i].detach().clone() 100 | 101 | model.train() 102 | 103 | return feature_bank, pred_bank 104 | 105 | 106 | def adapt_epoch(model_t, model_s, optim, train_loader, args, feature_bank, pred_bank, loss_weight=None): 107 | for sample_w, sample_s in train_loader: 108 | imgs_w = sample_w['image'] 109 | imgs_s = sample_s['image'] 110 | img_name = sample_w['img_name'] 111 | if torch.cuda.is_available(): 112 | imgs_w = imgs_w.cuda() 113 | imgs_s = imgs_s.cuda() 114 | 115 | # model predict 116 | predictions_stu_s, features_stu_s = model_s(imgs_s) 117 | with torch.no_grad(): 118 | predictions_tea_w, features_tea_w = model_t(imgs_w) 119 | 120 | predictions_stu_s_sigmoid = torch.sigmoid(predictions_stu_s) 121 | predictions_tea_w_sigmoid = torch.sigmoid(predictions_tea_w) 122 | 123 | # get hard pseudo label 124 | pseudo_labels = soft_label_to_hard(predictions_tea_w_sigmoid, args.pseudo_label_threshold) 125 | 126 | 127 | bceloss = torch.nn.BCELoss(reduction='none') 128 | loss_seg_pixel = bceloss(predictions_stu_s_sigmoid, pseudo_labels) 129 | 130 | mean_loss_weight_mask = torch.ones(pseudo_labels.size()).cuda() 131 | mean_loss_weight_mask[:, 0, ...][pseudo_labels[:, 0, ...] == 0] = loss_weight 132 | loss_mask = mean_loss_weight_mask 133 | 134 | loss = torch.sum(loss_seg_pixel * loss_mask) / torch.sum(loss_mask) 135 | 136 | loss.backward() 137 | optim.step() 138 | optim.zero_grad() 139 | 140 | # update teacher 141 | for param_s, param_t in zip(model_s.parameters(), model_t.parameters()): 142 | param_t.data = param_t.data.clone() * args.model_ema_rate + param_s.data.clone() * (1.0 - args.model_ema_rate) 143 | 144 | # update feature/pred bank 145 | for idx in range(len(img_name)): 146 | feature_bank[img_name[idx]] = features_tea_w[idx].detach().clone() 147 | pred_bank[img_name[idx]] = predictions_tea_w_sigmoid[idx].detach().clone() 148 | 149 | 150 | def eval(model, data_loader): 151 | model.eval() 152 | 153 | val_dice = {'cup': np.array([]), 'disc': np.array([])} 154 | val_assd = {'cup': np.array([]), 'disc': np.array([])} 155 | 156 | with torch.no_grad(): 157 | for batch_idx, sample in enumerate(data_loader): 158 | data = sample['image'] 159 | target_map = sample['label'] 160 | data = data.cuda() 161 | predictions, _ = model(data) 162 | 163 | dice_cup, dice_disc = dice_coeff_2label(predictions, target_map) 164 | val_dice['cup'] = np.append(val_dice['cup'], dice_cup) 165 | val_dice['disc'] = np.append(val_dice['disc'], dice_disc) 166 | 167 | assd = assd_compute(predictions, target_map) 168 | val_assd['cup'] = np.append(val_assd['cup'], assd[:, 0]) 169 | val_assd['disc'] = np.append(val_assd['disc'], assd[:, 1]) 170 | 171 | avg_dice = [0.0, 0.0, 0.0, 0.0] 172 | std_dice = [0.0, 0.0, 0.0, 0.0] 173 | avg_assd = [0.0, 0.0, 0.0, 0.0] 174 | std_assd = [0.0, 0.0, 0.0, 0.0] 175 | avg_dice[0] = np.mean(val_dice['cup']) 176 | avg_dice[1] = np.mean(val_dice['disc']) 177 | std_dice[0] = np.std(val_dice['cup']) 178 | std_dice[1] = np.std(val_dice['disc']) 179 | val_assd['cup'] = np.delete(val_assd['cup'], np.where(val_assd['cup'] == -1)) 180 | val_assd['disc'] = np.delete(val_assd['disc'], np.where(val_assd['disc'] == -1)) 181 | avg_assd[0] = np.mean(val_assd['cup']) 182 | avg_assd[1] = np.mean(val_assd['disc']) 183 | std_assd[0] = np.std(val_assd['cup']) 184 | std_assd[1] = np.std(val_assd['disc']) 185 | 186 | model.train() 187 | 188 | return avg_dice, std_dice, avg_assd, std_assd 189 | 190 | 191 | def main(): 192 | now = datetime.now() 193 | here = osp.dirname(osp.abspath(__file__)) 194 | args.out = osp.join(here, 'logs_target', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f')) 195 | if not osp.exists(args.out): 196 | os.makedirs(args.out) 197 | args.out_file = open(osp.join(args.out, now.strftime('%Y%m%d_%H%M%S.%f')+'.txt'), 'w') 198 | args.out_file.write(' '.join(sys.argv) + '\n') 199 | args.out_file.write(print_args(args) + '\n') 200 | args.out_file.flush() 201 | 202 | # dataset 203 | composed_transforms_train = transforms.Compose([ 204 | trans.Resize(512), 205 | trans.add_salt_pepper_noise(), 206 | trans.adjust_light(), 207 | trans.eraser(), 208 | trans.Normalize_tf(), 209 | trans.ToTensor() 210 | ]) 211 | composed_transforms_test = transforms.Compose([ 212 | trans.Resize(512), 213 | trans.Normalize_tf(), 214 | trans.ToTensor() 215 | ]) 216 | 217 | dataset_train = fundus_dataloader.FundusSegmentation_2transform(base_dir=args.data_dir, dataset=args.dataset, 218 | split='train/ROIs', 219 | transform_weak=composed_transforms_test, 220 | transform_strong=composed_transforms_train) 221 | dataset_train_weak = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, 222 | split='train/ROIs', 223 | transform=composed_transforms_test) 224 | dataset_test = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='test/ROIs', 225 | transform=composed_transforms_test) 226 | 227 | train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=2) 228 | train_loader_weak = DataLoader(dataset_train_weak, batch_size=args.batch_size, shuffle=False, num_workers=2) 229 | test_loader = DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=2) 230 | 231 | # model 232 | model_s = netd.DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, 233 | freeze_bn=args.freeze_bn) 234 | model_t = netd.DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn, 235 | freeze_bn=args.freeze_bn) 236 | 237 | 238 | if torch.cuda.is_available(): 239 | model_s = model_s.cuda() 240 | model_t = model_t.cuda() 241 | log_str = '==> Loading %s model file: %s' % (model_s.__class__.__name__, args.model_file) 242 | print(log_str) 243 | args.out_file.write(log_str + '\n') 244 | args.out_file.flush() 245 | checkpoint = torch.load(args.model_file) 246 | model_s.load_state_dict(checkpoint['model_state_dict']) 247 | model_t.load_state_dict(checkpoint['model_state_dict']) 248 | 249 | if (args.gpu).find(',') != -1: 250 | model_s = torch.nn.DataParallel(model_s, device_ids=[0, 1]) 251 | model_t = torch.nn.DataParallel(model_t, device_ids=[0, 1]) 252 | 253 | optim = torch.optim.Adam(model_s.parameters(), lr=args.lr, betas=(0.9, 0.99)) 254 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=args.lr_decrease_epoch, gamma=args.lr_decrease_rate) 255 | 256 | model_s.train() 257 | model_t.train() 258 | for param in model_t.parameters(): 259 | param.requires_grad = False 260 | 261 | 262 | feature_bank, pred_bank = init_feature_pred_bank(model_s, train_loader_weak) 263 | 264 | avg_dice, std_dice, avg_assd, std_assd = eval(model_t, test_loader) 265 | log_str = ("initial dice: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f, assd: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f" % ( 266 | avg_dice[0], std_dice[0], avg_dice[1], std_dice[1], (avg_dice[0] + avg_dice[1]) / 2.0, 267 | avg_assd[0], std_assd[0], avg_assd[1], std_assd[1], (avg_assd[0] + avg_assd[1]) / 2.0)) 268 | print(log_str) 269 | args.out_file.write(log_str + '\n') 270 | args.out_file.flush() 271 | 272 | for epoch in range(args.epoch): 273 | 274 | log_str = '\nepoch {}/{}:'.format(epoch+1, args.epoch) 275 | print(log_str) 276 | args.out_file.write(log_str + '\n') 277 | args.out_file.flush() 278 | 279 | not_cup_loss_sum = torch.FloatTensor([0]).cuda() 280 | cup_loss_sum = torch.FloatTensor([0]).cuda() 281 | not_cup_loss_num = 0 282 | cup_loss_num = 0 283 | lower_bound = args.pseudo_label_threshold * args.mean_loss_calc_bound_ratio 284 | upper_bound = 1 - ((1 - args.pseudo_label_threshold) * args.mean_loss_calc_bound_ratio) 285 | for pred_i in pred_bank.values(): 286 | not_cup_loss_sum += torch.sum( 287 | -torch.log(1 - pred_i[0, ...][(pred_i[0, ...] < args.pseudo_label_threshold) * (pred_i[0, ...] > lower_bound)])) 288 | not_cup_loss_num += torch.sum((pred_i[0, ...] < args.pseudo_label_threshold) * (pred_i[0, ...] > lower_bound)) 289 | cup_loss_sum += torch.sum(-torch.log(pred_i[0, ...][(pred_i[0, ...] > args.pseudo_label_threshold) * (pred_i[0, ...] < upper_bound)])) 290 | cup_loss_num += torch.sum((pred_i[0, ...] > args.pseudo_label_threshold) * (pred_i[0, ...] < upper_bound)) 291 | loss_weight = (cup_loss_sum.item() / cup_loss_num) / (not_cup_loss_sum.item() / not_cup_loss_num) 292 | 293 | adapt_epoch(model_t, model_s, optim, train_loader, args, feature_bank, pred_bank, loss_weight=loss_weight) 294 | 295 | scheduler.step() 296 | 297 | avg_dice, std_dice, avg_assd, std_assd = eval(model_t, test_loader) 298 | log_str = ("teacher dice: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f, assd: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f" % ( 299 | avg_dice[0], std_dice[0], avg_dice[1], std_dice[1], (avg_dice[0] + avg_dice[1]) / 2.0, 300 | avg_assd[0], std_assd[0], avg_assd[1], std_assd[1], (avg_assd[0] + avg_assd[1]) / 2.0)) 301 | print(log_str) 302 | args.out_file.write(log_str + '\n') 303 | args.out_file.flush() 304 | 305 | avg_dice, std_dice, avg_assd, std_assd = eval(model_s, test_loader) 306 | log_str = ("student dice: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f, assd: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f" % ( 307 | avg_dice[0], std_dice[0], avg_dice[1], std_dice[1], (avg_dice[0] + avg_dice[1]) / 2.0, 308 | avg_assd[0], std_assd[0], avg_assd[1], std_assd[1], (avg_assd[0] + avg_assd[1]) / 2.0)) 309 | print(log_str) 310 | args.out_file.write(log_str + '\n') 311 | args.out_file.flush() 312 | 313 | torch.save({'model_state_dict': model_t.state_dict()}, args.out + '/after_adaptation.pth.tar') 314 | 315 | 316 | if __name__ == '__main__': 317 | main() 318 | 319 | -------------------------------------------------------------------------------- /networks/backbone/drn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | webroot = 'https://tigress-web.princeton.edu/~fy/drn/models/' 7 | 8 | model_urls = { 9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 10 | 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth', 11 | 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth', 12 | 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth', 13 | 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth', 14 | 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth', 15 | 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth', 16 | 'drn-d-105': webroot + 'drn_d_105-12b40979.pth' 17 | } 18 | 19 | 20 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=padding, bias=False, dilation=dilation) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, 29 | dilation=(1, 1), residual=True, BatchNorm=None): 30 | super(BasicBlock, self).__init__() 31 | self.conv1 = conv3x3(inplanes, planes, stride, 32 | padding=dilation[0], dilation=dilation[0]) 33 | self.bn1 = BatchNorm(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes, 36 | padding=dilation[1], dilation=dilation[1]) 37 | self.bn2 = BatchNorm(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | self.residual = residual 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | if self.residual: 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, 65 | dilation=(1, 1), residual=True, BatchNorm=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = BatchNorm(planes) 69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 70 | padding=dilation[1], bias=False, 71 | dilation=dilation[1]) 72 | self.bn2 = BatchNorm(planes) 73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.bn3 = BatchNorm(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class DRN(nn.Module): 103 | 104 | def __init__(self, block, layers, arch='D', 105 | channels=(16, 32, 64, 128, 256, 512, 512, 512), 106 | BatchNorm=None): 107 | super(DRN, self).__init__() 108 | self.inplanes = channels[0] 109 | self.out_dim = channels[-1] 110 | self.arch = arch 111 | 112 | if arch == 'C': 113 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, 114 | padding=3, bias=False) 115 | self.bn1 = BatchNorm(channels[0]) 116 | self.relu = nn.ReLU(inplace=True) 117 | 118 | self.layer1 = self._make_layer( 119 | BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 120 | self.layer2 = self._make_layer( 121 | BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 122 | 123 | elif arch == 'D': 124 | self.layer0 = nn.Sequential( 125 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, 126 | bias=False), 127 | BatchNorm(channels[0]), 128 | nn.ReLU(inplace=True) 129 | ) 130 | 131 | self.layer1 = self._make_conv_layers( 132 | channels[0], layers[0], stride=1, BatchNorm=BatchNorm) 133 | self.layer2 = self._make_conv_layers( 134 | channels[1], layers[1], stride=2, BatchNorm=BatchNorm) 135 | 136 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm) 137 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm) 138 | self.layer5 = self._make_layer(block, channels[4], layers[4], 139 | dilation=2, new_level=False, BatchNorm=BatchNorm) 140 | self.layer6 = None if layers[5] == 0 else \ 141 | self._make_layer(block, channels[5], layers[5], dilation=4, 142 | new_level=False, BatchNorm=BatchNorm) 143 | 144 | if arch == 'C': 145 | self.layer7 = None if layers[6] == 0 else \ 146 | self._make_layer(BasicBlock, channels[6], layers[6], dilation=2, 147 | new_level=False, residual=False, BatchNorm=BatchNorm) 148 | self.layer8 = None if layers[7] == 0 else \ 149 | self._make_layer(BasicBlock, channels[7], layers[7], dilation=1, 150 | new_level=False, residual=False, BatchNorm=BatchNorm) 151 | elif arch == 'D': 152 | self.layer7 = None if layers[6] == 0 else \ 153 | self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm) 154 | self.layer8 = None if layers[7] == 0 else \ 155 | self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm) 156 | 157 | self._init_weight() 158 | 159 | def _init_weight(self): 160 | for m in self.modules(): 161 | if isinstance(m, nn.Conv2d): 162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | m.weight.data.normal_(0, math.sqrt(2. / n)) 164 | elif isinstance(m, SynchronizedBatchNorm2d): 165 | m.weight.data.fill_(1) 166 | m.bias.data.zero_() 167 | elif isinstance(m, nn.BatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | 171 | 172 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 173 | new_level=True, residual=True, BatchNorm=None): 174 | assert dilation == 1 or dilation % 2 == 0 175 | downsample = None 176 | if stride != 1 or self.inplanes != planes * block.expansion: 177 | downsample = nn.Sequential( 178 | nn.Conv2d(self.inplanes, planes * block.expansion, 179 | kernel_size=1, stride=stride, bias=False), 180 | BatchNorm(planes * block.expansion), 181 | ) 182 | 183 | layers = list() 184 | layers.append(block( 185 | self.inplanes, planes, stride, downsample, 186 | dilation=(1, 1) if dilation == 1 else ( 187 | dilation // 2 if new_level else dilation, dilation), 188 | residual=residual, BatchNorm=BatchNorm)) 189 | self.inplanes = planes * block.expansion 190 | for i in range(1, blocks): 191 | layers.append(block(self.inplanes, planes, residual=residual, 192 | dilation=(dilation, dilation), BatchNorm=BatchNorm)) 193 | 194 | return nn.Sequential(*layers) 195 | 196 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None): 197 | modules = [] 198 | for i in range(convs): 199 | modules.extend([ 200 | nn.Conv2d(self.inplanes, channels, kernel_size=3, 201 | stride=stride if i == 0 else 1, 202 | padding=dilation, bias=False, dilation=dilation), 203 | BatchNorm(channels), 204 | nn.ReLU(inplace=True)]) 205 | self.inplanes = channels 206 | return nn.Sequential(*modules) 207 | 208 | def forward(self, x): 209 | if self.arch == 'C': 210 | x = self.conv1(x) 211 | x = self.bn1(x) 212 | x = self.relu(x) 213 | elif self.arch == 'D': 214 | x = self.layer0(x) 215 | 216 | x = self.layer1(x) 217 | x = self.layer2(x) 218 | 219 | x = self.layer3(x) 220 | low_level_feat = x 221 | 222 | x = self.layer4(x) 223 | x = self.layer5(x) 224 | 225 | if self.layer6 is not None: 226 | x = self.layer6(x) 227 | 228 | if self.layer7 is not None: 229 | x = self.layer7(x) 230 | 231 | if self.layer8 is not None: 232 | x = self.layer8(x) 233 | 234 | return x, low_level_feat 235 | 236 | 237 | class DRN_A(nn.Module): 238 | 239 | def __init__(self, block, layers, BatchNorm=None): 240 | self.inplanes = 64 241 | super(DRN_A, self).__init__() 242 | self.out_dim = 512 * block.expansion 243 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 244 | bias=False) 245 | self.bn1 = BatchNorm(64) 246 | self.relu = nn.ReLU(inplace=True) 247 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 248 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm) 249 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm) 250 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 251 | dilation=2, BatchNorm=BatchNorm) 252 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 253 | dilation=4, BatchNorm=BatchNorm) 254 | 255 | self._init_weight() 256 | 257 | def _init_weight(self): 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 261 | m.weight.data.normal_(0, math.sqrt(2. / n)) 262 | elif isinstance(m, SynchronizedBatchNorm2d): 263 | m.weight.data.fill_(1) 264 | m.bias.data.zero_() 265 | elif isinstance(m, nn.BatchNorm2d): 266 | m.weight.data.fill_(1) 267 | m.bias.data.zero_() 268 | 269 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 270 | downsample = None 271 | if stride != 1 or self.inplanes != planes * block.expansion: 272 | downsample = nn.Sequential( 273 | nn.Conv2d(self.inplanes, planes * block.expansion, 274 | kernel_size=1, stride=stride, bias=False), 275 | BatchNorm(planes * block.expansion), 276 | ) 277 | 278 | layers = [] 279 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm)) 280 | self.inplanes = planes * block.expansion 281 | for i in range(1, blocks): 282 | layers.append(block(self.inplanes, planes, 283 | dilation=(dilation, dilation, ), BatchNorm=BatchNorm)) 284 | 285 | return nn.Sequential(*layers) 286 | 287 | def forward(self, x): 288 | x = self.conv1(x) 289 | x = self.bn1(x) 290 | x = self.relu(x) 291 | x = self.maxpool(x) 292 | 293 | x = self.layer1(x) 294 | x = self.layer2(x) 295 | x = self.layer3(x) 296 | x = self.layer4(x) 297 | 298 | return x 299 | 300 | def drn_a_50(BatchNorm, pretrained=True): 301 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm) 302 | if pretrained: 303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 304 | return model 305 | 306 | 307 | def drn_c_26(BatchNorm, pretrained=True): 308 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm) 309 | if pretrained: 310 | pretrained = model_zoo.load_url(model_urls['drn-c-26']) 311 | del pretrained['fc.weight'] 312 | del pretrained['fc.bias'] 313 | model.load_state_dict(pretrained) 314 | return model 315 | 316 | 317 | def drn_c_42(BatchNorm, pretrained=True): 318 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 319 | if pretrained: 320 | pretrained = model_zoo.load_url(model_urls['drn-c-42']) 321 | del pretrained['fc.weight'] 322 | del pretrained['fc.bias'] 323 | model.load_state_dict(pretrained) 324 | return model 325 | 326 | 327 | def drn_c_58(BatchNorm, pretrained=True): 328 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm) 329 | if pretrained: 330 | pretrained = model_zoo.load_url(model_urls['drn-c-58']) 331 | del pretrained['fc.weight'] 332 | del pretrained['fc.bias'] 333 | model.load_state_dict(pretrained) 334 | return model 335 | 336 | 337 | def drn_d_22(BatchNorm, pretrained=True): 338 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm) 339 | if pretrained: 340 | pretrained = model_zoo.load_url(model_urls['drn-d-22']) 341 | del pretrained['fc.weight'] 342 | del pretrained['fc.bias'] 343 | model.load_state_dict(pretrained) 344 | return model 345 | 346 | 347 | def drn_d_24(BatchNorm, pretrained=True): 348 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm) 349 | if pretrained: 350 | pretrained = model_zoo.load_url(model_urls['drn-d-24']) 351 | del pretrained['fc.weight'] 352 | del pretrained['fc.bias'] 353 | model.load_state_dict(pretrained) 354 | return model 355 | 356 | 357 | def drn_d_38(BatchNorm, pretrained=True): 358 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 359 | if pretrained: 360 | pretrained = model_zoo.load_url(model_urls['drn-d-38']) 361 | del pretrained['fc.weight'] 362 | del pretrained['fc.bias'] 363 | model.load_state_dict(pretrained) 364 | return model 365 | 366 | 367 | def drn_d_40(BatchNorm, pretrained=True): 368 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm) 369 | if pretrained: 370 | pretrained = model_zoo.load_url(model_urls['drn-d-40']) 371 | del pretrained['fc.weight'] 372 | del pretrained['fc.bias'] 373 | model.load_state_dict(pretrained) 374 | return model 375 | 376 | 377 | def drn_d_54(BatchNorm, pretrained=True): 378 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 379 | if pretrained: 380 | pretrained = model_zoo.load_url(model_urls['drn-d-54']) 381 | del pretrained['fc.weight'] 382 | del pretrained['fc.bias'] 383 | model.load_state_dict(pretrained) 384 | return model 385 | 386 | 387 | def drn_d_105(BatchNorm, pretrained=True): 388 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm) 389 | if pretrained: 390 | pretrained = model_zoo.load_url(model_urls['drn-d-105']) 391 | del pretrained['fc.weight'] 392 | del pretrained['fc.bias'] 393 | model.load_state_dict(pretrained) 394 | return model 395 | 396 | if __name__ == "__main__": 397 | import torch 398 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True) 399 | input = torch.rand(1, 3, 512, 512) 400 | output, low_level_feat = model(input) 401 | print(output.size()) 402 | print(low_level_feat.size()) 403 | -------------------------------------------------------------------------------- /networks/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | 8 | ############################################################################### 9 | # Helper Functions 10 | ############################################################################### 11 | 12 | 13 | class Identity(nn.Module): 14 | def forward(self, x): 15 | return x 16 | 17 | 18 | def get_norm_layer(norm_type='instance'): 19 | """Return a normalization layer 20 | 21 | Parameters: 22 | norm_type (str) -- the name of the normalization layer: batch | instance | none 23 | 24 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 25 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 26 | """ 27 | if norm_type == 'batch': 28 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 29 | elif norm_type == 'instance': 30 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 31 | elif norm_type == 'none': 32 | def norm_layer(x): return Identity() 33 | else: 34 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 35 | return norm_layer 36 | 37 | 38 | def get_scheduler(optimizer, opt): 39 | """Return a learning rate scheduler 40 | 41 | Parameters: 42 | optimizer -- the optimizer of the network 43 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  44 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 45 | 46 | For 'linear', we keep the same learning rate for the first epochs 47 | and linearly decay the rate to zero over the next epochs. 48 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 49 | See https://pytorch.org/docs/stable/optim.html for more details. 50 | """ 51 | if opt.lr_policy == 'linear': 52 | def lambda_rule(epoch): 53 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 54 | return lr_l 55 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 56 | elif opt.lr_policy == 'step': 57 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 58 | elif opt.lr_policy == 'plateau': 59 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 60 | elif opt.lr_policy == 'cosine': 61 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 62 | else: 63 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 64 | return scheduler 65 | 66 | 67 | def init_weights(net, init_type='normal', init_gain=0.02): 68 | """Initialize network weights. 69 | 70 | Parameters: 71 | net (network) -- network to be initialized 72 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 73 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 74 | 75 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 76 | work better for some applications. Feel free to try yourself. 77 | """ 78 | def init_func(m): # define the initialization function 79 | classname = m.__class__.__name__ 80 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 81 | if init_type == 'normal': 82 | init.normal_(m.weight.data, 0.0, init_gain) 83 | elif init_type == 'xavier': 84 | init.xavier_normal_(m.weight.data, gain=init_gain) 85 | elif init_type == 'kaiming': 86 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 87 | elif init_type == 'orthogonal': 88 | init.orthogonal_(m.weight.data, gain=init_gain) 89 | else: 90 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 91 | if hasattr(m, 'bias') and m.bias is not None: 92 | init.constant_(m.bias.data, 0.0) 93 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 94 | init.normal_(m.weight.data, 1.0, init_gain) 95 | init.constant_(m.bias.data, 0.0) 96 | 97 | print('initialize network with %s' % init_type) 98 | net.apply(init_func) # apply the initialization function 99 | 100 | 101 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 102 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 103 | Parameters: 104 | net (network) -- the network to be initialized 105 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 106 | gain (float) -- scaling factor for normal, xavier and orthogonal. 107 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 108 | 109 | Return an initialized network. 110 | """ 111 | if len(gpu_ids) > 0: 112 | # assert(torch.cuda.is_available()) 113 | net.to(gpu_ids[0]) 114 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 115 | init_weights(net, init_type, init_gain=init_gain) 116 | return net 117 | 118 | 119 | def define_G(input_nc=3, output_nc=1, ngf=64, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 120 | """Create a generator 121 | 122 | Parameters: 123 | input_nc (int) -- the number of channels in input images 124 | output_nc (int) -- the number of channels in output images 125 | ngf (int) -- the number of filters in the last conv layer 126 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 127 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 128 | use_dropout (bool) -- if use dropout layers. 129 | init_type (str) -- the name of our initialization method. 130 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 131 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 132 | 133 | Returns a generator 134 | 135 | Our current implementation provides two types of generators: 136 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) 137 | The original U-Net paper: https://arxiv.org/abs/1505.04597 138 | 139 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) 140 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. 141 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). 142 | 143 | 144 | The generator has been initialized by . It uses RELU for non-linearity. 145 | """ 146 | 147 | norm_layer = get_norm_layer(norm_type=norm) 148 | 149 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 150 | 151 | return init_net(net, init_type, init_gain, gpu_ids) 152 | 153 | 154 | def define_D(input_nc=2, ndf=64, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): 155 | """Create a discriminator 156 | 157 | Parameters: 158 | input_nc (int) -- the number of channels in input images 159 | ndf (int) -- the number of filters in the first conv layer 160 | netD (str) -- the architecture's name: basic | n_layers | pixel 161 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 162 | norm (str) -- the type of normalization layers used in the network. 163 | init_type (str) -- the name of the initialization method. 164 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 165 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 166 | 167 | Returns a discriminator 168 | 169 | Our current implementation provides three types of discriminators: 170 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 171 | It can classify whether 70×70 overlapping patches are real or fake. 172 | Such a patch-level discriminator architecture has fewer parameters 173 | than a full-image discriminator and can work on arbitrarily-sized images 174 | in a fully convolutional fashion. 175 | 176 | [n_layers]: With this mode, you can specify the number of conv layers in the discriminator 177 | with the parameter (default=3 as used in [basic] (PatchGAN).) 178 | 179 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 180 | It encourages greater color diversity but has no effect on spatial statistics. 181 | 182 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity. 183 | """ 184 | 185 | norm_layer = get_norm_layer(norm_type=norm) 186 | 187 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 188 | 189 | return init_net(net, init_type, init_gain, gpu_ids) 190 | 191 | 192 | ############################################################################## 193 | # Classes 194 | ############################################################################## 195 | class GANLoss(nn.Module): 196 | """Define different GAN objectives. 197 | 198 | The GANLoss class abstracts away the need to create the target label tensor 199 | that has the same size as the input. 200 | """ 201 | 202 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): 203 | """ Initialize the GANLoss class. 204 | 205 | Parameters: 206 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 207 | target_real_label (bool) - - label for a real image 208 | target_fake_label (bool) - - label of a fake image 209 | 210 | Note: Do not use sigmoid as the last layer of Discriminator. 211 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 212 | """ 213 | super(GANLoss, self).__init__() 214 | self.register_buffer('real_label', torch.tensor(target_real_label)) 215 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 216 | self.gan_mode = gan_mode 217 | if gan_mode == 'lsgan': 218 | self.loss = nn.MSELoss() 219 | elif gan_mode == 'vanilla': 220 | self.loss = nn.BCEWithLogitsLoss() 221 | elif gan_mode in ['wgangp']: 222 | self.loss = None 223 | else: 224 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 225 | 226 | def get_target_tensor(self, prediction, target_is_real): 227 | """Create label tensors with the same size as the input. 228 | 229 | Parameters: 230 | prediction (tensor) - - tpyically the prediction from a discriminator 231 | target_is_real (bool) - - if the ground truth label is for real images or fake images 232 | 233 | Returns: 234 | A label tensor filled with ground truth label, and with the size of the input 235 | """ 236 | 237 | if target_is_real: 238 | target_tensor = self.real_label 239 | else: 240 | target_tensor = self.fake_label 241 | return target_tensor.expand_as(prediction) 242 | 243 | def __call__(self, prediction, target_is_real): 244 | """Calculate loss given Discriminator's output and grount truth labels. 245 | 246 | Parameters: 247 | prediction (tensor) - - tpyically the prediction output from a discriminator 248 | target_is_real (bool) - - if the ground truth label is for real images or fake images 249 | 250 | Returns: 251 | the calculated loss. 252 | """ 253 | if self.gan_mode in ['lsgan', 'vanilla']: 254 | target_tensor = self.get_target_tensor(prediction, target_is_real) 255 | loss = self.loss(prediction, target_tensor) 256 | elif self.gan_mode == 'wgangp': 257 | if target_is_real: 258 | loss = -prediction.mean() 259 | else: 260 | loss = prediction.mean() 261 | return loss 262 | 263 | 264 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): 265 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 266 | 267 | Arguments: 268 | netD (network) -- discriminator network 269 | real_data (tensor array) -- real images 270 | fake_data (tensor array) -- generated images from the generator 271 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 272 | type (str) -- if we mix real and fake data or not [real | fake | mixed]. 273 | constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 274 | lambda_gp (float) -- weight for this loss 275 | 276 | Returns the gradient penalty loss 277 | """ 278 | if lambda_gp > 0.0: 279 | if type == 'real': # either use real images, fake images, or a linear interpolation of two. 280 | interpolatesv = real_data 281 | elif type == 'fake': 282 | interpolatesv = fake_data 283 | elif type == 'mixed': 284 | alpha = torch.rand(real_data.shape[0], 1, device=device) 285 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) 286 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) 287 | else: 288 | raise NotImplementedError('{} not implemented'.format(type)) 289 | interpolatesv.requires_grad_(True) 290 | disc_interpolates = netD(interpolatesv) 291 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, 292 | grad_outputs=torch.ones(disc_interpolates.size()).to(device), 293 | create_graph=True, retain_graph=True, only_inputs=True) 294 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data 295 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps 296 | return gradient_penalty, gradients 297 | else: 298 | return 0.0, None 299 | 300 | 301 | class UnetGenerator(nn.Module): 302 | """Create a Unet-based generator""" 303 | 304 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): 305 | """Construct a Unet generator 306 | Parameters: 307 | input_nc (int) -- the number of channels in input images 308 | output_nc (int) -- the number of channels in output images 309 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 310 | image of size 128x128 will become of size 1x1 # at the bottleneck 311 | ngf (int) -- the number of filters in the last conv layer 312 | norm_layer -- normalization layer 313 | 314 | We construct the U-Net from the innermost layer to the outermost layer. 315 | It is a recursive process. 316 | """ 317 | super(UnetGenerator, self).__init__() 318 | # construct unet structure 319 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer 320 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 321 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 322 | # gradually reduce the number of filters from ngf * 8 to ngf 323 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 324 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 325 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 326 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer 327 | 328 | def forward(self, input): 329 | """Standard forward""" 330 | return self.model(input) 331 | 332 | 333 | class UnetSkipConnectionBlock(nn.Module): 334 | """Defines the Unet submodule with skip connection. 335 | X -------------------identity---------------------- 336 | |-- downsampling -- |submodule| -- upsampling --| 337 | """ 338 | 339 | def __init__(self, outer_nc, inner_nc, input_nc=None, 340 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 341 | """Construct a Unet submodule with skip connections. 342 | 343 | Parameters: 344 | outer_nc (int) -- the number of filters in the outer conv layer 345 | inner_nc (int) -- the number of filters in the inner conv layer 346 | input_nc (int) -- the number of channels in input images/features 347 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 348 | outermost (bool) -- if this module is the outermost module 349 | innermost (bool) -- if this module is the innermost module 350 | norm_layer -- normalization layer 351 | use_dropout (bool) -- if use dropout layers. 352 | """ 353 | super(UnetSkipConnectionBlock, self).__init__() 354 | self.outermost = outermost 355 | if type(norm_layer) == functools.partial: 356 | use_bias = norm_layer.func == nn.InstanceNorm2d 357 | else: 358 | use_bias = norm_layer == nn.InstanceNorm2d 359 | if input_nc is None: 360 | input_nc = outer_nc 361 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 362 | stride=2, padding=1, bias=use_bias) 363 | downrelu = nn.LeakyReLU(0.2, True) 364 | downnorm = norm_layer(inner_nc) 365 | uprelu = nn.ReLU(True) 366 | upnorm = norm_layer(outer_nc) 367 | 368 | if outermost: 369 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 370 | kernel_size=4, stride=2, 371 | padding=1) 372 | down = [downconv] 373 | up = [uprelu, upconv, nn.Tanh()] 374 | model = down + [submodule] + up 375 | elif innermost: 376 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 377 | kernel_size=4, stride=2, 378 | padding=1, bias=use_bias) 379 | down = [downrelu, downconv] 380 | up = [uprelu, upconv, upnorm] 381 | model = down + up 382 | else: 383 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 384 | kernel_size=4, stride=2, 385 | padding=1, bias=use_bias) 386 | down = [downrelu, downconv, downnorm] 387 | up = [uprelu, upconv, upnorm] 388 | 389 | if use_dropout: 390 | model = down + [submodule] + up + [nn.Dropout(0.5)] 391 | else: 392 | model = down + [submodule] + up 393 | 394 | self.model = nn.Sequential(*model) 395 | 396 | def forward(self, x): 397 | if self.outermost: 398 | return self.model(x) 399 | else: # add skip connections 400 | return torch.cat([x, self.model(x)], 1) 401 | 402 | 403 | class NLayerDiscriminator(nn.Module): 404 | """Defines a PatchGAN discriminator""" 405 | 406 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 407 | """Construct a PatchGAN discriminator 408 | 409 | Parameters: 410 | input_nc (int) -- the number of channels in input images 411 | ndf (int) -- the number of filters in the last conv layer 412 | n_layers (int) -- the number of conv layers in the discriminator 413 | norm_layer -- normalization layer 414 | """ 415 | super(NLayerDiscriminator, self).__init__() 416 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 417 | use_bias = norm_layer.func == nn.InstanceNorm2d 418 | else: 419 | use_bias = norm_layer == nn.InstanceNorm2d 420 | 421 | kw = 4 422 | padw = 1 423 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 424 | nf_mult = 1 425 | nf_mult_prev = 1 426 | for n in range(1, n_layers): # gradually increase the number of filters 427 | nf_mult_prev = nf_mult 428 | nf_mult = min(2 ** n, 8) 429 | sequence += [ 430 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 431 | norm_layer(ndf * nf_mult), 432 | nn.LeakyReLU(0.2, True) 433 | ] 434 | 435 | nf_mult_prev = nf_mult 436 | nf_mult = min(2 ** n_layers, 8) 437 | sequence += [ 438 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 439 | norm_layer(ndf * nf_mult), 440 | nn.LeakyReLU(0.2, True) 441 | ] 442 | 443 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 444 | self.model = nn.Sequential(*sequence) 445 | 446 | def forward(self, input): 447 | """Standard forward.""" 448 | return self.model(input) 449 | 450 | 451 | -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numbers 4 | import random 5 | import numpy as np 6 | 7 | from PIL import Image, ImageOps 8 | from scipy.ndimage.filters import gaussian_filter 9 | from matplotlib.pyplot import imshow, imsave 10 | from scipy.ndimage.interpolation import map_coordinates 11 | import cv2 12 | from scipy import ndimage 13 | 14 | 15 | def to_multilabel(pre_mask, classes=2): # notice classes para 16 | mask = np.zeros((pre_mask.shape[0], pre_mask.shape[1], classes)) 17 | mask[pre_mask == 1] = [0, 1] 18 | mask[pre_mask == 2] = [1, 1] 19 | # 3 channel ver. 20 | # mask[pre_mask == 0] = [1, 0, 0] 21 | # mask[pre_mask == 1] = [0, 1, 0] 22 | # mask[pre_mask == 2] = [0, 0, 1] 23 | return mask 24 | 25 | 26 | class add_salt_pepper_noise(): 27 | def __call__(self, sample): 28 | 29 | image = np.array(sample['image']).astype(np.uint8) 30 | X_imgs_copy = image.copy() 31 | # row = image.shape[0] 32 | # col = image.shape[1] 33 | salt_vs_pepper = 0.2 34 | amount = 0.004 35 | 36 | num_salt = np.ceil(amount * X_imgs_copy.size * salt_vs_pepper) 37 | num_pepper = np.ceil(amount * X_imgs_copy.size * (1.0 - salt_vs_pepper)) 38 | 39 | seed = random.random() 40 | if seed > 0.75: 41 | # Add Salt noise 42 | coords = [np.random.randint(0, i - 1, int(num_salt)) for i in X_imgs_copy.shape] 43 | X_imgs_copy[coords[0], coords[1], :] = 1 44 | elif seed > 0.5: 45 | # Add Pepper noise 46 | coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in X_imgs_copy.shape] 47 | X_imgs_copy[coords[0], coords[1], :] = 0 48 | 49 | return {'image': X_imgs_copy, 50 | 'label': sample['label'], 51 | 'img_name': sample['img_name']} 52 | 53 | 54 | class add_salt_pepper_noise_BraTS(): 55 | def __call__(self, sample): 56 | 57 | 58 | images_0 = np.array(sample['image'][0]).astype(np.uint8) 59 | images_1 = np.array(sample['image'][1]).astype(np.uint8) 60 | images_2 = np.array(sample['image'][2]).astype(np.uint8) 61 | images_3 = np.array(sample['image'][3]).astype(np.uint8) 62 | X_imgs_copy = np.zeros([images_0.shape[0], images_0.shape[0], 4], dtype=np.uint8) 63 | X_imgs_copy[:, :, 0] = images_0.copy() 64 | X_imgs_copy[:, :, 1] = images_1.copy() 65 | X_imgs_copy[:, :, 2] = images_2.copy() 66 | X_imgs_copy[:, :, 3] = images_3.copy() 67 | # row = image.shape[0] 68 | # col = image.shape[1] 69 | salt_vs_pepper = 0.2 70 | amount = 0.004 71 | 72 | num_salt = np.ceil(amount * X_imgs_copy.size * salt_vs_pepper) 73 | num_pepper = np.ceil(amount * X_imgs_copy.size * (1.0 - salt_vs_pepper)) 74 | 75 | seed = random.random() 76 | if seed > 0.75: 77 | # Add Salt noise 78 | coords = [np.random.randint(0, i - 1, int(num_salt)) for i in X_imgs_copy.shape] 79 | X_imgs_copy[coords[0], coords[1], :] = 1 80 | elif seed > 0.5: 81 | # Add Pepper noise 82 | coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in X_imgs_copy.shape] 83 | X_imgs_copy[coords[0], coords[1], :] = 0 84 | 85 | return {'image': X_imgs_copy, 86 | 'label': sample['label'], 87 | 'img_name': sample['img_name']} 88 | 89 | 90 | class adjust_light(): 91 | def __call__(self, sample): 92 | image = sample['image'] 93 | seed = random.random() 94 | if seed > 0.5: 95 | gamma = random.random() * 3 + 0.5 96 | invGamma = 1.0 / gamma 97 | table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype(np.uint8) 98 | image = cv2.LUT(np.array(image).astype(np.uint8), table).astype(np.uint8) 99 | return {'image': image, 100 | 'label': sample['label'], 101 | 'img_name': sample['img_name']} 102 | else: 103 | return sample 104 | 105 | 106 | class eraser(): 107 | def __call__(self, sample, s_l=0.02, s_h=0.06, r_1=0.3, r_2=0.6, v_l=0, v_h=255, pixel_level=False): 108 | image = sample['image'] 109 | img_h, img_w, img_c = image.shape 110 | 111 | 112 | if random.random() > 0.5: 113 | return sample 114 | 115 | while True: 116 | s = np.random.uniform(s_l, s_h) * img_h * img_w 117 | r = np.random.uniform(r_1, r_2) 118 | w = int(np.sqrt(s / r)) 119 | h = int(np.sqrt(s * r)) 120 | left = np.random.randint(0, img_w) 121 | top = np.random.randint(0, img_h) 122 | 123 | if left + w <= img_w and top + h <= img_h: 124 | break 125 | 126 | if pixel_level: 127 | c = np.random.uniform(v_l, v_h, (h, w, img_c)) 128 | else: 129 | c = np.random.uniform(v_l, v_h) 130 | 131 | image[top:top + h, left:left + w, :] = c 132 | 133 | return {'image': image, 134 | 'label': sample['label'], 135 | 'img_name': sample['img_name']} 136 | 137 | class elastic_transform(): 138 | """Elastic deformation of images as described in [Simard2003]_. 139 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 140 | Convolutional Neural Networks applied to Visual Document Analysis", in 141 | Proc. of the International Conference on Document Analysis and 142 | Recognition, 2003. 143 | """ 144 | 145 | # def __init__(self): 146 | 147 | def __call__(self, sample): 148 | image, label = sample['image'], sample['label'] 149 | alpha = image.size[1] * 2 150 | sigma = image.size[1] * 0.08 151 | random_state = None 152 | seed = random.random() 153 | if seed > 0.5: 154 | # print(image.size) 155 | assert len(image.size) == 2 156 | 157 | if random_state is None: 158 | random_state = np.random.RandomState(None) 159 | 160 | shape = image.size[0:2] 161 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 162 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha 163 | 164 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij') 165 | indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1)) 166 | 167 | transformed_image = np.zeros([image.size[0], image.size[1], 3]) 168 | transformed_label = np.zeros([image.size[0], image.size[1]]) 169 | 170 | for i in range(3): 171 | # print(i) 172 | transformed_image[:, :, i] = map_coordinates(np.array(image)[:, :, i], indices, order=1).reshape(shape) 173 | # break 174 | if label is not None: 175 | transformed_label[:, :] = map_coordinates(np.array(label)[:, :], indices, order=1, mode='nearest').reshape(shape) 176 | else: 177 | transformed_label = None 178 | transformed_image = transformed_image.astype(np.uint8) 179 | 180 | if label is not None: 181 | transformed_label = transformed_label.astype(np.uint8) 182 | 183 | return {'image': transformed_image, 184 | 'label': transformed_label, 185 | 'img_name': sample['img_name']} 186 | else: 187 | return {'image': np.array(sample['image']), 188 | 'label': np.array(sample['label']), 189 | 'img_name': sample['img_name']} 190 | 191 | 192 | 193 | 194 | class RandomCrop(object): 195 | def __init__(self, size, padding=0): 196 | if isinstance(size, numbers.Number): 197 | self.size = (int(size), int(size)) 198 | else: 199 | self.size = size # h, w 200 | self.padding = padding 201 | 202 | def __call__(self, sample): 203 | img, mask = sample['image'], sample['label'] 204 | w, h = img.size 205 | if self.padding > 0 or w < self.size[0] or h < self.size[1]: 206 | padding = np.maximum(self.padding,np.maximum((self.size[0]-w)//2+5,(self.size[1]-h)//2+5)) 207 | img = ImageOps.expand(img, border=padding, fill=0) 208 | mask = ImageOps.expand(mask, border=padding, fill=255) 209 | 210 | assert img.width == mask.width 211 | assert img.height == mask.height 212 | w, h = img.size 213 | th, tw = self.size # target size 214 | if w == tw and h == th: 215 | return {'image': img, 216 | 'label': mask, 217 | 'img_name': sample['img_name']} 218 | x1 = random.randint(0, w - tw) 219 | y1 = random.randint(0, h - th) 220 | img = img.crop((x1, y1, x1 + tw, y1 + th)) 221 | mask = mask.crop((x1, y1, x1 + tw, y1 + th)) 222 | return {'image': img, 223 | 'label': mask, 224 | 'img_name': sample['img_name']} 225 | 226 | 227 | class RandomCrop_BraTS(object): 228 | def __init__(self, size, padding=0): 229 | if isinstance(size, numbers.Number): 230 | self.size = (int(size), int(size)) 231 | else: 232 | self.size = size # h, w 233 | self.padding = padding 234 | 235 | def __call__(self, sample): 236 | imgs, mask = sample['image'], sample['label'] 237 | w, h = imgs[0].size 238 | if self.padding > 0 or w < self.size[0] or h < self.size[1]: 239 | padding = np.maximum(self.padding,np.maximum((self.size[0]-w)//2+5,(self.size[1]-h)//2+5)) 240 | imgs[0] = ImageOps.expand(imgs[0], border=padding, fill=0) 241 | imgs[1] = ImageOps.expand(imgs[1], border=padding, fill=0) 242 | imgs[2] = ImageOps.expand(imgs[2], border=padding, fill=0) 243 | imgs[3] = ImageOps.expand(imgs[3], border=padding, fill=0) 244 | mask = ImageOps.expand(mask, border=padding, fill=255) 245 | 246 | assert imgs[0].width == mask.width 247 | assert imgs[0].height == mask.height 248 | w, h = imgs[0].size 249 | th, tw = self.size # target size 250 | if w == tw and h == th: 251 | return {'image': imgs, 252 | 'label': mask, 253 | 'img_name': sample['img_name']} 254 | x1 = random.randint(0, w - tw) 255 | y1 = random.randint(0, h - th) 256 | imgs[0] = imgs[0].crop((x1, y1, x1 + tw, y1 + th)) 257 | imgs[1] = imgs[1].crop((x1, y1, x1 + tw, y1 + th)) 258 | imgs[2] = imgs[2].crop((x1, y1, x1 + tw, y1 + th)) 259 | imgs[3] = imgs[3].crop((x1, y1, x1 + tw, y1 + th)) 260 | mask = mask.crop((x1, y1, x1 + tw, y1 + th)) 261 | return {'image': imgs, 262 | 'label': mask, 263 | 'img_name': sample['img_name']} 264 | 265 | 266 | class CenterCrop(object): 267 | def __init__(self, size): 268 | if isinstance(size, numbers.Number): 269 | self.size = (int(size), int(size)) 270 | else: 271 | self.size = size 272 | 273 | def __call__(self, sample): 274 | img = sample['image'] 275 | mask = sample['label'] 276 | 277 | w, h = img.size 278 | th, tw = self.size 279 | x1 = int(round((w - tw) / 2.)) 280 | y1 = int(round((h - th) / 2.)) 281 | img = img.crop((x1, y1, x1 + tw, y1 + th)) 282 | mask = mask.crop((x1, y1, x1 + tw, y1 + th)) 283 | 284 | return {'image': img, 285 | 'label': mask, 286 | 'img_name': sample['img_name']} 287 | 288 | 289 | class RandomFlip(object): 290 | def __call__(self, sample): 291 | img = sample['image'] 292 | mask = sample['label'] 293 | name = sample['img_name'] 294 | if random.random() < 0.5: 295 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 296 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 297 | if random.random() < 0.5: 298 | img = img.transpose(Image.FLIP_TOP_BOTTOM) 299 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM) 300 | 301 | return {'image': img, 302 | 'label': mask, 303 | 'img_name': name 304 | } 305 | 306 | class RandomFlip_BraTS(object): 307 | def __call__(self, sample): 308 | imgs = sample['image'] 309 | mask = sample['label'] 310 | name = sample['img_name'] 311 | if random.random() < 0.5: 312 | imgs[0] = imgs[0].transpose(Image.FLIP_LEFT_RIGHT) 313 | imgs[1] = imgs[1].transpose(Image.FLIP_LEFT_RIGHT) 314 | imgs[2] = imgs[2].transpose(Image.FLIP_LEFT_RIGHT) 315 | imgs[3] = imgs[3].transpose(Image.FLIP_LEFT_RIGHT) 316 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 317 | if random.random() < 0.5: 318 | imgs[0] = imgs[0].transpose(Image.FLIP_TOP_BOTTOM) 319 | imgs[1] = imgs[1].transpose(Image.FLIP_TOP_BOTTOM) 320 | imgs[2] = imgs[2].transpose(Image.FLIP_TOP_BOTTOM) 321 | imgs[3] = imgs[3].transpose(Image.FLIP_TOP_BOTTOM) 322 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM) 323 | 324 | return {'image': imgs, 325 | 'label': mask, 326 | 'img_name': name 327 | } 328 | 329 | 330 | class FixedResize(object): 331 | def __init__(self, size): 332 | self.size = tuple(reversed(size)) # size: (h, w) 333 | 334 | def __call__(self, sample): 335 | img = sample['image'] 336 | mask = sample['label'] 337 | name = sample['img_name'] 338 | 339 | assert img.width == mask.width 340 | assert img.height == mask.height 341 | img = img.resize(self.size, Image.BILINEAR) 342 | mask = mask.resize(self.size, Image.NEAREST) 343 | 344 | return {'image': img, 345 | 'label': mask, 346 | 'img_name': name} 347 | 348 | 349 | class Scale(object): 350 | def __init__(self, size): 351 | if isinstance(size, numbers.Number): 352 | self.size = (int(size), int(size)) 353 | else: 354 | self.size = size 355 | 356 | def __call__(self, sample): 357 | img = sample['image'] 358 | mask = sample['label'] 359 | assert img.width == mask.width 360 | assert img.height == mask.height 361 | w, h = img.size 362 | 363 | if (w >= h and w == self.size[1]) or (h >= w and h == self.size[0]): 364 | return {'image': img, 365 | 'label': mask, 366 | 'img_name': sample['img_name']} 367 | oh, ow = self.size 368 | img = img.resize((ow, oh), Image.BILINEAR) 369 | mask = mask.resize((ow, oh), Image.NEAREST) 370 | 371 | return {'image': img, 372 | 'label': mask, 373 | 'img_name': sample['img_name']} 374 | 375 | 376 | class RandomSizedCrop(object): 377 | def __init__(self, size): 378 | self.size = size 379 | 380 | def __call__(self, sample): 381 | img = sample['image'] 382 | mask = sample['label'] 383 | name = sample['img_name'] 384 | assert img.width == mask.width 385 | assert img.height == mask.height 386 | for attempt in range(10): 387 | area = img.size[0] * img.size[1] 388 | target_area = random.uniform(0.45, 1.0) * area 389 | aspect_ratio = random.uniform(0.5, 2) 390 | 391 | w = int(round(math.sqrt(target_area * aspect_ratio))) 392 | h = int(round(math.sqrt(target_area / aspect_ratio))) 393 | 394 | if random.random() < 0.5: 395 | w, h = h, w 396 | 397 | if w <= img.size[0] and h <= img.size[1]: 398 | x1 = random.randint(0, img.size[0] - w) 399 | y1 = random.randint(0, img.size[1] - h) 400 | 401 | img = img.crop((x1, y1, x1 + w, y1 + h)) 402 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 403 | assert (img.size == (w, h)) 404 | 405 | img = img.resize((self.size, self.size), Image.BILINEAR) 406 | mask = mask.resize((self.size, self.size), Image.NEAREST) 407 | 408 | return {'image': img, 409 | 'label': mask, 410 | 'img_name': name} 411 | 412 | # Fallback 413 | scale = Scale(self.size) 414 | crop = CenterCrop(self.size) 415 | sample = crop(scale(sample)) 416 | return sample 417 | 418 | 419 | class RandomRotate(object): 420 | def __init__(self, size=512): 421 | self.degree = random.randint(1, 4) * 90 422 | self.size = size 423 | 424 | def __call__(self, sample): 425 | img = sample['image'] 426 | mask = sample['label'] 427 | 428 | seed = random.random() 429 | if seed > 0.5: 430 | rotate_degree = self.degree 431 | img = img.rotate(rotate_degree, Image.BILINEAR, expand=0) 432 | mask = mask.rotate(rotate_degree, Image.NEAREST, expand=255) 433 | 434 | sample = {'image': img, 'label': mask, 'img_name': sample['img_name']} 435 | return sample 436 | 437 | 438 | class RandomRotate_BraTS(object): 439 | def __init__(self, size=512): 440 | self.degree = random.randint(1, 4) * 90 441 | self.size = size 442 | 443 | def __call__(self, sample): 444 | imgs = sample['image'] 445 | mask = sample['label'] 446 | 447 | seed = random.random() 448 | if seed > 0.5: 449 | rotate_degree = self.degree 450 | imgs[0] = imgs[0].rotate(rotate_degree, Image.BILINEAR, expand=0) 451 | imgs[1] = imgs[1].rotate(rotate_degree, Image.BILINEAR, expand=0) 452 | imgs[2] = imgs[2].rotate(rotate_degree, Image.BILINEAR, expand=0) 453 | imgs[3] = imgs[3].rotate(rotate_degree, Image.BILINEAR, expand=0) 454 | mask = mask.rotate(rotate_degree, Image.NEAREST, expand=255) 455 | 456 | sample = {'image': imgs, 'label': mask, 'img_name': sample['img_name']} 457 | return sample 458 | 459 | 460 | class RandomScaleCrop(object): 461 | def __init__(self, size): 462 | self.size = size 463 | self.crop = RandomCrop(self.size) 464 | 465 | def __call__(self, sample): 466 | img = sample['image'] 467 | mask = sample['label'] 468 | name = sample['img_name'] 469 | # print(img.size) 470 | assert img.width == mask.width 471 | assert img.height == mask.height 472 | 473 | seed = random.random() 474 | if seed > 0.5: 475 | w = int(random.uniform(0.5, 1.5) * img.size[0]) 476 | h = int(random.uniform(0.5, 1.5) * img.size[1]) 477 | 478 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST) 479 | sample = {'image': img, 'label': mask, 'img_name': name} 480 | 481 | return self.crop(sample) 482 | 483 | 484 | class RandomScaleCrop_BraTS(object): 485 | def __init__(self, size): 486 | self.size = size 487 | self.crop = RandomCrop_BraTS(self.size) 488 | 489 | def __call__(self, sample): 490 | imgs = sample['image'] 491 | mask = sample['label'] 492 | name = sample['img_name'] 493 | # print(img.size) 494 | assert imgs[0].width == mask.width 495 | assert imgs[0].height == mask.height 496 | 497 | seed = random.random() 498 | if seed > 0.5: 499 | w = int(random.uniform(0.5, 1.5) * imgs[0].size[0]) 500 | h = int(random.uniform(0.5, 1.5) * imgs[0].size[1]) 501 | 502 | imgs[0] = imgs[0].resize((w, h), Image.BILINEAR) 503 | imgs[1] = imgs[1].resize((w, h), Image.BILINEAR) 504 | imgs[2] = imgs[2].resize((w, h), Image.BILINEAR) 505 | imgs[3] = imgs[3].resize((w, h), Image.BILINEAR) 506 | mask = mask.resize((w, h), Image.NEAREST) 507 | sample = {'image': imgs, 'label': mask, 'img_name': name} 508 | 509 | return self.crop(sample) 510 | 511 | 512 | class ResizeImg(object): 513 | def __init__(self, size): 514 | self.size = size 515 | 516 | def __call__(self, sample): 517 | img = sample['image'] 518 | mask = sample['label'] 519 | name = sample['img_name'] 520 | assert img.width == mask.width 521 | assert img.height == mask.height 522 | 523 | img = img.resize((self.size, self.size)) 524 | 525 | sample = {'image': img, 'label': mask, 'img_name': name} 526 | return sample 527 | 528 | 529 | class Resize(object): 530 | def __init__(self, size): 531 | self.size = size 532 | 533 | def __call__(self, sample): 534 | img = sample['image'] 535 | mask = sample['label'] 536 | name = sample['img_name'] 537 | assert img.width == mask.width 538 | assert img.height == mask.height 539 | 540 | img = img.resize((self.size, self.size), Image.BILINEAR) 541 | mask = mask.resize((self.size, self.size), Image.NEAREST) 542 | 543 | sample = {'image': img, 'label': mask, 'img_name': name} 544 | return sample 545 | 546 | 547 | class Resize_BraTS(object): 548 | def __init__(self, size): 549 | self.size = size 550 | 551 | def __call__(self, sample): 552 | imgs = sample['image'] 553 | mask = sample['label'] 554 | name = sample['img_name'] 555 | assert imgs[0].width == mask.width 556 | assert imgs[0].height == mask.height 557 | 558 | imgs[0] = imgs[0].resize((self.size, self.size), Image.BILINEAR) 559 | imgs[1] = imgs[1].resize((self.size, self.size), Image.BILINEAR) 560 | imgs[2] = imgs[2].resize((self.size, self.size), Image.BILINEAR) 561 | imgs[3] = imgs[3].resize((self.size, self.size), Image.BILINEAR) 562 | mask = mask.resize((self.size, self.size), Image.NEAREST) 563 | 564 | sample = {'image': imgs, 'label': mask, 'img_name': name} 565 | return sample 566 | 567 | 568 | class Normalize(object): 569 | """Normalize a tensor image with mean and standard deviation. 570 | Args: 571 | mean (tuple): means for each channel. 572 | std (tuple): standard deviations for each channel. 573 | """ 574 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 575 | self.mean = mean 576 | self.std = std 577 | 578 | def __call__(self, sample): 579 | img = np.array(sample['image']).astype(np.float32) 580 | mask = np.array(sample['label']).astype(np.float32) 581 | img /= 255.0 582 | img -= self.mean 583 | img /= self.std 584 | 585 | return {'image': img, 586 | 'label': mask, 587 | 'img_name': sample['img_name']} 588 | 589 | 590 | class GetBoundary(object): 591 | def __init__(self, width = 5): 592 | self.width = width 593 | def __call__(self, mask): 594 | cup = mask[:, :, 0] 595 | disc = mask[:, :, 1] 596 | dila_cup = ndimage.binary_dilation(cup, iterations=self.width).astype(cup.dtype) 597 | eros_cup = ndimage.binary_erosion(cup, iterations=self.width).astype(cup.dtype) 598 | dila_disc= ndimage.binary_dilation(disc, iterations=self.width).astype(disc.dtype) 599 | eros_disc= ndimage.binary_erosion(disc, iterations=self.width).astype(disc.dtype) 600 | cup = dila_cup + eros_cup 601 | disc = dila_disc + eros_disc 602 | cup[cup==2]=0 603 | disc[disc==2]=0 604 | boundary = (cup + disc) > 0 605 | return boundary.astype(np.uint8) 606 | 607 | 608 | class Normalize_tf(object): 609 | """Normalize a tensor image with mean and standard deviation. 610 | Args: 611 | mean (tuple): means for each channel. 612 | std (tuple): standard deviations for each channel. 613 | """ 614 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 615 | self.mean = mean 616 | self.std = std 617 | self.get_boundary = GetBoundary() 618 | 619 | def __call__(self, sample): 620 | img = np.array(sample['image']).astype(np.float32) 621 | __mask = np.array(sample['label']).astype(np.uint8) 622 | name = sample['img_name'] 623 | img /= 127.5 624 | img -= 1.0 625 | _mask = np.zeros([__mask.shape[0], __mask.shape[1]]) 626 | _mask[__mask > 200] = 255 627 | _mask[(__mask > 50) & (__mask < 201)] = 128 628 | 629 | __mask[_mask == 0] = 2 630 | __mask[_mask == 255] = 0 631 | __mask[_mask == 128] = 1 632 | 633 | mask = to_multilabel(__mask) 634 | # boundary = self.get_boundary(mask) * 255 635 | # boundary = ndimage.gaussian_filter(boundary, sigma=3) / 255.0 636 | # boundary = np.expand_dims(boundary, -1) 637 | 638 | return {'image': img, 639 | 'label': mask, 640 | # 'boundary': boundary, 641 | 'img_name': name 642 | } 643 | 644 | 645 | class Normalize_cityscapes(object): 646 | """Normalize a tensor image with mean and standard deviation. 647 | Args: 648 | mean (tuple): means for each channel. 649 | std (tuple): standard deviations for each channel. 650 | """ 651 | def __init__(self, mean=(0., 0., 0.)): 652 | self.mean = mean 653 | 654 | def __call__(self, sample): 655 | img = np.array(sample['image']).astype(np.float32) 656 | mask = np.array(sample['label']).astype(np.float32) 657 | img -= self.mean 658 | img /= 255.0 659 | 660 | return {'image': img, 661 | 'label': mask, 662 | 'img_name': sample['img_name']} 663 | 664 | 665 | class Normalize_CMR(object): 666 | """ 667 | Normalize a tensor image. 668 | """ 669 | def __call__(self, sample): 670 | img = np.array(sample['image']).astype(np.float32) 671 | mask = np.array(sample['label']).astype(np.float32) 672 | img /= 255.0 673 | 674 | return {'image': img, 675 | 'label': mask, 676 | 'img_name': sample['img_name']} 677 | 678 | 679 | class Normalize_BraTS(object): 680 | """ 681 | Normalize a tensor image. 682 | """ 683 | def __call__(self, sample): 684 | if type(sample['image']) == list: 685 | img0 = np.array(sample['image'][0]).astype(np.float32) 686 | img1 = np.array(sample['image'][1]).astype(np.float32) 687 | img2 = np.array(sample['image'][2]).astype(np.float32) 688 | img3 = np.array(sample['image'][3]).astype(np.float32) 689 | img = np.zeros((img0.shape[0], img0.shape[1], 4), dtype=np.float32) 690 | img[:, :, 0] = img0 691 | img[:, :, 1] = img1 692 | img[:, :, 2] = img2 693 | img[:, :, 3] = img3 694 | else: 695 | img = np.array(sample['image']).astype(np.float32) 696 | img /= 255.0 697 | 698 | return {'image': img, 699 | 'label': sample['label'], 700 | 'img_name': sample['img_name']} 701 | 702 | 703 | class ToTensor(object): 704 | """Convert ndarrays in sample to Tensors.""" 705 | 706 | def __call__(self, sample): 707 | # swap color axis because 708 | # numpy image: H x W x C 709 | # torch image: C X H X W 710 | img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1)) 711 | map = np.array(sample['label']).astype(np.uint8).transpose((2, 0, 1)) 712 | # boundary = np.array(sample['boundary']).astype(np.float).transpose((2, 0, 1)) 713 | name = sample['img_name'] 714 | img = torch.from_numpy(img).float() 715 | map = torch.from_numpy(map).float() 716 | # boundary = torch.from_numpy(boundary).float() 717 | 718 | return {'image': img, 719 | 'label': map, 720 | # 'boundary': boundary, 721 | 'img_name': name} 722 | 723 | 724 | class LabelOneHotAndToTensor(object): 725 | """ 726 | Turn number(0, 80, 160, 180) label into one-hot label. 727 | Convert PIL.Image in sample to Tensors. 728 | """ 729 | def __call__(self, sample): 730 | img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1)) 731 | img = torch.from_numpy(img).type(dtype=torch.FloatTensor) 732 | 733 | map = np.array(sample['label']) 734 | map = (map == 80) * 1 + (map == 160) * 2 + (map == 240) * 3 735 | map = torch.from_numpy(map).type(dtype=torch.LongTensor) 736 | label_onehot = torch.FloatTensor(4, map.size()[0], map.size()[1]) 737 | label_onehot.zero_() 738 | label_onehot.scatter_(0, map.unsqueeze(dim=0), 1) 739 | 740 | name = sample['img_name'] 741 | 742 | return {'image': img, 743 | 'label': label_onehot, 744 | 'img_name': name} --------------------------------------------------------------------------------