├── README.md ├── Report.txt ├── Report_Unet.txt ├── __pycache__ ├── build_model.cpython-35.pyc ├── data_loader.cpython-35.pyc └── unet_parts.cpython-35.pyc ├── build_model.py ├── data_loader.py ├── gaussian_noise_and_flipping.py ├── inferences.py ├── metric_calc.py ├── preprocess_combining.py ├── preprocess_padding.py ├── train_model_Unet.py ├── train_model_segnet.py └── unet_parts.py /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This report describes the usage of **SegNet** and **U-Net** architechtures for **medical image segmentation**. 3 | 4 | We divide the article into the following parts 5 | 6 | - [Dataset](#dataset) 7 | - [SegNet](#segnet) 8 | - [U-Net](#u-net) 9 | - Loss Functions Used 10 | - Results 11 | - References 12 | - Further Help 13 | 14 | # Dataset 15 | ## Montgomory Dataset 16 | 17 | 18 | 19 | The dataset contains Chest X-Ray images. We use this dataset to perform a lung segmentation. 20 | >The dataset can be found [here](http://openi.nlm.nih.gov/imgs/collections/NLM-MontgomeryCXRSet.zip) 21 | 22 | 23 | 24 | Structure: 25 | 26 | We make the following structure of the given data set: 27 | 28 | ![](https://lh4.googleusercontent.com/J-kHm2BX9ywKISMuY_BaCaFf--UuPJOKlFYLO89gYgvjmqlM9RrFive2wOU30X8N7bzI03uwMCtnb_oCHDPaobyxTMEFlfsTSNXALS629uuAkSUZfm9y-lUv5FORquPe1P8CPp4p) 29 | 30 | ## Data Preprocessing 31 | ![](https://lh4.googleusercontent.com/TTLhU_8UxfxPWPURwLsqNbJu09EfPTReyCXHH9mX7saLzfK6aLgxK_NQd1VNeL7u1acwVnppg2pOZeLO9S4hoxpxjRSoXUHRlK8OAo6peOHpvv_zzTv2g43Wy4HMmk_i-aoATdEG) 32 | Figure 33 | 34 | The Montgomery dataset contains images from the Department of Health and Human Services, Montgomery County, Maryland, USA. The dataset consists of 35 | 36 | 138 CXRs, including 80 normal patients and 58 patients with manifested tuberculosis (TB). The CXR images are 12-bit gray-scale images of dimension 4020 × 4892 or 4892 × 4020 . Only the two lung masks annotations are available which were combined to a single image in order to make it easy for the network to learn the task of segmentation (Fig 1).To make all images of symmetric dimensions we padded the pictures to the maximum dimension in their height or width such that images are of 4892 x 4892, this is done to preserve the aspect ratio of CXR while resizing. We scale all images to 1024 x 1024 pixels, which retains sufficient visual details for vascular structures in the lung fields and this could be the maximum size that could be accommodated in, along with U-Net in Graphics Processing Unit (GPU). We scaled all pixel values to 0-1 . Data augmentation was applied by flipping around the vertical axis and adding gaussian noise with mean 0 and a variance of 0.01. Also rotation about the centre to subtle angles of 5-10 degrees during runtime were performed to make the model more robust. 37 | 38 | 39 | # SegNet 40 | 41 | ### Introduction 42 | ![SegNet](http://mi.eng.cam.ac.uk/projects/segnet/images/segnet.png) 43 | 44 | SegNet has an encoder network and a corresponding decoder network, followed by a final pixelwise classification layer. This architecture is illustrated in the above figure. The encoder network consists of 13 convolutional layers which correspond to the first 13 convolutional layers in the VGG16 network designed for object classification. We can therefore initialize the training process from weights trained for classification on large datasets. We can also discard the fully connected layers in favour of retaining higher resolution feature maps at the deepest encoder output. This also reduces the number of parameters in the SegNet encoder network significantly (from 134M to 14.7M) as compared to other recent architectures. Each encoder layer has a corresponding decoder layer and hence the decoder network has 13 layers. The final decoder output is fed to a multi-class soft-max classifier or for a binary classification task, to a sigmoid activation function to produce class probabilities for each pixel independently. Each encoder in the encoder network performs convolution with a filter bank to produce a set of feature maps. These are then batch normalized. Then an element-wise rectified- linear non-linearity (ReLU) max (0, x) is applied. Following that, max-pooling with a 2 × 2 window and stride 2 (non-overlapping window) is performed and the resulting output is sub-sampled by a factor of 2. Max-pooling is used to achieve translation invariance over small spatial shifts in the input image. Sub-sampling results in a large input image context (spatial window) for each pixel in the feature map. While several layers of max-pooling and sub-sampling can achieve more translation invariance for robust classification correspondingly there is a loss of spatial resolution of the feature maps. The increasingly lossy (boundary detail) image representation is not beneficial for segmentation where boundary delineation is vital. 45 | 46 | 47 | 48 | 49 | # U-Net 50 | 51 | ### Introduction 52 | **U-Net** *(O. Ronneberger and P.Fischer and T. Brox)* is a network that is used to train on medical images to segment the image according to a given mask. The network architecture is illustrated in Figure 1. It consists of a contracting path (left side) and an expansive path (right side). The contracting path follows the typical architecture of a convolutional network. It consists of the repeated application of two 3x3 convolutions (padded convolutions in this case), each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling. At each downsampling step we double the number of feature channels. Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“bilinear interpolation”) that halves the number of feature channels, a concatenation with the corresponding feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. At the final layer a 1x1 convolution is used to map each 64-component feature vector to the desired number of classes. In total the network has 23 convolutional layers. It is important to select the input image size such that all 2x2 max-pooling operations are applied to a layer with an even x- and y-size. 53 | 54 | ![Vanilla U-Net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png) 55 | 56 | 57 | 58 | ### Working 59 | 60 | We input the images to the network, which is first passed through the network encoder, this outputs a 1024 channel feature map as show in the above figure. This feature map is then upsampled back to the orignal size using upsampling(bilinear Interpolation). 61 | At each upsampling layer, a skip connection to its corresponding layer in the encoder, the channels from both layers are concatenated and this is used as input for the next upsampling layer. 62 | 63 | Finally, on the final layer, sigmoid activation is applied and the resulting feature map is then thesholded at 0.5 and which is then the segmented image. 64 | 65 | # Loss Function Used 66 | We use two loss functions here, viz. `Binary Cross Entropy` and `Dice loss` 67 | 68 | #### Binary Cross Entropy Loss 69 | Cross-entropy loss, or log loss, measures the performance of a classification model whose output is a probability value between 0 and 1. Cross-entropy loss increases as the predicted probability diverges from the actual label. So predicting a probability of `.012` when the actual observation label is `1` would be bad and result in a **high loss** value. A perfect model would have a log loss of `0`. 70 | 71 | It is defined mathematically as 72 | In binary classification, where the number of classes $M$ equals 2, cross-entropy can be calculated as: 73 | 74 | $$ 75 | -\frac{1}{N}\sum_{i=1}^N(y_{i}\log(p_{i}) + (1-y_{i})\log(1-p_{i})) 76 | $$ 77 | 78 | #### Dice Coefficient Loss 79 | The dice coefficient loss is used to measure the `intersection over union` of the output and target image. 80 | 81 | Mathematically, Dice Score is 82 | $$\frac{2 |P \cap R|}{|P| + |R|}$$ 83 | 84 | and the corresponding loss is 85 | $$1-\frac{2 |P\cap R|}{|P| + |R|}$$ 86 | 87 | $$1- \frac{2\sum_{i=0}^Np_{i}r_{i}+\epsilon}{\sum_{i=0}^Np_{i}+ \sum_{i=0}^Nr_{i}+\epsilon}\quad p_{i}\space\epsilon\space P,\space r_{i}\space\epsilon\space R$$ 88 | 89 | The dice loss is defined in code as : 90 | 91 | 92 | class SoftDiceLoss(nn.Module): 93 | def __init__(self, weight=None, size_average=True): 94 | super(SoftDiceLoss, self).__init__() 95 | 96 | def forward(self, logits, targets): 97 | smooth = 1 98 | num = targets.size(0) 99 | probs = F.sigmoid(logits) 100 | m1 = probs.view(num, -1) 101 | m2 = targets.view(num, -1) 102 | intersection = (m1 * m2) 103 | 104 | score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth) 105 | score = 1 - score.sum() / num 106 | return score 107 | 108 | #### Inverted Dice Coefficient Loss 109 | The formula below calculates the measure of overlap after inverting the image or in this case taking the complement. 110 | 111 | 112 | Mathematically, Inverted Dice Score is 113 | $$\frac{2|\overline{P}\cap\overline{R}|}{|\overline{P}| +|\overline{R}| }$$ 114 | and the corresponding loss is 115 | $$1-\frac{2|\overline{P}\cap\overline{R}|}{|\overline{P}| +|\overline{R}| }$$ 116 | $$1- \frac{2\sum_{i=0}^N(1-p_{i})(1-r_{i})+\epsilon}{\sum_{i=0}^N(1-p_{i})+ \sum_{i=0}^N(1-r_{i})+\epsilon}\quad p_{i}\space\epsilon\space P,\space r_{i}\space\epsilon\space R$$ 117 | 118 | 119 | class SoftInvDiceLoss(nn.Module): 120 | def __init__(self, weight=None, size_average=True): 121 | super(SoftDiceLoss, self).__init__() 122 | 123 | def forward(self, logits, targets): 124 | smooth = 1 125 | num = targets.size(0) 126 | probs = F.sigmoid(logits) 127 | m1 = probs.view(num, -1) 128 | m2 = targets.view(num, -1) 129 | m1, m2 = 1.-m1, 1.-m2 130 | intersection = (m1 * m2) 131 | 132 | score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth) 133 | score = 1 - score.sum() / num 134 | return score 135 | 136 | 137 | > NOTE: The reason why intersection is implemented as a multiplication and the cardinality as `sum()` on axis 1 (each 3 channels sum) is because predictions and targets are one-hot encoded vectors 138 | 139 | # Results 140 | 141 | | Architecture | Loss | Validation Scores | Validation Scores | Test Scores | Test Scores | 142 | |:------------:|:------------:|-------------------|-------------------|-------------|-------------| 143 | | | | mIoU | mDice | mIoU | mDice | 144 | | U-Net | BCE | 0.9403 | 0.9692 | - | - | 145 | | U-Net | BCE+DCL | 0.9426 | 0.9704 | - | - | 146 | | U-Net | BCE+DCL+IDCL | 0.9665 | 0.9829 | 0.9295 | 0.9623 | 147 | | SegNet | BCE | 0.8867 | 0.9396 | - | - | 148 | | SegNet | BCE+DCL | 0.9011 | 0.9477 | - | - | 149 | | SegNet | BCE+DCL+IDCL | 0.9234 | 0.9600 | 0.8731 | 0.9293 | 150 | 151 | The results with this network are good, and the some of the best ones are shown here 152 | 153 | 154 | 155 | # References 156 | To be added shortly. 157 | 158 | -------------------------------------------------------------------------------- /Report.txt: -------------------------------------------------------------------------------- 1 | SegNet BCE 2 | Mean IoU: 0.8867305371545303 3 | Mean Dice: 0.9396421301873673 4 | 5 | SegNet BCE+Dice 6 | Mean IoU: 0.901176846870474 7 | Mean Dice: 0.9477213769634781 8 | 9 | SegNet BCE+Dice+InvDice 10 | Mean IoU: 0.9234133376663273 11 | Mean Dice: 0.9600279186729525 12 | 13 | SegNet BCE+Dice+InvDice Test 14 | Mean IoU: 0.8731433000894546 15 | Mean Dice: 0.9293213644884696 16 | 17 | -------------------------------------------------------------------------------- /Report_Unet.txt: -------------------------------------------------------------------------------- 1 | UNet(bilinear) BCE_loss 2 | 3 | (with Flipping in realtime) 4 | ('Mean IoU:', 0.9060794828941379) 5 | ('Mean Dice:', 0.9505207747694377) 6 | 7 | (no Flipping) 8 | ('Mean IoU:', 0.9350350587463033) 9 | ('Mean Dice:', 0.9663559721631583) 10 | 11 | (Flipping and Noise) 12 | ('Mean IoU:', 0.9403301265614896) 13 | ('Mean Dice:', 0.969221751380565) 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | UNet(bilinear) BCE_loss + Dice 22 | 23 | (Flipping in realtime) 24 | ('Mean IoU:', 0.8993405777542494) 25 | ('Mean Dice:', 0.9468607435441629) 26 | 27 | (Flipping + Noise) 28 | ('Mean IoU:', 0.9426167844499045) 29 | ('Mean Dice:', 0.970421860911066) 30 | 31 | 32 | 33 | 34 | 35 | UNet(bilinear) BCE + Dice + InvDice 36 | (Flipping + Noise) 37 | ('Mean IoU:', 0.9379613208481775) 38 | ('Mean Dice:', 0.9679506541131151) 39 | 40 | (Flipping + Noise + No random Rotation) 41 | ('Mean IoU:', 0.9665012522183872) 42 | ('Mean Dice:', 0.9829327195790153) 43 | 44 | Test Report 45 | ('Mean IoU:', 0.9295389541727913) 46 | ('Mean Dice:', 0.9623430243881239) 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /__pycache__/build_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaurav104/LungSegmentation/3bd6e821494e75191b917b8d5bdbafc5f6241cd3/__pycache__/build_model.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaurav104/LungSegmentation/3bd6e821494e75191b917b8d5bdbafc5f6241cd3/__pycache__/data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /__pycache__/unet_parts.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaurav104/LungSegmentation/3bd6e821494e75191b917b8d5bdbafc5f6241cd3/__pycache__/unet_parts.cpython-35.pyc -------------------------------------------------------------------------------- /build_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | class SegNet(nn.Module): 7 | def __init__(self,input_nbr,label_nbr): 8 | super(SegNet, self).__init__() 9 | 10 | 11 | self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1) 12 | self.bn11 = nn.BatchNorm2d(64) 13 | self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 14 | self.bn12 = nn.BatchNorm2d(64) 15 | 16 | self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 17 | self.bn21 = nn.BatchNorm2d(128) 18 | self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 19 | self.bn22 = nn.BatchNorm2d(128) 20 | 21 | self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 22 | self.bn31 = nn.BatchNorm2d(256) 23 | self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 24 | self.bn32 = nn.BatchNorm2d(256) 25 | self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 26 | self.bn33 = nn.BatchNorm2d(256) 27 | 28 | self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 29 | self.bn41 = nn.BatchNorm2d(512) 30 | self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 31 | self.bn42 = nn.BatchNorm2d(512) 32 | self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 33 | self.bn43 = nn.BatchNorm2d(512) 34 | 35 | self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 36 | self.bn51 = nn.BatchNorm2d(512) 37 | self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 38 | self.bn52 = nn.BatchNorm2d(512) 39 | self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 40 | self.bn53 = nn.BatchNorm2d(512) 41 | 42 | self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 43 | self.bn53d = nn.BatchNorm2d(512) 44 | self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 45 | self.bn52d = nn.BatchNorm2d(512) 46 | self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 47 | self.bn51d = nn.BatchNorm2d(512) 48 | 49 | self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 50 | self.bn43d = nn.BatchNorm2d(512) 51 | self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 52 | self.bn42d = nn.BatchNorm2d(512) 53 | self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1) 54 | self.bn41d = nn.BatchNorm2d(256) 55 | 56 | self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 57 | self.bn33d = nn.BatchNorm2d(256) 58 | self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 59 | self.bn32d = nn.BatchNorm2d(256) 60 | self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1) 61 | self.bn31d = nn.BatchNorm2d(128) 62 | 63 | self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1) 64 | self.bn22d = nn.BatchNorm2d(128) 65 | self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1) 66 | self.bn21d = nn.BatchNorm2d(64) 67 | 68 | self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1) 69 | self.bn12d = nn.BatchNorm2d(64) 70 | self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1) 71 | self.Dropout = nn.Dropout(0.5) 72 | 73 | 74 | def forward(self, x): 75 | 76 | # Stage 1 77 | x11 = F.relu(self.bn11(self.conv11(x))) 78 | x11 = self.Dropout(x11) 79 | x12 = F.relu(self.bn12(self.conv12(x11))) 80 | x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True) 81 | 82 | # Stage 2 83 | x21 = F.relu(self.bn21(self.conv21(x1p))) 84 | x22 = F.relu(self.bn22(self.conv22(x21))) 85 | x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True) 86 | 87 | # Stage 3 88 | x31 = F.relu(self.bn31(self.conv31(x2p))) 89 | x31 = self.Dropout(x31) 90 | x32 = F.relu(self.bn32(self.conv32(x31))) 91 | x33 = F.relu(self.bn33(self.conv33(x32))) 92 | x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True) 93 | 94 | # Stage 4 95 | x41 = F.relu(self.bn41(self.conv41(x3p))) 96 | x42 = F.relu(self.bn42(self.conv42(x41))) 97 | x43 = F.relu(self.bn43(self.conv43(x42))) 98 | x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True) 99 | 100 | # Stage 5 101 | x51 = F.relu(self.bn51(self.conv51(x4p))) 102 | x51 = self.Dropout(x51) 103 | x52 = F.relu(self.bn52(self.conv52(x51))) 104 | x53 = F.relu(self.bn53(self.conv53(x52))) 105 | x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True) 106 | 107 | 108 | # Stage 5d 109 | x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2) 110 | x53d = F.relu(self.bn53d(self.conv53d(x5d))) 111 | x52d = F.relu(self.bn52d(self.conv52d(x53d))) 112 | x51d = F.relu(self.bn51d(self.conv51d(x52d))) 113 | 114 | # Stage 4d 115 | x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2) 116 | x43d = F.relu(self.bn43d(self.conv43d(x4d))) 117 | x42d = F.relu(self.bn42d(self.conv42d(x43d))) 118 | x41d = F.relu(self.bn41d(self.conv41d(x42d))) 119 | 120 | # Stage 3d 121 | x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2) 122 | x33d = F.relu(self.bn33d(self.conv33d(x3d))) 123 | x32d = F.relu(self.bn32d(self.conv32d(x33d))) 124 | x31d = F.relu(self.bn31d(self.conv31d(x32d))) 125 | 126 | # Stage 2d 127 | x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2) 128 | x22d = F.relu(self.bn22d(self.conv22d(x2d))) 129 | x21d = F.relu(self.bn21d(self.conv21d(x22d))) 130 | 131 | # Stage 1d 132 | x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2) 133 | x12d = F.relu(self.bn12d(self.conv12d(x1d))) 134 | x11d = self.conv11d(x12d) 135 | 136 | return x11d 137 | 138 | 139 | import torch 140 | import torch.nn as nn 141 | import torch.nn.functional as F 142 | 143 | from unet_parts import * 144 | 145 | 146 | class UNet(nn.Module): 147 | def __init__(self, n_channels, n_classes): 148 | super(UNet, self).__init__() 149 | self.inc = inconv(n_channels, 64) 150 | self.down1 = down(64, 128) 151 | self.down2 = down(128, 256) 152 | self.down3 = down(256, 512) 153 | self.down4 = down(512, 512) 154 | self.up1 = up(1024, 256) 155 | self.up2 = up(512, 128) 156 | self.up3 = up(256, 64) 157 | self.up4 = up(128, 64) 158 | self.outc = outconv(64, n_classes) 159 | self.Dropout = nn.Dropout(0.5) 160 | 161 | def forward(self, x): 162 | x1 = self.inc(x) 163 | x2 = self.down1(x1) 164 | x3 = self.down2(x2) 165 | x3 = self.Dropout(x3) 166 | x4 = self.down3(x3) 167 | x5 = self.down4(x4) 168 | x = self.up1(x5, x4) 169 | x = self.up2(x, x3) 170 | x = self.up3(x, x2) 171 | x = self.up4(x, x1) 172 | x = self.outc(x) 173 | return x 174 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data.dataset import Dataset 3 | from torchvision import transforms 4 | from skimage import io, transform 5 | from PIL import Image 6 | import os 7 | import numpy as np 8 | 9 | 10 | class LungSegTrain(Dataset): 11 | def __init__(self, path='Images_padded1/Train/', transforms=None): 12 | self.path = path 13 | self.list = os.listdir(self.path) 14 | 15 | self.transforms = transforms 16 | 17 | def __getitem__(self, index): 18 | # stuff 19 | image_path = 'Images_padded1/Train/' 20 | mask_path = 'Mask_padded1/Train/' 21 | image = Image.open(image_path+self.list[index]) 22 | image = image.convert('RGB') 23 | mask = Image.open(mask_path+self.list[index]) 24 | mask = mask.convert('L') 25 | if self.transforms is not None: 26 | image = self.transforms(image) 27 | mask = self.transforms(mask) 28 | # If the transform variable is not empty 29 | # then it applies the operations in the transforms with the order that it is created. 30 | return (image, mask) 31 | 32 | def __len__(self): 33 | return len(self.list) # of how many data(images?) you have 34 | 35 | 36 | class LungSegVal(Dataset): 37 | def __init__(self, path='Images_padded_test/', transforms=None): 38 | self.path = path 39 | self.list = os.listdir(self.path) 40 | 41 | self.transforms = transforms 42 | 43 | def __getitem__(self, index): 44 | # stuff 45 | image_path = 'Images_padded_test/' 46 | mask_path = 'Mask_padded_test/' 47 | image_name = self.list[index] 48 | image = Image.open(image_path+self.list[index]) 49 | image = image.convert('RGB') 50 | mask = Image.open(mask_path+self.list[index]) 51 | mask = mask.convert('L') 52 | if self.transforms is not None: 53 | image = self.transforms(image) 54 | mask = self.transforms(mask) 55 | # If the transform variable is not empty 56 | # then it applies the operations in the transforms with the order that it is created. 57 | return (image, mask, image_name) 58 | 59 | def __len__(self): 60 | return len(self.list) 61 | -------------------------------------------------------------------------------- /gaussian_noise_and_flipping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | from skimage import io 5 | from skimage.util import random_noise 6 | from skimage.filters import gaussian 7 | from tqdm import tqdm 8 | 9 | image_names = os.listdir("Images_padded1/Train/") 10 | 11 | for name in tqdm(image_names): 12 | img = io.imread('Images_padded1/Train/'+name) 13 | mask = io.imread('Mask_padded1/Train/'+name) 14 | 15 | noise_image_01 = random_noise(img, mode='gaussian', seed=None, clip=True, var = 0.01) 16 | noise_image_01 = noise_image_01*255. 17 | noise_image_01 = noise_image_01.astype('uint8') 18 | io.imsave('Images_padded1/Train/'+name[:-4]+'_01.png', noise_image_01) 19 | io.imsave('Mask_padded1/Train/'+name[:-4]+'_01.png', mask) 20 | 21 | ''' 22 | image_names = os.listdir("Images_padded1/Train/") 23 | 24 | for name in tqdm(image_names): 25 | img = cv2.imread('Images_padded1/Train/'+name,0) 26 | mask = cv2.imread('Mask_padded1/Train/'+name,0) 27 | 28 | # copy image to display all 4 variations 29 | horizontal_img = img.copy() 30 | horizontal_mask = mask.copy() 31 | 32 | 33 | # flip img horizontally, vertically, 34 | # and both axes with flip() 35 | horizontal_img = cv2.flip( img, 1 ) 36 | horizontal_mask = cv2.flip(mask, 1) 37 | 38 | 39 | # display the images on screen with imshow() 40 | cv2.imwrite( 'Images_padded1/Train/'+name[:-4]+'_flip.png', horizontal_img ) 41 | cv2.imwrite( 'Mask_padded1/Train/'+name[:-4]+'_flip.png', horizontal_mask ) 42 | 43 | ''' 44 | ''' 45 | img = io.imread('MCUCXR_0006_0.png') 46 | 47 | 48 | noise_image_1 = random_noise(img, mode='gaussian', seed=None, clip=True, var = 0.001) 49 | noise_image_1 = noise_image_1*255. 50 | noise_image_1 = noise_image_1.astype('uint8') 51 | 52 | io.imsave('MCUCXR_0006_0_noise_1.png', noise_image_1) 53 | ''' 54 | ''' 55 | blur = gaussian(img, sigma=5, output=None, mode='nearest', cval=0, multichannel=None, preserve_range=False, truncate=4.0) 56 | blur = blur*255. 57 | blur = blur.astype('uint16') 58 | io.imsave('MCUCXR_0006_0_vlur.png', blur) 59 | ''' 60 | -------------------------------------------------------------------------------- /inferences.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 4 | import torch 5 | from build_model import * 6 | from torch.utils.data import DataLoader 7 | from data_loader import LungSegVal 8 | from torchvision import transforms 9 | import torch.nn.functional as F 10 | import numpy as np 11 | 12 | from skimage import morphology, color, io, exposure 13 | 14 | def IoU(y_true, y_pred): 15 | """Returns Intersection over Union score for ground truth and predicted masks.""" 16 | assert y_true.dtype == bool and y_pred.dtype == bool 17 | y_true_f = y_true.flatten() 18 | y_pred_f = y_pred.flatten() 19 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 20 | union = np.logical_or(y_true_f, y_pred_f).sum() 21 | return (intersection + 1) * 1. / (union + 1) 22 | 23 | def Dice(y_true, y_pred): 24 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 25 | assert y_true.dtype == bool and y_pred.dtype == bool 26 | y_true_f = y_true.flatten() 27 | y_pred_f = y_pred.flatten() 28 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 29 | return (2. * intersection + 1.) / (y_true.sum() + y_pred.sum() + 1.) 30 | 31 | def masked(img, gt, mask, alpha=1): 32 | """Returns image with GT lung field outlined with red, predicted lung field 33 | filled with blue.""" 34 | rows, cols = img.shape[:2] 35 | color_mask = np.zeros((rows, cols, 3)) 36 | boundary = morphology.dilation(gt, morphology.disk(3))^gt 37 | color_mask[mask == 1] = [0, 0, 1] 38 | color_mask[boundary == 1] = [1, 0, 0] 39 | 40 | img_hsv = color.rgb2hsv(img) 41 | color_mask_hsv = color.rgb2hsv(color_mask) 42 | 43 | img_hsv[..., 0] = color_mask_hsv[..., 0] 44 | img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha 45 | 46 | img_masked = color.hsv2rgb(img_hsv) 47 | return img_masked 48 | 49 | def remove_small_regions(img, size): 50 | """Morphologically removes small (less than size) connected regions of 0s or 1s.""" 51 | img = morphology.remove_small_objects(img, size) 52 | img = morphology.remove_small_holes(img, size) 53 | return img 54 | 55 | if __name__ == '__main__': 56 | 57 | # Path to csv-file. File should contain X-ray filenames as first column, 58 | # mask filenames as second column. 59 | # Load test data 60 | img_size = (864, 864) 61 | 62 | inp_shape = (864,864,3) 63 | batch_size=1 64 | 65 | # Load model 66 | #model_name = 'model.020.hdf5' 67 | #UNet = load_model(model_name) 68 | net = SegNet(3,1) 69 | net.cuda() 70 | 71 | net.load_state_dict(torch.load('Weights_BCE_Dice_InvDice/cp_bce_flip_lr_04_no_rot52_0.04634043872356415.pth.tar')) 72 | net.eval() 73 | 74 | 75 | 76 | seed = 1 77 | transformations_test = transforms.Compose([transforms.Resize((864,864)),transforms.ToTensor()]) 78 | test_set = LungSegVal(transforms = transformations_test) 79 | test_loader = DataLoader(test_set, batch_size=batch_size, shuffle = False) 80 | ious = np.zeros(len(test_loader)) 81 | dices = np.zeros(len(test_loader)) 82 | if not(os.path.exists('./results_Unet_BCE_Dice_InvDice_test')): 83 | os.mkdir('./results_Unet_BCE_Dice_InvDice_test') 84 | 85 | 86 | i = 0 87 | for xx, yy, name in tqdm(test_loader): 88 | xx = xx.cuda() 89 | yy = yy 90 | 91 | name = name[0][:-4] 92 | print (name) 93 | pred = net(xx) 94 | pred = F.sigmoid(pred) 95 | pred = pred.cpu() 96 | pred = pred.detach().numpy()[0,0,:,:] 97 | mask = yy.numpy()[0,0,:,:] 98 | xx = xx.cpu() 99 | xx = xx.numpy()[0,:,:,:].transpose(1,2,0) 100 | img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 101 | 102 | # Binarize masks 103 | gt = mask > 0.5 104 | pr = pred > 0.5 105 | 106 | # Remove regions smaller than 2% of the image 107 | #pr = remove_small_regions(pr, 0.02 * np.prod(img_size)) 108 | 109 | 110 | io.imsave('results_Unet_BCE_Dice_InvDice_test/{}.png'.format(name), pr*255) 111 | 112 | ious[i] = IoU(gt, pr) 113 | dices[i] = Dice(gt, pr) 114 | 115 | i += 1 116 | if i == len(test_loader): 117 | break 118 | 119 | print ('Mean IoU:', ious.mean()) 120 | print ('Mean Dice:', dices.mean()) 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /metric_calc.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import cv2 4 | import os 5 | import sys 6 | from PIL import Image 7 | 8 | 9 | 10 | 11 | def get_IoU(Gi,Si): 12 | #print(Gi.shape, Si.shape) 13 | intersect = 1.0*np.sum(np.logical_and(Gi,Si)) 14 | union = 1.0*np.sum(np.logical_or(Gi,Si)) 15 | return intersect/union 16 | #check cv2.connectedComponents and what it returns, alongwith channels first or last 17 | def generate_list(G,S): 18 | G = G.astype('uint8') 19 | S = S.astype('uint8') 20 | #print(np.unique(G)) 21 | gland_obj_cnt,gland_obj = cv2.connectedComponents(G,connectivity=8) 22 | seg_obj_cnt,seg_obj = cv2.connectedComponents(S,connectivity=8) 23 | gland_obj_list = [] 24 | seg_obj_list = [] 25 | for i in range(1,gland_obj_cnt): 26 | gland_obj_list.append( (gland_obj==(i)).astype('int32') ) 27 | for i in range(1,seg_obj_cnt): 28 | seg_obj_list.append( (seg_obj==(i)).astype('int32') ) 29 | gland_obj_list = np.array(gland_obj_list) 30 | seg_obj_list = np.array(seg_obj_list) 31 | return gland_obj_list,seg_obj_list 32 | 33 | ####Find why channel parameter was passed 34 | def AGI_core(gland_obj_list,seg_obj_list,channel='last'): 35 | C = 0.0 36 | U = 0.0 37 | ##check below: 38 | ''' 39 | Swapping is not required. 40 | if(channel=='last'): 41 | # make channels first 42 | gland_obj_list = np.swapaxes( np.swapaxes( gland_obj_list , 0,2) , 1 , 2 ) 43 | seg_obj_list = np.swapaxes( np.swapaxes( seg_obj_list , 0,2) , 1 , 2 ) 44 | ''' 45 | #print(gland_obj_list.shape) 46 | seg_nonused = np.ones(len(seg_obj_list)) 47 | for gi in gland_obj_list: 48 | iou = np.multiply( [get_IoU(gi,si) for si in seg_obj_list] , seg_nonused ) 49 | max_iou = np.max(iou) 50 | j = np.argmax(iou) 51 | C = C + np.sum(np.logical_and(gi,seg_obj_list[j]) ) 52 | U = U + np.sum(np.logical_or(gi,seg_obj_list[j]) ) 53 | seg_nonused[j] = 0 54 | for ind in range(len(seg_obj_list)): 55 | if((seg_nonused[ind])==1): 56 | U = U + np.sum(seg_obj_list[ind]) 57 | return C*1./U 58 | 59 | def Acc_Jacard_Index(G,S): 60 | gland_obj_list,seg_obj_list = generate_list(G, S) 61 | print("In AJI and length is:{}".format(len(gland_obj_list))) 62 | return AGI_core(gland_obj_list,seg_obj_list) 63 | 64 | #----------------------------------------------- 65 | 66 | def F1_core(gland_obj_list,seg_obj_list,channel='first'): 67 | TP,FP,FN = 0.0,0.0,0.0 68 | if(channel=='last'): 69 | # make channels first 70 | gland_obj_list = np.swapaxes( np.swapaxes( gland_obj_list , 0,2) , 1 , 2 ) 71 | seg_obj_list = np.swapaxes( np.swapaxes( seg_obj_list , 0,2) , 1 , 2 ) 72 | seg_nonused = np.ones(len(seg_obj_list)) 73 | gland_unhit = np.ones(len(seg_obj_list)) 74 | 75 | for ind in range(len(gland_obj_list)): 76 | gi = gland_obj_list[ind] 77 | overlap_s = np.multiply( np.sum( seg_obj_list*gi , axis=(1,2) ) , seg_nonused ) 78 | max_ov = np.max(overlap_s) 79 | percent_overlap = max_ov/np.sum(gi) 80 | if percent_overlap>=0.01 : 81 | # hit 82 | TP = TP +1 83 | j = np.argmax(overlap_s) 84 | seg_nonused[j] = 0 85 | else: 86 | # unhit 87 | FN = FN + 1 88 | 89 | FP = np.sum(seg_nonused) 90 | F1_val = (2*TP)/(2*TP + FP + FN) 91 | return F1_val 92 | 93 | def F1_score(G,S): 94 | #y_mask = np.asarray(G[:, :, :, 0]).astype('uint8') 95 | #print y_mask.shape 96 | #y_pred = np.asarray(S[:, :, :, 0]).astype('uint8') 97 | #print y_pred.shape 98 | #print type(y_pred) 99 | #y_dist = K.expand_dims(G[:, :, :, 1], axis=-1) 100 | gland_obj_list,seg_obj_list = generate_list(G, S) 101 | return F1_core(gland_obj_list,seg_obj_list) 102 | 103 | def Dice(y_true, y_pred): 104 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 105 | #print(y_true.dtype) 106 | #print(y_pred.dtype) 107 | y_true = np.squeeze(y_true)/255 108 | y_pred = np.squeeze(y_pred)/255 109 | y_true.astype('bool') 110 | y_pred.astype('bool') 111 | intersection = np.logical_and(y_true, y_pred).sum() 112 | return ((2. * intersection.sum()) + 1.) / (y_true.sum() + y_pred.sum() + 1.) 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | smooth = 1 121 | 122 | image_names = os.listdir('results_scratch_custom99/') 123 | 124 | mean_dice = [] 125 | mean_F1 = [] 126 | aggr_jacard = [] 127 | 128 | 129 | for images in tqdm(image_names): 130 | 131 | S = np.expand_dims(np.array(Image.open('results_scratch_custom99/'+images).convert('L')),axis=-1) 132 | G = np.expand_dims(np.array(Image.open('/home/sahyadri/Testing/Test_40_y_HE/'+images).convert('L')),axis=-1) 133 | #print S.shape 134 | #G.shape 135 | #print(Acc_Jacard_Index(G,S)) 136 | #aggr_jacard.append(Acc_Jacard_Index(G,S)) 137 | #mean_F1.append(F1_score(G,S)) 138 | mean_dice.append(Dice(G, S)) 139 | 140 | 141 | print ('Mean_Dice = ', np.mean(np.array(mean_dice))) 142 | #print ('Mean_F1 = ', np.mean(np.array(mean_F1))) 143 | #print (len(aggr_jacard), aggr_jacard) 144 | #print ('Mean_Aggr_Jacard = ', np.mean(np.array(aggr_jacard))) 145 | 146 | f = open('lung.txt','w') 147 | a = 'Mean Dice : {}'.format(np.mean(np.array(mean_dice)))+ '\n' + 'Mean F1 : {}'.format(np.mean(np.array(mean_F1)))+ '\n' + 'Mean Aggregate Jacard : {}'.format(np.mean(np.array(aggr_jacard)))+ '\n' 148 | 149 | f.write(str(a)) 150 | f.close() 151 | 152 | -------------------------------------------------------------------------------- /preprocess_combining.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage import io, exposure 4 | 5 | 6 | 7 | def make_masks(): 8 | path = 'test_image/' 9 | for i, filename in enumerate(os.listdir(path)): 10 | left = io.imread('test_mask/left/' + filename[:-4] + '.png') 11 | right = io.imread('test_mask/right/' + filename[:-4] + '.png') 12 | io.imsave('test_mask/Mask/' + filename[:-4] + '.png', np.clip(left + right, 0, 255)) 13 | print ('Mask', i, filename) 14 | 15 | make_masks() 16 | 17 | -------------------------------------------------------------------------------- /preprocess_padding.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | from PIL import Image 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | img_names = os.listdir('test_image/') 10 | 11 | 12 | for name in tqdm(img_names): 13 | 14 | img = np.asarray(Image.open('test_image/'+name)) 15 | max_dim_img = max(img.shape[0], img.shape[1]) 16 | row_img = max_dim_img - img.shape[0] 17 | cols_img = max_dim_img - img.shape[1] 18 | padded_img = np.pad(img, ((row_img//2,row_img//2), (cols_img//2,cols_img//2)), mode ='constant') 19 | img = Image.fromarray(padded_img) 20 | img.save('Images_padded_test/'+name) 21 | 22 | mask = np.asarray(Image.open('test_mask/Mask/'+name)) 23 | max_dim_mask = max(mask.shape[0], mask.shape[1]) 24 | row_mask = max_dim_mask- mask.shape[0] 25 | cols_mask = max_dim_mask - mask.shape[1] 26 | padded_mask = np.pad(mask, ((row_mask//2,row_mask//2), (cols_mask//2,cols_mask//2)), mode ='constant') 27 | mask = Image.fromarray(padded_mask) 28 | mask.save('Mask_padded_test/'+name) 29 | 30 | -------------------------------------------------------------------------------- /train_model_Unet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torch.autograd import Variable 8 | from build_model import * 9 | import os 10 | from tqdm import tqdm 11 | from tensorboardX import SummaryWriter 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 15 | #-------------------------- 16 | class Average(object): 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.sum += val 26 | self.count += n 27 | 28 | #property 29 | def avg(self): 30 | return self.sum / self.count 31 | #------------------------------ 32 | # import csv 33 | writer = SummaryWriter() 34 | #---------------------------------------- 35 | class SoftDiceLoss(nn.Module): 36 | ''' 37 | Soft Dice Loss 38 | ''' 39 | def __init__(self, weight=None, size_average=True): 40 | super(SoftDiceLoss, self).__init__() 41 | 42 | def forward(self, logits, targets): 43 | smooth = 1. 44 | logits = F.sigmoid(logits) 45 | iflat = logits.view(-1) 46 | tflat = targets.view(-1) 47 | intersection = (iflat * tflat).sum() 48 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 49 | #------------------------------------------------- 50 | ''' 51 | class SoftDicescore(nn.Module): 52 | ''' 53 | #Soft Dice Loss 54 | ''' 55 | def __init__(self, weight=None, size_average=True): 56 | super(SoftDicescore, self).__init__() 57 | 58 | def forward(self, logits, targets): 59 | smooth = 1. 60 | logits = F.sigmoid(logits) 61 | iflat = logits.view(-1) 62 | tflat = targets.view(-1) 63 | intersection = (iflat * tflat).sum() 64 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 65 | ''' 66 | #------------------------------------------------- 67 | ''' 68 | class W_bce(nn.Module): 69 | 70 | #weighted crossentropy per image 71 | 72 | def __init__(self, weight=None, size_average=True): 73 | super(W_bce, self).__init__() 74 | 75 | def forward(self, logits, targets): 76 | eps = 1e-6 77 | total_size = targets.view(-1).size()[0] 78 | #print "total_size", total_size 79 | ones_size = torch.sum(targets.view(-1,1)).item() 80 | #print "one_size", ones_size 81 | zero_size = total_size - ones_size 82 | #print "zero_size", zero_size 83 | #assert total_size == (ones_size + zero_size) 84 | #print "crossed assertion" 85 | loss_1 = torch.mean(-(targets.view(-1)* ( total_size/ones_size) * torch.log(torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps))))#.sum(axis=1) 86 | #print "crossed loss1" 87 | loss_0 = torch.mean(-((1.-targets.view(-1))* ( total_size/zero_size) * torch.log((1.-torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps)))))#.sum(axis=1) 88 | #print "crossed loss0" 89 | return loss_1 + loss_0 90 | ''' 91 | #---------------------------------- 92 | class InvSoftDiceLoss(nn.Module): 93 | 94 | ''' 95 | Inverted Soft Dice Loss 96 | ''' 97 | def __init__(self, weight=None, size_average=True): 98 | super(InvSoftDiceLoss, self).__init__() 99 | 100 | def forward(self, logits, targets): 101 | smooth = 1. 102 | logits = F.sigmoid(logits) 103 | iflat = 1-logits.view(-1) 104 | tflat = 1-targets.view(-1) 105 | intersection = (iflat * tflat).sum() 106 | 107 | 108 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 109 | #-------------------------------------- 110 | ''' 111 | class InvSoftDicescore(nn.Module): 112 | 113 | ''' 114 | #Inverted Soft Dice Loss 115 | ''' 116 | 117 | def __init__(self, weight=None, size_average=True): 118 | super(InvSoftDicescore, self).__init__() 119 | 120 | def forward(self, logits, targets): 121 | smooth = 1. 122 | logits = F.sigmoid(logits) 123 | iflat = 1-logits.view(-1) 124 | tflat = 1-targets.view(-1) 125 | intersection = (iflat * tflat).sum() 126 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 127 | ''' 128 | #---------------------------------------- 129 | ''' 130 | class int_custom_loss(nn.Module): 131 | ''' 132 | #custom loss 133 | ''' 134 | def __init__(self, weight=None, size_average=True): 135 | super(int_custom_loss, self).__init__() 136 | 137 | def forward(self, logits, targets): 138 | loss_inv_dice = InvSoftDicescore() 139 | loss_dice = SoftDicescore() 140 | total_size = targets.view(-1).size()[0] 141 | ones_size = torch.sum(targets.view(-1,1)).item() 142 | th = 0.2 * total_size 143 | if(ones_size > th): 144 | return (- 0.8*torch.log(loss_dice(logits,targets))-0.2*torch.log(loss_inv_dice(logits, targets))) 145 | else: 146 | return(-0.2*torch.log(loss_dice(logits, targets))-0.8*torch.log(loss_inv_dice(logits, targets))) 147 | ''' 148 | ''' 149 | class weighted_dice_invdice(nn.Module): 150 | ''' 151 | #custom loss 152 | ''' 153 | def __init__(self, weight=None, size_average=True): 154 | super(weighted_dice_invdice, self).__init__() 155 | 156 | def forward(self, logits, targets): 157 | loss_inv_dice = InvSoftDicescore() 158 | loss_dice = SoftDicescore() 159 | total_size = targets.view(-1).size()[0] 160 | ones_size = torch.sum(targets.view(-1,1)).item() 161 | zero_size = total_size - ones_size 162 | th = 0.2 * total_size 163 | return (-(zero_size/total_size)*torch.log(loss_dice(logits,targets))-(ones_size/total_size)*torch.log(loss_inv_dice(logits, targets))) 164 | ''' 165 | 166 | #Tranformations------------------------------------------------ 167 | transformations_train = transforms.Compose([transforms.Resize((1024,1024)),transforms.ToTensor()]) 168 | 169 | transformations_val = transforms.Compose([transforms.Resize((1024,1024)),transforms.ToTensor()]) 170 | #------------------------------------------------------------- 171 | 172 | from data_loader import LungSegTrain 173 | from data_loader import LungSegVal 174 | train_set = LungSegTrain(transforms = transformations_train) 175 | batch_size = 1 176 | num_epochs = 75 177 | 178 | def train(): 179 | cuda = torch.cuda.is_available() 180 | net = UNet(3,1) 181 | if cuda: 182 | net = net.cuda() 183 | #net.load_state_dict(torch.load('Weights_BCE_Dice/cp_bce_lr_05_100_0.222594484687.pth.tar')) 184 | criterion1 = nn.BCEWithLogitsLoss().cuda() 185 | criterion2 = SoftDiceLoss().cuda() 186 | criterion3 = InvSoftDiceLoss().cuda() 187 | #criterion4 = W_bce().cuda() 188 | #criterion5 = int_custom_loss() 189 | #criterion6 = weighted_dice_invdice() 190 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) 191 | #scheduler = MultiStepLR(optimizer, milestones=[2,10,75,100], gamma=0.1) 192 | 193 | print("preparing training data ...") 194 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 195 | print("done ...") 196 | val_set = LungSegVal(transforms = transformations_val) 197 | val_loader = DataLoader(val_set, batch_size=batch_size,shuffle=False) 198 | for epoch in tqdm(range(num_epochs)): 199 | #scheduler.step() 200 | train_loss = Average() 201 | net.train() 202 | for i, (images, masks) in tqdm(enumerate(train_loader)): 203 | images = Variable(images) 204 | masks = Variable(masks) 205 | if cuda: 206 | images = images.cuda() 207 | masks = masks.cuda() 208 | 209 | optimizer.zero_grad() 210 | outputs = net(images) 211 | #writer.add_image('Training Input',images) 212 | #writer.add_image('Training Pred',F.sigmoid(outputs)>0.5) 213 | c1 = criterion1(outputs,masks) + criterion2(outputs, masks) + criterion3(outputs, masks) 214 | loss = c1 215 | writer.add_scalar('Train Loss',loss,epoch) 216 | loss.backward() 217 | optimizer.step() 218 | train_loss.update(loss.item(), images.size(0)) 219 | for param_group in optimizer.param_groups: 220 | writer.add_scalar('Learning Rate',param_group['lr']) 221 | 222 | val_loss1 = Average() 223 | val_loss2 = Average() 224 | val_loss3 = Average() 225 | net.eval() 226 | for images, masks,_ in tqdm(val_loader): 227 | images = Variable(images) 228 | masks = Variable(masks) 229 | if cuda: 230 | images = images.cuda() 231 | masks = masks.cuda() 232 | 233 | outputs = net(images) 234 | if (epoch)%10==0: 235 | writer.add_image('Validation Input',images,epoch) 236 | writer.add_image('Validation GT ',masks,epoch) 237 | writer.add_image('Validation Pred0.5',F.sigmoid(outputs)>0.5,epoch) 238 | writer.add_image('Validation Pred0.3',F.sigmoid(outputs)>0.3,epoch) 239 | writer.add_image('Validation Pred0.65',F.sigmoid(outputs)>0.65,epoch) 240 | 241 | vloss1 = criterion1(outputs, masks) 242 | vloss2 = criterion2(outputs, masks) 243 | vloss3 = criterion3(outputs, masks) #+ criterion2(outputs, masks) 244 | #vloss = vloss2 + vloss3 245 | writer.add_scalar('Validation loss(BCE)',vloss1,epoch) 246 | writer.add_scalar('Validation loss(Dice)',vloss2,epoch) 247 | writer.add_scalar('Validation loss(InvDice)',vloss3,epoch) 248 | 249 | val_loss1.update(vloss1.item(), images.size(0)) 250 | val_loss2.update(vloss2.item(), images.size(0)) 251 | val_loss3.update(vloss3.item(), images.size(0)) 252 | 253 | print("Epoch {}, Training Loss(BCE+Dice): {}, Validation Loss(BCE): {}, Validation Loss(Dice): {}, Validation Loss(InvDice): {}".format(epoch+1, train_loss.avg(), val_loss1.avg(), val_loss2.avg(), val_loss3.avg())) 254 | 255 | # with open('Log.csv', 'a') as logFile: 256 | # FileWriter = csv.writer(logFile) 257 | # FileWriter.writerow([epoch+1, train_loss.avg, val_loss1.avg, val_loss2.avg, val_loss3.avg]) 258 | 259 | torch.save(net.state_dict(), 'Weights_BCE_Dice_InvDice/cp_bce_flip_lr_04_no_rot{}_{}.pth.tar'.format(epoch+1, val_loss2.avg())) 260 | return net 261 | 262 | def test(model): 263 | model.eval() 264 | 265 | 266 | 267 | if __name__ == "__main__": 268 | train() 269 | -------------------------------------------------------------------------------- /train_model_segnet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torch.autograd import Variable 8 | from build_model import * 9 | import os 10 | from tqdm import tqdm 11 | from tensorboardX import SummaryWriter 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 15 | #-------------------------- 16 | class Average(object): 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.sum += val 26 | self.count += n 27 | 28 | #property 29 | def avg(self): 30 | return self.sum / self.count 31 | #------------------------------ 32 | # import csv 33 | writer = SummaryWriter() 34 | #---------------------------------------- 35 | class SoftDiceLoss(nn.Module): 36 | ''' 37 | Soft Dice Loss 38 | ''' 39 | def __init__(self, weight=None, size_average=True): 40 | super(SoftDiceLoss, self).__init__() 41 | 42 | def forward(self, logits, targets): 43 | smooth = 1. 44 | logits = F.sigmoid(logits) 45 | iflat = logits.view(-1) 46 | tflat = targets.view(-1) 47 | intersection = (iflat * tflat).sum() 48 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 49 | #------------------------------------------------- 50 | ''' 51 | class SoftDicescore(nn.Module): 52 | ''' 53 | #Soft Dice Loss 54 | ''' 55 | def __init__(self, weight=None, size_average=True): 56 | super(SoftDicescore, self).__init__() 57 | 58 | def forward(self, logits, targets): 59 | smooth = 1. 60 | logits = F.sigmoid(logits) 61 | iflat = logits.view(-1) 62 | tflat = targets.view(-1) 63 | intersection = (iflat * tflat).sum() 64 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 65 | ''' 66 | #------------------------------------------------- 67 | ''' 68 | class W_bce(nn.Module): 69 | 70 | #weighted crossentropy per image 71 | 72 | def __init__(self, weight=None, size_average=True): 73 | super(W_bce, self).__init__() 74 | 75 | def forward(self, logits, targets): 76 | eps = 1e-6 77 | total_size = targets.view(-1).size()[0] 78 | #print "total_size", total_size 79 | ones_size = torch.sum(targets.view(-1,1)).item() 80 | #print "one_size", ones_size 81 | zero_size = total_size - ones_size 82 | #print "zero_size", zero_size 83 | #assert total_size == (ones_size + zero_size) 84 | #print "crossed assertion" 85 | loss_1 = torch.mean(-(targets.view(-1)* ( total_size/ones_size) * torch.log(torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps))))#.sum(axis=1) 86 | #print "crossed loss1" 87 | loss_0 = torch.mean(-((1.-targets.view(-1))* ( total_size/zero_size) * torch.log((1.-torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps)))))#.sum(axis=1) 88 | #print "crossed loss0" 89 | return loss_1 + loss_0 90 | ''' 91 | #---------------------------------- 92 | class InvSoftDiceLoss(nn.Module): 93 | 94 | ''' 95 | Inverted Soft Dice Loss 96 | ''' 97 | def __init__(self, weight=None, size_average=True): 98 | super(InvSoftDiceLoss, self).__init__() 99 | 100 | def forward(self, logits, targets): 101 | smooth = 1. 102 | logits = F.sigmoid(logits) 103 | iflat = 1-logits.view(-1) 104 | tflat = 1-targets.view(-1) 105 | intersection = (iflat * tflat).sum() 106 | 107 | 108 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 109 | #-------------------------------------- 110 | ''' 111 | class InvSoftDicescore(nn.Module): 112 | 113 | ''' 114 | #Inverted Soft Dice Loss 115 | ''' 116 | 117 | def __init__(self, weight=None, size_average=True): 118 | super(InvSoftDicescore, self).__init__() 119 | 120 | def forward(self, logits, targets): 121 | smooth = 1. 122 | logits = F.sigmoid(logits) 123 | iflat = 1-logits.view(-1) 124 | tflat = 1-targets.view(-1) 125 | intersection = (iflat * tflat).sum() 126 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 127 | ''' 128 | #---------------------------------------- 129 | ''' 130 | class int_custom_loss(nn.Module): 131 | ''' 132 | #custom loss 133 | ''' 134 | def __init__(self, weight=None, size_average=True): 135 | super(int_custom_loss, self).__init__() 136 | 137 | def forward(self, logits, targets): 138 | loss_inv_dice = InvSoftDicescore() 139 | loss_dice = SoftDicescore() 140 | total_size = targets.view(-1).size()[0] 141 | ones_size = torch.sum(targets.view(-1,1)).item() 142 | th = 0.2 * total_size 143 | if(ones_size > th): 144 | return (- 0.8*torch.log(loss_dice(logits,targets))-0.2*torch.log(loss_inv_dice(logits, targets))) 145 | else: 146 | return(-0.2*torch.log(loss_dice(logits, targets))-0.8*torch.log(loss_inv_dice(logits, targets))) 147 | ''' 148 | ''' 149 | class weighted_dice_invdice(nn.Module): 150 | ''' 151 | #custom loss 152 | ''' 153 | def __init__(self, weight=None, size_average=True): 154 | super(weighted_dice_invdice, self).__init__() 155 | 156 | def forward(self, logits, targets): 157 | loss_inv_dice = InvSoftDicescore() 158 | loss_dice = SoftDicescore() 159 | total_size = targets.view(-1).size()[0] 160 | ones_size = torch.sum(targets.view(-1,1)).item() 161 | zero_size = total_size - ones_size 162 | th = 0.2 * total_size 163 | return (-(zero_size/total_size)*torch.log(loss_dice(logits,targets))-(ones_size/total_size)*torch.log(loss_inv_dice(logits, targets))) 164 | ''' 165 | 166 | #Tranformations------------------------------------------------ 167 | transformations_train = transforms.Compose([transforms.Resize((864,864)),transforms.ToTensor()]) 168 | 169 | transformations_val = transforms.Compose([transforms.Resize((864,864)),transforms.ToTensor()]) 170 | #------------------------------------------------------------- 171 | 172 | from data_loader import LungSegTrain 173 | from data_loader import LungSegVal 174 | train_set = LungSegTrain(transforms = transformations_train) 175 | batch_size = 1 176 | num_epochs = 75 177 | 178 | def train(): 179 | cuda = torch.cuda.is_available() 180 | net = SegNet(3,1) 181 | if cuda: 182 | net = net.cuda() 183 | #net.load_state_dict(torch.load('Weights_BCE_Dice/cp_bce_lr_05_100_0.222594484687.pth.tar')) 184 | criterion1 = nn.BCEWithLogitsLoss().cuda() 185 | criterion2 = SoftDiceLoss().cuda() 186 | criterion3 = InvSoftDiceLoss().cuda() 187 | #criterion4 = W_bce().cuda() 188 | #criterion5 = int_custom_loss() 189 | #criterion6 = weighted_dice_invdice() 190 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) 191 | #scheduler = MultiStepLR(optimizer, milestones=[2,10,75,100], gamma=0.1) 192 | 193 | print("preparing training data ...") 194 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 195 | print("done ...") 196 | val_set = LungSegVal(transforms = transformations_val) 197 | val_loader = DataLoader(val_set, batch_size=batch_size,shuffle=False) 198 | for epoch in tqdm(range(num_epochs)): 199 | #scheduler.step() 200 | train_loss = Average() 201 | net.train() 202 | for i, (images, masks) in tqdm(enumerate(train_loader)): 203 | images = Variable(images) 204 | masks = Variable(masks) 205 | if cuda: 206 | images = images.cuda() 207 | masks = masks.cuda() 208 | 209 | optimizer.zero_grad() 210 | outputs = net(images) 211 | #writer.add_image('Training Input',images) 212 | #writer.add_image('Training Pred',F.sigmoid(outputs)>0.5) 213 | c1 = criterion1(outputs,masks) + criterion2(outputs, masks) + criterion3(outputs, masks) 214 | loss = c1 215 | writer.add_scalar('Train Loss',loss,epoch) 216 | loss.backward() 217 | optimizer.step() 218 | train_loss.update(loss.item(), images.size(0)) 219 | for param_group in optimizer.param_groups: 220 | writer.add_scalar('Learning Rate',param_group['lr']) 221 | val_loss1 = Average() 222 | val_loss2 = Average() 223 | val_loss3 = Average() 224 | net.eval() 225 | for images, masks,_ in tqdm(val_loader): 226 | images = Variable(images) 227 | masks = Variable(masks) 228 | if cuda: 229 | images = images.cuda() 230 | masks = masks.cuda() 231 | 232 | outputs = net(images) 233 | if (epoch)%10==0: 234 | writer.add_image('Validation Input',images,epoch) 235 | writer.add_image('Validation GT ',masks,epoch) 236 | writer.add_image('Validation Pred0.5',F.sigmoid(outputs)>0.5,epoch) 237 | writer.add_image('Validation Pred0.3',F.sigmoid(outputs)>0.3,epoch) 238 | writer.add_image('Validation Pred0.65',F.sigmoid(outputs)>0.65,epoch) 239 | 240 | vloss1 = criterion1(outputs, masks) 241 | vloss2 = criterion2(outputs, masks) 242 | vloss3 = criterion3(outputs, masks) #+ criterion2(outputs, masks) 243 | #vloss = vloss2 + vloss3 244 | writer.add_scalar('Validation loss(BCE)',vloss1,epoch) 245 | writer.add_scalar('Validation loss(Dice)',vloss2,epoch) 246 | writer.add_scalar('Validation loss(InvDice)',vloss3,epoch) 247 | 248 | val_loss1.update(vloss1.item(), images.size(0)) 249 | val_loss2.update(vloss2.item(), images.size(0)) 250 | val_loss3.update(vloss3.item(), images.size(0)) 251 | 252 | print("Epoch {}, Training Loss(BCE+Dice): {}, Validation Loss(BCE): {}, Validation Loss(Dice): {}, Validation Loss(InvDice): {}".format(epoch+1, train_loss.avg(), val_loss1.avg(), val_loss2.avg(), val_loss3.avg())) 253 | 254 | # with open('Log.csv', 'a') as logFile: 255 | # FileWriter = csv.writer(logFile) 256 | # FileWriter.writerow([epoch+1, train_loss.avg, val_loss1.avg, val_loss2.avg, val_loss3.avg]) 257 | 258 | torch.save(net.state_dict(), 'Weights_BCE_Dice_InvDice/cp_bce_flip_lr_04_no_rot{}_{}.pth.tar'.format(epoch+1, val_loss2.avg())) 259 | return net 260 | 261 | def test(model): 262 | model.eval() 263 | 264 | 265 | 266 | if __name__ == "__main__": 267 | train() 268 | -------------------------------------------------------------------------------- /unet_parts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # sub-parts of the U-Net model 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class double_conv(nn.Module): 11 | '''(conv => BN => ReLU) * 2''' 12 | def __init__(self, in_ch, out_ch): 13 | super(double_conv, self).__init__() 14 | self.conv = nn.Sequential( 15 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 16 | nn.BatchNorm2d(out_ch), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 19 | nn.BatchNorm2d(out_ch), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | return x 26 | 27 | 28 | class inconv(nn.Module): 29 | def __init__(self, in_ch, out_ch): 30 | super(inconv, self).__init__() 31 | self.conv = double_conv(in_ch, out_ch) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class down(nn.Module): 39 | def __init__(self, in_ch, out_ch): 40 | super(down, self).__init__() 41 | self.mpconv = nn.Sequential( 42 | nn.MaxPool2d(2), 43 | double_conv(in_ch, out_ch) 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.mpconv(x) 48 | return x 49 | 50 | 51 | class up(nn.Module): 52 | def __init__(self, in_ch, out_ch, bilinear=True): 53 | super(up, self).__init__() 54 | 55 | # would be a nice idea if the upsampling could be learned too, 56 | # but my machine do not have enough memory to handle all those weights 57 | if bilinear: 58 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 59 | else: 60 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 61 | 62 | self.conv = double_conv(in_ch, out_ch) 63 | 64 | def forward(self, x1, x2): 65 | x1 = self.up(x1) 66 | diffX = x1.size()[2] - x2.size()[2] 67 | diffY = x1.size()[3] - x2.size()[3] 68 | x2 = F.pad(x2, (diffX // 2, int(diffX / 2), 69 | diffY // 2, int(diffY / 2))) 70 | x = torch.cat([x2, x1], dim=1) 71 | x = self.conv(x) 72 | return x 73 | 74 | 75 | class outconv(nn.Module): 76 | def __init__(self, in_ch, out_ch): 77 | super(outconv, self).__init__() 78 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | return x 83 | --------------------------------------------------------------------------------