├── .gitignore ├── GCN ├── README.md ├── code │ ├── build_model.py │ ├── data_loader.py │ ├── evaluate.py │ ├── inferences.py │ └── train_model.py └── pics │ ├── gcn4.png │ ├── gcn6.png │ ├── result_NEWMCUCXR_0019_0.png │ └── shrunk_result_NEWMCUCXR_0019_0.png ├── HDC_DUC ├── README.md └── code │ ├── augmentation.py │ ├── build_model.py │ ├── data_loader.py │ ├── flip.py │ ├── inferences.py │ ├── preprocess.py │ └── train_model.py ├── README.md ├── SegNet ├── README.md └── code │ ├── build_model.py │ ├── data_loader.py │ ├── gaussian_noise_and_flipping.py │ ├── inferences.py │ ├── metric_calc.py │ ├── preprocess_combining.py │ ├── preprocess_padding.py │ └── train_model_segnet.py └── VGG_UNet ├── README.md └── code ├── build_model.py ├── data_loader.py ├── gaussian_noise_and_flipping.py ├── inferences.py ├── metric_calc.py ├── preprocess_combining.py ├── preprocess_padding.py ├── train_model_Unet.py ├── train_model_segnet.py └── unet_parts.py /.gitignore: -------------------------------------------------------------------------------- 1 | Flowchart_lungseg.PNG 2 | VGG_UNet/Unet_results.PNG 3 | VGG_UNet/README.html 4 | SegNet/segnet_results.PNG 5 | SegNet/segnet_results_compiled.PNG 6 | HDC_DUC/results_git/ 7 | HDC_DUC/pics/ 8 | GCN/gcn4.png 9 | GCN/gcn6.png 10 | GCN/result_NEWMCUCXR_0019_0.png 11 | GCN/shrunk_result_NEWMCUCXR_0019_0.png 12 | HDC_DUC/code/requirements.txt 13 | -------------------------------------------------------------------------------- /GCN/README.md: -------------------------------------------------------------------------------- 1 | ># Global Convolutional Network 2 | * [Introduction](#introduction) 3 | * [Dataset](#dataset) 4 | * [Montgomery Dataset](#montgomery-dataset) 5 | * [Data Preprocessing](#data-preprocessing) 6 | * [Architecture](#architecture) 7 | * [Training](#training) 8 | * * [Loss Function](#loss-function) 9 | * * [Evaluation Metrics](#evaluation-metrics) 10 | * [Results](#results) 11 | * [References](#references) 12 | 13 | ## Introduction 14 | GCN Architecture is proposed in the paper "Large Kernel Matters —— 15 | Improve Semantic Segmentation by Global Convolutional Network"[$^{[1]}$](https://arxiv.org/abs/1703.02719) 16 | 17 | A GCN based architecture, called ResNet-GCN, is used for the purposes of lung segmentation from chest x-rays. 18 | 19 | ## Dataset 20 | ### Montgomery Dataset 21 | This architecture is proposed to segment out lungs from a chest radiograph (colloquially know as chest X-Ray, CXR). The dataset is known as the [Montgomery County X-Ray Set](https://ceb.nlm.nih.gov/repositories/tuberculosis-chest-x-ray-image-data-sets/), which contains 138 posterior-anterior x-rays. The motivation being that this information can be further used to detect chest abnormalities like shrunken lungs or other structural deformities. This is especially useful in detecting tuberculosis in patients. 22 | 23 | ### Data Preprocessing 24 | The x-rays are 4892x4020 pixels big. Due to GPU memory limitations, they are resized to 1024x1024. 25 | 26 | The dataset is augmented by randomly rotating and flipping the images. 27 | 28 | ## Architecture 29 | 30 | ### Intuition 31 | 32 | The Global Convolution Network or GCN is an architecture proposed for the task of segmenting images. An image segmenter has to perform 2 tasks: classification as well as localization. This has an inherent challenge as both tasks have inherent diametrically opposite demands. 33 | While a classifier has to be transformation and rotation invariant, a localizer has to sensitive to the same. The GCN architecture finds a balance of the two demands with the following properties: 34 | 35 | 1. To retain spatial information, no fully connected layers are used and a FCN framework is adopted 36 | 2. For better classification, a large kernel size is adopted to enable dense connections in feature maps 37 | 38 | For segmentation to have semantic context, local context obtained from simple CNN architectures is not sufficient; a bigger view (i.e. global context) is critical. 39 | This architecture, coined ResNet-GCN, is basically a modified ResNet model with additional GCN blocks obtaining the required global view and the Boundary Refinement Blocks further improving the segmentation performance near object boundaries. 40 | 41 | The entire pipeline of this architecture is visualized below: 42 | 43 | ![enter image description here](https://lh3.googleusercontent.com/jma3XKGwaLnS4-0TYajAsD8gNYg0_uJ0W81Xj5ssOjub3DdEhkjxhrcUEAoTJEyZ6_l7VBCmPybM "ResNet-GCN Pipeline") 44 | ### GCN Block 45 | The GCN Block is essentially a kx1 followed by 1xk convolution summed with a parallely computed 1xk followed by kx1 convolution. This results in a large kxk kernel with dense connections. 46 | NOTE: the blocks are acting on feature maps and so channel width is larger than 3 47 | 48 | ### Boundary Refinement Block 49 | 50 | The BR block improves the segmentation near the boundaries of objects, where segmentation is less like a pure classification problem. It's design is inspired by that of ResNets and is basically a parallel branch of Conv+ReLU, followed by another conv. layer added to the input. 51 | 52 | ![enter image description here](https://lh3.googleusercontent.com/b-WoF5ESCbTOWeR1mvHd6LTd-I0HAZ1V2pFX1E1NnSnTZhPb_eDnHevCPnUwTCb3aH6ituCTFz-_ "GCN and BR Block") 53 | 54 | ## Training 55 | 56 | A pretrained ResNet-50 [$^{[2]}$](https://arxiv.org/abs/1512.03385) is used and is later fine-tuned. The rationale being that while medical images are vastly different from natural images, the ResNet is a good feature extractor (eg. edges, blobs, etc.) It is further augmented by the fact that many components in a medical image have features that resemble that of natural images eg. nuclei looks similar to balls. 57 | 58 | Refer to http://ethereon.github.io/netscope/#/gist/db945b393d40bfa26006 for the ResNet50 Architecture and https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py for the torchvision.model code 59 | 60 | 57 CXRs, with their corresponding masks, were used to train the model while 20 were used for validation purposes (hold-out cross validation). Another 61 images have been reserved as test set. 61 | 62 | ### Loss Function 63 | 64 | A linear combination of Soft Dice Loss, Soft Inverse Dice Loss, and Binary Cross-Entropy Loss (with logits) is used to train the model end-to-end. The best performance was obtained by weighing the three criteria at 0.25:0.5:0.25 (respectively). 65 | 66 | #### Binary Cross-Entropy Loss (with logits) 67 | 68 | This is calculated by passing the output of the network through a sigmoid activation before applying cross-entropy loss. 69 | 70 | > The sigmoid and cross entropy calculations are done in one class to exploit the log-sum-exp trick for greater numerical stability (as compared to sequentially applying sigmoid activation and then using vanilla BCE). 71 | 72 | 73 | 74 | $$l_n = - w_n \left[ t_n \cdot \log \sigma(x_n) + (1 - t_n) \cdot \log (1 - \sigma(x_n)) \right],$$ 75 | 76 | $$ L(x,y) = \sum_{i=1}^{N}l_i$$ 77 | 78 | #### Soft Dice Loss 79 | 80 | Dice Loss gives a measure of how accurate the overlap of the mask and ground truth is. 81 | The Sørensen–Dice coefficient is calculated as: $\frac{2. X\cap Y}{|X| + |Y|} = \frac{2. TP}{2. TP + FP + FN}$ and the Dice Loss is simply 1 - Dice coeff. 82 | For Soft version of the loss, the output of the network is passed through a sigmoid before standard dice loss is evaluated. 83 | 84 | #### Soft Inverse Dice Loss 85 | 86 | Inverse Dice loss checks for how accurately the background is masked. This penalizes the excess areas in the predicted mask. It is found by inverting the output before using the soft dice loss. This is added to account for true-negatives in our prediction. 87 | 88 | ### Evaluation Metrics 89 | 90 | Three metrics were used to evaluate the trained models; 91 | - Intersection over Union (IoU) 92 | - Dice Index 93 | - Inverse Dice Index 94 | 95 | #### Intersection over Union (IoU) 96 | 97 | IoU measures the accuracy of the predicted mask. It rewards better overlap of the prediction with the ground truth. 98 | $$\text{IoU} =\frac{P\cap GT}{P\cup GT} = \frac{TP}{TP+FP+FN} $$ 99 | 100 | 101 | > P stands for Predicted Mask while GT is ground truth. 102 | The $\epsilon$ is added to ensure bounded ratios --> 103 | 104 | 105 | #### Dice Index 106 | The Dice index (also known as Sørensen–Dice similarity coefficient) has been discussed earlier. 107 | Like IoU, Dice Index gives a measure of accuracy. 108 | > While for a single inference, both Dice and IoU are functionally equialent, over an average both have different inferences. 109 | > 110 | > While the Dice score is a measure of the average performance of the model, the IoU score is harsher towards errors and is a measure of the worst performance of the model. [$^{[\dagger]}$](https://stats.stackexchange.com/questions/273537/f1-dice-score-vs-iou) 111 | 112 | #### Inverse Dice Index 113 | As mentioned before, the Inverse dice index is obtained by inverting the masks and ground truth before calculating their dice score. 114 | >Due to the relatively smaller area of lung compared to the background, Inverse Dice score is large for every model. 115 | 116 | ## Results 117 | After 35 epochs of training, with learning rate = $10^{-3}$, scheduled to decrease by a factor of 5 after $15^{th}$ and $30^{th}$ epoch. 118 | 119 | The model performed as follows: 120 | 121 | Mean IoU: `0.8313548250035635` 122 | Mean Dice: `0.9072525421846304` 123 | Mean Inv. Dice: `0.9705243499345569` 124 | 125 | ### Examples of output 126 | ![ 127 | ](https://lh3.googleusercontent.com/r5vPdjpPavjGm8lHB5P_HbgrqKxqowjK4xe0Q56ooCbPTe0e2lxeHOT4v-lSCixl-RYYfzjO8b63 "gcn4") 128 | 129 | ![ 130 | ](https://lh3.googleusercontent.com/o9cZnU9i9lNV2orFc6S-t_pz61Ga02bY11VgS1njUaYitMhYfeAFjE4XC0Wlk21Y08EY0JDYLM9b "gcn6") 131 | 132 | 133 | *The red boundary denotes the ground truth while the blue shaded portion is the predicted mask* 134 | 135 | 136 | ### Observations 137 | 138 | - The GCN architecture is comparatively lightweight (in terms of GPU consumption) 139 | - The GCN architecture performs remarkably well for the task of lung segmentation even with very little training. 140 | - However, further work is required to achieve better performance. Type -1 error is particularly prevalent in the predictions 141 | 142 | ## References 143 | [[1]](https://arxiv.org/abs/1703.02719) Chao Peng, Xiangyu Zhang, Gang Yu, Guiming Luo, and Jian Sun. Large kernel matters - improve semantic segmentation by global convolutional network. CoRR, abs/1703.02719, 2017. 144 | 145 | [[2]](https://arxiv.org/abs/1512.03385) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. CoRR, abs/1512.03385, 2015. 146 | 147 | [[3]](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. ”Imagenet classification with deep convolutional neural networks.” In Advances in neural information processing systems, pp. 1097-1105. 2012 148 | 149 | [[4]](https://arxiv.org/abs/1409.4842) Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. ”Going deeper with convolutions.” In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1-9. 2015 150 | 151 | [[5]](https://arxiv.org/abs/1502.03167) Ioffe, Sergey, and Christian Szegedy. ”Batch normalization: Accelerating deep network training by reducing internal covariate shift.” In International Conference on Machine Learning, pp. 448-456. 2015 152 | 153 | [[6]](http://jmlr.org/papers/v15/srivastava14a.html) Srivastava, Nitish, Geoffrey E. Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. ”Dropout: a simple way to prevent neural networks from overfitting.”Journal of machine learning research 15, no. 1 (2014): 1929-1958 154 | 155 | [[7]](https://arxiv.org/abs/1409.1556) Simonyan, Karen, and Andrew Zisserman. ”Very deep convolutional networks 156 | 157 | [8] Bishop, Christopher M. Pattern recognition and machine learning. springer, 2006 158 | 159 | [[9]](http://www.deeplearningbook.org) Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. MIT press,2016 160 | 161 | [[10]](https://www.tensorflow.org/) TensorFlow 162 | 163 | [[11]](https://keras.io/) Keras 164 | 165 | [[12]]([https://pytorch.org/docs/stable/index.html](https://pytorch.org/docs/stable/index.html)) Pytorch 166 | 167 | Return to main [README](www.github.com/medal-iitb/LungSegmentation/README.md) . 168 | 169 | -------------------------------------------------------------------------------- /GCN/code/build_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | import torch.utils.model_zoo as model_zoo 6 | from torchvision import models 7 | import math 8 | 9 | class GCN(nn.Module): 10 | def __init__(self,c,out_c,k=(7,7)): #out_Channel=21 in paper 11 | super(GCN, self).__init__() 12 | self.conv_l1 = nn.Conv2d(c, out_c, kernel_size=(k[0],1), padding =(int((k[0]-1)/2),0)) 13 | self.conv_l2 = nn.Conv2d(out_c, out_c, kernel_size=(1,k[0]), padding =(0,int((k[0]-1)/2))) 14 | self.conv_r1 = nn.Conv2d(c, out_c, kernel_size=(1,k[1]), padding =(0,int((k[1]-1)/2))) 15 | self.conv_r2 = nn.Conv2d(out_c, out_c, kernel_size=(k[1],1), padding =(int((k[1]-1)/2),0)) 16 | 17 | def forward(self, x): 18 | x_l = self.conv_l1(x) 19 | x_l = self.conv_l2(x_l) 20 | 21 | x_r = self.conv_r1(x) 22 | x_r = self.conv_r2(x_r) 23 | 24 | x = x_l + x_r 25 | 26 | return x 27 | 28 | class BR(nn.Module): 29 | def __init__(self, out_c): 30 | super(BR, self).__init__() 31 | # self.bn = nn.BatchNorm2d(out_c) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv1 = nn.Conv2d(out_c,out_c, kernel_size=3,padding=1) 34 | self.conv2 = nn.Conv2d(out_c,out_c, kernel_size=3,padding=1) 35 | 36 | def forward(self,x): 37 | x_res = self.conv1(x) 38 | x_res = self.relu(x_res) 39 | x_res = self.conv2(x_res) 40 | 41 | x = x + x_res 42 | 43 | return x 44 | 45 | class FCN_GCN(nn.Module): 46 | def __init__(self, num_classes): 47 | super(FCN_GCN, self).__init__() 48 | self.num_classes = num_classes #21 in paper 49 | 50 | resnet = models.resnet50(pretrained=True) 51 | 52 | self.conv1 = resnet.conv1 # 7x7,64, stride=2 53 | self.bn0 = resnet.bn1 #BatchNorm2d(64) 54 | self.relu = resnet.relu 55 | # self.maxpool = resnet.maxpool # maxpool /2 (kernel_size=3, stride=2, padding=1) 56 | self.layer1 = nn.Sequential(resnet.maxpool, resnet.layer1) #res-2 o/p = 56x56,256 57 | self.layer2 = resnet.layer2 #res-3 o/p = 28x28,512 58 | self.layer3 = resnet.layer3 #res-4 o/p = 14x14,1024 59 | self.layer4 = resnet.layer4 #res-5 o/p = 7x7,2048 60 | 61 | self.gcn1 = GCN(256,self.num_classes) #gcn_i after layer-1 62 | self.gcn2 = GCN(512,self.num_classes) 63 | self.gcn3 = GCN(1024,self.num_classes) 64 | self.gcn4 = GCN(2048,self.num_classes) 65 | 66 | self.br1 = BR(num_classes) 67 | self.br2 = BR(num_classes) 68 | self.br3 = BR(num_classes) 69 | self.br4 = BR(num_classes) 70 | self.br5 = BR(num_classes) 71 | self.br6 = BR(num_classes) 72 | self.br7 = BR(num_classes) 73 | self.br8 = BR(num_classes) 74 | self.br9 = BR(num_classes) 75 | 76 | def _classifier(self, in_c): 77 | return nn.Sequential( 78 | nn.Conv2d(in_c,in_c,3,padding=1,bias=False), 79 | nn.BatchNorm2d(in_c/2), 80 | nn.ReLU(inplace=True), 81 | #nn.Dropout(.5), 82 | nn.Conv2d(in_c/2, self.num_classes, 1), 83 | 84 | ) 85 | 86 | def forward(self,x): 87 | input = x 88 | x = self.conv1(x) 89 | x = self.bn0(x) 90 | x = self.relu(x) 91 | pooled_x = x 92 | fm1 = self.layer1(x) 93 | fm2 = self.layer2(fm1) 94 | fm3 = self.layer3(fm2) 95 | fm4 = self.layer4(fm3) 96 | 97 | gc_fm1 = self.br1(self.gcn1(fm1)) 98 | gc_fm2 = self.br2(self.gcn2(fm2)) 99 | gc_fm3 = self.br3(self.gcn3(fm3)) 100 | gc_fm4 = self.br4(self.gcn4(fm4)) 101 | 102 | gc_fm4 = F.upsample(gc_fm4, fm3.size()[2:], mode='bilinear', align_corners=True) 103 | gc_fm3 = F.upsample(self.br5(gc_fm3 + gc_fm4), fm2.size()[2:], mode='bilinear', align_corners=True) 104 | gc_fm2 = F.upsample(self.br6(gc_fm2 + gc_fm3), fm1.size()[2:], mode='bilinear', align_corners=True) 105 | gc_fm1 = F.upsample(self.br7(gc_fm1 + gc_fm2), pooled_x.size()[2:], mode='bilinear', align_corners=True) 106 | 107 | gc_fm1 = F.upsample(self.br8(gc_fm1), scale_factor=2, mode='bilinear', align_corners=True) 108 | 109 | out = F.upsample(self.br9(gc_fm1), input.size()[2:], mode='bilinear', align_corners=True) 110 | 111 | return out 112 | -------------------------------------------------------------------------------- /GCN/code/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data.dataset import Dataset 3 | from torchvision import transforms 4 | from skimage import io, transform 5 | from PIL import Image 6 | import os 7 | import numpy as np 8 | 9 | 10 | class LungSeg(Dataset): 11 | def __init__(self, path='Image/Train/Images', transforms=None): 12 | self.path = path 13 | self.list = os.listdir(self.path) 14 | 15 | self.transforms = transforms 16 | 17 | def __getitem__(self, index): 18 | # stuff 19 | image_path = 'Image/Train/Images/' 20 | mask_path = 'Mask/Train/Images/' 21 | image = Image.open(image_path+self.list[index]) 22 | image = image.convert('RGB') 23 | mask = Image.open(mask_path+self.list[index]) 24 | mask = mask.convert('L') 25 | if self.transforms is not None: 26 | image = self.transforms(image) 27 | mask = self.transforms(mask) 28 | # If the transform variable is not empty 29 | # then it applies the operations in the transforms with the order that it is created. 30 | return (image, mask) 31 | 32 | def __len__(self): 33 | return len(self.list) # of how many data(images?) you have 34 | 35 | 36 | class LungSegTest(Dataset): 37 | def __init__(self, path='Image/Test/Images', transforms=None): 38 | self.path = path 39 | self.list = os.listdir(self.path) 40 | 41 | self.transforms = transforms 42 | 43 | def __getitem__(self, index): 44 | # stuff 45 | image_path = 'Image/Test/Images/' 46 | mask_path = 'Mask/Test/Images/' 47 | image = Image.open(image_path+self.list[index]) 48 | image = image.convert('RGB') 49 | mask = Image.open(mask_path+self.list[index]) 50 | mask = mask.convert('L') 51 | if self.transforms is not None: 52 | image = self.transforms(image) 53 | mask = self.transforms(mask) 54 | # If the transform variable is not empty 55 | # then it applies the operations in the transforms with the order that it is created. 56 | return (image, mask) 57 | 58 | def __len__(self): 59 | return len(self.list) -------------------------------------------------------------------------------- /GCN/code/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from build_model import FCN_GCN 3 | from torch.utils.data import DataLoader 4 | from data_loader_EVAL import LungSegTest 5 | from torchvision import transforms 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | from skimage import morphology, color, io, exposure 10 | 11 | def IoU(y_true, y_pred): 12 | """Returns Intersection over Union score for ground truth and predicted masks.""" 13 | assert y_true.dtype == bool and y_pred.dtype == bool 14 | y_true_f = y_true.flatten() 15 | y_pred_f = y_pred.flatten() 16 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 17 | union = np.logical_or(y_true_f, y_pred_f).sum() 18 | return (intersection + 1) * 1. / (union + 1) 19 | 20 | def Dice(y_true, y_pred): 21 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 22 | assert y_true.dtype == bool and y_pred.dtype == bool 23 | y_true_f = y_true.flatten() 24 | y_pred_f = y_pred.flatten() 25 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 26 | return (2. * intersection + 1.) / (y_true.sum() + y_pred.sum() + 1.) 27 | 28 | def Inv_Dice(y_true, y_pred): 29 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 30 | assert y_true.dtype == bool and y_pred.dtype == bool 31 | y_true_f = np.logical_not(y_true.flatten()) 32 | y_pred_f = np.logical_not(y_pred.flatten()) 33 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 34 | return (2. * intersection + 1.) / (y_true_f.sum() + y_pred_f.sum() + 1.) 35 | 36 | def masked(img, gt, mask, alpha=1): 37 | """Returns image with GT lung field outlined with red, predicted lung field 38 | filled with blue.""" 39 | rows, cols = img.shape[:2] 40 | color_mask = np.zeros((rows, cols, 3)) 41 | boundary = morphology.dilation(gt, morphology.disk(3)) ^ gt 42 | color_mask[mask == 1] = [0, 0, 1] 43 | color_mask[boundary == 1] = [1, 0, 0] 44 | 45 | img_hsv = color.rgb2hsv(img) 46 | color_mask_hsv = color.rgb2hsv(color_mask) 47 | 48 | img_hsv[..., 0] = color_mask_hsv[..., 0] 49 | img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha 50 | 51 | img_masked = color.hsv2rgb(img_hsv) 52 | return img_masked 53 | 54 | def remove_small_regions(img, size): 55 | """Morphologically removes small (less than size) connected regions of 0s or 1s.""" 56 | img = morphology.remove_small_objects(img, size) 57 | img = morphology.remove_small_holes(img, size) 58 | return img 59 | 60 | if __name__ == '__main__': 61 | 62 | # Path to csv-file. File should contain X-ray filenames as first column, 63 | # mask filenames as second column. 64 | # Load test data 65 | img_size = (1024, 1024) 66 | 67 | n_test = 61 68 | inp_shape = (1024,1024,3) 69 | batch_size=1 70 | 71 | # Load model 72 | net = FCN_GCN(1) 73 | 74 | net.load_state_dict(torch.load('Weights_221_2/cp_19_0.1336055189371109.pth')) 75 | net.eval() 76 | 77 | 78 | ious = np.zeros(n_test) 79 | dices = np.zeros(n_test) 80 | inv_dices = np.zeros(n_test) 81 | seed = 1 82 | transformations_test = transforms.Compose([transforms.Resize(img_size), 83 | transforms.ToTensor()]) 84 | test_set = LungSegTest(transforms = transformations_test) 85 | test_loader = DataLoader(test_set, batch_size=batch_size) 86 | 87 | 88 | i = 0 89 | for xx, yy, name in test_loader: 90 | #img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 91 | pred = net(xx) 92 | pred = F.sigmoid(pred) 93 | pred = pred.detach().numpy()[0,0,:,:] 94 | mask = yy.numpy()[0,0,:,:] 95 | xx = xx.numpy()[0,:,:,:].transpose(1,2,0) 96 | img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 97 | 98 | # Binarize masks 99 | gt = mask > 0.5 100 | pr = pred > 0.5 101 | 102 | # Remove regions smaller than 2% of the image 103 | pr = remove_small_regions(pr, 0.02 * np.prod(img_size)) 104 | 105 | io.imsave('results/{}.png'.format(name[0][:-4]), masked(img, gt, pr, 1)) 106 | 107 | ious[i] = IoU(gt, pr) 108 | dices[i] = Dice(gt, pr) 109 | inv_dices[i] = Inv_Dice(gt, pr) 110 | 111 | i += 1 112 | if i == n_test: 113 | break 114 | 115 | print ('Mean IoU:', ious.mean()) 116 | print ('Mean Dice:', dices.mean()) 117 | print ('Mean Inv. Dice:', inv_dices.mean()) 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /GCN/code/inferences.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from build_model import FCN_GCN 3 | from torch.utils.data import DataLoader 4 | from data_loader import LungSegTest 5 | from torchvision import transforms 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | from skimage import morphology, color, io, exposure 10 | 11 | def IoU(y_true, y_pred): 12 | """Returns Intersection over Union score for ground truth and predicted masks.""" 13 | assert y_true.dtype == bool and y_pred.dtype == bool 14 | y_true_f = y_true.flatten() 15 | y_pred_f = y_pred.flatten() 16 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 17 | union = np.logical_or(y_true_f, y_pred_f).sum() 18 | return (intersection + 1) * 1. / (union + 1) 19 | 20 | def Dice(y_true, y_pred): 21 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 22 | assert y_true.dtype == bool and y_pred.dtype == bool 23 | y_true_f = y_true.flatten() 24 | y_pred_f = y_pred.flatten() 25 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 26 | return (2. * intersection + 1.) / (y_true.sum() + y_pred.sum() + 1.) 27 | 28 | def Inv_Dice(y_true, y_pred): 29 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 30 | assert y_true.dtype == bool and y_pred.dtype == bool 31 | y_true_f = np.logical_not(y_true.flatten()) 32 | y_pred_f = np.logical_not(y_pred.flatten()) 33 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 34 | return (2. * intersection + 1.) / (y_true_f.sum() + y_pred_f.sum() + 1.) 35 | 36 | def masked(img, gt, mask, alpha=1): 37 | """Returns image with GT lung field outlined with red, predicted lung field 38 | filled with blue.""" 39 | rows, cols = img.shape[:2] 40 | color_mask = np.zeros((rows, cols, 3)) 41 | boundary = morphology.dilation(gt, morphology.disk(3)) ^ gt 42 | color_mask[mask == 1] = [0, 0, 1] 43 | color_mask[boundary == 1] = [1, 0, 0] 44 | 45 | img_hsv = color.rgb2hsv(img) 46 | color_mask_hsv = color.rgb2hsv(color_mask) 47 | 48 | img_hsv[..., 0] = color_mask_hsv[..., 0] 49 | img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha 50 | 51 | img_masked = color.hsv2rgb(img_hsv) 52 | return img_masked 53 | 54 | def remove_small_regions(img, size): 55 | """Morphologically removes small (less than size) connected regions of 0s or 1s.""" 56 | img = morphology.remove_small_objects(img, size) 57 | img = morphology.remove_small_holes(img, size) 58 | return img 59 | 60 | if __name__ == '__main__': 61 | 62 | # Path to csv-file. File should contain X-ray filenames as first column, 63 | # mask filenames as second column. 64 | # Load test data 65 | img_size = (1024, 1024) 66 | 67 | n_test = 20 68 | inp_shape = (1024,1024,3) 69 | batch_size=1 70 | 71 | # Load model 72 | #model_name = 'model.020.hdf5' 73 | #UNet = load_model(model_name) 74 | net = FCN_GCN(1) 75 | 76 | net.load_state_dict(torch.load('Weights_UNO/cp_20_0.13776783943176268.pth')) 77 | net.eval() 78 | 79 | 80 | ious = np.zeros(n_test) 81 | dices = np.zeros(n_test) 82 | inv_dices = np.zeros(n_test) 83 | seed = 1 84 | transformations_test = transforms.Compose([transforms.Resize(img_size), 85 | transforms.ToTensor()]) 86 | test_set = LungSegTest(transforms = transformations_test) 87 | test_loader = DataLoader(test_set, batch_size=batch_size) 88 | 89 | 90 | i = 0 91 | for xx, yy in test_loader: 92 | #img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 93 | pred = net(xx) 94 | pred = F.sigmoid(pred) 95 | pred = pred.detach().numpy()[0,0,:,:] 96 | mask = yy.numpy()[0,0,:,:] 97 | xx = xx.numpy()[0,:,:,:].transpose(1,2,0) 98 | img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 99 | 100 | # Binarize masks 101 | gt = mask > 0.5 102 | pr = pred > 0.5 103 | 104 | # Remove regions smaller than 2% of the image 105 | pr = remove_small_regions(pr, 0.02 * np.prod(img_size)) 106 | 107 | io.imsave('results_UNO/{}.png'.format(i), masked(img, gt, pr, 1)) 108 | 109 | ious[i] = IoU(gt, pr) 110 | dices[i] = Dice(gt, pr) 111 | inv_dices[i] = Inv_Dice(gt, pr) 112 | i += 1 113 | if i == n_test: 114 | break 115 | 116 | print ('Mean IoU:', ious.mean()) 117 | print ('Mean Dice:', dices.mean()) 118 | print ('Mean Inv. Dice:', inv_dices.mean()) 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /GCN/code/train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from torchvision import transforms 6 | from torch.autograd import Variable 7 | from build_model import FCN_GCN 8 | import os 9 | import csv 10 | 11 | class SoftDiceLoss(nn.Module): 12 | def __init__(self, weight=None, size_average=True): 13 | super(SoftDiceLoss, self).__init__() 14 | 15 | def forward(self, logits, targets): 16 | smooth = 1. 17 | logits = F.sigmoid(logits) 18 | iflat = logits.view(-1) 19 | tflat = targets.view(-1) 20 | intersection = (iflat * tflat).sum() 21 | 22 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 23 | 24 | class SoftInvDiceLoss(nn.Module): 25 | def __init__(self, weight=None, size_average=True): 26 | super(SoftInvDiceLoss, self).__init__() 27 | 28 | def forward(self, logits, targets): 29 | smooth = 1. 30 | logits = F.sigmoid(logits) 31 | iflat = 1 - logits.view(-1) 32 | tflat = 1 - targets.view(-1) 33 | intersection = (iflat * tflat).sum() 34 | 35 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 36 | 37 | img_size = (1024,1024) 38 | transformations_train = transforms.Compose([transforms.Resize(img_size), 39 | transforms.RandomRotation(10), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor()]) 42 | 43 | transformations_test = transforms.Compose([transforms.Resize(img_size), 44 | transforms.ToTensor()]) 45 | 46 | 47 | 48 | from data_loader import LungSeg 49 | from data_loader import LungSegTest 50 | train_set = LungSeg(transforms = transformations_train) 51 | test_set = LungSegTest(transforms = transformations_test) 52 | batch_size = 1 53 | num_epochs = 30 54 | 55 | class Average(object): 56 | def __init__(self): 57 | self.reset() 58 | 59 | def reset(self): 60 | self.sum = 0 61 | self.count = 0 62 | 63 | def update(self, val, n=1): 64 | self.sum += val 65 | self.count += n 66 | 67 | @property 68 | def avg(self): 69 | return self.sum / self.count 70 | 71 | def train(): 72 | cuda = torch.cuda.is_available() 73 | net = FCN_GCN(1) 74 | net.load_state_dict(torch.load('cp.pth')) 75 | 76 | criterion1 = nn.BCEWithLogitsLoss() 77 | criterion2 = SoftDiceLoss() 78 | criterion3 = SoftInvDiceLoss() 79 | 80 | if cuda: 81 | net = net.cuda() 82 | criterion1 = criterion1.cuda() 83 | criterion2 = criterion2.cuda() 84 | criterion3 = criterion3.cuda() 85 | 86 | optimizer = torch.optim.Adam(net.parameters(), lr=4e-5) 87 | #scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,20], gamma=0.5) 88 | 89 | print("preparing training data ...") 90 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 91 | print("done ...") 92 | 93 | test_set = LungSegTest(transforms = transformations_test) 94 | test_loader = DataLoader(test_set, batch_size=batch_size) 95 | for epoch in range(num_epochs): 96 | train_loss = Average() 97 | net.train() 98 | 99 | #scheduler.step() 100 | 101 | for i, (images, masks) in enumerate(train_loader): 102 | images = Variable(images) 103 | masks = Variable(masks) 104 | if cuda: 105 | images = images.cuda() 106 | masks = masks.cuda() 107 | 108 | optimizer.zero_grad() 109 | outputs = net(images) 110 | loss = 0.4*criterion1(outputs, masks) + 0.4*criterion2(outputs, masks) + 0.2*criterion3(outputs, masks) 111 | loss.backward() 112 | optimizer.step() 113 | train_loss.update(loss.item(), images.size(0)) 114 | 115 | val_loss = Average() 116 | val_loss_dice = Average() 117 | net.eval() 118 | for images, masks in test_loader: 119 | images = Variable(images) 120 | masks = Variable(masks) 121 | if cuda: 122 | images = images.cuda() 123 | masks = masks.cuda() 124 | 125 | outputs = net(images) 126 | vloss = 0.4*criterion1(outputs, masks) + 0.4*criterion2(outputs, masks) + 0.2*criterion3(outputs, masks) 127 | vloss_dice = criterion2(outputs, masks) 128 | val_loss.update(vloss.item(), images.size(0)) 129 | val_loss_dice.update(vloss_dice.item(), images.size(0)) 130 | 131 | print("Epoch {}/{}, Loss: {}, Validation Loss: {}, Validation Dice Loss: {}".format(epoch+1,num_epochs, train_loss.avg, val_loss.avg, val_loss_dice.avg)) 132 | 133 | torch.save(net.state_dict(), 'Weights_221/cp_{}_{}.pth'.format(epoch+1, val_loss_dice.avg)) 134 | 135 | return net 136 | 137 | def test(model): 138 | model.eval() 139 | 140 | 141 | 142 | if __name__ == "__main__": 143 | train() 144 | -------------------------------------------------------------------------------- /GCN/pics/gcn4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MEDAL-IITB/Lung-Segmentation/d57a23536edd9e39c62608407eee7fbea388b43e/GCN/pics/gcn4.png -------------------------------------------------------------------------------- /GCN/pics/gcn6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MEDAL-IITB/Lung-Segmentation/d57a23536edd9e39c62608407eee7fbea388b43e/GCN/pics/gcn6.png -------------------------------------------------------------------------------- /GCN/pics/result_NEWMCUCXR_0019_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MEDAL-IITB/Lung-Segmentation/d57a23536edd9e39c62608407eee7fbea388b43e/GCN/pics/result_NEWMCUCXR_0019_0.png -------------------------------------------------------------------------------- /GCN/pics/shrunk_result_NEWMCUCXR_0019_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MEDAL-IITB/Lung-Segmentation/d57a23536edd9e39c62608407eee7fbea388b43e/GCN/pics/shrunk_result_NEWMCUCXR_0019_0.png -------------------------------------------------------------------------------- /HDC_DUC/README.md: -------------------------------------------------------------------------------- 1 | # Hybrid Dilated Convolution and Dense Upsampling Convolution for Lung Segmentation 2 | * [Introduction](#introduction) 3 | * [What are HDC and DUC?](#what-are-hdc-and-duc) 4 | * [Architecture](#architecture) 5 | * [Training](#training) 6 | * [Results](#results) 7 | * [References](#references) 8 | ## Introduction 9 | 10 | The paper titled "Understanding Convolution for Semantic Segmentation" [$^{[1]}$](https://arxiv.org/abs/1702.08502.pdf) describes Hybrid Dilated Convolution (HDC) and Dense Upsampling Convolution (DUC) frameworks on a base CNN architecture for semantic segmentation of images. This architecture achieves state-of-the-art (at the time of submission) results on Cityscapes, KITTI Road estimation (overall) and PASCALVOC 2012 datasets. 11 | 12 | We use the HDC and DUC frameworks proposed in the paper above on the Montgomery County X-Ray Set which contains 138 posterior-anterior chest X-Ray images. 13 | 14 | ## What are HDC and DUC? 15 | 16 | ![ 17 | ](https://lh3.googleusercontent.com/Hc5LBZAz15_VEPg3nxUP5vxfsawHip3xSwILiVhBDvQSlScaQYtv_k3WIyBFNIOsaiPtNri5z2g0 "hdc_arch") 18 | **Hybrid Dilated Convolution (HDC)** is a set of convolution layers in which each layer has a different rate of dilation ($r_i$). HDC helps detect finer details at higher resolutions. It effectively enlarges receptive fields of the network to aggregate global information using lesser number of parameters than a conventional convolution with a same-sized receptive field hence making it computationally more efficient. It also eliminates the issues of gridding which causes loss of local information. 19 | 20 | Suppose, we have N convolutional layers with kernel size K*K that have dilation rates of [$r_1, ..., r_i , ..., r_n$], the goal of HDC is to let the final size of the RF of a series of convolutional operations fully covers a square region without any holes or missing edges. 21 | The “maximum distance between two nonzero values” is defined as $M_i = max[$M_i+1$-$2r_i , M_i+1$−$2(M_i+1$−$r_i), r_i$ ], with $M_n = r_n$. The design goal is to let $M_2 ≤ K$. 22 | 23 | Note-Dilation rates such as [2,4,8] can't be used as they gridding still happens with rates which have a common factor. 24 | 25 | ![ 26 | ](https://lh3.googleusercontent.com/rlkiBpYIHdu3bssClJUAssPeNW7UwxurR3EzM7BNO0RQCbjTbG34Ym-h3EpiBdWQ6NVwGauAT5fX "dilatedconv") 27 | (a)Conventional Dilation causes gridding (b) HDC prevents gridding 28 | 29 | **Dense Upsampling Convolutional** is a convolution operation performed on the feature map (h x w x c) obtained as the output of the backbone CNN to output a feature map of dimensions (H x W x L) where h=H/d, w=W/d, c=$d^2$L, H*W*W x C are the dimensions of the input image, L is the number of classes of segmentation, and d is downsampling factor. The output of the DUC layer is then reshaped to H × W × L with a softmax layer, and an element-wise argmax operator is applied to get the final label map. The upsampling done here is dense in the sense that we are performing calculations on every pixel and no zeroes are involved. This is better than bilinear upsampling as DUC is learnable. DUC is particularly good for at detecting small and far off objects. Also, it is very easy to implement. 30 | 31 | ## Architecture 32 | 33 | For our task of lung segmentation from X-ray images, we use a ResNet101 (with pretrained weights from ImageNet) architecture with a HDC unit and then a DUC unit. The HDC unit consisted of grouping every four blocks together in the res4b module and using 1, 2, 5, and 9 dilation rates respectively. We also modify the dilation rates of the res5b module to 5, 9 and 17. 34 | 35 | ## Training 36 | 37 | We used a random split of 57 images for training, 20 images for cross-validation and 61 testing images. The loss used was a combination of Binary cross entropy, dice and inverse dice losses. The metrics chosen for evaluation are mIoU and Mean Dice scores. 38 | images for testing. 39 | Since, the dataset is small, augmentation was implemented. The images were flipped horizontally and vertically. Gaussian noise was then added to these images. 40 | After observing initial training results, a weighted mean of Binary Cross Entropy Loss, Dice Loss and Inverse Dice Loss was chosen 41 | with weights 1, 1 and 1.5. 42 | 43 | The model with the best validation score was chosen which is the following- 44 | Trained for 100 epochs with an Adam Optimiser and a batch size of 5 and a scheduled learning rate. 45 | The cross-validation scores were Mean IoU: 0.7863257 and Mean Dice: 0.8781220 46 | 47 | ## Results 48 | The results over the test data were Mean IoU: 0.7461923 and Mean Dice: 0.8500786. 49 | ![ 50 | ](https://lh3.googleusercontent.com/02cq0q1Jj1AI_U3laEazAQ8wdISfC_mvDsFzU369v0oW1ByDGkEKUtyadnIT0Es7NTeIXGEb6NKb "30") 51 | ![ 52 | ](https://lh3.googleusercontent.com/7sDIuCWYfxR7u-cj_LykVV-fWF7Ql8jE9G443Mv5OkpnBlIksT3_xlK0vPb03flWwFRKgtDOdPLe "17") 53 | 54 | ![ 55 | ](https://lh3.googleusercontent.com/_hjqTg4fmaYyDLlOCcHI3Ppc7F-qgkGVZ1pSC-XFRgLyAaRYb4w18cSwxrSXuyCP72JQ_V3nbidq "32") 56 | 57 | ## References 58 | [[1]](https://arxiv.org/abs/1702.08502) Panqu Wang, Pengfei Chen, Ye Yuan, Ding Liu, Zehua Huang, Xiaodi Hou, and Garrison W. Cottrell. Understanding convolution for semantic segmentation. CoRR, abs/1702.08502, 2017 59 | 60 | 61 | [[2]](https://arxiv.org/abs/1512.03385) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. CoRR, abs/1512.03385, 2015. 62 | 63 | [[3]](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. ”Imagenet classification with deep convolutional neural networks.” In Advances in neural information processing systems, pp. 1097-1105. 2012 64 | 65 | [[4]](https://arxiv.org/abs/1409.4842) Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. ”Going deeper with convolutions.” In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1-9. 2015 66 | 67 | [[5]](https://arxiv.org/abs/1502.03167) Ioffe, Sergey, and Christian Szegedy. ”Batch normalization: Accelerating deep network training by reducing internal covariate shift.” In International Conference on Machine Learning, pp. 448-456. 2015 68 | 69 | [[6]](http://jmlr.org/papers/v15/srivastava14a.html) Srivastava, Nitish, Geoffrey E. Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. ”Dropout: a simple way to prevent neural networks from overfitting.”Journal of machine learning research 15, no. 1 (2014): 1929-1958 70 | 71 | [[7]](https://arxiv.org/abs/1409.1556) Simonyan, Karen, and Andrew Zisserman. ”Very deep convolutional networks 72 | 73 | [8] Bishop, Christopher M. Pattern recognition and machine learning. springer, 2006 74 | 75 | [[9]](http://www.deeplearningbook.org) Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. MIT press,2016 76 | 77 | [[10]](https://www.tensorflow.org/) TensorFlow 78 | 79 | [[11]](https://keras.io/) Keras 80 | 81 | [[12]]([https://pytorch.org/docs/stable/index.html](https://pytorch.org/docs/stable/index.html)) Pytorch 82 | 83 | Return to main [README](www.github.com/medal-iitb/LungSegmentation/README.md) . 84 | 85 | 86 | -------------------------------------------------------------------------------- /HDC_DUC/code/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import cv2 4 | from tqdm import tqdm 5 | import os 6 | import sys 7 | from PIL import Image 8 | 9 | def sp_noise(image,var): 10 | row,col,ch= image.shape 11 | mean = 0 12 | #var = 0.1 13 | sigma = var**0.5 14 | gauss = np.random.normal(mean,sigma,(row,col,ch)) 15 | gauss = gauss.reshape(row,col,ch) 16 | noisy = image + gauss 17 | return noisy 18 | ''' 19 | output = np.zeros(image.size,np.uint8) 20 | thres = 1 - prob 21 | for i in range(image.shape[0]): 22 | for j in range(image.shape[1]): 23 | rdn = random.random() 24 | if rdn < prob: 25 | output[i][j] = 0 26 | elif rdn > thres: 27 | output[i][j] = 255 28 | else: 29 | output[i][j] = image[i][j] 30 | return output 31 | ''' 32 | 33 | 34 | # allimages = [] 35 | # maskimages = [] 36 | 37 | 38 | image_names_image_test = os.listdir('Image/Test/Images') 39 | image_names_image_train = os.listdir('Image/Train/Images') 40 | image_names_manualmask_left = os.listdir('ManualMask/leftMask') 41 | image_names_manualmask_right = os.listdir('ManualMask/rightMask') 42 | image_names_mask_test = os.listdir('Mask/Test/Images') 43 | image_names_mask_train = os.listdir('Mask/Train/Images') 44 | 45 | for images in tqdm(image_names_image_test): 46 | #print(images) 47 | im = cv2.imread('Image/Test/Images/'+images) # Only for grayscale image 48 | #print(type(im)) 49 | noise_img = sp_noise(im,0.1) 50 | cv2.imwrite('Image/Test/Images/'+images.split('.')[0] + '_gauss.png', noise_img) 51 | 52 | for images in tqdm(image_names_image_train): 53 | #print(images) 54 | im = cv2.imread('Image/Train/Images/'+images) # Only for grayscale image 55 | #print(type(im)) 56 | noise_img = sp_noise(im,0.1) 57 | cv2.imwrite('Image/Train/Images/'+images.split('.')[0] + '_gauss.png', noise_img) 58 | 59 | for images in tqdm(image_names_manualmask_left): 60 | #print(images) 61 | im = cv2.imread('ManualMask/leftMask/'+images) # Only for grayscale image 62 | #print(type(im)) 63 | noise_img = sp_noise(im,0.0) 64 | cv2.imwrite('ManualMask/leftMask/'+images.split('.')[0] + '_gauss.png', noise_img) 65 | 66 | for images in tqdm(image_names_manualmask_right): 67 | #print(images) 68 | im = cv2.imread('ManualMask/rightMask/'+images) # Only for grayscale image 69 | #print(type(im)) 70 | noise_img = sp_noise(im,0.0) 71 | cv2.imwrite('ManualMask/rightMask/'+images.split('.')[0] + '_gauss.png', noise_img) 72 | 73 | for images in tqdm(image_names_mask_test): 74 | #print(images) 75 | im = cv2.imread('Mask/Test/Images/'+images) # Only for grayscale image 76 | #print(type(im)) 77 | noise_img = sp_noise(im,0.0) 78 | cv2.imwrite('Mask/Test/Images/'+images.split('.')[0] + '_gauss.png', noise_img) 79 | 80 | for images in tqdm(image_names_mask_train): 81 | #print(images) 82 | im = cv2.imread('Mask/Train/Images/'+images) # Only for grayscale image 83 | #print(type(im)) 84 | noise_img = sp_noise(im,0.0) 85 | cv2.imwrite('Mask/Train/Images/'+images.split('.')[0] + '_gauss.png', noise_img) 86 | -------------------------------------------------------------------------------- /HDC_DUC/code/build_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function, division 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data import sampler 9 | import torchvision 10 | from torchvision import models 11 | from torchvision import datasets, transforms 12 | 13 | 14 | #import torchvision.datasets as dset 15 | #import torchvision.transforms as T 16 | import torch.nn.functional as F 17 | 18 | import numpy as np 19 | import matplotlib as mpl 20 | mpl.use('Agg') 21 | import matplotlib.pyplot as plt 22 | import time 23 | import os 24 | import copy 25 | 26 | 27 | class DUC(nn.Module): 28 | #d=downsample_factor, L=num_of_classes 29 | def __init__(self, in_channels, d, L ): 30 | super(DUC, self).__init__() 31 | out_channels = (d**2)*L 32 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(3,3)) 33 | self.BN = nn.BatchNorm2d(out_channels, affine = False) #Should affine be True only? 34 | self.pixel_shuffle = nn.PixelShuffle(d) 35 | 36 | def forward(self,x): 37 | x = self.conv(x) 38 | x = self.BN(x) 39 | x = F.relu(x) 40 | x = self.pixel_shuffle(x) 41 | return x 42 | 43 | class ResNetwithHDCDUC(nn.Module): 44 | def __init__(self, L, pretrained=True): 45 | super(ResNetwithHDCDUC, self).__init__() 46 | model = torchvision.models.resnet101(pretrained=True) 47 | self.res1 = nn.Sequential(*list(model.children())[0:3]) 48 | self.res2 = nn.Sequential(*list(model.children())[4]) 49 | self.res3 = nn.Sequential(*list(model.children())[5]) 50 | self.res4 = nn.Sequential(*list(model.children())[6]) 51 | self.res5 = nn.Sequential(*list(model.children())[7]) 52 | self.avg_pool = list(model.children())[8] 53 | #self.avg_pool = nn.Sequential(*list(model.children())[8]) 54 | self.max_pool = nn.MaxPool2d(kernel_size=(2,2), stride=1) 55 | 56 | layer4_group_config = [1, 2, 5, 9] 57 | for i in range(len(self.res4)): 58 | self.res4[i].conv2.dilation = (layer4_group_config[i % 4], layer4_group_config[i % 4]) 59 | self.res4[i].conv2.padding = (layer4_group_config[i % 4], layer4_group_config[i % 4]) 60 | layer5_group_config = [5, 9, 17] 61 | for i in range(len(self.res5)): 62 | self.res5[i].conv2.dilation = (layer5_group_config[i], layer5_group_config[i]) 63 | self.res5[i].conv2.padding = (layer5_group_config[i], layer5_group_config[i]) 64 | 65 | 66 | in_channels = 2048 67 | d = 32 68 | self.duc_func = DUC(in_channels, d, L=L) 69 | 70 | def forward(self, x): 71 | x1 = self.res1(x) 72 | x2 = self.res2(x1) 73 | x3 = self.res3(x2) 74 | x4 = self.res4(x3) 75 | x5 = self.res5(x4) 76 | x6 = self.avg_pool(x5) 77 | #x8 = self.max_pool(x6) 78 | #in_channels = 2048 79 | #d = float(x.shape[2]/x5.shape[2]) 80 | #duc_func = DUC(in_channels, d, 1) 81 | x7 = self.duc_func(x6) 82 | 83 | return x7 84 | -------------------------------------------------------------------------------- /HDC_DUC/code/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data.dataset import Dataset 3 | from torchvision import transforms 4 | from skimage import io, transform 5 | from PIL import Image 6 | import os 7 | import numpy as np 8 | 9 | 10 | class LungSeg(Dataset): 11 | def __init__(self, path='Image/Train/Images', transforms=None): 12 | self.path = path 13 | self.list = os.listdir(self.path) 14 | 15 | self.transforms = transforms 16 | 17 | def __getitem__(self, index): 18 | # stuff 19 | image_path = 'Image/Train/Images/' 20 | mask_path = 'Mask/Train/Images/' 21 | image = Image.open(image_path+self.list[index]) 22 | image = image.convert('RGB') 23 | mask = Image.open(mask_path+self.list[index]) 24 | mask = mask.convert('L') 25 | if self.transforms is not None: 26 | image = self.transforms(image) 27 | mask = self.transforms(mask) 28 | # If the transform variable is not empty 29 | # then it applies the operations in the transforms with the order that it is created. 30 | return (image, mask) 31 | 32 | def __len__(self): 33 | return len(self.list) # of how many data(images?) you have 34 | 35 | 36 | class LungSegTest(Dataset): 37 | def __init__(self, path='test_image', transforms=None): 38 | self.path = path 39 | self.list = os.listdir(self.path) 40 | 41 | self.transforms = transforms 42 | 43 | def __getitem__(self, index): 44 | # stuff 45 | image_path = 'test_image/' 46 | mask_path = 'test_mask/' 47 | image = Image.open(image_path+self.list[index]) 48 | image = image.convert('RGB') 49 | mask = Image.open(mask_path+self.list[index]) 50 | mask = mask.convert('L') 51 | if self.transforms is not None: 52 | image = self.transforms(image) 53 | mask = self.transforms(mask) 54 | # If the transform variable is not empty 55 | # then it applies the operations in the transforms with the order that it is created. 56 | return (image, mask) 57 | 58 | def __len__(self): 59 | return len(self.list) 60 | """ 61 | 62 | class LungSegTest(Dataset): 63 | def __init__(self, path='Image/Test/Images', transforms=None): 64 | self.path = path 65 | self.list = os.listdir(self.path) 66 | 67 | self.transforms = transforms 68 | 69 | def __getitem__(self, index): 70 | # stuff 71 | image_path = 'Image/Test/Images/' 72 | mask_path = 'Mask/Test/Images/' 73 | image = Image.open(image_path+self.list[index]) 74 | image = image.convert('RGB') 75 | mask = Image.open(mask_path+self.list[index]) 76 | mask = mask.convert('L') 77 | if self.transforms is not None: 78 | image = self.transforms(image) 79 | mask = self.transforms(mask) 80 | # If the transform variable is not empty 81 | # then it applies the operations in the transforms with the order that it is created. 82 | return (image, mask) 83 | 84 | def __len__(self): 85 | return len(self.list) 86 | 87 | """ -------------------------------------------------------------------------------- /HDC_DUC/code/flip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import cv2 4 | from tqdm import tqdm 5 | import os 6 | import sys 7 | from PIL import Image 8 | 9 | image_names_image_test = os.listdir('Image/Test/Images') 10 | image_names_image_train = os.listdir('Image/Train/Images') 11 | image_names_manualmask_left = os.listdir('ManualMask/leftMask') 12 | image_names_manualmask_right = os.listdir('ManualMask/rightMask') 13 | image_names_mask_test = os.listdir('Mask/Test/Images') 14 | image_names_mask_train = os.listdir('Mask/Train/Images') 15 | 16 | # maskimages.append(image_names_manualmask_right) 17 | # maskimages.append(image_names_manualmask_left) 18 | # maskimages.append(image_names_mask_test) 19 | # maskimages.append(image_names_mask_train) 20 | 21 | # allimages.append(image_names_image_train) 22 | # allimages.append(image_names_image_test) 23 | 24 | for images in tqdm(image_names_image_train): 25 | img = cv2.imread('Image/Train/Images/'+images) # Only for grayscale image 26 | horizontal_img = img.copy() 27 | vertical_img = img.copy() 28 | horizontal_img = cv2.flip( img, 0 ) 29 | vertical_img = cv2.flip( img, 1 ) 30 | cv2.imwrite('Image/Train/Images/'+images.split('.')[0] + '_horizontal.png', horizontal_img) 31 | cv2.imwrite('Image/Train/Images/'+images.split('.')[0] + '_vertical.png', vertical_img) 32 | 33 | for images in tqdm(image_names_image_test): 34 | img = cv2.imread('Image/Test/Images/'+images) # Only for grayscale image 35 | horizontal_img = img.copy() 36 | vertical_img = img.copy() 37 | horizontal_img = cv2.flip( img, 0 ) 38 | vertical_img = cv2.flip( img, 1 ) 39 | cv2.imwrite('Image/Test/Images/'+images.split('.')[0] + '_horizontal.png', horizontal_img) 40 | cv2.imwrite('Image/Test/Images/'+images.split('.')[0] + '_vertical.png', vertical_img) 41 | 42 | for images in tqdm(image_names_manualmask_left): 43 | img = cv2.imread('ManualMask/leftMask/'+images) # Only for grayscale image 44 | horizontal_img = img.copy() 45 | vertical_img = img.copy() 46 | horizontal_img = cv2.flip( img, 0 ) 47 | vertical_img = cv2.flip( img, 1 ) 48 | cv2.imwrite('ManualMask/leftMask/'+images.split('.')[0] + '_horizontal.png', horizontal_img) 49 | cv2.imwrite('ManualMask/leftMask/'+images.split('.')[0] + '_vertical.png', vertical_img) 50 | 51 | for images in tqdm(image_names_manualmask_right): 52 | img = cv2.imread('ManualMask/rightMask/'+images) # Only for grayscale image 53 | horizontal_img = img.copy() 54 | vertical_img = img.copy() 55 | horizontal_img = cv2.flip( img, 0 ) 56 | vertical_img = cv2.flip( img, 1 ) 57 | cv2.imwrite('ManualMask/rightMask/'+images.split('.')[0] + '_horizontal.png', horizontal_img) 58 | cv2.imwrite('ManualMask/rightMask/'+images.split('.')[0] + '_vertical.png', vertical_img) 59 | 60 | for images in tqdm(image_names_mask_test): 61 | img = cv2.imread('Mask/Test/Images/'+images) # Only for grayscale image 62 | horizontal_img = img.copy() 63 | vertical_img = img.copy() 64 | horizontal_img = cv2.flip( img, 0 ) 65 | vertical_img = cv2.flip( img, 1 ) 66 | cv2.imwrite('Mask/Test/Images/'+images.split('.')[0] + '_horizontal.png', horizontal_img) 67 | cv2.imwrite('Mask/Test/Images/'+images.split('.')[0] + '_vertical.png', vertical_img) 68 | 69 | for images in tqdm(image_names_mask_train): 70 | img = cv2.imread('Mask/Train/Images/'+images) # Only for grayscale image 71 | horizontal_img = img.copy() 72 | vertical_img = img.copy() 73 | horizontal_img = cv2.flip( img, 0 ) 74 | vertical_img = cv2.flip( img, 1 ) 75 | cv2.imwrite('Mask/Train/Images/'+images.split('.')[0] + '_horizontal.png', horizontal_img) 76 | cv2.imwrite('Mask/Train/Images/'+images.split('.')[0] + '_vertical.png', vertical_img) 77 | -------------------------------------------------------------------------------- /HDC_DUC/code/inferences.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from build_model import ResNetwithHDCDUC 3 | from torch.utils.data import DataLoader 4 | from data_loader import LungSegTest 5 | from torchvision import transforms 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | from skimage import morphology, color, io, exposure 10 | 11 | def IoU(y_true, y_pred): 12 | """Returns Intersection over Union score for ground truth and predicted masks.""" 13 | assert y_true.dtype == bool and y_pred.dtype == bool 14 | y_true_f = y_true.flatten() 15 | y_pred_f = y_pred.flatten() 16 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 17 | union = np.logical_or(y_true_f, y_pred_f).sum() 18 | return (intersection + 1) * 1. / (union + 1) 19 | 20 | def Dice(y_true, y_pred): 21 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 22 | assert y_true.dtype == bool and y_pred.dtype == bool 23 | y_true_f = y_true.flatten() 24 | y_pred_f = y_pred.flatten() 25 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 26 | return (2. * intersection + 1.) / (y_true.sum() + y_pred.sum() + 1.) 27 | 28 | def masked(img, gt, mask, alpha=1): 29 | """Returns image with GT lung field outlined with red, predicted lung field 30 | filled with blue.""" 31 | rows, cols = img.shape[:2] 32 | color_mask = np.zeros((rows, cols, 3)) 33 | boundary = morphology.dilation(gt, morphology.disk(3)) ^ gt 34 | color_mask[mask == 1] = [0, 0, 1] 35 | color_mask[boundary == 1] = [1, 0, 0] 36 | 37 | img_hsv = color.rgb2hsv(img) 38 | color_mask_hsv = color.rgb2hsv(color_mask) 39 | 40 | img_hsv[..., 0] = color_mask_hsv[..., 0] 41 | img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha 42 | 43 | img_masked = color.hsv2rgb(img_hsv) 44 | return img_masked 45 | 46 | def remove_small_regions(img, size): 47 | """Morphologically removes small (less than size) connected regions of 0s or 1s.""" 48 | img = morphology.remove_small_objects(img, size) 49 | img = morphology.remove_small_holes(img, size) 50 | return img 51 | 52 | if __name__ == '__main__': 53 | 54 | # Path to csv-file. File should contain X-ray filenames as first column, 55 | # mask filenames as second column. 56 | # Load test data 57 | img_size = (256, 256) 58 | 59 | n_test = 61 60 | inp_shape = (256,256,3) 61 | batch_size=1 62 | 63 | # Load model 64 | #model_name = 'model.020.hdf5' 65 | #UNet = load_model(model_name) 66 | net = ResNetwithHDCDUC(1) 67 | 68 | net.load_state_dict(torch.load('Weights/cp_88_0.09929631054401397.pth.tar')) 69 | net.eval() 70 | 71 | 72 | ious = np.zeros(n_test) 73 | dices = np.zeros(n_test) 74 | seed = 1 75 | transformations_test = transforms.Compose([transforms.Resize(img_size), 76 | transforms.ToTensor()]) 77 | test_set = LungSegTest(transforms = transformations_test) 78 | test_loader = DataLoader(test_set, batch_size=batch_size) 79 | 80 | 81 | i = 0 82 | for xx, yy in test_loader: 83 | #img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 84 | pred = net(xx) 85 | pred = F.sigmoid(pred) 86 | pred = pred.detach().numpy()[0,0,:,:] 87 | mask = yy.numpy()[0,0,:,:] 88 | xx = xx.numpy()[0,:,:,:].transpose(1,2,0) 89 | img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 90 | 91 | # Binarize masks 92 | gt = mask > 0.5 93 | pr = pred > 0.5 94 | 95 | # Remove regions smaller than 2% of the image 96 | pr = remove_small_regions(pr, 0.02 * np.prod(img_size)) 97 | 98 | io.imsave('results/{}.png'.format(i), masked(img, gt, pr, 1)) 99 | 100 | ious[i] = IoU(gt, pr) 101 | dices[i] = Dice(gt, pr) 102 | 103 | i += 1 104 | if i == n_test: 105 | break 106 | 107 | print ('Mean IoU:', ious.mean()) 108 | print ('Mean Dice:', dices.mean()) 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /HDC_DUC/code/preprocess.py: -------------------------------------------------------------------------------- 1 | from skimage import io 2 | import os 3 | from tqdm import tqdm 4 | 5 | 6 | image_list = os.listdir('test_image') 7 | 8 | for images in tqdm(image_list): 9 | left_mask = io.imread('test_mask/left/'+images) 10 | right_mask = io.imread('test_mask/right/'+images) 11 | mask = left_mask + right_mask 12 | io.imsave('test_mask/'+images, mask) 13 | 14 | 15 | -------------------------------------------------------------------------------- /HDC_DUC/code/train_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from torchvision import transforms 6 | from torch.autograd import Variable 7 | from build_model import ResNetwithHDCDUC 8 | import csv 9 | 10 | import matplotlib as mpl 11 | mpl.use('Agg') 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | class SoftDiceLoss(nn.Module): 16 | def __init__(self, weight=None, size_average=True): 17 | super(SoftDiceLoss, self).__init__() 18 | 19 | def forward(self, logits, targets): 20 | smooth = 1. 21 | logits = F.sigmoid(logits) 22 | iflat = logits.view(-1) 23 | tflat = targets.view(-1) 24 | intersection = (iflat * tflat).sum() 25 | 26 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 27 | 28 | class InverseSoftDiceLoss(nn.Module): 29 | def __init__(self, weight=None, size_average=True): 30 | super(InverseSoftDiceLoss, self).__init__() 31 | 32 | def forward(self, logits, targets): 33 | smooth = 1. 34 | logits = F.sigmoid(logits) 35 | iflat = 1-logits.view(-1) 36 | tflat = 1-targets.view(-1) 37 | intersection = (iflat * tflat).sum() 38 | 39 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 40 | 41 | img_size = (256,256) 42 | transformations_train = transforms.Compose([transforms.Resize(img_size), 43 | transforms.RandomRotation(10), 44 | transforms.RandomHorizontalFlip(), 45 | transforms.ToTensor()]) 46 | 47 | transformations_test = transforms.Compose([transforms.Resize(img_size), 48 | transforms.ToTensor()]) 49 | 50 | 51 | 52 | from data_loader import LungSeg 53 | from data_loader import LungSegTest 54 | train_set = LungSeg(transforms = transformations_train) 55 | test_set = LungSegTest(transforms = transformations_test) 56 | batch_size = 5 57 | num_epochs = 100 58 | 59 | class Average(object): 60 | def __init__(self): 61 | self.reset() 62 | 63 | def reset(self): 64 | self.sum = 0 65 | self.count = 0 66 | 67 | def update(self, val, n=1): 68 | self.sum += val 69 | self.count += n 70 | 71 | @property 72 | def avg(self): 73 | return self.sum / self.count 74 | loss_list = [] 75 | vloss_list = [] 76 | 77 | def train(): 78 | cuda = torch.cuda.is_available() 79 | net = ResNetwithHDCDUC(1) 80 | if cuda: 81 | net = net.cuda() 82 | criterion1 = nn.BCEWithLogitsLoss().cuda() 83 | criterion2 = SoftDiceLoss().cuda() 84 | criterion3 = InverseSoftDiceLoss().cuda() 85 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) 86 | 87 | print("preparing training data ...") 88 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 89 | print("done ...") 90 | 91 | test_set = LungSegTest(transforms = transformations_test) 92 | test_loader = DataLoader(test_set, batch_size=batch_size) 93 | for epoch in range(num_epochs): 94 | train_loss = Average() 95 | net.train() 96 | 97 | if epoch<20: 98 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) 99 | elif 20<=epoch<30: 100 | optimizer = torch.optim.Adam(net.parameters(), lr=3.33e-5) 101 | elif 30<=epoch<40: 102 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-5) 103 | elif 40<=epoch<100: 104 | optimizer = torch.optim.Adam(net.parameters(), lr=3.33e-6) 105 | elif 150<=epoch<201: 106 | optimizer = torch.optim.Adam(net.parameters(), lr=1.e-6) 107 | 108 | 109 | for i, (images, masks) in enumerate(train_loader): 110 | images = Variable(images) 111 | masks = Variable(masks) 112 | if cuda: 113 | images = images.cuda() 114 | masks = masks.cuda() 115 | 116 | optimizer.zero_grad() 117 | outputs = net(images) 118 | loss = (criterion2(outputs, masks) + criterion1(outputs, masks) + 1.5*criterion3(outputs, masks))/3.5 119 | #loss_list.append(loss) 120 | 121 | loss.backward() 122 | optimizer.step() 123 | train_loss.update(loss.item(), images.size(0)) 124 | 125 | val_loss = Average() 126 | net.eval() 127 | for images, masks in test_loader: 128 | images = Variable(images) 129 | masks = Variable(masks) 130 | if cuda: 131 | images = images.cuda() 132 | masks = masks.cuda() 133 | 134 | outputs = net(images) 135 | vloss = criterion2(outputs, masks) 136 | #vloss_list.append(vloss) 137 | val_loss.update(vloss.item(), images.size(0)) 138 | 139 | print("Epoch {}, Loss: {}, Validation Loss: {}".format(epoch+1, train_loss.avg, val_loss.avg)) 140 | loss_list.append(train_loss.avg) 141 | vloss_list.append(val_loss.avg) 142 | 143 | with open('Log.csv', 'a') as logFile: 144 | FileWriter = csv.writer(logFile) 145 | FileWriter.writerow([epoch+1, train_loss.avg, val_loss.avg]) 146 | torch.save(net.state_dict(), 'Weights/cp_{}_{}.pth.tar'.format(epoch+1, val_loss.avg)) 147 | 148 | 149 | torch.save(net.state_dict(), 'cp.pth') 150 | """ 151 | key = list(range(2)) 152 | plt.plot(key, loss_list, 'r') 153 | plt.plot(key, vloss_list, 'b') 154 | plt.show() 155 | """ 156 | return net, loss_list, vloss_list 157 | 158 | def test(model): 159 | model.eval() 160 | 161 | 162 | 163 | if __name__ == "__main__": 164 | train() 165 | #print("done!") 166 | key = list(range(100)) 167 | plt.plot(key, loss_list, 'r') 168 | plt.plot(key, vloss_list, 'b') 169 | plt.show() 170 | plt.savefig('Losses plot') 171 | 172 | """ 173 | alpha0 = 1e-5 174 | k = 9.0/num_epochs 175 | 176 | optimizer = torch.optim.Adam(net.parameters(), lr=alpha0/(1+k*epoch)) 177 | """ 178 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ># Lung Segmentation 2 | >by [MeDAL - IIT Bombay](https://github.com/MEDAL-IITB) 3 | 4 | * [Introduction](#introduction) 5 | * [Dataset](#dataset) 6 | * [Montgomory Dataset](#montgomory-dataset) 7 | * [Data Preprocessing](#data-preprocessing) 8 | * [GCN](#global-convolutional-network) 9 | * [VGG Unet](#vgg-unet) 10 | * [SegNet](#segnet) 11 | * [HDC/DUC](#hybrid-dilated-convolution-and-dense-upsampling-convolution) 12 | * [Results](##Results) 13 | 14 | 15 | ## Introduction 16 | 17 | Chest X-ray (CXR) is one of the most commonly prescribed medical imaging procedures. Such large volume of CXR scans place significant workloads on radiologists and medical practitioners. 18 | Organ segmentation is a crucial step to obtain effective computer-aided detection on CXR. 19 | Future applications include 20 | 1. Abnormal shape/size of lungs 21 | - cardiomegaly (enlargement of the heart), pneumothorax (lung collapse), pleural effusion, and emphysema 22 | 23 | 2. An initial step (preprocessing) for deeper analysis - eg. tumor detection 24 | 25 | In this work, we demonstrate the effectiveness of Fully Convolution Networks (FCN) to segment lung fields in CXR images. 26 | FCN incorporates a critic network, consisting primarily of an encoder and a decoder network to impose segmentation to CXR. During training, the network learns to generate a mask which then can be used to segment the organ. Via supervised learning, the FCN learns the higher order structures and guides the segmentation model to achieve realistic segmentation outcomes 27 | ## Dataset 28 | This architecture is proposed to segment out lungs from a chest radiograph (colloquially know as chest X-Ray, CXR). The dataset is known as the [Montgomery County X-Ray Set](https://ceb.nlm.nih.gov/repositories/tuberculosis-chest-x-ray-image-data-sets/), which contains 138 posterior-anterior x-rays. The motivation being that this information can be further used to detect chest abnormalities like shrunken lungs or other structural deformities. This is especially useful in detecting tuberculosis in patients. 29 | 30 | ### Data Preprocessing 31 | The x-rays are 4892x4020 pixels big. Due to GPU memory limitations, they are resized to 1024x1024(gcn) or 256x256(others) 32 | 33 | The dataset is augmented by randomly rotating and flipping the images, and adding Gaussian noise to the images. 34 | 35 | ## Flow Chart 36 | ![ 37 | ](https://lh3.googleusercontent.com/4jhBbczKqk8j4k2NyvMzljuzpdZYUMqZHpiT4OSQ4F0Z_-yvZAfNCfC1ge6wvg-BI-MAwXGKQzjD "flowchart") 38 | 39 | ## Models 40 | 41 | ### Global Convolutional Network 42 | For details, go [here](https://github.com/MEDAL-IITB/Lung-Segmentation/tree/master/GCN/) . 43 | 44 | ### VGG Unet 45 | For details, go [here](https://github.com/MEDAL-IITB/Lung-Segmentation/tree/master/VGG_UNet/) . 46 | 47 | ### SegNet 48 | For details, go [here](https://github.com/MEDAL-IITB/Lung-Segmentation/tree/master/SegNet/) . 49 | 50 | ### Hybrid Dilated Convolution and Dense Upsampling Convolution 51 | For details, go [here](https://github.com/MEDAL-IITB/Lung-Segmentation/tree/master/HDC_DUC/) . 52 | 53 | ## Results 54 | A few of the results of the various models have been displayed below. (Scores are mean scores) 55 | | Model | Dice Score | IoU | 56 | | ----- | ---------------|-----------| 57 | |VGG UNet| 0.9623 | 0.9295 | 58 | |SegNet | 0.9293 | 0.8731 | 59 | |GCN | 0.907 | 0.8314 | 60 | |HDC/DUC | 0.8501 | 0.7462 | 61 | 62 | **U-Net Result** 63 | ![](https://lh3.googleusercontent.com/ku0vzfGUgolooGUqcYm6haipYcm_QLA33aw-ywOatslqHRX2cbat54HQsCRyX-xDpy2zkX2DuVx4 "UNet Results") 64 | 65 | **SegNet Result** 66 | ![enter image description here](https://lh3.googleusercontent.com/2SueAM5xuMZJ99UwSgW1-Ne4mRC9-WsXt7NyCZ0mMYh3wP9QlFPt_uFd80cIpqzmtBZEzXB5vGDu "SegNet") 67 | 68 | **HDC/DUC Result** 69 | ![ 70 | ](https://lh3.googleusercontent.com/emerB7tePI6Cw90KCaHhqtPj_26Uo7R1z2yafjwlNeKgfIk2m1saP9ybWm2ChB09LiyYOCXUY9a6 "hdc") 71 | 72 | **GCN Result** 73 | ![ 74 | ](https://lh3.googleusercontent.com/VxAr3JeDDNO1yocRDYmxwqcHdjCcg1lOZraIHz7XDSXy4YVU6U3TExnEdJeWdfAOEExQiWstoQh8 "gcn4") 75 | -------------------------------------------------------------------------------- /SegNet/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This report describes the usage of **SegNet** and **U-Net** architechtures for **medical image segmentation**. 3 | 4 | We divide the article into the following parts 5 | 6 | - [Dataset](#dataset) 7 | - [SegNet](#segnet) 8 | - [U-Net](#u-net) 9 | - [Loss Function Used](#loss-function-used) 10 | - Results 11 | - References 12 | - Further Help 13 | 14 | # Dataset 15 | ## Montgomory Dataset 16 | 17 | 18 | 19 | The dataset contains Chest X-Ray images. We use this dataset to perform a lung segmentation. 20 | >The dataset can be found [here](http://openi.nlm.nih.gov/imgs/collections/NLM-MontgomeryCXRSet.zip) 21 | 22 | 23 | 24 | Structure: 25 | 26 | We make the following structure of the given data set: 27 | 28 | ![](https://lh4.googleusercontent.com/J-kHm2BX9ywKISMuY_BaCaFf--UuPJOKlFYLO89gYgvjmqlM9RrFive2wOU30X8N7bzI03uwMCtnb_oCHDPaobyxTMEFlfsTSNXALS629uuAkSUZfm9y-lUv5FORquPe1P8CPp4p) 29 | 30 | ## Data Preprocessing 31 | ![](https://lh4.googleusercontent.com/TTLhU_8UxfxPWPURwLsqNbJu09EfPTReyCXHH9mX7saLzfK6aLgxK_NQd1VNeL7u1acwVnppg2pOZeLO9S4hoxpxjRSoXUHRlK8OAo6peOHpvv_zzTv2g43Wy4HMmk_i-aoATdEG) 32 | Figure 33 | 34 | The Montgomery dataset contains images from the Department of Health and Human Services, Montgomery County, Maryland, USA. The dataset consists of 35 | 36 | 138 CXRs, including 80 normal patients and 58 patients with manifested tuberculosis (TB). The CXR images are 12-bit gray-scale images of dimension 4020 × 4892 or 4892 × 4020 . Only the two lung masks annotations are available which were combined to a single image in order to make it easy for the network to learn the task of segmentation (Fig 1).To make all images of symmetric dimensions we padded the pictures to the maximum dimension in their height or width such that images are of 4892 x 4892, this is done to preserve the aspect ratio of CXR while resizing. We scale all images to 1024 x 1024 pixels, which retains sufficient visual details for vascular structures in the lung fields and this could be the maximum size that could be accommodated in, along with U-Net in Graphics Processing Unit (GPU). We scaled all pixel values to 0-1 . Data augmentation was applied by flipping around the vertical axis and adding gaussian noise with mean 0 and a variance of 0.01. Also rotation about the centre to subtle angles of 5-10 degrees during runtime were performed to make the model more robust. 37 | 38 | 39 | # SegNet 40 | 41 | ### Introduction 42 | ![SegNet](http://mi.eng.cam.ac.uk/projects/segnet/images/segnet.png) 43 | 44 | SegNet has an encoder network and a corresponding decoder network, followed by a final pixelwise classification layer. This architecture is illustrated in the above figure. The encoder network consists of 13 convolutional layers which correspond to the first 13 convolutional layers in the VGG16 network designed for object classification. We can therefore initialize the training process from weights trained for classification on large datasets. We can also discard the fully connected layers in favour of retaining higher resolution feature maps at the deepest encoder output. This also reduces the number of parameters in the SegNet encoder network significantly (from 134M to 14.7M) as compared to other recent architectures. Each encoder layer has a corresponding decoder layer and hence the decoder network has 13 layers. The final decoder output is fed to a multi-class soft-max classifier or for a binary classification task, to a sigmoid activation function to produce class probabilities for each pixel independently. Each encoder in the encoder network performs convolution with a filter bank to produce a set of feature maps. These are then batch normalized. Then an element-wise rectified- linear non-linearity (ReLU) max (0, x) is applied. Following that, max-pooling with a 2 × 2 window and stride 2 (non-overlapping window) is performed and the resulting output is sub-sampled by a factor of 2. Max-pooling is used to achieve translation invariance over small spatial shifts in the input image. Sub-sampling results in a large input image context (spatial window) for each pixel in the feature map. While several layers of max-pooling and sub-sampling can achieve more translation invariance for robust classification correspondingly there is a loss of spatial resolution of the feature maps. The increasingly lossy (boundary detail) image representation is not beneficial for segmentation where boundary delineation is vital. 45 | 46 | 47 | # Loss Function Used 48 | We use two loss functions here, viz. `Binary Cross Entropy` and `Dice loss` 49 | 50 | #### Binary Cross Entropy Loss 51 | Cross-entropy loss, or log loss, measures the performance of a classification model whose output is a probability value between 0 and 1. Cross-entropy loss increases as the predicted probability diverges from the actual label. So predicting a probability of `.012` when the actual observation label is `1` would be bad and result in a **high loss** value. A perfect model would have a log loss of `0`. 52 | 53 | It is defined mathematically as 54 | In binary classification, where the number of classes $M$ equals 2, cross-entropy can be calculated as: 55 | 56 | $$ 57 | -\frac{1}{N}\sum_{i=1}^N(y_{i}\log(p_{i}) + (1-y_{i})\log(1-p_{i})) 58 | $$ 59 | 60 | #### Dice Coefficient Loss 61 | The dice coefficient loss is used to measure the `intersection over union` of the output and target image. 62 | 63 | Mathematically, Dice Score is 64 | $$\frac{2 |P \cap R|}{|P| + |R|}$$ 65 | 66 | and the corresponding loss is 67 | $$1-\frac{2 |P\cap R|}{|P| + |R|}$$ 68 | 69 | $$1- \frac{2\sum_{i=0}^Np_{i}r_{i}+\epsilon}{\sum_{i=0}^Np_{i}+ \sum_{i=0}^Nr_{i}+\epsilon}\quad p_{i}\space\epsilon\space P,\space r_{i}\space\epsilon\space R$$ 70 | 71 | The dice loss is defined in code as : 72 | 73 | 74 | class SoftDiceLoss(nn.Module): 75 | def __init__(self, weight=None, size_average=True): 76 | super(SoftDiceLoss, self).__init__() 77 | 78 | def forward(self, logits, targets): 79 | smooth = 1 80 | num = targets.size(0) 81 | probs = F.sigmoid(logits) 82 | m1 = probs.view(num, -1) 83 | m2 = targets.view(num, -1) 84 | intersection = (m1 * m2) 85 | 86 | score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth) 87 | score = 1 - score.sum() / num 88 | return score 89 | 90 | #### Inverted Dice Coefficient Loss 91 | The formula below calculates the measure of overlap after inverting the image or in this case taking the complement. 92 | 93 | 94 | Mathematically, Inverted Dice Score is 95 | $$\frac{2|\overline{P}\cap\overline{R}|}{|\overline{P}| +|\overline{R}| }$$ 96 | and the corresponding loss is 97 | $$1-\frac{2|\overline{P}\cap\overline{R}|}{|\overline{P}| +|\overline{R}| }$$ 98 | $$1- \frac{2\sum_{i=0}^N(1-p_{i})(1-r_{i})+\epsilon}{\sum_{i=0}^N(1-p_{i})+ \sum_{i=0}^N(1-r_{i})+\epsilon}\quad p_{i}\space\epsilon\space P,\space r_{i}\space\epsilon\space R$$ 99 | 100 | 101 | class SoftInvDiceLoss(nn.Module): 102 | def __init__(self, weight=None, size_average=True): 103 | super(SoftDiceLoss, self).__init__() 104 | 105 | def forward(self, logits, targets): 106 | smooth = 1 107 | num = targets.size(0) 108 | probs = F.sigmoid(logits) 109 | m1 = probs.view(num, -1) 110 | m2 = targets.view(num, -1) 111 | m1, m2 = 1.-m1, 1.-m2 112 | intersection = (m1 * m2) 113 | 114 | score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth) 115 | score = 1 - score.sum() / num 116 | return score 117 | 118 | 119 | > NOTE: The reason why intersection is implemented as a multiplication and the cardinality as `sum()` on axis 1 (each 3 channels sum) is because predictions and targets are one-hot encoded vectors 120 | 121 | # Results 122 | 123 | | Loss | Validation Scores | Validation Scores | Test Scores | Test Scores | 124 | |:------------:|:------------:|-------------------|-------------------|-------------|-------------| 125 | | | mIoU | mDice | mIoU | mDice | mDice | mIoU | mDice | 126 | | BCE | 0.8867 | 0.9396 | - | - | 127 | | BCE+DCL | 0.9011 | 0.9477 | - | - | 128 | | BCE+DCL+IDCL | 0.9234 | 0.9600 | 0.8731 | 0.9293 | 129 | 130 | The results with this network are good, and the some of the best ones are shown here 131 | ![ 132 | ](https://lh3.googleusercontent.com/IYCDXq8yFP5drY8miu-IYgBUE9HfKWDTJxLxBlPdYXPCQtF4Gwrhk-EnKoRftNpE-Z4paZ90VFBA "results") 133 | 134 | 135 | # References 136 | [[1]](https://arxiv.org/abs/1511.00561) Vijay Badrinarayanan, Alex Kendall, and Roberto Cipolla. "Segnet: A deep convolutional encoder-decoder architecture for image segmentation". CoRR, abs/1511.00561, 2015 137 | 138 | [[2]](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. ”Imagenet classification with deep convolutional neural networks.” In Advances in neural information processing systems, pp. 1097-1105. 2012 139 | 140 | [[3]](https://arxiv.org/abs/1409.4842) Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. ”Going deeper with convolutions.” In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1-9. 2015 141 | 142 | [[4]](https://arxiv.org/abs/1502.03167) Ioffe, Sergey, and Christian Szegedy. ”Batch normalization: Accelerating deep network training by reducing internal covariate shift.” In International Conference on Machine Learning, pp. 448-456. 2015 143 | 144 | [[5]](http://jmlr.org/papers/v15/srivastava14a.html) Srivastava, Nitish, Geoffrey E. Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. ”Dropout: a simple way to prevent neural networks from overfitting.”Journal of machine learning research 15, no. 1 (2014): 1929-1958 145 | 146 | [[6]](https://arxiv.org/abs/1409.1556) Simonyan, Karen, and Andrew Zisserman. ”Very deep convolutional networks 147 | 148 | [7] Bishop, Christopher M. Pattern recognition and machine learning. springer, 2006 149 | 150 | [[8]](http://www.deeplearningbook.org) Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. MIT press,2016 151 | 152 | [[9]](https://www.tensorflow.org/) TensorFlow 153 | 154 | [[10]](https://keras.io/) Keras 155 | 156 | [[11]]([https://pytorch.org/docs/stable/index.html](https://pytorch.org/docs/stable/index.html)) Pytorch 157 | 158 | Return to main [README](www.github.com/medal-iitb/LungSegmentation/README.md) . 159 | -------------------------------------------------------------------------------- /SegNet/code/build_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | class SegNet(nn.Module): 7 | def __init__(self,input_nbr,label_nbr): 8 | super(SegNet, self).__init__() 9 | 10 | 11 | self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1) 12 | self.bn11 = nn.BatchNorm2d(64) 13 | self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 14 | self.bn12 = nn.BatchNorm2d(64) 15 | 16 | self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 17 | self.bn21 = nn.BatchNorm2d(128) 18 | self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 19 | self.bn22 = nn.BatchNorm2d(128) 20 | 21 | self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 22 | self.bn31 = nn.BatchNorm2d(256) 23 | self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 24 | self.bn32 = nn.BatchNorm2d(256) 25 | self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 26 | self.bn33 = nn.BatchNorm2d(256) 27 | 28 | self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 29 | self.bn41 = nn.BatchNorm2d(512) 30 | self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 31 | self.bn42 = nn.BatchNorm2d(512) 32 | self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 33 | self.bn43 = nn.BatchNorm2d(512) 34 | 35 | self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 36 | self.bn51 = nn.BatchNorm2d(512) 37 | self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 38 | self.bn52 = nn.BatchNorm2d(512) 39 | self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 40 | self.bn53 = nn.BatchNorm2d(512) 41 | 42 | self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 43 | self.bn53d = nn.BatchNorm2d(512) 44 | self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 45 | self.bn52d = nn.BatchNorm2d(512) 46 | self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 47 | self.bn51d = nn.BatchNorm2d(512) 48 | 49 | self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 50 | self.bn43d = nn.BatchNorm2d(512) 51 | self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 52 | self.bn42d = nn.BatchNorm2d(512) 53 | self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1) 54 | self.bn41d = nn.BatchNorm2d(256) 55 | 56 | self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 57 | self.bn33d = nn.BatchNorm2d(256) 58 | self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 59 | self.bn32d = nn.BatchNorm2d(256) 60 | self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1) 61 | self.bn31d = nn.BatchNorm2d(128) 62 | 63 | self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1) 64 | self.bn22d = nn.BatchNorm2d(128) 65 | self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1) 66 | self.bn21d = nn.BatchNorm2d(64) 67 | 68 | self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1) 69 | self.bn12d = nn.BatchNorm2d(64) 70 | self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1) 71 | self.Dropout = nn.Dropout(0.5) 72 | 73 | 74 | def forward(self, x): 75 | 76 | # Stage 1 77 | x11 = F.relu(self.bn11(self.conv11(x))) 78 | x11 = self.Dropout(x11) 79 | x12 = F.relu(self.bn12(self.conv12(x11))) 80 | x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True) 81 | 82 | # Stage 2 83 | x21 = F.relu(self.bn21(self.conv21(x1p))) 84 | x22 = F.relu(self.bn22(self.conv22(x21))) 85 | x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True) 86 | 87 | # Stage 3 88 | x31 = F.relu(self.bn31(self.conv31(x2p))) 89 | x31 = self.Dropout(x31) 90 | x32 = F.relu(self.bn32(self.conv32(x31))) 91 | x33 = F.relu(self.bn33(self.conv33(x32))) 92 | x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True) 93 | 94 | # Stage 4 95 | x41 = F.relu(self.bn41(self.conv41(x3p))) 96 | x42 = F.relu(self.bn42(self.conv42(x41))) 97 | x43 = F.relu(self.bn43(self.conv43(x42))) 98 | x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True) 99 | 100 | # Stage 5 101 | x51 = F.relu(self.bn51(self.conv51(x4p))) 102 | x51 = self.Dropout(x51) 103 | x52 = F.relu(self.bn52(self.conv52(x51))) 104 | x53 = F.relu(self.bn53(self.conv53(x52))) 105 | x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True) 106 | 107 | 108 | # Stage 5d 109 | x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2) 110 | x53d = F.relu(self.bn53d(self.conv53d(x5d))) 111 | x52d = F.relu(self.bn52d(self.conv52d(x53d))) 112 | x51d = F.relu(self.bn51d(self.conv51d(x52d))) 113 | 114 | # Stage 4d 115 | x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2) 116 | x43d = F.relu(self.bn43d(self.conv43d(x4d))) 117 | x42d = F.relu(self.bn42d(self.conv42d(x43d))) 118 | x41d = F.relu(self.bn41d(self.conv41d(x42d))) 119 | 120 | # Stage 3d 121 | x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2) 122 | x33d = F.relu(self.bn33d(self.conv33d(x3d))) 123 | x32d = F.relu(self.bn32d(self.conv32d(x33d))) 124 | x31d = F.relu(self.bn31d(self.conv31d(x32d))) 125 | 126 | # Stage 2d 127 | x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2) 128 | x22d = F.relu(self.bn22d(self.conv22d(x2d))) 129 | x21d = F.relu(self.bn21d(self.conv21d(x22d))) 130 | 131 | # Stage 1d 132 | x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2) 133 | x12d = F.relu(self.bn12d(self.conv12d(x1d))) 134 | x11d = self.conv11d(x12d) 135 | 136 | return x11d 137 | 138 | 139 | import torch 140 | import torch.nn as nn 141 | import torch.nn.functional as F 142 | 143 | from unet_parts import * 144 | 145 | 146 | class UNet(nn.Module): 147 | def __init__(self, n_channels, n_classes): 148 | super(UNet, self).__init__() 149 | self.inc = inconv(n_channels, 64) 150 | self.down1 = down(64, 128) 151 | self.down2 = down(128, 256) 152 | self.down3 = down(256, 512) 153 | self.down4 = down(512, 512) 154 | self.up1 = up(1024, 256) 155 | self.up2 = up(512, 128) 156 | self.up3 = up(256, 64) 157 | self.up4 = up(128, 64) 158 | self.outc = outconv(64, n_classes) 159 | self.Dropout = nn.Dropout(0.5) 160 | 161 | def forward(self, x): 162 | x1 = self.inc(x) 163 | x2 = self.down1(x1) 164 | x3 = self.down2(x2) 165 | x3 = self.Dropout(x3) 166 | x4 = self.down3(x3) 167 | x5 = self.down4(x4) 168 | x = self.up1(x5, x4) 169 | x = self.up2(x, x3) 170 | x = self.up3(x, x2) 171 | x = self.up4(x, x1) 172 | x = self.outc(x) 173 | return x 174 | -------------------------------------------------------------------------------- /SegNet/code/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data.dataset import Dataset 3 | from torchvision import transforms 4 | from skimage import io, transform 5 | from PIL import Image 6 | import os 7 | import numpy as np 8 | 9 | 10 | class LungSegTrain(Dataset): 11 | def __init__(self, path='Images_padded1/Train/', transforms=None): 12 | self.path = path 13 | self.list = os.listdir(self.path) 14 | 15 | self.transforms = transforms 16 | 17 | def __getitem__(self, index): 18 | # stuff 19 | image_path = 'Images_padded1/Train/' 20 | mask_path = 'Mask_padded1/Train/' 21 | image = Image.open(image_path+self.list[index]) 22 | image = image.convert('RGB') 23 | mask = Image.open(mask_path+self.list[index]) 24 | mask = mask.convert('L') 25 | if self.transforms is not None: 26 | image = self.transforms(image) 27 | mask = self.transforms(mask) 28 | # If the transform variable is not empty 29 | # then it applies the operations in the transforms with the order that it is created. 30 | return (image, mask) 31 | 32 | def __len__(self): 33 | return len(self.list) # of how many data(images?) you have 34 | 35 | 36 | class LungSegVal(Dataset): 37 | def __init__(self, path='Images_padded_test/', transforms=None): 38 | self.path = path 39 | self.list = os.listdir(self.path) 40 | 41 | self.transforms = transforms 42 | 43 | def __getitem__(self, index): 44 | # stuff 45 | image_path = 'Images_padded_test/' 46 | mask_path = 'Mask_padded_test/' 47 | image_name = self.list[index] 48 | image = Image.open(image_path+self.list[index]) 49 | image = image.convert('RGB') 50 | mask = Image.open(mask_path+self.list[index]) 51 | mask = mask.convert('L') 52 | if self.transforms is not None: 53 | image = self.transforms(image) 54 | mask = self.transforms(mask) 55 | # If the transform variable is not empty 56 | # then it applies the operations in the transforms with the order that it is created. 57 | return (image, mask, image_name) 58 | 59 | def __len__(self): 60 | return len(self.list) 61 | -------------------------------------------------------------------------------- /SegNet/code/gaussian_noise_and_flipping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | from skimage import io 5 | from skimage.util import random_noise 6 | from skimage.filters import gaussian 7 | from tqdm import tqdm 8 | 9 | image_names = os.listdir("Images_padded1/Train/") 10 | 11 | for name in tqdm(image_names): 12 | img = io.imread('Images_padded1/Train/'+name) 13 | mask = io.imread('Mask_padded1/Train/'+name) 14 | 15 | noise_image_01 = random_noise(img, mode='gaussian', seed=None, clip=True, var = 0.01) 16 | noise_image_01 = noise_image_01*255. 17 | noise_image_01 = noise_image_01.astype('uint8') 18 | io.imsave('Images_padded1/Train/'+name[:-4]+'_01.png', noise_image_01) 19 | io.imsave('Mask_padded1/Train/'+name[:-4]+'_01.png', mask) 20 | 21 | ''' 22 | image_names = os.listdir("Images_padded1/Train/") 23 | 24 | for name in tqdm(image_names): 25 | img = cv2.imread('Images_padded1/Train/'+name,0) 26 | mask = cv2.imread('Mask_padded1/Train/'+name,0) 27 | 28 | # copy image to display all 4 variations 29 | horizontal_img = img.copy() 30 | horizontal_mask = mask.copy() 31 | 32 | 33 | # flip img horizontally, vertically, 34 | # and both axes with flip() 35 | horizontal_img = cv2.flip( img, 1 ) 36 | horizontal_mask = cv2.flip(mask, 1) 37 | 38 | 39 | # display the images on screen with imshow() 40 | cv2.imwrite( 'Images_padded1/Train/'+name[:-4]+'_flip.png', horizontal_img ) 41 | cv2.imwrite( 'Mask_padded1/Train/'+name[:-4]+'_flip.png', horizontal_mask ) 42 | 43 | ''' 44 | ''' 45 | img = io.imread('MCUCXR_0006_0.png') 46 | 47 | 48 | noise_image_1 = random_noise(img, mode='gaussian', seed=None, clip=True, var = 0.001) 49 | noise_image_1 = noise_image_1*255. 50 | noise_image_1 = noise_image_1.astype('uint8') 51 | 52 | io.imsave('MCUCXR_0006_0_noise_1.png', noise_image_1) 53 | ''' 54 | ''' 55 | blur = gaussian(img, sigma=5, output=None, mode='nearest', cval=0, multichannel=None, preserve_range=False, truncate=4.0) 56 | blur = blur*255. 57 | blur = blur.astype('uint16') 58 | io.imsave('MCUCXR_0006_0_vlur.png', blur) 59 | ''' 60 | -------------------------------------------------------------------------------- /SegNet/code/inferences.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 4 | import torch 5 | from build_model import * 6 | from torch.utils.data import DataLoader 7 | from data_loader import LungSegVal 8 | from torchvision import transforms 9 | import torch.nn.functional as F 10 | import numpy as np 11 | 12 | from skimage import morphology, color, io, exposure 13 | 14 | def IoU(y_true, y_pred): 15 | """Returns Intersection over Union score for ground truth and predicted masks.""" 16 | assert y_true.dtype == bool and y_pred.dtype == bool 17 | y_true_f = y_true.flatten() 18 | y_pred_f = y_pred.flatten() 19 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 20 | union = np.logical_or(y_true_f, y_pred_f).sum() 21 | return (intersection + 1) * 1. / (union + 1) 22 | 23 | def Dice(y_true, y_pred): 24 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 25 | assert y_true.dtype == bool and y_pred.dtype == bool 26 | y_true_f = y_true.flatten() 27 | y_pred_f = y_pred.flatten() 28 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 29 | return (2. * intersection + 1.) / (y_true.sum() + y_pred.sum() + 1.) 30 | 31 | def masked(img, gt, mask, alpha=1): 32 | """Returns image with GT lung field outlined with red, predicted lung field 33 | filled with blue.""" 34 | rows, cols = img.shape[:2] 35 | color_mask = np.zeros((rows, cols, 3)) 36 | boundary = morphology.dilation(gt, morphology.disk(3))^gt 37 | color_mask[mask == 1] = [0, 0, 1] 38 | color_mask[boundary == 1] = [1, 0, 0] 39 | 40 | img_hsv = color.rgb2hsv(img) 41 | color_mask_hsv = color.rgb2hsv(color_mask) 42 | 43 | img_hsv[..., 0] = color_mask_hsv[..., 0] 44 | img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha 45 | 46 | img_masked = color.hsv2rgb(img_hsv) 47 | return img_masked 48 | 49 | def remove_small_regions(img, size): 50 | """Morphologically removes small (less than size) connected regions of 0s or 1s.""" 51 | img = morphology.remove_small_objects(img, size) 52 | img = morphology.remove_small_holes(img, size) 53 | return img 54 | 55 | if __name__ == '__main__': 56 | 57 | # Path to csv-file. File should contain X-ray filenames as first column, 58 | # mask filenames as second column. 59 | # Load test data 60 | img_size = (864, 864) 61 | 62 | inp_shape = (864,864,3) 63 | batch_size=1 64 | 65 | # Load model 66 | #model_name = 'model.020.hdf5' 67 | #UNet = load_model(model_name) 68 | net = SegNet(3,1) 69 | net.cuda() 70 | 71 | net.load_state_dict(torch.load('Weights_BCE_Dice_InvDice/cp_bce_flip_lr_04_no_rot52_0.04634043872356415.pth.tar')) 72 | net.eval() 73 | 74 | 75 | 76 | seed = 1 77 | transformations_test = transforms.Compose([transforms.Resize((864,864)),transforms.ToTensor()]) 78 | test_set = LungSegVal(transforms = transformations_test) 79 | test_loader = DataLoader(test_set, batch_size=batch_size, shuffle = False) 80 | ious = np.zeros(len(test_loader)) 81 | dices = np.zeros(len(test_loader)) 82 | if not(os.path.exists('./results_Unet_BCE_Dice_InvDice_test')): 83 | os.mkdir('./results_Unet_BCE_Dice_InvDice_test') 84 | 85 | 86 | i = 0 87 | for xx, yy, name in tqdm(test_loader): 88 | xx = xx.cuda() 89 | yy = yy 90 | 91 | name = name[0][:-4] 92 | print (name) 93 | pred = net(xx) 94 | pred = F.sigmoid(pred) 95 | pred = pred.cpu() 96 | pred = pred.detach().numpy()[0,0,:,:] 97 | mask = yy.numpy()[0,0,:,:] 98 | xx = xx.cpu() 99 | xx = xx.numpy()[0,:,:,:].transpose(1,2,0) 100 | img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 101 | 102 | # Binarize masks 103 | gt = mask > 0.5 104 | pr = pred > 0.5 105 | 106 | # Remove regions smaller than 2% of the image 107 | #pr = remove_small_regions(pr, 0.02 * np.prod(img_size)) 108 | 109 | 110 | io.imsave('results_Unet_BCE_Dice_InvDice_test/{}.png'.format(name), pr*255) 111 | 112 | ious[i] = IoU(gt, pr) 113 | dices[i] = Dice(gt, pr) 114 | 115 | i += 1 116 | if i == len(test_loader): 117 | break 118 | 119 | print ('Mean IoU:', ious.mean()) 120 | print ('Mean Dice:', dices.mean()) 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /SegNet/code/metric_calc.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import cv2 4 | import os 5 | import sys 6 | from PIL import Image 7 | 8 | 9 | 10 | 11 | def get_IoU(Gi,Si): 12 | #print(Gi.shape, Si.shape) 13 | intersect = 1.0*np.sum(np.logical_and(Gi,Si)) 14 | union = 1.0*np.sum(np.logical_or(Gi,Si)) 15 | return intersect/union 16 | #check cv2.connectedComponents and what it returns, alongwith channels first or last 17 | def generate_list(G,S): 18 | G = G.astype('uint8') 19 | S = S.astype('uint8') 20 | #print(np.unique(G)) 21 | gland_obj_cnt,gland_obj = cv2.connectedComponents(G,connectivity=8) 22 | seg_obj_cnt,seg_obj = cv2.connectedComponents(S,connectivity=8) 23 | gland_obj_list = [] 24 | seg_obj_list = [] 25 | for i in range(1,gland_obj_cnt): 26 | gland_obj_list.append( (gland_obj==(i)).astype('int32') ) 27 | for i in range(1,seg_obj_cnt): 28 | seg_obj_list.append( (seg_obj==(i)).astype('int32') ) 29 | gland_obj_list = np.array(gland_obj_list) 30 | seg_obj_list = np.array(seg_obj_list) 31 | return gland_obj_list,seg_obj_list 32 | 33 | ####Find why channel parameter was passed 34 | def AGI_core(gland_obj_list,seg_obj_list,channel='last'): 35 | C = 0.0 36 | U = 0.0 37 | ##check below: 38 | ''' 39 | Swapping is not required. 40 | if(channel=='last'): 41 | # make channels first 42 | gland_obj_list = np.swapaxes( np.swapaxes( gland_obj_list , 0,2) , 1 , 2 ) 43 | seg_obj_list = np.swapaxes( np.swapaxes( seg_obj_list , 0,2) , 1 , 2 ) 44 | ''' 45 | #print(gland_obj_list.shape) 46 | seg_nonused = np.ones(len(seg_obj_list)) 47 | for gi in gland_obj_list: 48 | iou = np.multiply( [get_IoU(gi,si) for si in seg_obj_list] , seg_nonused ) 49 | max_iou = np.max(iou) 50 | j = np.argmax(iou) 51 | C = C + np.sum(np.logical_and(gi,seg_obj_list[j]) ) 52 | U = U + np.sum(np.logical_or(gi,seg_obj_list[j]) ) 53 | seg_nonused[j] = 0 54 | for ind in range(len(seg_obj_list)): 55 | if((seg_nonused[ind])==1): 56 | U = U + np.sum(seg_obj_list[ind]) 57 | return C*1./U 58 | 59 | def Acc_Jacard_Index(G,S): 60 | gland_obj_list,seg_obj_list = generate_list(G, S) 61 | print("In AJI and length is:{}".format(len(gland_obj_list))) 62 | return AGI_core(gland_obj_list,seg_obj_list) 63 | 64 | #----------------------------------------------- 65 | 66 | def F1_core(gland_obj_list,seg_obj_list,channel='first'): 67 | TP,FP,FN = 0.0,0.0,0.0 68 | if(channel=='last'): 69 | # make channels first 70 | gland_obj_list = np.swapaxes( np.swapaxes( gland_obj_list , 0,2) , 1 , 2 ) 71 | seg_obj_list = np.swapaxes( np.swapaxes( seg_obj_list , 0,2) , 1 , 2 ) 72 | seg_nonused = np.ones(len(seg_obj_list)) 73 | gland_unhit = np.ones(len(seg_obj_list)) 74 | 75 | for ind in range(len(gland_obj_list)): 76 | gi = gland_obj_list[ind] 77 | overlap_s = np.multiply( np.sum( seg_obj_list*gi , axis=(1,2) ) , seg_nonused ) 78 | max_ov = np.max(overlap_s) 79 | percent_overlap = max_ov/np.sum(gi) 80 | if percent_overlap>=0.01 : 81 | # hit 82 | TP = TP +1 83 | j = np.argmax(overlap_s) 84 | seg_nonused[j] = 0 85 | else: 86 | # unhit 87 | FN = FN + 1 88 | 89 | FP = np.sum(seg_nonused) 90 | F1_val = (2*TP)/(2*TP + FP + FN) 91 | return F1_val 92 | 93 | def F1_score(G,S): 94 | #y_mask = np.asarray(G[:, :, :, 0]).astype('uint8') 95 | #print y_mask.shape 96 | #y_pred = np.asarray(S[:, :, :, 0]).astype('uint8') 97 | #print y_pred.shape 98 | #print type(y_pred) 99 | #y_dist = K.expand_dims(G[:, :, :, 1], axis=-1) 100 | gland_obj_list,seg_obj_list = generate_list(G, S) 101 | return F1_core(gland_obj_list,seg_obj_list) 102 | 103 | def Dice(y_true, y_pred): 104 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 105 | #print(y_true.dtype) 106 | #print(y_pred.dtype) 107 | y_true = np.squeeze(y_true)/255 108 | y_pred = np.squeeze(y_pred)/255 109 | y_true.astype('bool') 110 | y_pred.astype('bool') 111 | intersection = np.logical_and(y_true, y_pred).sum() 112 | return ((2. * intersection.sum()) + 1.) / (y_true.sum() + y_pred.sum() + 1.) 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | smooth = 1 121 | 122 | image_names = os.listdir('results_scratch_custom99/') 123 | 124 | mean_dice = [] 125 | mean_F1 = [] 126 | aggr_jacard = [] 127 | 128 | 129 | for images in tqdm(image_names): 130 | 131 | S = np.expand_dims(np.array(Image.open('results_scratch_custom99/'+images).convert('L')),axis=-1) 132 | G = np.expand_dims(np.array(Image.open('/home/sahyadri/Testing/Test_40_y_HE/'+images).convert('L')),axis=-1) 133 | #print S.shape 134 | #G.shape 135 | #print(Acc_Jacard_Index(G,S)) 136 | #aggr_jacard.append(Acc_Jacard_Index(G,S)) 137 | #mean_F1.append(F1_score(G,S)) 138 | mean_dice.append(Dice(G, S)) 139 | 140 | 141 | print ('Mean_Dice = ', np.mean(np.array(mean_dice))) 142 | #print ('Mean_F1 = ', np.mean(np.array(mean_F1))) 143 | #print (len(aggr_jacard), aggr_jacard) 144 | #print ('Mean_Aggr_Jacard = ', np.mean(np.array(aggr_jacard))) 145 | 146 | f = open('lung.txt','w') 147 | a = 'Mean Dice : {}'.format(np.mean(np.array(mean_dice)))+ '\n' + 'Mean F1 : {}'.format(np.mean(np.array(mean_F1)))+ '\n' + 'Mean Aggregate Jacard : {}'.format(np.mean(np.array(aggr_jacard)))+ '\n' 148 | 149 | f.write(str(a)) 150 | f.close() 151 | 152 | -------------------------------------------------------------------------------- /SegNet/code/preprocess_combining.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage import io, exposure 4 | 5 | 6 | 7 | def make_masks(): 8 | path = 'test_image/' 9 | for i, filename in enumerate(os.listdir(path)): 10 | left = io.imread('test_mask/left/' + filename[:-4] + '.png') 11 | right = io.imread('test_mask/right/' + filename[:-4] + '.png') 12 | io.imsave('test_mask/Mask/' + filename[:-4] + '.png', np.clip(left + right, 0, 255)) 13 | print ('Mask', i, filename) 14 | 15 | make_masks() 16 | 17 | -------------------------------------------------------------------------------- /SegNet/code/preprocess_padding.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | from PIL import Image 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | img_names = os.listdir('test_image/') 10 | 11 | 12 | for name in tqdm(img_names): 13 | 14 | img = np.asarray(Image.open('test_image/'+name)) 15 | max_dim_img = max(img.shape[0], img.shape[1]) 16 | row_img = max_dim_img - img.shape[0] 17 | cols_img = max_dim_img - img.shape[1] 18 | padded_img = np.pad(img, ((row_img//2,row_img//2), (cols_img//2,cols_img//2)), mode ='constant') 19 | img = Image.fromarray(padded_img) 20 | img.save('Images_padded_test/'+name) 21 | 22 | mask = np.asarray(Image.open('test_mask/Mask/'+name)) 23 | max_dim_mask = max(mask.shape[0], mask.shape[1]) 24 | row_mask = max_dim_mask- mask.shape[0] 25 | cols_mask = max_dim_mask - mask.shape[1] 26 | padded_mask = np.pad(mask, ((row_mask//2,row_mask//2), (cols_mask//2,cols_mask//2)), mode ='constant') 27 | mask = Image.fromarray(padded_mask) 28 | mask.save('Mask_padded_test/'+name) 29 | 30 | -------------------------------------------------------------------------------- /SegNet/code/train_model_segnet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torch.autograd import Variable 8 | from build_model import * 9 | import os 10 | from tqdm import tqdm 11 | from tensorboardX import SummaryWriter 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 15 | #-------------------------- 16 | class Average(object): 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.sum += val 26 | self.count += n 27 | 28 | #property 29 | def avg(self): 30 | return self.sum / self.count 31 | #------------------------------ 32 | # import csv 33 | writer = SummaryWriter() 34 | #---------------------------------------- 35 | class SoftDiceLoss(nn.Module): 36 | ''' 37 | Soft Dice Loss 38 | ''' 39 | def __init__(self, weight=None, size_average=True): 40 | super(SoftDiceLoss, self).__init__() 41 | 42 | def forward(self, logits, targets): 43 | smooth = 1. 44 | logits = F.sigmoid(logits) 45 | iflat = logits.view(-1) 46 | tflat = targets.view(-1) 47 | intersection = (iflat * tflat).sum() 48 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 49 | #------------------------------------------------- 50 | ''' 51 | class SoftDicescore(nn.Module): 52 | ''' 53 | #Soft Dice Loss 54 | ''' 55 | def __init__(self, weight=None, size_average=True): 56 | super(SoftDicescore, self).__init__() 57 | 58 | def forward(self, logits, targets): 59 | smooth = 1. 60 | logits = F.sigmoid(logits) 61 | iflat = logits.view(-1) 62 | tflat = targets.view(-1) 63 | intersection = (iflat * tflat).sum() 64 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 65 | ''' 66 | #------------------------------------------------- 67 | ''' 68 | class W_bce(nn.Module): 69 | 70 | #weighted crossentropy per image 71 | 72 | def __init__(self, weight=None, size_average=True): 73 | super(W_bce, self).__init__() 74 | 75 | def forward(self, logits, targets): 76 | eps = 1e-6 77 | total_size = targets.view(-1).size()[0] 78 | #print "total_size", total_size 79 | ones_size = torch.sum(targets.view(-1,1)).item() 80 | #print "one_size", ones_size 81 | zero_size = total_size - ones_size 82 | #print "zero_size", zero_size 83 | #assert total_size == (ones_size + zero_size) 84 | #print "crossed assertion" 85 | loss_1 = torch.mean(-(targets.view(-1)* ( total_size/ones_size) * torch.log(torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps))))#.sum(axis=1) 86 | #print "crossed loss1" 87 | loss_0 = torch.mean(-((1.-targets.view(-1))* ( total_size/zero_size) * torch.log((1.-torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps)))))#.sum(axis=1) 88 | #print "crossed loss0" 89 | return loss_1 + loss_0 90 | ''' 91 | #---------------------------------- 92 | class InvSoftDiceLoss(nn.Module): 93 | 94 | ''' 95 | Inverted Soft Dice Loss 96 | ''' 97 | def __init__(self, weight=None, size_average=True): 98 | super(InvSoftDiceLoss, self).__init__() 99 | 100 | def forward(self, logits, targets): 101 | smooth = 1. 102 | logits = F.sigmoid(logits) 103 | iflat = 1-logits.view(-1) 104 | tflat = 1-targets.view(-1) 105 | intersection = (iflat * tflat).sum() 106 | 107 | 108 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 109 | #-------------------------------------- 110 | ''' 111 | class InvSoftDicescore(nn.Module): 112 | 113 | ''' 114 | #Inverted Soft Dice Loss 115 | ''' 116 | 117 | def __init__(self, weight=None, size_average=True): 118 | super(InvSoftDicescore, self).__init__() 119 | 120 | def forward(self, logits, targets): 121 | smooth = 1. 122 | logits = F.sigmoid(logits) 123 | iflat = 1-logits.view(-1) 124 | tflat = 1-targets.view(-1) 125 | intersection = (iflat * tflat).sum() 126 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 127 | ''' 128 | #---------------------------------------- 129 | ''' 130 | class int_custom_loss(nn.Module): 131 | ''' 132 | #custom loss 133 | ''' 134 | def __init__(self, weight=None, size_average=True): 135 | super(int_custom_loss, self).__init__() 136 | 137 | def forward(self, logits, targets): 138 | loss_inv_dice = InvSoftDicescore() 139 | loss_dice = SoftDicescore() 140 | total_size = targets.view(-1).size()[0] 141 | ones_size = torch.sum(targets.view(-1,1)).item() 142 | th = 0.2 * total_size 143 | if(ones_size > th): 144 | return (- 0.8*torch.log(loss_dice(logits,targets))-0.2*torch.log(loss_inv_dice(logits, targets))) 145 | else: 146 | return(-0.2*torch.log(loss_dice(logits, targets))-0.8*torch.log(loss_inv_dice(logits, targets))) 147 | ''' 148 | ''' 149 | class weighted_dice_invdice(nn.Module): 150 | ''' 151 | #custom loss 152 | ''' 153 | def __init__(self, weight=None, size_average=True): 154 | super(weighted_dice_invdice, self).__init__() 155 | 156 | def forward(self, logits, targets): 157 | loss_inv_dice = InvSoftDicescore() 158 | loss_dice = SoftDicescore() 159 | total_size = targets.view(-1).size()[0] 160 | ones_size = torch.sum(targets.view(-1,1)).item() 161 | zero_size = total_size - ones_size 162 | th = 0.2 * total_size 163 | return (-(zero_size/total_size)*torch.log(loss_dice(logits,targets))-(ones_size/total_size)*torch.log(loss_inv_dice(logits, targets))) 164 | ''' 165 | 166 | #Tranformations------------------------------------------------ 167 | transformations_train = transforms.Compose([transforms.Resize((864,864)),transforms.ToTensor()]) 168 | 169 | transformations_val = transforms.Compose([transforms.Resize((864,864)),transforms.ToTensor()]) 170 | #------------------------------------------------------------- 171 | 172 | from data_loader import LungSegTrain 173 | from data_loader import LungSegVal 174 | train_set = LungSegTrain(transforms = transformations_train) 175 | batch_size = 1 176 | num_epochs = 75 177 | 178 | def train(): 179 | cuda = torch.cuda.is_available() 180 | net = SegNet(3,1) 181 | if cuda: 182 | net = net.cuda() 183 | #net.load_state_dict(torch.load('Weights_BCE_Dice/cp_bce_lr_05_100_0.222594484687.pth.tar')) 184 | criterion1 = nn.BCEWithLogitsLoss().cuda() 185 | criterion2 = SoftDiceLoss().cuda() 186 | criterion3 = InvSoftDiceLoss().cuda() 187 | #criterion4 = W_bce().cuda() 188 | #criterion5 = int_custom_loss() 189 | #criterion6 = weighted_dice_invdice() 190 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) 191 | #scheduler = MultiStepLR(optimizer, milestones=[2,10,75,100], gamma=0.1) 192 | 193 | print("preparing training data ...") 194 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 195 | print("done ...") 196 | val_set = LungSegVal(transforms = transformations_val) 197 | val_loader = DataLoader(val_set, batch_size=batch_size,shuffle=False) 198 | for epoch in tqdm(range(num_epochs)): 199 | #scheduler.step() 200 | train_loss = Average() 201 | net.train() 202 | for i, (images, masks) in tqdm(enumerate(train_loader)): 203 | images = Variable(images) 204 | masks = Variable(masks) 205 | if cuda: 206 | images = images.cuda() 207 | masks = masks.cuda() 208 | 209 | optimizer.zero_grad() 210 | outputs = net(images) 211 | #writer.add_image('Training Input',images) 212 | #writer.add_image('Training Pred',F.sigmoid(outputs)>0.5) 213 | c1 = criterion1(outputs,masks) + criterion2(outputs, masks) + criterion3(outputs, masks) 214 | loss = c1 215 | writer.add_scalar('Train Loss',loss,epoch) 216 | loss.backward() 217 | optimizer.step() 218 | train_loss.update(loss.item(), images.size(0)) 219 | for param_group in optimizer.param_groups: 220 | writer.add_scalar('Learning Rate',param_group['lr']) 221 | val_loss1 = Average() 222 | val_loss2 = Average() 223 | val_loss3 = Average() 224 | net.eval() 225 | for images, masks,_ in tqdm(val_loader): 226 | images = Variable(images) 227 | masks = Variable(masks) 228 | if cuda: 229 | images = images.cuda() 230 | masks = masks.cuda() 231 | 232 | outputs = net(images) 233 | if (epoch)%10==0: 234 | writer.add_image('Validation Input',images,epoch) 235 | writer.add_image('Validation GT ',masks,epoch) 236 | writer.add_image('Validation Pred0.5',F.sigmoid(outputs)>0.5,epoch) 237 | writer.add_image('Validation Pred0.3',F.sigmoid(outputs)>0.3,epoch) 238 | writer.add_image('Validation Pred0.65',F.sigmoid(outputs)>0.65,epoch) 239 | 240 | vloss1 = criterion1(outputs, masks) 241 | vloss2 = criterion2(outputs, masks) 242 | vloss3 = criterion3(outputs, masks) #+ criterion2(outputs, masks) 243 | #vloss = vloss2 + vloss3 244 | writer.add_scalar('Validation loss(BCE)',vloss1,epoch) 245 | writer.add_scalar('Validation loss(Dice)',vloss2,epoch) 246 | writer.add_scalar('Validation loss(InvDice)',vloss3,epoch) 247 | 248 | val_loss1.update(vloss1.item(), images.size(0)) 249 | val_loss2.update(vloss2.item(), images.size(0)) 250 | val_loss3.update(vloss3.item(), images.size(0)) 251 | 252 | print("Epoch {}, Training Loss(BCE+Dice): {}, Validation Loss(BCE): {}, Validation Loss(Dice): {}, Validation Loss(InvDice): {}".format(epoch+1, train_loss.avg(), val_loss1.avg(), val_loss2.avg(), val_loss3.avg())) 253 | 254 | # with open('Log.csv', 'a') as logFile: 255 | # FileWriter = csv.writer(logFile) 256 | # FileWriter.writerow([epoch+1, train_loss.avg, val_loss1.avg, val_loss2.avg, val_loss3.avg]) 257 | 258 | torch.save(net.state_dict(), 'Weights_BCE_Dice_InvDice/cp_bce_flip_lr_04_no_rot{}_{}.pth.tar'.format(epoch+1, val_loss2.avg())) 259 | return net 260 | 261 | def test(model): 262 | model.eval() 263 | 264 | 265 | 266 | if __name__ == "__main__": 267 | train() 268 | -------------------------------------------------------------------------------- /VGG_UNet/README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This report describes the usage of **VGG Unet** architecture for **medical image segmentation**. 3 | 4 | We divide the article into the following parts 5 | 6 | - [Dataset](#dataset) 7 | - [VGG Unet Architecture](#vgg-unet) 8 | - [Results](#results) 9 | - [References](#references) 10 | 11 | # Dataset 12 | ## Montgomory Dataset 13 | 14 | The dataset contains Chest X-Ray images. We use this dataset to perform lung segmentation. 15 | >The dataset can be found [here](http://openi.nlm.nih.gov/imgs/collections/NLM-MontgomeryCXRSet.zip) 16 | 17 | 18 | 19 | Structure: 20 | 21 | We make the following structure of the given data set: 22 | 23 | ![](https://lh4.googleusercontent.com/J-kHm2BX9ywKISMuY_BaCaFf--UuPJOKlFYLO89gYgvjmqlM9RrFive2wOU30X8N7bzI03uwMCtnb_oCHDPaobyxTMEFlfsTSNXALS629uuAkSUZfm9y-lUv5FORquPe1P8CPp4p) 24 | 25 | ## Data Preprocessing 26 | We apply random rotations as the only augmentation technique, as any other technique like center crop or flip will distort the data, and the results won’t be as expected 27 | 28 | Since each image is approx. `4000X4000`, we resize the images to a manageable size of `512X512` as we were limited by the GPU memory. 29 | 30 | 31 | # VGG Unet 32 | 33 | ## Introduction 34 | 35 | We use **Ternaus-Net** *([Vladimir Iglovikov](https://arxiv.org/search?searchtype=author&query=Iglovikov%2C+V), [Alexey Shvets](https://arxiv.org/search?searchtype=author&query=Shvets%2C+A))* a network that is used to train on medical images to segment the image according to a given mask, that uses a `VGG 11` pretrained encoder. 36 | 37 | ## Implementation of the network 38 | 39 | The VGG 11 is implemented as `configuration A` specified in the following image 40 | 41 | > ![enter image description here](https://qph.ec.quoracdn.net/main-qimg-30abbdf1982c8cb049ac65f3cf9d5640) 42 | > (source: https://www.quora.com/How-does-VGG-network-architecture-work) 43 | we remove the FC, last maxpool and soft max layers and then the upsampling layers are mirrored to match the channels of each VGG layer. Skip connections are added per layer. The entire network can be visualized as 44 | 45 | ![](https://camo.githubusercontent.com/cf2ff198ddd4f4600726fa0f2844e77c4041186b/68747470733a2f2f686162726173746f726167652e6f72672f776562742f68752f6a692f69722f68756a696972767067706637657377713838685f783761686c69772e706e67) 46 | (Source: https://github.com/ternaus/TernausNet) 47 | 48 | VGG net can be visualized as: 49 | 50 | ![](https://www.cs.toronto.edu/~frossard/post/vgg16/vgg16.png) 51 | (Source: https://www.cs.toronto.edu/~frossard/post/vgg16/) 52 | 53 | 54 | ## Working 55 | 56 | We input the images to the network, which is first passed through the VGG 11 encoder, this outputs a 512 channel feature map. This feature map is then upsampled back to the orignal size using transposed convolutions 57 | 58 | The Transposed Convolution can be mathematically shows in the simplest form as an operation to upsample a given feature map to the desired dimentions. 59 | 60 | 61 | 62 | ![](https://lh5.googleusercontent.com/qOJ46aQEsUShQswuF9m7Sj7ZVocttxzxZHBm1jzhpb80gE8VSDpzBayc2KGnaCC2INmoUbrXu3-HUXNfzWRngfj3fewcnQ0aZzqSMVO5LDu7UQwlIuaMjaTs-0YlUkrKH_kQCohR) 63 | (Source: https://datascience.stackexchange.com/a/20176) 64 | 65 | At each upsampling layer, a skip connection to its corresponding layer in the encoder is added, the channels from both layers are concatenated and this is used as input for the next upsampling layer. 66 | 67 | Finally, on the final layer, sigmoid activation is applied and the resulting feature map is the segmented image. 68 | 69 | ## Loss Functions used 70 | We use two loss functions here, viz. `Binary Cross Entropy` and `Dice loss` 71 | 72 | ### Binary Cross Entropy Loss 73 | Cross Entropy measures the probability of an item belonging to a particular class, binary cross entropy is the same concept, except that here there are only 2 classes 74 | 75 | It is defined mathematically as 76 | In binary classification, where the number of classes $M$ equals 2, cross-entropy can be calculated as: 77 | 78 | $$ 79 | -(y * log(p) + (1-y)* log(1-p) ) 80 | $$ 81 | here $p$ is the prediction value and $y$ is the ground 82 | 83 | ### Dice Coefficient Loss 84 | The dice coefficient loss is used to measure the `intersection over union` of the output and target image. 85 | 86 | Mathematically, Dice Score is 87 | $$\frac{2 * |X \cap Y|}{|X| + |Y|}$$ 88 | 89 | 90 | The dice loss is defined in code as : 91 | 92 | 93 | class SoftDiceLoss(nn.Module): 94 | def __init__(self, weight=None, size_average=True): 95 | super(SoftDiceLoss, self).__init__() 96 | 97 | def forward(self, logits, targets): 98 | smooth = 1 99 | num = targets.size(0) 100 | probs = F.sigmoid(logits) 101 | m1 = probs.view(num, -1) 102 | m2 = targets.view(num, -1) 103 | intersection = (m1 * m2) 104 | 105 | score = 2. * (intersection.sum(1) + smooth) / (m1.sum(1) + m2.sum(1) + smooth) 106 | score = 1 - score.sum() / num 107 | return score 108 | 109 | 110 | > NOTE: The reason why intersection is implemented as a multiplication and the cardinality as `sum()` on axis 1 (each 3 channels sum) is because predictions and targets are one-hot encoded vectors 111 | 112 | ## Results 113 | | Loss | Validation Scores | Validation Scores | Test Scores | Test Scores | 114 | |:------------:|:------------:|-------------------|-------------------|-------------|-------------| 115 | | | mIoU | mDice | mIoU | mDice | 116 | | BCE | 0.9403 | 0.9692 | - | - | 117 | | BCE+DCL | 0.9426 | 0.9704 | - | - | 118 | | BCE+DCL+IDCL | 0.9665 | 0.9829 | 0.9295 | 0.9623 | 119 | The results with this network are good, and the some of the best ones are shown here 120 | 121 | Input Image 122 | ![prediction](https://imgur.com/MwOoEno.png) 123 | 124 | Correspoding Segmented Image 125 | ![Segmented Image](https://i.imgur.com/ak3Aa2M.png) 126 | 127 | 130 | 131 | 132 | 133 | # References 134 | 135 | 136 | [[1]](https://arxiv.org/abs/1801.05746) V. Iglovikov and A. Shvets - TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation, 2018 137 | - The code for our uses was taken from 138 | > https://github.com/ternaus/TernausNet 139 | 140 | [[2]](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) Krizhevsky, Alex, Ilya Sutskever, and Geoffrey E. Hinton. ”Imagenet classification with deep convolutional neural networks.” In Advances in neural information processing systems, pp. 1097-1105. 2012 141 | 142 | [[3]](https://arxiv.org/abs/1409.4842) Szegedy, Christian, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, and Andrew Rabinovich. ”Going deeper with convolutions.” In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1-9. 2015 143 | 144 | [[4]](https://arxiv.org/abs/1502.03167) Ioffe, Sergey, and Christian Szegedy. ”Batch normalization: Accelerating deep network training by reducing internal covariate shift.” In International Conference on Machine Learning, pp. 448-456. 2015 145 | 146 | [[5]](http://jmlr.org/papers/v15/srivastava14a.html) Srivastava, Nitish, Geoffrey E. Hinton, Alex Krizhevsky, Ilya Sutskever, and Ruslan Salakhutdinov. ”Dropout: a simple way to prevent neural networks from overfitting.”Journal of machine learning research 15, no. 1 (2014): 1929-1958 147 | 148 | [[6]](https://arxiv.org/abs/1409.1556) Simonyan, Karen, and Andrew Zisserman. ”Very deep convolutional networks 149 | 150 | [7] Bishop, Christopher M. Pattern recognition and machine learning. springer, 2006 151 | 152 | [[8]](http://www.deeplearningbook.org) Goodfellow, Ian, Yoshua Bengio, and Aaron Courville. Deep learning. MIT press,2016 153 | 154 | [[9]](https://www.tensorflow.org/) TensorFlow 155 | 156 | [[10]](https://keras.io/) Keras 157 | 158 | [[11]]([https://pytorch.org/docs/stable/index.html](https://pytorch.org/docs/stable/index.html)) Pytorch 159 | 160 | Return to main [README](www.github.com/medal-iitb/LungSegmentation/README.md) . 161 | -------------------------------------------------------------------------------- /VGG_UNet/code/build_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | class SegNet(nn.Module): 7 | def __init__(self,input_nbr,label_nbr): 8 | super(SegNet, self).__init__() 9 | 10 | 11 | self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1) 12 | self.bn11 = nn.BatchNorm2d(64) 13 | self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 14 | self.bn12 = nn.BatchNorm2d(64) 15 | 16 | self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 17 | self.bn21 = nn.BatchNorm2d(128) 18 | self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 19 | self.bn22 = nn.BatchNorm2d(128) 20 | 21 | self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 22 | self.bn31 = nn.BatchNorm2d(256) 23 | self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 24 | self.bn32 = nn.BatchNorm2d(256) 25 | self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 26 | self.bn33 = nn.BatchNorm2d(256) 27 | 28 | self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 29 | self.bn41 = nn.BatchNorm2d(512) 30 | self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 31 | self.bn42 = nn.BatchNorm2d(512) 32 | self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 33 | self.bn43 = nn.BatchNorm2d(512) 34 | 35 | self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 36 | self.bn51 = nn.BatchNorm2d(512) 37 | self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 38 | self.bn52 = nn.BatchNorm2d(512) 39 | self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 40 | self.bn53 = nn.BatchNorm2d(512) 41 | 42 | self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 43 | self.bn53d = nn.BatchNorm2d(512) 44 | self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 45 | self.bn52d = nn.BatchNorm2d(512) 46 | self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 47 | self.bn51d = nn.BatchNorm2d(512) 48 | 49 | self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 50 | self.bn43d = nn.BatchNorm2d(512) 51 | self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1) 52 | self.bn42d = nn.BatchNorm2d(512) 53 | self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1) 54 | self.bn41d = nn.BatchNorm2d(256) 55 | 56 | self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 57 | self.bn33d = nn.BatchNorm2d(256) 58 | self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1) 59 | self.bn32d = nn.BatchNorm2d(256) 60 | self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1) 61 | self.bn31d = nn.BatchNorm2d(128) 62 | 63 | self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1) 64 | self.bn22d = nn.BatchNorm2d(128) 65 | self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1) 66 | self.bn21d = nn.BatchNorm2d(64) 67 | 68 | self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1) 69 | self.bn12d = nn.BatchNorm2d(64) 70 | self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1) 71 | self.Dropout = nn.Dropout(0.5) 72 | 73 | 74 | def forward(self, x): 75 | 76 | # Stage 1 77 | x11 = F.relu(self.bn11(self.conv11(x))) 78 | x11 = self.Dropout(x11) 79 | x12 = F.relu(self.bn12(self.conv12(x11))) 80 | x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True) 81 | 82 | # Stage 2 83 | x21 = F.relu(self.bn21(self.conv21(x1p))) 84 | x22 = F.relu(self.bn22(self.conv22(x21))) 85 | x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True) 86 | 87 | # Stage 3 88 | x31 = F.relu(self.bn31(self.conv31(x2p))) 89 | x31 = self.Dropout(x31) 90 | x32 = F.relu(self.bn32(self.conv32(x31))) 91 | x33 = F.relu(self.bn33(self.conv33(x32))) 92 | x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True) 93 | 94 | # Stage 4 95 | x41 = F.relu(self.bn41(self.conv41(x3p))) 96 | x42 = F.relu(self.bn42(self.conv42(x41))) 97 | x43 = F.relu(self.bn43(self.conv43(x42))) 98 | x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True) 99 | 100 | # Stage 5 101 | x51 = F.relu(self.bn51(self.conv51(x4p))) 102 | x51 = self.Dropout(x51) 103 | x52 = F.relu(self.bn52(self.conv52(x51))) 104 | x53 = F.relu(self.bn53(self.conv53(x52))) 105 | x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True) 106 | 107 | 108 | # Stage 5d 109 | x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2) 110 | x53d = F.relu(self.bn53d(self.conv53d(x5d))) 111 | x52d = F.relu(self.bn52d(self.conv52d(x53d))) 112 | x51d = F.relu(self.bn51d(self.conv51d(x52d))) 113 | 114 | # Stage 4d 115 | x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2) 116 | x43d = F.relu(self.bn43d(self.conv43d(x4d))) 117 | x42d = F.relu(self.bn42d(self.conv42d(x43d))) 118 | x41d = F.relu(self.bn41d(self.conv41d(x42d))) 119 | 120 | # Stage 3d 121 | x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2) 122 | x33d = F.relu(self.bn33d(self.conv33d(x3d))) 123 | x32d = F.relu(self.bn32d(self.conv32d(x33d))) 124 | x31d = F.relu(self.bn31d(self.conv31d(x32d))) 125 | 126 | # Stage 2d 127 | x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2) 128 | x22d = F.relu(self.bn22d(self.conv22d(x2d))) 129 | x21d = F.relu(self.bn21d(self.conv21d(x22d))) 130 | 131 | # Stage 1d 132 | x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2) 133 | x12d = F.relu(self.bn12d(self.conv12d(x1d))) 134 | x11d = self.conv11d(x12d) 135 | 136 | return x11d 137 | 138 | 139 | import torch 140 | import torch.nn as nn 141 | import torch.nn.functional as F 142 | 143 | from unet_parts import * 144 | 145 | 146 | class UNet(nn.Module): 147 | def __init__(self, n_channels, n_classes): 148 | super(UNet, self).__init__() 149 | self.inc = inconv(n_channels, 64) 150 | self.down1 = down(64, 128) 151 | self.down2 = down(128, 256) 152 | self.down3 = down(256, 512) 153 | self.down4 = down(512, 512) 154 | self.up1 = up(1024, 256) 155 | self.up2 = up(512, 128) 156 | self.up3 = up(256, 64) 157 | self.up4 = up(128, 64) 158 | self.outc = outconv(64, n_classes) 159 | self.Dropout = nn.Dropout(0.5) 160 | 161 | def forward(self, x): 162 | x1 = self.inc(x) 163 | x2 = self.down1(x1) 164 | x3 = self.down2(x2) 165 | x3 = self.Dropout(x3) 166 | x4 = self.down3(x3) 167 | x5 = self.down4(x4) 168 | x = self.up1(x5, x4) 169 | x = self.up2(x, x3) 170 | x = self.up3(x, x2) 171 | x = self.up4(x, x1) 172 | x = self.outc(x) 173 | return x 174 | -------------------------------------------------------------------------------- /VGG_UNet/code/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data.dataset import Dataset 3 | from torchvision import transforms 4 | from skimage import io, transform 5 | from PIL import Image 6 | import os 7 | import numpy as np 8 | 9 | 10 | class LungSegTrain(Dataset): 11 | def __init__(self, path='Images_padded1/Train/', transforms=None): 12 | self.path = path 13 | self.list = os.listdir(self.path) 14 | 15 | self.transforms = transforms 16 | 17 | def __getitem__(self, index): 18 | # stuff 19 | image_path = 'Images_padded1/Train/' 20 | mask_path = 'Mask_padded1/Train/' 21 | image = Image.open(image_path+self.list[index]) 22 | image = image.convert('RGB') 23 | mask = Image.open(mask_path+self.list[index]) 24 | mask = mask.convert('L') 25 | if self.transforms is not None: 26 | image = self.transforms(image) 27 | mask = self.transforms(mask) 28 | # If the transform variable is not empty 29 | # then it applies the operations in the transforms with the order that it is created. 30 | return (image, mask) 31 | 32 | def __len__(self): 33 | return len(self.list) # of how many data(images?) you have 34 | 35 | 36 | class LungSegVal(Dataset): 37 | def __init__(self, path='Images_padded_test/', transforms=None): 38 | self.path = path 39 | self.list = os.listdir(self.path) 40 | 41 | self.transforms = transforms 42 | 43 | def __getitem__(self, index): 44 | # stuff 45 | image_path = 'Images_padded_test/' 46 | mask_path = 'Mask_padded_test/' 47 | image_name = self.list[index] 48 | image = Image.open(image_path+self.list[index]) 49 | image = image.convert('RGB') 50 | mask = Image.open(mask_path+self.list[index]) 51 | mask = mask.convert('L') 52 | if self.transforms is not None: 53 | image = self.transforms(image) 54 | mask = self.transforms(mask) 55 | # If the transform variable is not empty 56 | # then it applies the operations in the transforms with the order that it is created. 57 | return (image, mask, image_name) 58 | 59 | def __len__(self): 60 | return len(self.list) 61 | -------------------------------------------------------------------------------- /VGG_UNet/code/gaussian_noise_and_flipping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | from skimage import io 5 | from skimage.util import random_noise 6 | from skimage.filters import gaussian 7 | from tqdm import tqdm 8 | 9 | image_names = os.listdir("Images_padded1/Train/") 10 | 11 | for name in tqdm(image_names): 12 | img = io.imread('Images_padded1/Train/'+name) 13 | mask = io.imread('Mask_padded1/Train/'+name) 14 | 15 | noise_image_01 = random_noise(img, mode='gaussian', seed=None, clip=True, var = 0.01) 16 | noise_image_01 = noise_image_01*255. 17 | noise_image_01 = noise_image_01.astype('uint8') 18 | io.imsave('Images_padded1/Train/'+name[:-4]+'_01.png', noise_image_01) 19 | io.imsave('Mask_padded1/Train/'+name[:-4]+'_01.png', mask) 20 | 21 | ''' 22 | image_names = os.listdir("Images_padded1/Train/") 23 | 24 | for name in tqdm(image_names): 25 | img = cv2.imread('Images_padded1/Train/'+name,0) 26 | mask = cv2.imread('Mask_padded1/Train/'+name,0) 27 | 28 | # copy image to display all 4 variations 29 | horizontal_img = img.copy() 30 | horizontal_mask = mask.copy() 31 | 32 | 33 | # flip img horizontally, vertically, 34 | # and both axes with flip() 35 | horizontal_img = cv2.flip( img, 1 ) 36 | horizontal_mask = cv2.flip(mask, 1) 37 | 38 | 39 | # display the images on screen with imshow() 40 | cv2.imwrite( 'Images_padded1/Train/'+name[:-4]+'_flip.png', horizontal_img ) 41 | cv2.imwrite( 'Mask_padded1/Train/'+name[:-4]+'_flip.png', horizontal_mask ) 42 | 43 | ''' 44 | ''' 45 | img = io.imread('MCUCXR_0006_0.png') 46 | 47 | 48 | noise_image_1 = random_noise(img, mode='gaussian', seed=None, clip=True, var = 0.001) 49 | noise_image_1 = noise_image_1*255. 50 | noise_image_1 = noise_image_1.astype('uint8') 51 | 52 | io.imsave('MCUCXR_0006_0_noise_1.png', noise_image_1) 53 | ''' 54 | ''' 55 | blur = gaussian(img, sigma=5, output=None, mode='nearest', cval=0, multichannel=None, preserve_range=False, truncate=4.0) 56 | blur = blur*255. 57 | blur = blur.astype('uint16') 58 | io.imsave('MCUCXR_0006_0_vlur.png', blur) 59 | ''' 60 | -------------------------------------------------------------------------------- /VGG_UNet/code/inferences.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 4 | import torch 5 | from build_model import * 6 | from torch.utils.data import DataLoader 7 | from data_loader import LungSegVal 8 | from torchvision import transforms 9 | import torch.nn.functional as F 10 | import numpy as np 11 | 12 | from skimage import morphology, color, io, exposure 13 | 14 | def IoU(y_true, y_pred): 15 | """Returns Intersection over Union score for ground truth and predicted masks.""" 16 | assert y_true.dtype == bool and y_pred.dtype == bool 17 | y_true_f = y_true.flatten() 18 | y_pred_f = y_pred.flatten() 19 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 20 | union = np.logical_or(y_true_f, y_pred_f).sum() 21 | return (intersection + 1) * 1. / (union + 1) 22 | 23 | def Dice(y_true, y_pred): 24 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 25 | assert y_true.dtype == bool and y_pred.dtype == bool 26 | y_true_f = y_true.flatten() 27 | y_pred_f = y_pred.flatten() 28 | intersection = np.logical_and(y_true_f, y_pred_f).sum() 29 | return (2. * intersection + 1.) / (y_true.sum() + y_pred.sum() + 1.) 30 | 31 | def masked(img, gt, mask, alpha=1): 32 | """Returns image with GT lung field outlined with red, predicted lung field 33 | filled with blue.""" 34 | rows, cols = img.shape[:2] 35 | color_mask = np.zeros((rows, cols, 3)) 36 | boundary = morphology.dilation(gt, morphology.disk(3))^gt 37 | color_mask[mask == 1] = [0, 0, 1] 38 | color_mask[boundary == 1] = [1, 0, 0] 39 | 40 | img_hsv = color.rgb2hsv(img) 41 | color_mask_hsv = color.rgb2hsv(color_mask) 42 | 43 | img_hsv[..., 0] = color_mask_hsv[..., 0] 44 | img_hsv[..., 1] = color_mask_hsv[..., 1] * alpha 45 | 46 | img_masked = color.hsv2rgb(img_hsv) 47 | return img_masked 48 | 49 | def remove_small_regions(img, size): 50 | """Morphologically removes small (less than size) connected regions of 0s or 1s.""" 51 | img = morphology.remove_small_objects(img, size) 52 | img = morphology.remove_small_holes(img, size) 53 | return img 54 | 55 | if __name__ == '__main__': 56 | 57 | # Path to csv-file. File should contain X-ray filenames as first column, 58 | # mask filenames as second column. 59 | # Load test data 60 | img_size = (864, 864) 61 | 62 | inp_shape = (864,864,3) 63 | batch_size=1 64 | 65 | # Load model 66 | #model_name = 'model.020.hdf5' 67 | #UNet = load_model(model_name) 68 | net = SegNet(3,1) 69 | net.cuda() 70 | 71 | net.load_state_dict(torch.load('Weights_BCE_Dice_InvDice/cp_bce_flip_lr_04_no_rot52_0.04634043872356415.pth.tar')) 72 | net.eval() 73 | 74 | 75 | 76 | seed = 1 77 | transformations_test = transforms.Compose([transforms.Resize((864,864)),transforms.ToTensor()]) 78 | test_set = LungSegVal(transforms = transformations_test) 79 | test_loader = DataLoader(test_set, batch_size=batch_size, shuffle = False) 80 | ious = np.zeros(len(test_loader)) 81 | dices = np.zeros(len(test_loader)) 82 | if not(os.path.exists('./results_Unet_BCE_Dice_InvDice_test')): 83 | os.mkdir('./results_Unet_BCE_Dice_InvDice_test') 84 | 85 | 86 | i = 0 87 | for xx, yy, name in tqdm(test_loader): 88 | xx = xx.cuda() 89 | yy = yy 90 | 91 | name = name[0][:-4] 92 | print (name) 93 | pred = net(xx) 94 | pred = F.sigmoid(pred) 95 | pred = pred.cpu() 96 | pred = pred.detach().numpy()[0,0,:,:] 97 | mask = yy.numpy()[0,0,:,:] 98 | xx = xx.cpu() 99 | xx = xx.numpy()[0,:,:,:].transpose(1,2,0) 100 | img = exposure.rescale_intensity(np.squeeze(xx), out_range=(0,1)) 101 | 102 | # Binarize masks 103 | gt = mask > 0.5 104 | pr = pred > 0.5 105 | 106 | # Remove regions smaller than 2% of the image 107 | #pr = remove_small_regions(pr, 0.02 * np.prod(img_size)) 108 | 109 | 110 | io.imsave('results_Unet_BCE_Dice_InvDice_test/{}.png'.format(name), pr*255) 111 | 112 | ious[i] = IoU(gt, pr) 113 | dices[i] = Dice(gt, pr) 114 | 115 | i += 1 116 | if i == len(test_loader): 117 | break 118 | 119 | print ('Mean IoU:', ious.mean()) 120 | print ('Mean Dice:', dices.mean()) 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /VGG_UNet/code/metric_calc.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import cv2 4 | import os 5 | import sys 6 | from PIL import Image 7 | 8 | 9 | 10 | 11 | def get_IoU(Gi,Si): 12 | #print(Gi.shape, Si.shape) 13 | intersect = 1.0*np.sum(np.logical_and(Gi,Si)) 14 | union = 1.0*np.sum(np.logical_or(Gi,Si)) 15 | return intersect/union 16 | #check cv2.connectedComponents and what it returns, alongwith channels first or last 17 | def generate_list(G,S): 18 | G = G.astype('uint8') 19 | S = S.astype('uint8') 20 | #print(np.unique(G)) 21 | gland_obj_cnt,gland_obj = cv2.connectedComponents(G,connectivity=8) 22 | seg_obj_cnt,seg_obj = cv2.connectedComponents(S,connectivity=8) 23 | gland_obj_list = [] 24 | seg_obj_list = [] 25 | for i in range(1,gland_obj_cnt): 26 | gland_obj_list.append( (gland_obj==(i)).astype('int32') ) 27 | for i in range(1,seg_obj_cnt): 28 | seg_obj_list.append( (seg_obj==(i)).astype('int32') ) 29 | gland_obj_list = np.array(gland_obj_list) 30 | seg_obj_list = np.array(seg_obj_list) 31 | return gland_obj_list,seg_obj_list 32 | 33 | ####Find why channel parameter was passed 34 | def AGI_core(gland_obj_list,seg_obj_list,channel='last'): 35 | C = 0.0 36 | U = 0.0 37 | ##check below: 38 | ''' 39 | Swapping is not required. 40 | if(channel=='last'): 41 | # make channels first 42 | gland_obj_list = np.swapaxes( np.swapaxes( gland_obj_list , 0,2) , 1 , 2 ) 43 | seg_obj_list = np.swapaxes( np.swapaxes( seg_obj_list , 0,2) , 1 , 2 ) 44 | ''' 45 | #print(gland_obj_list.shape) 46 | seg_nonused = np.ones(len(seg_obj_list)) 47 | for gi in gland_obj_list: 48 | iou = np.multiply( [get_IoU(gi,si) for si in seg_obj_list] , seg_nonused ) 49 | max_iou = np.max(iou) 50 | j = np.argmax(iou) 51 | C = C + np.sum(np.logical_and(gi,seg_obj_list[j]) ) 52 | U = U + np.sum(np.logical_or(gi,seg_obj_list[j]) ) 53 | seg_nonused[j] = 0 54 | for ind in range(len(seg_obj_list)): 55 | if((seg_nonused[ind])==1): 56 | U = U + np.sum(seg_obj_list[ind]) 57 | return C*1./U 58 | 59 | def Acc_Jacard_Index(G,S): 60 | gland_obj_list,seg_obj_list = generate_list(G, S) 61 | print("In AJI and length is:{}".format(len(gland_obj_list))) 62 | return AGI_core(gland_obj_list,seg_obj_list) 63 | 64 | #----------------------------------------------- 65 | 66 | def F1_core(gland_obj_list,seg_obj_list,channel='first'): 67 | TP,FP,FN = 0.0,0.0,0.0 68 | if(channel=='last'): 69 | # make channels first 70 | gland_obj_list = np.swapaxes( np.swapaxes( gland_obj_list , 0,2) , 1 , 2 ) 71 | seg_obj_list = np.swapaxes( np.swapaxes( seg_obj_list , 0,2) , 1 , 2 ) 72 | seg_nonused = np.ones(len(seg_obj_list)) 73 | gland_unhit = np.ones(len(seg_obj_list)) 74 | 75 | for ind in range(len(gland_obj_list)): 76 | gi = gland_obj_list[ind] 77 | overlap_s = np.multiply( np.sum( seg_obj_list*gi , axis=(1,2) ) , seg_nonused ) 78 | max_ov = np.max(overlap_s) 79 | percent_overlap = max_ov/np.sum(gi) 80 | if percent_overlap>=0.01 : 81 | # hit 82 | TP = TP +1 83 | j = np.argmax(overlap_s) 84 | seg_nonused[j] = 0 85 | else: 86 | # unhit 87 | FN = FN + 1 88 | 89 | FP = np.sum(seg_nonused) 90 | F1_val = (2*TP)/(2*TP + FP + FN) 91 | return F1_val 92 | 93 | def F1_score(G,S): 94 | #y_mask = np.asarray(G[:, :, :, 0]).astype('uint8') 95 | #print y_mask.shape 96 | #y_pred = np.asarray(S[:, :, :, 0]).astype('uint8') 97 | #print y_pred.shape 98 | #print type(y_pred) 99 | #y_dist = K.expand_dims(G[:, :, :, 1], axis=-1) 100 | gland_obj_list,seg_obj_list = generate_list(G, S) 101 | return F1_core(gland_obj_list,seg_obj_list) 102 | 103 | def Dice(y_true, y_pred): 104 | """Returns Dice Similarity Coefficient for ground truth and predicted masks.""" 105 | #print(y_true.dtype) 106 | #print(y_pred.dtype) 107 | y_true = np.squeeze(y_true)/255 108 | y_pred = np.squeeze(y_pred)/255 109 | y_true.astype('bool') 110 | y_pred.astype('bool') 111 | intersection = np.logical_and(y_true, y_pred).sum() 112 | return ((2. * intersection.sum()) + 1.) / (y_true.sum() + y_pred.sum() + 1.) 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | smooth = 1 121 | 122 | image_names = os.listdir('results_scratch_custom99/') 123 | 124 | mean_dice = [] 125 | mean_F1 = [] 126 | aggr_jacard = [] 127 | 128 | 129 | for images in tqdm(image_names): 130 | 131 | S = np.expand_dims(np.array(Image.open('results_scratch_custom99/'+images).convert('L')),axis=-1) 132 | G = np.expand_dims(np.array(Image.open('/home/sahyadri/Testing/Test_40_y_HE/'+images).convert('L')),axis=-1) 133 | #print S.shape 134 | #G.shape 135 | #print(Acc_Jacard_Index(G,S)) 136 | #aggr_jacard.append(Acc_Jacard_Index(G,S)) 137 | #mean_F1.append(F1_score(G,S)) 138 | mean_dice.append(Dice(G, S)) 139 | 140 | 141 | print ('Mean_Dice = ', np.mean(np.array(mean_dice))) 142 | #print ('Mean_F1 = ', np.mean(np.array(mean_F1))) 143 | #print (len(aggr_jacard), aggr_jacard) 144 | #print ('Mean_Aggr_Jacard = ', np.mean(np.array(aggr_jacard))) 145 | 146 | f = open('lung.txt','w') 147 | a = 'Mean Dice : {}'.format(np.mean(np.array(mean_dice)))+ '\n' + 'Mean F1 : {}'.format(np.mean(np.array(mean_F1)))+ '\n' + 'Mean Aggregate Jacard : {}'.format(np.mean(np.array(aggr_jacard)))+ '\n' 148 | 149 | f.write(str(a)) 150 | f.close() 151 | 152 | -------------------------------------------------------------------------------- /VGG_UNet/code/preprocess_combining.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage import io, exposure 4 | 5 | 6 | 7 | def make_masks(): 8 | path = 'test_image/' 9 | for i, filename in enumerate(os.listdir(path)): 10 | left = io.imread('test_mask/left/' + filename[:-4] + '.png') 11 | right = io.imread('test_mask/right/' + filename[:-4] + '.png') 12 | io.imsave('test_mask/Mask/' + filename[:-4] + '.png', np.clip(left + right, 0, 255)) 13 | print ('Mask', i, filename) 14 | 15 | make_masks() 16 | 17 | -------------------------------------------------------------------------------- /VGG_UNet/code/preprocess_padding.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | from PIL import Image 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | img_names = os.listdir('test_image/') 10 | 11 | 12 | for name in tqdm(img_names): 13 | 14 | img = np.asarray(Image.open('test_image/'+name)) 15 | max_dim_img = max(img.shape[0], img.shape[1]) 16 | row_img = max_dim_img - img.shape[0] 17 | cols_img = max_dim_img - img.shape[1] 18 | padded_img = np.pad(img, ((row_img//2,row_img//2), (cols_img//2,cols_img//2)), mode ='constant') 19 | img = Image.fromarray(padded_img) 20 | img.save('Images_padded_test/'+name) 21 | 22 | mask = np.asarray(Image.open('test_mask/Mask/'+name)) 23 | max_dim_mask = max(mask.shape[0], mask.shape[1]) 24 | row_mask = max_dim_mask- mask.shape[0] 25 | cols_mask = max_dim_mask - mask.shape[1] 26 | padded_mask = np.pad(mask, ((row_mask//2,row_mask//2), (cols_mask//2,cols_mask//2)), mode ='constant') 27 | mask = Image.fromarray(padded_mask) 28 | mask.save('Mask_padded_test/'+name) 29 | 30 | -------------------------------------------------------------------------------- /VGG_UNet/code/train_model_Unet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torch.autograd import Variable 8 | from build_model import * 9 | import os 10 | from tqdm import tqdm 11 | from tensorboardX import SummaryWriter 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 15 | #-------------------------- 16 | class Average(object): 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.sum += val 26 | self.count += n 27 | 28 | #property 29 | def avg(self): 30 | return self.sum / self.count 31 | #------------------------------ 32 | # import csv 33 | writer = SummaryWriter() 34 | #---------------------------------------- 35 | class SoftDiceLoss(nn.Module): 36 | ''' 37 | Soft Dice Loss 38 | ''' 39 | def __init__(self, weight=None, size_average=True): 40 | super(SoftDiceLoss, self).__init__() 41 | 42 | def forward(self, logits, targets): 43 | smooth = 1. 44 | logits = F.sigmoid(logits) 45 | iflat = logits.view(-1) 46 | tflat = targets.view(-1) 47 | intersection = (iflat * tflat).sum() 48 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 49 | #------------------------------------------------- 50 | ''' 51 | class SoftDicescore(nn.Module): 52 | ''' 53 | #Soft Dice Loss 54 | ''' 55 | def __init__(self, weight=None, size_average=True): 56 | super(SoftDicescore, self).__init__() 57 | 58 | def forward(self, logits, targets): 59 | smooth = 1. 60 | logits = F.sigmoid(logits) 61 | iflat = logits.view(-1) 62 | tflat = targets.view(-1) 63 | intersection = (iflat * tflat).sum() 64 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 65 | ''' 66 | #------------------------------------------------- 67 | ''' 68 | class W_bce(nn.Module): 69 | 70 | #weighted crossentropy per image 71 | 72 | def __init__(self, weight=None, size_average=True): 73 | super(W_bce, self).__init__() 74 | 75 | def forward(self, logits, targets): 76 | eps = 1e-6 77 | total_size = targets.view(-1).size()[0] 78 | #print "total_size", total_size 79 | ones_size = torch.sum(targets.view(-1,1)).item() 80 | #print "one_size", ones_size 81 | zero_size = total_size - ones_size 82 | #print "zero_size", zero_size 83 | #assert total_size == (ones_size + zero_size) 84 | #print "crossed assertion" 85 | loss_1 = torch.mean(-(targets.view(-1)* ( total_size/ones_size) * torch.log(torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps))))#.sum(axis=1) 86 | #print "crossed loss1" 87 | loss_0 = torch.mean(-((1.-targets.view(-1))* ( total_size/zero_size) * torch.log((1.-torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps)))))#.sum(axis=1) 88 | #print "crossed loss0" 89 | return loss_1 + loss_0 90 | ''' 91 | #---------------------------------- 92 | class InvSoftDiceLoss(nn.Module): 93 | 94 | ''' 95 | Inverted Soft Dice Loss 96 | ''' 97 | def __init__(self, weight=None, size_average=True): 98 | super(InvSoftDiceLoss, self).__init__() 99 | 100 | def forward(self, logits, targets): 101 | smooth = 1. 102 | logits = F.sigmoid(logits) 103 | iflat = 1-logits.view(-1) 104 | tflat = 1-targets.view(-1) 105 | intersection = (iflat * tflat).sum() 106 | 107 | 108 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 109 | #-------------------------------------- 110 | ''' 111 | class InvSoftDicescore(nn.Module): 112 | 113 | ''' 114 | #Inverted Soft Dice Loss 115 | ''' 116 | 117 | def __init__(self, weight=None, size_average=True): 118 | super(InvSoftDicescore, self).__init__() 119 | 120 | def forward(self, logits, targets): 121 | smooth = 1. 122 | logits = F.sigmoid(logits) 123 | iflat = 1-logits.view(-1) 124 | tflat = 1-targets.view(-1) 125 | intersection = (iflat * tflat).sum() 126 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 127 | ''' 128 | #---------------------------------------- 129 | ''' 130 | class int_custom_loss(nn.Module): 131 | ''' 132 | #custom loss 133 | ''' 134 | def __init__(self, weight=None, size_average=True): 135 | super(int_custom_loss, self).__init__() 136 | 137 | def forward(self, logits, targets): 138 | loss_inv_dice = InvSoftDicescore() 139 | loss_dice = SoftDicescore() 140 | total_size = targets.view(-1).size()[0] 141 | ones_size = torch.sum(targets.view(-1,1)).item() 142 | th = 0.2 * total_size 143 | if(ones_size > th): 144 | return (- 0.8*torch.log(loss_dice(logits,targets))-0.2*torch.log(loss_inv_dice(logits, targets))) 145 | else: 146 | return(-0.2*torch.log(loss_dice(logits, targets))-0.8*torch.log(loss_inv_dice(logits, targets))) 147 | ''' 148 | ''' 149 | class weighted_dice_invdice(nn.Module): 150 | ''' 151 | #custom loss 152 | ''' 153 | def __init__(self, weight=None, size_average=True): 154 | super(weighted_dice_invdice, self).__init__() 155 | 156 | def forward(self, logits, targets): 157 | loss_inv_dice = InvSoftDicescore() 158 | loss_dice = SoftDicescore() 159 | total_size = targets.view(-1).size()[0] 160 | ones_size = torch.sum(targets.view(-1,1)).item() 161 | zero_size = total_size - ones_size 162 | th = 0.2 * total_size 163 | return (-(zero_size/total_size)*torch.log(loss_dice(logits,targets))-(ones_size/total_size)*torch.log(loss_inv_dice(logits, targets))) 164 | ''' 165 | 166 | #Tranformations------------------------------------------------ 167 | transformations_train = transforms.Compose([transforms.Resize((1024,1024)),transforms.ToTensor()]) 168 | 169 | transformations_val = transforms.Compose([transforms.Resize((1024,1024)),transforms.ToTensor()]) 170 | #------------------------------------------------------------- 171 | 172 | from data_loader import LungSegTrain 173 | from data_loader import LungSegVal 174 | train_set = LungSegTrain(transforms = transformations_train) 175 | batch_size = 1 176 | num_epochs = 75 177 | 178 | def train(): 179 | cuda = torch.cuda.is_available() 180 | net = UNet(3,1) 181 | if cuda: 182 | net = net.cuda() 183 | #net.load_state_dict(torch.load('Weights_BCE_Dice/cp_bce_lr_05_100_0.222594484687.pth.tar')) 184 | criterion1 = nn.BCEWithLogitsLoss().cuda() 185 | criterion2 = SoftDiceLoss().cuda() 186 | criterion3 = InvSoftDiceLoss().cuda() 187 | #criterion4 = W_bce().cuda() 188 | #criterion5 = int_custom_loss() 189 | #criterion6 = weighted_dice_invdice() 190 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) 191 | #scheduler = MultiStepLR(optimizer, milestones=[2,10,75,100], gamma=0.1) 192 | 193 | print("preparing training data ...") 194 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 195 | print("done ...") 196 | val_set = LungSegVal(transforms = transformations_val) 197 | val_loader = DataLoader(val_set, batch_size=batch_size,shuffle=False) 198 | for epoch in tqdm(range(num_epochs)): 199 | #scheduler.step() 200 | train_loss = Average() 201 | net.train() 202 | for i, (images, masks) in tqdm(enumerate(train_loader)): 203 | images = Variable(images) 204 | masks = Variable(masks) 205 | if cuda: 206 | images = images.cuda() 207 | masks = masks.cuda() 208 | 209 | optimizer.zero_grad() 210 | outputs = net(images) 211 | #writer.add_image('Training Input',images) 212 | #writer.add_image('Training Pred',F.sigmoid(outputs)>0.5) 213 | c1 = criterion1(outputs,masks) + criterion2(outputs, masks) + criterion3(outputs, masks) 214 | loss = c1 215 | writer.add_scalar('Train Loss',loss,epoch) 216 | loss.backward() 217 | optimizer.step() 218 | train_loss.update(loss.item(), images.size(0)) 219 | for param_group in optimizer.param_groups: 220 | writer.add_scalar('Learning Rate',param_group['lr']) 221 | 222 | val_loss1 = Average() 223 | val_loss2 = Average() 224 | val_loss3 = Average() 225 | net.eval() 226 | for images, masks,_ in tqdm(val_loader): 227 | images = Variable(images) 228 | masks = Variable(masks) 229 | if cuda: 230 | images = images.cuda() 231 | masks = masks.cuda() 232 | 233 | outputs = net(images) 234 | if (epoch)%10==0: 235 | writer.add_image('Validation Input',images,epoch) 236 | writer.add_image('Validation GT ',masks,epoch) 237 | writer.add_image('Validation Pred0.5',F.sigmoid(outputs)>0.5,epoch) 238 | writer.add_image('Validation Pred0.3',F.sigmoid(outputs)>0.3,epoch) 239 | writer.add_image('Validation Pred0.65',F.sigmoid(outputs)>0.65,epoch) 240 | 241 | vloss1 = criterion1(outputs, masks) 242 | vloss2 = criterion2(outputs, masks) 243 | vloss3 = criterion3(outputs, masks) #+ criterion2(outputs, masks) 244 | #vloss = vloss2 + vloss3 245 | writer.add_scalar('Validation loss(BCE)',vloss1,epoch) 246 | writer.add_scalar('Validation loss(Dice)',vloss2,epoch) 247 | writer.add_scalar('Validation loss(InvDice)',vloss3,epoch) 248 | 249 | val_loss1.update(vloss1.item(), images.size(0)) 250 | val_loss2.update(vloss2.item(), images.size(0)) 251 | val_loss3.update(vloss3.item(), images.size(0)) 252 | 253 | print("Epoch {}, Training Loss(BCE+Dice): {}, Validation Loss(BCE): {}, Validation Loss(Dice): {}, Validation Loss(InvDice): {}".format(epoch+1, train_loss.avg(), val_loss1.avg(), val_loss2.avg(), val_loss3.avg())) 254 | 255 | # with open('Log.csv', 'a') as logFile: 256 | # FileWriter = csv.writer(logFile) 257 | # FileWriter.writerow([epoch+1, train_loss.avg, val_loss1.avg, val_loss2.avg, val_loss3.avg]) 258 | 259 | torch.save(net.state_dict(), 'Weights_BCE_Dice_InvDice/cp_bce_flip_lr_04_no_rot{}_{}.pth.tar'.format(epoch+1, val_loss2.avg())) 260 | return net 261 | 262 | def test(model): 263 | model.eval() 264 | 265 | 266 | 267 | if __name__ == "__main__": 268 | train() 269 | -------------------------------------------------------------------------------- /VGG_UNet/code/train_model_segnet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torch.autograd import Variable 8 | from build_model import * 9 | import os 10 | from tqdm import tqdm 11 | from tensorboardX import SummaryWriter 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | 14 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 15 | #-------------------------- 16 | class Average(object): 17 | def __init__(self): 18 | self.reset() 19 | 20 | def reset(self): 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.sum += val 26 | self.count += n 27 | 28 | #property 29 | def avg(self): 30 | return self.sum / self.count 31 | #------------------------------ 32 | # import csv 33 | writer = SummaryWriter() 34 | #---------------------------------------- 35 | class SoftDiceLoss(nn.Module): 36 | ''' 37 | Soft Dice Loss 38 | ''' 39 | def __init__(self, weight=None, size_average=True): 40 | super(SoftDiceLoss, self).__init__() 41 | 42 | def forward(self, logits, targets): 43 | smooth = 1. 44 | logits = F.sigmoid(logits) 45 | iflat = logits.view(-1) 46 | tflat = targets.view(-1) 47 | intersection = (iflat * tflat).sum() 48 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 49 | #------------------------------------------------- 50 | ''' 51 | class SoftDicescore(nn.Module): 52 | ''' 53 | #Soft Dice Loss 54 | ''' 55 | def __init__(self, weight=None, size_average=True): 56 | super(SoftDicescore, self).__init__() 57 | 58 | def forward(self, logits, targets): 59 | smooth = 1. 60 | logits = F.sigmoid(logits) 61 | iflat = logits.view(-1) 62 | tflat = targets.view(-1) 63 | intersection = (iflat * tflat).sum() 64 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 65 | ''' 66 | #------------------------------------------------- 67 | ''' 68 | class W_bce(nn.Module): 69 | 70 | #weighted crossentropy per image 71 | 72 | def __init__(self, weight=None, size_average=True): 73 | super(W_bce, self).__init__() 74 | 75 | def forward(self, logits, targets): 76 | eps = 1e-6 77 | total_size = targets.view(-1).size()[0] 78 | #print "total_size", total_size 79 | ones_size = torch.sum(targets.view(-1,1)).item() 80 | #print "one_size", ones_size 81 | zero_size = total_size - ones_size 82 | #print "zero_size", zero_size 83 | #assert total_size == (ones_size + zero_size) 84 | #print "crossed assertion" 85 | loss_1 = torch.mean(-(targets.view(-1)* ( total_size/ones_size) * torch.log(torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps))))#.sum(axis=1) 86 | #print "crossed loss1" 87 | loss_0 = torch.mean(-((1.-targets.view(-1))* ( total_size/zero_size) * torch.log((1.-torch.clamp(F.sigmoid(logits).view(-1),eps,1.-eps)))))#.sum(axis=1) 88 | #print "crossed loss0" 89 | return loss_1 + loss_0 90 | ''' 91 | #---------------------------------- 92 | class InvSoftDiceLoss(nn.Module): 93 | 94 | ''' 95 | Inverted Soft Dice Loss 96 | ''' 97 | def __init__(self, weight=None, size_average=True): 98 | super(InvSoftDiceLoss, self).__init__() 99 | 100 | def forward(self, logits, targets): 101 | smooth = 1. 102 | logits = F.sigmoid(logits) 103 | iflat = 1-logits.view(-1) 104 | tflat = 1-targets.view(-1) 105 | intersection = (iflat * tflat).sum() 106 | 107 | 108 | return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 109 | #-------------------------------------- 110 | ''' 111 | class InvSoftDicescore(nn.Module): 112 | 113 | ''' 114 | #Inverted Soft Dice Loss 115 | ''' 116 | 117 | def __init__(self, weight=None, size_average=True): 118 | super(InvSoftDicescore, self).__init__() 119 | 120 | def forward(self, logits, targets): 121 | smooth = 1. 122 | logits = F.sigmoid(logits) 123 | iflat = 1-logits.view(-1) 124 | tflat = 1-targets.view(-1) 125 | intersection = (iflat * tflat).sum() 126 | return ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) 127 | ''' 128 | #---------------------------------------- 129 | ''' 130 | class int_custom_loss(nn.Module): 131 | ''' 132 | #custom loss 133 | ''' 134 | def __init__(self, weight=None, size_average=True): 135 | super(int_custom_loss, self).__init__() 136 | 137 | def forward(self, logits, targets): 138 | loss_inv_dice = InvSoftDicescore() 139 | loss_dice = SoftDicescore() 140 | total_size = targets.view(-1).size()[0] 141 | ones_size = torch.sum(targets.view(-1,1)).item() 142 | th = 0.2 * total_size 143 | if(ones_size > th): 144 | return (- 0.8*torch.log(loss_dice(logits,targets))-0.2*torch.log(loss_inv_dice(logits, targets))) 145 | else: 146 | return(-0.2*torch.log(loss_dice(logits, targets))-0.8*torch.log(loss_inv_dice(logits, targets))) 147 | ''' 148 | ''' 149 | class weighted_dice_invdice(nn.Module): 150 | ''' 151 | #custom loss 152 | ''' 153 | def __init__(self, weight=None, size_average=True): 154 | super(weighted_dice_invdice, self).__init__() 155 | 156 | def forward(self, logits, targets): 157 | loss_inv_dice = InvSoftDicescore() 158 | loss_dice = SoftDicescore() 159 | total_size = targets.view(-1).size()[0] 160 | ones_size = torch.sum(targets.view(-1,1)).item() 161 | zero_size = total_size - ones_size 162 | th = 0.2 * total_size 163 | return (-(zero_size/total_size)*torch.log(loss_dice(logits,targets))-(ones_size/total_size)*torch.log(loss_inv_dice(logits, targets))) 164 | ''' 165 | 166 | #Tranformations------------------------------------------------ 167 | transformations_train = transforms.Compose([transforms.Resize((864,864)),transforms.ToTensor()]) 168 | 169 | transformations_val = transforms.Compose([transforms.Resize((864,864)),transforms.ToTensor()]) 170 | #------------------------------------------------------------- 171 | 172 | from data_loader import LungSegTrain 173 | from data_loader import LungSegVal 174 | train_set = LungSegTrain(transforms = transformations_train) 175 | batch_size = 1 176 | num_epochs = 75 177 | 178 | def train(): 179 | cuda = torch.cuda.is_available() 180 | net = SegNet(3,1) 181 | if cuda: 182 | net = net.cuda() 183 | #net.load_state_dict(torch.load('Weights_BCE_Dice/cp_bce_lr_05_100_0.222594484687.pth.tar')) 184 | criterion1 = nn.BCEWithLogitsLoss().cuda() 185 | criterion2 = SoftDiceLoss().cuda() 186 | criterion3 = InvSoftDiceLoss().cuda() 187 | #criterion4 = W_bce().cuda() 188 | #criterion5 = int_custom_loss() 189 | #criterion6 = weighted_dice_invdice() 190 | optimizer = torch.optim.Adam(net.parameters(), lr=1e-4) 191 | #scheduler = MultiStepLR(optimizer, milestones=[2,10,75,100], gamma=0.1) 192 | 193 | print("preparing training data ...") 194 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True) 195 | print("done ...") 196 | val_set = LungSegVal(transforms = transformations_val) 197 | val_loader = DataLoader(val_set, batch_size=batch_size,shuffle=False) 198 | for epoch in tqdm(range(num_epochs)): 199 | #scheduler.step() 200 | train_loss = Average() 201 | net.train() 202 | for i, (images, masks) in tqdm(enumerate(train_loader)): 203 | images = Variable(images) 204 | masks = Variable(masks) 205 | if cuda: 206 | images = images.cuda() 207 | masks = masks.cuda() 208 | 209 | optimizer.zero_grad() 210 | outputs = net(images) 211 | #writer.add_image('Training Input',images) 212 | #writer.add_image('Training Pred',F.sigmoid(outputs)>0.5) 213 | c1 = criterion1(outputs,masks) + criterion2(outputs, masks) + criterion3(outputs, masks) 214 | loss = c1 215 | writer.add_scalar('Train Loss',loss,epoch) 216 | loss.backward() 217 | optimizer.step() 218 | train_loss.update(loss.item(), images.size(0)) 219 | for param_group in optimizer.param_groups: 220 | writer.add_scalar('Learning Rate',param_group['lr']) 221 | val_loss1 = Average() 222 | val_loss2 = Average() 223 | val_loss3 = Average() 224 | net.eval() 225 | for images, masks,_ in tqdm(val_loader): 226 | images = Variable(images) 227 | masks = Variable(masks) 228 | if cuda: 229 | images = images.cuda() 230 | masks = masks.cuda() 231 | 232 | outputs = net(images) 233 | if (epoch)%10==0: 234 | writer.add_image('Validation Input',images,epoch) 235 | writer.add_image('Validation GT ',masks,epoch) 236 | writer.add_image('Validation Pred0.5',F.sigmoid(outputs)>0.5,epoch) 237 | writer.add_image('Validation Pred0.3',F.sigmoid(outputs)>0.3,epoch) 238 | writer.add_image('Validation Pred0.65',F.sigmoid(outputs)>0.65,epoch) 239 | 240 | vloss1 = criterion1(outputs, masks) 241 | vloss2 = criterion2(outputs, masks) 242 | vloss3 = criterion3(outputs, masks) #+ criterion2(outputs, masks) 243 | #vloss = vloss2 + vloss3 244 | writer.add_scalar('Validation loss(BCE)',vloss1,epoch) 245 | writer.add_scalar('Validation loss(Dice)',vloss2,epoch) 246 | writer.add_scalar('Validation loss(InvDice)',vloss3,epoch) 247 | 248 | val_loss1.update(vloss1.item(), images.size(0)) 249 | val_loss2.update(vloss2.item(), images.size(0)) 250 | val_loss3.update(vloss3.item(), images.size(0)) 251 | 252 | print("Epoch {}, Training Loss(BCE+Dice): {}, Validation Loss(BCE): {}, Validation Loss(Dice): {}, Validation Loss(InvDice): {}".format(epoch+1, train_loss.avg(), val_loss1.avg(), val_loss2.avg(), val_loss3.avg())) 253 | 254 | # with open('Log.csv', 'a') as logFile: 255 | # FileWriter = csv.writer(logFile) 256 | # FileWriter.writerow([epoch+1, train_loss.avg, val_loss1.avg, val_loss2.avg, val_loss3.avg]) 257 | 258 | torch.save(net.state_dict(), 'Weights_BCE_Dice_InvDice/cp_bce_flip_lr_04_no_rot{}_{}.pth.tar'.format(epoch+1, val_loss2.avg())) 259 | return net 260 | 261 | def test(model): 262 | model.eval() 263 | 264 | 265 | 266 | if __name__ == "__main__": 267 | train() 268 | -------------------------------------------------------------------------------- /VGG_UNet/code/unet_parts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # sub-parts of the U-Net model 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class double_conv(nn.Module): 11 | '''(conv => BN => ReLU) * 2''' 12 | def __init__(self, in_ch, out_ch): 13 | super(double_conv, self).__init__() 14 | self.conv = nn.Sequential( 15 | nn.Conv2d(in_ch, out_ch, 3, padding=1), 16 | nn.BatchNorm2d(out_ch), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(out_ch, out_ch, 3, padding=1), 19 | nn.BatchNorm2d(out_ch), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | def forward(self, x): 24 | x = self.conv(x) 25 | return x 26 | 27 | 28 | class inconv(nn.Module): 29 | def __init__(self, in_ch, out_ch): 30 | super(inconv, self).__init__() 31 | self.conv = double_conv(in_ch, out_ch) 32 | 33 | def forward(self, x): 34 | x = self.conv(x) 35 | return x 36 | 37 | 38 | class down(nn.Module): 39 | def __init__(self, in_ch, out_ch): 40 | super(down, self).__init__() 41 | self.mpconv = nn.Sequential( 42 | nn.MaxPool2d(2), 43 | double_conv(in_ch, out_ch) 44 | ) 45 | 46 | def forward(self, x): 47 | x = self.mpconv(x) 48 | return x 49 | 50 | 51 | class up(nn.Module): 52 | def __init__(self, in_ch, out_ch, bilinear=True): 53 | super(up, self).__init__() 54 | 55 | # would be a nice idea if the upsampling could be learned too, 56 | # but my machine do not have enough memory to handle all those weights 57 | if bilinear: 58 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 59 | else: 60 | self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) 61 | 62 | self.conv = double_conv(in_ch, out_ch) 63 | 64 | def forward(self, x1, x2): 65 | x1 = self.up(x1) 66 | diffX = x1.size()[2] - x2.size()[2] 67 | diffY = x1.size()[3] - x2.size()[3] 68 | x2 = F.pad(x2, (diffX // 2, int(diffX / 2), 69 | diffY // 2, int(diffY / 2))) 70 | x = torch.cat([x2, x1], dim=1) 71 | x = self.conv(x) 72 | return x 73 | 74 | 75 | class outconv(nn.Module): 76 | def __init__(self, in_ch, out_ch): 77 | super(outconv, self).__init__() 78 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | return x 83 | --------------------------------------------------------------------------------