├── LICENSE ├── Plots ├── Loss_plots.png ├── Train_dice_plots.png ├── Val_dice_plots.png ├── find_acc.ipynb ├── luna_inference.py └── new_test.py ├── README.md ├── data_prep ├── LUNA_mean_std.ipynb ├── luna_mask.ipynb ├── test_pts.npy ├── train_pts.npy └── val_pts.npy └── train_codes ├── LUNA_loader.py ├── lovasz_losses.py └── train_sumnet_luna_CE_Lov.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Rakshith Sathish 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Plots/Loss_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rakshith2597/Lung-nodule-detection-LUNA-16/4114784ad6c06467db8c106c266a1f128016f920/Plots/Loss_plots.png -------------------------------------------------------------------------------- /Plots/Train_dice_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rakshith2597/Lung-nodule-detection-LUNA-16/4114784ad6c06467db8c106c266a1f128016f920/Plots/Train_dice_plots.png -------------------------------------------------------------------------------- /Plots/Val_dice_plots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rakshith2597/Lung-nodule-detection-LUNA-16/4114784ad6c06467db8c106c266a1f128016f920/Plots/Val_dice_plots.png -------------------------------------------------------------------------------- /Plots/find_acc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import pandas as pd\n", 11 | "from glob import glob\n", 12 | "import os\n", 13 | "import torch\n", 14 | "import SimpleITK as sitk\n", 15 | "from SUMNet_bn import SUMNet\n", 16 | "from torchvision import transforms\n", 17 | "import torch.nn.functional as F\n", 18 | "import cv2\n", 19 | "from tqdm import tqdm_notebook as tq" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "def load_itk_image(filename):\n", 29 | " itkimage = sitk.ReadImage(filename)\n", 30 | " numpyImage = sitk.GetArrayFromImage(itkimage)\n", 31 | " \n", 32 | " numpyOrigin = np.array(list(reversed(itkimage.GetOrigin())))\n", 33 | " numpySpacing = np.array(list(reversed(itkimage.GetSpacing())))\n", 34 | " return numpyImage, numpyOrigin, numpySpacing" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "seg_model_loadPath = '/home/siplab/rachana/rak/Results/SUMNet_new/Adam_1e-4_ep100_CE+Lov/'\n", 44 | "netS = SUMNet(in_ch=1,out_ch=2)\n", 45 | "netS.load_state_dict(torch.load(seg_model_loadPath+'sumnet_best.pt'))\n", 46 | "netS = netS.cuda()\n", 47 | "apply_norm = transforms.Normalize([-460.466],[444.421]) " 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 14, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "Subset: 3\n" 60 | ] 61 | }, 62 | { 63 | "data": { 64 | "application/vnd.jupyter.widget-view+json": { 65 | "model_id": "a52ffa83aeed4d4bb601e285c04e9fa9", 66 | "version_major": 2, 67 | "version_minor": 0 68 | }, 69 | "text/plain": [ 70 | "A Jupyter Widget" 71 | ] 72 | }, 73 | "metadata": {}, 74 | "output_type": "display_data" 75 | }, 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "cand_path = \"/home/siplab/rachana/rak/dataset/candidates.csv\"\n", 86 | "b_sz = 8\n", 87 | "df_node = pd.read_csv(cand_path)\n", 88 | "subset = ['3']#,'5']\n", 89 | "running_correct = 0\n", 90 | "count = 0\n", 91 | "\n", 92 | "orig_list = []\n", 93 | "pred_list = []\n", 94 | "for s in subset:\n", 95 | " print('Subset:',s)\n", 96 | " luna_subset_path = '/home/siplab/rachana/rak/dataset/subset'+str(s)+'/' \n", 97 | " all_files = os.listdir(luna_subset_path)\n", 98 | " mhd_files = []\n", 99 | " for f in all_files:\n", 100 | " if '.mhd' in f:\n", 101 | " mhd_files.append(f)\n", 102 | " count = 0\n", 103 | " for m in tq(mhd_files): \n", 104 | " mini_df = df_node[df_node[\"seriesuid\"]==m[:-4]]\n", 105 | " itk_img = sitk.ReadImage(luna_subset_path+m) \n", 106 | " img_array = sitk.GetArrayFromImage(itk_img)\n", 107 | " origin = np.array(itk_img.GetOrigin()) # x,y,z Origin in world coordinates (mm)\n", 108 | " spacing = np.array(itk_img.GetSpacing()) \n", 109 | " slice_list = []\n", 110 | " if len(mini_df)>0:\n", 111 | " for i in range(len(mini_df)):\n", 112 | " fName = mini_df['seriesuid'].values[i]\n", 113 | " z_coord = mini_df['coordZ'].values[i]\n", 114 | " orig_class = mini_df['class'].values[i]\n", 115 | " pred = 0\n", 116 | " v_center =np.rint((z_coord-origin[2])/spacing[2]) \n", 117 | " img_slice = img_array[int(v_center)]\n", 118 | " mid_mean = img_slice[100:400,100:400].mean() \n", 119 | " img_slice[img_slice==img_slice.min()] = mid_mean\n", 120 | " img_slice[img_slice==img_slice.max()] = mid_mean\n", 121 | " img_slice_tensor = torch.from_numpy(img_slice).unsqueeze(0).float()\n", 122 | " img_slice_norm = apply_norm(img_slice_tensor).unsqueeze(0)\n", 123 | " \n", 124 | " out = F.softmax(netS(img_slice_norm.cuda()),dim=1)\n", 125 | " out_np = np.asarray(out[0,1].squeeze(0).detach().cpu().numpy()*255,dtype=np.uint8)\n", 126 | "\n", 127 | " ret, thresh = cv2.threshold(out_np,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)\n", 128 | " connectivity = 4 \n", 129 | " output = cv2.connectedComponentsWithStats(thresh, connectivity, cv2.CV_32S)\n", 130 | " stats = output[2]\n", 131 | " temp = stats[1:, cv2.CC_STAT_AREA]\n", 132 | " if len(temp)>0:\n", 133 | " largest_label = 1 + np.argmax(temp) \n", 134 | " areas = stats[1:, cv2.CC_STAT_AREA]\n", 135 | " max_area = np.max(areas)\n", 136 | " if max_area>150:\n", 137 | " pred = 1\n", 138 | " if pred == orig_class: \n", 139 | " running_correct += 1\n", 140 | " pred_list.append(pred)\n", 141 | " orig_list.append(orig_class)\n", 142 | " count += 1 " 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 15, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "Accuarcy: 92.38076137689615\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "print('Accuarcy:',(running_correct/count)*100)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 18, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "from sklearn.metrics import confusion_matrix" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "cf = confusion_matrix(orig_list, pred_list)\n", 178 | "tn, fp, fn, tp = cf.ravel()" 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "execution_count": 20, 184 | "metadata": {}, 185 | "outputs": [ 186 | { 187 | "data": { 188 | "text/plain": [ 189 | "array([[50641, 4066],\n", 190 | " [ 113, 28]])" 191 | ] 192 | }, 193 | "execution_count": 20, 194 | "metadata": {}, 195 | "output_type": "execute_result" 196 | } 197 | ], 198 | "source": [ 199 | "cf" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 23, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 25, 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "Sensitivity: 0.19858156028368795\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "sensitivity = tp/(tp+fn)\n", 224 | "print('Sensitivity:',sensitivity)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": 29, 230 | "metadata": {}, 231 | "outputs": [ 232 | { 233 | "name": "stdout", 234 | "output_type": "stream", 235 | "text": [ 236 | "Specificity: 0.9256767872484326\n" 237 | ] 238 | } 239 | ], 240 | "source": [ 241 | "specificity = tn/(tn+fp)\n", 242 | "print('Specificity:',specificity)" 243 | ] 244 | }, 245 | { 246 | "cell_type": "code", 247 | "execution_count": 30, 248 | "metadata": {}, 249 | "outputs": [ 250 | { 251 | "name": "stdout", 252 | "output_type": "stream", 253 | "text": [ 254 | "Precision: 0.006839276990718124\n" 255 | ] 256 | } 257 | ], 258 | "source": [ 259 | "precision = tp/(tp+fp)\n", 260 | "print('Precision:',precision)" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [] 269 | } 270 | ], 271 | "metadata": { 272 | "kernelspec": { 273 | "display_name": "Python 3", 274 | "language": "python", 275 | "name": "python3" 276 | }, 277 | "language_info": { 278 | "codemirror_mode": { 279 | "name": "ipython", 280 | "version": 3 281 | }, 282 | "file_extension": ".py", 283 | "mimetype": "text/x-python", 284 | "name": "python", 285 | "nbconvert_exporter": "python", 286 | "pygments_lexer": "ipython3", 287 | "version": "3.6.6" 288 | } 289 | }, 290 | "nbformat": 4, 291 | "nbformat_minor": 2 292 | } 293 | -------------------------------------------------------------------------------- /Plots/luna_inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from SUMNet_bn import SUMNet 7 | import SimpleITK as sitk 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | from glob import glob 11 | import os 12 | 13 | def load_itk_image(filename): 14 | itkimage = sitk.ReadImage(filename) 15 | numpyImage = sitk.GetArrayFromImage(itkimage) 16 | 17 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 18 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 19 | return numpyImage, numpyOrigin, numpySpacing 20 | 21 | luna_subset_path = '/home/siplab/rachana/rak/dataset/subset3/' 22 | img_file = '/home/siplab/rachana/rak/dataset/subset3/1.3.6.1.4.1.14519.5.2.1.6279.6001.292057261351416339496913597985.mhd' 23 | itk_img = sitk.ReadImage(img_file) 24 | img_tensor = torch.from_numpy(sitk.GetArrayFromImage(itk_img)).unsqueeze(1).float() 25 | # normalization to [0,1] 26 | img_tensor_norm = img_tensor-img_tensor.min() 27 | img_tensor_norm = img_tensor_norm/img_tensor_norm.max() 28 | 29 | 30 | seg_model_loadPath = '/home/siplab/rachana/rak/Results/SUMNet/Adam_1e-4_ep100/' 31 | netS = SUMNet(in_ch=1,out_ch=2) 32 | netS.load_state_dict(torch.load(seg_model_loadPath+'sumnet_cpu.pt')) 33 | 34 | # netS = netS.cuda() 35 | savePath = seg_model_loadPath+'seg_results/' 36 | if not os.path.isdir(savePath): 37 | os.makedirs(savePath) 38 | 39 | for sliceNum in range(img_tensor_norm.shape[0]): 40 | img_slice = img_tensor_norm[sliceNum].unsqueeze(0) 41 | out = F.softmax(netS(img_slice),dim=1) 42 | 43 | plt.figure(figsize=[15,5]) 44 | plt.subplot(121) 45 | plt.imshow(img_slice.squeeze(0).squeeze(0).numpy(),cmap='gray') 46 | plt.title('Original Image') 47 | plt.subplot(122) 48 | plt.imshow(out[0,1].squeeze(0).detach().numpy(),cmap='gray') 49 | plt.title('Segmented Nodules') 50 | plt.savefig(savePath+'results_slice_'+str(sliceNum)+'.png') 51 | plt.close() 52 | -------------------------------------------------------------------------------- /Plots/new_test.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from SUMNet_bn import SUMNet 7 | import SimpleITK as sitk 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | from glob import glob 11 | from torchvision import transforms 12 | 13 | def load_itk_image(filename): 14 | itkimage = sitk.ReadImage(filename) 15 | numpyImage = sitk.GetArrayFromImage(itkimage) 16 | 17 | numpyOrigin = np.array(list(reversed(itkimage.GetOrigin()))) 18 | numpySpacing = np.array(list(reversed(itkimage.GetSpacing()))) 19 | return numpyImage, numpyOrigin, numpySpacing 20 | 21 | def get_filename(case): 22 | global file_list 23 | for f in file_list: 24 | if case in f: 25 | return(f) 26 | 27 | 28 | luna_subset_path = '/home/siplab/rachana/rak/dataset/subset3/' 29 | result_path = '/home/siplab/rachana/rak/img_results/' 30 | img_file = '/home/siplab/rachana/rak/dataset/subset3/1.3.6.1.4.1.14519.5.2.1.6279.6001.244681063194071446501270815660.mhd' 31 | itk_img = sitk.ReadImage(img_file) 32 | img_tensor = torch.from_numpy(sitk.GetArrayFromImage(itk_img)).unsqueeze(1).float() 33 | 34 | 35 | seg_model_loadPath = '/home/siplab/rachana/rak/Results/SUMNet/Adam_1e-4_ep100/' 36 | netS = SUMNet(in_ch=1,out_ch=2) 37 | netS.load_state_dict(torch.load(seg_model_loadPath+'sumnet_best.pt')) 38 | 39 | apply_norm = transforms.Normalize([-460.466],[444.421]) 40 | N = int(img_tensor.shape[0]*0.5) 41 | for sliceNum in range(N-5,N+5): 42 | img_slice = img_tensor[sliceNum] 43 | mid_mean = img_slice[:,100:400,100:400].mean() 44 | img_slice[img_slice==img_slice.min()] = mid_mean 45 | img_slice[img_slice==img_slice.max()] = mid_mean 46 | img_slice_norm = apply_norm(img_slice).unsqueeze(0) 47 | 48 | out = F.softmax(netS(img_slice_norm),dim=1) 49 | out_np = np.asarray(out[0,1].squeeze(0).detach().cpu().numpy()*255,dtype=np.uint8) 50 | 51 | ret, thresh = cv2.threshold(out_np,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU) 52 | connectivity = 4 53 | output = cv2.connectedComponentsWithStats(thresh, connectivity, cv2.CV_32S) 54 | stats = output[2] 55 | temp = stats[1:, cv2.CC_STAT_AREA] 56 | if len(temp)>0: 57 | largest_label = 1 + np.argmax(temp) 58 | areas = stats[1:, cv2.CC_STAT_AREA] 59 | max_area = np.max(areas) 60 | if max_area>150: 61 | print('Slice:',sliceNum+1) 62 | out_mask = np.zeros((512,512)) 63 | idx = np.where(output[1]==largest_label) 64 | out_mask[idx] = 1 65 | plt.figure(figsize=[15,5]) 66 | plt.subplot(131) 67 | plt.imshow(img_slice.squeeze(0).squeeze(0).numpy(),cmap='gray') 68 | plt.title('Original image') 69 | plt.subplot(132) 70 | plt.imshow(out[0,1].squeeze(0).detach().numpy(),cmap='gray') 71 | plt.title('Segmented regions') 72 | plt.subplot(133) 73 | plt.imshow(out_mask,cmap='gray') 74 | plt.title('Detected largest nodule') 75 | plt.savefig(result_path+'slice_'+str(sliceNum+1)+'.png') 76 | plt.close() 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lung-nodule-detection-LUNA-16 2 | 3 | This Github repository,has the code used as part of my Bachelor's in technology main-project. The purpose of this code is to detect nodules in a CT scan of lung and subsequently to classify them as being benign, malignant. 4 | 5 | Abstract: 6 | 7 | Abstract—Lung cancer is one of the leading cause for cancer related death in the world. Early detection of the tumor is 8 | a crucial part of giving patients the best chance of recovery. However, analysis and cure of lung malignancy have been one 9 | of the greatest difficulties faced by humans over the most recent couple of decades. Deep learning gives us to increase the 10 | accuracy of the automated initial diagnosis. This project uses an approach that utilizes a network with features of U-Net architecture to classify cancer nodules as benign or malignant with an accuracy of 92.38 and a low percentage of false positives(<10%). 11 | 12 | Dataset used: LUNA 16 13 | 14 | Trained model and results have not been uploaded in the repo due to its size. 15 | 16 | Front-end of the CAD system is in the repo mentioned below. 17 | https://github.com/Soumya-Raj/Main-project 18 | 19 | Repo organization: 20 | 21 | data_prep :- Directory contains the code used to prepare the LUNA16 dataset for training. 22 | 23 | train_codes :- Directory contains the code used to train the network. 24 | 25 | plots:- Directory contains the scripts to evaluate the network and also the dice plots. 26 | 27 | -------------------------------------------------------------------------------- /data_prep/LUNA_mean_std.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import os\n", 11 | "from tqdm import tqdm_notebook as tq" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "loadPath = '../data/train/images/'\n", 21 | "trFiles = os.listdir(loadPath)\n", 22 | "print(len(trFiles))" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "temp = np.load(loadPath+trFiles[0])\n", 32 | "print(temp.shape)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "totalData = np.zeros((len(trFiles),512,512))\n", 42 | "for n in tq(range(len(trFiles))):\n", 43 | " totalData[n] = np.load(loadPath+trFiles[n])" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "print('Mean:',totalData.mean(),'Std:',totalData.std())" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "print('Min:',totalData.min(),'Max:',totalData.max())" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [] 70 | } 71 | ], 72 | "metadata": { 73 | "kernelspec": { 74 | "display_name": "Python 3", 75 | "language": "python", 76 | "name": "python3" 77 | }, 78 | "language_info": { 79 | "codemirror_mode": { 80 | "name": "ipython", 81 | "version": 3 82 | }, 83 | "file_extension": ".py", 84 | "mimetype": "text/x-python", 85 | "name": "python", 86 | "nbconvert_exporter": "python", 87 | "pygments_lexer": "ipython3", 88 | "version": "3.6.8" 89 | } 90 | }, 91 | "nbformat": 4, 92 | "nbformat_minor": 2 93 | } 94 | -------------------------------------------------------------------------------- /data_prep/test_pts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rakshith2597/Lung-nodule-detection-LUNA-16/4114784ad6c06467db8c106c266a1f128016f920/data_prep/test_pts.npy -------------------------------------------------------------------------------- /data_prep/train_pts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rakshith2597/Lung-nodule-detection-LUNA-16/4114784ad6c06467db8c106c266a1f128016f920/data_prep/train_pts.npy -------------------------------------------------------------------------------- /data_prep/val_pts.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rakshith2597/Lung-nodule-detection-LUNA-16/4114784ad6c06467db8c106c266a1f128016f920/data_prep/val_pts.npy -------------------------------------------------------------------------------- /train_codes/LUNA_loader.py: -------------------------------------------------------------------------------- 1 | #Code written by Rakshith Sathish 2 | #The work is made public with MIT License 3 | 4 | import os 5 | import collections 6 | import torch 7 | import numpy as np 8 | import scipy.misc as m 9 | import matplotlib.pyplot as plt 10 | from PIL import Image 11 | from torchvision import transforms 12 | 13 | 14 | from torch.utils import data 15 | 16 | 17 | class lunaLoader(data.Dataset): 18 | def __init__(self,split="train",is_transform=True,img_size=512): 19 | self.split = split 20 | self.path= "/home/rak/data/"+self.split 21 | self.is_transform = is_transform 22 | self.img_size = img_size 23 | self.files = os.listdir(self.path+'/images/') # [image1_img.npy, image2_img.npy] 24 | 25 | self.img_tf = transforms.Compose( 26 | [ transforms.Resize(self.img_size), 27 | transforms.ToTensor(), 28 | transforms.Normalize([-460.466],[444.421]) 29 | ]) 30 | 31 | self.label_tf = transforms.Compose( 32 | [ 33 | transforms.Resize(self.img_size,interpolation=0), 34 | transforms.ToTensor(), 35 | ]) 36 | 37 | 38 | 39 | def __len__(self): 40 | return len(self.files) 41 | 42 | def __getitem__(self,index): 43 | fname = self.files[index] # image1_img.npy, image1_label.npy 44 | img = Image.fromarray(np.load(self.path+'/images/'+fname).astype(float)) 45 | im_id = fname.split('_')[1] 46 | label = Image.fromarray(np.load(self.path+'/labels_small/masks_'+im_id)) 47 | 48 | if self.is_transform: 49 | img, label = self.transform(img,label) 50 | 51 | return img, label.squeeze(0) 52 | 53 | def transform(self,img,label): 54 | img = self.img_tf(img) 55 | label = self.label_tf(label) 56 | 57 | return img,label 58 | 59 | -------------------------------------------------------------------------------- /train_codes/lovasz_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | """ 5 | 6 | from __future__ import print_function, division 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import numpy as np 12 | try: 13 | from itertools import ifilterfalse 14 | except ImportError: # py3k 15 | from itertools import filterfalse as ifilterfalse 16 | 17 | 18 | def lovasz_grad(gt_sorted): 19 | """ 20 | Computes gradient of the Lovasz extension w.r.t sorted errors 21 | See Alg. 1 in paper 22 | """ 23 | p = len(gt_sorted) 24 | gts = gt_sorted.sum() 25 | intersection = gts - gt_sorted.float().cumsum(0) 26 | union = gts + (1 - gt_sorted).float().cumsum(0) 27 | jaccard = 1. - intersection / union 28 | if p > 1: # cover 1-pixel case 29 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 30 | return jaccard 31 | 32 | 33 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 34 | """ 35 | IoU for foreground class 36 | binary: 1 foreground, 0 background 37 | """ 38 | if not per_image: 39 | preds, labels = (preds,), (labels,) 40 | ious = [] 41 | for pred, label in zip(preds, labels): 42 | intersection = ((label == 1) & (pred == 1)).sum() 43 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 44 | if not union: 45 | iou = EMPTY 46 | else: 47 | iou = float(intersection) / float(union) 48 | ious.append(iou) 49 | iou = mean(ious) # mean accross images if per_image 50 | return 100 * iou 51 | 52 | 53 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 54 | """ 55 | Array of IoU for each (non ignored) class 56 | """ 57 | if not per_image: 58 | preds, labels = (preds,), (labels,) 59 | ious = [] 60 | for pred, label in zip(preds, labels): 61 | iou = [] 62 | for i in range(C): 63 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 64 | intersection = ((label == i) & (pred == i)).sum() 65 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 66 | if not union: 67 | iou.append(EMPTY) 68 | else: 69 | iou.append(float(intersection) / float(union)) 70 | ious.append(iou) 71 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 72 | return 100 * np.array(ious) 73 | 74 | 75 | # --------------------------- BINARY LOSSES --------------------------- 76 | 77 | 78 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 79 | """ 80 | Binary Lovasz hinge loss 81 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 82 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 83 | per_image: compute the loss per image instead of per batch 84 | ignore: void class id 85 | """ 86 | if per_image: 87 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 88 | for log, lab in zip(logits, labels)) 89 | else: 90 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 91 | return loss 92 | 93 | 94 | def lovasz_hinge_flat(logits, labels): 95 | """ 96 | Binary Lovasz hinge loss 97 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 98 | labels: [P] Tensor, binary ground truth labels (0 or 1) 99 | ignore: label to ignore 100 | """ 101 | if len(labels) == 0: 102 | # only void pixels, the gradients should be 0 103 | return logits.sum() * 0. 104 | signs = 2. * labels.float() - 1. 105 | errors = (1. - logits * Variable(signs)) 106 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 107 | perm = perm.data 108 | gt_sorted = labels[perm] 109 | grad = lovasz_grad(gt_sorted) 110 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 111 | return loss 112 | 113 | 114 | def flatten_binary_scores(scores, labels, ignore=None): 115 | """ 116 | Flattens predictions in the batch (binary case) 117 | Remove labels equal to 'ignore' 118 | """ 119 | scores = scores.view(-1) 120 | labels = labels.view(-1) 121 | if ignore is None: 122 | return scores, labels 123 | valid = (labels != ignore) 124 | vscores = scores[valid] 125 | vlabels = labels[valid] 126 | return vscores, vlabels 127 | 128 | 129 | class StableBCELoss(torch.nn.modules.Module): 130 | def __init__(self): 131 | super(StableBCELoss, self).__init__() 132 | def forward(self, input, target): 133 | neg_abs = - input.abs() 134 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 135 | return loss.mean() 136 | 137 | 138 | def binary_xloss(logits, labels, ignore=None): 139 | """ 140 | Binary Cross entropy loss 141 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 142 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 143 | ignore: void class id 144 | """ 145 | logits, labels = flatten_binary_scores(logits, labels, ignore) 146 | loss = StableBCELoss()(logits, Variable(labels.float())) 147 | return loss 148 | 149 | 150 | # --------------------------- MULTICLASS LOSSES --------------------------- 151 | 152 | 153 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 154 | """ 155 | Multi-class Lovasz-Softmax loss 156 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 157 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 158 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 159 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 160 | per_image: compute the loss per image instead of per batch 161 | ignore: void class labels 162 | """ 163 | if per_image: 164 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 165 | for prob, lab in zip(probas, labels)) 166 | else: 167 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 168 | return loss 169 | 170 | 171 | def lovasz_softmax_flat(probas, labels, classes='present'): 172 | """ 173 | Multi-class Lovasz-Softmax loss 174 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 175 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 176 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 177 | """ 178 | if probas.numel() == 0: 179 | # only void pixels, the gradients should be 0 180 | return probas * 0. 181 | C = probas.size(1) 182 | losses = [] 183 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 184 | for c in class_to_sum: 185 | fg = (labels == c).float() # foreground for class c 186 | if (classes is 'present' and fg.sum() == 0): 187 | continue 188 | if C == 1: 189 | if len(classes) > 1: 190 | raise ValueError('Sigmoid output possible only with 1 class') 191 | class_pred = probas[:, 0] 192 | else: 193 | class_pred = probas[:, c] 194 | errors = (Variable(fg) - class_pred).abs() 195 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 196 | perm = perm.data 197 | fg_sorted = fg[perm] 198 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 199 | return mean(losses) 200 | 201 | 202 | def flatten_probas(probas, labels, ignore=None): 203 | """ 204 | Flattens predictions in the batch 205 | """ 206 | if probas.dim() == 3: 207 | # assumes output of a sigmoid layer 208 | B, H, W = probas.size() 209 | probas = probas.view(B, 1, H, W) 210 | B, C, H, W = probas.size() 211 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 212 | labels = labels.view(-1) 213 | if ignore is None: 214 | return probas, labels 215 | valid = (labels != ignore) 216 | vprobas = probas[valid.nonzero().squeeze()] 217 | vlabels = labels[valid] 218 | return vprobas, vlabels 219 | 220 | def xloss(logits, labels, ignore=None): 221 | """ 222 | Cross entropy loss 223 | """ 224 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 225 | 226 | 227 | # --------------------------- HELPER FUNCTIONS --------------------------- 228 | def isnan(x): 229 | return x != x 230 | 231 | 232 | def mean(l, ignore_nan=False, empty=0): 233 | """ 234 | nanmean compatible with generators. 235 | """ 236 | l = iter(l) 237 | if ignore_nan: 238 | l = ifilterfalse(isnan, l) 239 | try: 240 | n = 1 241 | acc = next(l) 242 | except StopIteration: 243 | if empty == 'raise': 244 | raise ValueError('Empty mean') 245 | return empty 246 | for n, v in enumerate(l, 2): 247 | acc += v 248 | if n == 1: 249 | return acc 250 | return acc / n -------------------------------------------------------------------------------- /train_codes/train_sumnet_luna_CE_Lov.py: -------------------------------------------------------------------------------- 1 | #Code written by Rakshith Sathish 2 | #The work is made public with MIT License 3 | 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch import optim 9 | import tqdm 10 | import time 11 | from torch.utils import data 12 | import os 13 | import torch.nn 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | import matplotlib.pyplot as plt 17 | plt.switch_backend('agg') 18 | from sklearn.metrics import confusion_matrix 19 | from SUMNet_bn import SUMNet 20 | from LUNA_loader import lunaLoader 21 | import lovasz_losses as L 22 | 23 | def dice_coefficient(pred, target): 24 | predC = torch.argmax(F.softmax(pred,dim=1),dim=1) 25 | c = confusion_matrix(target.view(-1).cpu().numpy(), predC.view(-1).cpu().numpy(),labels=[0,1]) 26 | TP = np.diag(c) 27 | FP = c.sum(axis=0) - np.diag(c) 28 | FN = c.sum(axis=1) - np.diag(c) 29 | TN = c.sum() - (FP + FN + TP) 30 | return (TP,FP,FN) 31 | 32 | 33 | 34 | savePath = 'Results/SUMNet_new/Adam_1e-4_ep100_CE+Lov/' 35 | if not os.path.isdir(savePath): 36 | os.makedirs(savePath) 37 | 38 | 39 | trainDset = lunaLoader(is_transform=True, split='train',img_size=256) 40 | valDset = lunaLoader(is_transform=True, split='val',img_size=256) 41 | 42 | trainDataLoader = data.DataLoader(trainDset,batch_size=16,shuffle=True,num_workers=4,pin_memory=True) 43 | validDataLoader = data.DataLoader(valDset,batch_size=16,shuffle=False,num_workers=4,pin_memory=True) 44 | 45 | n_classes = 2 46 | net = SUMNet(in_ch=1,out_ch=n_classes) 47 | 48 | use_gpu = torch.cuda.is_available() 49 | if use_gpu: 50 | net = net.cuda() 51 | 52 | 53 | optimizerS = optim.Adam(net.parameters(), lr = 1e-4, weight_decay = 1e-5) 54 | criterionS = nn.CrossEntropyLoss() 55 | 56 | 57 | 58 | epochs = 100 59 | trainLoss = [] 60 | validLoss = [] 61 | trainDiceCoeff = [] 62 | validDiceCoeff = [] 63 | start = time.time() 64 | 65 | bestValidDice = 0.0 66 | 67 | for epoch in range(epochs): 68 | epochStart = time.time() 69 | trainRunningLoss = 0 70 | validRunningLoss = 0 71 | trainBatches = 0 72 | validBatches = 0 73 | 74 | train_tp = np.zeros(n_classes) 75 | train_fp = np.zeros(n_classes) 76 | train_fn = np.zeros(n_classes) 77 | 78 | val_tp = np.zeros(n_classes) 79 | val_fp = np.zeros(n_classes) 80 | val_fn = np.zeros(n_classes) 81 | 82 | 83 | net.train(True) 84 | for data1 in tqdm.tqdm(trainDataLoader): 85 | imgs, mask = data1 86 | # print(imgs.shape) 87 | if use_gpu: 88 | inputs = imgs.cuda() 89 | labels = mask.cuda() 90 | 91 | 92 | cpmap = net(Variable(inputs)) 93 | cpmapD = F.softmax(cpmap,dim=1) 94 | 95 | LGce = criterionS(cpmap,labels.long()) 96 | L_lov = L.lovasz_softmax(F.softmax(cpmap,dim=1),labels) 97 | LGseg = LGce+L_lov 98 | 99 | optimizerS.zero_grad() 100 | 101 | LGseg.backward() 102 | 103 | optimizerS.step() 104 | 105 | trainRunningLoss += LGseg.item() 106 | 107 | train_cf = dice_coefficient(cpmapD,labels) 108 | train_tp += train_cf[0] 109 | train_fp += train_cf[1] 110 | train_fn += train_cf[2] 111 | trainBatches += 1 112 | # break 113 | 114 | 115 | train_dice = (2*train_tp)/(2*train_tp + train_fp + train_fn ) 116 | trainLoss.append(trainRunningLoss/trainBatches) 117 | trainDiceCoeff.append(train_dice) 118 | 119 | print("\n{}][{}]| LGseg: {:.4f} | " 120 | .format(epoch,epochs,LGseg.item())) 121 | 122 | with torch.no_grad(): 123 | for data1 in tqdm.tqdm(validDataLoader): 124 | imgs, mask = data1 125 | if use_gpu: 126 | inputs = imgs.cuda() 127 | labels = mask.cuda() 128 | 129 | 130 | cpmap = net(Variable(inputs)) 131 | cpmapD = F.softmax(cpmap.data,dim=1) 132 | 133 | val_cf = dice_coefficient(cpmapD,labels) 134 | val_tp += val_cf[0] 135 | val_fp += val_cf[1] 136 | val_fn += val_cf[2] 137 | validRunningLoss += LGseg.item() 138 | validBatches += 1 139 | # break 140 | 141 | 142 | val_dice = (2*val_tp)/(2*val_tp + val_fp + val_fn ) 143 | validLoss.append(validRunningLoss/validBatches) 144 | validDiceCoeff.append(val_dice) 145 | # scheduler.step(validRunningLoss/validBatches) 146 | if (val_dice[1] > bestValidDice): 147 | bestValidDice = val_dice[1] 148 | torch.save(net.state_dict(), savePath+'sumnet_best.pt') 149 | 150 | 151 | plt.figure() 152 | plt.plot(range(len(trainLoss)),trainLoss,'-r',label='Train') 153 | plt.plot(range(len(validLoss)),validLoss,'-g',label='Valid') 154 | if epoch==0: 155 | plt.legend() 156 | plt.savefig(savePath+'LossPlot.png') 157 | plt.close() 158 | epochEnd = time.time()-epochStart 159 | print('Epoch: {:.0f}/{:.0f} | Train Loss: {:.3f} | Valid Loss: {:.3f}'\ 160 | .format(epoch+1, epochs, trainRunningLoss/trainBatches, validRunningLoss/validBatches)) 161 | print('\nDice | Train | BG {:.3f} | Nodule {:.3f} |\n Valid | BG: {:.3f} | Nodule {:.3f} |' 162 | .format(train_dice[0],train_dice[1], val_dice[0], val_dice[1])) 163 | 164 | print('\nTime: {:.0f}m {:.0f}s'.format(epochEnd//60,epochEnd%60)) 165 | trainLoss_np = np.array(trainLoss) 166 | validLoss_np = np.array(validLoss) 167 | trainDiceCoeff_np = np.array(trainDiceCoeff) 168 | validDiceCoeff_np = np.array(validDiceCoeff) 169 | 170 | print('Saving losses') 171 | 172 | torch.save(trainLoss_np, savePath+'trainLoss.pt') 173 | torch.save(validLoss_np, savePath+'validLoss.pt') 174 | torch.save(trainDiceCoeff_np, savePath+'trainDice.pt') 175 | torch.save(validDiceCoeff_np, savePath+'validDice.pt') 176 | # break 177 | 178 | 179 | end = time.time()-start 180 | print('Training completed in {:.0f}m {:.0f}s'.format(end//60,end%60)) 181 | plt.figure() 182 | plt.plot(range(len(trainLoss)),trainLoss,'-r') 183 | plt.plot(range(len(validLoss)),validLoss,'-g') 184 | plt.title('Loss plot') 185 | plt.savefig(savePath+'trainLossFinal.png') 186 | plt.close() 187 | 188 | trainDiceCoeff_bg = [x[0] for x in trainDiceCoeff] 189 | trainDiceCoeff_nodule = [x[1] for x in trainDiceCoeff] 190 | plt.figure() 191 | plt.plot(range(len(trainDiceCoeff_bg)),trainDiceCoeff_bg,'-r',label='BG') 192 | plt.plot(range(len(trainDiceCoeff_nodule)),trainDiceCoeff_nodule,'-g',label='Nodule') 193 | plt.legend() 194 | plt.title('Dice coefficient: Train') 195 | plt.savefig(savePath+'trainDice.png') 196 | plt.close() 197 | 198 | validDiceCoeff_bg = [x[0] for x in validDiceCoeff] 199 | validDiceCoeff_nodule = [x[1] for x in validDiceCoeff] 200 | plt.figure() 201 | plt.plot(range(len(validDiceCoeff_bg)),validDiceCoeff_bg,'-r',label='BG') 202 | plt.plot(range(len(validDiceCoeff_nodule)),validDiceCoeff_nodule,'-g',label='Nodule') 203 | plt.legend() 204 | plt.title('Dice coefficient: Valid') 205 | plt.savefig(savePath+'validDice.png') 206 | plt.close() 207 | --------------------------------------------------------------------------------