├── README.md ├── UNet.py ├── Unet_main.ipynb ├── engine.py ├── imm ├── output_10_0.png ├── output_10_1.png ├── output_10_2.png ├── output_12_0.png ├── output_12_1.png ├── output_16_0.png ├── output_16_1.png ├── output_17_0.png ├── output_21_0.png ├── output_23_0.png ├── output_28_0.png ├── output_28_1.png ├── output_28_2.png ├── output_31_0.png ├── output_40_0.png ├── output_40_1.png ├── output_40_2.png ├── output_9_0.png ├── output_9_1.png ├── output_9_2.png ├── render0.png └── unet_arch.png ├── learning_rate_range_test.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # U-Net: Convolutional Networks for Biomedical Image Segmentation 2 | 3 | ## Overview 4 | 5 | In the following we will implement Unet in pytorch and train on the dataset used by the authors of [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://link.springer.com/chapter/10.1007%2F978-3-319-24574-4_28) 6 | 7 | ### Libraries 8 | 9 | 10 | ```python 11 | from UNet import Unet 12 | import utils 13 | import engine 14 | from learning_rate_range_test import LRTest 15 | 16 | import os 17 | import numpy as np 18 | import matplotlib.pyplot as plt 19 | 20 | import albumentations as A 21 | import gc 22 | from tqdm.notebook import tqdm 23 | 24 | import torch 25 | import torch.nn as nn 26 | from torch.utils.data import DataLoader 27 | import torch.optim as optim 28 | 29 | ``` 30 | 31 | ## Unet Implementation 32 | 33 | The architecture of the model is the following: 34 | 35 | ![Unet](./imm/unet_arch.png) 36 | 37 | It consists of a contracting path (left side) and an expansive path (right side). The constrasting path follows the typical architecture of a convolutional neural network, and it consists of the repeated 38 | application of two 3x3 convolutions (unpadded convolutions), each followed by 39 | a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 40 | for downsampling. At each downsampling step we double the number of feature 41 | channels. Every step in the expansive path consists of an upsampling of the 42 | feature map followed by a 2x2 convolution (“up-convolution”) that halves the 43 | number of feature channels, a concatenation with the correspondingly cropped 44 | feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. The cropping is necessary due to the loss of border pixels in 45 | every convolution. At the final layer a 1x1 convolution is used to map each 64- 46 | component feature vector to the desired number of classes. In total the network 47 | has 23 convolutional layers. 48 | 49 | Let's make sure that the implementation works. We'll check if the output dimensions of the tensors match the ones mentioned in Figure 1 when using the same architecture and same size inputs 50 | 51 | 52 | ```python 53 | x = torch.Tensor(np.random.rand(2, 3, 572, 572)) 54 | ch = [3, 64, 128, 256, 512, 1024] 55 | net = Unet(channels = ch, no_classes = 4) 56 | x = net(x) 57 | print(x.shape) 58 | 59 | x = torch.Tensor(np.random.rand(3, 1, 572, 572)) 60 | ch = [1, 64, 128, 256, 512] 61 | net = Unet(channels = ch, no_classes = 2, output_size = (572,572)) 62 | x = net(x) 63 | print(x.shape) 64 | 65 | del x, net 66 | gc.collect(); 67 | ``` 68 | 69 | torch.Size([2, 4, 388, 388]) 70 | torch.Size([3, 2, 572, 572]) 71 | 72 | 73 | Perfect. 74 | 75 | ## Training on the 2012 EM Segmentation Challenge 76 | 77 | A test case example of the paper was the [ISBI 2012 challenge](http://brainiac2.mit.edu/isbi_challenge/). This dataset contains 30 ssTEM (serial section Transmission Electron Microscopy) 512x512 images taken from the Drosophila larva ventral nerve cord (VNC). The images represent a set of consecutive slices within one 3D volume. The microcube measures 2 x 2 x 1.5 microns approx., with a resolution of 4x4x50 nm/pixel. 78 | 79 | The corresponding binary labels are provided in an in-out fashion, i.e. white for the pixels of segmented objects and black for the rest of pixels (which correspond mostly to membranes). 80 | 81 | 30 512x512 images of 1 channel can easily fit to memory, so we can load the entire dataset once. 82 | Let's read the data and visualise an example: 83 | 84 | 85 | ```python 86 | datadir = './data/images/' 87 | labeldir = './data/labels/' 88 | 89 | # Channel and image dimensions (from the data decription) 90 | C, H, W = 1, 512, 512 91 | 92 | img_mtrx, mask_mtrx = utils.readData(datadir, labeldir) 93 | ``` 94 | 95 | 96 | ```python 97 | # Plot a few examples 98 | for i in [1, 5, 10]: 99 | fig, ax = plt.subplots(1, 2, figsize = (8, 8)) 100 | 101 | im = ax[0].imshow(img_mtrx[i, :, :], 'gray') 102 | plt.colorbar(im, ax = ax[0], fraction = 0.046, pad = 0.04) 103 | ax[0].axis('off') 104 | ax[0].set_title('image') 105 | 106 | im = ax[1].imshow(mask_mtrx[i, :, :], 'gray') 107 | plt.colorbar(im, ax = ax[1], fraction = 0.046, pad = 0.04) 108 | ax[1].axis('off') 109 | ax[1].set_title('mask') 110 | ``` 111 | 112 | 113 | 114 | ![png](./imm/output_9_0.png) 115 | 116 | 117 | 118 | 119 | 120 | ![png](./imm/output_9_1.png) 121 | 122 | 123 | 124 | 125 | 126 | ![png](./imm/output_9_2.png) 127 | 128 | 129 | 130 | ## Weight map 131 | 132 | Authors precompute the weight map for each ground truth segmentation to compensate the different frequency of pixels from a certain class in the training dataset, in order to force the network to learn the small separation borders that they introduce between touching cells. 133 | The separation border is computed using morphological operations, and subsequently the weight map is the computed as: 134 | 135 | $w(\mathbf{x}) = w_c(\mathbf{x}) + w_0 \exp \left( - \frac{\left[ d_1(x) + d_2(x)\right]^2}{2\sigma^2} \right) $ 136 | 137 | where $w_c: \Omega \rightarrow \mathbb{R}$ is the weight map to balance the class fequencies, $d_1:\Omega \rightarrow \mathbb{R}$ is the distance to the border of the nearest cell, and $d_2:\Omega \rightarrow \mathbb{R}$ is the distance to the border of the second nearest cell. 138 | 139 | Let's check a few examples: 140 | 141 | 142 | ```python 143 | # Parameters used by the authors 144 | w0 = 10 145 | sigma = 5 146 | 147 | for i in np.random.randint(0, img_mtrx.shape[0], 2): # Randomly draw 3 images 148 | 149 | # Grab image and label map 150 | mask = mask_mtrx[i, :, :] 151 | img = img_mtrx[i, :, :] 152 | 153 | # Compute weight map 154 | w = utils.weight_map(mask = mask, w0 = w0, sigma = sigma) 155 | 156 | # Plot results 157 | fig, ax = plt.subplots(nrows = 1, ncols = 2, figsize = (10, 10)); 158 | 159 | im = ax[0].imshow(img, 'gray') 160 | plt.colorbar(im, ax = ax[0], fraction = 0.046, pad = 0.04) 161 | ax[0].axis('off') 162 | ax[0].set_title('image') 163 | 164 | im = ax[1].imshow(w) 165 | plt.colorbar(im, ax = ax[1], fraction = 0.046, pad = 0.04) 166 | ax[1].axis('off'); 167 | ax[1].set_title('weight map'); 168 | ``` 169 | 170 | 171 | 172 | ![png](./imm/output_12_0.png) 173 | 174 | 175 | 176 | 177 | 178 | ![png](./imm/output_12_1.png) 179 | 180 | 181 | 182 | Very good 183 | 184 | ### Data Augmentation 185 | 186 | According to the authors we need: 187 | Shift and rotation invariance as well as robustness to deformations and gray value variations. 188 | Especially random elastic deformations of the training samples were found to be the key concept to train a segmentation network with very few annotated images. 189 | 190 | Let's define the transformations: 191 | 192 | 193 | ```python 194 | # Define augmentation pipelines 195 | p = 0.95 196 | train_transform = A.Compose([ 197 | A.OneOf([ 198 | A.HorizontalFlip(p = p), 199 | A.VerticalFlip(p = p), 200 | A.Transpose(p = p), 201 | A.RandomRotate90(p = p), 202 | A.ShiftScaleRotate(p = p, shift_limit = 0.0625, scale_limit = 0.1, rotate_limit = 45) 203 | ], p = 1), 204 | A.GaussNoise(p = p, var_limit = (0, 20), mean = 0, per_channel = True), 205 | A.MultiplicativeNoise(p = p, multiplier=(0.9, 1.1), elementwise = True), 206 | A.ElasticTransform(p = p, alpha = 35, sigma = 5, alpha_affine = 3, approximate = True), 207 | A.RandomBrightnessContrast(p = p, brightness_limit = 0.15, contrast_limit = 0.15), 208 | A.PadIfNeeded(p = 1, min_height = 128, min_width = 128, border_mode = cv2.BORDER_REFLECT) 209 | ]) 210 | 211 | ``` 212 | 213 | Let's have a look: 214 | 215 | 216 | ```python 217 | # Parameters recommended by the authors 218 | w0 = 10 219 | sigma = 5 220 | 221 | for i in np.random.randint(0, 30, 1): # Randomly draw 1 image 222 | 223 | # Grab image and label map 224 | mask = mask_mtrx[i, :, :] 225 | img = img_mtrx[i, :, :] 226 | 227 | # Apply transformations 228 | aug = train_transform(image = img, mask = mask) 229 | img_t = aug["image"] 230 | mask_t = aug["mask"] 231 | 232 | # Compute weight map 233 | weights = utils.weight_map(mask = mask_t, w0 = w0, sigma = sigma) 234 | 235 | # Plot 236 | fig, ax = plt.subplots(nrows = 3, ncols = 3, figsize = (10,8), constrained_layout=True) 237 | 238 | im = ax[0, 0].imshow(img, 'gray', interpolation = None) 239 | plt.colorbar(im, ax = ax[0, 0], fraction = 0.046, pad = 0.04) 240 | ax[0, 0].axis('off'); 241 | ax[0, 0].set_title('image - original') 242 | 243 | im = ax[0, 1].imshow(mask, 'gray', interpolation = None) 244 | plt.colorbar(im, ax = ax[0, 1], fraction = 0.046, pad = 0.04) 245 | ax[0, 1].axis('off'); 246 | ax[0, 1].set_title('mask - original') 247 | 248 | ax[0, 2].imshow(img, 'gray', interpolation = None) 249 | ax[0, 2].imshow(mask, 'gray', interpolation = None, alpha = 0.3) 250 | ax[0, 2].axis('off'); 251 | ax[0, 2].set_title('image & mask') 252 | 253 | im = ax[1, 0].imshow(img_t, 'gray', interpolation = None) 254 | plt.colorbar(im, ax = ax[1, 0], fraction = 0.046, pad = 0.04) 255 | ax[1, 0].axis('off'); 256 | ax[1, 0].set_title('image - transformed') 257 | 258 | im = ax[1, 1].imshow(mask_t, 'gray', interpolation = None) 259 | plt.colorbar(im, ax = ax[1, 1], fraction = 0.046, pad = 0.04) 260 | ax[1, 1].axis('off'); 261 | ax[1, 1].set_title('mask - transformed') 262 | 263 | ax[1, 2].imshow(img_t, 'gray', interpolation = None) 264 | ax[1, 2].imshow(mask_t, 'gray', interpolation = None, alpha = 0.3) 265 | ax[1, 2].axis('off'); 266 | ax[1, 2].set_title('image & mask') 267 | 268 | counts, _, _ = ax[2, 0].hist(img_t.reshape(-1, 1), bins = 50, density = True); 269 | ax[2, 0].vlines(img_t.mean(), 0, max(counts), colors='k') 270 | ax[2, 0].vlines(img_t.mean() + 2 * np.std(img_t), 0, max(counts) * 0.75, colors='r') 271 | ax[2, 0].vlines(img_t.mean() - 2 * np.std(img_t), 0, max(counts) * 0.75, colors='r') 272 | ax[2, 0].set_title('image histogram') 273 | 274 | ax[2, 1].hist(mask_t.reshape(-1, 1), bins = 50, density = True); 275 | ax[2, 1].set_title('mask histogram') 276 | 277 | im = ax[2, 2].imshow(weights, interpolation = None) 278 | plt.colorbar(im, ax = ax[2, 2], fraction = 0.046, pad = 0.04) 279 | ax[2, 2].axis('off'); 280 | ax[2, 2].set_title('weights') 281 | ``` 282 | 283 | 284 | 285 | ![png](./imm/output_17_0.png) 286 | 287 | 288 | 289 | Perfect. We can start training 290 | 291 | ### Learning rate range test 292 | 293 | An early stopping class, pixel-wise weighted negative log loss and the training/validation functions have been implemented in the engine module. The authors mention that they used SGD with a batchsize equal to 1 and momentum equal to 0.99. They do not mention the learning rate though. 294 | Let's conduct an [LR test](https://arxiv.org/pdf/1506.01186.pdf) to find a good learning rate. 295 | First, let's make a class for the learning rate range test: 296 | 297 | Let's run it: 298 | 299 | 300 | ```python 301 | act_batch_size = 1 # Can't fit more than one image in GPU! 302 | eff_batch_size = 1 # Efective batch (Gradient accumulation) 303 | device = 'cuda' 304 | momentum = 0.99 305 | channels = [C, 64, 128, 256, 512, 1024] 306 | w0 = 10 307 | sigma = 5 308 | 309 | min_lr = 1e-6 310 | max_lr = 10 311 | no_iter = 50 312 | 313 | # Configure train test split 314 | np.random.seed(123) 315 | no_img = img_mtrx.shape[0] 316 | test_idx = np.random.randint(0, no_img, 3) # Keep 3 images for test set 317 | train_idx = np.setdiff1d(np.arange(0, no_img), test_idx) 318 | 319 | # Make model 320 | model = Unet(channels = channels, no_classes = 1).double().to(device) 321 | 322 | # Configure criterion 323 | criterion = engine.WeightedBCEWithLogitsLoss(batch_size = act_batch_size) 324 | 325 | # Setup optimiser 326 | optimizer = optim.SGD(model.parameters(), lr = min_lr, momentum = momentum) 327 | 328 | # Make dataset 329 | train_set = utils.SegmentationDataset(images = img_mtrx[train_idx, :, :], 330 | masks = mask_mtrx[train_idx, :, :], 331 | transform = train_transform, 332 | device = device, wmap_w0 = w0, wmap_sigma = sigma) 333 | 334 | # Make dataloader 335 | train_loader = DataLoader(dataset = train_set, batch_size = act_batch_size, 336 | shuffle = True, num_workers = 0, pin_memory = False) 337 | 338 | # Run LR range test 339 | lr_test = LRTest(min_lr = min_lr, max_lr = max_lr, no_iter = no_iter, batch_size = eff_batch_size) 340 | lr, loss = lr_test(train_loader, criterion, optimizer, model) 341 | ``` 342 | 343 | 344 | Diverged on iteration 43 with loss 1367209663.550878 345 | 346 | 347 | Let's plot the results: 348 | 349 | 350 | ```python 351 | plt.figure(figsize = (10, 4)) 352 | plt.semilogx(lr, loss, marker = '.') 353 | plt.ylim(min(loss) * 0.98, min(loss) * 1.35); 354 | plt.title('Learning Rate Range Test') 355 | plt.ylabel('Loss') 356 | plt.xlabel('Learning Rate') 357 | plt.grid(b = True, which='both', axis='both'); 358 | ``` 359 | 360 | 361 | 362 | ![png](./imm/output_23_0.png) 363 | 364 | 365 | 366 | We could go with a learning rate of 1e-2 367 | 368 | ### Training Loop 369 | 370 | Now, we can put everything together. The authors report average performance over 7 rotations of the dataset. We'll just do one here to save time: 371 | 372 | 373 | ```python 374 | epochs = 1000 375 | learning_rate = 1e-2 376 | act_batch_size = 1 # Can't fit more than one image in GPU! 377 | eff_batch_size = 1 # Efective batch (Gradient accumulation) 378 | momentum = 0.99 379 | device = 'cuda' 380 | channels = [C, 64, 128, 256, 512, 1024] 381 | w0 = 10 382 | sigma = 5 383 | model_path = './model.pt' 384 | 385 | # Early stopping 386 | es = engine.EarlyStopping(patience = 100, fname = model_path) 387 | 388 | # Make datasets 389 | train_set = engine.SegmentationDataset(images = img_mtrx[train_idx, :, :], 390 | masks = mask_mtrx[train_idx, :, :], 391 | transform = train_transform, 392 | device = device, wmap_w0 = w0, wmap_sigma = sigma) 393 | 394 | test_set = engine.SegmentationDataset(images = img_mtrx[test_idx, :, :], 395 | masks = mask_mtrx[test_idx, :, :], 396 | transform = None, 397 | device = device, wmap_w0 = w0, wmap_sigma = sigma) 398 | 399 | # Make dataloaders 400 | train_loader = DataLoader(dataset = train_set, 401 | batch_size = act_batch_size, 402 | shuffle = True, 403 | num_workers = 0, # Change to >0 for performance 404 | pin_memory = False) # Change to true for performance 405 | 406 | test_loader = DataLoader(dataset = test_set, 407 | batch_size = act_batch_size, 408 | shuffle = False, 409 | num_workers = 0, # Change to >0 for performance 410 | pin_memory = False) # Change to true for performance 411 | 412 | # Make progress bars 413 | pbar_epoch = tqdm(total = epochs, unit = 'epoch', position = 0, leave = False) 414 | pbar_train = tqdm(total = len(train_loader), unit = 'batch', position = 1, leave = False) 415 | 416 | # Make model 417 | model = Unet(channels = channels, no_classes = 1).double().to(device) 418 | 419 | # Make optimiser 420 | optimizer = optim.SGD(model.parameters(), lr = learning_rate, momentum = momentum) 421 | 422 | # Make loss 423 | criterion = engine.WeightedBCEWithLogitsLoss(batch_size = act_batch_size) 424 | 425 | # Load checkpoint (if it exists) 426 | cur_epoch = 0 427 | if os.path.isfile(model_path): 428 | checkpoint = torch.load(model_path) 429 | cur_epoch = checkpoint['epoch'] 430 | es.best_loss = checkpoint['loss'] 431 | model.load_state_dict(checkpoint['model_state_dict']) 432 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 433 | 434 | # Hold stats for training process 435 | stats = {'epoch': [], 'train_loss': [], 'val_loss': []} 436 | 437 | # Training / validation loop 438 | for epoch in range(cur_epoch, epochs): 439 | 440 | # Train / validate 441 | pbar_epoch.set_description_str(f'Epoch {epoch + 1}') 442 | train_loss = engine.train(model, optimizer, train_loader, criterion, eff_batch_size, pbar_train) 443 | val_loss = engine.validation(model, test_loader, criterion) 444 | 445 | # Append stats 446 | stats['epoch'].append(epoch) 447 | stats['train_loss'].append(train_loss) 448 | stats['val_loss'].append(val_loss) 449 | 450 | # Early stopping (just saves model if validation loss decreases when: pass) 451 | if es(epoch, val_loss, optimizer, model): pass 452 | 453 | # Update progress bars 454 | pbar_epoch.set_postfix(train_loss = train_loss, val_loss = val_loss) 455 | pbar_epoch.update(1) 456 | pbar_train.reset() 457 | ``` 458 | 459 | 460 | Let's check the predictions on the validation set: 461 | 462 | 463 | ```python 464 | # load model 465 | 466 | model = Unet(channels = channels, no_classes = 1).double().to(device) 467 | checkpoint = torch.load(model_path) 468 | model.load_state_dict(checkpoint['model_state_dict']) 469 | model.eval() 470 | 471 | # Make loss 472 | criterion = engine.WeightedBCEWithLogitsLoss(batch_size = act_batch_size) 473 | criterion = nn.BCEWithLogitsLoss() 474 | 475 | with torch.no_grad(): 476 | 477 | for batch_id, (X, y, weights) in enumerate(test_loader): 478 | 479 | # Forward 480 | y_hat = model(X) 481 | y_hat = torch.sigmoid(y_hat) 482 | 483 | 484 | # Convert to numpy 485 | X = np.squeeze(X.cpu().numpy()) 486 | y = np.squeeze(y.cpu().numpy()) 487 | w = np.squeeze(weights.cpu().numpy()) 488 | y_hat = np.squeeze(y_hat.detach().cpu().numpy()) 489 | 490 | # Make mask 491 | y_hat2 = y_hat > 0.5 492 | 493 | # plot 494 | fig, ax = plt.subplots(nrows = 1, ncols = 2, figsize = (8, 8)) 495 | 496 | ax[0].imshow(y, 'gray', interpolation = None) 497 | ax[0].axis('off'); 498 | ax[0].set_title('Target'); 499 | 500 | ax[1].imshow(y_hat, 'gray', interpolation = None) 501 | ax[1].axis('off'); 502 | ax[1].set_title('Prediction'); 503 | ``` 504 | 505 | 506 | 507 | ![png](./imm/output_28_0.png) 508 | 509 | 510 | 511 | 512 | 513 | ![png](./imm/output_28_1.png) 514 | 515 | 516 | 517 | 518 | 519 | ![png](./imm/output_28_2.png) 520 | 521 | 522 | 523 | Not bad, not perfect either. 524 | -------------------------------------------------------------------------------- /UNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class Block(nn.Module): 7 | ''' One block of Unet. 8 | Contains 2 repeated 3 x 3 unpadded convolutions, each followed by a ReLU. 9 | ''' 10 | 11 | def __init__(self, in_channel, out_channel, kernel_size): 12 | ''' Initialisation ''' 13 | 14 | super().__init__() 15 | self.conv_1 = nn.Conv2d(in_channel, out_channel, kernel_size) 16 | self.conv_2 = nn.Conv2d(out_channel, out_channel, kernel_size) 17 | self.relu = nn.ReLU() 18 | 19 | # Initialise weights on convolutional layers 20 | nn.init.normal_(self.conv_1.weight, mean = 0.0, std = self.init_std(in_channel, kernel_size)) 21 | nn.init.normal_(self.conv_1.weight, mean = 0.0, std = self.init_std(out_channel, kernel_size)) 22 | 23 | 24 | @staticmethod 25 | def init_std(channels, kernel_size): 26 | ''' Computes std for weight initialisation on the convolutional layers''' 27 | return 2.0 / np.sqrt(channels * kernel_size ** 2) 28 | 29 | 30 | def forward(self, x): 31 | ''' Forward Phase ''' 32 | 33 | x = self.conv_1(x) 34 | x = self.relu(x) 35 | x = self.conv_2(x) 36 | x = self.relu(x) 37 | 38 | return x 39 | 40 | 41 | class Encoder(nn.Module): 42 | ''' Contractive Part of Unet ''' 43 | 44 | def __init__(self, channels): 45 | '''Initialisation''' 46 | 47 | super().__init__() 48 | 49 | # Make block list 50 | modules = [] 51 | 52 | for in_channel, out_channel in zip(channels[:-1], channels[1:]): 53 | block = Block(in_channel = in_channel, out_channel = out_channel, kernel_size = 3) 54 | modules.append(block) 55 | 56 | self.blocks = nn.ModuleList(modules = modules) 57 | self.max_pol = nn.MaxPool2d(kernel_size = 2, stride = None) 58 | self.feat_maps = [] # Feature map of each block to be concatenated with the decoder part 59 | 60 | def forward(self, x): 61 | '''Forward phase''' 62 | 63 | for layer_no, layer in enumerate(self.blocks): 64 | 65 | # Run block 66 | x = layer(x) 67 | 68 | if not self.is_final_layer(layer_no): 69 | 70 | # Store feature maps for the decoder 71 | self.feat_maps.append(x) 72 | 73 | # Perform max pooling operation 74 | x = self.max_pol(x) 75 | 76 | return x 77 | 78 | def is_final_layer(self, layer_no): 79 | return layer_no == len(self.blocks) - 1 80 | 81 | 82 | class Decoder(nn.Module): 83 | ''' Expansive Part of Unet ''' 84 | 85 | def __init__(self, channels): 86 | '''Initialisation''' 87 | 88 | super().__init__() 89 | 90 | # Make module lists 91 | up_convs = [] 92 | blocks = [] 93 | for in_channel, out_channel in zip(channels[:-1], channels[1:]): 94 | 95 | # 2x2 Upconvolution 96 | upconv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size = 2, stride = 2) 97 | up_convs.append(upconv) 98 | 99 | # Block (2 convolutions with ReLUs) 100 | block = Block(in_channel, out_channel, kernel_size = 3) 101 | blocks.append(block) 102 | 103 | # Make modules 104 | self.upconvs = nn.ModuleList(up_convs) 105 | self.blocks = nn.ModuleList(blocks) 106 | 107 | 108 | def forward(self, x, encoded_feat_maps): 109 | 110 | for upconv, block in zip(self.upconvs, self.blocks): 111 | 112 | # Apply upconvolution 113 | x = upconv(x) 114 | 115 | # Grab corresponding feature map from the encoder 116 | fts = encoded_feat_maps.pop() 117 | 118 | # Crop it 119 | fts = self.crop(fts, x.shape[2], x.shape[3]) 120 | 121 | # Concatenate it to the input 122 | x = torch.cat([x, fts], dim = 1) 123 | 124 | # Perform convs with ReLUs 125 | x = block(x) 126 | 127 | return x 128 | 129 | @staticmethod 130 | def crop(tnsr, new_H, new_W): 131 | ''' Center crop an input tensor to shape [hew_H, hew_W] ''' 132 | 133 | # Grab existing size 134 | _, _, H, W = tnsr.size() 135 | 136 | # Compute one corner of the image 137 | x1 = int(round( (H - new_H) / 2.)) 138 | y1 = int(round( (W - new_W) / 2.)) 139 | 140 | # Compute the other one 141 | x2 = x1 + new_H 142 | y2 = y1 + new_W 143 | 144 | return tnsr[:, :, x1:x2, y1:y2] 145 | 146 | 147 | class Unet(nn.Module): 148 | ''' Unet class 149 | As suggested in "U-Net: Convolutional Networks for Biomedical Image Segmentation" (https://arxiv.org/pdf/1505.04597.pdf) 150 | ''' 151 | 152 | def __init__(self, channels, no_classes, output_size = None): 153 | '''Initialisation''' 154 | 155 | super().__init__() 156 | 157 | self.output_size = output_size 158 | 159 | # Initialise encoder 160 | self.encoder = Encoder(channels) 161 | 162 | # Initialise decoder 163 | dec_channels = list(reversed(channels[1:])) # Flip the channels for the contractive part (and omit the first one) 164 | self.decoder = Decoder(dec_channels) 165 | 166 | # Initialise final layer 167 | self.head = nn.Conv2d(in_channels = channels[1], out_channels = no_classes, kernel_size = 1) 168 | 169 | 170 | def forward(self, x): 171 | '''Forward Phase''' 172 | 173 | x = self.encoder(x) 174 | x = self.decoder(x, self.encoder.feat_maps) 175 | x = self.head(x) 176 | 177 | # Retain dimensions 178 | if self.output_size is not None: 179 | x = F.interpolate(x, self.output_size) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from utils import weight_map 4 | from torch import nn 5 | import numpy as np 6 | 7 | 8 | def train(model, optimizer, dataloader, criterion, effective_batch_size, p_bar = None): 9 | ''' Training ''' 10 | 11 | model.train() 12 | optimizer.zero_grad() 13 | running_loss = 0 14 | 15 | for batch_id, (X, y, weights) in enumerate(dataloader): 16 | 17 | if p_bar is not None: 18 | p_bar.set_description_str(f'Batch {batch_id + 1}') 19 | 20 | # Forward 21 | y_hat = model(X) 22 | 23 | # Compute loss 24 | loss = criterion(y, y_hat, weights) / effective_batch_size 25 | running_loss += loss.item() 26 | loss.backward() 27 | 28 | # Backprop 29 | if ( (batch_id + 1) % effective_batch_size == 0 ) or ( (batch_id + 1) == len(dataloader) ): 30 | optimizer.step() 31 | optimizer.zero_grad() 32 | 33 | # Update progress bar 34 | if p_bar is not None: 35 | p_bar.set_postfix(loss = loss.item()) 36 | p_bar.update(1) 37 | 38 | # Compute average loss 39 | running_loss = running_loss / len(dataloader) * effective_batch_size 40 | 41 | return running_loss 42 | 43 | 44 | def validation(model, dataloader, criterion): 45 | ''' Validation ''' 46 | 47 | # Validation 48 | model.eval() 49 | running_loss = 0 50 | 51 | with torch.no_grad(): 52 | for X, y, weights in dataloader: 53 | 54 | # Forward 55 | y_hat = model(X) 56 | 57 | # Compute loss 58 | loss = criterion(y, y_hat, weights) 59 | running_loss += loss.item() 60 | 61 | # Compute average loss 62 | running_loss /= len(dataloader) 63 | 64 | return running_loss 65 | 66 | 67 | class EarlyStopping(object): 68 | '''Early Stopping''' 69 | 70 | def __init__(self, patience, fname): 71 | self.patience = patience 72 | self.best_loss = np.Inf 73 | self.counter = 0 74 | self.filename = fname 75 | 76 | def __call__(self, epoch, loss, optimizer, model): 77 | 78 | if loss < self.best_loss: 79 | self.counter = 0 80 | self.best_loss = loss 81 | 82 | torch.save({ 83 | 'epoch': epoch, 84 | 'model_state_dict': model.state_dict(), 85 | 'optimizer_state_dict': optimizer.state_dict(), 86 | 'loss': loss, 87 | }, self.filename) 88 | 89 | else: 90 | self.counter += 1 91 | 92 | return self.counter == self.patience 93 | 94 | 95 | class WeightedBCEWithLogitsLoss(nn.Module): 96 | ''' Pixel-wise weighted BCEWithLogitsLoss''' 97 | 98 | def __init__(self, batch_size): 99 | 100 | super().__init__() 101 | self.batch_size = batch_size 102 | self.unw_loss = nn.BCEWithLogitsLoss(reduction = 'none') 103 | 104 | def __call__(self, true, predicted, weights): 105 | 106 | # Compute weighted loss 107 | loss = self.unw_loss(predicted, true) * weights 108 | 109 | # Sum over all channels 110 | loss = loss.sum(dim = 1) 111 | 112 | # Flatten and rescale so that loss is approx. in the same interval 113 | loss = loss.view(self.batch_size, -1) / weights.view(self.batch_size, -1) 114 | 115 | # Average over mini-batch 116 | loss = loss.mean() 117 | 118 | return loss 119 | 120 | 121 | class SegmentationDataset(Dataset): 122 | 123 | def __init__(self, images, masks, wmap_w0, wmap_sigma, device, transform = None): 124 | ''' Initialisation function ''' 125 | 126 | self.images = images 127 | self.masks = masks 128 | self.transform = transform 129 | self.device = device 130 | 131 | # Parameters for weight map calculation 132 | self.w0 = wmap_w0 133 | self.sigma = wmap_sigma 134 | 135 | def __len__(self): 136 | return len(self.images) 137 | 138 | def __getitem__(self, idx): 139 | ''' Preprocess and return image, mask, and weight map ''' 140 | 141 | image = self.images[idx, :, :] 142 | mask = self.masks[idx, :, :] 143 | 144 | if self.transform: 145 | 146 | # Apply transformations 147 | aug = self.transform(image = image, mask = mask) 148 | image = aug["image"] 149 | mask = aug["mask"] 150 | 151 | # Compute weight map 152 | weights = weight_map(mask = mask, w0 = self.w0, sigma = self.sigma) 153 | 154 | # Min-max scale image and mask 155 | image = self.min_max_scale(image, min_val = 0, max_val = 1) 156 | mask = self.min_max_scale(mask, min_val = 0, max_val = 1) 157 | 158 | # Add channel dimensions 159 | image = np.expand_dims(image, axis = 0) 160 | weights = np.expand_dims(weights, axis = 0) 161 | mask = np.expand_dims(mask, axis = 0) 162 | 163 | # Convert to tensors and send to device 164 | weights = torch.from_numpy(weights).double().to(self.device) 165 | image = torch.from_numpy(image).double().to(self.device) 166 | mask = torch.from_numpy(mask).double().to(self.device) 167 | 168 | # Center crop mask and weights (negative padding = cropping - size defined manually) 169 | mask = nn.ZeroPad2d(-94)(mask) 170 | weights = nn.ZeroPad2d(-94)(weights) 171 | 172 | return image, mask, weights 173 | 174 | @staticmethod 175 | def min_max_scale(image, max_val, min_val): 176 | '''Normalization to range of min, max''' 177 | 178 | image_new = (image - np.min(image)) * (max_val - min_val) / (np.max(image) - np.min(image)) + min_val 179 | return image_new -------------------------------------------------------------------------------- /imm/output_10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_10_0.png -------------------------------------------------------------------------------- /imm/output_10_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_10_1.png -------------------------------------------------------------------------------- /imm/output_10_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_10_2.png -------------------------------------------------------------------------------- /imm/output_12_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_12_0.png -------------------------------------------------------------------------------- /imm/output_12_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_12_1.png -------------------------------------------------------------------------------- /imm/output_16_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_16_0.png -------------------------------------------------------------------------------- /imm/output_16_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_16_1.png -------------------------------------------------------------------------------- /imm/output_17_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_17_0.png -------------------------------------------------------------------------------- /imm/output_21_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_21_0.png -------------------------------------------------------------------------------- /imm/output_23_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_23_0.png -------------------------------------------------------------------------------- /imm/output_28_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_28_0.png -------------------------------------------------------------------------------- /imm/output_28_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_28_1.png -------------------------------------------------------------------------------- /imm/output_28_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_28_2.png -------------------------------------------------------------------------------- /imm/output_31_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_31_0.png -------------------------------------------------------------------------------- /imm/output_40_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_40_0.png -------------------------------------------------------------------------------- /imm/output_40_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_40_1.png -------------------------------------------------------------------------------- /imm/output_40_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_40_2.png -------------------------------------------------------------------------------- /imm/output_9_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_9_0.png -------------------------------------------------------------------------------- /imm/output_9_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_9_1.png -------------------------------------------------------------------------------- /imm/output_9_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/output_9_2.png -------------------------------------------------------------------------------- /imm/render0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/render0.png -------------------------------------------------------------------------------- /imm/unet_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Miltos-90/UNet_Biomedical_Image_Segmentation/a50669d1cbbb256b12b05f60817e0e4b37dea9f6/imm/unet_arch.png -------------------------------------------------------------------------------- /learning_rate_range_test.py: -------------------------------------------------------------------------------- 1 | from torch import optim 2 | from tqdm.notebook import tqdm 3 | 4 | class LRTest(object): 5 | 6 | def __init__(self, min_lr, max_lr, no_iter, batch_size): 7 | ''' Initialisation function ''' 8 | 9 | self.batch_size = batch_size 10 | self.no_iter = no_iter 11 | self.lr_multiplier = (max_lr / min_lr) ** (1 / (no_iter)) 12 | self.dataiter = None 13 | 14 | 15 | # Function to perform the learning rate range test on one experiment 16 | def __call__(self, dataloader, criterion, optimizer, model): 17 | ''' LR Range test ''' 18 | 19 | # Set model to training mode 20 | model.train() 21 | 22 | # Configure scheduler 23 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma = self.lr_multiplier) 24 | 25 | # Empty lists to hold results 26 | loss_arr, lr_arr = [], [] 27 | 28 | # Counters' initialisation 29 | cur_iter, best_loss = 0, 1e9 30 | 31 | with tqdm(total = self.no_iter) as pbar: 32 | 33 | while cur_iter < self.no_iter: 34 | 35 | # Grab learning rate (before stepping the scheduler) 36 | lr_arr.append(scheduler.get_lr()) 37 | 38 | # Train a batch 39 | cur_loss = self.train_batch(model, criterion, optimizer, scheduler, dataloader) 40 | 41 | # Append loss 42 | loss_arr.append(cur_loss) 43 | 44 | # Check for divergence and exit if needed 45 | if cur_loss < best_loss: 46 | best_loss = cur_loss 47 | 48 | if cur_loss > 2e2 * best_loss: # Divergence 49 | print('Diverged on iteration ' + str(cur_iter) + ' with loss ' + str(cur_loss)) 50 | break 51 | 52 | # Update progress bar 53 | pbar.set_postfix(loss = cur_loss) 54 | pbar.update(1) 55 | cur_iter += 1 56 | 57 | pbar.close() # Close 58 | 59 | return lr_arr, loss_arr 60 | 61 | 62 | # Return a batch 63 | def grab_batch(self, dataloader): 64 | 65 | # Lazy init 66 | if self.dataiter is None: 67 | self.dataiter = iter(dataloader) 68 | 69 | # Get next batch 70 | try: 71 | X, y, w = next(self.dataiter) 72 | 73 | except: # End of dataset -> restart 74 | 75 | self.dataiter = iter(dataloader) 76 | X, y, w = next(self.dataiter) 77 | 78 | return X, y, w 79 | 80 | 81 | # Train batch 82 | def train_batch(self, model, criterion, optimizer, scheduler, dataloader): 83 | 84 | optimizer.zero_grad() 85 | 86 | cur_iter = 0 87 | run_loss = 0 88 | while cur_iter < self.batch_size: 89 | 90 | # Get sample 91 | X, y, w = self.grab_batch(dataloader) 92 | 93 | # Predict 94 | y_hat = model(X) 95 | 96 | # Compute normalised gradients 97 | loss = criterion(y, y_hat, w) / self.batch_size 98 | run_loss += loss.item() 99 | 100 | # Backprop 101 | loss.backward() 102 | 103 | # Update counter 104 | cur_iter += 1 105 | 106 | # Update 107 | optimizer.step() 108 | scheduler.step() 109 | 110 | return run_loss 111 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import defaultdict 3 | import numpy as np 4 | import cv2 5 | import os 6 | from PIL import Image 7 | 8 | def readData(datadir, labeldir, H = 512, W = 512): 9 | img_mtrx = np.empty(shape = (30, H, W), dtype = np.uint8) 10 | mask_mtrx = np.empty(shape = (30, H, W)) 11 | 12 | # Loop over all image / mask pairs 13 | for i, (img, lab) in enumerate(zip(os.listdir(datadir), os.listdir(labeldir))): 14 | 15 | # Load image 16 | image = np.array(Image.open(datadir + img)) 17 | 18 | # Add channel dimension 19 | image = np.expand_dims(image, axis = 0) 20 | 21 | # Add to matrix 22 | img_mtrx[i, :, :] = image 23 | 24 | # Load mask in grayscale (single channel) 25 | mask = cv2.imread(labeldir + lab, cv2.IMREAD_GRAYSCALE) 26 | 27 | # Binarise 28 | _, mask = cv2.threshold(mask, 128, 255, cv2.THRESH_BINARY) 29 | 30 | # Add to matrix 31 | mask_mtrx[i, :, :] = mask 32 | 33 | return img_mtrx, mask_mtrx 34 | 35 | def weight_map(mask, w0, sigma, background_class = 0): 36 | 37 | # Fix mask datatype (should be unsigned 8 bit) 38 | if mask.dtype != 'uint8': 39 | mask = mask.astype('uint8') 40 | 41 | # Weight values to balance classs frequencies 42 | wc = _class_weights(mask) 43 | 44 | # Assign a different label to each connected region of the image 45 | _, regions = cv2.connectedComponents(mask) 46 | 47 | # Get total no. of connected regions in the image and sort them excluding background 48 | region_ids = sorted(np.unique(regions)) 49 | region_ids = [region_id for region_id in region_ids if region_id != background_class] 50 | 51 | if len(region_ids) > 1: # More than one connected regions 52 | 53 | # Initialise distance matrix (dimensions: H x W x no.regions) 54 | distances = np.zeros((mask.shape[0], mask.shape[1], len(region_ids))) 55 | 56 | # For each region 57 | for i, region_id in enumerate(region_ids): 58 | 59 | # Mask all pixels belonging to a different region 60 | m = (regions != region_id).astype(np.uint8)# * 255 61 | 62 | # Compute Euclidean distance for all pixels belongind to a different region 63 | distances[:, :, i] = cv2.distanceTransform(m, distanceType = cv2.DIST_L2, maskSize = 0) 64 | 65 | # Sort distances w.r.t region for every pixel 66 | distances = np.sort(distances, axis = 2) 67 | 68 | # Grab distance to the border of nearest region 69 | d1, d2 = distances[:, :, 0], distances[:, :, 1] 70 | 71 | # Compute RHS of weight map and mask background pixels 72 | w = w0 * np.exp(-1 / (2 * sigma ** 2) * (d1 + d2) ** 2) * (regions == background_class) 73 | 74 | else: # Only a single region present in the image 75 | w = np.zeros_like(mask) 76 | 77 | # Instantiate a matrix to hold class weights 78 | wc_x = np.zeros_like(mask) 79 | 80 | # Compute class weights for each pixel class (background, etc.) 81 | for pixel_class, weight in wc.items(): 82 | 83 | wc_x[mask == pixel_class] = weight 84 | 85 | # Add them to the weight map 86 | w = w + wc_x 87 | 88 | return w 89 | 90 | def _class_weights(mask): 91 | ''' Create a dictionary containing the classes in a mask, 92 | and their corresponding weights to balance their occurence 93 | ''' 94 | 95 | wc = defaultdict() 96 | 97 | # Grab classes and their corresponding counts 98 | unique, counts = np.unique(mask, return_counts = True) 99 | 100 | # Convert counts to frequencies 101 | counts = counts / np.product(mask.shape) 102 | 103 | # Get max. counts 104 | max_count = max(counts) 105 | 106 | for val, count in zip(unique, counts): 107 | wc[val] = max_count / count 108 | 109 | return wc --------------------------------------------------------------------------------