├── utils
├── __init__.py
├── losses.py
├── metrics.py
└── Utils.py
├── dataloaders
├── __init__.py
├── fundus_dataloader.py
└── custom_transforms.py
├── networks
├── __init__.py
├── backbone
│ ├── __init__.py
│ ├── mobilenet.py
│ ├── resnet.py
│ ├── xception.py
│ └── drn.py
├── decoder.py
├── deeplabv3.py
├── layers.py
├── decoder_old.py
├── aspp.py
├── aspp_eval.py
├── sync_batchnorm
│ ├── comm.py
│ └── batchnorm.py
├── GAN.py
└── models.py
├── train_process
├── __init__.py
└── Trainer.py
├── figure
└── framework.png
├── README.md
├── .gitignore
├── train_source.py
└── train_target.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/train_process/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/figure/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lloongx/SFDA-CBMT/HEAD/figure/framework.png
--------------------------------------------------------------------------------
/networks/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from networks.backbone import resnet, xception, drn, mobilenet
2 |
3 | def build_backbone(backbone, output_stride, BatchNorm):
4 | if backbone == 'resnet':
5 | return resnet.ResNet101(output_stride, BatchNorm)
6 | elif backbone == 'xception':
7 | return xception.AlignedXception(output_stride, BatchNorm)
8 | elif backbone == 'drn':
9 | return drn.drn_d_54(BatchNorm)
10 | elif backbone == 'mobilenet':
11 | return mobilenet.MobileNetV2(output_stride, BatchNorm)
12 | else:
13 | raise NotImplementedError
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Source-Free Domain Adaptive Fundus Image Segmentation with Class-Balanced Mean Teacher
2 |
3 | Pytorch implementation of MICCAI'23 paper *Source-Free Domain Adaptive Fundus Image Segmentation with Class-Balanced Mean Teacher*.
4 |
5 |
6 |
7 |
8 |
9 | ## Installation
10 | * Install python 3.10.5, pytorch 1.12.0, CUDA 11.6 and other essential packages (Note that using other versions of packages may affect performance.)
11 | * Clone this repo
12 | ```
13 | git clone https://github.com/lloongx/SFDA-CBMT.git
14 | cd SFDA-CBMT
15 | ```
16 |
17 | ## Training
18 | * Download datasets from [here](https://drive.google.com/file/d/1B7ArHRBjt2Dx29a3A6X_lGhD0vDVr3sy/view).
19 | * Download source domain model from [here](https://drive.google.com/drive/folders/1L23mCg8prsdu1imEQI5ouuwvVL_FSiLY) or specify the `--data-dir` in `./train_source.py` and then run it.
20 | * Save source domain model into folder `./logs_train/`.
21 | * Run `./train_target.py` with specified `--model-file` and `--data-dir` to start the SFDA training process.
22 |
23 |
24 | ## Acknowledgement
25 | This repo benefits from [BEAL](https://github.com/emma-sjwang/BEAL) and [SFDA-DPL](https://github.com/cchen-cc/SFDA-DPL). Thanks for their wonderful works.
26 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import math
5 |
6 | def entropy_loss(p, C=2):
7 | y1 = -1.0*torch.sum(p*torch.log(p+1e-6), dim=1)/torch.tensor(np.log(C)).cuda()
8 | ent = torch.mean(y1)
9 |
10 | return ent
11 |
12 | class CrossEntropyLoss(nn.CrossEntropyLoss):
13 | def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'):
14 | super().__init__(weight, size_average, ignore_index, reduce, reduction)
15 |
16 | def forward(self, logits: torch.tensor, target: torch.tensor, **kwargs):
17 | return super().forward(logits, target)
18 |
19 |
20 | class StochasticSegmentationNetworkLossMCIntegral(nn.Module):
21 | def __init__(self, num_mc_samples: int = 1):
22 | super().__init__()
23 | self.num_mc_samples = num_mc_samples
24 |
25 | @staticmethod
26 | def fixed_re_parametrization_trick(dist, num_samples):
27 | assert num_samples % 2 == 0
28 | samples = dist.rsample((num_samples // 2,))
29 | mean = dist.mean.unsqueeze(0)
30 | samples = samples - mean
31 | return torch.cat([samples, -samples]) + mean
32 |
33 | def forward(self, logits, target, distribution, **kwargs):
34 | batch_size = logits.shape[0]
35 | num_classes = logits.shape[1]
36 | assert num_classes >= 2 # not implemented for binary case with implied background
37 | # logit_sample = distribution.rsample((self.num_mc_samples,))
38 | logit_sample = self.fixed_re_parametrization_trick(distribution, self.num_mc_samples)
39 | target = target.unsqueeze(1)
40 | target = target.expand((self.num_mc_samples,) + target.shape)
41 |
42 | flat_size = self.num_mc_samples * batch_size
43 | logit_sample = logit_sample.view((flat_size, num_classes, -1))
44 | target = target.reshape((flat_size, -1))
45 |
46 | # log_prob = -F.cross_entropy(logit_sample, target, reduction='none').view((self.num_mc_samples, batch_size, -1))
47 | log_prob = -F.binary_cross_entropy(F.sigmoid(logit_sample), target, reduction='none').view((self.num_mc_samples, batch_size, -1))
48 | loglikelihood = torch.mean(torch.logsumexp(torch.sum(log_prob, dim=-1), dim=0) - math.log(self.num_mc_samples))
49 | loss = -loglikelihood
50 | return loss
--------------------------------------------------------------------------------
/networks/decoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, num_classes, backbone, BatchNorm):
9 | super(Decoder, self).__init__()
10 | if backbone == 'resnet' or backbone == 'drn':
11 | low_level_inplanes = 256
12 | elif backbone == 'xception':
13 | low_level_inplanes = 128
14 | elif backbone == 'mobilenet':
15 | low_level_inplanes = 24
16 | else:
17 | raise NotImplementedError
18 |
19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
20 | self.bn1 = BatchNorm(48)
21 | self.relu = nn.ReLU()
22 | self.last_conv_1 = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
23 | BatchNorm(256),
24 | nn.ReLU(),
25 | nn.Dropout(0.5),
26 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
27 | BatchNorm(256),
28 | nn.ReLU())
29 | self.last_conv_2 = nn.Sequential(nn.Dropout(0.1),
30 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
31 | self._init_weight()
32 |
33 |
34 | def forward(self, x, low_level_feat):
35 | low_level_feat = self.conv1(low_level_feat)
36 | low_level_feat = self.bn1(low_level_feat)
37 | low_level_feat = self.relu(low_level_feat)
38 |
39 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
40 | x = torch.cat((x, low_level_feat), dim=1)
41 | x = self.last_conv_1(x)
42 | out = self.last_conv_2(x)
43 |
44 | return out, x
45 |
46 | def _init_weight(self):
47 | for m in self.modules():
48 | if isinstance(m, nn.Conv2d):
49 | torch.nn.init.kaiming_normal_(m.weight)
50 | elif isinstance(m, SynchronizedBatchNorm2d):
51 | m.weight.data.fill_(1)
52 | m.bias.data.zero_()
53 | elif isinstance(m, nn.BatchNorm2d):
54 | m.weight.data.fill_(1)
55 | m.bias.data.zero_()
56 |
57 | def build_decoder(num_classes, backbone, BatchNorm):
58 | return Decoder(num_classes, backbone, BatchNorm)
59 |
--------------------------------------------------------------------------------
/networks/deeplabv3.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 | from networks.aspp import build_aspp
6 | from networks.decoder import build_decoder
7 | from networks.backbone import build_backbone
8 |
9 |
10 | class DeepLab(nn.Module):
11 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21,
12 | sync_bn=True, freeze_bn=False):
13 | super(DeepLab, self).__init__()
14 | if backbone == 'drn':
15 | output_stride = 8
16 |
17 | if sync_bn == True:
18 | BatchNorm = SynchronizedBatchNorm2d
19 | else:
20 | BatchNorm = nn.BatchNorm2d
21 |
22 | self.backbone = build_backbone(backbone, output_stride, BatchNorm)
23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm)
24 | self.decoder = build_decoder(num_classes, backbone, BatchNorm)
25 |
26 | if freeze_bn:
27 | self.freeze_bn()
28 |
29 | def forward(self, input):
30 | x, low_level_feat = self.backbone(input)
31 | x = self.aspp(x)
32 | x, features = self.decoder(x, low_level_feat)
33 |
34 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True)
35 | return x, features
36 |
37 | def freeze_bn(self):
38 | for m in self.modules():
39 | if isinstance(m, SynchronizedBatchNorm2d):
40 | m.eval()
41 | elif isinstance(m, nn.BatchNorm2d):
42 | m.eval()
43 |
44 | def get_1x_lr_params(self):
45 | modules = [self.backbone]
46 | for i in range(len(modules)):
47 | for m in modules[i].named_modules():
48 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
49 | or isinstance(m[1], nn.BatchNorm2d):
50 | for p in m[1].parameters():
51 | if p.requires_grad:
52 | yield p
53 |
54 | def get_10x_lr_params(self):
55 | modules = [self.aspp, self.decoder]
56 | for i in range(len(modules)):
57 | for m in modules[i].named_modules():
58 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
59 | or isinstance(m[1], nn.BatchNorm2d):
60 | for p in m[1].parameters():
61 | if p.requires_grad:
62 | yield p
63 |
64 |
65 | if __name__ == "__main__":
66 | model = DeepLab(backbone='mobilenet', output_stride=16)
67 | model.eval()
68 | input = torch.rand(1, 3, 513, 513)
69 | output = model(input)
70 | print(output.size())
71 |
72 |
73 |
--------------------------------------------------------------------------------
/networks/layers.py:
--------------------------------------------------------------------------------
1 | """
2 | Wrappers for the operations to take the meta-learning gradient
3 | updates into account.
4 | """
5 | import torch.autograd as autograd
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 |
10 | def linear(inputs, weight, bias, meta_step_size=0.001, meta_loss=None, stop_gradient=False):
11 | inputs = inputs.cuda()
12 | weight = weight.cuda()
13 | bias = bias.cuda()
14 |
15 | if meta_loss is not None:
16 |
17 | if not stop_gradient:
18 | grad_weight = autograd.grad(meta_loss, weight, create_graph=True)[0]
19 |
20 | if bias is not None:
21 | grad_bias = autograd.grad(meta_loss, bias, create_graph=True)[0]
22 | bias_adapt = bias - grad_bias * meta_step_size
23 | else:
24 | bias_adapt = bias
25 |
26 | else:
27 | grad_weight = Variable(autograd.grad(meta_loss, weight, create_graph=True)[0].data, requires_grad=False)
28 |
29 | if bias is not None:
30 | grad_bias = Variable(autograd.grad(meta_loss, bias, create_graph=True)[0].data, requires_grad=False)
31 | bias_adapt = bias - grad_bias * meta_step_size
32 | else:
33 | bias_adapt = bias
34 |
35 | return F.linear(inputs,
36 | weight - grad_weight * meta_step_size,
37 | bias_adapt)
38 | else:
39 | return F.linear(inputs, weight, bias)
40 |
41 | def conv2d(inputs, weight, bias, stride=1, padding=1, dilation=1, groups=1, kernel_size=3):
42 |
43 | inputs = inputs.cuda()
44 | weight = weight.cuda()
45 | bias = bias.cuda()
46 |
47 | return F.conv2d(inputs, weight, bias, stride, padding, dilation, groups)
48 |
49 |
50 | def deconv2d(inputs, weight, bias, stride=2, padding=0, dilation=0, groups=1, kernel_size=None):
51 |
52 | inputs = inputs.cuda()
53 | weight = weight.cuda()
54 | bias = bias.cuda()
55 |
56 | return F.conv_transpose2d(inputs, weight, bias, stride, padding, dilation, groups)
57 |
58 | def relu(inputs):
59 | return F.relu(inputs, inplace=True)
60 |
61 |
62 | def maxpool(inputs, kernel_size, stride=None, padding=0):
63 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding)
64 |
65 |
66 | def dropout(inputs):
67 | return F.dropout(inputs, p=0.5, training=False, inplace=False)
68 |
69 | def batchnorm(inputs, running_mean, running_var):
70 | return F.batch_norm(inputs, running_mean, running_var)
71 |
72 |
73 | """
74 | The following are the new methods for 2D-Unet:
75 | Conv2d, batchnorm2d, GroupNorm, InstanceNorm2d, MaxPool2d, UpSample
76 | """
77 | #as per the 2D Unet: kernel_size, stride, padding
78 |
79 | def instancenorm(input):
80 | return F.instance_norm(input)
81 |
82 | def groupnorm(input):
83 | return F.group_norm(input)
84 |
85 | def dropout2D(inputs):
86 | return F.dropout2d(inputs, p=0.5, training=False, inplace=False)
87 |
88 | def maxpool2D(inputs, kernel_size, stride=None, padding=0):
89 | return F.max_pool2d(inputs, kernel_size, stride, padding=padding)
90 |
91 | def upsample(input):
92 | return F.upsample(input, scale_factor=2, mode='bilinear', align_corners=False)
93 |
--------------------------------------------------------------------------------
/networks/decoder_old.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class Decoder(nn.Module):
8 | def __init__(self, num_classes, backbone, BatchNorm):
9 | super(Decoder, self).__init__()
10 | if backbone == 'resnet' or backbone == 'drn':
11 | low_level_inplanes = 256
12 | elif backbone == 'xception':
13 | low_level_inplanes = 128
14 | elif backbone == 'mobilenet':
15 | low_level_inplanes = 24
16 | else:
17 | raise NotImplementedError
18 |
19 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
20 | self.bn1 = BatchNorm(48)
21 | self.relu = nn.ReLU()
22 | self.last_conv = nn.Sequential(
23 | # nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
24 | # BatchNorm(256),
25 | # nn.ReLU(),
26 | # nn.Dropout(0.5),
27 | # nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
28 | BatchNorm(305),
29 | nn.ReLU(),
30 | nn.Dropout(0.1),
31 | nn.Conv2d(305, num_classes, kernel_size=1, stride=1))
32 | self.last_conv_boundary = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
33 | BatchNorm(256),
34 | nn.ReLU(),
35 | nn.Dropout(0.5),
36 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
37 | BatchNorm(256),
38 | nn.ReLU(),
39 | nn.Dropout(0.1),
40 | nn.Conv2d(256, 1, kernel_size=1, stride=1))
41 | self._init_weight()
42 |
43 |
44 | def forward(self, x, low_level_feat):
45 | low_level_feat = self.conv1(low_level_feat)
46 | low_level_feat = self.bn1(low_level_feat)
47 | low_level_feat = self.relu(low_level_feat)
48 |
49 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)
50 | x = torch.cat((x, low_level_feat), dim=1)
51 | boundary = self.last_conv_boundary(x)
52 | x = torch.cat([x, boundary], 1)
53 | x1 = self.last_conv(x)
54 |
55 | return x1, boundary, x
56 |
57 | def _init_weight(self):
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | torch.nn.init.kaiming_normal_(m.weight)
61 | elif isinstance(m, SynchronizedBatchNorm2d):
62 | m.weight.data.fill_(1)
63 | m.bias.data.zero_()
64 | elif isinstance(m, nn.BatchNorm2d):
65 | m.weight.data.fill_(1)
66 | m.bias.data.zero_()
67 |
68 | def build_decoder(num_classes, backbone, BatchNorm):
69 | return Decoder(num_classes, backbone, BatchNorm)
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/networks/aspp.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class _ASPPModule(nn.Module):
8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
9 | super(_ASPPModule, self).__init__()
10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
11 | stride=1, padding=padding, dilation=dilation, bias=False)
12 | self.bn = BatchNorm(planes)
13 | self.relu = nn.ReLU()
14 |
15 | self._init_weight()
16 |
17 | def forward(self, x):
18 | x = self.atrous_conv(x)
19 | x = self.bn(x)
20 |
21 | return self.relu(x)
22 |
23 | def _init_weight(self):
24 | for m in self.modules():
25 | if isinstance(m, nn.Conv2d):
26 | torch.nn.init.kaiming_normal_(m.weight)
27 | elif isinstance(m, SynchronizedBatchNorm2d):
28 | m.weight.data.fill_(1)
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.BatchNorm2d):
31 | m.weight.data.fill_(1)
32 | m.bias.data.zero_()
33 |
34 | class ASPP(nn.Module):
35 | def __init__(self, backbone, output_stride, BatchNorm):
36 | super(ASPP, self).__init__()
37 | if backbone == 'drn':
38 | inplanes = 512
39 | elif backbone == 'mobilenet':
40 | inplanes = 320
41 | else:
42 | inplanes = 2048
43 | if output_stride == 16:
44 | dilations = [1, 6, 12, 18]
45 | elif output_stride == 8:
46 | dilations = [1, 12, 24, 36]
47 | else:
48 | raise NotImplementedError
49 |
50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
54 |
55 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
56 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
57 | BatchNorm(256),
58 | nn.ReLU())
59 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
60 | self.bn1 = BatchNorm(256)
61 | self.relu = nn.ReLU()
62 | self.dropout = nn.Dropout(0.5)
63 | self._init_weight()
64 |
65 | def forward(self, x):
66 | x1 = self.aspp1(x)
67 | x2 = self.aspp2(x)
68 | x3 = self.aspp3(x)
69 | x4 = self.aspp4(x)
70 | x5 = self.global_avg_pool(x)
71 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
72 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
73 |
74 | x = self.conv1(x)
75 | x = self.bn1(x)
76 | x = self.relu(x)
77 |
78 | return self.dropout(x)
79 |
80 | def _init_weight(self):
81 | for m in self.modules():
82 | if isinstance(m, nn.Conv2d):
83 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
84 | # m.weight.data.normal_(0, math.sqrt(2. / n))
85 | torch.nn.init.kaiming_normal_(m.weight)
86 | elif isinstance(m, SynchronizedBatchNorm2d):
87 | m.weight.data.fill_(1)
88 | m.bias.data.zero_()
89 | elif isinstance(m, nn.BatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 |
93 |
94 | def build_aspp(backbone, output_stride, BatchNorm):
95 | return ASPP(backbone, output_stride, BatchNorm)
96 |
--------------------------------------------------------------------------------
/networks/aspp_eval.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 |
7 | class _ASPPModule(nn.Module):
8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
9 | super(_ASPPModule, self).__init__()
10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
11 | stride=1, padding=padding, dilation=dilation, bias=False)
12 | self.bn = BatchNorm(planes)
13 | self.relu = nn.ReLU()
14 |
15 | self._init_weight()
16 |
17 | def forward(self, x):
18 | x = self.atrous_conv(x)
19 | x = self.bn(x)
20 |
21 | return self.relu(x)
22 |
23 | def _init_weight(self):
24 | for m in self.modules():
25 | if isinstance(m, nn.Conv2d):
26 | torch.nn.init.kaiming_normal_(m.weight)
27 | elif isinstance(m, SynchronizedBatchNorm2d):
28 | m.weight.data.fill_(1)
29 | m.bias.data.zero_()
30 | elif isinstance(m, nn.BatchNorm2d):
31 | m.weight.data.fill_(1)
32 | m.bias.data.zero_()
33 |
34 | class ASPP(nn.Module):
35 | def __init__(self, backbone, output_stride, BatchNorm):
36 | super(ASPP, self).__init__()
37 | if backbone == 'drn':
38 | inplanes = 512
39 | elif backbone == 'mobilenet':
40 | inplanes = 320
41 | else:
42 | inplanes = 2048
43 | if output_stride == 16:
44 | dilations = [1, 6, 12, 18]
45 | elif output_stride == 8:
46 | dilations = [1, 12, 24, 36]
47 | else:
48 | raise NotImplementedError
49 |
50 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
51 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
52 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
53 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)
54 |
55 | # self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
56 | # nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
57 | # BatchNorm(256),
58 | # nn.ReLU())
59 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
60 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
61 | nn.ReLU())
62 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
63 | self.bn1 = BatchNorm(256)
64 | self.relu = nn.ReLU()
65 | self.dropout = nn.Dropout(0.5)
66 | self._init_weight()
67 |
68 | def forward(self, x):
69 | x1 = self.aspp1(x)
70 | x2 = self.aspp2(x)
71 | x3 = self.aspp3(x)
72 | x4 = self.aspp4(x)
73 | x5 = self.global_avg_pool(x)
74 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
75 | x = torch.cat((x1, x2, x3, x4, x5), dim=1)
76 |
77 | x = self.conv1(x)
78 | x = self.bn1(x)
79 | x = self.relu(x)
80 |
81 | return self.dropout(x)
82 |
83 | def _init_weight(self):
84 | for m in self.modules():
85 | if isinstance(m, nn.Conv2d):
86 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
87 | # m.weight.data.normal_(0, math.sqrt(2. / n))
88 | torch.nn.init.kaiming_normal_(m.weight)
89 | elif isinstance(m, SynchronizedBatchNorm2d):
90 | m.weight.data.fill_(1)
91 | m.bias.data.zero_()
92 | elif isinstance(m, nn.BatchNorm2d):
93 | m.weight.data.fill_(1)
94 | m.bias.data.zero_()
95 |
96 |
97 | def build_aspp(backbone, output_stride, BatchNorm):
98 | return ASPP(backbone, output_stride, BatchNorm)
99 |
--------------------------------------------------------------------------------
/networks/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
61 | and passed to a registered callback.
62 | - After receiving the messages, the master device should gather the information and determine to message passed
63 | back to each slave devices.
64 | """
65 |
66 | def __init__(self, master_callback):
67 | """
68 | Args:
69 | master_callback: a callback to be invoked after having collected messages from slave devices.
70 | """
71 | self._master_callback = master_callback
72 | self._queue = queue.Queue()
73 | self._registry = collections.OrderedDict()
74 | self._activated = False
75 |
76 | def __getstate__(self):
77 | return {'master_callback': self._master_callback}
78 |
79 | def __setstate__(self, state):
80 | self.__init__(state['master_callback'])
81 |
82 | def register_slave(self, identifier):
83 | """
84 | Register an slave device.
85 | Args:
86 | identifier: an identifier, usually is the device id.
87 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
88 | """
89 | if self._activated:
90 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
91 | self._activated = False
92 | self._registry.clear()
93 | future = FutureResult()
94 | self._registry[identifier] = _MasterRegistry(future)
95 | return SlavePipe(identifier, self._queue, future)
96 |
97 | def run_master(self, master_msg):
98 | """
99 | Main entry for the master device in each forward pass.
100 | The messages were first collected from each devices (including the master device), and then
101 | an callback will be invoked to compute the message to be sent back to each devices
102 | (including the master device).
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 | Returns: the message to be sent back to the master device.
107 | """
108 | self._activated = True
109 |
110 | intermediates = [(0, master_msg)]
111 | for i in range(self.nr_slaves):
112 | intermediates.append(self._queue.get())
113 |
114 | results = self._master_callback(intermediates)
115 | assert results[0][0] == 0, 'The first result should belongs to the master.'
116 |
117 | for i, res in results:
118 | if i == 0:
119 | continue
120 | self._registry[i].result.put(res)
121 |
122 | for i in range(self.nr_slaves):
123 | assert self._queue.get() is True
124 |
125 | return results[0][1]
126 |
127 | @property
128 | def nr_slaves(self):
129 | return len(self._registry)
130 |
--------------------------------------------------------------------------------
/networks/backbone/mobilenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | import math
5 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
6 | import torch.utils.model_zoo as model_zoo
7 |
8 | def conv_bn(inp, oup, stride, BatchNorm):
9 | return nn.Sequential(
10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
11 | BatchNorm(oup),
12 | nn.ReLU6(inplace=True)
13 | )
14 |
15 |
16 | def fixed_padding(inputs, kernel_size, dilation):
17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
18 | pad_total = kernel_size_effective - 1
19 | pad_beg = pad_total // 2
20 | pad_end = pad_total - pad_beg
21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
22 | return padded_inputs
23 |
24 |
25 | class InvertedResidual(nn.Module):
26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm):
27 | super(InvertedResidual, self).__init__()
28 | self.stride = stride
29 | assert stride in [1, 2]
30 |
31 | hidden_dim = round(inp * expand_ratio)
32 | self.use_res_connect = self.stride == 1 and inp == oup
33 | self.kernel_size = 3
34 | self.dilation = dilation
35 |
36 | if expand_ratio == 1:
37 | self.conv = nn.Sequential(
38 | # dw
39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
40 | BatchNorm(hidden_dim),
41 | nn.ReLU6(inplace=True),
42 | # pw-linear
43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False),
44 | BatchNorm(oup),
45 | )
46 | else:
47 | self.conv = nn.Sequential(
48 | # pw
49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False),
50 | BatchNorm(hidden_dim),
51 | nn.ReLU6(inplace=True),
52 | # dw
53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False),
54 | BatchNorm(hidden_dim),
55 | nn.ReLU6(inplace=True),
56 | # pw-linear
57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False),
58 | BatchNorm(oup),
59 | )
60 |
61 | def forward(self, x):
62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation)
63 | if self.use_res_connect:
64 | x = x + self.conv(x_pad)
65 | else:
66 | x = self.conv(x_pad)
67 | return x
68 |
69 |
70 | class MobileNetV2(nn.Module):
71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True):
72 | super(MobileNetV2, self).__init__()
73 | block = InvertedResidual
74 | input_channel = 32
75 | current_stride = 1
76 | rate = 1
77 | interverted_residual_setting = [
78 | # t, c, n, s
79 | [1, 16, 1, 1],
80 | [6, 24, 2, 2],
81 | [6, 32, 3, 2],
82 | [6, 64, 4, 2],
83 | [6, 96, 3, 1],
84 | [6, 160, 3, 2],
85 | [6, 320, 1, 1],
86 | ]
87 |
88 | # building first layer
89 | input_channel = int(input_channel * width_mult)
90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)]
91 | current_stride *= 2
92 | # building inverted residual blocks
93 | for t, c, n, s in interverted_residual_setting:
94 | if current_stride == output_stride:
95 | stride = 1
96 | dilation = rate
97 | rate *= s
98 | else:
99 | stride = s
100 | dilation = 1
101 | current_stride *= s
102 | output_channel = int(c * width_mult)
103 | for i in range(n):
104 | if i == 0:
105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm))
106 | else:
107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm))
108 | input_channel = output_channel
109 | self.features = nn.Sequential(*self.features)
110 | self._initialize_weights()
111 |
112 | if pretrained:
113 | self._load_pretrained_model()
114 |
115 | self.low_level_features = self.features[0:4]
116 | self.high_level_features = self.features[4:]
117 |
118 | def forward(self, x):
119 | low_level_feat = self.low_level_features(x)
120 | x = self.high_level_features(low_level_feat)
121 | return x, low_level_feat
122 |
123 | def _load_pretrained_model(self):
124 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth')
125 | model_dict = {}
126 | state_dict = self.state_dict()
127 | for k, v in pretrain_dict.items():
128 | if k in state_dict:
129 | model_dict[k] = v
130 | state_dict.update(model_dict)
131 | self.load_state_dict(state_dict)
132 |
133 | def _initialize_weights(self):
134 | for m in self.modules():
135 | if isinstance(m, nn.Conv2d):
136 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
137 | # m.weight.data.normal_(0, math.sqrt(2. / n))
138 | torch.nn.init.kaiming_normal_(m.weight)
139 | elif isinstance(m, SynchronizedBatchNorm2d):
140 | m.weight.data.fill_(1)
141 | m.bias.data.zero_()
142 | elif isinstance(m, nn.BatchNorm2d):
143 | m.weight.data.fill_(1)
144 | m.bias.data.zero_()
145 |
146 | if __name__ == "__main__":
147 | input = torch.rand(1, 3, 512, 512)
148 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d)
149 | output, low_level_feat = model(input)
150 | print(output.size())
151 | print(low_level_feat.size())
152 |
--------------------------------------------------------------------------------
/dataloaders/fundus_dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | from PIL import Image
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 | from glob import glob
7 | import random
8 |
9 |
10 | class FundusSegmentation(Dataset):
11 | """
12 | Fundus segmentation dataset
13 | including 5 domain dataset
14 | one for test others for training
15 | """
16 |
17 | def __init__(self,
18 | base_dir,
19 | dataset='refuge',
20 | split='train',
21 | testid=None,
22 | transform=None
23 | ):
24 | """
25 | :param base_dir: path to VOC dataset directory
26 | :param split: train/val
27 | :param transform: transform to apply
28 | """
29 | # super().__init__()
30 | self._base_dir = base_dir
31 | self.image_list = []
32 | self.split = split
33 |
34 | self.image_pool = []
35 | self.label_pool = []
36 | self.img_name_pool = []
37 |
38 | self._image_dir = os.path.join(self._base_dir, dataset, split, 'image')
39 | print(self._image_dir)
40 | imagelist = glob(self._image_dir + "/*.png")
41 | for image_path in imagelist:
42 | gt_path = image_path.replace('image', 'mask')
43 | self.image_list.append({'image': image_path, 'label': gt_path, 'id': testid})
44 |
45 | self.transform = transform
46 | # self._read_img_into_memory()
47 | # Display stats
48 | print('Number of images in {}: {:d}'.format(split, len(self.image_list)))
49 |
50 | def __len__(self):
51 | return len(self.image_list)
52 |
53 | def __getitem__(self, index):
54 |
55 | _img = Image.open(self.image_list[index]['image']).convert('RGB')
56 | _target = Image.open(self.image_list[index]['label'])
57 | if _target.mode is 'RGB':
58 | _target = _target.convert('L')
59 | _img_name = self.image_list[index]['image'].split('/')[-1]
60 |
61 | # _img = self.image_pool[index]
62 | # _target = self.label_pool[index]
63 | # _img_name = self.img_name_pool[index]
64 | anco_sample = {'image': _img, 'label': _target, 'img_name': _img_name}
65 |
66 | if self.transform is not None:
67 | anco_sample = self.transform(anco_sample)
68 |
69 | return anco_sample
70 |
71 | def _read_img_into_memory(self):
72 |
73 | img_num = len(self.image_list)
74 | for index in range(img_num):
75 | self.image_pool.append(Image.open(self.image_list[index]['image']).convert('RGB'))
76 | _target = Image.open(self.image_list[index]['label'])
77 | if _target.mode is 'RGB':
78 | _target = _target.convert('L')
79 | self.label_pool.append(_target)
80 | _img_name = self.image_list[index]['image'].split('/')[-1]
81 | self.img_name_pool.append(_img_name)
82 |
83 |
84 | def __str__(self):
85 | return 'Fundus(split=' + str(self.split) + ')'
86 |
87 |
88 | class FundusSegmentation_2transform(Dataset):
89 | """
90 | Fundus segmentation dataset
91 | including 5 domain dataset
92 | one for test others for training
93 | """
94 |
95 | def __init__(self,
96 | base_dir,
97 | dataset='refuge',
98 | split='train',
99 | testid=None,
100 | transform_weak=None,
101 | transform_strong=None
102 | ):
103 | """
104 | :param base_dir: path to VOC dataset directory
105 | :param split: train/val
106 | :param transform: transform to apply
107 | """
108 | # super().__init__()
109 | self._base_dir = base_dir
110 | self.image_list = []
111 | self.split = split
112 |
113 | self.image_pool = []
114 | self.label_pool = []
115 | self.img_name_pool = []
116 |
117 | self._image_dir = os.path.join(self._base_dir, dataset, split, 'image')
118 | print(self._image_dir)
119 | imagelist = glob(self._image_dir + "/*.png")
120 | for image_path in imagelist:
121 | gt_path = image_path.replace('image', 'mask')
122 | self.image_list.append({'image': image_path, 'label': gt_path, 'id': testid})
123 |
124 | self.transform_weak = transform_weak
125 | self.transform_strong = transform_strong
126 | # self._read_img_into_memory()
127 | # Display stats
128 | print('Number of images in {}: {:d}'.format(split, len(self.image_list)))
129 |
130 | def __len__(self):
131 | return len(self.image_list)
132 |
133 | def __getitem__(self, index):
134 |
135 | _img = Image.open(self.image_list[index]['image']).convert('RGB')
136 | _target = Image.open(self.image_list[index]['label'])
137 | if _target.mode is 'RGB':
138 | _target = _target.convert('L')
139 | _img_name = self.image_list[index]['image'].split('/')[-1]
140 |
141 | # _img = self.image_pool[index]
142 | # _target = self.label_pool[index]
143 | # _img_name = self.img_name_pool[index]
144 | anco_sample = {'image': _img, 'label': _target, 'img_name': _img_name}
145 |
146 | anco_sample_weak_aug = self.transform_weak(anco_sample)
147 |
148 | anco_sample_strong_aug = self.transform_strong(anco_sample)
149 |
150 | return anco_sample_weak_aug, anco_sample_strong_aug
151 |
152 | def _read_img_into_memory(self):
153 |
154 | img_num = len(self.image_list)
155 | for index in range(img_num):
156 | self.image_pool.append(Image.open(self.image_list[index]['image']).convert('RGB'))
157 | _target = Image.open(self.image_list[index]['label'])
158 | if _target.mode is 'RGB':
159 | _target = _target.convert('L')
160 | self.label_pool.append(_target)
161 | _img_name = self.image_list[index]['image'].split('/')[-1]
162 | self.img_name_pool.append(_img_name)
163 |
164 |
165 | def __str__(self):
166 | return 'Fundus(split=' + str(self.split) + ')'
167 |
168 |
169 |
--------------------------------------------------------------------------------
/networks/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch.nn as nn
3 | import torch.utils.model_zoo as model_zoo
4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 |
6 | class Bottleneck(nn.Module):
7 | expansion = 4
8 |
9 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None):
10 | super(Bottleneck, self).__init__()
11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
12 | self.bn1 = BatchNorm(planes)
13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
14 | dilation=dilation, padding=dilation, bias=False)
15 | self.bn2 = BatchNorm(planes)
16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
17 | self.bn3 = BatchNorm(planes * 4)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.downsample = downsample
20 | self.stride = stride
21 | self.dilation = dilation
22 |
23 | def forward(self, x):
24 | residual = x
25 |
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 |
30 | out = self.conv2(out)
31 | out = self.bn2(out)
32 | out = self.relu(out)
33 |
34 | out = self.conv3(out)
35 | out = self.bn3(out)
36 |
37 | if self.downsample is not None:
38 | residual = self.downsample(x)
39 |
40 | out += residual
41 | out = self.relu(out)
42 |
43 | return out
44 |
45 | class ResNet(nn.Module):
46 |
47 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True):
48 | self.inplanes = 64
49 | super(ResNet, self).__init__()
50 | blocks = [1, 2, 4]
51 | if output_stride == 16:
52 | strides = [1, 2, 2, 1]
53 | dilations = [1, 1, 1, 2]
54 | elif output_stride == 8:
55 | strides = [1, 2, 1, 1]
56 | dilations = [1, 1, 2, 4]
57 | else:
58 | raise NotImplementedError
59 |
60 | # Modules
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = BatchNorm(64)
64 | self.relu = nn.ReLU(inplace=True)
65 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
66 |
67 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm)
68 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm)
69 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm)
70 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
71 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm)
72 | self._init_weight()
73 |
74 | if pretrained:
75 | self._load_pretrained_model()
76 |
77 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
78 | downsample = None
79 | if stride != 1 or self.inplanes != planes * block.expansion:
80 | downsample = nn.Sequential(
81 | nn.Conv2d(self.inplanes, planes * block.expansion,
82 | kernel_size=1, stride=stride, bias=False),
83 | BatchNorm(planes * block.expansion),
84 | )
85 |
86 | layers = []
87 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm))
88 | self.inplanes = planes * block.expansion
89 | for i in range(1, blocks):
90 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm))
91 |
92 | return nn.Sequential(*layers)
93 |
94 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
95 | downsample = None
96 | if stride != 1 or self.inplanes != planes * block.expansion:
97 | downsample = nn.Sequential(
98 | nn.Conv2d(self.inplanes, planes * block.expansion,
99 | kernel_size=1, stride=stride, bias=False),
100 | BatchNorm(planes * block.expansion),
101 | )
102 |
103 | layers = []
104 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation,
105 | downsample=downsample, BatchNorm=BatchNorm))
106 | self.inplanes = planes * block.expansion
107 | for i in range(1, len(blocks)):
108 | layers.append(block(self.inplanes, planes, stride=1,
109 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm))
110 |
111 | return nn.Sequential(*layers)
112 |
113 | def forward(self, input):
114 | x = self.conv1(input)
115 | x = self.bn1(x)
116 | x = self.relu(x)
117 | x = self.maxpool(x)
118 |
119 | x = self.layer1(x)
120 | low_level_feat = x
121 | x = self.layer2(x)
122 | x = self.layer3(x)
123 | x = self.layer4(x)
124 | return x, low_level_feat
125 |
126 | def _init_weight(self):
127 | for m in self.modules():
128 | if isinstance(m, nn.Conv2d):
129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
130 | m.weight.data.normal_(0, math.sqrt(2. / n))
131 | elif isinstance(m, SynchronizedBatchNorm2d):
132 | m.weight.data.fill_(1)
133 | m.bias.data.zero_()
134 | elif isinstance(m, nn.BatchNorm2d):
135 | m.weight.data.fill_(1)
136 | m.bias.data.zero_()
137 |
138 | def _load_pretrained_model(self):
139 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth')
140 | model_dict = {}
141 | state_dict = self.state_dict()
142 | for k, v in pretrain_dict.items():
143 | if k in state_dict:
144 | model_dict[k] = v
145 | state_dict.update(model_dict)
146 | self.load_state_dict(state_dict)
147 |
148 | def ResNet101(output_stride, BatchNorm, pretrained=True):
149 | """Constructs a ResNet-101 model.
150 | Args:
151 | pretrained (bool): If True, returns a model pre-trained on ImageNet
152 | """
153 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained)
154 | return model
155 |
156 | if __name__ == "__main__":
157 | import torch
158 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8)
159 | input = torch.rand(1, 3, 512, 512)
160 | output, low_level_feat = model(input)
161 | print(output.size())
162 | print(low_level_feat.size())
163 |
--------------------------------------------------------------------------------
/train_source.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser(
3 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
4 | )
5 | parser.add_argument('-g', '--gpu', type=str, default='7', help='gpu id')
6 | parser.add_argument('--resume', default=None, help='checkpoint path')
7 | parser.add_argument(
8 | '--dataset', type=str, default='Domain3', help='test folder id contain images ROIs to test'
9 | )
10 | parser.add_argument(
11 | '--model', type=str, default='Deeplab', help='Deeplab'
12 | )
13 | parser.add_argument(
14 | '--batch-size', type=int, default=8, help='batch size for training the model'
15 | )
16 | parser.add_argument(
17 | '--group-num', type=int, default=1, help='group number for group normalization'
18 | )
19 | parser.add_argument(
20 | '--max-epoch', type=int, default=200, help='max epoch'
21 | )
22 | parser.add_argument(
23 | '--stop-epoch', type=int, default=200, help='stop epoch'
24 | )
25 | parser.add_argument(
26 | '--warmup-epoch', type=int, default=-1, help='warmup epoch begin train GAN'
27 | )
28 |
29 | parser.add_argument(
30 | '--interval-validate', type=int, default=5, help='interval epoch number to validate the model'
31 | )
32 | parser.add_argument(
33 | '--interval-save', type=int, default=10, help='interval epoch number to save the model'
34 | )
35 | parser.add_argument(
36 | '--lr', type=float, default=1e-3, help='learning rate',
37 | )
38 | parser.add_argument(
39 | '--lr-decrease-rate', type=float, default=0.98, help='ratio multiplied to initial lr',
40 | )
41 | parser.add_argument(
42 | '--lr-decrease-epoch', type=int, default=1, help='interval epoch number for lr decrease',
43 | )
44 | parser.add_argument(
45 | '--weight-decay', type=float, default=0, help='weight decay',
46 | )
47 | parser.add_argument(
48 | '--momentum', type=float, default=0.99, help='momentum',
49 | )
50 | parser.add_argument(
51 | '--data-dir',
52 | default='../Datasets/Fundus',
53 | help='data root path'
54 | )
55 | parser.add_argument(
56 | '--out-stride',
57 | type=int,
58 | default=16,
59 | help='out-stride of deeplabv3+',
60 | )
61 | parser.add_argument(
62 | '--sync-bn',
63 | type=bool,
64 | default=True,
65 | help='sync-bn in deeplabv3+',
66 | )
67 | parser.add_argument(
68 | '--freeze-bn',
69 | type=bool,
70 | default=False,
71 | help='freeze batch normalization of deeplabv3+',
72 | )
73 | parser.add_argument('--no-augmentation', action='store_true')
74 |
75 | args = parser.parse_args()
76 |
77 | from datetime import datetime
78 | import os
79 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
80 | import os.path as osp
81 |
82 | # PyTorch includes
83 | import torch
84 | from torchvision import transforms
85 | from torch.utils.data import DataLoader
86 | import yaml
87 | from train_process import Trainer
88 |
89 | # Custom includes
90 | from dataloaders import fundus_dataloader
91 | from dataloaders import custom_transforms as trans
92 | from networks.deeplabv3 import *
93 |
94 |
95 | here = osp.dirname(osp.abspath(__file__))
96 |
97 | def main():
98 | now = datetime.now()
99 | args.out = osp.join(here, 'logs_train', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f'))
100 |
101 | os.makedirs(args.out)
102 | with open(osp.join(args.out, 'config.yaml'), 'w') as f:
103 | yaml.safe_dump(args.__dict__, f, default_flow_style=False)
104 |
105 | multiply_gpu = False
106 | if (args.gpu).find(',') != -1:
107 | multiply_gpu = True
108 | cuda = torch.cuda.is_available()
109 |
110 | torch.manual_seed(42)
111 | if cuda:
112 | torch.cuda.manual_seed(42)
113 |
114 | # 1. dataset
115 | composed_transforms_tr = transforms.Compose([
116 | trans.RandomScaleCrop(512),
117 | trans.RandomRotate(),
118 | trans.RandomFlip(),
119 | trans.elastic_transform(),
120 | trans.add_salt_pepper_noise(),
121 | trans.adjust_light(),
122 | trans.eraser(),
123 | trans.Normalize_tf(),
124 | trans.ToTensor()
125 | ])
126 |
127 | composed_transforms_ts = transforms.Compose([
128 | # fundus_trans.RandomCrop(512),
129 | trans.Resize(512),
130 | trans.Normalize_tf(), # this function separates (raw ground truth mask) into (2 masks)
131 | trans.ToTensor()
132 | ])
133 |
134 | if args.no_augmentation:
135 | domain = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset,
136 | split='train/ROIs', transform=composed_transforms_ts)
137 | else:
138 | domain = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset,
139 | split='train/ROIs', transform=composed_transforms_tr)
140 | domain_loader = DataLoader(domain, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)
141 |
142 | domain_val = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset,
143 | split='test/ROIs', transform=composed_transforms_ts)
144 | domain_loader_val = DataLoader(domain_val, batch_size=args.batch_size, shuffle=False, num_workers=2,
145 | pin_memory=True)
146 |
147 | # 2. model
148 | model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride,
149 | sync_bn=args.sync_bn, freeze_bn=args.freeze_bn)
150 | if cuda:
151 | model = model.cuda()
152 |
153 | start_epoch = 0
154 | start_iteration = 0
155 |
156 | # 3. optimizer
157 | if multiply_gpu:
158 | model = torch.nn.DataParallel(model, device_ids=[0, 1])
159 |
160 | optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.weight_decay)
161 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=args.lr_decrease_epoch, gamma=args.lr_decrease_rate)
162 |
163 | if args.resume:
164 | checkpoint = torch.load(args.resume)
165 | pretrained_dict = checkpoint['model_state_dict']
166 | model_dict = model.state_dict()
167 | # 1. filter out unnecessary keys
168 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
169 | # 2. overwrite entries in the existing state dict
170 | model_dict.update(pretrained_dict)
171 | # 3. load the new state dict
172 | model.load_state_dict(model_dict)
173 |
174 |
175 | start_epoch = checkpoint['epoch'] + 1
176 | start_iteration = checkpoint['iteration'] + 1
177 | optim.load_state_dict(checkpoint['optim_state_dict'])
178 |
179 | trainer = Trainer.Trainer(
180 | cuda=cuda,
181 | multiply_gpu=multiply_gpu,
182 | model=model,
183 | optimizer=optim,
184 | scheduler=scheduler,
185 | lr=args.lr,
186 | val_loader=domain_loader_val,
187 | domain_loader=domain_loader,
188 | out=args.out,
189 | max_epoch=args.max_epoch,
190 | stop_epoch=args.stop_epoch,
191 | interval_validate=args.interval_validate,
192 | interval_save=args.interval_save,
193 | batch_size=args.batch_size,
194 | warmup_epoch=args.warmup_epoch
195 | )
196 | trainer.epoch = start_epoch
197 | trainer.iteration = start_iteration
198 | trainer.train()
199 |
200 | if __name__ == '__main__':
201 | main()
202 |
--------------------------------------------------------------------------------
/networks/GAN.py:
--------------------------------------------------------------------------------
1 | # camera-ready
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch
6 |
7 |
8 | class Discriminator(nn.Module):
9 | def __init__(self, ):
10 | super(Discriminator, self).__init__()
11 |
12 | filter_num_list = [4096, 2048, 1024, 1]
13 |
14 | self.fc1 = nn.Linear(24576, filter_num_list[0])
15 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2)
16 | self.fc2 = nn.Linear(filter_num_list[0], filter_num_list[1])
17 | self.fc3 = nn.Linear(filter_num_list[1], filter_num_list[2])
18 | self.fc4 = nn.Linear(filter_num_list[2], filter_num_list[3])
19 |
20 | # self.sigmoid = nn.Sigmoid()
21 | self._initialize_weights()
22 |
23 |
24 | def _initialize_weights(self):
25 |
26 | for m in self.modules():
27 | if isinstance(m, nn.Conv2d):
28 | m.weight.data.normal_(0.0, 0.02)
29 | if m.bias is not None:
30 | m.bias.data.zero_()
31 |
32 | if isinstance(m, nn.ConvTranspose2d):
33 | m.weight.data.normal_(0.0, 0.02)
34 | if m.bias is not None:
35 | m.bias.data.zero_()
36 |
37 | if isinstance(m, nn.Linear):
38 | m.weight.data.normal_(0.0, 0.02)
39 | if m.bias is not None:
40 | # m.bias.data.copy_(1.0)
41 | m.bias.data.zero_()
42 |
43 |
44 | def forward(self, x):
45 |
46 | x = self.leakyrelu(self.fc1(x))
47 | x = self.leakyrelu(self.fc2(x))
48 | x = self.leakyrelu(self.fc3(x))
49 | x = self.fc4(x)
50 | return x
51 |
52 |
53 | class OutputDiscriminator(nn.Module):
54 | def __init__(self, ):
55 | super(OutputDiscriminator, self).__init__()
56 |
57 | filter_num_list = [64, 128, 256, 512, 1]
58 |
59 | self.conv1 = nn.Conv2d(2, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False)
60 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False)
61 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False)
62 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False)
63 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False)
64 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2)
65 | # self.sigmoid = nn.Sigmoid()
66 | self._initialize_weights()
67 |
68 |
69 | def _initialize_weights(self):
70 | for m in self.modules():
71 | if isinstance(m, nn.Conv2d):
72 | m.weight.data.normal_(0.0, 0.02)
73 | if m.bias is not None:
74 | m.bias.data.zero_()
75 |
76 |
77 | def forward(self, x):
78 | x = self.leakyrelu(self.conv1(x))
79 | x = self.leakyrelu(self.conv2(x))
80 | x = self.leakyrelu(self.conv3(x))
81 | x = self.leakyrelu(self.conv4(x))
82 | x = self.conv5(x)
83 | return x
84 |
85 |
86 | class UncertaintyDiscriminator(nn.Module):
87 | def __init__(self, ):
88 | super(UncertaintyDiscriminator, self).__init__()
89 |
90 | filter_num_list = [64, 128, 256, 512, 1]
91 |
92 | self.conv1 = nn.Conv2d(2, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False)
93 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False)
94 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False)
95 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False)
96 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False)
97 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2)
98 | # self.sigmoid = nn.Sigmoid()
99 | self._initialize_weights()
100 |
101 |
102 | def _initialize_weights(self):
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | m.weight.data.normal_(0.0, 0.02)
106 | if m.bias is not None:
107 | m.bias.data.zero_()
108 |
109 |
110 | def forward(self, x):
111 | x = self.leakyrelu(self.conv1(x))
112 | x = self.leakyrelu(self.conv2(x))
113 | x = self.leakyrelu(self.conv3(x))
114 | x = self.leakyrelu(self.conv4(x))
115 | x = self.conv5(x)
116 | return x
117 |
118 | class BoundaryDiscriminator(nn.Module):
119 | def __init__(self, ):
120 | super(BoundaryDiscriminator, self).__init__()
121 |
122 | filter_num_list = [64, 128, 256, 512, 1]
123 |
124 | self.conv1 = nn.Conv2d(1, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False)
125 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False)
126 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False)
127 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False)
128 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False)
129 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2)
130 | # self.sigmoid = nn.Sigmoid()
131 | self._initialize_weights()
132 |
133 |
134 | def _initialize_weights(self):
135 | for m in self.modules():
136 | if isinstance(m, nn.Conv2d):
137 | m.weight.data.normal_(0.0, 0.02)
138 | if m.bias is not None:
139 | m.bias.data.zero_()
140 |
141 |
142 | def forward(self, x):
143 | x = self.leakyrelu(self.conv1(x))
144 | x = self.leakyrelu(self.conv2(x))
145 | x = self.leakyrelu(self.conv3(x))
146 | x = self.leakyrelu(self.conv4(x))
147 | x = self.conv5(x)
148 | return x
149 |
150 | class BoundaryEntDiscriminator(nn.Module):
151 | def __init__(self, ):
152 | super(BoundaryEntDiscriminator, self).__init__()
153 |
154 | filter_num_list = [64, 128, 256, 512, 1]
155 |
156 | self.conv1 = nn.Conv2d(3, filter_num_list[0], kernel_size=4, stride=2, padding=2, bias=False)
157 | self.conv2 = nn.Conv2d(filter_num_list[0], filter_num_list[1], kernel_size=4, stride=2, padding=2, bias=False)
158 | self.conv3 = nn.Conv2d(filter_num_list[1], filter_num_list[2], kernel_size=4, stride=2, padding=2, bias=False)
159 | self.conv4 = nn.Conv2d(filter_num_list[2], filter_num_list[3], kernel_size=4, stride=2, padding=2, bias=False)
160 | self.conv5 = nn.Conv2d(filter_num_list[3], filter_num_list[4], kernel_size=4, stride=2, padding=2, bias=False)
161 | self.leakyrelu = nn.LeakyReLU(negative_slope=0.2)
162 | # self.sigmoid = nn.Sigmoid()
163 | self._initialize_weights()
164 |
165 |
166 | def _initialize_weights(self):
167 | for m in self.modules():
168 | if isinstance(m, nn.Conv2d):
169 | m.weight.data.normal_(0.0, 0.02)
170 | if m.bias is not None:
171 | m.bias.data.zero_()
172 |
173 |
174 | def forward(self, x):
175 | x = self.leakyrelu(self.conv1(x))
176 | x = self.leakyrelu(self.conv2(x))
177 | x = self.leakyrelu(self.conv3(x))
178 | x = self.leakyrelu(self.conv4(x))
179 | x = self.conv5(x)
180 | return x
181 |
182 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import medpy.metric.binary as medmetric
4 |
5 | bce = torch.nn.BCEWithLogitsLoss(reduction='none')
6 |
7 | def _upscan(f):
8 | for i, fi in enumerate(f):
9 | if fi == np.inf: continue
10 | for j in range(1,i+1):
11 | x = fi+j*j
12 | if f[i-j] < x: break
13 | f[i-j] = x
14 |
15 |
16 | def dice_coefficient_numpy(binary_segmentation, binary_gt_label):
17 | '''
18 | Compute the Dice coefficient between two binary segmentation.
19 | Dice coefficient is defined as here: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
20 | Input:
21 | binary_segmentation: binary 2D numpy array representing the region of interest as segmented by the algorithm
22 | binary_gt_label: binary 2D numpy array representing the region of interest as provided in the database
23 | Output:
24 | dice_value: Dice coefficient between the segmentation and the ground truth
25 | '''
26 |
27 | # turn all variables to booleans, just in case
28 | binary_segmentation = np.asarray(binary_segmentation, dtype=np.bool)
29 | binary_gt_label = np.asarray(binary_gt_label, dtype=np.bool)
30 |
31 | # compute the intersection
32 | intersection = np.logical_and(binary_segmentation, binary_gt_label)
33 |
34 | # count the number of True pixels in the binary segmentation
35 | # segmentation_pixels = float(np.sum(binary_segmentation.flatten()))
36 | segmentation_pixels = np.sum(binary_segmentation.astype(float), axis=(1,2))
37 | # same for the ground truth
38 | # gt_label_pixels = float(np.sum(binary_gt_label.flatten()))
39 | gt_label_pixels = np.sum(binary_gt_label.astype(float), axis=(1,2))
40 | # same for the intersection
41 | intersection = np.sum(intersection.astype(float), axis=(1,2))
42 |
43 | # compute the Dice coefficient
44 | dice_value = (2 * intersection + 1.0) / (1.0 + segmentation_pixels + gt_label_pixels)
45 |
46 | # return it
47 | return dice_value
48 |
49 |
50 | def dice_coefficient_numpy_3D(binary_segmentation, binary_gt_label):
51 | '''
52 | Compute the Dice coefficient between two binary segmentation.
53 | Dice coefficient is defined as here: https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient
54 | Input:
55 | binary_segmentation: binary 3D numpy array representing the region of interest as segmented by the algorithm
56 | binary_gt_label: binary 3D numpy array representing the region of interest as provided in the database
57 | Output:
58 | dice_value: Dice coefficient between the segmentation and the ground truth
59 | '''
60 |
61 | # turn all variables to booleans, just in case
62 | binary_segmentation = np.asarray(binary_segmentation, dtype=np.bool)
63 | binary_gt_label = np.asarray(binary_gt_label, dtype=np.bool)
64 |
65 | # compute the intersection
66 | intersection = np.logical_and(binary_segmentation, binary_gt_label)
67 |
68 | # count the number of True pixels in the binary segmentation
69 | # segmentation_pixels = float(np.sum(binary_segmentation.flatten()))
70 | segmentation_pixels = np.sum(binary_segmentation.astype(float), axis=(0,1,2))
71 | # same for the ground truth
72 | # gt_label_pixels = float(np.sum(binary_gt_label.flatten()))
73 | gt_label_pixels = np.sum(binary_gt_label.astype(float), axis=(0,1,2))
74 | # same for the intersection
75 | intersection = np.sum(intersection.astype(float), axis=(0,1,2))
76 |
77 | # compute the Dice coefficient
78 | dice_value = (2 * intersection + 1.0) / (1.0 + segmentation_pixels + gt_label_pixels)
79 |
80 | # return it
81 | return dice_value
82 |
83 |
84 | def dice_numpy_medpy(binary_segmentation, binary_gt_label):
85 |
86 | # turn all variables to booleans, just in case
87 | binary_segmentation = np.asarray(binary_segmentation)
88 | binary_gt_label = np.asarray(binary_gt_label)
89 |
90 | return medmetric.dc(binary_segmentation, binary_gt_label)
91 |
92 |
93 | # if get_hd:
94 | # if np.sum(binary_segmentation) > 0 and np.sum(binary_gt_label) > 0:
95 | # return medmetric.assd(binary_segmentation, binary_gt_label)
96 | # # return medmetric.hd(binary_segmentation, binary_gt_label)
97 | # else:
98 | # return np.nan
99 | # else:
100 | # return 0.0
101 |
102 |
103 | def assd_numpy(binary_segmentation, binary_gt_label):
104 |
105 | # turn all variables to booleans, just in case
106 | binary_segmentation = np.asarray(binary_segmentation)
107 | binary_gt_label = np.asarray(binary_gt_label)
108 |
109 | if np.sum(binary_segmentation) > 0 and np.sum(binary_gt_label) > 0:
110 | return medmetric.assd(binary_segmentation, binary_gt_label)
111 | else:
112 | return -1
113 |
114 |
115 | def hd_numpy(binary_segmentation, binary_gt_label):
116 |
117 | # turn all variables to booleans, just in case
118 | binary_segmentation = np.asarray(binary_segmentation)
119 | binary_gt_label = np.asarray(binary_gt_label)
120 |
121 | if np.sum(binary_segmentation) > 0 and np.sum(binary_gt_label) > 0:
122 | return medmetric.hd(binary_segmentation, binary_gt_label)
123 | else:
124 | return -1
125 |
126 |
127 | def dice_coeff(pred, target):
128 | """This definition generalize to real valued pred and target vector.
129 | This should be differentiable.
130 | pred: tensor with first dimension as batch
131 | target: tensor with first dimension as batch
132 | """
133 |
134 | target = target.data.cpu()
135 | pred = torch.sigmoid(pred)
136 | pred = pred.data.cpu()
137 | pred[pred > 0.5] = 1
138 | pred[pred <= 0.5] = 0
139 |
140 | return dice_coefficient_numpy(pred, target)
141 |
142 | def dice_coeff_2label(pred, target):
143 | """This definition generalize to real valued pred and target vector.
144 | This should be differentiable.
145 | pred: tensor with first dimension as batch
146 | target: tensor with first dimension as batch
147 | """
148 | target = target.data.cpu()
149 | pred = torch.sigmoid(pred)
150 | pred = pred.data.cpu()
151 | pred[pred > 0.75] = 1
152 | pred[pred <= 0.75] = 0
153 | return dice_coefficient_numpy(pred[:, 0, ...], target[:, 0, ...]), dice_coefficient_numpy(pred[:, 1, ...],
154 | target[:, 1, ...])
155 |
156 |
157 | def dice_coeff_4label(pred, target):
158 | """This definition generalize to real valued pred and target vector.
159 | This should be differentiable.
160 | pred: tensor with first dimension as batch
161 | target: tensor with first dimension as batch
162 | """
163 | y_hat = torch.argmax(pred, dim=1, keepdim=True)
164 | pred_label = torch.zeros(pred.size()) # bs*4*W*H
165 | if torch.cuda.is_available():
166 | pred_label = pred_label.cuda()
167 | pred_label = pred_label.scatter_(1, y_hat, 1) # one-hot label
168 | target = target.data.cpu()
169 | pred_label = pred_label.data.cpu()
170 | return (dice_coefficient_numpy(pred_label[:, 0, ...], target[:, 0, ...]),
171 | dice_coefficient_numpy(pred_label[:, 1, ...], target[:, 1, ...]),
172 | dice_coefficient_numpy(pred_label[:, 2, ...], target[:, 2, ...]),
173 | dice_coefficient_numpy(pred_label[:, 3, ...], target[:, 3, ...]))
174 |
175 |
176 | def DiceLoss(input, target):
177 | '''
178 | in tensor fomate
179 | :param input:
180 | :param target:
181 | :return:
182 | '''
183 | smooth = 1.
184 | iflat = input.contiguous().view(-1)
185 | tflat = target.contiguous().view(-1)
186 | intersection = (iflat * tflat).sum()
187 |
188 | return 1 - ((2. * intersection + smooth) /
189 | (iflat.sum() + tflat.sum() + smooth))
190 |
191 |
192 | def assd_compute(pred, target):
193 | target = target.data.cpu()
194 | pred = torch.sigmoid(pred)
195 | pred = pred.data.cpu()
196 | pred[pred > 0.75] = 1
197 | pred[pred <= 0.75] = 0
198 |
199 | assd = np.zeros([pred.shape[0], pred.shape[1]])
200 | for i in range(pred.shape[0]):
201 | for j in range(pred.shape[1]):
202 | assd[i][j] = assd_numpy(pred[i, j, ...], target[i, j, ...])
203 |
204 | return assd
205 |
--------------------------------------------------------------------------------
/train_process/Trainer.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import os
3 | import os.path as osp
4 | import timeit
5 | from torchvision.utils import make_grid
6 | import time
7 |
8 | import numpy as np
9 | import pytz
10 | import torch
11 | import torch.nn.functional as F
12 |
13 | from tensorboardX import SummaryWriter
14 |
15 | import tqdm
16 | import socket
17 | from utils.metrics import *
18 | from utils.Utils import *
19 |
20 | bceloss = torch.nn.BCELoss()
21 | mseloss = torch.nn.MSELoss()
22 |
23 | def get_lr(optimizer):
24 | for param_group in optimizer.param_groups:
25 | return param_group['lr']
26 |
27 | class Trainer(object):
28 |
29 | def __init__(self, cuda, multiply_gpu, model, optimizer, scheduler, val_loader, domain_loader, out, max_epoch, stop_epoch=None,
30 | lr=1e-3, interval_validate=10, interval_save=10, batch_size=8, warmup_epoch=10):
31 | self.cuda = cuda
32 | self.multiply_gpu = multiply_gpu
33 | self.warmup_epoch = warmup_epoch
34 | self.model = model
35 | self.optim = optimizer
36 | self.scheduler = scheduler
37 | self.lr = lr
38 | # self.lr_decrease_rate = lr_decrease_rate
39 | # self.lr_decrease_epoch = lr_decrease_epoch
40 | self.batch_size = batch_size
41 |
42 | self.val_loader = val_loader
43 | self.domain_loader = domain_loader
44 | self.time_zone = 'Asia/Shanghai'
45 | self.timestamp_start = datetime.now(pytz.timezone(self.time_zone))
46 |
47 | self.interval_validate = interval_validate
48 | self.interval_save = interval_save
49 |
50 | self.out = out
51 | if not osp.exists(self.out):
52 | os.makedirs(self.out)
53 |
54 | self.log_headers = [
55 | 'epoch',
56 | 'iteration',
57 | 'train/loss_seg',
58 | 'valid/loss_CE',
59 | 'valid/cup_dice',
60 | 'valid/disc_dice',
61 | 'elapsed_time',
62 | 'best_epoch'
63 | ]
64 | if not osp.exists(osp.join(self.out, 'log.csv')):
65 | with open(osp.join(self.out, 'log.csv'), 'w') as f:
66 | f.write(','.join(self.log_headers) + '\n')
67 |
68 | log_dir = os.path.join(self.out, 'tensorboard',
69 | datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
70 | self.writer = SummaryWriter(log_dir=log_dir)
71 |
72 | self.epoch = 0
73 | self.iteration = 0
74 | self.max_epoch = max_epoch
75 | self.stop_epoch = stop_epoch if stop_epoch is not None else max_epoch
76 | self.best_mean_dice = 0.0
77 | self.best_epoch = -1
78 |
79 |
80 | def validate_fundus(self):
81 | training = self.model.training
82 | self.model.eval()
83 |
84 | val_loss = 0.0
85 | val_fundus_dice = {'cup': 0.0, 'disc': 0.0}
86 | data_num_cnt = 0.0
87 | metrics = []
88 | with torch.no_grad():
89 | for batch_idx, sample in tqdm.tqdm(
90 | enumerate(self.val_loader), total=len(self.val_loader),
91 | desc='Valid iteration=%d' % self.iteration, ncols=80,
92 | leave=False):
93 | data = sample['image']
94 | target_map = sample['label']
95 | if self.cuda:
96 | data, target_map = data.cuda(), target_map.cuda()
97 | with torch.no_grad():
98 | predictions, _ = self.model(data)
99 |
100 | loss = F.binary_cross_entropy_with_logits(predictions, target_map)
101 | loss_data = loss.data.item()
102 | if np.isnan(loss_data):
103 | raise ValueError('loss is nan while validating')
104 | val_loss += loss_data
105 |
106 | dice_cup, dice_disc = dice_coeff_2label(predictions, target_map)
107 | val_fundus_dice['cup'] += np.sum(dice_cup)
108 | val_fundus_dice['disc'] += np.sum(dice_disc)
109 | data_num_cnt += float(dice_cup.shape[0])
110 |
111 | val_loss /= data_num_cnt
112 | val_fundus_dice['cup'] /= data_num_cnt
113 | val_fundus_dice['disc'] /= data_num_cnt
114 | metrics.append((val_loss, val_fundus_dice['cup'], val_fundus_dice['disc']))
115 |
116 | self.writer.add_scalar('val_data/loss_CE', val_loss, self.epoch * (len(self.domain_loader)))
117 | self.writer.add_scalar('val_data/val_CUP_dice', val_fundus_dice['cup'], self.epoch * (len(self.domain_loader)))
118 | self.writer.add_scalar('val_data/val_DISC_dice', val_fundus_dice['disc'], self.epoch * (len(self.domain_loader)))
119 |
120 | mean_dice = (val_fundus_dice['cup'] + val_fundus_dice['disc'])/2
121 | is_best = mean_dice > self.best_mean_dice
122 | if is_best:
123 | self.best_epoch = self.epoch + 1
124 | self.best_mean_dice = mean_dice
125 |
126 | torch.save({
127 | 'epoch': self.epoch,
128 | 'iteration': self.iteration,
129 | 'arch': self.model.__class__.__name__,
130 | 'optim_state_dict': self.optim.state_dict(),
131 | 'model_state_dict': self.model.module.state_dict() if self.multiply_gpu else self.model.state_dict(),
132 | 'learning_rate_gen': get_lr(self.optim),
133 | 'best_mean_dice': self.best_mean_dice,
134 | }, osp.join(self.out, 'checkpoint_%d.pth.tar' % self.best_epoch))
135 | else:
136 | if (self.epoch + 1) % self.interval_save == 0:
137 | torch.save({
138 | 'epoch': self.epoch,
139 | 'iteration': self.iteration,
140 | 'arch': self.model.__class__.__name__,
141 | 'optim_state_dict': self.optim.state_dict(),
142 | 'model_state_dict': self.model.module.state_dict() if self.multiply_gpu else self.model.state_dict(),
143 | 'learning_rate_gen': get_lr(self.optim),
144 | 'best_mean_dice': self.best_mean_dice,
145 | }, osp.join(self.out, 'checkpoint_%d.pth.tar' % (self.epoch + 1)))
146 |
147 | with open(osp.join(self.out, 'log.csv'), 'a') as f:
148 | elapsed_time = (
149 | datetime.now(pytz.timezone(self.time_zone)) -
150 | self.timestamp_start).total_seconds()
151 | log = [self.epoch, self.iteration] + [''] + list(metrics) + [elapsed_time] + [self.best_epoch]
152 | log = map(str, log)
153 | f.write(','.join(log) + '\n')
154 | self.writer.add_scalar('best_model_epoch', self.best_epoch, self.epoch * (len(self.domain_loader)))
155 | if training:
156 | self.model.train()
157 |
158 |
159 | def train_epoch(self):
160 | self.model.train()
161 | self.running_seg_loss = 0.0
162 |
163 | start_time = timeit.default_timer()
164 | for batch_idx, sample in tqdm.tqdm(
165 | enumerate(self.domain_loader), total=len(self.domain_loader),
166 | desc='Train epoch=%d' % self.epoch, ncols=80, leave=False):
167 |
168 | iteration = batch_idx + self.epoch * len(self.domain_loader)
169 | self.iteration = iteration
170 |
171 | assert self.model.training
172 |
173 | self.optim.zero_grad()
174 |
175 | # train
176 | for param in self.model.parameters():
177 | param.requires_grad = True
178 |
179 | image = sample['image'].cuda()
180 | target_map = sample['label'].cuda()
181 |
182 | pred, _ = self.model(image)
183 | pred = torch.sigmoid(pred)
184 | loss_seg = bceloss(pred, target_map)
185 |
186 | self.running_seg_loss += loss_seg.item()
187 | self.running_seg_loss /= len(self.domain_loader)
188 |
189 | loss_seg_data = loss_seg.data.item()
190 | if np.isnan(loss_seg_data):
191 | raise ValueError('loss is nan while training')
192 |
193 | loss_seg.backward()
194 | self.optim.step()
195 |
196 | # write image log
197 | if iteration % 30 == 0:
198 | grid_image = make_grid(image[0, ...].clone().cpu().data, 1, normalize=True)
199 | self.writer.add_image('Domain/image', grid_image, iteration)
200 | grid_image = make_grid(target_map[0, 0, ...].clone().cpu().data, 1, normalize=True)
201 | self.writer.add_image('Domain/target_cup', grid_image, iteration)
202 | grid_image = make_grid(target_map[0, 1, ...].clone().cpu().data, 1, normalize=True)
203 | self.writer.add_image('Domain/target_disc', grid_image, iteration)
204 | grid_image = make_grid(pred[0, 0, ...].clone().cpu().data, 1, normalize=True)
205 | self.writer.add_image('Domain/prediction_cup', grid_image, iteration)
206 | grid_image = make_grid(pred[0, 1, ...].clone().cpu().data, 1, normalize=True)
207 | self.writer.add_image('Domain/prediction_disc', grid_image, iteration)
208 |
209 |
210 | self.writer.add_scalar('train/loss_seg', loss_seg_data, iteration)
211 |
212 | with open(osp.join(self.out, 'log.csv'), 'a') as f:
213 | elapsed_time = (
214 | datetime.now(pytz.timezone(self.time_zone)) -
215 | self.timestamp_start).total_seconds()
216 | log = [self.epoch, self.iteration] + [loss_seg_data] + [''] * 3 + [elapsed_time] + ['']
217 | log = map(str, log)
218 | f.write(','.join(log) + '\n')
219 |
220 | stop_time = timeit.default_timer()
221 |
222 | print('\n[Epoch: %d] lr:%f, Average segLoss: %f, Execution time: %.5f\n' %
223 | (self.epoch, get_lr(self.optim), self.running_seg_loss, stop_time - start_time))
224 |
225 |
226 | def train(self):
227 | for epoch in tqdm.trange(self.epoch, self.max_epoch,
228 | desc='Train', ncols=80):
229 | self.epoch = epoch
230 | self.train_epoch()
231 | if self.stop_epoch == self.epoch:
232 | print('Stop epoch at %d' % self.stop_epoch)
233 | break
234 |
235 | self.scheduler.step()
236 | self.writer.add_scalar('lr', get_lr(self.optim), self.epoch * (len(self.domain_loader)))
237 |
238 | if (self.epoch+1) % self.interval_validate == 0:
239 | self.validate_fundus()
240 | self.writer.close()
241 |
242 |
243 |
244 |
--------------------------------------------------------------------------------
/utils/Utils.py:
--------------------------------------------------------------------------------
1 |
2 | # from scipy.misc import imsave
3 | import os.path as osp
4 | import numpy as np
5 | import os
6 | import cv2
7 | from skimage import morphology
8 | import scipy
9 | from PIL import Image
10 | from matplotlib.pyplot import imsave
11 | # from keras.preprocessing import image
12 | from skimage.measure import label, regionprops
13 | from skimage.transform import rotate, resize
14 | from skimage import measure, draw
15 |
16 | import matplotlib.pyplot as plt
17 | plt.switch_backend('agg')
18 |
19 | # from scipy.misc import imsave
20 | from utils.metrics import *
21 | import cv2
22 |
23 |
24 | def construct_color_img(prob_per_slice):
25 | shape = prob_per_slice.shape
26 | img = np.zeros((shape[0], shape[1], 3), dtype=np.uint8)
27 | img[:, :, 0] = prob_per_slice * 255
28 | img[:, :, 1] = prob_per_slice * 255
29 | img[:, :, 2] = prob_per_slice * 255
30 |
31 | im_color = cv2.applyColorMap(img, cv2.COLORMAP_JET)
32 | return im_color
33 |
34 |
35 | def normalize_ent(ent):
36 | '''
37 | Normalizate ent to 0 - 1
38 | :param ent:
39 | :return:
40 | '''
41 | min = np.amin(ent)
42 | return (ent - min) / 0.4
43 |
44 |
45 | def draw_ent(prediction, save_root, name):
46 | '''
47 | Draw the entropy information for each img and save them to the save path
48 | :param prediction: [2, h, w] numpy
49 | :param save_path: string including img name
50 | :return: None
51 | '''
52 | if not os.path.exists(os.path.join(save_root, 'disc')):
53 | os.makedirs(os.path.join(save_root, 'disc'))
54 | if not os.path.exists(os.path.join(save_root, 'cup')):
55 | os.makedirs(os.path.join(save_root, 'cup'))
56 | smooth = 1e-8
57 | cup = prediction[0]
58 | disc = prediction[1]
59 | cup_ent = - cup * np.log(cup + smooth)
60 | disc_ent = - disc * np.log(disc + smooth)
61 | cup_ent = normalize_ent(cup_ent)
62 | disc_ent = normalize_ent(disc_ent)
63 | disc = construct_color_img(disc_ent)
64 | cv2.imwrite(os.path.join(save_root, 'disc', name.split('.')[0]) + '.png', disc)
65 | cup = construct_color_img(cup_ent)
66 | cv2.imwrite(os.path.join(save_root, 'cup', name.split('.')[0]) + '.png', cup)
67 |
68 |
69 | def draw_mask(prediction, save_root, name):
70 | '''
71 | Draw the mask probability for each img and save them to the save path
72 | :param prediction: [2, h, w] numpy
73 | :param save_path: string including img name
74 | :return: None
75 | '''
76 | if not os.path.exists(os.path.join(save_root, 'disc')):
77 | os.makedirs(os.path.join(save_root, 'disc'))
78 | if not os.path.exists(os.path.join(save_root, 'cup')):
79 | os.makedirs(os.path.join(save_root, 'cup'))
80 | cup = prediction[0]
81 | disc = prediction[1]
82 |
83 | disc = construct_color_img(disc)
84 | cv2.imwrite(os.path.join(save_root, 'disc', name.split('.')[0]) + '.png', disc)
85 | cup = construct_color_img(cup)
86 | cv2.imwrite(os.path.join(save_root, 'cup', name.split('.')[0]) + '.png', cup)
87 |
88 | def draw_boundary(prediction, save_root, name):
89 | '''
90 | Draw the mask probability for each img and save them to the save path
91 | :param prediction: [2, h, w] numpy
92 | :param save_path: string including img name
93 | :return: None
94 | '''
95 | if not os.path.exists(os.path.join(save_root, 'boundary')):
96 | os.makedirs(os.path.join(save_root, 'boundary'))
97 | boundary = prediction[0]
98 | boundary = construct_color_img(boundary)
99 | cv2.imwrite(os.path.join(save_root, 'boundary', name.split('.')[0]) + '.png', boundary)
100 |
101 |
102 | def get_largest_fillhole(binary):
103 | label_image = label(binary)
104 | regions = regionprops(label_image)
105 | area_list = []
106 | for region in regions:
107 | area_list.append(region.area)
108 | if area_list:
109 | idx_max = np.argmax(area_list)
110 | binary[label_image != idx_max + 1] = 0
111 | return scipy.ndimage.binary_fill_holes(np.asarray(binary).astype(int))
112 |
113 | def postprocessing(prediction, threshold=0.75, dataset='G'):
114 | if dataset[0] == 'D':
115 | prediction = prediction.numpy()
116 | prediction_copy = np.copy(prediction)
117 | disc_mask = prediction[1]
118 | cup_mask = prediction[0]
119 | disc_mask = (disc_mask > 0.5) # return binary mask
120 | cup_mask = (cup_mask > 0.1) # return binary mask
121 | disc_mask = disc_mask.astype(np.uint8)
122 | cup_mask = cup_mask.astype(np.uint8)
123 | for i in range(5):
124 | disc_mask = scipy.signal.medfilt2d(disc_mask, 7)
125 | cup_mask = scipy.signal.medfilt2d(cup_mask, 7)
126 | disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1
127 | cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1
128 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1
129 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8)
130 | prediction_copy[0] = cup_mask
131 | prediction_copy[1] = disc_mask
132 | return prediction_copy
133 | else:
134 | prediction = prediction.numpy()
135 | prediction = (prediction > threshold) # return binary mask
136 | prediction = prediction.astype(np.uint8)
137 | prediction_copy = np.copy(prediction)
138 | disc_mask = prediction[1]
139 | cup_mask = prediction[0]
140 | # for i in range(5):
141 | # disc_mask = scipy.signal.medfilt2d(disc_mask, 7)
142 | # cup_mask = scipy.signal.medfilt2d(cup_mask, 7)
143 | # disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1
144 | # cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1
145 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1
146 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8)
147 | prediction_copy[0] = cup_mask
148 | prediction_copy[1] = disc_mask
149 | return prediction_copy
150 |
151 |
152 | def joint_val_image(image, prediction, mask):
153 | ratio = 0.5
154 | _pred_cup = np.zeros([mask.shape[-2], mask.shape[-1], 3])
155 | _pred_disc = np.zeros([mask.shape[-2], mask.shape[-1], 3])
156 | _mask = np.zeros([mask.shape[-2], mask.shape[-1], 3])
157 | image = np.transpose(image, (1, 2, 0))
158 |
159 | _pred_cup[:, :, 0] = prediction[0]
160 | _pred_cup[:, :, 1] = prediction[0]
161 | _pred_cup[:, :, 2] = prediction[0]
162 | _pred_disc[:, :, 0] = prediction[1]
163 | _pred_disc[:, :, 1] = prediction[1]
164 | _pred_disc[:, :, 2] = prediction[1]
165 | _mask[:,:,0] = mask[0]
166 | _mask[:,:,1] = mask[1]
167 |
168 | pred_cup = np.add(ratio * image, (1 - ratio) * _pred_cup)
169 | pred_disc = np.add(ratio * image, (1 - ratio) * _pred_disc)
170 | mask_img = np.add(ratio * image, (1 - ratio) * _mask)
171 |
172 | joint_img = np.concatenate([image, mask_img, pred_cup, pred_disc], axis=1)
173 | return joint_img
174 |
175 |
176 | def save_val_img(path, epoch, img):
177 | name = osp.join(path, "visualization", "epoch_%d.png" % epoch)
178 | out = osp.join(path, "visualization")
179 | if not osp.exists(out):
180 | os.makedirs(out)
181 | img_shape = img[0].shape
182 | stack_image = np.zeros([len(img) * img_shape[0], img_shape[1], img_shape[2]])
183 | for i in range(len(img)):
184 | stack_image[i * img_shape[0] : (i + 1) * img_shape[0], :, : ] = img[i]
185 | imsave(name, stack_image)
186 |
187 |
188 |
189 |
190 | def save_per_img(patch_image, data_save_path, img_name, prob_map, mask_path=None, ext="bmp"):
191 | path1 = os.path.join(data_save_path, 'overlay', img_name.split('.')[0]+'.png')
192 | path0 = os.path.join(data_save_path, 'original_image', img_name.split('.')[0]+'.png')
193 | if not os.path.exists(os.path.dirname(path0)):
194 | os.makedirs(os.path.dirname(path0))
195 | if not os.path.exists(os.path.dirname(path1)):
196 | os.makedirs(os.path.dirname(path1))
197 |
198 | disc_map = prob_map[0]
199 | cup_map = prob_map[1]
200 | size = disc_map.shape
201 | disc_map[:, 0] = np.zeros(size[0])
202 | disc_map[:, size[1] - 1] = np.zeros(size[0])
203 | disc_map[0, :] = np.zeros(size[1])
204 | disc_map[size[0] - 1, :] = np.zeros(size[1])
205 | size = cup_map.shape
206 | cup_map[:, 0] = np.zeros(size[0])
207 | cup_map[:, size[1] - 1] = np.zeros(size[0])
208 | cup_map[0, :] = np.zeros(size[1])
209 | cup_map[size[0] - 1, :] = np.zeros(size[1])
210 |
211 | disc_mask = (disc_map > 0.75) # return binary mask
212 | cup_mask = (cup_map > 0.75)
213 | disc_mask = disc_mask.astype(np.uint8)
214 | cup_mask = cup_mask.astype(np.uint8)
215 |
216 | for i in range(5):
217 | disc_mask = scipy.signal.medfilt2d(disc_mask, 7)
218 | cup_mask = scipy.signal.medfilt2d(cup_mask, 7)
219 | disc_mask = morphology.binary_erosion(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1
220 | cup_mask = morphology.binary_erosion(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1
221 | disc_mask = get_largest_fillhole(disc_mask)
222 | cup_mask = get_largest_fillhole(cup_mask)
223 |
224 | disc_mask = morphology.binary_dilation(disc_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1
225 | cup_mask = morphology.binary_dilation(cup_mask, morphology.diamond(7)).astype(np.uint8) # return 0,1
226 |
227 | disc_mask = get_largest_fillhole(disc_mask).astype(np.uint8) # return 0,1
228 | cup_mask = get_largest_fillhole(cup_mask).astype(np.uint8)
229 |
230 |
231 | contours_disc = measure.find_contours(disc_mask, 0.5)
232 | contours_cup = measure.find_contours(cup_mask, 0.5)
233 |
234 | patch_image2 = patch_image.astype(np.uint8)
235 | patch_image2 = Image.fromarray(patch_image2)
236 |
237 | patch_image2.save(path0)
238 |
239 | for n, contour in enumerate(contours_cup):
240 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0]
241 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0]
242 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 255, 0]
243 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 255, 0]
244 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 255, 0]
245 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 255, 0]
246 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 255, 0]
247 |
248 | for n, contour in enumerate(contours_disc):
249 | patch_image[contour[:, 0].astype(int), contour[:, 1].astype(int), :] = [0, 0, 255]
250 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 0, 255]
251 | patch_image[(contour[:, 0] + 1.0).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 0, 255]
252 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] + 1.0).astype(int), :] = [0, 0, 255]
253 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1]).astype(int), :] = [0, 0, 255]
254 | patch_image[(contour[:, 0] - 1.0).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 0, 255]
255 | patch_image[(contour[:, 0]).astype(int), (contour[:, 1] - 1.0).astype(int), :] = [0, 0, 255]
256 |
257 | patch_image = patch_image.astype(np.uint8)
258 | patch_image = Image.fromarray(patch_image)
259 |
260 | patch_image.save(path1)
261 |
262 | def untransform(img, lt):
263 | img = (img + 1) * 127.5
264 | lt = lt * 128
265 | return img, lt
--------------------------------------------------------------------------------
/networks/backbone/xception.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.utils.model_zoo as model_zoo
6 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
7 |
8 | def fixed_padding(inputs, kernel_size, dilation):
9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
10 | pad_total = kernel_size_effective - 1
11 | pad_beg = pad_total // 2
12 | pad_end = pad_total - pad_beg
13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
14 | return padded_inputs
15 |
16 |
17 | class SeparableConv2d(nn.Module):
18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None):
19 | super(SeparableConv2d, self).__init__()
20 |
21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation,
22 | groups=inplanes, bias=bias)
23 | self.bn = BatchNorm(inplanes)
24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias)
25 |
26 | def forward(self, x):
27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0])
28 | x = self.conv1(x)
29 | x = self.bn(x)
30 | x = self.pointwise(x)
31 | return x
32 |
33 |
34 | class Block(nn.Module):
35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None,
36 | start_with_relu=True, grow_first=True, is_last=False):
37 | super(Block, self).__init__()
38 |
39 | if planes != inplanes or stride != 1:
40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False)
41 | self.skipbn = BatchNorm(planes)
42 | else:
43 | self.skip = None
44 |
45 | self.relu = nn.ReLU(inplace=True)
46 | rep = []
47 |
48 | filters = inplanes
49 | if grow_first:
50 | rep.append(self.relu)
51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
52 | rep.append(BatchNorm(planes))
53 | filters = planes
54 |
55 | for i in range(reps - 1):
56 | rep.append(self.relu)
57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm))
58 | rep.append(BatchNorm(filters))
59 |
60 | if not grow_first:
61 | rep.append(self.relu)
62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm))
63 | rep.append(BatchNorm(planes))
64 |
65 | if stride != 1:
66 | rep.append(self.relu)
67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm))
68 | rep.append(BatchNorm(planes))
69 |
70 | if stride == 1 and is_last:
71 | rep.append(self.relu)
72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm))
73 | rep.append(BatchNorm(planes))
74 |
75 | if not start_with_relu:
76 | rep = rep[1:]
77 |
78 | self.rep = nn.Sequential(*rep)
79 |
80 | def forward(self, inp):
81 | x = self.rep(inp)
82 |
83 | if self.skip is not None:
84 | skip = self.skip(inp)
85 | skip = self.skipbn(skip)
86 | else:
87 | skip = inp
88 |
89 | x = x + skip
90 |
91 | return x
92 |
93 |
94 | class AlignedXception(nn.Module):
95 | """
96 | Modified Alighed Xception
97 | """
98 | def __init__(self, output_stride, BatchNorm,
99 | pretrained=True):
100 | super(AlignedXception, self).__init__()
101 |
102 | if output_stride == 16:
103 | entry_block3_stride = 2
104 | middle_block_dilation = 1
105 | exit_block_dilations = (1, 2)
106 | elif output_stride == 8:
107 | entry_block3_stride = 1
108 | middle_block_dilation = 2
109 | exit_block_dilations = (2, 4)
110 | else:
111 | raise NotImplementedError
112 |
113 |
114 | # Entry flow
115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False)
116 | self.bn1 = BatchNorm(32)
117 | self.relu = nn.ReLU(inplace=True)
118 |
119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)
120 | self.bn2 = BatchNorm(64)
121 |
122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False)
123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False,
124 | grow_first=True)
125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm,
126 | start_with_relu=True, grow_first=True, is_last=True)
127 |
128 | # Middle flow
129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation,
160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True)
161 |
162 | # Exit flow
163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0],
164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True)
165 |
166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
167 | self.bn3 = BatchNorm(1536)
168 |
169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
170 | self.bn4 = BatchNorm(1536)
171 |
172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm)
173 | self.bn5 = BatchNorm(2048)
174 |
175 | # Init weights
176 | self._init_weight()
177 |
178 | # Load pretrained model
179 | if pretrained:
180 | self._load_pretrained_model()
181 |
182 | def forward(self, x):
183 | # Entry flow
184 | x = self.conv1(x)
185 | x = self.bn1(x)
186 | x = self.relu(x)
187 |
188 | x = self.conv2(x)
189 | x = self.bn2(x)
190 | x = self.relu(x)
191 |
192 | x = self.block1(x)
193 | # add relu here
194 | x = self.relu(x)
195 | low_level_feat = x
196 | x = self.block2(x)
197 | x = self.block3(x)
198 |
199 | # Middle flow
200 | x = self.block4(x)
201 | x = self.block5(x)
202 | x = self.block6(x)
203 | x = self.block7(x)
204 | x = self.block8(x)
205 | x = self.block9(x)
206 | x = self.block10(x)
207 | x = self.block11(x)
208 | x = self.block12(x)
209 | x = self.block13(x)
210 | x = self.block14(x)
211 | x = self.block15(x)
212 | x = self.block16(x)
213 | x = self.block17(x)
214 | x = self.block18(x)
215 | x = self.block19(x)
216 |
217 | # Exit flow
218 | x = self.block20(x)
219 | x = self.relu(x)
220 | x = self.conv3(x)
221 | x = self.bn3(x)
222 | x = self.relu(x)
223 |
224 | x = self.conv4(x)
225 | x = self.bn4(x)
226 | x = self.relu(x)
227 |
228 | x = self.conv5(x)
229 | x = self.bn5(x)
230 | x = self.relu(x)
231 |
232 | return x, low_level_feat
233 |
234 | def _init_weight(self):
235 | for m in self.modules():
236 | if isinstance(m, nn.Conv2d):
237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
238 | m.weight.data.normal_(0, math.sqrt(2. / n))
239 | elif isinstance(m, SynchronizedBatchNorm2d):
240 | m.weight.data.fill_(1)
241 | m.bias.data.zero_()
242 | elif isinstance(m, nn.BatchNorm2d):
243 | m.weight.data.fill_(1)
244 | m.bias.data.zero_()
245 |
246 |
247 | def _load_pretrained_model(self):
248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth')
249 | model_dict = {}
250 | state_dict = self.state_dict()
251 |
252 | for k, v in pretrain_dict.items():
253 | if k in model_dict:
254 | if 'pointwise' in k:
255 | v = v.unsqueeze(-1).unsqueeze(-1)
256 | if k.startswith('block11'):
257 | model_dict[k] = v
258 | model_dict[k.replace('block11', 'block12')] = v
259 | model_dict[k.replace('block11', 'block13')] = v
260 | model_dict[k.replace('block11', 'block14')] = v
261 | model_dict[k.replace('block11', 'block15')] = v
262 | model_dict[k.replace('block11', 'block16')] = v
263 | model_dict[k.replace('block11', 'block17')] = v
264 | model_dict[k.replace('block11', 'block18')] = v
265 | model_dict[k.replace('block11', 'block19')] = v
266 | elif k.startswith('block12'):
267 | model_dict[k.replace('block12', 'block20')] = v
268 | elif k.startswith('bn3'):
269 | model_dict[k] = v
270 | model_dict[k.replace('bn3', 'bn4')] = v
271 | elif k.startswith('conv4'):
272 | model_dict[k.replace('conv4', 'conv5')] = v
273 | elif k.startswith('bn4'):
274 | model_dict[k.replace('bn4', 'bn5')] = v
275 | else:
276 | model_dict[k] = v
277 | state_dict.update(model_dict)
278 | self.load_state_dict(state_dict)
279 |
280 |
281 |
282 | if __name__ == "__main__":
283 | import torch
284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16)
285 | input = torch.rand(1, 3, 512, 512)
286 | output, low_level_feat = model(input)
287 | print(output.size())
288 | print(low_level_feat.size())
289 |
--------------------------------------------------------------------------------
/networks/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | def forward(self, input):
49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50 | if not (self._is_parallel and self.training):
51 | return F.batch_norm(
52 | input, self.running_mean, self.running_var, self.weight, self.bias,
53 | self.training, self.momentum, self.eps)
54 |
55 | # Resize the input to (B, C, -1).
56 | input_shape = input.size()
57 | input = input.view(input.size(0), self.num_features, -1)
58 |
59 | # Compute the sum and square-sum.
60 | sum_size = input.size(0) * input.size(2)
61 | input_sum = _sum_ft(input)
62 | input_ssum = _sum_ft(input ** 2)
63 |
64 | # Reduce-and-broadcast the statistics.
65 | if self._parallel_id == 0:
66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67 | else:
68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69 |
70 | # Compute the output.
71 | if self.affine:
72 | # MJY:: Fuse the multiplication for speed.
73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74 | else:
75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76 |
77 | # Reshape it.
78 | return output.view(input_shape)
79 |
80 | def __data_parallel_replicate__(self, ctx, copy_id):
81 | self._is_parallel = True
82 | self._parallel_id = copy_id
83 |
84 | # parallel_id == 0 means master device.
85 | if self._parallel_id == 0:
86 | ctx.sync_master = self._sync_master
87 | else:
88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89 |
90 | def _data_parallel_master(self, intermediates):
91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92 |
93 | # Always using same "device order" makes the ReduceAdd operation faster.
94 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96 |
97 | to_reduce = [i[1][:2] for i in intermediates]
98 | to_reduce = [j for i in to_reduce for j in i] # flatten
99 | target_gpus = [i[1].sum.get_device() for i in intermediates]
100 |
101 | sum_size = sum([i[1].sum_size for i in intermediates])
102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104 |
105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106 |
107 | outputs = []
108 | for i, rec in enumerate(intermediates):
109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))
110 |
111 | return outputs
112 |
113 | def _compute_mean_std(self, sum_, ssum, size):
114 | """Compute the mean and standard-deviation with sum and square-sum. This method
115 | also maintains the moving average on the master device."""
116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117 | mean = sum_ / size
118 | sumvar = ssum - sum_ * mean
119 | unbias_var = sumvar / (size - 1)
120 | bias_var = sumvar / size
121 |
122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124 |
125 | return mean, bias_var.clamp(self.eps) ** -0.5
126 |
127 |
128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130 | mini-batch.
131 | .. math::
132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
134 | standard-deviation are reduced across all devices during training.
135 | For example, when one uses `nn.DataParallel` to wrap the network during
136 | training, PyTorch's implementation normalize the tensor on each device using
137 | the statistics only on that device, which accelerated the computation and
138 | is also easy to implement, but the statistics might be inaccurate.
139 | Instead, in this synchronized version, the statistics will be computed
140 | over all training samples distributed on multiple devices.
141 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
142 | as the built-in PyTorch implementation.
143 | The mean and standard-deviation are calculated per-dimension over
144 | the mini-batches and gamma and beta are learnable parameter vectors
145 | of size C (where C is the input size).
146 | During training, this layer keeps a running estimate of its computed mean
147 | and variance. The running sum is kept with a default momentum of 0.1.
148 | During evaluation, this running mean/variance is used for normalization.
149 | Because the BatchNorm is done over the `C` dimension, computing statistics
150 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
151 | Args:
152 | num_features: num_features from an expected input of size
153 | `batch_size x num_features [x width]`
154 | eps: a value added to the denominator for numerical stability.
155 | Default: 1e-5
156 | momentum: the value used for the running_mean and running_var
157 | computation. Default: 0.1
158 | affine: a boolean value that when set to ``True``, gives the layer learnable
159 | affine parameters. Default: ``True``
160 | Shape:
161 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
162 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
163 | Examples:
164 | >>> # With Learnable Parameters
165 | >>> m = SynchronizedBatchNorm1d(100)
166 | >>> # Without Learnable Parameters
167 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
168 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
169 | >>> output = m(input)
170 | """
171 |
172 | def _check_input_dim(self, input):
173 | if input.dim() != 2 and input.dim() != 3:
174 | raise ValueError('expected 2D or 3D input (got {}D input)'
175 | .format(input.dim()))
176 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
177 |
178 |
179 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
180 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
181 | of 3d inputs
182 | .. math::
183 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
184 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
185 | standard-deviation are reduced across all devices during training.
186 | For example, when one uses `nn.DataParallel` to wrap the network during
187 | training, PyTorch's implementation normalize the tensor on each device using
188 | the statistics only on that device, which accelerated the computation and
189 | is also easy to implement, but the statistics might be inaccurate.
190 | Instead, in this synchronized version, the statistics will be computed
191 | over all training samples distributed on multiple devices.
192 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
193 | as the built-in PyTorch implementation.
194 | The mean and standard-deviation are calculated per-dimension over
195 | the mini-batches and gamma and beta are learnable parameter vectors
196 | of size C (where C is the input size).
197 | During training, this layer keeps a running estimate of its computed mean
198 | and variance. The running sum is kept with a default momentum of 0.1.
199 | During evaluation, this running mean/variance is used for normalization.
200 | Because the BatchNorm is done over the `C` dimension, computing statistics
201 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
202 | Args:
203 | num_features: num_features from an expected input of
204 | size batch_size x num_features x height x width
205 | eps: a value added to the denominator for numerical stability.
206 | Default: 1e-5
207 | momentum: the value used for the running_mean and running_var
208 | computation. Default: 0.1
209 | affine: a boolean value that when set to ``True``, gives the layer learnable
210 | affine parameters. Default: ``True``
211 | Shape:
212 | - Input: :math:`(N, C, H, W)`
213 | - Output: :math:`(N, C, H, W)` (same shape as input)
214 | Examples:
215 | >>> # With Learnable Parameters
216 | >>> m = SynchronizedBatchNorm2d(100)
217 | >>> # Without Learnable Parameters
218 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
219 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
220 | >>> output = m(input)
221 | """
222 |
223 | def _check_input_dim(self, input):
224 | if input.dim() != 4:
225 | raise ValueError('expected 4D input (got {}D input)'
226 | .format(input.dim()))
227 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
228 |
229 |
230 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
231 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
232 | of 4d inputs
233 | .. math::
234 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
235 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
236 | standard-deviation are reduced across all devices during training.
237 | For example, when one uses `nn.DataParallel` to wrap the network during
238 | training, PyTorch's implementation normalize the tensor on each device using
239 | the statistics only on that device, which accelerated the computation and
240 | is also easy to implement, but the statistics might be inaccurate.
241 | Instead, in this synchronized version, the statistics will be computed
242 | over all training samples distributed on multiple devices.
243 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
244 | as the built-in PyTorch implementation.
245 | The mean and standard-deviation are calculated per-dimension over
246 | the mini-batches and gamma and beta are learnable parameter vectors
247 | of size C (where C is the input size).
248 | During training, this layer keeps a running estimate of its computed mean
249 | and variance. The running sum is kept with a default momentum of 0.1.
250 | During evaluation, this running mean/variance is used for normalization.
251 | Because the BatchNorm is done over the `C` dimension, computing statistics
252 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
253 | or Spatio-temporal BatchNorm
254 | Args:
255 | num_features: num_features from an expected input of
256 | size batch_size x num_features x depth x height x width
257 | eps: a value added to the denominator for numerical stability.
258 | Default: 1e-5
259 | momentum: the value used for the running_mean and running_var
260 | computation. Default: 0.1
261 | affine: a boolean value that when set to ``True``, gives the layer learnable
262 | affine parameters. Default: ``True``
263 | Shape:
264 | - Input: :math:`(N, C, D, H, W)`
265 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
266 | Examples:
267 | >>> # With Learnable Parameters
268 | >>> m = SynchronizedBatchNorm3d(100)
269 | >>> # Without Learnable Parameters
270 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
271 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
272 | >>> output = m(input)
273 | """
274 |
275 | def _check_input_dim(self, input):
276 | if input.dim() != 5:
277 | raise ValueError('expected 5D input (got {}D input)'
278 | .format(input.dim()))
279 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
--------------------------------------------------------------------------------
/train_target.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser()
3 | parser.add_argument('-g', '--gpu', type=str, default='0')
4 | parser.add_argument('--model-file', type=str, default='./logs_train/Domain3/source_model.pth.tar')
5 | parser.add_argument('--model', type=str, default='Deeplab', help='Deeplab')
6 | parser.add_argument('--out-stride', type=int, default=16)
7 | parser.add_argument('--sync-bn', type=bool, default=True)
8 | parser.add_argument('--freeze-bn', type=bool, default=False)
9 | parser.add_argument('--epoch', type=int, default=20)
10 | parser.add_argument('--lr', type=float, default=5e-4)
11 | parser.add_argument('--lr-decrease-rate', type=float, default=0.9, help='ratio multiplied to initial lr')
12 | parser.add_argument('--lr-decrease-epoch', type=int, default=1, help='interval epoch number for lr decrease')
13 |
14 | parser.add_argument('--data-dir', default='../Datasets/Fundus')
15 | parser.add_argument('--dataset', type=str, default='Domain2')
16 | parser.add_argument('--model-source', type=str, default='Domain3')
17 | parser.add_argument('--batch-size', type=int, default=8)
18 |
19 | parser.add_argument('--model-ema-rate', type=float, default=0.98)
20 | parser.add_argument('--pseudo-label-threshold', type=float, default=0.75)
21 | parser.add_argument('--mean-loss-calc-bound-ratio', type=float, default=0.2)
22 |
23 | args = parser.parse_args()
24 |
25 | import os
26 |
27 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
28 |
29 | import os.path as osp
30 |
31 | import numpy as np
32 | import torch.nn.functional as F
33 |
34 | import torch
35 | from torch.autograd import Variable
36 | import tqdm
37 | from torch.utils.data import DataLoader
38 | from dataloaders import fundus_dataloader
39 | from dataloaders import custom_transforms as trans
40 | from torchvision import transforms
41 | # from scipy.misc import imsave
42 | from matplotlib.pyplot import imsave
43 | from utils.Utils import *
44 | from utils.metrics import *
45 | from datetime import datetime
46 | import pytz
47 | import networks.deeplabv3 as netd
48 | import cv2
49 | import torch.backends.cudnn as cudnn
50 | import random
51 | import glob
52 | import sys
53 |
54 | seed = 42
55 | savefig = False
56 | get_hd = True
57 | model_save = True
58 | cudnn.benchmark = False
59 | cudnn.deterministic = True
60 | random.seed(seed)
61 | np.random.seed(seed)
62 | torch.manual_seed(seed)
63 | torch.cuda.manual_seed(seed)
64 |
65 | def print_args(args):
66 | s = "==========================================\n"
67 | for arg, content in args.__dict__.items():
68 | s += "{}:{}\n".format(arg, content)
69 | return s
70 |
71 |
72 | def soft_label_to_hard(soft_pls, pseudo_label_threshold):
73 | pseudo_labels = torch.zeros(soft_pls.size())
74 | if torch.cuda.is_available():
75 | pseudo_labels = pseudo_labels.cuda()
76 | pseudo_labels[soft_pls > pseudo_label_threshold] = 1
77 | pseudo_labels[soft_pls <= pseudo_label_threshold] = 0
78 |
79 | return pseudo_labels
80 |
81 |
82 | def init_feature_pred_bank(model, loader):
83 | feature_bank = {}
84 | pred_bank = {}
85 |
86 | model.eval()
87 |
88 | with torch.no_grad():
89 | for sample in loader:
90 | data = sample['image']
91 | img_name = sample['img_name']
92 | data = data.cuda()
93 |
94 | pred, feat = model(data)
95 | pred = torch.sigmoid(pred)
96 |
97 | for i in range(data.size(0)):
98 | feature_bank[img_name[i]] = feat[i].detach().clone()
99 | pred_bank[img_name[i]] = pred[i].detach().clone()
100 |
101 | model.train()
102 |
103 | return feature_bank, pred_bank
104 |
105 |
106 | def adapt_epoch(model_t, model_s, optim, train_loader, args, feature_bank, pred_bank, loss_weight=None):
107 | for sample_w, sample_s in train_loader:
108 | imgs_w = sample_w['image']
109 | imgs_s = sample_s['image']
110 | img_name = sample_w['img_name']
111 | if torch.cuda.is_available():
112 | imgs_w = imgs_w.cuda()
113 | imgs_s = imgs_s.cuda()
114 |
115 | # model predict
116 | predictions_stu_s, features_stu_s = model_s(imgs_s)
117 | with torch.no_grad():
118 | predictions_tea_w, features_tea_w = model_t(imgs_w)
119 |
120 | predictions_stu_s_sigmoid = torch.sigmoid(predictions_stu_s)
121 | predictions_tea_w_sigmoid = torch.sigmoid(predictions_tea_w)
122 |
123 | # get hard pseudo label
124 | pseudo_labels = soft_label_to_hard(predictions_tea_w_sigmoid, args.pseudo_label_threshold)
125 |
126 |
127 | bceloss = torch.nn.BCELoss(reduction='none')
128 | loss_seg_pixel = bceloss(predictions_stu_s_sigmoid, pseudo_labels)
129 |
130 | mean_loss_weight_mask = torch.ones(pseudo_labels.size()).cuda()
131 | mean_loss_weight_mask[:, 0, ...][pseudo_labels[:, 0, ...] == 0] = loss_weight
132 | loss_mask = mean_loss_weight_mask
133 |
134 | loss = torch.sum(loss_seg_pixel * loss_mask) / torch.sum(loss_mask)
135 |
136 | loss.backward()
137 | optim.step()
138 | optim.zero_grad()
139 |
140 | # update teacher
141 | for param_s, param_t in zip(model_s.parameters(), model_t.parameters()):
142 | param_t.data = param_t.data.clone() * args.model_ema_rate + param_s.data.clone() * (1.0 - args.model_ema_rate)
143 |
144 | # update feature/pred bank
145 | for idx in range(len(img_name)):
146 | feature_bank[img_name[idx]] = features_tea_w[idx].detach().clone()
147 | pred_bank[img_name[idx]] = predictions_tea_w_sigmoid[idx].detach().clone()
148 |
149 |
150 | def eval(model, data_loader):
151 | model.eval()
152 |
153 | val_dice = {'cup': np.array([]), 'disc': np.array([])}
154 | val_assd = {'cup': np.array([]), 'disc': np.array([])}
155 |
156 | with torch.no_grad():
157 | for batch_idx, sample in enumerate(data_loader):
158 | data = sample['image']
159 | target_map = sample['label']
160 | data = data.cuda()
161 | predictions, _ = model(data)
162 |
163 | dice_cup, dice_disc = dice_coeff_2label(predictions, target_map)
164 | val_dice['cup'] = np.append(val_dice['cup'], dice_cup)
165 | val_dice['disc'] = np.append(val_dice['disc'], dice_disc)
166 |
167 | assd = assd_compute(predictions, target_map)
168 | val_assd['cup'] = np.append(val_assd['cup'], assd[:, 0])
169 | val_assd['disc'] = np.append(val_assd['disc'], assd[:, 1])
170 |
171 | avg_dice = [0.0, 0.0, 0.0, 0.0]
172 | std_dice = [0.0, 0.0, 0.0, 0.0]
173 | avg_assd = [0.0, 0.0, 0.0, 0.0]
174 | std_assd = [0.0, 0.0, 0.0, 0.0]
175 | avg_dice[0] = np.mean(val_dice['cup'])
176 | avg_dice[1] = np.mean(val_dice['disc'])
177 | std_dice[0] = np.std(val_dice['cup'])
178 | std_dice[1] = np.std(val_dice['disc'])
179 | val_assd['cup'] = np.delete(val_assd['cup'], np.where(val_assd['cup'] == -1))
180 | val_assd['disc'] = np.delete(val_assd['disc'], np.where(val_assd['disc'] == -1))
181 | avg_assd[0] = np.mean(val_assd['cup'])
182 | avg_assd[1] = np.mean(val_assd['disc'])
183 | std_assd[0] = np.std(val_assd['cup'])
184 | std_assd[1] = np.std(val_assd['disc'])
185 |
186 | model.train()
187 |
188 | return avg_dice, std_dice, avg_assd, std_assd
189 |
190 |
191 | def main():
192 | now = datetime.now()
193 | here = osp.dirname(osp.abspath(__file__))
194 | args.out = osp.join(here, 'logs_target', args.dataset, now.strftime('%Y%m%d_%H%M%S.%f'))
195 | if not osp.exists(args.out):
196 | os.makedirs(args.out)
197 | args.out_file = open(osp.join(args.out, now.strftime('%Y%m%d_%H%M%S.%f')+'.txt'), 'w')
198 | args.out_file.write(' '.join(sys.argv) + '\n')
199 | args.out_file.write(print_args(args) + '\n')
200 | args.out_file.flush()
201 |
202 | # dataset
203 | composed_transforms_train = transforms.Compose([
204 | trans.Resize(512),
205 | trans.add_salt_pepper_noise(),
206 | trans.adjust_light(),
207 | trans.eraser(),
208 | trans.Normalize_tf(),
209 | trans.ToTensor()
210 | ])
211 | composed_transforms_test = transforms.Compose([
212 | trans.Resize(512),
213 | trans.Normalize_tf(),
214 | trans.ToTensor()
215 | ])
216 |
217 | dataset_train = fundus_dataloader.FundusSegmentation_2transform(base_dir=args.data_dir, dataset=args.dataset,
218 | split='train/ROIs',
219 | transform_weak=composed_transforms_test,
220 | transform_strong=composed_transforms_train)
221 | dataset_train_weak = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset,
222 | split='train/ROIs',
223 | transform=composed_transforms_test)
224 | dataset_test = fundus_dataloader.FundusSegmentation(base_dir=args.data_dir, dataset=args.dataset, split='test/ROIs',
225 | transform=composed_transforms_test)
226 |
227 | train_loader = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=2)
228 | train_loader_weak = DataLoader(dataset_train_weak, batch_size=args.batch_size, shuffle=False, num_workers=2)
229 | test_loader = DataLoader(dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=2)
230 |
231 | # model
232 | model_s = netd.DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn,
233 | freeze_bn=args.freeze_bn)
234 | model_t = netd.DeepLab(num_classes=2, backbone='mobilenet', output_stride=args.out_stride, sync_bn=args.sync_bn,
235 | freeze_bn=args.freeze_bn)
236 |
237 |
238 | if torch.cuda.is_available():
239 | model_s = model_s.cuda()
240 | model_t = model_t.cuda()
241 | log_str = '==> Loading %s model file: %s' % (model_s.__class__.__name__, args.model_file)
242 | print(log_str)
243 | args.out_file.write(log_str + '\n')
244 | args.out_file.flush()
245 | checkpoint = torch.load(args.model_file)
246 | model_s.load_state_dict(checkpoint['model_state_dict'])
247 | model_t.load_state_dict(checkpoint['model_state_dict'])
248 |
249 | if (args.gpu).find(',') != -1:
250 | model_s = torch.nn.DataParallel(model_s, device_ids=[0, 1])
251 | model_t = torch.nn.DataParallel(model_t, device_ids=[0, 1])
252 |
253 | optim = torch.optim.Adam(model_s.parameters(), lr=args.lr, betas=(0.9, 0.99))
254 | scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=args.lr_decrease_epoch, gamma=args.lr_decrease_rate)
255 |
256 | model_s.train()
257 | model_t.train()
258 | for param in model_t.parameters():
259 | param.requires_grad = False
260 |
261 |
262 | feature_bank, pred_bank = init_feature_pred_bank(model_s, train_loader_weak)
263 |
264 | avg_dice, std_dice, avg_assd, std_assd = eval(model_t, test_loader)
265 | log_str = ("initial dice: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f, assd: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f" % (
266 | avg_dice[0], std_dice[0], avg_dice[1], std_dice[1], (avg_dice[0] + avg_dice[1]) / 2.0,
267 | avg_assd[0], std_assd[0], avg_assd[1], std_assd[1], (avg_assd[0] + avg_assd[1]) / 2.0))
268 | print(log_str)
269 | args.out_file.write(log_str + '\n')
270 | args.out_file.flush()
271 |
272 | for epoch in range(args.epoch):
273 |
274 | log_str = '\nepoch {}/{}:'.format(epoch+1, args.epoch)
275 | print(log_str)
276 | args.out_file.write(log_str + '\n')
277 | args.out_file.flush()
278 |
279 | not_cup_loss_sum = torch.FloatTensor([0]).cuda()
280 | cup_loss_sum = torch.FloatTensor([0]).cuda()
281 | not_cup_loss_num = 0
282 | cup_loss_num = 0
283 | lower_bound = args.pseudo_label_threshold * args.mean_loss_calc_bound_ratio
284 | upper_bound = 1 - ((1 - args.pseudo_label_threshold) * args.mean_loss_calc_bound_ratio)
285 | for pred_i in pred_bank.values():
286 | not_cup_loss_sum += torch.sum(
287 | -torch.log(1 - pred_i[0, ...][(pred_i[0, ...] < args.pseudo_label_threshold) * (pred_i[0, ...] > lower_bound)]))
288 | not_cup_loss_num += torch.sum((pred_i[0, ...] < args.pseudo_label_threshold) * (pred_i[0, ...] > lower_bound))
289 | cup_loss_sum += torch.sum(-torch.log(pred_i[0, ...][(pred_i[0, ...] > args.pseudo_label_threshold) * (pred_i[0, ...] < upper_bound)]))
290 | cup_loss_num += torch.sum((pred_i[0, ...] > args.pseudo_label_threshold) * (pred_i[0, ...] < upper_bound))
291 | loss_weight = (cup_loss_sum.item() / cup_loss_num) / (not_cup_loss_sum.item() / not_cup_loss_num)
292 |
293 | adapt_epoch(model_t, model_s, optim, train_loader, args, feature_bank, pred_bank, loss_weight=loss_weight)
294 |
295 | scheduler.step()
296 |
297 | avg_dice, std_dice, avg_assd, std_assd = eval(model_t, test_loader)
298 | log_str = ("teacher dice: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f, assd: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f" % (
299 | avg_dice[0], std_dice[0], avg_dice[1], std_dice[1], (avg_dice[0] + avg_dice[1]) / 2.0,
300 | avg_assd[0], std_assd[0], avg_assd[1], std_assd[1], (avg_assd[0] + avg_assd[1]) / 2.0))
301 | print(log_str)
302 | args.out_file.write(log_str + '\n')
303 | args.out_file.flush()
304 |
305 | avg_dice, std_dice, avg_assd, std_assd = eval(model_s, test_loader)
306 | log_str = ("student dice: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f, assd: cup: %.4f+-%.4f disc: %.4f+-%.4f avg: %.4f" % (
307 | avg_dice[0], std_dice[0], avg_dice[1], std_dice[1], (avg_dice[0] + avg_dice[1]) / 2.0,
308 | avg_assd[0], std_assd[0], avg_assd[1], std_assd[1], (avg_assd[0] + avg_assd[1]) / 2.0))
309 | print(log_str)
310 | args.out_file.write(log_str + '\n')
311 | args.out_file.flush()
312 |
313 | torch.save({'model_state_dict': model_t.state_dict()}, args.out + '/after_adaptation.pth.tar')
314 |
315 |
316 | if __name__ == '__main__':
317 | main()
318 |
319 |
--------------------------------------------------------------------------------
/networks/backbone/drn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 | from networks.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
5 |
6 | webroot = 'https://tigress-web.princeton.edu/~fy/drn/models/'
7 |
8 | model_urls = {
9 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
10 | 'drn-c-26': webroot + 'drn_c_26-ddedf421.pth',
11 | 'drn-c-42': webroot + 'drn_c_42-9d336e8c.pth',
12 | 'drn-c-58': webroot + 'drn_c_58-0a53a92c.pth',
13 | 'drn-d-22': webroot + 'drn_d_22-4bd2f8ea.pth',
14 | 'drn-d-38': webroot + 'drn_d_38-eebb45f0.pth',
15 | 'drn-d-54': webroot + 'drn_d_54-0e0534ff.pth',
16 | 'drn-d-105': webroot + 'drn_d_105-12b40979.pth'
17 | }
18 |
19 |
20 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=padding, bias=False, dilation=dilation)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None,
29 | dilation=(1, 1), residual=True, BatchNorm=None):
30 | super(BasicBlock, self).__init__()
31 | self.conv1 = conv3x3(inplanes, planes, stride,
32 | padding=dilation[0], dilation=dilation[0])
33 | self.bn1 = BatchNorm(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.conv2 = conv3x3(planes, planes,
36 | padding=dilation[1], dilation=dilation[1])
37 | self.bn2 = BatchNorm(planes)
38 | self.downsample = downsample
39 | self.stride = stride
40 | self.residual = residual
41 |
42 | def forward(self, x):
43 | residual = x
44 |
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu(out)
48 |
49 | out = self.conv2(out)
50 | out = self.bn2(out)
51 |
52 | if self.downsample is not None:
53 | residual = self.downsample(x)
54 | if self.residual:
55 | out += residual
56 | out = self.relu(out)
57 |
58 | return out
59 |
60 |
61 | class Bottleneck(nn.Module):
62 | expansion = 4
63 |
64 | def __init__(self, inplanes, planes, stride=1, downsample=None,
65 | dilation=(1, 1), residual=True, BatchNorm=None):
66 | super(Bottleneck, self).__init__()
67 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
68 | self.bn1 = BatchNorm(planes)
69 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
70 | padding=dilation[1], bias=False,
71 | dilation=dilation[1])
72 | self.bn2 = BatchNorm(planes)
73 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
74 | self.bn3 = BatchNorm(planes * 4)
75 | self.relu = nn.ReLU(inplace=True)
76 | self.downsample = downsample
77 | self.stride = stride
78 |
79 | def forward(self, x):
80 | residual = x
81 |
82 | out = self.conv1(x)
83 | out = self.bn1(out)
84 | out = self.relu(out)
85 |
86 | out = self.conv2(out)
87 | out = self.bn2(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv3(out)
91 | out = self.bn3(out)
92 |
93 | if self.downsample is not None:
94 | residual = self.downsample(x)
95 |
96 | out += residual
97 | out = self.relu(out)
98 |
99 | return out
100 |
101 |
102 | class DRN(nn.Module):
103 |
104 | def __init__(self, block, layers, arch='D',
105 | channels=(16, 32, 64, 128, 256, 512, 512, 512),
106 | BatchNorm=None):
107 | super(DRN, self).__init__()
108 | self.inplanes = channels[0]
109 | self.out_dim = channels[-1]
110 | self.arch = arch
111 |
112 | if arch == 'C':
113 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1,
114 | padding=3, bias=False)
115 | self.bn1 = BatchNorm(channels[0])
116 | self.relu = nn.ReLU(inplace=True)
117 |
118 | self.layer1 = self._make_layer(
119 | BasicBlock, channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
120 | self.layer2 = self._make_layer(
121 | BasicBlock, channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
122 |
123 | elif arch == 'D':
124 | self.layer0 = nn.Sequential(
125 | nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
126 | bias=False),
127 | BatchNorm(channels[0]),
128 | nn.ReLU(inplace=True)
129 | )
130 |
131 | self.layer1 = self._make_conv_layers(
132 | channels[0], layers[0], stride=1, BatchNorm=BatchNorm)
133 | self.layer2 = self._make_conv_layers(
134 | channels[1], layers[1], stride=2, BatchNorm=BatchNorm)
135 |
136 | self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2, BatchNorm=BatchNorm)
137 | self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2, BatchNorm=BatchNorm)
138 | self.layer5 = self._make_layer(block, channels[4], layers[4],
139 | dilation=2, new_level=False, BatchNorm=BatchNorm)
140 | self.layer6 = None if layers[5] == 0 else \
141 | self._make_layer(block, channels[5], layers[5], dilation=4,
142 | new_level=False, BatchNorm=BatchNorm)
143 |
144 | if arch == 'C':
145 | self.layer7 = None if layers[6] == 0 else \
146 | self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
147 | new_level=False, residual=False, BatchNorm=BatchNorm)
148 | self.layer8 = None if layers[7] == 0 else \
149 | self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
150 | new_level=False, residual=False, BatchNorm=BatchNorm)
151 | elif arch == 'D':
152 | self.layer7 = None if layers[6] == 0 else \
153 | self._make_conv_layers(channels[6], layers[6], dilation=2, BatchNorm=BatchNorm)
154 | self.layer8 = None if layers[7] == 0 else \
155 | self._make_conv_layers(channels[7], layers[7], dilation=1, BatchNorm=BatchNorm)
156 |
157 | self._init_weight()
158 |
159 | def _init_weight(self):
160 | for m in self.modules():
161 | if isinstance(m, nn.Conv2d):
162 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
163 | m.weight.data.normal_(0, math.sqrt(2. / n))
164 | elif isinstance(m, SynchronizedBatchNorm2d):
165 | m.weight.data.fill_(1)
166 | m.bias.data.zero_()
167 | elif isinstance(m, nn.BatchNorm2d):
168 | m.weight.data.fill_(1)
169 | m.bias.data.zero_()
170 |
171 |
172 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
173 | new_level=True, residual=True, BatchNorm=None):
174 | assert dilation == 1 or dilation % 2 == 0
175 | downsample = None
176 | if stride != 1 or self.inplanes != planes * block.expansion:
177 | downsample = nn.Sequential(
178 | nn.Conv2d(self.inplanes, planes * block.expansion,
179 | kernel_size=1, stride=stride, bias=False),
180 | BatchNorm(planes * block.expansion),
181 | )
182 |
183 | layers = list()
184 | layers.append(block(
185 | self.inplanes, planes, stride, downsample,
186 | dilation=(1, 1) if dilation == 1 else (
187 | dilation // 2 if new_level else dilation, dilation),
188 | residual=residual, BatchNorm=BatchNorm))
189 | self.inplanes = planes * block.expansion
190 | for i in range(1, blocks):
191 | layers.append(block(self.inplanes, planes, residual=residual,
192 | dilation=(dilation, dilation), BatchNorm=BatchNorm))
193 |
194 | return nn.Sequential(*layers)
195 |
196 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1, BatchNorm=None):
197 | modules = []
198 | for i in range(convs):
199 | modules.extend([
200 | nn.Conv2d(self.inplanes, channels, kernel_size=3,
201 | stride=stride if i == 0 else 1,
202 | padding=dilation, bias=False, dilation=dilation),
203 | BatchNorm(channels),
204 | nn.ReLU(inplace=True)])
205 | self.inplanes = channels
206 | return nn.Sequential(*modules)
207 |
208 | def forward(self, x):
209 | if self.arch == 'C':
210 | x = self.conv1(x)
211 | x = self.bn1(x)
212 | x = self.relu(x)
213 | elif self.arch == 'D':
214 | x = self.layer0(x)
215 |
216 | x = self.layer1(x)
217 | x = self.layer2(x)
218 |
219 | x = self.layer3(x)
220 | low_level_feat = x
221 |
222 | x = self.layer4(x)
223 | x = self.layer5(x)
224 |
225 | if self.layer6 is not None:
226 | x = self.layer6(x)
227 |
228 | if self.layer7 is not None:
229 | x = self.layer7(x)
230 |
231 | if self.layer8 is not None:
232 | x = self.layer8(x)
233 |
234 | return x, low_level_feat
235 |
236 |
237 | class DRN_A(nn.Module):
238 |
239 | def __init__(self, block, layers, BatchNorm=None):
240 | self.inplanes = 64
241 | super(DRN_A, self).__init__()
242 | self.out_dim = 512 * block.expansion
243 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
244 | bias=False)
245 | self.bn1 = BatchNorm(64)
246 | self.relu = nn.ReLU(inplace=True)
247 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
248 | self.layer1 = self._make_layer(block, 64, layers[0], BatchNorm=BatchNorm)
249 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, BatchNorm=BatchNorm)
250 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
251 | dilation=2, BatchNorm=BatchNorm)
252 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
253 | dilation=4, BatchNorm=BatchNorm)
254 |
255 | self._init_weight()
256 |
257 | def _init_weight(self):
258 | for m in self.modules():
259 | if isinstance(m, nn.Conv2d):
260 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
261 | m.weight.data.normal_(0, math.sqrt(2. / n))
262 | elif isinstance(m, SynchronizedBatchNorm2d):
263 | m.weight.data.fill_(1)
264 | m.bias.data.zero_()
265 | elif isinstance(m, nn.BatchNorm2d):
266 | m.weight.data.fill_(1)
267 | m.bias.data.zero_()
268 |
269 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
270 | downsample = None
271 | if stride != 1 or self.inplanes != planes * block.expansion:
272 | downsample = nn.Sequential(
273 | nn.Conv2d(self.inplanes, planes * block.expansion,
274 | kernel_size=1, stride=stride, bias=False),
275 | BatchNorm(planes * block.expansion),
276 | )
277 |
278 | layers = []
279 | layers.append(block(self.inplanes, planes, stride, downsample, BatchNorm=BatchNorm))
280 | self.inplanes = planes * block.expansion
281 | for i in range(1, blocks):
282 | layers.append(block(self.inplanes, planes,
283 | dilation=(dilation, dilation, ), BatchNorm=BatchNorm))
284 |
285 | return nn.Sequential(*layers)
286 |
287 | def forward(self, x):
288 | x = self.conv1(x)
289 | x = self.bn1(x)
290 | x = self.relu(x)
291 | x = self.maxpool(x)
292 |
293 | x = self.layer1(x)
294 | x = self.layer2(x)
295 | x = self.layer3(x)
296 | x = self.layer4(x)
297 |
298 | return x
299 |
300 | def drn_a_50(BatchNorm, pretrained=True):
301 | model = DRN_A(Bottleneck, [3, 4, 6, 3], BatchNorm=BatchNorm)
302 | if pretrained:
303 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
304 | return model
305 |
306 |
307 | def drn_c_26(BatchNorm, pretrained=True):
308 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', BatchNorm=BatchNorm)
309 | if pretrained:
310 | pretrained = model_zoo.load_url(model_urls['drn-c-26'])
311 | del pretrained['fc.weight']
312 | del pretrained['fc.bias']
313 | model.load_state_dict(pretrained)
314 | return model
315 |
316 |
317 | def drn_c_42(BatchNorm, pretrained=True):
318 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
319 | if pretrained:
320 | pretrained = model_zoo.load_url(model_urls['drn-c-42'])
321 | del pretrained['fc.weight']
322 | del pretrained['fc.bias']
323 | model.load_state_dict(pretrained)
324 | return model
325 |
326 |
327 | def drn_c_58(BatchNorm, pretrained=True):
328 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', BatchNorm=BatchNorm)
329 | if pretrained:
330 | pretrained = model_zoo.load_url(model_urls['drn-c-58'])
331 | del pretrained['fc.weight']
332 | del pretrained['fc.bias']
333 | model.load_state_dict(pretrained)
334 | return model
335 |
336 |
337 | def drn_d_22(BatchNorm, pretrained=True):
338 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', BatchNorm=BatchNorm)
339 | if pretrained:
340 | pretrained = model_zoo.load_url(model_urls['drn-d-22'])
341 | del pretrained['fc.weight']
342 | del pretrained['fc.bias']
343 | model.load_state_dict(pretrained)
344 | return model
345 |
346 |
347 | def drn_d_24(BatchNorm, pretrained=True):
348 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', BatchNorm=BatchNorm)
349 | if pretrained:
350 | pretrained = model_zoo.load_url(model_urls['drn-d-24'])
351 | del pretrained['fc.weight']
352 | del pretrained['fc.bias']
353 | model.load_state_dict(pretrained)
354 | return model
355 |
356 |
357 | def drn_d_38(BatchNorm, pretrained=True):
358 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
359 | if pretrained:
360 | pretrained = model_zoo.load_url(model_urls['drn-d-38'])
361 | del pretrained['fc.weight']
362 | del pretrained['fc.bias']
363 | model.load_state_dict(pretrained)
364 | return model
365 |
366 |
367 | def drn_d_40(BatchNorm, pretrained=True):
368 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', BatchNorm=BatchNorm)
369 | if pretrained:
370 | pretrained = model_zoo.load_url(model_urls['drn-d-40'])
371 | del pretrained['fc.weight']
372 | del pretrained['fc.bias']
373 | model.load_state_dict(pretrained)
374 | return model
375 |
376 |
377 | def drn_d_54(BatchNorm, pretrained=True):
378 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
379 | if pretrained:
380 | pretrained = model_zoo.load_url(model_urls['drn-d-54'])
381 | del pretrained['fc.weight']
382 | del pretrained['fc.bias']
383 | model.load_state_dict(pretrained)
384 | return model
385 |
386 |
387 | def drn_d_105(BatchNorm, pretrained=True):
388 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', BatchNorm=BatchNorm)
389 | if pretrained:
390 | pretrained = model_zoo.load_url(model_urls['drn-d-105'])
391 | del pretrained['fc.weight']
392 | del pretrained['fc.bias']
393 | model.load_state_dict(pretrained)
394 | return model
395 |
396 | if __name__ == "__main__":
397 | import torch
398 | model = drn_a_50(BatchNorm=nn.BatchNorm2d, pretrained=True)
399 | input = torch.rand(1, 3, 512, 512)
400 | output, low_level_feat = model(input)
401 | print(output.size())
402 | print(low_level_feat.size())
403 |
--------------------------------------------------------------------------------
/networks/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 |
7 |
8 | ###############################################################################
9 | # Helper Functions
10 | ###############################################################################
11 |
12 |
13 | class Identity(nn.Module):
14 | def forward(self, x):
15 | return x
16 |
17 |
18 | def get_norm_layer(norm_type='instance'):
19 | """Return a normalization layer
20 |
21 | Parameters:
22 | norm_type (str) -- the name of the normalization layer: batch | instance | none
23 |
24 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
25 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
26 | """
27 | if norm_type == 'batch':
28 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
29 | elif norm_type == 'instance':
30 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
31 | elif norm_type == 'none':
32 | def norm_layer(x): return Identity()
33 | else:
34 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
35 | return norm_layer
36 |
37 |
38 | def get_scheduler(optimizer, opt):
39 | """Return a learning rate scheduler
40 |
41 | Parameters:
42 | optimizer -- the optimizer of the network
43 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
44 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
45 |
46 | For 'linear', we keep the same learning rate for the first epochs
47 | and linearly decay the rate to zero over the next epochs.
48 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
49 | See https://pytorch.org/docs/stable/optim.html for more details.
50 | """
51 | if opt.lr_policy == 'linear':
52 | def lambda_rule(epoch):
53 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
54 | return lr_l
55 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
56 | elif opt.lr_policy == 'step':
57 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
58 | elif opt.lr_policy == 'plateau':
59 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
60 | elif opt.lr_policy == 'cosine':
61 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
62 | else:
63 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
64 | return scheduler
65 |
66 |
67 | def init_weights(net, init_type='normal', init_gain=0.02):
68 | """Initialize network weights.
69 |
70 | Parameters:
71 | net (network) -- network to be initialized
72 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
73 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
74 |
75 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
76 | work better for some applications. Feel free to try yourself.
77 | """
78 | def init_func(m): # define the initialization function
79 | classname = m.__class__.__name__
80 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
81 | if init_type == 'normal':
82 | init.normal_(m.weight.data, 0.0, init_gain)
83 | elif init_type == 'xavier':
84 | init.xavier_normal_(m.weight.data, gain=init_gain)
85 | elif init_type == 'kaiming':
86 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
87 | elif init_type == 'orthogonal':
88 | init.orthogonal_(m.weight.data, gain=init_gain)
89 | else:
90 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
91 | if hasattr(m, 'bias') and m.bias is not None:
92 | init.constant_(m.bias.data, 0.0)
93 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
94 | init.normal_(m.weight.data, 1.0, init_gain)
95 | init.constant_(m.bias.data, 0.0)
96 |
97 | print('initialize network with %s' % init_type)
98 | net.apply(init_func) # apply the initialization function
99 |
100 |
101 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
102 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
103 | Parameters:
104 | net (network) -- the network to be initialized
105 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
106 | gain (float) -- scaling factor for normal, xavier and orthogonal.
107 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
108 |
109 | Return an initialized network.
110 | """
111 | if len(gpu_ids) > 0:
112 | # assert(torch.cuda.is_available())
113 | net.to(gpu_ids[0])
114 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
115 | init_weights(net, init_type, init_gain=init_gain)
116 | return net
117 |
118 |
119 | def define_G(input_nc=3, output_nc=1, ngf=64, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
120 | """Create a generator
121 |
122 | Parameters:
123 | input_nc (int) -- the number of channels in input images
124 | output_nc (int) -- the number of channels in output images
125 | ngf (int) -- the number of filters in the last conv layer
126 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
127 | norm (str) -- the name of normalization layers used in the network: batch | instance | none
128 | use_dropout (bool) -- if use dropout layers.
129 | init_type (str) -- the name of our initialization method.
130 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
131 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
132 |
133 | Returns a generator
134 |
135 | Our current implementation provides two types of generators:
136 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
137 | The original U-Net paper: https://arxiv.org/abs/1505.04597
138 |
139 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
140 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
141 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
142 |
143 |
144 | The generator has been initialized by . It uses RELU for non-linearity.
145 | """
146 |
147 | norm_layer = get_norm_layer(norm_type=norm)
148 |
149 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
150 |
151 | return init_net(net, init_type, init_gain, gpu_ids)
152 |
153 |
154 | def define_D(input_nc=2, ndf=64, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
155 | """Create a discriminator
156 |
157 | Parameters:
158 | input_nc (int) -- the number of channels in input images
159 | ndf (int) -- the number of filters in the first conv layer
160 | netD (str) -- the architecture's name: basic | n_layers | pixel
161 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
162 | norm (str) -- the type of normalization layers used in the network.
163 | init_type (str) -- the name of the initialization method.
164 | init_gain (float) -- scaling factor for normal, xavier and orthogonal.
165 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
166 |
167 | Returns a discriminator
168 |
169 | Our current implementation provides three types of discriminators:
170 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
171 | It can classify whether 70×70 overlapping patches are real or fake.
172 | Such a patch-level discriminator architecture has fewer parameters
173 | than a full-image discriminator and can work on arbitrarily-sized images
174 | in a fully convolutional fashion.
175 |
176 | [n_layers]: With this mode, you can specify the number of conv layers in the discriminator
177 | with the parameter (default=3 as used in [basic] (PatchGAN).)
178 |
179 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
180 | It encourages greater color diversity but has no effect on spatial statistics.
181 |
182 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity.
183 | """
184 |
185 | norm_layer = get_norm_layer(norm_type=norm)
186 |
187 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
188 |
189 | return init_net(net, init_type, init_gain, gpu_ids)
190 |
191 |
192 | ##############################################################################
193 | # Classes
194 | ##############################################################################
195 | class GANLoss(nn.Module):
196 | """Define different GAN objectives.
197 |
198 | The GANLoss class abstracts away the need to create the target label tensor
199 | that has the same size as the input.
200 | """
201 |
202 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
203 | """ Initialize the GANLoss class.
204 |
205 | Parameters:
206 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
207 | target_real_label (bool) - - label for a real image
208 | target_fake_label (bool) - - label of a fake image
209 |
210 | Note: Do not use sigmoid as the last layer of Discriminator.
211 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
212 | """
213 | super(GANLoss, self).__init__()
214 | self.register_buffer('real_label', torch.tensor(target_real_label))
215 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
216 | self.gan_mode = gan_mode
217 | if gan_mode == 'lsgan':
218 | self.loss = nn.MSELoss()
219 | elif gan_mode == 'vanilla':
220 | self.loss = nn.BCEWithLogitsLoss()
221 | elif gan_mode in ['wgangp']:
222 | self.loss = None
223 | else:
224 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
225 |
226 | def get_target_tensor(self, prediction, target_is_real):
227 | """Create label tensors with the same size as the input.
228 |
229 | Parameters:
230 | prediction (tensor) - - tpyically the prediction from a discriminator
231 | target_is_real (bool) - - if the ground truth label is for real images or fake images
232 |
233 | Returns:
234 | A label tensor filled with ground truth label, and with the size of the input
235 | """
236 |
237 | if target_is_real:
238 | target_tensor = self.real_label
239 | else:
240 | target_tensor = self.fake_label
241 | return target_tensor.expand_as(prediction)
242 |
243 | def __call__(self, prediction, target_is_real):
244 | """Calculate loss given Discriminator's output and grount truth labels.
245 |
246 | Parameters:
247 | prediction (tensor) - - tpyically the prediction output from a discriminator
248 | target_is_real (bool) - - if the ground truth label is for real images or fake images
249 |
250 | Returns:
251 | the calculated loss.
252 | """
253 | if self.gan_mode in ['lsgan', 'vanilla']:
254 | target_tensor = self.get_target_tensor(prediction, target_is_real)
255 | loss = self.loss(prediction, target_tensor)
256 | elif self.gan_mode == 'wgangp':
257 | if target_is_real:
258 | loss = -prediction.mean()
259 | else:
260 | loss = prediction.mean()
261 | return loss
262 |
263 |
264 | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
265 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
266 |
267 | Arguments:
268 | netD (network) -- discriminator network
269 | real_data (tensor array) -- real images
270 | fake_data (tensor array) -- generated images from the generator
271 | device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
272 | type (str) -- if we mix real and fake data or not [real | fake | mixed].
273 | constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2
274 | lambda_gp (float) -- weight for this loss
275 |
276 | Returns the gradient penalty loss
277 | """
278 | if lambda_gp > 0.0:
279 | if type == 'real': # either use real images, fake images, or a linear interpolation of two.
280 | interpolatesv = real_data
281 | elif type == 'fake':
282 | interpolatesv = fake_data
283 | elif type == 'mixed':
284 | alpha = torch.rand(real_data.shape[0], 1, device=device)
285 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
286 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
287 | else:
288 | raise NotImplementedError('{} not implemented'.format(type))
289 | interpolatesv.requires_grad_(True)
290 | disc_interpolates = netD(interpolatesv)
291 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
292 | grad_outputs=torch.ones(disc_interpolates.size()).to(device),
293 | create_graph=True, retain_graph=True, only_inputs=True)
294 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data
295 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
296 | return gradient_penalty, gradients
297 | else:
298 | return 0.0, None
299 |
300 |
301 | class UnetGenerator(nn.Module):
302 | """Create a Unet-based generator"""
303 |
304 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
305 | """Construct a Unet generator
306 | Parameters:
307 | input_nc (int) -- the number of channels in input images
308 | output_nc (int) -- the number of channels in output images
309 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
310 | image of size 128x128 will become of size 1x1 # at the bottleneck
311 | ngf (int) -- the number of filters in the last conv layer
312 | norm_layer -- normalization layer
313 |
314 | We construct the U-Net from the innermost layer to the outermost layer.
315 | It is a recursive process.
316 | """
317 | super(UnetGenerator, self).__init__()
318 | # construct unet structure
319 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
320 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
321 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
322 | # gradually reduce the number of filters from ngf * 8 to ngf
323 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
324 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
325 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
326 | self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
327 |
328 | def forward(self, input):
329 | """Standard forward"""
330 | return self.model(input)
331 |
332 |
333 | class UnetSkipConnectionBlock(nn.Module):
334 | """Defines the Unet submodule with skip connection.
335 | X -------------------identity----------------------
336 | |-- downsampling -- |submodule| -- upsampling --|
337 | """
338 |
339 | def __init__(self, outer_nc, inner_nc, input_nc=None,
340 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
341 | """Construct a Unet submodule with skip connections.
342 |
343 | Parameters:
344 | outer_nc (int) -- the number of filters in the outer conv layer
345 | inner_nc (int) -- the number of filters in the inner conv layer
346 | input_nc (int) -- the number of channels in input images/features
347 | submodule (UnetSkipConnectionBlock) -- previously defined submodules
348 | outermost (bool) -- if this module is the outermost module
349 | innermost (bool) -- if this module is the innermost module
350 | norm_layer -- normalization layer
351 | use_dropout (bool) -- if use dropout layers.
352 | """
353 | super(UnetSkipConnectionBlock, self).__init__()
354 | self.outermost = outermost
355 | if type(norm_layer) == functools.partial:
356 | use_bias = norm_layer.func == nn.InstanceNorm2d
357 | else:
358 | use_bias = norm_layer == nn.InstanceNorm2d
359 | if input_nc is None:
360 | input_nc = outer_nc
361 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
362 | stride=2, padding=1, bias=use_bias)
363 | downrelu = nn.LeakyReLU(0.2, True)
364 | downnorm = norm_layer(inner_nc)
365 | uprelu = nn.ReLU(True)
366 | upnorm = norm_layer(outer_nc)
367 |
368 | if outermost:
369 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
370 | kernel_size=4, stride=2,
371 | padding=1)
372 | down = [downconv]
373 | up = [uprelu, upconv, nn.Tanh()]
374 | model = down + [submodule] + up
375 | elif innermost:
376 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
377 | kernel_size=4, stride=2,
378 | padding=1, bias=use_bias)
379 | down = [downrelu, downconv]
380 | up = [uprelu, upconv, upnorm]
381 | model = down + up
382 | else:
383 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
384 | kernel_size=4, stride=2,
385 | padding=1, bias=use_bias)
386 | down = [downrelu, downconv, downnorm]
387 | up = [uprelu, upconv, upnorm]
388 |
389 | if use_dropout:
390 | model = down + [submodule] + up + [nn.Dropout(0.5)]
391 | else:
392 | model = down + [submodule] + up
393 |
394 | self.model = nn.Sequential(*model)
395 |
396 | def forward(self, x):
397 | if self.outermost:
398 | return self.model(x)
399 | else: # add skip connections
400 | return torch.cat([x, self.model(x)], 1)
401 |
402 |
403 | class NLayerDiscriminator(nn.Module):
404 | """Defines a PatchGAN discriminator"""
405 |
406 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
407 | """Construct a PatchGAN discriminator
408 |
409 | Parameters:
410 | input_nc (int) -- the number of channels in input images
411 | ndf (int) -- the number of filters in the last conv layer
412 | n_layers (int) -- the number of conv layers in the discriminator
413 | norm_layer -- normalization layer
414 | """
415 | super(NLayerDiscriminator, self).__init__()
416 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
417 | use_bias = norm_layer.func == nn.InstanceNorm2d
418 | else:
419 | use_bias = norm_layer == nn.InstanceNorm2d
420 |
421 | kw = 4
422 | padw = 1
423 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
424 | nf_mult = 1
425 | nf_mult_prev = 1
426 | for n in range(1, n_layers): # gradually increase the number of filters
427 | nf_mult_prev = nf_mult
428 | nf_mult = min(2 ** n, 8)
429 | sequence += [
430 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
431 | norm_layer(ndf * nf_mult),
432 | nn.LeakyReLU(0.2, True)
433 | ]
434 |
435 | nf_mult_prev = nf_mult
436 | nf_mult = min(2 ** n_layers, 8)
437 | sequence += [
438 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
439 | norm_layer(ndf * nf_mult),
440 | nn.LeakyReLU(0.2, True)
441 | ]
442 |
443 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
444 | self.model = nn.Sequential(*sequence)
445 |
446 | def forward(self, input):
447 | """Standard forward."""
448 | return self.model(input)
449 |
450 |
451 |
--------------------------------------------------------------------------------
/dataloaders/custom_transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numbers
4 | import random
5 | import numpy as np
6 |
7 | from PIL import Image, ImageOps
8 | from scipy.ndimage.filters import gaussian_filter
9 | from matplotlib.pyplot import imshow, imsave
10 | from scipy.ndimage.interpolation import map_coordinates
11 | import cv2
12 | from scipy import ndimage
13 |
14 |
15 | def to_multilabel(pre_mask, classes=2): # notice classes para
16 | mask = np.zeros((pre_mask.shape[0], pre_mask.shape[1], classes))
17 | mask[pre_mask == 1] = [0, 1]
18 | mask[pre_mask == 2] = [1, 1]
19 | # 3 channel ver.
20 | # mask[pre_mask == 0] = [1, 0, 0]
21 | # mask[pre_mask == 1] = [0, 1, 0]
22 | # mask[pre_mask == 2] = [0, 0, 1]
23 | return mask
24 |
25 |
26 | class add_salt_pepper_noise():
27 | def __call__(self, sample):
28 |
29 | image = np.array(sample['image']).astype(np.uint8)
30 | X_imgs_copy = image.copy()
31 | # row = image.shape[0]
32 | # col = image.shape[1]
33 | salt_vs_pepper = 0.2
34 | amount = 0.004
35 |
36 | num_salt = np.ceil(amount * X_imgs_copy.size * salt_vs_pepper)
37 | num_pepper = np.ceil(amount * X_imgs_copy.size * (1.0 - salt_vs_pepper))
38 |
39 | seed = random.random()
40 | if seed > 0.75:
41 | # Add Salt noise
42 | coords = [np.random.randint(0, i - 1, int(num_salt)) for i in X_imgs_copy.shape]
43 | X_imgs_copy[coords[0], coords[1], :] = 1
44 | elif seed > 0.5:
45 | # Add Pepper noise
46 | coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in X_imgs_copy.shape]
47 | X_imgs_copy[coords[0], coords[1], :] = 0
48 |
49 | return {'image': X_imgs_copy,
50 | 'label': sample['label'],
51 | 'img_name': sample['img_name']}
52 |
53 |
54 | class add_salt_pepper_noise_BraTS():
55 | def __call__(self, sample):
56 |
57 |
58 | images_0 = np.array(sample['image'][0]).astype(np.uint8)
59 | images_1 = np.array(sample['image'][1]).astype(np.uint8)
60 | images_2 = np.array(sample['image'][2]).astype(np.uint8)
61 | images_3 = np.array(sample['image'][3]).astype(np.uint8)
62 | X_imgs_copy = np.zeros([images_0.shape[0], images_0.shape[0], 4], dtype=np.uint8)
63 | X_imgs_copy[:, :, 0] = images_0.copy()
64 | X_imgs_copy[:, :, 1] = images_1.copy()
65 | X_imgs_copy[:, :, 2] = images_2.copy()
66 | X_imgs_copy[:, :, 3] = images_3.copy()
67 | # row = image.shape[0]
68 | # col = image.shape[1]
69 | salt_vs_pepper = 0.2
70 | amount = 0.004
71 |
72 | num_salt = np.ceil(amount * X_imgs_copy.size * salt_vs_pepper)
73 | num_pepper = np.ceil(amount * X_imgs_copy.size * (1.0 - salt_vs_pepper))
74 |
75 | seed = random.random()
76 | if seed > 0.75:
77 | # Add Salt noise
78 | coords = [np.random.randint(0, i - 1, int(num_salt)) for i in X_imgs_copy.shape]
79 | X_imgs_copy[coords[0], coords[1], :] = 1
80 | elif seed > 0.5:
81 | # Add Pepper noise
82 | coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in X_imgs_copy.shape]
83 | X_imgs_copy[coords[0], coords[1], :] = 0
84 |
85 | return {'image': X_imgs_copy,
86 | 'label': sample['label'],
87 | 'img_name': sample['img_name']}
88 |
89 |
90 | class adjust_light():
91 | def __call__(self, sample):
92 | image = sample['image']
93 | seed = random.random()
94 | if seed > 0.5:
95 | gamma = random.random() * 3 + 0.5
96 | invGamma = 1.0 / gamma
97 | table = np.array([((i / 255.0) ** invGamma) * 255 for i in np.arange(0, 256)]).astype(np.uint8)
98 | image = cv2.LUT(np.array(image).astype(np.uint8), table).astype(np.uint8)
99 | return {'image': image,
100 | 'label': sample['label'],
101 | 'img_name': sample['img_name']}
102 | else:
103 | return sample
104 |
105 |
106 | class eraser():
107 | def __call__(self, sample, s_l=0.02, s_h=0.06, r_1=0.3, r_2=0.6, v_l=0, v_h=255, pixel_level=False):
108 | image = sample['image']
109 | img_h, img_w, img_c = image.shape
110 |
111 |
112 | if random.random() > 0.5:
113 | return sample
114 |
115 | while True:
116 | s = np.random.uniform(s_l, s_h) * img_h * img_w
117 | r = np.random.uniform(r_1, r_2)
118 | w = int(np.sqrt(s / r))
119 | h = int(np.sqrt(s * r))
120 | left = np.random.randint(0, img_w)
121 | top = np.random.randint(0, img_h)
122 |
123 | if left + w <= img_w and top + h <= img_h:
124 | break
125 |
126 | if pixel_level:
127 | c = np.random.uniform(v_l, v_h, (h, w, img_c))
128 | else:
129 | c = np.random.uniform(v_l, v_h)
130 |
131 | image[top:top + h, left:left + w, :] = c
132 |
133 | return {'image': image,
134 | 'label': sample['label'],
135 | 'img_name': sample['img_name']}
136 |
137 | class elastic_transform():
138 | """Elastic deformation of images as described in [Simard2003]_.
139 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
140 | Convolutional Neural Networks applied to Visual Document Analysis", in
141 | Proc. of the International Conference on Document Analysis and
142 | Recognition, 2003.
143 | """
144 |
145 | # def __init__(self):
146 |
147 | def __call__(self, sample):
148 | image, label = sample['image'], sample['label']
149 | alpha = image.size[1] * 2
150 | sigma = image.size[1] * 0.08
151 | random_state = None
152 | seed = random.random()
153 | if seed > 0.5:
154 | # print(image.size)
155 | assert len(image.size) == 2
156 |
157 | if random_state is None:
158 | random_state = np.random.RandomState(None)
159 |
160 | shape = image.size[0:2]
161 | dx = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
162 | dy = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
163 |
164 | x, y = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
165 | indices = np.reshape(x + dx, (-1, 1)), np.reshape(y + dy, (-1, 1))
166 |
167 | transformed_image = np.zeros([image.size[0], image.size[1], 3])
168 | transformed_label = np.zeros([image.size[0], image.size[1]])
169 |
170 | for i in range(3):
171 | # print(i)
172 | transformed_image[:, :, i] = map_coordinates(np.array(image)[:, :, i], indices, order=1).reshape(shape)
173 | # break
174 | if label is not None:
175 | transformed_label[:, :] = map_coordinates(np.array(label)[:, :], indices, order=1, mode='nearest').reshape(shape)
176 | else:
177 | transformed_label = None
178 | transformed_image = transformed_image.astype(np.uint8)
179 |
180 | if label is not None:
181 | transformed_label = transformed_label.astype(np.uint8)
182 |
183 | return {'image': transformed_image,
184 | 'label': transformed_label,
185 | 'img_name': sample['img_name']}
186 | else:
187 | return {'image': np.array(sample['image']),
188 | 'label': np.array(sample['label']),
189 | 'img_name': sample['img_name']}
190 |
191 |
192 |
193 |
194 | class RandomCrop(object):
195 | def __init__(self, size, padding=0):
196 | if isinstance(size, numbers.Number):
197 | self.size = (int(size), int(size))
198 | else:
199 | self.size = size # h, w
200 | self.padding = padding
201 |
202 | def __call__(self, sample):
203 | img, mask = sample['image'], sample['label']
204 | w, h = img.size
205 | if self.padding > 0 or w < self.size[0] or h < self.size[1]:
206 | padding = np.maximum(self.padding,np.maximum((self.size[0]-w)//2+5,(self.size[1]-h)//2+5))
207 | img = ImageOps.expand(img, border=padding, fill=0)
208 | mask = ImageOps.expand(mask, border=padding, fill=255)
209 |
210 | assert img.width == mask.width
211 | assert img.height == mask.height
212 | w, h = img.size
213 | th, tw = self.size # target size
214 | if w == tw and h == th:
215 | return {'image': img,
216 | 'label': mask,
217 | 'img_name': sample['img_name']}
218 | x1 = random.randint(0, w - tw)
219 | y1 = random.randint(0, h - th)
220 | img = img.crop((x1, y1, x1 + tw, y1 + th))
221 | mask = mask.crop((x1, y1, x1 + tw, y1 + th))
222 | return {'image': img,
223 | 'label': mask,
224 | 'img_name': sample['img_name']}
225 |
226 |
227 | class RandomCrop_BraTS(object):
228 | def __init__(self, size, padding=0):
229 | if isinstance(size, numbers.Number):
230 | self.size = (int(size), int(size))
231 | else:
232 | self.size = size # h, w
233 | self.padding = padding
234 |
235 | def __call__(self, sample):
236 | imgs, mask = sample['image'], sample['label']
237 | w, h = imgs[0].size
238 | if self.padding > 0 or w < self.size[0] or h < self.size[1]:
239 | padding = np.maximum(self.padding,np.maximum((self.size[0]-w)//2+5,(self.size[1]-h)//2+5))
240 | imgs[0] = ImageOps.expand(imgs[0], border=padding, fill=0)
241 | imgs[1] = ImageOps.expand(imgs[1], border=padding, fill=0)
242 | imgs[2] = ImageOps.expand(imgs[2], border=padding, fill=0)
243 | imgs[3] = ImageOps.expand(imgs[3], border=padding, fill=0)
244 | mask = ImageOps.expand(mask, border=padding, fill=255)
245 |
246 | assert imgs[0].width == mask.width
247 | assert imgs[0].height == mask.height
248 | w, h = imgs[0].size
249 | th, tw = self.size # target size
250 | if w == tw and h == th:
251 | return {'image': imgs,
252 | 'label': mask,
253 | 'img_name': sample['img_name']}
254 | x1 = random.randint(0, w - tw)
255 | y1 = random.randint(0, h - th)
256 | imgs[0] = imgs[0].crop((x1, y1, x1 + tw, y1 + th))
257 | imgs[1] = imgs[1].crop((x1, y1, x1 + tw, y1 + th))
258 | imgs[2] = imgs[2].crop((x1, y1, x1 + tw, y1 + th))
259 | imgs[3] = imgs[3].crop((x1, y1, x1 + tw, y1 + th))
260 | mask = mask.crop((x1, y1, x1 + tw, y1 + th))
261 | return {'image': imgs,
262 | 'label': mask,
263 | 'img_name': sample['img_name']}
264 |
265 |
266 | class CenterCrop(object):
267 | def __init__(self, size):
268 | if isinstance(size, numbers.Number):
269 | self.size = (int(size), int(size))
270 | else:
271 | self.size = size
272 |
273 | def __call__(self, sample):
274 | img = sample['image']
275 | mask = sample['label']
276 |
277 | w, h = img.size
278 | th, tw = self.size
279 | x1 = int(round((w - tw) / 2.))
280 | y1 = int(round((h - th) / 2.))
281 | img = img.crop((x1, y1, x1 + tw, y1 + th))
282 | mask = mask.crop((x1, y1, x1 + tw, y1 + th))
283 |
284 | return {'image': img,
285 | 'label': mask,
286 | 'img_name': sample['img_name']}
287 |
288 |
289 | class RandomFlip(object):
290 | def __call__(self, sample):
291 | img = sample['image']
292 | mask = sample['label']
293 | name = sample['img_name']
294 | if random.random() < 0.5:
295 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
296 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
297 | if random.random() < 0.5:
298 | img = img.transpose(Image.FLIP_TOP_BOTTOM)
299 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
300 |
301 | return {'image': img,
302 | 'label': mask,
303 | 'img_name': name
304 | }
305 |
306 | class RandomFlip_BraTS(object):
307 | def __call__(self, sample):
308 | imgs = sample['image']
309 | mask = sample['label']
310 | name = sample['img_name']
311 | if random.random() < 0.5:
312 | imgs[0] = imgs[0].transpose(Image.FLIP_LEFT_RIGHT)
313 | imgs[1] = imgs[1].transpose(Image.FLIP_LEFT_RIGHT)
314 | imgs[2] = imgs[2].transpose(Image.FLIP_LEFT_RIGHT)
315 | imgs[3] = imgs[3].transpose(Image.FLIP_LEFT_RIGHT)
316 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
317 | if random.random() < 0.5:
318 | imgs[0] = imgs[0].transpose(Image.FLIP_TOP_BOTTOM)
319 | imgs[1] = imgs[1].transpose(Image.FLIP_TOP_BOTTOM)
320 | imgs[2] = imgs[2].transpose(Image.FLIP_TOP_BOTTOM)
321 | imgs[3] = imgs[3].transpose(Image.FLIP_TOP_BOTTOM)
322 | mask = mask.transpose(Image.FLIP_TOP_BOTTOM)
323 |
324 | return {'image': imgs,
325 | 'label': mask,
326 | 'img_name': name
327 | }
328 |
329 |
330 | class FixedResize(object):
331 | def __init__(self, size):
332 | self.size = tuple(reversed(size)) # size: (h, w)
333 |
334 | def __call__(self, sample):
335 | img = sample['image']
336 | mask = sample['label']
337 | name = sample['img_name']
338 |
339 | assert img.width == mask.width
340 | assert img.height == mask.height
341 | img = img.resize(self.size, Image.BILINEAR)
342 | mask = mask.resize(self.size, Image.NEAREST)
343 |
344 | return {'image': img,
345 | 'label': mask,
346 | 'img_name': name}
347 |
348 |
349 | class Scale(object):
350 | def __init__(self, size):
351 | if isinstance(size, numbers.Number):
352 | self.size = (int(size), int(size))
353 | else:
354 | self.size = size
355 |
356 | def __call__(self, sample):
357 | img = sample['image']
358 | mask = sample['label']
359 | assert img.width == mask.width
360 | assert img.height == mask.height
361 | w, h = img.size
362 |
363 | if (w >= h and w == self.size[1]) or (h >= w and h == self.size[0]):
364 | return {'image': img,
365 | 'label': mask,
366 | 'img_name': sample['img_name']}
367 | oh, ow = self.size
368 | img = img.resize((ow, oh), Image.BILINEAR)
369 | mask = mask.resize((ow, oh), Image.NEAREST)
370 |
371 | return {'image': img,
372 | 'label': mask,
373 | 'img_name': sample['img_name']}
374 |
375 |
376 | class RandomSizedCrop(object):
377 | def __init__(self, size):
378 | self.size = size
379 |
380 | def __call__(self, sample):
381 | img = sample['image']
382 | mask = sample['label']
383 | name = sample['img_name']
384 | assert img.width == mask.width
385 | assert img.height == mask.height
386 | for attempt in range(10):
387 | area = img.size[0] * img.size[1]
388 | target_area = random.uniform(0.45, 1.0) * area
389 | aspect_ratio = random.uniform(0.5, 2)
390 |
391 | w = int(round(math.sqrt(target_area * aspect_ratio)))
392 | h = int(round(math.sqrt(target_area / aspect_ratio)))
393 |
394 | if random.random() < 0.5:
395 | w, h = h, w
396 |
397 | if w <= img.size[0] and h <= img.size[1]:
398 | x1 = random.randint(0, img.size[0] - w)
399 | y1 = random.randint(0, img.size[1] - h)
400 |
401 | img = img.crop((x1, y1, x1 + w, y1 + h))
402 | mask = mask.crop((x1, y1, x1 + w, y1 + h))
403 | assert (img.size == (w, h))
404 |
405 | img = img.resize((self.size, self.size), Image.BILINEAR)
406 | mask = mask.resize((self.size, self.size), Image.NEAREST)
407 |
408 | return {'image': img,
409 | 'label': mask,
410 | 'img_name': name}
411 |
412 | # Fallback
413 | scale = Scale(self.size)
414 | crop = CenterCrop(self.size)
415 | sample = crop(scale(sample))
416 | return sample
417 |
418 |
419 | class RandomRotate(object):
420 | def __init__(self, size=512):
421 | self.degree = random.randint(1, 4) * 90
422 | self.size = size
423 |
424 | def __call__(self, sample):
425 | img = sample['image']
426 | mask = sample['label']
427 |
428 | seed = random.random()
429 | if seed > 0.5:
430 | rotate_degree = self.degree
431 | img = img.rotate(rotate_degree, Image.BILINEAR, expand=0)
432 | mask = mask.rotate(rotate_degree, Image.NEAREST, expand=255)
433 |
434 | sample = {'image': img, 'label': mask, 'img_name': sample['img_name']}
435 | return sample
436 |
437 |
438 | class RandomRotate_BraTS(object):
439 | def __init__(self, size=512):
440 | self.degree = random.randint(1, 4) * 90
441 | self.size = size
442 |
443 | def __call__(self, sample):
444 | imgs = sample['image']
445 | mask = sample['label']
446 |
447 | seed = random.random()
448 | if seed > 0.5:
449 | rotate_degree = self.degree
450 | imgs[0] = imgs[0].rotate(rotate_degree, Image.BILINEAR, expand=0)
451 | imgs[1] = imgs[1].rotate(rotate_degree, Image.BILINEAR, expand=0)
452 | imgs[2] = imgs[2].rotate(rotate_degree, Image.BILINEAR, expand=0)
453 | imgs[3] = imgs[3].rotate(rotate_degree, Image.BILINEAR, expand=0)
454 | mask = mask.rotate(rotate_degree, Image.NEAREST, expand=255)
455 |
456 | sample = {'image': imgs, 'label': mask, 'img_name': sample['img_name']}
457 | return sample
458 |
459 |
460 | class RandomScaleCrop(object):
461 | def __init__(self, size):
462 | self.size = size
463 | self.crop = RandomCrop(self.size)
464 |
465 | def __call__(self, sample):
466 | img = sample['image']
467 | mask = sample['label']
468 | name = sample['img_name']
469 | # print(img.size)
470 | assert img.width == mask.width
471 | assert img.height == mask.height
472 |
473 | seed = random.random()
474 | if seed > 0.5:
475 | w = int(random.uniform(0.5, 1.5) * img.size[0])
476 | h = int(random.uniform(0.5, 1.5) * img.size[1])
477 |
478 | img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)
479 | sample = {'image': img, 'label': mask, 'img_name': name}
480 |
481 | return self.crop(sample)
482 |
483 |
484 | class RandomScaleCrop_BraTS(object):
485 | def __init__(self, size):
486 | self.size = size
487 | self.crop = RandomCrop_BraTS(self.size)
488 |
489 | def __call__(self, sample):
490 | imgs = sample['image']
491 | mask = sample['label']
492 | name = sample['img_name']
493 | # print(img.size)
494 | assert imgs[0].width == mask.width
495 | assert imgs[0].height == mask.height
496 |
497 | seed = random.random()
498 | if seed > 0.5:
499 | w = int(random.uniform(0.5, 1.5) * imgs[0].size[0])
500 | h = int(random.uniform(0.5, 1.5) * imgs[0].size[1])
501 |
502 | imgs[0] = imgs[0].resize((w, h), Image.BILINEAR)
503 | imgs[1] = imgs[1].resize((w, h), Image.BILINEAR)
504 | imgs[2] = imgs[2].resize((w, h), Image.BILINEAR)
505 | imgs[3] = imgs[3].resize((w, h), Image.BILINEAR)
506 | mask = mask.resize((w, h), Image.NEAREST)
507 | sample = {'image': imgs, 'label': mask, 'img_name': name}
508 |
509 | return self.crop(sample)
510 |
511 |
512 | class ResizeImg(object):
513 | def __init__(self, size):
514 | self.size = size
515 |
516 | def __call__(self, sample):
517 | img = sample['image']
518 | mask = sample['label']
519 | name = sample['img_name']
520 | assert img.width == mask.width
521 | assert img.height == mask.height
522 |
523 | img = img.resize((self.size, self.size))
524 |
525 | sample = {'image': img, 'label': mask, 'img_name': name}
526 | return sample
527 |
528 |
529 | class Resize(object):
530 | def __init__(self, size):
531 | self.size = size
532 |
533 | def __call__(self, sample):
534 | img = sample['image']
535 | mask = sample['label']
536 | name = sample['img_name']
537 | assert img.width == mask.width
538 | assert img.height == mask.height
539 |
540 | img = img.resize((self.size, self.size), Image.BILINEAR)
541 | mask = mask.resize((self.size, self.size), Image.NEAREST)
542 |
543 | sample = {'image': img, 'label': mask, 'img_name': name}
544 | return sample
545 |
546 |
547 | class Resize_BraTS(object):
548 | def __init__(self, size):
549 | self.size = size
550 |
551 | def __call__(self, sample):
552 | imgs = sample['image']
553 | mask = sample['label']
554 | name = sample['img_name']
555 | assert imgs[0].width == mask.width
556 | assert imgs[0].height == mask.height
557 |
558 | imgs[0] = imgs[0].resize((self.size, self.size), Image.BILINEAR)
559 | imgs[1] = imgs[1].resize((self.size, self.size), Image.BILINEAR)
560 | imgs[2] = imgs[2].resize((self.size, self.size), Image.BILINEAR)
561 | imgs[3] = imgs[3].resize((self.size, self.size), Image.BILINEAR)
562 | mask = mask.resize((self.size, self.size), Image.NEAREST)
563 |
564 | sample = {'image': imgs, 'label': mask, 'img_name': name}
565 | return sample
566 |
567 |
568 | class Normalize(object):
569 | """Normalize a tensor image with mean and standard deviation.
570 | Args:
571 | mean (tuple): means for each channel.
572 | std (tuple): standard deviations for each channel.
573 | """
574 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
575 | self.mean = mean
576 | self.std = std
577 |
578 | def __call__(self, sample):
579 | img = np.array(sample['image']).astype(np.float32)
580 | mask = np.array(sample['label']).astype(np.float32)
581 | img /= 255.0
582 | img -= self.mean
583 | img /= self.std
584 |
585 | return {'image': img,
586 | 'label': mask,
587 | 'img_name': sample['img_name']}
588 |
589 |
590 | class GetBoundary(object):
591 | def __init__(self, width = 5):
592 | self.width = width
593 | def __call__(self, mask):
594 | cup = mask[:, :, 0]
595 | disc = mask[:, :, 1]
596 | dila_cup = ndimage.binary_dilation(cup, iterations=self.width).astype(cup.dtype)
597 | eros_cup = ndimage.binary_erosion(cup, iterations=self.width).astype(cup.dtype)
598 | dila_disc= ndimage.binary_dilation(disc, iterations=self.width).astype(disc.dtype)
599 | eros_disc= ndimage.binary_erosion(disc, iterations=self.width).astype(disc.dtype)
600 | cup = dila_cup + eros_cup
601 | disc = dila_disc + eros_disc
602 | cup[cup==2]=0
603 | disc[disc==2]=0
604 | boundary = (cup + disc) > 0
605 | return boundary.astype(np.uint8)
606 |
607 |
608 | class Normalize_tf(object):
609 | """Normalize a tensor image with mean and standard deviation.
610 | Args:
611 | mean (tuple): means for each channel.
612 | std (tuple): standard deviations for each channel.
613 | """
614 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
615 | self.mean = mean
616 | self.std = std
617 | self.get_boundary = GetBoundary()
618 |
619 | def __call__(self, sample):
620 | img = np.array(sample['image']).astype(np.float32)
621 | __mask = np.array(sample['label']).astype(np.uint8)
622 | name = sample['img_name']
623 | img /= 127.5
624 | img -= 1.0
625 | _mask = np.zeros([__mask.shape[0], __mask.shape[1]])
626 | _mask[__mask > 200] = 255
627 | _mask[(__mask > 50) & (__mask < 201)] = 128
628 |
629 | __mask[_mask == 0] = 2
630 | __mask[_mask == 255] = 0
631 | __mask[_mask == 128] = 1
632 |
633 | mask = to_multilabel(__mask)
634 | # boundary = self.get_boundary(mask) * 255
635 | # boundary = ndimage.gaussian_filter(boundary, sigma=3) / 255.0
636 | # boundary = np.expand_dims(boundary, -1)
637 |
638 | return {'image': img,
639 | 'label': mask,
640 | # 'boundary': boundary,
641 | 'img_name': name
642 | }
643 |
644 |
645 | class Normalize_cityscapes(object):
646 | """Normalize a tensor image with mean and standard deviation.
647 | Args:
648 | mean (tuple): means for each channel.
649 | std (tuple): standard deviations for each channel.
650 | """
651 | def __init__(self, mean=(0., 0., 0.)):
652 | self.mean = mean
653 |
654 | def __call__(self, sample):
655 | img = np.array(sample['image']).astype(np.float32)
656 | mask = np.array(sample['label']).astype(np.float32)
657 | img -= self.mean
658 | img /= 255.0
659 |
660 | return {'image': img,
661 | 'label': mask,
662 | 'img_name': sample['img_name']}
663 |
664 |
665 | class Normalize_CMR(object):
666 | """
667 | Normalize a tensor image.
668 | """
669 | def __call__(self, sample):
670 | img = np.array(sample['image']).astype(np.float32)
671 | mask = np.array(sample['label']).astype(np.float32)
672 | img /= 255.0
673 |
674 | return {'image': img,
675 | 'label': mask,
676 | 'img_name': sample['img_name']}
677 |
678 |
679 | class Normalize_BraTS(object):
680 | """
681 | Normalize a tensor image.
682 | """
683 | def __call__(self, sample):
684 | if type(sample['image']) == list:
685 | img0 = np.array(sample['image'][0]).astype(np.float32)
686 | img1 = np.array(sample['image'][1]).astype(np.float32)
687 | img2 = np.array(sample['image'][2]).astype(np.float32)
688 | img3 = np.array(sample['image'][3]).astype(np.float32)
689 | img = np.zeros((img0.shape[0], img0.shape[1], 4), dtype=np.float32)
690 | img[:, :, 0] = img0
691 | img[:, :, 1] = img1
692 | img[:, :, 2] = img2
693 | img[:, :, 3] = img3
694 | else:
695 | img = np.array(sample['image']).astype(np.float32)
696 | img /= 255.0
697 |
698 | return {'image': img,
699 | 'label': sample['label'],
700 | 'img_name': sample['img_name']}
701 |
702 |
703 | class ToTensor(object):
704 | """Convert ndarrays in sample to Tensors."""
705 |
706 | def __call__(self, sample):
707 | # swap color axis because
708 | # numpy image: H x W x C
709 | # torch image: C X H X W
710 | img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
711 | map = np.array(sample['label']).astype(np.uint8).transpose((2, 0, 1))
712 | # boundary = np.array(sample['boundary']).astype(np.float).transpose((2, 0, 1))
713 | name = sample['img_name']
714 | img = torch.from_numpy(img).float()
715 | map = torch.from_numpy(map).float()
716 | # boundary = torch.from_numpy(boundary).float()
717 |
718 | return {'image': img,
719 | 'label': map,
720 | # 'boundary': boundary,
721 | 'img_name': name}
722 |
723 |
724 | class LabelOneHotAndToTensor(object):
725 | """
726 | Turn number(0, 80, 160, 180) label into one-hot label.
727 | Convert PIL.Image in sample to Tensors.
728 | """
729 | def __call__(self, sample):
730 | img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
731 | img = torch.from_numpy(img).type(dtype=torch.FloatTensor)
732 |
733 | map = np.array(sample['label'])
734 | map = (map == 80) * 1 + (map == 160) * 2 + (map == 240) * 3
735 | map = torch.from_numpy(map).type(dtype=torch.LongTensor)
736 | label_onehot = torch.FloatTensor(4, map.size()[0], map.size()[1])
737 | label_onehot.zero_()
738 | label_onehot.scatter_(0, map.unsqueeze(dim=0), 1)
739 |
740 | name = sample['img_name']
741 |
742 | return {'image': img,
743 | 'label': label_onehot,
744 | 'img_name': name}
--------------------------------------------------------------------------------