├── .gitignore ├── README.md ├── monte_carlo_dropout ├── __init__.py ├── alexnet.py ├── mc_dropout.py ├── model_meta.py ├── resnet.py └── unet_learner.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | **/*.rs.bk 3 | Cargo.lock 4 | .DS_Stores 5 | .idea/ 6 | .cache/ 7 | .vscode/ 8 | *.pyc 9 | __pycache__/ 10 | .python-version 11 | /chariots.egg-info/ 12 | /.tox/ 13 | /.eggs/ 14 | .pytest_cache/ 15 | .ipynb_checkpoints/ 16 | .DS_Store 17 | node_modules 18 | /dist 19 | 20 | # local env files 21 | .env.local 22 | .env.*.local 23 | 24 | # Log files 25 | npm-debug.log* 26 | yarn-debug.log* 27 | yarn-error.log* 28 | 29 | # Editor directories and files 30 | .idea 31 | .vscode 32 | *.suo 33 | *.ntvs* 34 | *.njsproj 35 | *.sln 36 | *.sw? 37 | .editorconfig 38 | yarn.lock 39 | 40 | # doc 41 | docs/_build/* 42 | 43 | # notebooks 44 | *.ipynb 45 | !project_template/{{cookiecutter.project_name}}/notebooks/*.ipynb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Monte_carlo_dropout 2 | Using monte carlo dropout to have an estimation of predictions uncertainty 3 | 4 | ### instalation 5 | ``` 6 | cd monte_carlo_dropout 7 | pip install -e ./ 8 | ``` 9 | 10 | ### usage 11 | 12 | executing the ` unet_learner` function will give you the modified unet with dropout. 13 | using the `DropOutAlexnet` class will give you the alexnet architecture with dropout added. 14 | 15 | ### credits: 16 | __Fastai online resources:__ 17 | - Complete code for image segmentation with the One Hundred Layers Tiramisu (FC-DenseNet model) https://github.com/fastai/course-v3/blob/master/nbs/dl1/lesson3-camvid-tiramisu.ipynb 18 | - Research papers used 19 | - Dropout as a Bayesian Approximation: Representing Model: https://arxiv.org/abs/1506.02142 20 | - Bayesian Convolutional Neural Networks with Bernoulli Approximate Variational Inference: https://arxiv.org/abs/1506.02158 21 | - Nature paper : https://www.nature.com/articles/s41598-019-50587-1?fbclid=IwAR3vS2Jsa16NtOdFgp-I_deIwT8ipsK0isY6oIzBeaPHjOllhDSv1FfAVGg 22 | 23 | 24 | -------------------------------------------------------------------------------- /monte_carlo_dropout/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aredier/monte_carlo_dropout/5271f0b45c855958a243e997c1d6860dc6ee874e/monte_carlo_dropout/__init__.py -------------------------------------------------------------------------------- /monte_carlo_dropout/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['AlexNet', 'alexnet'] 6 | 7 | from monte_carlo_dropout.mc_dropout import MCDropout 8 | 9 | model_urls = { 10 | 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 11 | } 12 | 13 | 14 | class AlexNet(nn.Module): 15 | 16 | def __init__(self, num_classes=1000): 17 | super(AlexNet, self).__init__() 18 | self.features = nn.Sequential( 19 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 20 | nn.ReLU(inplace=True), 21 | nn.MaxPool2d(kernel_size=3, stride=2), 22 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 23 | nn.ReLU(inplace=True), 24 | nn.MaxPool2d(kernel_size=3, stride=2), 25 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 30 | nn.ReLU(inplace=True), 31 | nn.MaxPool2d(kernel_size=3, stride=2), 32 | ) 33 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 34 | self.classifier = nn.Sequential( 35 | nn.Dropout(), 36 | nn.Linear(256 * 6 * 6, 4096), 37 | nn.ReLU(inplace=True), 38 | nn.Dropout(), 39 | nn.Linear(4096, 4096), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(4096, num_classes), 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.features(x) 46 | x = self.avgpool(x) 47 | x = torch.flatten(x, 1) 48 | x = self.classifier(x) 49 | return x 50 | 51 | 52 | class DropOutAlexnet(nn.Module): 53 | 54 | def __init__(self, og_alex: AlexNet, conv_dropout: float, dense_dropout: float, force_dropout: bool): 55 | 56 | super().__init__() 57 | self.dense_dropout = dense_dropout 58 | self.force_dropout = force_dropout 59 | 60 | # features 61 | feature_layers = list(og_alex.features.children()) 62 | self.conv_1 = feature_layers[0] 63 | self.relu_1 = feature_layers[1] 64 | self.pool_1 = feature_layers[2] 65 | self.dropout_1 = MCDropout(conv_dropout) 66 | self.conv_2 = feature_layers[3] 67 | self.relu_2 = feature_layers[4] 68 | self.pool_2 = feature_layers[5] 69 | self.dropout_2 = MCDropout(conv_dropout) 70 | self.conv_3 = feature_layers[6] 71 | self.relu_3 = feature_layers[7] 72 | self.dropout_3 = MCDropout(conv_dropout) 73 | self.conv_4 = feature_layers[8] 74 | self.relu_4 = feature_layers[9] 75 | self.dropout_4 = MCDropout(conv_dropout) 76 | self.conv_5 = feature_layers[10] 77 | self.relu_5 = feature_layers[11] 78 | self.pool_3 = feature_layers[12] 79 | 80 | self.avgpool = og_alex.avgpool 81 | 82 | # classifier 83 | clasifier_layers = list(og_alex.classifier.children()) 84 | self.dropout_5 = MCDropout(dense_dropout) 85 | self.dense_1 = clasifier_layers[1] 86 | self.relu_6 = clasifier_layers[2] 87 | self.dropout_6 = MCDropout(dense_dropout) 88 | self.dense_2 = clasifier_layers[4] 89 | self.relu_7 = clasifier_layers[5] 90 | self.dense_3 = clasifier_layers[6] 91 | 92 | def forward(self, x): 93 | out = self.conv_1(x) 94 | out = self.relu_1(out) 95 | out = self.pool_1(out) 96 | out = self.dropout_1(out) 97 | out = self.conv_2(out) 98 | out = self.relu_2(out) 99 | out = self.pool_2(out) 100 | out = self.dropout_2(out) 101 | out = self.conv_3(out) 102 | out = self.relu_3(out) 103 | out = self.dropout_3(out) 104 | out = self.conv_4(out) 105 | out = self.relu_4(out) 106 | out = self.dropout_4(out) 107 | out = self.conv_5(out) 108 | out = self.relu_5(out) 109 | out = self.pool_3(out) 110 | 111 | out = self.avgpool(out) 112 | 113 | out = self.dropout_5(out) 114 | out = self.dense_1(out) 115 | out = self.relu_6(out) 116 | out = self.dropout_6(out) 117 | out = self.dense_2(out) 118 | out = self.relu_7(out) 119 | out = self.dense_3(out) 120 | 121 | return out 122 | -------------------------------------------------------------------------------- /monte_carlo_dropout/mc_dropout.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class MCDropout(nn.Module): 5 | def __init__(self, p: float = 0.5, force_dropout: bool = False): 6 | super().__init__() 7 | self.force_dropout = force_dropout 8 | self.p = p 9 | 10 | def forward(self, x): 11 | return nn.functional.dropout(x, p=self.p, training=self.training or self.force_dropout) 12 | -------------------------------------------------------------------------------- /monte_carlo_dropout/model_meta.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torchvision import models 3 | 4 | def _default_split(m:nn.Module): return (m[1],) 5 | # Split a resnet style model 6 | def _resnet_split(m:nn.Module): return (m[0][6],m[1]) 7 | # Split squeezenet model on maxpool layers 8 | def _squeezenet_split(m:nn.Module): return (m[0][0][5], m[0][0][8], m[1]) 9 | def _densenet_split(m:nn.Module): return (m[0][0][7],m[1]) 10 | def _vgg_split(m:nn.Module): return (m[0][0][22],m[1]) 11 | def _alexnet_split(m:nn.Module): return (m[0][0][6],m[1]) 12 | 13 | _default_meta = {'cut':None, 'split':_default_split} 14 | _resnet_meta = {'cut':-2, 'split':_resnet_split } 15 | _squeezenet_meta = {'cut':-1, 'split': _squeezenet_split} 16 | _densenet_meta = {'cut':-1, 'split':_densenet_split} 17 | _vgg_meta = {'cut':-1, 'split':_vgg_split} 18 | _alexnet_meta = {'cut':-1, 'split':_alexnet_split} 19 | 20 | model_meta = { 21 | models.resnet18 :{**_resnet_meta}, models.resnet34: {**_resnet_meta}, 22 | models.resnet50 :{**_resnet_meta}, models.resnet101:{**_resnet_meta}, 23 | models.resnet152:{**_resnet_meta}, 24 | 25 | models.squeezenet1_0:{**_squeezenet_meta}, 26 | models.squeezenet1_1:{**_squeezenet_meta}, 27 | 28 | models.densenet121:{**_densenet_meta}, models.densenet169:{**_densenet_meta}, 29 | models.densenet201:{**_densenet_meta}, models.densenet161:{**_densenet_meta}, 30 | models.vgg11_bn:{**_vgg_meta}, models.vgg13_bn:{**_vgg_meta}, models.vgg16_bn:{**_vgg_meta}, models.vgg19_bn:{**_vgg_meta}, 31 | models.alexnet:{**_alexnet_meta}} -------------------------------------------------------------------------------- /monte_carlo_dropout/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models.utils import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | __constants__ = ['downsample'] 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 40 | base_width=64, dilation=1, norm_layer=None): 41 | super(BasicBlock, self).__init__() 42 | if norm_layer is None: 43 | norm_layer = nn.BatchNorm2d 44 | if groups != 1 or base_width != 64: 45 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 46 | if dilation > 1: 47 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 48 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 49 | self.conv1 = conv3x3(inplanes, planes, stride) 50 | self.bn1 = norm_layer(planes) 51 | self.relu = nn.ReLU(inplace=True) 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | self.downsample = downsample 55 | self.stride = stride 56 | 57 | def forward(self, x): 58 | identity = x 59 | 60 | out = self.conv1(x) 61 | out = self.bn1(out) 62 | out = self.relu(out) 63 | 64 | out = self.conv2(out) 65 | out = self.bn2(out) 66 | 67 | if self.downsample is not None: 68 | identity = self.downsample(x) 69 | 70 | out += identity 71 | out = self.relu(out) 72 | 73 | return out 74 | 75 | class DropoutBlock(nn.Module): 76 | """ 77 | same as a basic block but adding dropout to it 78 | """ 79 | 80 | def __init__(self, basic_block: BasicBlock, dropout_rate: float = 0.): 81 | super(DropoutBlock, self).__init__() 82 | self.conv1 = basic_block.conv1 83 | self.bn1 = basic_block.bn1 84 | self.relu = basic_block.relu 85 | self.conv2 = basic_block.conv2 86 | self.bn2 = basic_block.bn2 87 | self.downsample = basic_block.downsample 88 | self.stride = basic_block.stride 89 | self.force_dropout = False 90 | self.dropout_rate = dropout_rate 91 | 92 | def forward(self, x): 93 | identity = x 94 | 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | out = torch.nn.functional.dropout(out, p=self.dropout_rate, training=self.training or self.force_dropout) 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | 102 | if self.downsample is not None: 103 | identity = self.downsample(x) 104 | 105 | out = torch.nn.functional.dropout(out, p=self.dropout_rate, training=self.training or self.force_dropout) 106 | 107 | out += identity 108 | out = self.relu(out) 109 | 110 | return out 111 | 112 | 113 | 114 | class Bottleneck(nn.Module): 115 | expansion = 4 116 | __constants__ = ['downsample'] 117 | 118 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 119 | base_width=64, dilation=1, norm_layer=None): 120 | super(Bottleneck, self).__init__() 121 | if norm_layer is None: 122 | norm_layer = nn.BatchNorm2d 123 | width = int(planes * (base_width / 64.)) * groups 124 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 125 | self.conv1 = conv1x1(inplanes, width) 126 | self.bn1 = norm_layer(width) 127 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 128 | self.bn2 = norm_layer(width) 129 | self.conv3 = conv1x1(width, planes * self.expansion) 130 | self.bn3 = norm_layer(planes * self.expansion) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.downsample = downsample 133 | self.stride = stride 134 | 135 | def forward(self, x): 136 | identity = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | out = self.bn2(out) 144 | out = self.relu(out) 145 | 146 | out = self.conv3(out) 147 | out = self.bn3(out) 148 | 149 | if self.downsample is not None: 150 | identity = self.downsample(x) 151 | 152 | out += identity 153 | out = self.relu(out) 154 | 155 | return out 156 | 157 | 158 | class ResNet(nn.Module): 159 | 160 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 161 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 162 | norm_layer=None): 163 | super(ResNet, self).__init__() 164 | if norm_layer is None: 165 | norm_layer = nn.BatchNorm2d 166 | self._norm_layer = norm_layer 167 | 168 | self.inplanes = 64 169 | self.dilation = 1 170 | if replace_stride_with_dilation is None: 171 | # each element in the tuple indicates if we should replace 172 | # the 2x2 stride with a dilated convolution instead 173 | replace_stride_with_dilation = [False, False, False] 174 | if len(replace_stride_with_dilation) != 3: 175 | raise ValueError("replace_stride_with_dilation should be None " 176 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 177 | self.groups = groups 178 | self.base_width = width_per_group 179 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 180 | bias=False) 181 | self.bn1 = norm_layer(self.inplanes) 182 | self.relu = nn.ReLU(inplace=True) 183 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 184 | self.layer1 = self._make_layer(block, 64, layers[0]) 185 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 186 | dilate=replace_stride_with_dilation[0]) 187 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 188 | dilate=replace_stride_with_dilation[1]) 189 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 190 | dilate=replace_stride_with_dilation[2]) 191 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 192 | self.fc = nn.Linear(512 * block.expansion, num_classes) 193 | 194 | for m in self.modules(): 195 | if isinstance(m, nn.Conv2d): 196 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 197 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 198 | nn.init.constant_(m.weight, 1) 199 | nn.init.constant_(m.bias, 0) 200 | 201 | # Zero-initialize the last BN in each residual branch, 202 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 203 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 204 | if zero_init_residual: 205 | for m in self.modules(): 206 | if isinstance(m, Bottleneck): 207 | nn.init.constant_(m.bn3.weight, 0) 208 | elif isinstance(m, BasicBlock): 209 | nn.init.constant_(m.bn2.weight, 0) 210 | 211 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 212 | norm_layer = self._norm_layer 213 | downsample = None 214 | previous_dilation = self.dilation 215 | if dilate: 216 | self.dilation *= stride 217 | stride = 1 218 | if stride != 1 or self.inplanes != planes * block.expansion: 219 | downsample = nn.Sequential( 220 | conv1x1(self.inplanes, planes * block.expansion, stride), 221 | norm_layer(planes * block.expansion), 222 | ) 223 | 224 | layers = [] 225 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 226 | self.base_width, previous_dilation, norm_layer)) 227 | self.inplanes = planes * block.expansion 228 | for _ in range(1, blocks): 229 | layers.append(block(self.inplanes, planes, groups=self.groups, 230 | base_width=self.base_width, dilation=self.dilation, 231 | norm_layer=norm_layer)) 232 | 233 | return nn.Sequential(*layers) 234 | 235 | def _forward(self, x): 236 | x = self.conv1(x) 237 | x = self.bn1(x) 238 | x = self.relu(x) 239 | x = self.maxpool(x) 240 | 241 | x = self.layer1(x) 242 | x = self.layer2(x) 243 | x = self.layer3(x) 244 | x = self.layer4(x) 245 | 246 | x = self.avgpool(x) 247 | x = torch.flatten(x, 1) 248 | x = self.fc(x) 249 | 250 | return x 251 | 252 | # Allow for accessing forward method in a inherited class 253 | forward = _forward 254 | 255 | 256 | class DropoutResnet(nn.Module): 257 | """adds dropout to an existing resnet""" 258 | 259 | def __init__(self, source_resnet: ResNet, dropout_rate: float = .0): 260 | 261 | super(DropoutResnet, self).__init__() 262 | self._norm_layer = source_resnet._norm_layer 263 | 264 | self.inplanes = source_resnet.inplanes 265 | self.dilation = source_resnet.dilation 266 | self.groups = source_resnet.groups 267 | self.base_width = source_resnet.base_width 268 | self.conv1 = source_resnet.conv1 269 | self.bn1 = source_resnet.bn1 270 | self.relu = source_resnet.relu 271 | self.maxpool = source_resnet.relu 272 | self.layer1 = self._make_layer(source_resnet.layer1, dropout_rate) 273 | self.layer2 = self._make_layer(source_resnet.layer2, dropout_rate) 274 | self.layer3 = self._make_layer(source_resnet.layer3, dropout_rate) 275 | self.layer4 = self._make_layer(source_resnet.layer4, dropout_rate) 276 | self.avgpool = source_resnet.avgpool 277 | self.fc = source_resnet.fc 278 | 279 | @staticmethod 280 | def _set_force_dropout_on_layer(force_dropout: bool, layer: nn.Sequential): 281 | for block in layer.children(): 282 | block.force_dropout = force_dropout 283 | 284 | def set_force_dropout(self, force_dropout): 285 | self._set_force_dropout_on_layer(force_dropout, self.layer1) 286 | self._set_force_dropout_on_layer(force_dropout, self.layer2) 287 | self._set_force_dropout_on_layer(force_dropout, self.layer3) 288 | self._set_force_dropout_on_layer(force_dropout, self.layer4) 289 | 290 | def _make_layer(self, source_layer: nn.Sequential, dropout_rate): 291 | return nn.Sequential(*[DropoutBlock(block, dropout_rate) for block in source_layer.children()]) 292 | 293 | def _forward(self, x): 294 | x = self.conv1(x) 295 | x = self.bn1(x) 296 | x = self.relu(x) 297 | x = self.maxpool(x) 298 | 299 | x = self.layer1(x) 300 | x = self.layer2(x) 301 | x = self.layer3(x) 302 | x = self.layer4(x) 303 | 304 | x = self.avgpool(x) 305 | x = torch.flatten(x, 1) 306 | x = self.fc(x) 307 | 308 | return x 309 | 310 | # Allow for accessing forward method in a inherited class 311 | forward = _forward 312 | 313 | 314 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 315 | model = ResNet(block, layers, **kwargs) 316 | if pretrained: 317 | state_dict = load_state_dict_from_url(model_urls[arch], 318 | progress=progress) 319 | model.load_state_dict(state_dict) 320 | return model 321 | 322 | 323 | def resnet18(pretrained=False, progress=True, **kwargs): 324 | r"""ResNet-18 model from 325 | `"Deep Residual Learning for Image Recognition" `_ 326 | 327 | Args: 328 | pretrained (bool): If True, returns a model pre-trained on ImageNet 329 | progress (bool): If True, displays a progress bar of the download to stderr 330 | """ 331 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 332 | **kwargs) 333 | 334 | 335 | def resnet34(pretrained=False, progress=True, **kwargs): 336 | r"""ResNet-34 model from 337 | `"Deep Residual Learning for Image Recognition" `_ 338 | 339 | Args: 340 | pretrained (bool): If True, returns a model pre-trained on ImageNet 341 | progress (bool): If True, displays a progress bar of the download to stderr 342 | """ 343 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 344 | **kwargs) 345 | 346 | 347 | def resnet50(pretrained=False, progress=True, **kwargs): 348 | r"""ResNet-50 model from 349 | `"Deep Residual Learning for Image Recognition" `_ 350 | 351 | Args: 352 | pretrained (bool): If True, returns a model pre-trained on ImageNet 353 | progress (bool): If True, displays a progress bar of the download to stderr 354 | """ 355 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 356 | **kwargs) 357 | 358 | 359 | def resnet101(pretrained=False, progress=True, **kwargs): 360 | r"""ResNet-101 model from 361 | `"Deep Residual Learning for Image Recognition" `_ 362 | 363 | Args: 364 | pretrained (bool): If True, returns a model pre-trained on ImageNet 365 | progress (bool): If True, displays a progress bar of the download to stderr 366 | """ 367 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 368 | **kwargs) 369 | 370 | 371 | def resnet152(pretrained=False, progress=True, **kwargs): 372 | r"""ResNet-152 model from 373 | `"Deep Residual Learning for Image Recognition" `_ 374 | 375 | Args: 376 | pretrained (bool): If True, returns a model pre-trained on ImageNet 377 | progress (bool): If True, displays a progress bar of the download to stderr 378 | """ 379 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 380 | **kwargs) 381 | 382 | 383 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 384 | r"""ResNeXt-50 32x4d model from 385 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 386 | 387 | Args: 388 | pretrained (bool): If True, returns a model pre-trained on ImageNet 389 | progress (bool): If True, displays a progress bar of the download to stderr 390 | """ 391 | kwargs['groups'] = 32 392 | kwargs['width_per_group'] = 4 393 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 394 | pretrained, progress, **kwargs) 395 | 396 | 397 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 398 | r"""ResNeXt-101 32x8d model from 399 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 400 | 401 | Args: 402 | pretrained (bool): If True, returns a model pre-trained on ImageNet 403 | progress (bool): If True, displays a progress bar of the download to stderr 404 | """ 405 | kwargs['groups'] = 32 406 | kwargs['width_per_group'] = 8 407 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 408 | pretrained, progress, **kwargs) 409 | 410 | 411 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 412 | r"""Wide ResNet-50-2 model from 413 | `"Wide Residual Networks" `_ 414 | 415 | The model is the same as ResNet except for the bottleneck number of channels 416 | which is twice larger in every block. The number of channels in outer 1x1 417 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 418 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 419 | 420 | Args: 421 | pretrained (bool): If True, returns a model pre-trained on ImageNet 422 | progress (bool): If True, displays a progress bar of the download to stderr 423 | """ 424 | kwargs['width_per_group'] = 64 * 2 425 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 426 | pretrained, progress, **kwargs) 427 | 428 | 429 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 430 | r"""Wide ResNet-101-2 model from 431 | `"Wide Residual Networks" `_ 432 | 433 | The model is the same as ResNet except for the bottleneck number of channels 434 | which is twice larger in every block. The number of channels in outer 1x1 435 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 436 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 437 | 438 | Args: 439 | pretrained (bool): If True, returns a model pre-trained on ImageNet 440 | progress (bool): If True, displays a progress bar of the download to stderr 441 | """ 442 | kwargs['width_per_group'] = 64 * 2 443 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 444 | pretrained, progress, **kwargs) 445 | -------------------------------------------------------------------------------- /monte_carlo_dropout/unet_learner.py: -------------------------------------------------------------------------------- 1 | from typing import List, Callable, Optional, Tuple, Union, Any 2 | 3 | import numpy as np 4 | 5 | from fastai.torch_core import * 6 | from fastai.basic_train import * 7 | from fastai.basic_data import * 8 | from fastai.layers import * 9 | from fastai.callbacks.hooks import * 10 | 11 | from monte_carlo_dropout.mc_dropout import MCDropout 12 | from monte_carlo_dropout.model_meta import model_meta, _default_meta 13 | 14 | 15 | def _get_sfs_idxs(sizes: Sizes) -> List[int]: 16 | "Get the indexes of the layers where the size of the activation changes." 17 | feature_szs = [size[-1] for size in sizes] 18 | sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]) 19 | if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs 20 | return sfs_idxs 21 | 22 | 23 | class UnetBlock(Module): 24 | "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`." 25 | def __init__(self, up_in_c: int, x_in_c: int, hook: Hook, final_div: bool = True, blur: bool = False, 26 | leaky: float = None, self_attention: bool = False, dropout_rate: float = .0, 27 | **kwargs): 28 | self.hook = hook 29 | self.shuf = PixelShuffle_ICNR(up_in_c, up_in_c//2, blur=blur, leaky=leaky, **kwargs) 30 | self.bn = batchnorm_2d(x_in_c) 31 | ni = up_in_c//2 + x_in_c 32 | nf = ni if final_div else ni//2 33 | self.conv1 = conv_layer(ni, nf, leaky=leaky, **kwargs) 34 | self.conv2 = conv_layer(nf, nf, leaky=leaky, self_attention=self_attention, **kwargs) 35 | self.dropout = MCDropout(dropout_rate) 36 | self.relu = relu(leaky=leaky) 37 | 38 | def forward(self, up_in: Tensor) -> Tensor: 39 | s = self.hook.stored 40 | up_out = self.shuf(up_in) 41 | ssh = s.shape[-2:] 42 | if ssh != up_out.shape[-2:]: 43 | up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') 44 | out = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) 45 | out = self.conv2(self.conv1(out)) 46 | out = self.dropout(out) 47 | return out 48 | 49 | 50 | class DynamicUnet(SequentialEx): 51 | "Create a U-Net from a given architecture." 52 | 53 | def __init__(self, encoder: nn.Module, n_classes: int, block_drate: float, final_drate: float, 54 | img_size: Tuple[int,int] = (256, 256), blur: bool = False, blur_final = True, 55 | self_attention: bool = False, y_range: Optional[Tuple[float, float]] = None, last_cross: bool = True, 56 | bottle: bool = False, **kwargs): 57 | 58 | imsize = img_size 59 | sfs_szs = model_sizes(encoder, size=imsize) 60 | sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) 61 | self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False) 62 | x = dummy_eval(encoder, imsize).detach() 63 | 64 | ni = sfs_szs[-1][1] 65 | middle_conv = nn.Sequential(conv_layer(ni, ni*2, **kwargs), 66 | conv_layer(ni*2, ni, **kwargs)).eval() 67 | x = middle_conv(x) 68 | layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv] 69 | 70 | for i, idx in enumerate(sfs_idxs): 71 | not_final = i != len(sfs_idxs)-1 72 | up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) 73 | do_blur = blur and (not_final or blur_final) 74 | sa = self_attention and (i==len(sfs_idxs)-3) 75 | unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa, 76 | dropout_rate=block_drate, 77 | **kwargs).eval() 78 | layers.append(unet_block) 79 | x = unet_block(x) 80 | 81 | ni = x.shape[1] 82 | if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs)) 83 | x = PixelShuffle_ICNR(ni)(x) 84 | if imsize != x.shape[-2:]: layers.append(Lambda(lambda x: F.interpolate(x, imsize, mode='nearest'))) 85 | if last_cross: 86 | layers.append(MergeLayer(dense=True)) 87 | ni += in_channels(encoder) 88 | layers.append(res_block(ni, bottle=bottle, **kwargs)) 89 | layers += [MCDropout(final_drate)] 90 | layers += [conv_layer(ni, n_classes, ks=1, use_activ=False, **kwargs)] 91 | if y_range is not None: 92 | layers.append(SigmoidRange(*y_range)) 93 | super().__init__(*layers) 94 | 95 | def __del__(self): 96 | if hasattr(self, "sfs"): self.sfs.remove() 97 | 98 | @classmethod 99 | def _set_force_dropout(cls, module: nn.Module, mode): 100 | for submodule in module.children(): 101 | if isinstance(submodule, MCDropout): 102 | submodule.force_dropout = mode 103 | cls._set_force_dropout(submodule, mode) 104 | 105 | def force_dropout(self, mode: bool = True): 106 | self._set_force_dropout(self, mode) 107 | 108 | 109 | def get_mc_dropout_preds(learner, n_iter=10, dtype=torch.float32, **pred_args): 110 | 111 | preds = learner.get_preds()[0].to(dtype) 112 | for i in range(1, n_iter): 113 | preds += learner.get_preds()[0].to(dtype) 114 | return preds / n_iter 115 | 116 | 117 | def has_pool_type(m): 118 | if is_pool_type(m): return True 119 | for l in m.children(): 120 | if has_pool_type(l): return True 121 | return False 122 | 123 | 124 | def create_body(arch:Callable, pretrained:bool=True, cut:Optional[Union[int, Callable]]=None): 125 | """ 126 | Cut off the body of a typically pretrained `model` at `cut` (int) or cut the model as specified by `cut(model)` 127 | (function). 128 | """ 129 | model = arch(pretrained) 130 | cut = ifnone(cut, cnn_config(arch)['cut']) 131 | if cut is None: 132 | ll = list(enumerate(model.children())) 133 | cut = next(i for i,o in reversed(ll) if has_pool_type(o)) 134 | if isinstance(cut, int): 135 | return nn.Sequential(*list(model.children())[:cut]) 136 | elif isinstance(cut, Callable): 137 | return cut(model) 138 | else: 139 | raise TypeError("cut must be either integer or a function") 140 | 141 | 142 | def cnn_config(arch): 143 | "Get the metadata associated with `arch`." 144 | torch.backends.cudnn.benchmark = True 145 | return model_meta.get(arch, _default_meta) 146 | 147 | 148 | def unet_learner(data: DataBunch, arch: Callable, block_drate: float, final_drate: float, 149 | pretrained: bool = True, blur_final: bool = True, 150 | norm_type: Optional[NormType] = NormType, split_on: Optional[SplitFuncOrIdxList] = None, 151 | blur: bool = False, self_attention: bool = False, y_range: Optional[Tuple[float, float]] = None, 152 | last_cross: bool = True, bottle: bool = False, cut: Union[int, Callable]=None, 153 | **learn_kwargs: Any) -> Learner: 154 | "Build Unet learner from `data` and `arch`." 155 | meta = cnn_config(arch) 156 | body = create_body(arch, pretrained, cut) 157 | try: size = data.train_ds[0][0].size 158 | except: size = next(iter(data.train_dl))[0].shape[-2:] 159 | model = to_device(DynamicUnet(body, n_classes=data.c, img_size=size, blur=blur, blur_final=blur_final, 160 | self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross, 161 | bottle=bottle, block_drate=block_drate, final_drate=final_drate), data.device) 162 | learn = Learner(data, model, **learn_kwargs) 163 | learn.split(ifnone(split_on, meta['split'])) 164 | if pretrained: learn.freeze() 165 | apply_init(model[2], nn.init.kaiming_normal_) 166 | return learn -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('requirements.txt', 'r') as file: 4 | requirements = list(file.readlines()) 5 | 6 | 7 | setup( 8 | name='monte_carlo_dropout', 9 | version='0.1.0', 10 | url='https://github.com/aredier/monte_carlo_dropout.git', 11 | author='Author Name', 12 | author_email='author@gmail.com', 13 | description='using dropout to infer confidence over a NN output', 14 | packages=find_packages(), 15 | install_requires=requirements, 16 | ) --------------------------------------------------------------------------------