├── LICENSE.txt ├── README.md ├── data ├── __init__.py └── dataset.py ├── models ├── __init__.py ├── cflow.png ├── cflownet.py ├── flows.py ├── layers.py ├── unet.py └── unet_blocks.py ├── nflib ├── __init__.py ├── flows.py ├── made.py ├── nets.py └── spline_flows.py ├── train_model.py └── utils ├── __init__.py ├── tools.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Raghavendra Selvan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # README # 2 | 3 | This is official Pytorch implementation of 4 | "[Uncertainty quantification in medical 5 | image segmentation with Normalizing Flows](https://arxiv.org/abs/2006.02683)", Raghavendra Selvan et al. 2020 6 | 7 | ![lotenet](models/cflow.png) 8 | ### What is this repository for? ### 9 | 10 | * Train the proposed model on LIDC and Retina datasets 11 | * Reproduce the reported numbers in the paper 12 | * v1.0 13 | 14 | ### How do I get set up? ### 15 | 16 | * Basic Pytorch dependency 17 | * Tested on Pytorch 1.3, Python 3.6 18 | * Download preprocessed LIDC dataset [from here](https://sid.erda.dk/share_redirect/BgFgw4NMf4). ** Change the file name with .zip after downloading. ** 19 | 20 | ### Usage guidelines ### 21 | 22 | * Kindly cite our publication if you use any part of the code 23 | ``` 24 | @inproceedings{raghav2020cFlowNet, 25 | title={Uncertainty quantification in medical image segmentation with Normalizing Flows}, 26 | author={Raghavendra Selvan, Frederik Faye, Jon Middleton, Akshay Pai}, 27 | booktitle={11th International Workshop on Machine Learning in Medical Imaging}, 28 | month={October}, 29 | year={2020} 30 | url={https://arxiv.org/abs/2006.02683}} 31 | 32 | ``` 33 | 34 | ### Who do I talk to? ### 35 | 36 | * raghav@di.ku.dk 37 | 38 | ### Thanks 39 | Some parts of our implementation are based on: 40 | * [Prob.U-Net and LIDC data](https://github.com/stefanknegt/Probabilistic-Unet-Pytorch) 41 | * [Planar flows](https://github.com/riannevdberg/sylvester-flows) 42 | * [Glow model](https://github.com/karpathy/pytorch-normalizing-flows) 43 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghavian/cFlow/79caf3cc9ccc1b6a21a3bc157f432b40947ff217/data/__init__.py -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | from torch.utils.data import TensorDataset, DataLoader, Dataset 8 | import pdb 9 | 10 | class LIDC(Dataset): 11 | def __init__(self, rater=4, data_dir = '/datadrive/raghav/lidc/lidcSeg/raw/', transform=None): 12 | super().__init__() 13 | 14 | self.data_dir = data_dir 15 | self.rater = rater 16 | self.transform = transform 17 | self.data, self.targets = torch.load(data_dir+'lidcSeg.pt') 18 | 19 | def __len__(self): 20 | return len(self.targets) 21 | 22 | def __getitem__(self, index): 23 | 24 | image, label = self.data[index], self.targets[index] 25 | if self.transform is not None: 26 | image = self.transform(image) 27 | return image, label.type(torch.FloatTensor) 28 | 29 | 30 | class Drive(Dataset): 31 | def __init__(self, split='train', data_dir = '/datadrive/raghav/retinaDataset/', 32 | transform=None, target_transform=None): 33 | super().__init__() 34 | 35 | self.data_dir = data_dir 36 | self.transform = transform 37 | self.target_transform = target_transform 38 | self.data, self.targets = torch.load(data_dir+'retina_'+split+'_new.pt') 39 | 40 | def __len__(self): 41 | return len(self.targets) 42 | 43 | def __getitem__(self, index): 44 | 45 | image, label_mask = self.data[index], self.targets[index] 46 | if self.transform is not None: 47 | image = self.transform(image) 48 | label_mask = self.target_transform(label_mask) 49 | return image, label_mask[:2], label_mask[[2]] 50 | 51 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghavian/cFlow/79caf3cc9ccc1b6a21a3bc157f432b40947ff217/models/__init__.py -------------------------------------------------------------------------------- /models/cflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghavian/cFlow/79caf3cc9ccc1b6a21a3bc157f432b40947ff217/models/cflow.png -------------------------------------------------------------------------------- /models/cflownet.py: -------------------------------------------------------------------------------- 1 | #This code is based on: https://github.com/SimonKohl/probabilistic_unet 2 | 3 | from models.unet_blocks import * 4 | from models.unet import Unet 5 | from utils.utils import init_weights,init_weights_orthogonal_normal, l2_regularisation 6 | import torch.nn.functional as F 7 | from torch.distributions import Normal, Independent, kl 8 | import pdb 9 | from models import flows 10 | from nflib.flows import * 11 | from models.layers import * 12 | from utils.tools import dice_loss 13 | 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | class Encoder(nn.Module): 17 | """ 18 | A convolutional neural network, consisting of len(num_filters) times a block of no_convs_per_block convolutional layers, 19 | after each block a pooling operation is performed. And after each convolutional layer a non-linear (ReLU) activation function is applied. 20 | """ 21 | def __init__(self, input_channels, num_filters, no_convs_per_block, initializers, padding=True, posterior=False,norm=False): 22 | super(Encoder, self).__init__() 23 | self.contracting_path = nn.ModuleList() 24 | self.input_channels = input_channels 25 | self.num_filters = num_filters 26 | 27 | if posterior: 28 | #To accomodate for the mask that is concatenated at the channel axis, we increase the input_channels. 29 | self.input_channels += 1 30 | 31 | layers = [] 32 | for i in range(len(self.num_filters)): 33 | """ 34 | Determine input_dim and output_dim of conv layers in this block. The first layer is input x output, 35 | All the subsequent layers are output x output. 36 | """ 37 | input_dim = self.input_channels if i == 0 else output_dim 38 | output_dim = num_filters[i] 39 | 40 | if i != 0: 41 | layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 42 | 43 | layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=int(padding))) 44 | layers.append(nn.ReLU(inplace=True)) 45 | 46 | for _ in range(no_convs_per_block-1): 47 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=int(padding))) 48 | layers.append(nn.ReLU(inplace=True)) 49 | if i < len(self.num_filters)-1 and norm == True: 50 | layers.append(nn.BatchNorm2d(output_dim)) 51 | 52 | self.layers = nn.Sequential(*layers) 53 | 54 | self.layers.apply(init_weights) 55 | 56 | def forward(self, input): 57 | output = self.layers(input) 58 | return output 59 | 60 | class AxisAlignedConvGaussian(nn.Module): 61 | """ 62 | A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix. 63 | """ 64 | def __init__(self, input_channels, num_filters, no_convs_per_block, latent_dim, 65 | initializers, posterior=False,norm=False): 66 | super(AxisAlignedConvGaussian, self).__init__() 67 | self.input_channels = input_channels 68 | self.channel_axis = 1 69 | self.num_filters = num_filters 70 | self.no_convs_per_block = no_convs_per_block 71 | self.latent_dim = latent_dim 72 | self.posterior = posterior 73 | if self.posterior: 74 | self.name = 'Posterior' 75 | else: 76 | self.name = 'Prior' 77 | self.encoder = Encoder(self.input_channels, self.num_filters, self.no_convs_per_block, initializers, 78 | posterior=self.posterior,norm=norm) 79 | self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1,1), stride=1) 80 | self.show_img = 0 81 | self.show_seg = 0 82 | self.show_concat = 0 83 | self.show_enc = 0 84 | self.sum_input = 0 85 | 86 | nn.init.kaiming_normal_(self.conv_layer.weight, mode='fan_in', nonlinearity='relu') 87 | nn.init.normal_(self.conv_layer.bias) 88 | 89 | def forward(self, input, segm=None): 90 | 91 | #If segmentation is not none, concatenate the mask to the channel axis of the input 92 | if segm is not None: 93 | self.show_img = input 94 | self.show_seg = segm 95 | input = torch.cat((input, segm), dim=1) 96 | self.show_concat = input 97 | self.sum_input = torch.sum(input) 98 | 99 | encoding = self.encoder(input) 100 | self.show_enc = encoding 101 | #We only want the mean of the resulting hxw image 102 | encoding = encoding.mean([2,3],True) 103 | 104 | #Convert encoding to 2 x latent dim and split up for mu and log_sigma 105 | #We squeeze the second dimension twice, since otherwise it won't work when batch size is equal to 1 106 | mu_log_sigma = (self.conv_layer(encoding)).squeeze(-1).squeeze(-1) 107 | 108 | mu, log_sigma = torch.chunk(mu_log_sigma,2,dim=1) 109 | #This is a multivariate normal with diagonal covariance matrix sigma 110 | #https://github.com/pytorch/pytorch/pull/11178 111 | dist = Independent(Normal(loc=mu, scale=torch.exp(log_sigma)),1) 112 | return encoding.squeeze(-1).squeeze(-1), dist 113 | 114 | class Fcomb(nn.Module): 115 | """ 116 | A function composed of no_convs_fcomb times a 1x1 convolution that combines the sample taken from the latent space, 117 | and output of the UNet (the feature map) by concatenating them along their channel axis. 118 | """ 119 | def __init__(self, num_filters, latent_dim, num_output_channels, num_classes, 120 | no_convs_fcomb, initializers, use_tile=True,norm=False): 121 | super(Fcomb, self).__init__() 122 | self.num_channels = num_output_channels #output channels 123 | self.num_classes = num_classes 124 | self.channel_axis = 1 125 | self.spatial_axes = [2,3] 126 | self.num_filters = num_filters 127 | self.latent_dim = latent_dim 128 | self.use_tile = use_tile 129 | self.no_convs_fcomb = no_convs_fcomb 130 | self.name = 'Fcomb' 131 | if not use_tile: 132 | self.latent_broadcast = nn.Sequential( 133 | GatedConvTranspose2d(self.latent_dim, 64, 32, 1, 0), 134 | GatedConvTranspose2d(64, 64, 5, 1, 2), 135 | GatedConvTranspose2d(64, 32, 5, 2, 2, 1), 136 | GatedConvTranspose2d(32, 32, 5, 1, 2), 137 | GatedConvTranspose2d(32, 32, 5, 2, 2, 1), 138 | GatedConvTranspose2d(32, self.latent_dim, 5, 1, 2) 139 | ) 140 | 141 | layers = [] 142 | 143 | #Decoder of N x a 1x1 convolution followed by a ReLU activation function except for the last layer 144 | layers.append(nn.Conv2d(self.num_filters[0]+self.latent_dim, self.num_filters[0], kernel_size=1)) 145 | layers.append(nn.ReLU(inplace=True)) 146 | 147 | for _ in range(no_convs_fcomb-2): 148 | if norm: 149 | layers.append(nn.BatchNorm2d(self.num_filters[0])) 150 | layers.append(nn.Conv2d(self.num_filters[0], self.num_filters[0], kernel_size=1)) 151 | layers.append(nn.ReLU(inplace=True)) 152 | 153 | self.layers = nn.Sequential(*layers) 154 | 155 | 156 | self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1) 157 | 158 | if initializers['w'] == 'orthogonal': 159 | self.layers.apply(init_weights_orthogonal_normal) 160 | self.last_layer.apply(init_weights_orthogonal_normal) 161 | else: 162 | self.layers.apply(init_weights) 163 | self.last_layer.apply(init_weights) 164 | 165 | def tile(self, a, dim, n_tile): 166 | """ 167 | This function is taken form PyTorch forum and mimics the behavior of tf.tile. 168 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 169 | """ 170 | init_dim = a.size(dim) 171 | repeat_idx = [1] * a.dim() 172 | repeat_idx[dim] = n_tile 173 | a = a.repeat(*(repeat_idx)) 174 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device) 175 | return torch.index_select(a, dim, order_index) 176 | 177 | def forward(self, feature_map, z): 178 | """ 179 | Z is batch_sizexlatent_dim and feature_map is batch_sizexno_channelsxHxW. 180 | So broadcast Z to batch_sizexlatent_dimxHxW. Behavior is exactly the same as tf.tile (verified) 181 | """ 182 | if self.use_tile: 183 | z = torch.unsqueeze(z,2) 184 | z = self.tile(z, 2, feature_map.shape[self.spatial_axes[0]]) 185 | z = torch.unsqueeze(z,3) 186 | z = self.tile(z, 3, feature_map.shape[self.spatial_axes[1]]) 187 | else: 188 | z = z.unsqueeze(2).unsqueeze(2) 189 | z = self.latent_broadcast(z) 190 | #Concatenate the feature map (output of the UNet) and the sample taken from the latent space 191 | feature_map = torch.cat((feature_map, z), dim=self.channel_axis) 192 | output = self.layers(feature_map) 193 | return self.last_layer(output) 194 | 195 | class glowDensity(nn.Module): 196 | """ 197 | A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix as 198 | the base distribution for a sequence of flow based transformations. 199 | """ 200 | def __init__(self, num_flows, input_channels, num_filters, no_convs_per_block, 201 | latent_dim, initializers, posterior=False,norm=False): 202 | super(glowDensity, self).__init__() 203 | 204 | self.base_density = AxisAlignedConvGaussian(input_channels, num_filters, 205 | no_convs_per_block, latent_dim, initializers, posterior=True,norm=norm).to(device) 206 | 207 | # Initialize log-det-jacobian to zero 208 | self.log_det_j = 0. 209 | self.latent_dim = latent_dim 210 | # Flow parameters 211 | self.num_flows = num_flows 212 | nF_oP = num_flows * latent_dim 213 | 214 | 215 | # Normalizing flow layers 216 | self.norms = [CondActNorm(dim=latent_dim) for _ in range(num_flows)] 217 | self.InvConvs = [CondInvertible1x1Conv(dim=latent_dim) for i in range(num_flows)] 218 | self.couplings = [CondAffineHalfFlow(dim=latent_dim,latent_dim=num_filters[-1], 219 | parity=i%2, nh=4) for i in range(num_flows)] 220 | 221 | # Amortized flow parameters 222 | self.amor_W = nn.Sequential(nn.Linear(num_filters[-1], 4),nn.ReLU(), 223 | nn.Linear(4, num_flows * latent_dim**2),) 224 | self.amor_s = nn.Linear(num_filters[-1], nF_oP) 225 | self.amor_t = nn.Linear(num_filters[-1], num_flows) 226 | 227 | def forward(self, input, segm=None): 228 | 229 | """ 230 | Forward pass with planar flows for the transformation z_0 -> z_1 -> ... -> z_k. 231 | Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. 232 | """ 233 | batch_size = input.shape[0] 234 | self.ldj = torch.zeros(batch_size).to(device) 235 | h, z0_density = self.base_density(input,segm) 236 | z = [z0_density.rsample()] 237 | W = (self.amor_W(h)).view(batch_size, self.num_flows, self.latent_dim,self.latent_dim) 238 | s = (self.amor_s(h)).view(batch_size, self.num_flows, self.latent_dim) 239 | t = self.amor_t(h).view(batch_size, self.num_flows, 1) 240 | 241 | 242 | # Normalizing flows 243 | for k in range(self.num_flows): 244 | z_k, ldj = self.norms[k](z[k], s[:,k,:], t[:,k,:]) 245 | self.ldj += ldj 246 | z_k, ldj = self.InvConvs[k](z_k, W[:,k,:,:]) 247 | self.ldj += ldj 248 | z_k, ldj = self.couplings[k](z_k,h) 249 | self.ldj += ldj 250 | z.append(z_k) 251 | 252 | return self.ldj, z[0], z[-1], z0_density 253 | 254 | 255 | class planarFlowDensity(nn.Module): 256 | """ 257 | A convolutional net that parametrizes a Gaussian distribution with axis aligned covariance matrix as 258 | the base distribution for a sequence of flow based transformations. 259 | """ 260 | def __init__(self, num_flows, input_channels, num_filters, no_convs_per_block, 261 | latent_dim, initializers, posterior=False,norm=False): 262 | super(planarFlowDensity, self).__init__() 263 | 264 | self.base_density = AxisAlignedConvGaussian(input_channels, num_filters, 265 | no_convs_per_block, latent_dim, initializers, posterior=True,norm=norm).to(device) 266 | 267 | # Initialize log-det-jacobian to zero 268 | self.log_det_j = 0. 269 | self.latent_dim = latent_dim 270 | # Flow parameters 271 | flow = flows.Planar 272 | self.num_flows = num_flows 273 | nF_oP = num_flows * latent_dim 274 | # Amortized flow parameters 275 | self.amor_u = nn.Sequential(nn.Linear(num_filters[-1], nF_oP),nn.ReLU(), 276 | nn.Linear(nF_oP, nF_oP),nn.BatchNorm1d(nF_oP)) 277 | self.amor_w = nn.Sequential(nn.Linear(num_filters[-1], nF_oP),nn.ReLU(), 278 | nn.Linear(nF_oP, nF_oP),nn.BatchNorm1d(nF_oP)) 279 | self.amor_b = nn.Sequential(nn.Linear(num_filters[-1], num_flows), nn.ReLU(), 280 | nn.Linear(num_flows, num_flows),nn.BatchNorm1d(num_flows)) 281 | 282 | # Normalizing flow layers 283 | for k in range(num_flows): 284 | flow_k = flow().to(device) 285 | self.add_module('flow_' + str(k), flow_k) 286 | 287 | 288 | def forward(self, input, segm=None): 289 | 290 | """ 291 | Forward pass with planar flows for the transformation z_0 -> z_1 -> ... -> z_k. 292 | Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ]. 293 | """ 294 | batch_size = input.shape[0] 295 | self.log_det_j = 0. 296 | h, z0_density = self.base_density(input,segm) 297 | z = [z0_density.rsample()] 298 | 299 | # return amortized u an w for all flows 300 | u = self.amor_u(h).view(batch_size, self.num_flows, self.latent_dim, 1) 301 | w = self.amor_w(h).view(batch_size, self.num_flows, 1, self.latent_dim) 302 | b = self.amor_b(h).view(batch_size, self.num_flows, 1, 1) 303 | 304 | # Normalizing flows 305 | for k in range(self.num_flows): 306 | flow_k = getattr(self, 'flow_' + str(k)) 307 | z_k, log_det_jacobian = flow_k(z[k], u[:, k, :, :], w[:, k, :, :], b[:, k, :, :]) 308 | z.append(z_k) 309 | self.log_det_j += log_det_jacobian 310 | 311 | return self.log_det_j, z[0], z[-1], z0_density 312 | 313 | class cFlowNet(nn.Module): 314 | """ 315 | input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) 316 | num_classes: the number of classes to predict 317 | num_filters: is a list consisint of the amount of filters layer 318 | latent_dim: dimension of the latent space 319 | no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior 320 | """ 321 | 322 | def __init__(self, input_channels=1, num_classes=1, num_filters=[32,64,128,256], 323 | latent_dim=6, no_convs_fcomb=4, beta=1.0, num_flows=4,norm=False,flow=False,glow=False): 324 | 325 | super(cFlowNet, self).__init__() 326 | self.input_channels = input_channels 327 | self.num_classes = num_classes 328 | self.num_filters = num_filters 329 | self.latent_dim = latent_dim 330 | self.no_convs_per_block = 3 331 | self.no_convs_fcomb = no_convs_fcomb 332 | self.initializers = {'w':'he_normal', 'b':'normal'} 333 | self.beta = beta 334 | self.z_prior_sample = 0 335 | self.flow = flow 336 | self.flow_steps = num_flows 337 | 338 | self.unet = Unet(self.input_channels, self.num_classes, self.num_filters, 339 | self.initializers, apply_last_layer=False, padding=True, norm=norm).to(device) 340 | self.prior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, 341 | self.no_convs_per_block, self.latent_dim, self.initializers,norm=norm).to(device) 342 | 343 | if flow: 344 | if glow: 345 | self.posterior = glowDensity(self.flow_steps, self.input_channels, self.num_filters, self.no_convs_per_block, 346 | self.latent_dim, self.initializers,posterior=True,norm=norm).to(device) 347 | else: 348 | self.posterior = planarFlowDensity(self.flow_steps, self.input_channels, self.num_filters, self.no_convs_per_block, 349 | self.latent_dim, self.initializers,posterior=True,norm=norm).to(device) 350 | else: 351 | self.posterior = AxisAlignedConvGaussian(self.input_channels, self.num_filters, 352 | self.no_convs_per_block, self.latent_dim, self.initializers, posterior=True,norm=norm).to(device) 353 | 354 | self.fcomb = Fcomb(self.num_filters, self.latent_dim, self.input_channels, 355 | self.num_classes, self.no_convs_fcomb, {'w':'orthogonal', 'b':'normal'}, use_tile=True,norm=norm).to(device) 356 | 357 | def forward(self, patch, segm, training=True): 358 | """ 359 | Construct prior latent space for patch and run patch through UNet, 360 | in case training is True also construct posterior latent space 361 | """ 362 | # pdb.set_trace() 363 | if training: 364 | if self.flow: 365 | self.log_det_j, self.z0, self.z, self.posterior_latent_space = self.posterior.forward(patch, segm) 366 | else: 367 | _, self.posterior_latent_space = self.posterior.forward(patch,segm) 368 | self.z = self.posterior_latent_space.rsample() 369 | self.z0 = self.z.clone() 370 | _, self.prior_latent_space = self.prior.forward(patch) 371 | self.unet_features = self.unet.forward(patch,False) 372 | 373 | def sample(self, testing=False): 374 | """ 375 | Sample a segmentation by reconstructing from a prior sample 376 | and combining this with UNet features 377 | """ 378 | if testing == False: 379 | z_prior = self.prior_latent_space.rsample() 380 | self.z_prior_sample = z_prior 381 | else: 382 | #You can choose whether you mean a sample or the mean here. For the GED it is important to take a sample. 383 | z_prior = self.prior_latent_space.sample() 384 | self.z_prior_sample = z_prior 385 | log_pz = self.prior_latent_space.log_prob(z_prior) 386 | log_qz = self.posterior_latent_space.log_prob(z_prior) 387 | return self.fcomb.forward(self.unet_features,z_prior), log_pz, log_qz 388 | 389 | 390 | def reconstruct(self, use_posterior_mean=False, calculate_posterior=False, z_posterior=None): 391 | """ 392 | Reconstruct a segmentation from a posterior sample (decoding a posterior sample) and UNet feature map 393 | use_posterior_mean: use posterior_mean instead of sampling z_q 394 | calculate_posterior: use a provided sample or sample from posterior latent space 395 | """ 396 | if use_posterior_mean: 397 | z_posterior = self.posterior_latent_space.loc 398 | else: 399 | if calculate_posterior: 400 | z_posterior = self.posterior_latent_space.rsample() 401 | return self.fcomb.forward(self.unet_features, z_posterior) 402 | 403 | def kl_divergence(self, analytic=True, calculate_posterior=False, z_posterior=None): 404 | """ 405 | Calculate the KL divergence between the posterior and prior KL(Q||P) 406 | analytic: calculate KL analytically or via sampling from the posterior 407 | calculate_posterior: if we use samapling to approximate KL we can sample here or supply a sample 408 | """ 409 | if analytic: 410 | #Neeed to add this to torch source code, see: https://github.com/pytorch/pytorch/issues/13545 411 | kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space).sum() 412 | 413 | else: 414 | log_posterior_prob = self.posterior_latent_space.log_prob(self.z) 415 | log_prior_prob = self.prior_latent_space.log_prob(self.z) 416 | kl_div = (log_posterior_prob - log_prior_prob).sum() 417 | if self.flow: 418 | kl_div = kl_div - self.log_det_j.sum() 419 | return kl_div 420 | 421 | def elbo(self, segm, mask=None,use_mask = True, analytic_kl=True, reconstruct_posterior_mean=False): 422 | """ 423 | Calculate the evidence lower bound of the log-likelihood of P(Y|X) 424 | """ 425 | batch_size = segm.shape[0] 426 | self.kl = (self.kl_divergence(analytic=analytic_kl, calculate_posterior=False)) 427 | 428 | #Here we use the posterior sample sampled above 429 | self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean, 430 | calculate_posterior=False, z_posterior=self.z) 431 | if use_mask: 432 | 433 | self.reconstruction = self.reconstruction*mask 434 | criterion = nn.BCEWithLogitsLoss(reduction='none') 435 | reconstruction_loss = criterion(input=self.reconstruction, target=segm) 436 | self.reconstruction_loss = torch.sum(reconstruction_loss) 437 | self.mean_reconstruction_loss = torch.mean(reconstruction_loss) 438 | 439 | return self.reconstruction, self.reconstruction_loss/batch_size, self.kl/batch_size,\ 440 | -(self.reconstruction_loss + self.beta * self.kl)/batch_size 441 | -------------------------------------------------------------------------------- /models/flows.py: -------------------------------------------------------------------------------- 1 | """ 2 | Collection of flow strategies 3 | """ 4 | 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | 12 | from models.layers import MaskedConv2d, MaskedLinear 13 | 14 | 15 | class Planar(nn.Module): 16 | """ 17 | PyTorch implementation of planar flows as presented in "Variational Inference with Normalizing Flows" 18 | by Danilo Jimenez Rezende, Shakir Mohamed. Model assumes amortized flow parameters. 19 | """ 20 | 21 | def __init__(self): 22 | 23 | super(Planar, self).__init__() 24 | 25 | self.h = nn.Tanh() 26 | self.softplus = nn.Softplus() 27 | 28 | def der_h(self, x): 29 | """ Derivative of tanh """ 30 | 31 | return 1 - self.h(x) ** 2 32 | 33 | def forward(self, zk, u, w, b): 34 | """ 35 | Forward pass. Assumes amortized u, w and b. Conditions on diagonals of u and w for invertibility 36 | will be be satisfied inside this function. Computes the following transformation: 37 | z' = z + u h( w^T z + b) 38 | or actually 39 | z'^T = z^T + h(z^T w + b)u^T 40 | Assumes the following input shapes: 41 | shape u = (batch_size, z_size, 1) 42 | shape w = (batch_size, 1, z_size) 43 | shape b = (batch_size, 1, 1) 44 | shape z = (batch_size, z_size). 45 | """ 46 | 47 | zk = zk.unsqueeze(2) 48 | 49 | # reparameterize u such that the flow becomes invertible (see appendix paper) 50 | uw = torch.bmm(w, u) 51 | m_uw = -1. + self.softplus(uw) 52 | w_norm_sq = torch.sum(w ** 2, dim=2, keepdim=True) 53 | u_hat = u + ((m_uw - uw) * w.transpose(2, 1) / w_norm_sq) 54 | 55 | # compute flow with u_hat 56 | wzb = torch.bmm(w, zk) + b 57 | z = zk + u_hat * self.h(wzb) 58 | z = z.squeeze(2) 59 | 60 | # compute logdetJ 61 | psi = w * self.der_h(wzb) 62 | log_det_jacobian = torch.log(torch.abs(1 + torch.bmm(psi, u_hat))) 63 | log_det_jacobian = log_det_jacobian.squeeze(2).squeeze(1) 64 | 65 | return z, log_det_jacobian 66 | 67 | 68 | class Sylvester(nn.Module): 69 | """ 70 | Sylvester normalizing flow. 71 | """ 72 | 73 | def __init__(self, num_ortho_vecs): 74 | 75 | super(Sylvester, self).__init__() 76 | 77 | self.num_ortho_vecs = num_ortho_vecs 78 | 79 | self.h = nn.Tanh() 80 | 81 | triu_mask = torch.triu(torch.ones(num_ortho_vecs, num_ortho_vecs), diagonal=1).unsqueeze(0) 82 | diag_idx = torch.arange(0, num_ortho_vecs).long() 83 | 84 | self.register_buffer('triu_mask', Variable(triu_mask)) 85 | self.triu_mask.requires_grad = False 86 | self.register_buffer('diag_idx', diag_idx) 87 | 88 | def der_h(self, x): 89 | return self.der_tanh(x) 90 | 91 | def der_tanh(self, x): 92 | return 1 - self.h(x) ** 2 93 | 94 | def _forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True): 95 | """ 96 | All flow parameters are amortized. Conditions on diagonals of R1 and R2 for invertibility need to be satisfied 97 | outside of this function. Computes the following transformation: 98 | z' = z + QR1 h( R2Q^T z + b) 99 | or actually 100 | z'^T = z^T + h(z^T Q R2^T + b^T)R1^T Q^T 101 | :param zk: shape: (batch_size, z_size) 102 | :param r1: shape: (batch_size, num_ortho_vecs, num_ortho_vecs) 103 | :param r2: shape: (batch_size, num_ortho_vecs, num_ortho_vecs) 104 | :param q_ortho: shape (batch_size, z_size , num_ortho_vecs) 105 | :param b: shape: (batch_size, 1, self.z_size) 106 | :return: z, log_det_j 107 | """ 108 | 109 | # Amortized flow parameters 110 | zk = zk.unsqueeze(1) 111 | 112 | # Save diagonals for log_det_j 113 | diag_r1 = r1[:, self.diag_idx, self.diag_idx] 114 | diag_r2 = r2[:, self.diag_idx, self.diag_idx] 115 | 116 | r1_hat = r1 117 | r2_hat = r2 118 | 119 | qr2 = torch.bmm(q_ortho, r2_hat.transpose(2, 1)) 120 | qr1 = torch.bmm(q_ortho, r1_hat) 121 | 122 | r2qzb = torch.bmm(zk, qr2) + b 123 | z = torch.bmm(self.h(r2qzb), qr1.transpose(2, 1)) + zk 124 | z = z.squeeze(1) 125 | 126 | # Compute log|det J| 127 | # Output log_det_j in shape (batch_size) instead of (batch_size,1) 128 | diag_j = diag_r1 * diag_r2 129 | diag_j = self.der_h(r2qzb).squeeze(1) * diag_j 130 | diag_j += 1. 131 | log_diag_j = diag_j.abs().log() 132 | 133 | if sum_ldj: 134 | log_det_j = log_diag_j.sum(-1) 135 | else: 136 | log_det_j = log_diag_j 137 | 138 | return z, log_det_j 139 | 140 | def forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True): 141 | 142 | return self._forward(zk, r1, r2, q_ortho, b, sum_ldj) 143 | 144 | 145 | class TriangularSylvester(nn.Module): 146 | """ 147 | Sylvester normalizing flow with Q=P or Q=I. 148 | """ 149 | 150 | def __init__(self, z_size): 151 | 152 | super(TriangularSylvester, self).__init__() 153 | 154 | self.z_size = z_size 155 | self.h = nn.Tanh() 156 | 157 | diag_idx = torch.arange(0, z_size).long() 158 | self.register_buffer('diag_idx', diag_idx) 159 | 160 | def der_h(self, x): 161 | return self.der_tanh(x) 162 | 163 | def der_tanh(self, x): 164 | return 1 - self.h(x) ** 2 165 | 166 | def _forward(self, zk, r1, r2, b, permute_z=None, sum_ldj=True): 167 | """ 168 | All flow parameters are amortized. conditions on diagonals of R1 and R2 need to be satisfied 169 | outside of this function. 170 | Computes the following transformation: 171 | z' = z + QR1 h( R2Q^T z + b) 172 | or actually 173 | z'^T = z^T + h(z^T Q R2^T + b^T)R1^T Q^T 174 | with Q = P a permutation matrix (equal to identity matrix if permute_z=None) 175 | :param zk: shape: (batch_size, z_size) 176 | :param r1: shape: (batch_size, num_ortho_vecs, num_ortho_vecs). 177 | :param r2: shape: (batch_size, num_ortho_vecs, num_ortho_vecs). 178 | :param b: shape: (batch_size, 1, self.z_size) 179 | :return: z, log_det_j 180 | """ 181 | 182 | # Amortized flow parameters 183 | zk = zk.unsqueeze(1) 184 | 185 | # Save diagonals for log_det_j 186 | diag_r1 = r1[:, self.diag_idx, self.diag_idx] 187 | diag_r2 = r2[:, self.diag_idx, self.diag_idx] 188 | 189 | if permute_z is not None: 190 | # permute order of z 191 | z_per = zk[:, :, permute_z] 192 | else: 193 | z_per = zk 194 | 195 | r2qzb = torch.bmm(z_per, r2.transpose(2, 1)) + b 196 | z = torch.bmm(self.h(r2qzb), r1.transpose(2, 1)) 197 | 198 | if permute_z is not None: 199 | # permute order of z again back again 200 | z = z[:, :, permute_z] 201 | 202 | z += zk 203 | z = z.squeeze(1) 204 | 205 | # Compute log|det J| 206 | # Output log_det_j in shape (batch_size) instead of (batch_size,1) 207 | diag_j = diag_r1 * diag_r2 208 | diag_j = self.der_h(r2qzb).squeeze(1) * diag_j 209 | diag_j += 1. 210 | log_diag_j = diag_j.abs().log() 211 | 212 | if sum_ldj: 213 | log_det_j = log_diag_j.sum(-1) 214 | else: 215 | log_det_j = log_diag_j 216 | 217 | return z, log_det_j 218 | 219 | def forward(self, zk, r1, r2, q_ortho, b, sum_ldj=True): 220 | 221 | return self._forward(zk, r1, r2, q_ortho, b, sum_ldj) 222 | 223 | 224 | class IAF(nn.Module): 225 | """ 226 | PyTorch implementation of inverse autoregressive flows as presented in 227 | "Improving Variational Inference with Inverse Autoregressive Flow" by Diederik P. Kingma, Tim Salimans, 228 | Rafal Jozefowicz, Xi Chen, Ilya Sutskever, Max Welling. 229 | Inverse Autoregressive Flow with either MADE MLPs or Pixel CNNs. Contains several flows. Each transformation 230 | takes as an input the previous stochastic z, and a context h. The structure of each flow is then as follows: 231 | z <- autoregressive_layer(z) + h, allow for diagonal connections 232 | z <- autoregressive_layer(z), allow for diagonal connections 233 | : 234 | z <- autoregressive_layer(z), do not allow for diagonal connections. 235 | 236 | Note that the size of h needs to be the same as h_size, which is the width of the MADE layers. 237 | """ 238 | 239 | def __init__(self, z_size, num_flows=2, num_hidden=0, h_size=50, forget_bias=1., conv2d=False): 240 | super(IAF, self).__init__() 241 | self.z_size = z_size 242 | self.num_flows = num_flows 243 | self.num_hidden = num_hidden 244 | self.h_size = h_size 245 | self.conv2d = conv2d 246 | if not conv2d: 247 | ar_layer = MaskedLinear 248 | else: 249 | ar_layer = MaskedConv2d 250 | self.activation = torch.nn.ELU 251 | # self.activation = torch.nn.ReLU 252 | 253 | self.forget_bias = forget_bias 254 | self.flows = [] 255 | self.param_list = [] 256 | 257 | # For reordering z after each flow 258 | flip_idx = torch.arange(self.z_size - 1, -1, -1).long() 259 | self.register_buffer('flip_idx', flip_idx) 260 | 261 | for k in range(num_flows): 262 | arch_z = [ar_layer(z_size, h_size), self.activation()] 263 | self.param_list += list(arch_z[0].parameters()) 264 | z_feats = torch.nn.Sequential(*arch_z) 265 | arch_zh = [] 266 | for j in range(num_hidden): 267 | arch_zh += [ar_layer(h_size, h_size), self.activation()] 268 | self.param_list += list(arch_zh[-2].parameters()) 269 | zh_feats = torch.nn.Sequential(*arch_zh) 270 | linear_mean = ar_layer(h_size, z_size, diagonal_zeros=True) 271 | linear_std = ar_layer(h_size, z_size, diagonal_zeros=True) 272 | self.param_list += list(linear_mean.parameters()) 273 | self.param_list += list(linear_std.parameters()) 274 | 275 | if torch.cuda.is_available(): 276 | z_feats = z_feats.cuda() 277 | zh_feats = zh_feats.cuda() 278 | linear_mean = linear_mean.cuda() 279 | linear_std = linear_std.cuda() 280 | self.flows.append((z_feats, zh_feats, linear_mean, linear_std)) 281 | 282 | self.param_list = torch.nn.ParameterList(self.param_list) 283 | 284 | def forward(self, z, h_context): 285 | 286 | logdets = 0. 287 | for i, flow in enumerate(self.flows): 288 | if (i + 1) % 2 == 0 and not self.conv2d: 289 | # reverse ordering to help mixing 290 | z = z[:, self.flip_idx] 291 | 292 | h = flow[0](z) 293 | h = h + h_context 294 | h = flow[1](h) 295 | mean = flow[2](h) 296 | gate = F.sigmoid(flow[3](h) + self.forget_bias) 297 | z = gate * z + (1 - gate) * mean 298 | logdets += torch.sum(gate.log().view(gate.size(0), -1), 1) 299 | return z, logdets 300 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | 8 | class Identity(nn.Module): 9 | def __init__(self): 10 | super(Identity, self).__init__() 11 | 12 | def forward(self, x): 13 | return x 14 | 15 | 16 | class GatedConv2d(nn.Module): 17 | def __init__(self, input_channels, output_channels, kernel_size, stride, padding, dilation=1, activation=None): 18 | super(GatedConv2d, self).__init__() 19 | 20 | self.activation = activation 21 | self.sigmoid = nn.Sigmoid() 22 | 23 | self.h = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) 24 | self.g = nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding, dilation) 25 | 26 | def forward(self, x): 27 | if self.activation is None: 28 | h = self.h(x) 29 | else: 30 | h = self.activation(self.h(x)) 31 | 32 | g = self.sigmoid(self.g(x)) 33 | 34 | return h * g 35 | 36 | 37 | class GatedConvTranspose2d(nn.Module): 38 | def __init__(self, input_channels, output_channels, kernel_size, stride, padding, output_padding=0, dilation=1, 39 | activation=None): 40 | super(GatedConvTranspose2d, self).__init__() 41 | 42 | self.activation = activation 43 | self.sigmoid = nn.Sigmoid() 44 | 45 | self.h = nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding, output_padding, 46 | dilation=dilation) 47 | self.g = nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride, padding, output_padding, 48 | dilation=dilation) 49 | 50 | def forward(self, x): 51 | if self.activation is None: 52 | h = self.h(x) 53 | else: 54 | h = self.activation(self.h(x)) 55 | 56 | g = self.sigmoid(self.g(x)) 57 | 58 | return h * g 59 | 60 | 61 | class MaskedLinear(nn.Module): 62 | """ 63 | Creates masked linear layer for MLP MADE. 64 | For input (x) to hidden (h) or hidden to hidden layers choose diagonal_zeros = False. 65 | For hidden to output (y) layers: 66 | If output depends on input through y_i = f(x_{= n_in: 97 | k = n_out // n_in 98 | for i in range(n_in): 99 | mask[i + 1:, i * k:(i + 1) * k] = 0 100 | if self.diagonal_zeros: 101 | mask[i:i + 1, i * k:(i + 1) * k] = 0 102 | else: 103 | k = n_in // n_out 104 | for i in range(n_out): 105 | mask[(i + 1) * k:, i:i + 1] = 0 106 | if self.diagonal_zeros: 107 | mask[i * k:(i + 1) * k:, i:i + 1] = 0 108 | return mask 109 | 110 | def forward(self, x): 111 | output = x.mm(self.mask * self.weight) 112 | 113 | if self.bias is not None: 114 | return output.add(self.bias.expand_as(output)) 115 | else: 116 | return output 117 | 118 | def __repr__(self): 119 | if self.bias is not None: 120 | bias = True 121 | else: 122 | bias = False 123 | return self.__class__.__name__ + ' (' \ 124 | + str(self.in_features) + ' -> ' \ 125 | + str(self.out_features) + ', diagonal_zeros=' \ 126 | + str(self.diagonal_zeros) + ', bias=' \ 127 | + str(bias) + ')' 128 | 129 | 130 | class MaskedConv2d(nn.Module): 131 | """ 132 | Creates masked convolutional autoregressive layer for pixelCNN. 133 | For input (x) to hidden (h) or hidden to hidden layers choose diagonal_zeros = False. 134 | For hidden to output (y) layers: 135 | If output depends on input through y_i = f(x_{= n_in: 174 | k = n_out // n_in 175 | for i in range(n_in): 176 | mask[i * k:(i + 1) * k, i + 1:, l, m] = 0 177 | if self.diagonal_zeros: 178 | mask[i * k:(i + 1) * k, i:i + 1, l, m] = 0 179 | else: 180 | k = n_in // n_out 181 | for i in range(n_out): 182 | mask[i:i + 1, (i + 1) * k:, l, m] = 0 183 | if self.diagonal_zeros: 184 | mask[i:i + 1, i * k:(i + 1) * k:, l, m] = 0 185 | 186 | return mask 187 | 188 | def forward(self, x): 189 | output = F.conv2d(x, self.mask * self.weight, bias=self.bias, padding=(1, 1)) 190 | return output 191 | 192 | def __repr__(self): 193 | if self.bias is not None: 194 | bias = True 195 | else: 196 | bias = False 197 | return self.__class__.__name__ + ' (' \ 198 | + str(self.in_features) + ' -> ' \ 199 | + str(self.out_features) + ', diagonal_zeros=' \ 200 | + str(self.diagonal_zeros) + ', bias=' \ 201 | + str(bias) + ', size_kernel=' \ 202 | + str(self.size_kernel) + ')' 203 | 204 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | from models.unet_blocks import * 2 | import torch.nn.functional as F 3 | import pdb 4 | 5 | class Unet(nn.Module): 6 | """ 7 | A UNet (https://arxiv.org/abs/1505.04597) implementation. 8 | input_channels: the number of channels in the image (1 for greyscale and 3 for RGB) 9 | num_classes: the number of classes to predict 10 | num_filters: list with the amount of filters per layer 11 | apply_last_layer: boolean to apply last layer or not (not used in Probabilistic UNet) 12 | padidng: Boolean, if true we pad the images with 1 so that we keep the same dimensions 13 | """ 14 | 15 | def __init__(self, input_channels, num_classes, num_filters, initializers, apply_last_layer=True, padding=True,norm=False): 16 | super(Unet, self).__init__() 17 | self.input_channels = input_channels 18 | self.num_classes = num_classes 19 | self.num_filters = num_filters 20 | self.padding = padding 21 | self.activation_maps = [] 22 | self.apply_last_layer = apply_last_layer 23 | self.contracting_path = nn.ModuleList() 24 | 25 | for i in range(len(self.num_filters)): 26 | input = self.input_channels if i == 0 else output 27 | output = self.num_filters[i] 28 | 29 | if i == 0: 30 | pool = False 31 | else: 32 | pool = True 33 | # if i == len(self.num_filters)-1: 34 | # self.contracting_path.append(DownConvBlock(input, output, initializers, padding, pool=pool,norm=False)) 35 | self.contracting_path.append(DownConvBlock(input, output, initializers, padding, pool=pool,norm=norm)) 36 | 37 | 38 | self.upsampling_path = nn.ModuleList() 39 | 40 | n = len(self.num_filters) - 2 41 | # pdb.set_trace() 42 | for i in range(n, -1, -1): 43 | input = output + self.num_filters[i] 44 | output = self.num_filters[i] 45 | if i == 0: 46 | norm = False 47 | self.upsampling_path.append(UpConvBlock(input, output, initializers, padding,norm=norm)) 48 | 49 | if self.apply_last_layer: 50 | self.last_layer = nn.Conv2d(output, num_classes, kernel_size=1) 51 | #nn.init.kaiming_normal_(self.last_layer.weight, mode='fan_in',nonlinearity='relu') 52 | #nn.init.normal_(self.last_layer.bias) 53 | 54 | 55 | def forward(self, x, val): 56 | blocks = [] 57 | for i, down in enumerate(self.contracting_path): 58 | x = down(x) 59 | if i != len(self.contracting_path)-1: 60 | blocks.append(x) 61 | 62 | for i, up in enumerate(self.upsampling_path): 63 | x = up(x, blocks[-i-1]) 64 | 65 | del blocks 66 | 67 | #Used for saving the activations and plotting 68 | if val: 69 | self.activation_maps.append(x) 70 | 71 | if self.apply_last_layer: 72 | x = self.last_layer(x) 73 | 74 | return x 75 | -------------------------------------------------------------------------------- /models/unet_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from utils.utils import init_weights 6 | import pdb 7 | class DownConvBlock(nn.Module): 8 | """ 9 | A block of three convolutional layers where each layer is followed by a non-linear activation function 10 | Between each block we add a pooling operation. 11 | """ 12 | def __init__(self, input_dim, output_dim, initializers, padding, pool=True,norm=False): 13 | super(DownConvBlock, self).__init__() 14 | layers = [] 15 | 16 | if pool: 17 | layers.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True)) 18 | 19 | layers.append(nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=1, padding=int(padding))) 20 | layers.append(nn.ReLU(inplace=True)) 21 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding))) 22 | layers.append(nn.ReLU(inplace=True)) 23 | layers.append(nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=1, padding=int(padding))) 24 | layers.append(nn.ReLU(inplace=True)) 25 | 26 | if norm: 27 | layers.append(nn.BatchNorm2d(output_dim)) 28 | self.layers = nn.Sequential(*layers) 29 | 30 | self.layers.apply(init_weights) 31 | 32 | def forward(self, patch): 33 | return self.layers(patch) 34 | 35 | class UpConvBlock(nn.Module): 36 | """ 37 | A block consists of an upsampling layer followed by a convolutional layer to reduce the amount of channels and then a DownConvBlock 38 | If bilinear is set to false, we do a transposed convolution instead of upsampling 39 | """ 40 | def __init__(self, input_dim, output_dim, initializers, padding, bilinear=True,norm=False): 41 | super(UpConvBlock, self).__init__() 42 | self.bilinear = bilinear 43 | if not self.bilinear: 44 | self.upconv_layer = nn.ConvTranspose2d(input_dim, output_dim, kernel_size=2, stride=2) 45 | self.upconv_layer.apply(init_weights) 46 | # pdb.set_trace() 47 | self.conv_block = DownConvBlock(input_dim, output_dim, initializers, padding, pool=False,norm=norm) 48 | 49 | def forward(self, x, bridge): 50 | if self.bilinear: 51 | up = nn.functional.interpolate(x, mode='bilinear', scale_factor=2, align_corners=True) 52 | else: 53 | up = self.upconv_layer(x) 54 | 55 | assert up.shape[3] == bridge.shape[3] 56 | out = torch.cat([up, bridge], 1) 57 | out = self.conv_block(out) 58 | 59 | return out 60 | -------------------------------------------------------------------------------- /nflib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghavian/cFlow/79caf3cc9ccc1b6a21a3bc157f432b40947ff217/nflib/__init__.py -------------------------------------------------------------------------------- /nflib/flows.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements various flows. 3 | Each flow is invertible so it can be forward()ed and backward()ed. 4 | Notice that backward() is not backward as in backprop but simply inversion. 5 | Each flow also outputs its log det J "regularization" 6 | 7 | Reference: 8 | 9 | NICE: Non-linear Independent Components Estimation, Dinh et al. 2014 10 | https://arxiv.org/abs/1410.8516 11 | 12 | Variational Inference with Normalizing Flows, Rezende and Mohamed 2015 13 | https://arxiv.org/abs/1505.05770 14 | 15 | Density estimation using Real NVP, Dinh et al. May 2016 16 | https://arxiv.org/abs/1605.08803 17 | (Laurent's extension of NICE) 18 | 19 | Improved Variational Inference with Inverse Autoregressive Flow, Kingma et al June 2016 20 | https://arxiv.org/abs/1606.04934 21 | (IAF) 22 | 23 | Masked Autoregressive Flow for Density Estimation, Papamakarios et al. May 2017 24 | https://arxiv.org/abs/1705.07057 25 | "The advantage of Real NVP compared to MAF and IAF is that it can both generate data and estimate densities with one forward pass only, whereas MAF would need D passes to generate data and IAF would need D passes to estimate densities." 26 | (MAF) 27 | 28 | Glow: Generative Flow with Invertible 1x1 Convolutions, Kingma and Dhariwal, Jul 2018 29 | https://arxiv.org/abs/1807.03039 30 | 31 | "Normalizing Flows for Probabilistic Modeling and Inference" 32 | https://arxiv.org/abs/1912.02762 33 | (review paper) 34 | """ 35 | 36 | import numpy as np 37 | import torch 38 | import torch.nn.functional as F 39 | from torch import nn 40 | 41 | from nflib.nets import LeafParam, MLP, ARMLP 42 | 43 | class AffineConstantFlow(nn.Module): 44 | """ 45 | Scales + Shifts the flow by (learned) constants per dimension. 46 | In NICE paper there is a Scaling layer which is a special case of this where t is None 47 | """ 48 | def __init__(self, dim, scale=True, shift=True): 49 | super().__init__() 50 | self.s = nn.Parameter(torch.randn(1, dim, requires_grad=True)) if scale else None 51 | self.t = nn.Parameter(torch.randn(1, dim, requires_grad=True)) if shift else None 52 | 53 | def forward(self, x): 54 | s = self.s if self.s is not None else x.new_zeros(x.size()) 55 | t = self.t if self.t is not None else x.new_zeros(x.size()) 56 | z = x * torch.exp(s) + t 57 | log_det = torch.sum(s, dim=1) 58 | return z, log_det 59 | 60 | def backward(self, z): 61 | s = self.s if self.s is not None else z.new_zeros(z.size()) 62 | t = self.t if self.t is not None else z.new_zeros(z.size()) 63 | x = (z - t) * torch.exp(-s) 64 | log_det = torch.sum(-s, dim=1) 65 | return x, log_det 66 | 67 | 68 | class ActNorm(AffineConstantFlow): 69 | """ 70 | Really an AffineConstantFlow but with a data-dependent initialization, 71 | where on the very first batch we clever initialize the s,t so that the output 72 | is unit gaussian. As described in Glow paper. 73 | """ 74 | def __init__(self, *args, **kwargs): 75 | super().__init__(*args, **kwargs) 76 | self.data_dep_init_done = False 77 | 78 | def forward(self, x): 79 | # first batch is used for init 80 | if not self.data_dep_init_done: 81 | assert self.s is not None and self.t is not None # for now 82 | self.s.data = (-torch.log(x.std(dim=0, keepdim=True))).detach() 83 | self.t.data = (-(x * torch.exp(self.s)).mean(dim=0, keepdim=True)).detach() 84 | self.data_dep_init_done = True 85 | return super().forward(x) 86 | 87 | 88 | def backward(self, z): 89 | z0, z1 = z[:,::2], z[:,1::2] 90 | if self.parity: 91 | z0, z1 = z1, z0 92 | s = self.s_cond(z0) 93 | t = self.t_cond(z0) 94 | x0 = z0 # this was the same 95 | x1 = (z1 - t) * torch.exp(-s) # reverse the transform on this half 96 | if self.parity: 97 | x0, x1 = x1, x0 98 | x = torch.cat([x0, x1], dim=1) 99 | log_det = torch.sum(-s, dim=1) 100 | return x, log_det 101 | 102 | 103 | class CondActNorm(nn.Module): 104 | """ 105 | Scales + Shifts the flow by (learned) constants per dimension. 106 | In NICE paper there is a Scaling layer which is a special case of this where t is None 107 | """ 108 | def __init__(self, dim, scale=True, shift=True): 109 | super().__init__() 110 | self.dim = dim 111 | def forward(self, x, s, t): 112 | eps = 1e-6 113 | z = x * s + t 114 | log_det = torch.sum( torch.log(torch.abs(s)+eps), dim=1) 115 | return z, log_det 116 | 117 | def backward(self, z): 118 | s = self.s if self.s is not None else z.new_zeros(z.size()) 119 | t = self.t if self.t is not None else z.new_zeros(z.size()) 120 | x = (z - t) * torch.exp(-s) 121 | log_det = torch.sum(-s, dim=1) 122 | return x, log_det 123 | 124 | 125 | class CondInvertible1x1Conv(nn.Module): 126 | """ 127 | As introduced in Glow paper. 128 | """ 129 | 130 | def __init__(self, dim): 131 | super().__init__() 132 | self.dim = dim 133 | 134 | def forward(self, x, W): 135 | eps=1e-6 136 | z = torch.bmm(x.unsqueeze(1), W).squeeze(1) 137 | log_det = torch.log(eps+torch.abs(torch.det(W))) 138 | 139 | return z, log_det 140 | 141 | def backward(self, z): 142 | W_inv = torch.inverse(W) 143 | x = z @ W_inv 144 | log_det = -torch.sum(torch.log(torch.abs(torch.det(W_inv)))) 145 | return x, log_det 146 | 147 | class CondAffineHalfFlow(nn.Module): 148 | """ 149 | As seen in RealNVP, affine autoregressive flow (z = x * exp(s) + t), where half of the 150 | dimensions in x are linearly scaled/transfromed as a function of the other half. 151 | Which half is which is determined by the parity bit. 152 | - RealNVP both scales and shifts (default) 153 | - NICE only shifts 154 | """ 155 | def __init__(self, dim, latent_dim, parity, net_class=MLP, nh=24, scale=True, shift=True): 156 | super().__init__() 157 | self.dim = dim 158 | self.latent_dim = latent_dim 159 | self.parity = parity 160 | self.s_cond = lambda x: x.new_zeros(x.size(0), self.dim // 2) 161 | self.t_cond = lambda x: x.new_zeros(x.size(0), self.dim // 2) 162 | if scale: 163 | self.s_cond = net_class(self.dim , self.dim // 2, nh) 164 | if shift: 165 | self.t_cond = net_class(self.dim, self.dim // 2, nh) 166 | self.condition = nn.Linear(latent_dim, self.dim //2) 167 | self.condition = self.condition.cuda() 168 | self.s_cond = self.s_cond.cuda() 169 | self.t_cond = self.t_cond.cuda() 170 | 171 | def forward(self, x, h): 172 | 173 | eps = 1e-6 174 | x0, x1 = x[:,::2], x[:,1::2] 175 | if self.parity: 176 | x0, x1 = x1, x0 177 | h = self.condition(h) 178 | log_s = self.s_cond(torch.cat((x0,h),dim=1)) 179 | s = torch.exp(log_s) 180 | t = self.t_cond(torch.cat((x0,h),dim=1)) 181 | z0 = x0 # untouched half 182 | z1 = s * x1 + t # transform this half as a function of the other 183 | if self.parity: 184 | z0, z1 = z1, z0 185 | z = torch.cat([z0, z1], dim=1) 186 | log_det = torch.sum(torch.log(s+eps), dim=1) 187 | return z, log_det 188 | 189 | class AffineHalfFlow(nn.Module): 190 | """ 191 | As seen in RealNVP, affine autoregressive flow (z = x * exp(s) + t), where half of the 192 | dimensions in x are linearly scaled/transfromed as a function of the other half. 193 | Which half is which is determined by the parity bit. 194 | - RealNVP both scales and shifts (default) 195 | - NICE only shifts 196 | """ 197 | def __init__(self, dim, parity, net_class=MLP, nh=24, scale=True, shift=True): 198 | super().__init__() 199 | self.dim = dim 200 | self.parity = parity 201 | self.s_cond = lambda x: x.new_zeros(x.size(0), self.dim // 2) 202 | self.t_cond = lambda x: x.new_zeros(x.size(0), self.dim // 2) 203 | if scale: 204 | self.s_cond = net_class(self.dim // 2, self.dim // 2, nh) 205 | if shift: 206 | self.t_cond = net_class(self.dim // 2, self.dim // 2, nh) 207 | self.s_cond = self.s_cond.cuda() 208 | self.t_cond = self.t_cond.cuda() 209 | 210 | 211 | def forward(self, x): 212 | x0, x1 = x[:,::2], x[:,1::2] 213 | if self.parity: 214 | x0, x1 = x1, x0 215 | s = self.s_cond(x0) 216 | t = self.t_cond(x0) 217 | z0 = x0 # untouched half 218 | z1 = torch.exp(s) * x1 + t # transform this half as a function of the other 219 | if self.parity: 220 | z0, z1 = z1, z0 221 | z = torch.cat([z0, z1], dim=1) 222 | log_det = torch.sum(s, dim=1) 223 | return z, log_det 224 | 225 | def backward(self, z): 226 | z0, z1 = z[:,::2], z[:,1::2] 227 | if self.parity: 228 | z0, z1 = z1, z0 229 | s = self.s_cond(z0) 230 | t = self.t_cond(z0) 231 | x0 = z0 # this was the same 232 | x1 = (z1 - t) * torch.exp(-s) # reverse the transform on this half 233 | if self.parity: 234 | x0, x1 = x1, x0 235 | x = torch.cat([x0, x1], dim=1) 236 | log_det = torch.sum(-s, dim=1) 237 | return x, log_det 238 | 239 | 240 | class SlowMAF(nn.Module): 241 | """ 242 | Masked Autoregressive Flow, slow version with explicit networks per dim 243 | """ 244 | def __init__(self, dim, parity, net_class=MLP, nh=24): 245 | super().__init__() 246 | self.dim = dim 247 | self.layers = nn.ModuleDict() 248 | self.layers[str(0)] = LeafParam(2) 249 | for i in range(1, dim): 250 | self.layers[str(i)] = net_class(i, 2, nh) 251 | self.order = list(range(dim)) if parity else list(range(dim))[::-1] 252 | 253 | def forward(self, x): 254 | z = torch.zeros_like(x) 255 | log_det = torch.zeros(x.size(0)) 256 | for i in range(self.dim): 257 | st = self.layers[str(i)](x[:, :i]) 258 | s, t = st[:, 0], st[:, 1] 259 | z[:, self.order[i]] = x[:, i] * torch.exp(s) + t 260 | log_det += s 261 | return z, log_det 262 | 263 | def backward(self, z): 264 | x = torch.zeros_like(z) 265 | log_det = torch.zeros(z.size(0)) 266 | for i in range(self.dim): 267 | st = self.layers[str(i)](x[:, :i]) 268 | s, t = st[:, 0], st[:, 1] 269 | x[:, i] = (z[:, self.order[i]] - t) * torch.exp(-s) 270 | log_det += -s 271 | return x, log_det 272 | 273 | class MAF(nn.Module): 274 | """ Masked Autoregressive Flow that uses a MADE-style network for fast forward """ 275 | 276 | def __init__(self, dim, parity, net_class=ARMLP, nh=24): 277 | super().__init__() 278 | self.dim = dim 279 | self.net = net_class(dim, dim*2, nh) 280 | self.parity = parity 281 | 282 | def forward(self, x): 283 | # here we see that we are evaluating all of z in parallel, so density estimation will be fast 284 | st = self.net(x) 285 | s, t = st.split(self.dim, dim=1) 286 | z = x * torch.exp(s) + t 287 | # reverse order, so if we stack MAFs correct things happen 288 | z = z.flip(dims=(1,)) if self.parity else z 289 | log_det = torch.sum(s, dim=1) 290 | return z, log_det 291 | 292 | def backward(self, z): 293 | # we have to decode the x one at a time, sequentially 294 | x = torch.zeros_like(z) 295 | log_det = torch.zeros(z.size(0)) 296 | z = z.flip(dims=(1,)) if self.parity else z 297 | for i in range(self.dim): 298 | st = self.net(x.clone()) # clone to avoid in-place op errors if using IAF 299 | s, t = st.split(self.dim, dim=1) 300 | x[:, i] = (z[:, i] - t[:, i]) * torch.exp(-s[:, i]) 301 | log_det += -s[:, i] 302 | return x, log_det 303 | 304 | class IAF(MAF): 305 | def __init__(self, *args, **kwargs): 306 | super().__init__(*args, **kwargs) 307 | """ 308 | reverse the flow, giving an Inverse Autoregressive Flow (IAF) instead, 309 | where sampling will be fast but density estimation slow 310 | """ 311 | self.forward, self.backward = self.backward, self.forward 312 | 313 | 314 | class Invertible1x1Conv(nn.Module): 315 | """ 316 | As introduced in Glow paper. 317 | """ 318 | 319 | def __init__(self, dim): 320 | super().__init__() 321 | self.dim = dim 322 | Q = torch.nn.init.orthogonal_(torch.randn(dim, dim)) 323 | P, L, U = torch.lu_unpack(*Q.lu()) 324 | self.P = P # remains fixed during optimization 325 | self.L = nn.Parameter(L) # lower triangular portion 326 | self.S = nn.Parameter(U.diag()) # "crop out" the diagonal to its own parameter 327 | self.U = nn.Parameter(torch.triu(U, diagonal=1)) # "crop out" diagonal, stored in S 328 | 329 | 330 | def _assemble_W(self): 331 | """ assemble W from its pieces (P, L, U, S) """ 332 | L = torch.tril(self.L, diagonal=-1) + torch.diag(torch.ones(self.dim)) 333 | U = torch.triu(self.U, diagonal=1) 334 | W = self.P @ L @ (U + torch.diag(self.S)) 335 | return W 336 | 337 | def forward(self, x): 338 | W = self._assemble_W() 339 | z = x @ W 340 | log_det = torch.sum(torch.log(torch.abs(self.S))) 341 | return z, log_det 342 | 343 | def backward(self, z): 344 | W = self._assemble_W() 345 | W_inv = torch.inverse(W) 346 | x = z @ W_inv 347 | log_det = -torch.sum(torch.log(torch.abs(self.S))) 348 | return x, log_det 349 | 350 | # ------------------------------------------------------------------------ 351 | 352 | class NormalizingFlow(nn.Module): 353 | """ A sequence of Normalizing Flows is a Normalizing Flow """ 354 | 355 | def __init__(self, flows): 356 | super().__init__() 357 | self.flows = nn.ModuleList(flows) 358 | 359 | def forward(self, x): 360 | m, _ = x.shape 361 | log_det = torch.zeros(m) 362 | zs = [x] 363 | for flow in self.flows: 364 | x, ld = flow.forward(x) 365 | log_det += ld 366 | zs.append(x) 367 | return zs, log_det 368 | 369 | def backward(self, z): 370 | m, _ = z.shape 371 | log_det = torch.zeros(m) 372 | xs = [z] 373 | for flow in self.flows[::-1]: 374 | z, ld = flow.backward(z) 375 | log_det += ld 376 | xs.append(z) 377 | return xs, log_det 378 | 379 | class NormalizingFlowModel(nn.Module): 380 | """ A Normalizing Flow Model is a (prior, flow) pair """ 381 | 382 | def __init__(self, prior, flows): 383 | super().__init__() 384 | self.prior = prior 385 | self.flow = NormalizingFlow(flows) 386 | 387 | def forward(self, x): 388 | zs, log_det = self.flow.forward(x) 389 | prior_logprob = self.prior.log_prob(zs[-1]).view(x.size(0), -1).sum(1) 390 | return zs, prior_logprob, log_det 391 | 392 | def backward(self, z): 393 | xs, log_det = self.flow.backward(z) 394 | return xs, log_det 395 | 396 | def sample(self, num_samples): 397 | z = self.prior.sample((num_samples,)) 398 | xs, _ = self.flow.backward(z) 399 | return xs 400 | -------------------------------------------------------------------------------- /nflib/made.py: -------------------------------------------------------------------------------- 1 | """ 2 | # copy pasted from my earlier MADE implementation 3 | # https://github.com/karpathy/pytorch-made 4 | 5 | Implements a Masked Autoregressive MLP, where carefully constructed 6 | binary masks over weights ensure the autoregressive property. 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn 13 | 14 | class MaskedLinear(nn.Linear): 15 | """ same as Linear except has a configurable mask on the weights """ 16 | 17 | def __init__(self, in_features, out_features, bias=True): 18 | super().__init__(in_features, out_features, bias) 19 | self.register_buffer('mask', torch.ones(out_features, in_features)) 20 | 21 | def set_mask(self, mask): 22 | self.mask.data.copy_(torch.from_numpy(mask.astype(np.uint8).T)) 23 | 24 | def forward(self, input): 25 | return F.linear(input, self.mask * self.weight, self.bias) 26 | 27 | class MADE(nn.Module): 28 | def __init__(self, nin, hidden_sizes, nout, num_masks=1, natural_ordering=False): 29 | """ 30 | nin: integer; number of inputs 31 | hidden sizes: a list of integers; number of units in hidden layers 32 | nout: integer; number of outputs, which usually collectively parameterize some kind of 1D distribution 33 | note: if nout is e.g. 2x larger than nin (perhaps the mean and std), then the first nin 34 | will be all the means and the second nin will be stds. i.e. output dimensions depend on the 35 | same input dimensions in "chunks" and should be carefully decoded downstream appropriately. 36 | the output of running the tests for this file makes this a bit more clear with examples. 37 | num_masks: can be used to train ensemble over orderings/connections 38 | natural_ordering: force natural ordering of dimensions, don't use random permutations 39 | """ 40 | 41 | super().__init__() 42 | self.nin = nin 43 | self.nout = nout 44 | self.hidden_sizes = hidden_sizes 45 | assert self.nout % self.nin == 0, "nout must be integer multiple of nin" 46 | 47 | # define a simple MLP neural net 48 | self.net = [] 49 | hs = [nin] + hidden_sizes + [nout] 50 | for h0,h1 in zip(hs, hs[1:]): 51 | self.net.extend([ 52 | MaskedLinear(h0, h1), 53 | nn.ReLU(), 54 | ]) 55 | self.net.pop() # pop the last ReLU for the output layer 56 | self.net = nn.Sequential(*self.net) 57 | 58 | # seeds for orders/connectivities of the model ensemble 59 | self.natural_ordering = natural_ordering 60 | self.num_masks = num_masks 61 | self.seed = 0 # for cycling through num_masks orderings 62 | 63 | self.m = {} 64 | self.update_masks() # builds the initial self.m connectivity 65 | # note, we could also precompute the masks and cache them, but this 66 | # could get memory expensive for large number of masks. 67 | 68 | def update_masks(self): 69 | if self.m and self.num_masks == 1: return # only a single seed, skip for efficiency 70 | L = len(self.hidden_sizes) 71 | 72 | # fetch the next seed and construct a random stream 73 | rng = np.random.RandomState(self.seed) 74 | self.seed = (self.seed + 1) % self.num_masks 75 | 76 | # sample the order of the inputs and the connectivity of all neurons 77 | self.m[-1] = np.arange(self.nin) if self.natural_ordering else rng.permutation(self.nin) 78 | for l in range(L): 79 | self.m[l] = rng.randint(self.m[l-1].min(), self.nin-1, size=self.hidden_sizes[l]) 80 | 81 | # construct the mask matrices 82 | masks = [self.m[l-1][:,None] <= self.m[l][None,:] for l in range(L)] 83 | masks.append(self.m[L-1][:,None] < self.m[-1][None,:]) 84 | 85 | # handle the case where nout = nin * k, for integer k > 1 86 | if self.nout > self.nin: 87 | k = int(self.nout / self.nin) 88 | # replicate the mask across the other outputs 89 | masks[-1] = np.concatenate([masks[-1]]*k, axis=1) 90 | 91 | # set the masks in all MaskedLinear layers 92 | layers = [l for l in self.net.modules() if isinstance(l, MaskedLinear)] 93 | for l,m in zip(layers, masks): 94 | l.set_mask(m) 95 | 96 | def forward(self, x): 97 | return self.net(x) 98 | -------------------------------------------------------------------------------- /nflib/nets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various helper network modules 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from nflib.made import MADE 10 | 11 | class LeafParam(nn.Module): 12 | """ 13 | just ignores the input and outputs a parameter tensor, lol 14 | todo maybe this exists in PyTorch somewhere? 15 | """ 16 | def __init__(self, n): 17 | super().__init__() 18 | self.p = nn.Parameter(torch.zeros(1,n)) 19 | 20 | def forward(self, x): 21 | return self.p.expand(x.size(0), self.p.size(1)) 22 | 23 | class PositionalEncoder(nn.Module): 24 | """ 25 | Each dimension of the input gets expanded out with sins/coses 26 | to "carve" out the space. Useful in low-dimensional cases with 27 | tightly "curled up" data. 28 | """ 29 | def __init__(self, freqs=(.5,1,2,4,8)): 30 | super().__init__() 31 | self.freqs = freqs 32 | 33 | def forward(self, x): 34 | sines = [torch.sin(x * f) for f in self.freqs] 35 | coses = [torch.cos(x * f) for f in self.freqs] 36 | out = torch.cat(sines + coses, dim=1) 37 | return out 38 | 39 | class MLP(nn.Module): 40 | """ a simple 4-layer MLP """ 41 | 42 | def __init__(self, nin, nout, nh): 43 | super().__init__() 44 | self.net = nn.Sequential( 45 | nn.Linear(nin, nh), 46 | nn.LeakyReLU(0.2), 47 | nn.Linear(nh, nh), 48 | nn.LeakyReLU(0.2), 49 | nn.Linear(nh, nh), 50 | nn.LeakyReLU(0.2), 51 | nn.Linear(nh, nout), 52 | ) 53 | def forward(self, x): 54 | return self.net(x) 55 | 56 | class PosEncMLP(nn.Module): 57 | """ 58 | Position Encoded MLP, where the first layer performs position encoding. 59 | Each dimension of the input gets transformed to len(freqs)*2 dimensions 60 | using a fixed transformation of sin/cos of given frequencies. 61 | """ 62 | def __init__(self, nin, nout, nh, freqs=(.5,1,2,4,8)): 63 | super().__init__() 64 | self.net = nn.Sequential( 65 | PositionalEncoder(freqs), 66 | MLP(nin * len(freqs) * 2, nout, nh), 67 | ) 68 | def forward(self, x): 69 | return self.net(x) 70 | 71 | class ARMLP(nn.Module): 72 | """ a 4-layer auto-regressive MLP, wrapper around MADE net """ 73 | 74 | def __init__(self, nin, nout, nh): 75 | super().__init__() 76 | self.net = MADE(nin, [nh, nh, nh], nout, num_masks=1, natural_ordering=True) 77 | 78 | def forward(self, x): 79 | return self.net(x) 80 | -------------------------------------------------------------------------------- /nflib/spline_flows.py: -------------------------------------------------------------------------------- 1 | """ 2 | Neural Spline Flows, coupling and autoregressive 3 | 4 | Paper reference: Durkan et al https://arxiv.org/abs/1906.04032 5 | Code reference: slightly modified https://github.com/tonyduan/normalizing-flows/blob/master/nf/flows.py 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.init as init 12 | import torch.nn.functional as F 13 | 14 | from nflib.nets import MLP 15 | 16 | DEFAULT_MIN_BIN_WIDTH = 1e-3 17 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 18 | DEFAULT_MIN_DERIVATIVE = 1e-3 19 | 20 | def searchsorted(bin_locations, inputs, eps=1e-6): 21 | bin_locations[..., -1] += eps 22 | return torch.sum( 23 | inputs[..., None] >= bin_locations, 24 | dim=-1 25 | ) - 1 26 | 27 | def unconstrained_RQS(inputs, unnormalized_widths, unnormalized_heights, 28 | unnormalized_derivatives, inverse=False, 29 | tail_bound=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, 30 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 31 | min_derivative=DEFAULT_MIN_DERIVATIVE): 32 | inside_intvl_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 33 | outside_interval_mask = ~inside_intvl_mask 34 | 35 | outputs = torch.zeros_like(inputs) 36 | logabsdet = torch.zeros_like(inputs) 37 | 38 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 39 | constant = np.log(np.exp(1 - min_derivative) - 1) 40 | unnormalized_derivatives[..., 0] = constant 41 | unnormalized_derivatives[..., -1] = constant 42 | 43 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 44 | logabsdet[outside_interval_mask] = 0 45 | 46 | outputs[inside_intvl_mask], logabsdet[inside_intvl_mask] = RQS( 47 | inputs=inputs[inside_intvl_mask], 48 | unnormalized_widths=unnormalized_widths[inside_intvl_mask, :], 49 | unnormalized_heights=unnormalized_heights[inside_intvl_mask, :], 50 | unnormalized_derivatives=unnormalized_derivatives[inside_intvl_mask, :], 51 | inverse=inverse, 52 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 53 | min_bin_width=min_bin_width, 54 | min_bin_height=min_bin_height, 55 | min_derivative=min_derivative 56 | ) 57 | return outputs, logabsdet 58 | 59 | def RQS(inputs, unnormalized_widths, unnormalized_heights, 60 | unnormalized_derivatives, inverse=False, left=0., right=1., 61 | bottom=0., top=1., min_bin_width=DEFAULT_MIN_BIN_WIDTH, 62 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 63 | min_derivative=DEFAULT_MIN_DERIVATIVE): 64 | if torch.min(inputs) < left or torch.max(inputs) > right: 65 | raise ValueError("Input outside domain") 66 | 67 | num_bins = unnormalized_widths.shape[-1] 68 | 69 | if min_bin_width * num_bins > 1.0: 70 | raise ValueError('Minimal bin width too large for the number of bins') 71 | if min_bin_height * num_bins > 1.0: 72 | raise ValueError('Minimal bin height too large for the number of bins') 73 | 74 | widths = F.softmax(unnormalized_widths, dim=-1) 75 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 76 | cumwidths = torch.cumsum(widths, dim=-1) 77 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 78 | cumwidths = (right - left) * cumwidths + left 79 | cumwidths[..., 0] = left 80 | cumwidths[..., -1] = right 81 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 82 | 83 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 84 | 85 | heights = F.softmax(unnormalized_heights, dim=-1) 86 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 87 | cumheights = torch.cumsum(heights, dim=-1) 88 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 89 | cumheights = (top - bottom) * cumheights + bottom 90 | cumheights[..., 0] = bottom 91 | cumheights[..., -1] = top 92 | heights = cumheights[..., 1:] - cumheights[..., :-1] 93 | 94 | if inverse: 95 | bin_idx = searchsorted(cumheights, inputs)[..., None] 96 | else: 97 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 98 | 99 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 100 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 101 | 102 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 103 | delta = heights / widths 104 | input_delta = delta.gather(-1, bin_idx)[..., 0] 105 | 106 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 107 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx) 108 | input_derivatives_plus_one = input_derivatives_plus_one[..., 0] 109 | 110 | input_heights = heights.gather(-1, bin_idx)[..., 0] 111 | 112 | if inverse: 113 | a = (((inputs - input_cumheights) * (input_derivatives \ 114 | + input_derivatives_plus_one - 2 * input_delta) \ 115 | + input_heights * (input_delta - input_derivatives))) 116 | b = (input_heights * input_derivatives - (inputs - input_cumheights) \ 117 | * (input_derivatives + input_derivatives_plus_one \ 118 | - 2 * input_delta)) 119 | c = - input_delta * (inputs - input_cumheights) 120 | 121 | discriminant = b.pow(2) - 4 * a * c 122 | assert (discriminant >= 0).all() 123 | 124 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 125 | outputs = root * input_bin_widths + input_cumwidths 126 | 127 | theta_one_minus_theta = root * (1 - root) 128 | denominator = input_delta \ 129 | + ((input_derivatives + input_derivatives_plus_one \ 130 | - 2 * input_delta) * theta_one_minus_theta) 131 | derivative_numerator = input_delta.pow(2) \ 132 | * (input_derivatives_plus_one * root.pow(2) \ 133 | + 2 * input_delta * theta_one_minus_theta \ 134 | + input_derivatives * (1 - root).pow(2)) 135 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 136 | return outputs, -logabsdet 137 | else: 138 | theta = (inputs - input_cumwidths) / input_bin_widths 139 | theta_one_minus_theta = theta * (1 - theta) 140 | 141 | numerator = input_heights * (input_delta * theta.pow(2) \ 142 | + input_derivatives * theta_one_minus_theta) 143 | denominator = input_delta + ((input_derivatives \ 144 | + input_derivatives_plus_one - 2 * input_delta) \ 145 | * theta_one_minus_theta) 146 | outputs = input_cumheights + numerator / denominator 147 | 148 | derivative_numerator = input_delta.pow(2) \ 149 | * (input_derivatives_plus_one * theta.pow(2) \ 150 | + 2 * input_delta * theta_one_minus_theta \ 151 | + input_derivatives * (1 - theta).pow(2)) 152 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 153 | return outputs, logabsdet 154 | 155 | class NSF_AR(nn.Module): 156 | """ Neural spline flow, coupling layer, [Durkan et al. 2019] """ 157 | 158 | def __init__(self, dim, K=5, B=3, hidden_dim=8, base_network=MLP): 159 | super().__init__() 160 | self.dim = dim 161 | self.K = K 162 | self.B = B 163 | self.layers = nn.ModuleList() 164 | self.init_param = nn.Parameter(torch.Tensor(3 * K - 1)) 165 | for i in range(1, dim): 166 | self.layers += [base_network(i, 3 * K - 1, hidden_dim)] 167 | self.reset_parameters() 168 | 169 | def reset_parameters(self): 170 | init.uniform_(self.init_param, - 1 / 2, 1 / 2) 171 | 172 | def forward(self, x): 173 | z = torch.zeros_like(x) 174 | log_det = torch.zeros(z.shape[0]) 175 | for i in range(self.dim): 176 | if i == 0: 177 | init_param = self.init_param.expand(x.shape[0], 3 * self.K - 1) 178 | W, H, D = torch.split(init_param, self.K, dim = 1) 179 | else: 180 | out = self.layers[i - 1](x[:, :i]) 181 | W, H, D = torch.split(out, self.K, dim = 1) 182 | W, H = torch.softmax(W, dim = 1), torch.softmax(H, dim = 1) 183 | W, H = 2 * self.B * W, 2 * self.B * H 184 | D = F.softplus(D) 185 | z[:, i], ld = unconstrained_RQS(x[:, i], W, H, D, inverse=False, tail_bound=self.B) 186 | log_det += ld 187 | return z, log_det 188 | 189 | def backward(self, z): 190 | x = torch.zeros_like(z) 191 | log_det = torch.zeros(x.shape[0]) 192 | for i in range(self.dim): 193 | if i == 0: 194 | init_param = self.init_param.expand(x.shape[0], 3 * self.K - 1) 195 | W, H, D = torch.split(init_param, self.K, dim = 1) 196 | else: 197 | out = self.layers[i - 1](x[:, :i]) 198 | W, H, D = torch.split(out, self.K, dim = 1) 199 | W, H = torch.softmax(W, dim = 1), torch.softmax(H, dim = 1) 200 | W, H = 2 * self.B * W, 2 * self.B * H 201 | D = F.softplus(D) 202 | x[:, i], ld = unconstrained_RQS(z[:, i], W, H, D, inverse = True, tail_bound = self.B) 203 | log_det += ld 204 | return x, log_det 205 | 206 | 207 | class NSF_CL(nn.Module): 208 | """ Neural spline flow, coupling layer, [Durkan et al. 2019] """ 209 | 210 | def __init__(self, dim, K=5, B=3, hidden_dim=8, base_network=MLP): 211 | super().__init__() 212 | self.dim = dim 213 | self.K = K 214 | self.B = B 215 | self.f1 = base_network(dim // 2, (3 * K - 1) * dim // 2, hidden_dim) 216 | self.f2 = base_network(dim // 2, (3 * K - 1) * dim // 2, hidden_dim) 217 | 218 | def forward(self, x): 219 | log_det = torch.zeros(x.shape[0]) 220 | lower, upper = x[:, :self.dim // 2], x[:, self.dim // 2:] 221 | out = self.f1(lower).reshape(-1, self.dim // 2, 3 * self.K - 1) 222 | W, H, D = torch.split(out, self.K, dim = 2) 223 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 224 | W, H = 2 * self.B * W, 2 * self.B * H 225 | D = F.softplus(D) 226 | upper, ld = unconstrained_RQS(upper, W, H, D, inverse=False, tail_bound=self.B) 227 | log_det += torch.sum(ld, dim = 1) 228 | out = self.f2(upper).reshape(-1, self.dim // 2, 3 * self.K - 1) 229 | W, H, D = torch.split(out, self.K, dim = 2) 230 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 231 | W, H = 2 * self.B * W, 2 * self.B * H 232 | D = F.softplus(D) 233 | lower, ld = unconstrained_RQS(lower, W, H, D, inverse=False, tail_bound=self.B) 234 | log_det += torch.sum(ld, dim = 1) 235 | return torch.cat([lower, upper], dim = 1), log_det 236 | 237 | def backward(self, z): 238 | log_det = torch.zeros(z.shape[0]) 239 | lower, upper = z[:, :self.dim // 2], z[:, self.dim // 2:] 240 | out = self.f2(upper).reshape(-1, self.dim // 2, 3 * self.K - 1) 241 | W, H, D = torch.split(out, self.K, dim = 2) 242 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 243 | W, H = 2 * self.B * W, 2 * self.B * H 244 | D = F.softplus(D) 245 | lower, ld = unconstrained_RQS(lower, W, H, D, inverse=True, tail_bound=self.B) 246 | log_det += torch.sum(ld, dim = 1) 247 | out = self.f1(lower).reshape(-1, self.dim // 2, 3 * self.K - 1) 248 | W, H, D = torch.split(out, self.K, dim = 2) 249 | W, H = torch.softmax(W, dim = 2), torch.softmax(H, dim = 2) 250 | W, H = 2 * self.B * W, 2 * self.B * H 251 | D = F.softplus(D) 252 | upper, ld = unconstrained_RQS(upper, W, H, D, inverse = True, tail_bound = self.B) 253 | log_det += torch.sum(ld, dim = 1) 254 | return torch.cat([lower, upper], dim = 1), log_det -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | from data.dataset import LIDC 6 | from models.cflownet import cFlowNet 7 | import torch.nn as nn 8 | from models.unet import Unet 9 | from utils.utils import l2_regularisation,ged 10 | import time 11 | from utils.tools import makeLogFile, writeLog, dice_loss 12 | import pdb 13 | import argparse 14 | import sys 15 | import os 16 | 17 | torch.manual_seed(42) 18 | np.random.seed(42) 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--flow', action='store_true', default=False, help=' Train with Flow model') 23 | parser.add_argument('--glow', action='store_true', default=False, help=' Train with Glow model') 24 | parser.add_argument('--data', type=str, default='data/lidc/',help='Path to data.') 25 | parser.add_argument('--probUnet', action='store_true', default=False, help='Train with Prob. Unet') 26 | parser.add_argument('--unet', action='store_true', default=False, help='Train with Det. Unet') 27 | parser.add_argument('--singleRater', action='store_true', default=False, help='Train with single rater') 28 | parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs') 29 | parser.add_argument('--batch_size', type=int, default=96, help='Batch size') 30 | parser.add_argument('--num_flows', type=int, default=4, help='Num flows') 31 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') 32 | 33 | args = parser.parse_args() 34 | 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | dataset = LIDC(data_dir=args.data) 37 | dataset_size = len(dataset) 38 | indices = list(range(dataset_size)) 39 | split = int(np.floor(0.2 * dataset_size)) 40 | 41 | np.random.shuffle(indices) 42 | valid_indices, test_indices, train_indices = indices[:split], indices[split:2*split], indices[2*split:] 43 | train_sampler = SubsetRandomSampler(train_indices) 44 | valid_sampler = SubsetRandomSampler(valid_indices) 45 | test_sampler = SubsetRandomSampler(test_indices) 46 | train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler) 47 | valid_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=valid_sampler) 48 | test_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=test_sampler) 49 | 50 | print("Number of train/valid/test patches:", (len(train_indices),len(valid_indices),len(test_indices))) 51 | 52 | fName = time.strftime("%Y%m%d_%H_%M") 53 | 54 | if args.singleRater: 55 | print("Using a single rater..") 56 | fName = fName+'_1R' 57 | else: 58 | print("Using all experts...") 59 | fName = fName+'_4R' 60 | 61 | if args.flow: 62 | print("Using flow based model with %d steps"%args.num_flows) 63 | fName = fName+'_flow' 64 | net = cFlowNet(input_channels=1, num_classes=1, 65 | num_filters=[32,64,128,256], latent_dim=6, 66 | no_convs_fcomb=4, num_flows=args.num_flows, 67 | norm=True,flow=args.flow) 68 | elif args.glow: 69 | print("Using Glow based model with %d steps"%args.num_flows) 70 | fName = fName+'_glow' 71 | net = cFlowNet(input_channels=1, num_classes=1, 72 | num_filters=[32,64,128,256], latent_dim=6, 73 | no_convs_fcomb=4, norm=True,num_flows=args.num_flows, 74 | flow=args.flow,glow=args.glow) 75 | elif args.probUnet: 76 | print("Using probUnet") 77 | fName = fName+'_probUnet' 78 | net = cFlowNet(input_channels=1, num_classes=1, 79 | num_filters=[32,64,128,256], latent_dim=6, 80 | no_convs_fcomb=4, norm=True,flow=args.flow) 81 | elif args.unet: 82 | print("Using Det. Unet") 83 | fName = fName+'_Unet' 84 | net = Unet(input_channels=1, num_classes=1, 85 | num_filters=[32,64,128,256], apply_last_layer=True, 86 | padding=True, norm=True, 87 | initializers={'w':'he_normal', 'b':'normal'}) 88 | criterion = nn.BCELoss(size_average=False) 89 | else: 90 | print("Choose a model.\nAborting....") 91 | sys.exit() 92 | 93 | if not os.path.exists('logs'): 94 | os.mkdir('logs') 95 | 96 | logFile = 'logs/'+fName+'.txt' 97 | makeLogFile(logFile) 98 | 99 | net.to(device) 100 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-5) 101 | nTrain = len(train_loader) 102 | nValid = len(valid_loader) 103 | nTest = len(test_loader) 104 | 105 | minLoss = 1e8 106 | 107 | convIter=0 108 | convCheck = 20 109 | 110 | for epoch in range(args.epochs): 111 | trLoss = [] 112 | vlLoss = [] 113 | vlGed = [0] 114 | klEpoch = [0] 115 | recEpoch = [0] 116 | kl = torch.zeros(1) 117 | recLoss = torch.zeros(1) 118 | dGED = 0 119 | t = time.time() 120 | for step, (patch, masks) in enumerate(train_loader): 121 | patch = patch.to(device) 122 | masks = masks.to(device) 123 | if args.singleRater or args.unet: 124 | rater = 0 125 | else: 126 | # Choose a random mask 127 | rater = torch.randperm(4)[0] 128 | mask = masks[:,[rater]] 129 | if not args.unet: 130 | net.forward(patch, mask, training=True) 131 | _,_,_,elbo = net.elbo(mask,use_mask=False,analytic_kl=False) 132 | reg_loss = l2_regularisation(net.posterior) + l2_regularisation(net.prior) 133 | loss = -elbo + 1e-5 * reg_loss 134 | else: 135 | pred = torch.sigmoid(net.forward(patch,False)) 136 | loss = criterion(target=mask,input=pred) 137 | optimizer.zero_grad() 138 | loss.backward() 139 | optimizer.step() 140 | trLoss.append(loss.item()) 141 | 142 | if (step+1) % 5 == 0: 143 | with torch.no_grad(): 144 | for idx, (patch, masks) in enumerate(valid_loader): 145 | patch = patch.to(device) 146 | masks = masks.to(device) 147 | # Choose a random mask 148 | mask = masks[:,[rater]] 149 | if not args.unet: 150 | net.forward(patch, mask, training=True) 151 | _,recLoss, kl, elbo = net.elbo(mask,use_mask=False, 152 | analytic_kl=False) 153 | reg_loss = l2_regularisation(net.posterior) + \ 154 | l2_regularisation(net.prior) 155 | loss = -elbo + 1e-5 * reg_loss 156 | klEpoch.append(kl.item()) 157 | recEpoch.append(recLoss.item()) 158 | else: 159 | pred = torch.sigmoid(net.forward(patch, False)) 160 | loss = criterion(target=mask, input=pred) 161 | vlLoss.append(loss.item()) 162 | break 163 | print ('Epoch [{}/{}], Step [{}/{}], TrLoss: {:.4f}, VlLoss: {:.4f}, RecLoss: {:.4f}, kl: {:.4f}, GED: {:.4f}' 164 | .format(epoch+1, args.epochs, step+1, nTrain, trLoss[-1], vlLoss[-1], recLoss.item(),\ 165 | kl.item(), vlGed[-1])) 166 | epValidLoss = np.mean(vlLoss) 167 | if (epoch+1) % 1 == 0 and epValidLoss > 0 and epValidLoss < minLoss: 168 | convIter = 0 169 | minLoss = epValidLoss 170 | print("New min: %.2f\nSaving model..."%(minLoss)) 171 | torch.save(net.state_dict(),'../models/'+fName+'.pt') 172 | else: 173 | convIter += 1 174 | writeLog(logFile, epoch, np.mean(trLoss), 175 | epValidLoss,np.mean(recEpoch), np.mean(klEpoch), time.time()-t) 176 | 177 | if convIter == convCheck: 178 | print("Converged at epoch %d"%(epoch+1-convCheck)) 179 | break 180 | elif np.isnan(epValidLoss): 181 | print("Nan error!") 182 | break 183 | 184 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raghavian/cFlow/79caf3cc9ccc1b6a21a3bc157f432b40947ff217/utils/__init__.py -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | from os.path import isfile 5 | from os import rename 6 | SMOOTH=1 7 | import pdb 8 | from sklearn.metrics import auc, roc_curve 9 | import torch.nn.functional as F 10 | import torch.nn as nn 11 | from PIL.ImageFilter import GaussianBlur 12 | 13 | 14 | def wCELoss(prediction, target): 15 | w1 = 1.33 # False negative penalty 16 | w2 = .66 # False positive penalty 17 | return -torch.mean(w1 * target * torch.log(prediction.clamp_min(1e-3)) 18 | + w2 * (1. - target) * torch.log(1. - prediction.clamp_max(.999))) 19 | 20 | class GaussianFilter(object): 21 | """Apply Gaussian blur to the PIL image 22 | Args: 23 | sigma (float): Sigma of Gaussian kernel. Default value 1.0 24 | """ 25 | def __init__(self, sigma=1): 26 | self.sigma = sigma 27 | self.filter = GaussianBlur(radius=sigma) 28 | 29 | def __call__(self, img): 30 | """ 31 | Args: 32 | img (PIL Image): Image to be blurred. 33 | 34 | Returns: 35 | PIL Image: Blurred image. 36 | """ 37 | return img.filter(self.filter) 38 | 39 | def __repr__(self): 40 | return self.__class__.__name__ + '(sigma={})'.format(self.sigma) 41 | 42 | class GaussianLayer(nn.Module): 43 | def __init__(self): 44 | super(GaussianLayer, self, sigma=1, size=10).__init__() 45 | self.sigma = sigma 46 | self.size = size 47 | self.seq = nn.Sequential( 48 | nn.ReflectionPad2d(size), 49 | nn.Conv2d(3, 3, size, stride=1, padding=0, bias=None, groups=3) 50 | ) 51 | self.weights_init() 52 | 53 | def forward(self, x): 54 | return self.seq(x) 55 | 56 | def weights_init(self): 57 | s = self.size * 2 + 1 58 | k = np.zeros((s,s)) 59 | k[s,s] = 1 60 | kernel = gaussian_filter(k,sigma=self.sigma) 61 | for name, f in self.named_parameters(): 62 | f.data.copy_(torch.from_numpy(kernel)) 63 | 64 | class focalLoss(nn.Module): 65 | def __init__(self, alpha=1, gamma=2, logits=False, reduce=True): 66 | super(focalLoss, self).__init__() 67 | self.alpha = alpha 68 | self.gamma = gamma 69 | self.logits = logits 70 | self.reduce = reduce 71 | 72 | def forward(self, inputs, targets): 73 | if self.logits: 74 | BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False) 75 | else: 76 | BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False) 77 | pt = torch.exp(-BCE_loss) 78 | F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss 79 | 80 | if self.reduce: 81 | return torch.mean(F_loss) 82 | else: 83 | return F_loss 84 | 85 | def computeAuc(target,preds): 86 | fpr, tpr, thresholds = roc_curve(target,preds) 87 | aucVal = auc(fpr,tpr) 88 | return aucVal 89 | 90 | class hingeLoss(torch.nn.Module): 91 | 92 | def __init__(self): 93 | super(hingeLoss, self).__init__() 94 | 95 | def forward(self, output, target): 96 | # pdb.set_trace() 97 | target = 2*target-1 98 | output = 2*output-1 99 | hinge_loss = 1 - torch.mul(output, target) 100 | hinge_loss[hinge_loss < 0] = 0 101 | return hinge_loss.mean() 102 | 103 | 104 | def makeBatchAdj(adj,bSize): 105 | 106 | E = adj._nnz() 107 | N = adj.shape[0] 108 | batch_idx = torch.zeros(2,bSize*E).type(torch.LongTensor) 109 | batch_val = torch.zeros(bSize*E) 110 | 111 | idx = adj._indices() 112 | vals = adj._values() 113 | 114 | for i in range(bSize): 115 | batch_idx[:,i*E:(i+1)*E] = idx + i*N 116 | batch_val[i*E:(i+1)*E] = vals 117 | 118 | return torch.sparse.FloatTensor(batch_idx,batch_val,(bSize*N,bSize*N)) 119 | 120 | 121 | def makeAdj(ngbrs, normalize=True): 122 | """ Create an adjacency matrix, given the neighbour indices 123 | Input: Nxd neighbourhood, where N is number of nodes 124 | Output: NxN sparse torch adjacency matrix 125 | """ 126 | # pdb.set_trace() 127 | N, d = ngbrs.shape 128 | validNgbrs = (ngbrs >= 0) # Mask for valid neighbours amongst the d-neighbours 129 | row = np.repeat(np.arange(N),d) # Row indices like in sparse matrix formats 130 | row = row[validNgbrs.reshape(-1)] #Remove non-neighbour row indices 131 | col = (ngbrs*validNgbrs).reshape(-1) # Obtain nieghbour col indices 132 | col = col[validNgbrs.reshape(-1)] # Remove non-neighbour col indices 133 | data = np.ones(col.size) 134 | adj = sp.csr_matrix((np.ones(col.size, dtype=bool),(row, col)), shape=(N, N)).toarray() # Make adj matrix 135 | adj = adj + np.eye(N) # Self connections 136 | adj = sp.csr_matrix(adj, dtype=np.float32)#/(d+1) 137 | if normalize: 138 | adj = row_normalize(adj) 139 | adj = sparse_mx_to_torch_sparse_tensor(adj) 140 | 141 | return adj 142 | 143 | def makeRegAdj(numNgbrs=26): 144 | """ Make regular pixel neighbourhoods""" 145 | idx = 0 146 | ngbrOffset = np.zeros((3,numNgbrs),dtype=int) 147 | for i in range(-1,2): 148 | for j in range(-1,2): 149 | for k in range(-1,2): 150 | if(i | j | k): 151 | ngbrOffset[:,idx] = [i,j,k] 152 | idx+=1 153 | idx = 0 154 | ngbrs = np.zeros((numEl, numNgbrs), dtype=int) 155 | 156 | for i in range(xdim): 157 | for j in range(ydim): 158 | for k in range(zdim): 159 | xIdx = np.mod(ngbrOffset[0,:]+i,xdim) 160 | yIdx = np.mod(ngbrOffset[1,:]+j,ydim) 161 | zIdx = np.mod(ngbrOffset[2,:]+k,zdim) 162 | ngbrs[idx,:] = idxVol[xIdx, yIdx, zIdx] 163 | idx += 1 164 | 165 | 166 | def makeAdjWithInvNgbrs(ngbrs, normalize=False): 167 | """ Create an adjacency matrix, given the neighbour indices including invalid indices where self connections are added. 168 | Input: Nxd neighbourhood, where N is number of nodes 169 | Output: NxN sparse torch adjacency matrix 170 | """ 171 | np.random.seed(2) 172 | # pdb.set_trace() 173 | N, d = ngbrs.shape 174 | row = np.arange(N).reshape(-1,1) 175 | random = np.random.randint(0,N-1,(N,d)) 176 | valIdx = np.array((ngbrs < 0),dtype=int) 177 | ngbrs = random*valIdx + ngbrs*(1-valIdx)# Mask for valid neighbours amongst the d-neighbours 178 | row = np.repeat(row,d).reshape(-1) # Row indices like in sparse matrix formats 179 | col = ngbrs.reshape(-1) # Obtain nieghbour col indices 180 | data = np.ones(col.size) 181 | adj = sp.csr_matrix((np.ones(col.size, dtype=bool),(row, col)), shape=(N, N)).toarray() # Make adj matrix 182 | adj = adj + np.eye(N) # Self connections 183 | adj = sp.csr_matrix(adj, dtype=np.float32)#/(d+1) 184 | if normalize: 185 | adj = row_normalize(adj) 186 | adj = sparse_mx_to_torch_sparse_tensor(adj) 187 | adj = adj.coalesce() 188 | adj._values = adj.values() 189 | return adj 190 | 191 | 192 | def transformers(adj): 193 | """ Obtain source and sink node transformer matrices""" 194 | edges = adj._indices() 195 | N = adj.shape[0] 196 | nnz = adj._nnz() 197 | val = torch.ones(nnz) 198 | idx0 = torch.arange(nnz) 199 | 200 | idx = torch.stack((idx0,edges[1,:])) 201 | n2e_in = torch.sparse.FloatTensor(idx,val,(nnz,N)) 202 | 203 | idx = torch.stack((idx0,edges[0,:])) 204 | n2e_out = torch.sparse.FloatTensor(idx,val,(nnz,N)) 205 | 206 | return n2e_in, n2e_out 207 | 208 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 209 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 210 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 211 | indices = torch.from_numpy(np.vstack((sparse_mx.row, 212 | sparse_mx.col))).long() 213 | values = torch.from_numpy(sparse_mx.data) 214 | shape = torch.Size(sparse_mx.shape) 215 | return torch.sparse.FloatTensor(indices, values, shape) 216 | 217 | def to_linear_idx(x_idx, y_idx, num_cols): 218 | assert num_cols > np.max(x_idx) 219 | x_idx = np.array(x_idx, dtype=np.int32) 220 | y_idx = np.array(y_idx, dtype=np.int32) 221 | return y_idx * num_cols + x_idx 222 | 223 | 224 | def row_normalize(mx): 225 | """Row-normalize sparse matrix""" 226 | rowsum = np.array(mx.sum(1), dtype=np.float32) 227 | r_inv = np.power(rowsum, -1).flatten() 228 | r_inv[np.isinf(r_inv)] = 0. 229 | r_mat_inv = sp.diags(r_inv) 230 | mx = r_mat_inv.dot(mx) 231 | return mx 232 | 233 | def to_2d_idx(idx, num_cols): 234 | idx = np.array(idx, dtype=np.int64) 235 | y_idx = np.array(np.floor(idx / float(num_cols)), dtype=np.int64) 236 | x_idx = idx % num_cols 237 | return x_idx, y_idx 238 | 239 | def dice_loss(input, target): 240 | "Return dice score. " 241 | # pdb.set_trace() 242 | preds_sq = input**2 243 | return 1 - (2. * (torch.sum(input * target)) + SMOOTH) / \ 244 | (preds_sq.sum() + target.sum() + SMOOTH) 245 | 246 | def binary_accuracy(output, labels): 247 | preds = output > 0.5 248 | correct = preds.type_as(labels).eq(labels).double() 249 | correct = correct.sum() 250 | return correct / len(labels) 251 | 252 | def multiClassAccuracy(output, labels): 253 | # pdb.set_trace() 254 | preds = output.argmax(1) 255 | # preds = (output > (1.0/labels.shape[1])).type_as(labels) 256 | correct = (preds == labels.view(-1)) 257 | correct = correct.sum().float() 258 | return correct / len(labels) 259 | 260 | def regrAcc(output, labels): 261 | # pdb.set_trace() 262 | preds = output.round().type(torch.long).type_as(labels) 263 | # preds = (output > (1.0/labels.shape[1])).type_as(labels) 264 | correct = (preds == labels.view(-1)) 265 | correct = correct.sum().float() 266 | return correct / len(labels) 267 | 268 | 269 | def rescaledRegAcc(output,labels,lRange=37,lMin=-20): 270 | # pdb.set_trace() 271 | preds = (output+1)*(lRange)/2 + lMin 272 | preds = preds.round().type(torch.long).type_as(labels) 273 | # preds = (output > (1.0/labels.shape[1])).type_as(labels) 274 | correct = (preds == labels.view(-1)) 275 | correct = correct.sum().float() 276 | return correct / len(labels) 277 | 278 | def focalCE(preds, labels, gamma=1): 279 | "Return focal cross entropy" 280 | loss = -torch.mean( ( ((1-preds)**gamma) * labels * torch.log(preds) ) \ 281 | + ( ((preds)**gamma) * (1-labels) * torch.log(1-preds) ) ) 282 | return loss 283 | 284 | def dice(preds, labels): 285 | # pdb.set_trace() 286 | "Return dice score" 287 | preds_bin = (preds > 0.5).type_as(labels) 288 | return 2. * torch.sum(preds_bin * labels) / (preds_bin.sum() + labels.sum()) 289 | 290 | def wBCE(preds, labels, w): 291 | "Return weighted CE loss." 292 | return -torch.mean( w*labels*torch.log(preds) + (1-w)*(1-labels)*torch.log(1-preds) ) 293 | 294 | def makeLogFile(filename="lossHistory.txt"): 295 | if isfile(filename): 296 | rename(filename,"lossHistoryOld.txt") 297 | 298 | with open(filename,"w") as text_file: 299 | print('Epoch\tlossTr\tlossVl\tRecLoss\tKL\ttime(s)',file=text_file) 300 | print("Log file created...") 301 | return 302 | 303 | def writeLog(logFile, epoch, lossTr, lossVl, recLoss, kl,eTime): 304 | print('Epoch:{:04d}\t'.format(epoch + 1), 305 | 'lossTr:{:.4f}\t'.format(lossTr), 306 | 'lossVl:{:.4f}\t'.format(lossVl), 307 | 'recLoss:{:.4f}\t'.format(recLoss), 308 | 'KL:{:.4f}\t'.format(kl), 309 | 'time:{:.4f}'.format(eTime)) 310 | 311 | with open(logFile,"a") as text_file: 312 | print('{:04d}\t'.format(epoch + 1), 313 | '{:.4f}\t'.format(lossTr), 314 | '{:.4f}\t'.format(lossVl), 315 | 'recLoss:{:.4f}\t'.format(recLoss), 316 | 'KL:{:.4f}\t'.format(kl), 317 | '{:.4f}'.format(eTime),file=text_file) 318 | return 319 | 320 | def plotLearningCurve(): 321 | plt.clf() 322 | tmp = np.load('loss_tr.npz')['arr_0'] 323 | plt.plot(tmp,label='Tr.Loss') 324 | tmp = np.load('loss_vl.npz')['arr_0'] 325 | plt.plot(tmp,label='Vl.Loss') 326 | tmp = np.load('dice_tr.npz')['arr_0'] 327 | plt.plot(tmp,label='Tr.Dice') 328 | tmp = np.load('dice_vl.npz')['arr_0'] 329 | plt.plot(tmp,label='Vl.Dice') 330 | plt.legend() 331 | plt.grid() 332 | plt.show() 333 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import matplotlib.pyplot as plt 6 | import pdb 7 | import numpy as np 8 | 9 | EPS = 1e-6 10 | 11 | def ncc(a,v, zero_norm=True): 12 | 13 | a = a.flatten() 14 | v = v.flatten() 15 | 16 | if zero_norm: 17 | 18 | a = (a - np.mean(a)) / (np.std(a) * len(a)) 19 | v = (v - np.mean(v)) / np.std(v) 20 | 21 | else: 22 | 23 | a = (a) / (np.std(a) * len(a)) 24 | v = (v) / np.std(v) 25 | 26 | return np.correlate(a,v) 27 | 28 | def variance_ncc_dist(sample_arr, gt_arr): 29 | 30 | def pixel_wise_xent(m_samp, m_gt, eps=1e-8): 31 | 32 | 33 | log_samples = np.log(m_samp + eps) 34 | 35 | return -1.0*np.sum(m_gt*log_samples, axis=-1) 36 | 37 | """ 38 | :param sample_arr: expected shape N x X x Y 39 | :param gt_arr: M x X x Y 40 | :return: 41 | """ 42 | # pdb.set_trace() 43 | gt_arr = gt_arr[0] 44 | sample_arr = sample_arr[0] 45 | mean_seg = np.mean(sample_arr, axis=0) 46 | 47 | N = sample_arr.shape[0] 48 | M = gt_arr.shape[0] 49 | 50 | sX = sample_arr.shape[1] 51 | sY = sample_arr.shape[2] 52 | 53 | E_ss_arr = np.zeros((N,sX,sY)) 54 | for i in range(N): 55 | E_ss_arr[i,...] = pixel_wise_xent(sample_arr[i,...], mean_seg) 56 | # print('pixel wise xent') 57 | # plt.imshow( E_ss_arr[i,...]) 58 | # plt.show() 59 | # pdb.set_trace() 60 | 61 | E_ss = np.mean(E_ss_arr, axis=0) 62 | 63 | E_sy_arr = np.zeros((M,N, sX, sY)) 64 | for j in range(M): 65 | for i in range(N): 66 | E_sy_arr[j,i, ...] = pixel_wise_xent(sample_arr[i,...], gt_arr[j,...]) 67 | 68 | E_sy = np.mean(E_sy_arr, axis=1) 69 | 70 | ncc_list = [] 71 | # pdb.set_trace() 72 | 73 | for j in range(M): 74 | 75 | ncc_list.append(ncc(E_ss, E_sy[j,...])) 76 | 77 | return (1/M)*sum(np.array(ncc_list)) 78 | 79 | 80 | def pdist(a,b): 81 | N = a.shape[1] 82 | M = b.shape[1] 83 | H = a.shape[-2] 84 | W = a.shape[-1] 85 | # C = a.shape[2] 86 | 87 | aRep = a.repeat(1,M,1,1).view(-1,N,M,H,W) 88 | bRep = b.repeat(1,N,1,1).view(-1,M,N,H,W).transpose(1,2) 89 | 90 | inter = (aRep & bRep).float().sum(-1).sum(-1) + EPS 91 | union = (aRep | bRep).float().sum(-1).sum(-1) + EPS 92 | IoU = inter/union 93 | dis = (1-IoU).mean(-1).mean(-1) 94 | return dis 95 | 96 | def ged(seg,prd): 97 | # pdb.set_trace() 98 | seg = seg.type(torch.ByteTensor) 99 | prd = prd.type_as(seg) 100 | 101 | dSP = pdist(seg,prd) 102 | dSS = pdist(seg,seg) 103 | dPP = pdist(prd,prd) 104 | 105 | return (2*dSP - dSS - dPP) 106 | 107 | def truncated_normal_(tensor, mean=0, std=1): 108 | size = tensor.shape 109 | tmp = tensor.new_empty(size + (4,)).normal_() 110 | valid = (tmp < 2) & (tmp > -2) 111 | ind = valid.max(-1, keepdim=True)[1] 112 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 113 | tensor.data.mul_(std).add_(mean) 114 | 115 | def init_weights(m): 116 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 117 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 118 | #nn.init.normal_(m.weight, std=0.001) 119 | #nn.init.normal_(m.bias, std=0.001) 120 | truncated_normal_(m.bias, mean=0, std=0.001) 121 | 122 | def init_weights_orthogonal_normal(m): 123 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 124 | nn.init.orthogonal_(m.weight) 125 | truncated_normal_(m.bias, mean=0, std=0.001) 126 | #nn.init.normal_(m.bias, std=0.001) 127 | 128 | def l2_regularisation(m): 129 | l2_reg = None 130 | 131 | for W in m.parameters(): 132 | if l2_reg is None: 133 | l2_reg = W.norm(2) 134 | else: 135 | l2_reg = l2_reg + W.norm(2) 136 | return l2_reg 137 | 138 | def save_mask_prediction_example(mask, pred, iter): 139 | plt.imshow(pred[0,:,:],cmap='Greys') 140 | plt.savefig('images/'+str(iter)+"_prediction.png") 141 | plt.imshow(mask[0,:,:],cmap='Greys') 142 | plt.savefig('images/'+str(iter)+"_mask.png") 143 | --------------------------------------------------------------------------------