├── 2d_from_3d.py ├── CODE_OF_CONDUCT.md ├── Data_Loader.py ├── LICENSE ├── Metrics.py ├── Models.py ├── README.md ├── dice.png ├── images ├── att-r2u.png ├── att-unet.png ├── filt1.png ├── in1.png ├── in2.png ├── l2.png ├── nested.jpg ├── r2unet.png ├── tensorb.png └── unet1.png ├── losses.py ├── ploting.py ├── pytorch_run.py ├── pytorch_run_old.py └── requirements.txt /2d_from_3d.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import scipy.misc 3 | 4 | import SimpleITK as sitk #reading MR images 5 | 6 | import glob 7 | 8 | 9 | readfolderT = glob.glob('/home/bat161/Desktop/Thesis/EADC_HHP/*_MNI.nii.gz') 10 | readfolderL = glob.glob('/home/bat161/Desktop/Thesis/EADC_HHP/*_HHP_EADC.nii.gz') 11 | 12 | 13 | TrainingImagesList = [] 14 | TrainingLabelsList = [] 15 | 16 | 17 | for i in range(len(readfolderT)): 18 | y_folder = readfolderT[i] 19 | yread = sitk.ReadImage(y_folder) 20 | yimage = sitk.GetArrayFromImage(yread) 21 | x = yimage[:184,:232,112:136] 22 | x = scipy.rot90(x) 23 | x = scipy.rot90(x) 24 | for j in range(x.shape[2]): 25 | TrainingImagesList.append((x[:184,:224,j])) 26 | 27 | for i in range(len(readfolderL)): 28 | y_folder = readfolderL[i] 29 | yread = sitk.ReadImage(y_folder) 30 | yimage = sitk.GetArrayFromImage(yread) 31 | x = yimage[:184,:232,112:136] 32 | x = scipy.rot90(x) 33 | x = scipy.rot90(x) 34 | for j in range(x.shape[2]): 35 | TrainingLabelsList.append((x[:184,:224,j])) 36 | 37 | for i in range(len(TrainingImagesList)): 38 | 39 | xchangeL = TrainingImagesList[i] 40 | xchangeL = cv2.resize(xchangeL,(128,128)) 41 | scipy.misc.imsave('/home/bat161/Desktop/Thesis/Image/png_1C_images/'+str(i)+'.png',xchangeL) 42 | 43 | for i in range(len(TrainingLabelsList)): 44 | 45 | xchangeL = TrainingLabelsList[i] 46 | xchangeL = cv2.resize(xchangeL,(128,128)) 47 | scipy.misc.imsave('/home/bat161/Desktop/Thesis/Image/png_1C_labels/'+str(i)+'.png',xchangeL) -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at malav.b93@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /Data_Loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import torch 5 | import torch.utils.data 6 | import torchvision 7 | from skimage import io 8 | from torch.utils.data import Dataset 9 | import random 10 | import numpy as np 11 | 12 | 13 | class Images_Dataset(Dataset): 14 | """Class for getting data as a Dict 15 | Args: 16 | images_dir = path of input images 17 | labels_dir = path of labeled images 18 | transformI = Input Images transformation (default: None) 19 | transformM = Input Labels transformation (default: None) 20 | Output: 21 | sample : Dict of images and labels""" 22 | 23 | def __init__(self, images_dir, labels_dir, transformI = None, transformM = None): 24 | 25 | self.labels_dir = labels_dir 26 | self.images_dir = images_dir 27 | self.transformI = transformI 28 | self.transformM = transformM 29 | 30 | def __len__(self): 31 | return len(self.images_dir) 32 | 33 | def __getitem__(self, idx): 34 | 35 | for i in range(len(self.images_dir)): 36 | image = io.imread(self.images_dir[i]) 37 | label = io.imread(self.labels_dir[i]) 38 | if self.transformI: 39 | image = self.transformI(image) 40 | if self.transformM: 41 | label = self.transformM(label) 42 | sample = {'images': image, 'labels': label} 43 | 44 | return sample 45 | 46 | 47 | class Images_Dataset_folder(torch.utils.data.Dataset): 48 | """Class for getting individual transformations and data 49 | Args: 50 | images_dir = path of input images 51 | labels_dir = path of labeled images 52 | transformI = Input Images transformation (default: None) 53 | transformM = Input Labels transformation (default: None) 54 | Output: 55 | tx = Transformed images 56 | lx = Transformed labels""" 57 | 58 | def __init__(self, images_dir, labels_dir,transformI = None, transformM = None): 59 | self.images = sorted(os.listdir(images_dir)) 60 | self.labels = sorted(os.listdir(labels_dir)) 61 | self.images_dir = images_dir 62 | self.labels_dir = labels_dir 63 | self.transformI = transformI 64 | self.transformM = transformM 65 | 66 | if self.transformI: 67 | self.tx = self.transformI 68 | else: 69 | self.tx = torchvision.transforms.Compose([ 70 | # torchvision.transforms.Resize((128,128)), 71 | torchvision.transforms.CenterCrop(96), 72 | torchvision.transforms.RandomRotation((-10,10)), 73 | # torchvision.transforms.RandomHorizontalFlip(), 74 | torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4), 75 | torchvision.transforms.ToTensor(), 76 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 77 | ]) 78 | 79 | if self.transformM: 80 | self.lx = self.transformM 81 | else: 82 | self.lx = torchvision.transforms.Compose([ 83 | # torchvision.transforms.Resize((128,128)), 84 | torchvision.transforms.CenterCrop(96), 85 | torchvision.transforms.RandomRotation((-10,10)), 86 | torchvision.transforms.Grayscale(), 87 | torchvision.transforms.ToTensor(), 88 | #torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0)) 89 | ]) 90 | 91 | def __len__(self): 92 | 93 | return len(self.images) 94 | 95 | def __getitem__(self, i): 96 | i1 = Image.open(self.images_dir + self.images[i]) 97 | l1 = Image.open(self.labels_dir + self.labels[i]) 98 | 99 | seed=np.random.randint(0,2**32) # make a seed with numpy generator 100 | 101 | # apply this seed to img tranfsorms 102 | random.seed(seed) 103 | torch.manual_seed(seed) 104 | img = self.tx(i1) 105 | 106 | # apply this seed to target/label tranfsorms 107 | random.seed(seed) 108 | torch.manual_seed(seed) 109 | label = self.lx(l1) 110 | 111 | 112 | 113 | return img, label 114 | 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Malav Bateriwala 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import spatial 3 | 4 | 5 | def dice_coeff(im1, im2, empty_score=1.0): 6 | """Calculates the dice coefficient for the images""" 7 | 8 | im1 = np.asarray(im1).astype(np.bool) 9 | im2 = np.asarray(im2).astype(np.bool) 10 | 11 | if im1.shape != im2.shape: 12 | raise ValueError("Shape mismatch: im1 and im2 must have the same shape.") 13 | 14 | im1 = im1 > 0.5 15 | im2 = im2 > 0.5 16 | 17 | im_sum = im1.sum() + im2.sum() 18 | if im_sum == 0: 19 | return empty_score 20 | 21 | # Compute Dice coefficient 22 | intersection = np.logical_and(im1, im2) 23 | #print(im_sum) 24 | 25 | return 2. * intersection.sum() / im_sum 26 | 27 | 28 | def numeric_score(prediction, groundtruth): 29 | """Computes scores: 30 | FP = False Positives 31 | FN = False Negatives 32 | TP = True Positives 33 | TN = True Negatives 34 | return: FP, FN, TP, TN""" 35 | 36 | FP = np.float(np.sum((prediction == 1) & (groundtruth == 0))) 37 | FN = np.float(np.sum((prediction == 0) & (groundtruth == 1))) 38 | TP = np.float(np.sum((prediction == 1) & (groundtruth == 1))) 39 | TN = np.float(np.sum((prediction == 0) & (groundtruth == 0))) 40 | 41 | return FP, FN, TP, TN 42 | 43 | 44 | def accuracy_score(prediction, groundtruth): 45 | """Getting the accuracy of the model""" 46 | 47 | FP, FN, TP, TN = numeric_score(prediction, groundtruth) 48 | N = FP + FN + TP + TN 49 | accuracy = np.divide(TP + TN, N) 50 | return accuracy * 100.0 -------------------------------------------------------------------------------- /Models.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data 5 | import torch 6 | 7 | 8 | class conv_block(nn.Module): 9 | """ 10 | Convolution Block 11 | """ 12 | def __init__(self, in_ch, out_ch): 13 | super(conv_block, self).__init__() 14 | 15 | self.conv = nn.Sequential( 16 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 17 | nn.BatchNorm2d(out_ch), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 20 | nn.BatchNorm2d(out_ch), 21 | nn.ReLU(inplace=True)) 22 | 23 | def forward(self, x): 24 | 25 | x = self.conv(x) 26 | return x 27 | 28 | 29 | class up_conv(nn.Module): 30 | """ 31 | Up Convolution Block 32 | """ 33 | def __init__(self, in_ch, out_ch): 34 | super(up_conv, self).__init__() 35 | self.up = nn.Sequential( 36 | nn.Upsample(scale_factor=2), 37 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 38 | nn.BatchNorm2d(out_ch), 39 | nn.ReLU(inplace=True) 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.up(x) 44 | return x 45 | 46 | 47 | class U_Net(nn.Module): 48 | """ 49 | UNet - Basic Implementation 50 | Paper : https://arxiv.org/abs/1505.04597 51 | """ 52 | def __init__(self, in_ch=3, out_ch=1): 53 | super(U_Net, self).__init__() 54 | 55 | n1 = 64 56 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 57 | 58 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 59 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 60 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) 61 | self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2) 62 | 63 | self.Conv1 = conv_block(in_ch, filters[0]) 64 | self.Conv2 = conv_block(filters[0], filters[1]) 65 | self.Conv3 = conv_block(filters[1], filters[2]) 66 | self.Conv4 = conv_block(filters[2], filters[3]) 67 | self.Conv5 = conv_block(filters[3], filters[4]) 68 | 69 | self.Up5 = up_conv(filters[4], filters[3]) 70 | self.Up_conv5 = conv_block(filters[4], filters[3]) 71 | 72 | self.Up4 = up_conv(filters[3], filters[2]) 73 | self.Up_conv4 = conv_block(filters[3], filters[2]) 74 | 75 | self.Up3 = up_conv(filters[2], filters[1]) 76 | self.Up_conv3 = conv_block(filters[2], filters[1]) 77 | 78 | self.Up2 = up_conv(filters[1], filters[0]) 79 | self.Up_conv2 = conv_block(filters[1], filters[0]) 80 | 81 | self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0) 82 | 83 | # self.active = torch.nn.Sigmoid() 84 | 85 | def forward(self, x): 86 | 87 | e1 = self.Conv1(x) 88 | 89 | e2 = self.Maxpool1(e1) 90 | e2 = self.Conv2(e2) 91 | 92 | e3 = self.Maxpool2(e2) 93 | e3 = self.Conv3(e3) 94 | 95 | e4 = self.Maxpool3(e3) 96 | e4 = self.Conv4(e4) 97 | 98 | e5 = self.Maxpool4(e4) 99 | e5 = self.Conv5(e5) 100 | 101 | d5 = self.Up5(e5) 102 | d5 = torch.cat((e4, d5), dim=1) 103 | 104 | d5 = self.Up_conv5(d5) 105 | 106 | d4 = self.Up4(d5) 107 | d4 = torch.cat((e3, d4), dim=1) 108 | d4 = self.Up_conv4(d4) 109 | 110 | d3 = self.Up3(d4) 111 | d3 = torch.cat((e2, d3), dim=1) 112 | d3 = self.Up_conv3(d3) 113 | 114 | d2 = self.Up2(d3) 115 | d2 = torch.cat((e1, d2), dim=1) 116 | d2 = self.Up_conv2(d2) 117 | 118 | out = self.Conv(d2) 119 | 120 | #d1 = self.active(out) 121 | 122 | return out 123 | 124 | 125 | class Recurrent_block(nn.Module): 126 | """ 127 | Recurrent Block for R2Unet_CNN 128 | """ 129 | def __init__(self, out_ch, t=2): 130 | super(Recurrent_block, self).__init__() 131 | 132 | self.t = t 133 | self.out_ch = out_ch 134 | self.conv = nn.Sequential( 135 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 136 | nn.BatchNorm2d(out_ch), 137 | nn.ReLU(inplace=True) 138 | ) 139 | 140 | def forward(self, x): 141 | for i in range(self.t): 142 | if i == 0: 143 | x = self.conv(x) 144 | out = self.conv(x + x) 145 | return out 146 | 147 | 148 | class RRCNN_block(nn.Module): 149 | """ 150 | Recurrent Residual Convolutional Neural Network Block 151 | """ 152 | def __init__(self, in_ch, out_ch, t=2): 153 | super(RRCNN_block, self).__init__() 154 | 155 | self.RCNN = nn.Sequential( 156 | Recurrent_block(out_ch, t=t), 157 | Recurrent_block(out_ch, t=t) 158 | ) 159 | self.Conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0) 160 | 161 | def forward(self, x): 162 | x1 = self.Conv(x) 163 | x2 = self.RCNN(x1) 164 | out = x1 + x2 165 | return out 166 | 167 | 168 | class R2U_Net(nn.Module): 169 | """ 170 | R2U-Unet implementation 171 | Paper: https://arxiv.org/abs/1802.06955 172 | """ 173 | def __init__(self, img_ch=3, output_ch=1, t=2): 174 | super(R2U_Net, self).__init__() 175 | 176 | n1 = 64 177 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 178 | 179 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 180 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 181 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 182 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) 183 | 184 | self.Upsample = nn.Upsample(scale_factor=2) 185 | 186 | self.RRCNN1 = RRCNN_block(img_ch, filters[0], t=t) 187 | 188 | self.RRCNN2 = RRCNN_block(filters[0], filters[1], t=t) 189 | 190 | self.RRCNN3 = RRCNN_block(filters[1], filters[2], t=t) 191 | 192 | self.RRCNN4 = RRCNN_block(filters[2], filters[3], t=t) 193 | 194 | self.RRCNN5 = RRCNN_block(filters[3], filters[4], t=t) 195 | 196 | self.Up5 = up_conv(filters[4], filters[3]) 197 | self.Up_RRCNN5 = RRCNN_block(filters[4], filters[3], t=t) 198 | 199 | self.Up4 = up_conv(filters[3], filters[2]) 200 | self.Up_RRCNN4 = RRCNN_block(filters[3], filters[2], t=t) 201 | 202 | self.Up3 = up_conv(filters[2], filters[1]) 203 | self.Up_RRCNN3 = RRCNN_block(filters[2], filters[1], t=t) 204 | 205 | self.Up2 = up_conv(filters[1], filters[0]) 206 | self.Up_RRCNN2 = RRCNN_block(filters[1], filters[0], t=t) 207 | 208 | self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0) 209 | 210 | # self.active = torch.nn.Sigmoid() 211 | 212 | 213 | def forward(self, x): 214 | 215 | e1 = self.RRCNN1(x) 216 | 217 | e2 = self.Maxpool(e1) 218 | e2 = self.RRCNN2(e2) 219 | 220 | e3 = self.Maxpool1(e2) 221 | e3 = self.RRCNN3(e3) 222 | 223 | e4 = self.Maxpool2(e3) 224 | e4 = self.RRCNN4(e4) 225 | 226 | e5 = self.Maxpool3(e4) 227 | e5 = self.RRCNN5(e5) 228 | 229 | d5 = self.Up5(e5) 230 | d5 = torch.cat((e4, d5), dim=1) 231 | d5 = self.Up_RRCNN5(d5) 232 | 233 | d4 = self.Up4(d5) 234 | d4 = torch.cat((e3, d4), dim=1) 235 | d4 = self.Up_RRCNN4(d4) 236 | 237 | d3 = self.Up3(d4) 238 | d3 = torch.cat((e2, d3), dim=1) 239 | d3 = self.Up_RRCNN3(d3) 240 | 241 | d2 = self.Up2(d3) 242 | d2 = torch.cat((e1, d2), dim=1) 243 | d2 = self.Up_RRCNN2(d2) 244 | 245 | out = self.Conv(d2) 246 | 247 | # out = self.active(out) 248 | 249 | return out 250 | 251 | 252 | class Attention_block(nn.Module): 253 | """ 254 | Attention Block 255 | """ 256 | 257 | def __init__(self, F_g, F_l, F_int): 258 | super(Attention_block, self).__init__() 259 | 260 | self.W_g = nn.Sequential( 261 | nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), 262 | nn.BatchNorm2d(F_int) 263 | ) 264 | 265 | self.W_x = nn.Sequential( 266 | nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), 267 | nn.BatchNorm2d(F_int) 268 | ) 269 | 270 | self.psi = nn.Sequential( 271 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 272 | nn.BatchNorm2d(1), 273 | nn.Sigmoid() 274 | ) 275 | 276 | self.relu = nn.ReLU(inplace=True) 277 | 278 | def forward(self, g, x): 279 | g1 = self.W_g(g) 280 | x1 = self.W_x(x) 281 | psi = self.relu(g1 + x1) 282 | psi = self.psi(psi) 283 | out = x * psi 284 | return out 285 | 286 | 287 | class AttU_Net(nn.Module): 288 | """ 289 | Attention Unet implementation 290 | Paper: https://arxiv.org/abs/1804.03999 291 | """ 292 | def __init__(self, img_ch=3, output_ch=1): 293 | super(AttU_Net, self).__init__() 294 | 295 | n1 = 64 296 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 297 | 298 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 299 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 300 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) 301 | self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2) 302 | 303 | self.Conv1 = conv_block(img_ch, filters[0]) 304 | self.Conv2 = conv_block(filters[0], filters[1]) 305 | self.Conv3 = conv_block(filters[1], filters[2]) 306 | self.Conv4 = conv_block(filters[2], filters[3]) 307 | self.Conv5 = conv_block(filters[3], filters[4]) 308 | 309 | self.Up5 = up_conv(filters[4], filters[3]) 310 | self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2]) 311 | self.Up_conv5 = conv_block(filters[4], filters[3]) 312 | 313 | self.Up4 = up_conv(filters[3], filters[2]) 314 | self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1]) 315 | self.Up_conv4 = conv_block(filters[3], filters[2]) 316 | 317 | self.Up3 = up_conv(filters[2], filters[1]) 318 | self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0]) 319 | self.Up_conv3 = conv_block(filters[2], filters[1]) 320 | 321 | self.Up2 = up_conv(filters[1], filters[0]) 322 | self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32) 323 | self.Up_conv2 = conv_block(filters[1], filters[0]) 324 | 325 | self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0) 326 | 327 | #self.active = torch.nn.Sigmoid() 328 | 329 | 330 | def forward(self, x): 331 | 332 | e1 = self.Conv1(x) 333 | 334 | e2 = self.Maxpool1(e1) 335 | e2 = self.Conv2(e2) 336 | 337 | e3 = self.Maxpool2(e2) 338 | e3 = self.Conv3(e3) 339 | 340 | e4 = self.Maxpool3(e3) 341 | e4 = self.Conv4(e4) 342 | 343 | e5 = self.Maxpool4(e4) 344 | e5 = self.Conv5(e5) 345 | 346 | #print(x5.shape) 347 | d5 = self.Up5(e5) 348 | #print(d5.shape) 349 | x4 = self.Att5(g=d5, x=e4) 350 | d5 = torch.cat((x4, d5), dim=1) 351 | d5 = self.Up_conv5(d5) 352 | 353 | d4 = self.Up4(d5) 354 | x3 = self.Att4(g=d4, x=e3) 355 | d4 = torch.cat((x3, d4), dim=1) 356 | d4 = self.Up_conv4(d4) 357 | 358 | d3 = self.Up3(d4) 359 | x2 = self.Att3(g=d3, x=e2) 360 | d3 = torch.cat((x2, d3), dim=1) 361 | d3 = self.Up_conv3(d3) 362 | 363 | d2 = self.Up2(d3) 364 | x1 = self.Att2(g=d2, x=e1) 365 | d2 = torch.cat((x1, d2), dim=1) 366 | d2 = self.Up_conv2(d2) 367 | 368 | out = self.Conv(d2) 369 | 370 | # out = self.active(out) 371 | 372 | return out 373 | 374 | 375 | class R2AttU_Net(nn.Module): 376 | """ 377 | Residual Recuurent Block with attention Unet 378 | Implementation : https://github.com/LeeJunHyun/Image_Segmentation 379 | """ 380 | def __init__(self, in_ch=3, out_ch=1, t=2): 381 | super(R2AttU_Net, self).__init__() 382 | 383 | n1 = 64 384 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 385 | 386 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 387 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 388 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) 389 | self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2) 390 | 391 | self.RRCNN1 = RRCNN_block(in_ch, filters[0], t=t) 392 | self.RRCNN2 = RRCNN_block(filters[0], filters[1], t=t) 393 | self.RRCNN3 = RRCNN_block(filters[1], filters[2], t=t) 394 | self.RRCNN4 = RRCNN_block(filters[2], filters[3], t=t) 395 | self.RRCNN5 = RRCNN_block(filters[3], filters[4], t=t) 396 | 397 | self.Up5 = up_conv(filters[4], filters[3]) 398 | self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2]) 399 | self.Up_RRCNN5 = RRCNN_block(filters[4], filters[3], t=t) 400 | 401 | self.Up4 = up_conv(filters[3], filters[2]) 402 | self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1]) 403 | self.Up_RRCNN4 = RRCNN_block(filters[3], filters[2], t=t) 404 | 405 | self.Up3 = up_conv(filters[2], filters[1]) 406 | self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0]) 407 | self.Up_RRCNN3 = RRCNN_block(filters[2], filters[1], t=t) 408 | 409 | self.Up2 = up_conv(filters[1], filters[0]) 410 | self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32) 411 | self.Up_RRCNN2 = RRCNN_block(filters[1], filters[0], t=t) 412 | 413 | self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0) 414 | 415 | # self.active = torch.nn.Sigmoid() 416 | 417 | 418 | def forward(self, x): 419 | 420 | e1 = self.RRCNN1(x) 421 | 422 | e2 = self.Maxpool1(e1) 423 | e2 = self.RRCNN2(e2) 424 | 425 | e3 = self.Maxpool2(e2) 426 | e3 = self.RRCNN3(e3) 427 | 428 | e4 = self.Maxpool3(e3) 429 | e4 = self.RRCNN4(e4) 430 | 431 | e5 = self.Maxpool4(e4) 432 | e5 = self.RRCNN5(e5) 433 | 434 | d5 = self.Up5(e5) 435 | e4 = self.Att5(g=d5, x=e4) 436 | d5 = torch.cat((e4, d5), dim=1) 437 | d5 = self.Up_RRCNN5(d5) 438 | 439 | d4 = self.Up4(d5) 440 | e3 = self.Att4(g=d4, x=e3) 441 | d4 = torch.cat((e3, d4), dim=1) 442 | d4 = self.Up_RRCNN4(d4) 443 | 444 | d3 = self.Up3(d4) 445 | e2 = self.Att3(g=d3, x=e2) 446 | d3 = torch.cat((e2, d3), dim=1) 447 | d3 = self.Up_RRCNN3(d3) 448 | 449 | d2 = self.Up2(d3) 450 | e1 = self.Att2(g=d2, x=e1) 451 | d2 = torch.cat((e1, d2), dim=1) 452 | d2 = self.Up_RRCNN2(d2) 453 | 454 | out = self.Conv(d2) 455 | 456 | # out = self.active(out) 457 | 458 | return out 459 | 460 | #For nested 3 channels are required 461 | 462 | class conv_block_nested(nn.Module): 463 | 464 | def __init__(self, in_ch, mid_ch, out_ch): 465 | super(conv_block_nested, self).__init__() 466 | self.activation = nn.ReLU(inplace=True) 467 | self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True) 468 | self.bn1 = nn.BatchNorm2d(mid_ch) 469 | self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True) 470 | self.bn2 = nn.BatchNorm2d(out_ch) 471 | 472 | def forward(self, x): 473 | x = self.conv1(x) 474 | x = self.bn1(x) 475 | x = self.activation(x) 476 | 477 | x = self.conv2(x) 478 | x = self.bn2(x) 479 | output = self.activation(x) 480 | 481 | return output 482 | 483 | #Nested Unet 484 | 485 | class NestedUNet(nn.Module): 486 | """ 487 | Implementation of this paper: 488 | https://arxiv.org/pdf/1807.10165.pdf 489 | """ 490 | def __init__(self, in_ch=3, out_ch=1): 491 | super(NestedUNet, self).__init__() 492 | 493 | n1 = 64 494 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] 495 | 496 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 497 | self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 498 | 499 | self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0]) 500 | self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1]) 501 | self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2]) 502 | self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3]) 503 | self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4]) 504 | 505 | self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0]) 506 | self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1]) 507 | self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2]) 508 | self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3]) 509 | 510 | self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1], filters[0], filters[0]) 511 | self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2], filters[1], filters[1]) 512 | self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3], filters[2], filters[2]) 513 | 514 | self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1], filters[0], filters[0]) 515 | self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2], filters[1], filters[1]) 516 | 517 | self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1], filters[0], filters[0]) 518 | 519 | self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1) 520 | 521 | 522 | def forward(self, x): 523 | 524 | x0_0 = self.conv0_0(x) 525 | x1_0 = self.conv1_0(self.pool(x0_0)) 526 | x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1)) 527 | 528 | x2_0 = self.conv2_0(self.pool(x1_0)) 529 | x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1)) 530 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1)) 531 | 532 | x3_0 = self.conv3_0(self.pool(x2_0)) 533 | x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1)) 534 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1)) 535 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1)) 536 | 537 | x4_0 = self.conv4_0(self.pool(x3_0)) 538 | x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1)) 539 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1)) 540 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1)) 541 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1)) 542 | 543 | output = self.final(x0_4) 544 | return output 545 | 546 | #Dictioary Unet 547 | #if required for getting the filters and model parameters for each step 548 | 549 | class ConvolutionBlock(nn.Module): 550 | """Convolution block""" 551 | 552 | def __init__(self, in_filters, out_filters, kernel_size=3, batchnorm=True, last_active=F.relu): 553 | super(ConvolutionBlock, self).__init__() 554 | 555 | self.bn = batchnorm 556 | self.last_active = last_active 557 | self.c1 = nn.Conv2d(in_filters, out_filters, kernel_size, padding=1) 558 | self.b1 = nn.BatchNorm2d(out_filters) 559 | self.c2 = nn.Conv2d(out_filters, out_filters, kernel_size, padding=1) 560 | self.b2 = nn.BatchNorm2d(out_filters) 561 | 562 | def forward(self, x): 563 | x = self.c1(x) 564 | if self.bn: 565 | x = self.b1(x) 566 | x = F.relu(x) 567 | x = self.c2(x) 568 | if self.bn: 569 | x = self.b2(x) 570 | x = self.last_active(x) 571 | return x 572 | 573 | 574 | class ContractiveBlock(nn.Module): 575 | """Deconvuling Block""" 576 | 577 | def __init__(self, in_filters, out_filters, conv_kern=3, pool_kern=2, dropout=0.5, batchnorm=True): 578 | super(ContractiveBlock, self).__init__() 579 | self.c1 = ConvolutionBlock(in_filters=in_filters, out_filters=out_filters, kernel_size=conv_kern, 580 | batchnorm=batchnorm) 581 | self.p1 = nn.MaxPool2d(kernel_size=pool_kern, ceil_mode=True) 582 | self.d1 = nn.Dropout2d(dropout) 583 | 584 | def forward(self, x): 585 | c = self.c1(x) 586 | return c, self.d1(self.p1(c)) 587 | 588 | 589 | class ExpansiveBlock(nn.Module): 590 | """Upconvole Block""" 591 | 592 | def __init__(self, in_filters1, in_filters2, out_filters, tr_kern=3, conv_kern=3, stride=2, dropout=0.5): 593 | super(ExpansiveBlock, self).__init__() 594 | self.t1 = nn.ConvTranspose2d(in_filters1, out_filters, tr_kern, stride=2, padding=1, output_padding=1) 595 | self.d1 = nn.Dropout(dropout) 596 | self.c1 = ConvolutionBlock(out_filters + in_filters2, out_filters, conv_kern) 597 | 598 | def forward(self, x, contractive_x): 599 | x_ups = self.t1(x) 600 | x_concat = torch.cat([x_ups, contractive_x], 1) 601 | x_fin = self.c1(self.d1(x_concat)) 602 | return x_fin 603 | 604 | 605 | class Unet_dict(nn.Module): 606 | """Unet which operates with filters dictionary values""" 607 | 608 | def __init__(self, n_labels, n_filters=32, p_dropout=0.5, batchnorm=True): 609 | super(Unet_dict, self).__init__() 610 | filters_dict = {} 611 | filt_pair = [3, n_filters] 612 | 613 | for i in range(4): 614 | self.add_module('contractive_' + str(i), ContractiveBlock(filt_pair[0], filt_pair[1], batchnorm=batchnorm)) 615 | filters_dict['contractive_' + str(i)] = (filt_pair[0], filt_pair[1]) 616 | filt_pair[0] = filt_pair[1] 617 | filt_pair[1] = filt_pair[1] * 2 618 | 619 | self.bottleneck = ConvolutionBlock(filt_pair[0], filt_pair[1], batchnorm=batchnorm) 620 | filters_dict['bottleneck'] = (filt_pair[0], filt_pair[1]) 621 | 622 | for i in reversed(range(4)): 623 | self.add_module('expansive_' + str(i), 624 | ExpansiveBlock(filt_pair[1], filters_dict['contractive_' + str(i)][1], filt_pair[0])) 625 | filters_dict['expansive_' + str(i)] = (filt_pair[1], filt_pair[0]) 626 | filt_pair[1] = filt_pair[0] 627 | filt_pair[0] = filt_pair[0] // 2 628 | 629 | self.output = nn.Conv2d(filt_pair[1], n_labels, kernel_size=1) 630 | filters_dict['output'] = (filt_pair[1], n_labels) 631 | self.filters_dict = filters_dict 632 | 633 | # final_forward 634 | def forward(self, x): 635 | c00, c0 = self.contractive_0(x) 636 | c11, c1 = self.contractive_1(c0) 637 | c22, c2 = self.contractive_2(c1) 638 | c33, c3 = self.contractive_3(c2) 639 | bottle = self.bottleneck(c3) 640 | u3 = F.relu(self.expansive_3(bottle, c33)) 641 | u2 = F.relu(self.expansive_2(u3, c22)) 642 | u1 = F.relu(self.expansive_1(u2, c11)) 643 | u0 = F.relu(self.expansive_0(u1, c00)) 644 | return F.softmax(self.output(u0), dim=1) 645 | 646 | #Need to check why this Unet is not workin properly 647 | # 648 | # class Convolution2(nn.Module): 649 | # """Convolution Block using 2 Conv2D 650 | # Args: 651 | # in_channels = Input Channels 652 | # out_channels = Output Channels 653 | # kernal_size = 3 654 | # activation = Relu 655 | # batchnorm = True 656 | # 657 | # Output: 658 | # Sequential Relu output """ 659 | # 660 | # def __init__(self, in_channels, out_channels, kernal_size=3, activation='Relu', batchnorm=True): 661 | # super(Convolution2, self).__init__() 662 | # 663 | # self.in_channels = in_channels 664 | # self.out_channels = out_channels 665 | # self.kernal_size = kernal_size 666 | # self.batchnorm1 = batchnorm 667 | # 668 | # self.batchnorm2 = batchnorm 669 | # self.activation = activation 670 | # 671 | # self.conv1 = nn.Conv2d(self.in_channels, self.out_channels, self.kernal_size, padding=1, bias=True) 672 | # self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, self.kernal_size, padding=1, bias=True) 673 | # 674 | # self.b1 = nn.BatchNorm2d(out_channels) 675 | # self.b2 = nn.BatchNorm2d(out_channels) 676 | # 677 | # if self.activation == 'LRelu': 678 | # self.a1 = nn.LeakyReLU(inplace=True) 679 | # if self.activation == 'Relu': 680 | # self.a1 = nn.ReLU(inplace=True) 681 | # 682 | # if self.activation == 'LRelu': 683 | # self.a2 = nn.LeakyReLU(inplace=True) 684 | # if self.activation == 'Relu': 685 | # self.a2 = nn.ReLU(inplace=True) 686 | # 687 | # def forward(self, x): 688 | # x1 = self.conv1(x) 689 | # 690 | # if self.batchnorm1: 691 | # x1 = self.b1(x1) 692 | # 693 | # x1 = self.a1(x1) 694 | # 695 | # x1 = self.conv2(x1) 696 | # 697 | # if self.batchnorm2: 698 | # x1 = self.b1(x1) 699 | # 700 | # x = self.a2(x1) 701 | # 702 | # return x 703 | # 704 | # 705 | # class UNet(nn.Module): 706 | # """Implementation of U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015) 707 | # https://arxiv.org/abs/1505.04597 708 | # Args: 709 | # n_class = no. of classes""" 710 | # 711 | # def __init__(self, n_class, dropout=0.4): 712 | # super(UNet, self).__init__() 713 | # 714 | # in_ch = 3 715 | # n1 = 64 716 | # n2 = n1*2 717 | # n3 = n2*2 718 | # n4 = n3*2 719 | # n5 = n4*2 720 | # 721 | # self.dconv_down1 = Convolution2(in_ch, n1) 722 | # self.dconv_down2 = Convolution2(n1, n2) 723 | # self.dconv_down3 = Convolution2(n2, n3) 724 | # self.dconv_down4 = Convolution2(n3, n4) 725 | # self.dconv_down5 = Convolution2(n4, n5) 726 | # 727 | # self.maxpool1 = nn.MaxPool2d(2) 728 | # self.maxpool2 = nn.MaxPool2d(2) 729 | # self.maxpool3 = nn.MaxPool2d(2) 730 | # self.maxpool4 = nn.MaxPool2d(2) 731 | # 732 | # self.upsample1 = nn.Upsample(scale_factor=2)#, mode='bilinear', align_corners=True) 733 | # self.upsample2 = nn.Upsample(scale_factor=2)#, mode='bilinear', align_corners=True) 734 | # self.upsample3 = nn.Upsample(scale_factor=2)#, mode='bilinear', align_corners=True) 735 | # self.upsample4 = nn.Upsample(scale_factor=2)#, mode='bilinear', align_corners=True) 736 | # 737 | # self.dropout1 = nn.Dropout(dropout) 738 | # self.dropout2 = nn.Dropout(dropout) 739 | # self.dropout3 = nn.Dropout(dropout) 740 | # self.dropout4 = nn.Dropout(dropout) 741 | # self.dropout5 = nn.Dropout(dropout) 742 | # self.dropout6 = nn.Dropout(dropout) 743 | # self.dropout7 = nn.Dropout(dropout) 744 | # self.dropout8 = nn.Dropout(dropout) 745 | # 746 | # self.dconv_up4 = Convolution2(n4 + n5, n4) 747 | # self.dconv_up3 = Convolution2(n3 + n4, n3) 748 | # self.dconv_up2 = Convolution2(n2 + n3, n2) 749 | # self.dconv_up1 = Convolution2(n1 + n2, n1) 750 | # 751 | # self.conv_last = nn.Conv2d(n1, n_class, kernel_size=1, stride=1, padding=0) 752 | # # self.active = torch.nn.Sigmoid() 753 | # 754 | # 755 | # 756 | # def forward(self, x): 757 | # conv1 = self.dconv_down1(x) 758 | # x = self.maxpool1(conv1) 759 | # # x = self.dropout1(x) 760 | # 761 | # conv2 = self.dconv_down2(x) 762 | # x = self.maxpool2(conv2) 763 | # # x = self.dropout2(x) 764 | # 765 | # conv3 = self.dconv_down3(x) 766 | # x = self.maxpool3(conv3) 767 | # # x = self.dropout3(x) 768 | # 769 | # conv4 = self.dconv_down4(x) 770 | # x = self.maxpool4(conv4) 771 | # #x = self.dropout4(x) 772 | # 773 | # x = self.dconv_down5(x) 774 | # 775 | # x = self.upsample4(x) 776 | # x = torch.cat((x, conv4), dim=1) 777 | # #x = self.dropout5(x) 778 | # 779 | # x = self.dconv_up4(x) 780 | # x = self.upsample3(x) 781 | # x = torch.cat((x, conv3), dim=1) 782 | # # x = self.dropout6(x) 783 | # 784 | # x = self.dconv_up3(x) 785 | # x = self.upsample2(x) 786 | # x = torch.cat((x, conv2), dim=1) 787 | # #x = self.dropout7(x) 788 | # 789 | # x = self.dconv_up2(x) 790 | # x = self.upsample1(x) 791 | # x = torch.cat((x, conv1), dim=1) 792 | # #x = self.dropout8(x) 793 | # 794 | # x = self.dconv_up1(x) 795 | # 796 | # x = self.conv_last(x) 797 | # # out = self.active(x) 798 | # 799 | # return x 800 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unet-Segmentation-Pytorch-Nest-of-Unets 2 | 3 | [![forthebadge](https://forthebadge.com/images/badges/made-with-python.svg)](https://www.python.org/) 4 | 5 | [![HitCount](http://hits.dwyl.io/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets.svg)](http://hits.dwyl.io/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets) 6 | [![License: MIT](https://img.shields.io/badge/License-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) 7 | [![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/graphs/commit-activity) 8 | [![GitHub issues](https://img.shields.io/github/issues/Naereen/StrapDown.js.svg)](https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/issues) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/unet-a-nested-u-net-architecture-for-medical/semantic-segmentation-on-cityscapes-val)](https://paperswithcode.com/sota/semantic-segmentation-on-cityscapes-val?p=unet-a-nested-u-net-architecture-for-medical) 10 | 11 | Implementation of different kinds of Unet Models for Image Segmentation 12 | 13 | 1) **UNet** - U-Net: Convolutional Networks for Biomedical Image Segmentation 14 | https://arxiv.org/abs/1505.04597 15 | 16 | 2) **RCNN-UNet** - Recurrent Residual Convolutional Neural Network based on U-Net (R2U-Net) for Medical Image Segmentation 17 | https://arxiv.org/abs/1802.06955 18 | 19 | 3) **Attention Unet** - Attention U-Net: Learning Where to Look for the Pancreas 20 | https://arxiv.org/abs/1804.03999 21 | 22 | 4) **RCNN-Attention Unet** - Attention R2U-Net : Just integration of two recent advanced works (R2U-Net + Attention U-Net) 23 | 24 | 25 | 5) **Nested UNet** - UNet++: A Nested U-Net Architecture for Medical Image Segmentation 26 | https://arxiv.org/abs/1807.10165 27 | 28 | With Layer Visualization 29 | 30 | ## 1. Getting Started 31 | 32 | Clone the repo: 33 | 34 | ```bash 35 | git clone https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets.git 36 | ``` 37 | 38 | ## 2. Requirements 39 | 40 | ``` 41 | python>=3.6 42 | torch>=0.4.0 43 | torchvision 44 | torchsummary 45 | tensorboardx 46 | natsort 47 | numpy 48 | pillow 49 | scipy 50 | scikit-image 51 | sklearn 52 | ``` 53 | Install all dependent libraries: 54 | ```bash 55 | pip install -r requirements.txt 56 | ``` 57 | ## 3. Run the file 58 | 59 | Add all your folders to this line 106-113 60 | ``` 61 | t_data = '' # Input data 62 | l_data = '' #Input Label 63 | test_image = '' #Image to be predicted while training 64 | test_label = '' #Label of the prediction Image 65 | test_folderP = '' #Test folder Image 66 | test_folderL = '' #Test folder Label for calculating the Dice score 67 | ``` 68 | 69 | ## 4. Types of Unet 70 | 71 | **Unet** 72 | ![unet1](/images/unet1.png) 73 | 74 | **RCNN Unet** 75 | ![r2unet](/images/r2unet.png) 76 | 77 | 78 | **Attention Unet** 79 | ![att-unet](/images/att-unet.png) 80 | 81 | 82 | **Attention-RCNN Unet** 83 | ![att-r2u](/images/att-r2u.png) 84 | 85 | 86 | **Nested Unet** 87 | 88 | ![nested](/images/nested.jpg) 89 | 90 | ## 5. Visualization 91 | 92 | To plot the loss , Visdom would be required. The code is already written, just uncomment the required part. 93 | Gradient flow can be used too. Taken from (https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10) 94 | 95 | A model folder is created and all the data is stored inside that. 96 | Last layer will be saved in the model folder. If any particular layer is required , mention it in the line 361. 97 | 98 | **Layer Visulization** 99 | 100 | ![l2](/images/l2.png) 101 | 102 | **Filter Visulization** 103 | 104 | ![filt1](/images/filt1.png) 105 | 106 | **TensorboardX** 107 | Still have to tweak some parameters to get visualization. Have messed up this trying to make pytorch 1.1.0 working with tensorboard directly (and then came to know Currently it doesn't support anything apart from linear graphs) 108 | 109 | 110 | **Input Image Visulization for checking** 111 | 112 | **a) Original Image** 113 | 114 | 115 | 116 | **b) CenterCrop Image** 117 | 118 | 119 | 120 | ## 6. Results 121 | 122 | **Dice Score for hippocampus segmentation** 123 | ADNI-LONI Dataset 124 | 125 | 126 | 127 | ## 7. Citation 128 | 129 | If you find it usefull for your work. 130 | ``` 131 | @article{DBLP:journals/corr/abs-1906-07160, 132 | author = {Malav Bateriwala and 133 | Pierrick Bourgeat}, 134 | title = {Enforcing temporal consistency in Deep Learning segmentation of brain 135 | {MR} images}, 136 | journal = {CoRR}, 137 | volume = {abs/1906.07160}, 138 | year = {2019}, 139 | url = {http://arxiv.org/abs/1906.07160}, 140 | archivePrefix = {arXiv}, 141 | eprint = {1906.07160}, 142 | timestamp = {Mon, 24 Jun 2019 17:28:45 +0200}, 143 | biburl = {https://dblp.org/rec/bib/journals/corr/abs-1906-07160}, 144 | bibsource = {dblp computer science bibliography, https://dblp.org} 145 | } 146 | ``` 147 | 148 | ## 8. Blog about different Unets 149 | ``` 150 | In progress 151 | ``` 152 | 153 | 154 | -------------------------------------------------------------------------------- /dice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/dice.png -------------------------------------------------------------------------------- /images/att-r2u.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/att-r2u.png -------------------------------------------------------------------------------- /images/att-unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/att-unet.png -------------------------------------------------------------------------------- /images/filt1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/filt1.png -------------------------------------------------------------------------------- /images/in1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/in1.png -------------------------------------------------------------------------------- /images/in2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/in2.png -------------------------------------------------------------------------------- /images/l2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/l2.png -------------------------------------------------------------------------------- /images/nested.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/nested.jpg -------------------------------------------------------------------------------- /images/r2unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/r2unet.png -------------------------------------------------------------------------------- /images/tensorb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/tensorb.png -------------------------------------------------------------------------------- /images/unet1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/unet1.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import torch.nn.functional as F 3 | 4 | 5 | def dice_loss(prediction, target): 6 | """Calculating the dice loss 7 | Args: 8 | prediction = predicted image 9 | target = Targeted image 10 | Output: 11 | dice_loss""" 12 | 13 | smooth = 1.0 14 | 15 | i_flat = prediction.view(-1) 16 | t_flat = target.view(-1) 17 | 18 | intersection = (i_flat * t_flat).sum() 19 | 20 | return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth)) 21 | 22 | 23 | def calc_loss(prediction, target, bce_weight=0.5): 24 | """Calculating the loss and metrics 25 | Args: 26 | prediction = predicted image 27 | target = Targeted image 28 | metrics = Metrics printed 29 | bce_weight = 0.5 (default) 30 | Output: 31 | loss : dice loss of the epoch """ 32 | bce = F.binary_cross_entropy_with_logits(prediction, target) 33 | prediction = F.sigmoid(prediction) 34 | dice = dice_loss(prediction, target) 35 | 36 | loss = bce * bce_weight + dice * (1 - bce_weight) 37 | 38 | return loss 39 | 40 | 41 | def threshold_predictions_v(predictions, thr=150): 42 | thresholded_preds = predictions[:] 43 | # hist = cv2.calcHist([predictions], [0], None, [2], [0, 2]) 44 | # plt.plot(hist) 45 | # plt.xlim([0, 2]) 46 | # plt.show() 47 | low_values_indices = thresholded_preds < thr 48 | thresholded_preds[low_values_indices] = 0 49 | low_values_indices = thresholded_preds >= thr 50 | thresholded_preds[low_values_indices] = 255 51 | return thresholded_preds 52 | 53 | 54 | def threshold_predictions_p(predictions, thr=0.01): 55 | thresholded_preds = predictions[:] 56 | #hist = cv2.calcHist([predictions], [0], None, [256], [0, 256]) 57 | low_values_indices = thresholded_preds < thr 58 | thresholded_preds[low_values_indices] = 0 59 | low_values_indices = thresholded_preds >= thr 60 | thresholded_preds[low_values_indices] = 1 61 | return thresholded_preds -------------------------------------------------------------------------------- /ploting.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.lines import Line2D 3 | import numpy as np 4 | from visdom import Visdom 5 | 6 | 7 | def show_images(images, labels): 8 | """Show image with label 9 | Args: 10 | images = input images 11 | labels = input labels 12 | Output: 13 | plt = concatenated image and label """ 14 | 15 | plt.imshow(images.permute(1, 2, 0)) 16 | plt.imshow(labels, alpha=0.7, cmap='gray') 17 | plt.figure() 18 | 19 | 20 | def show_training_dataset(training_dataset): 21 | """Showing the images in training set for dict images and labels 22 | Args: 23 | training_dataset = dictionary of images and labels 24 | Output: 25 | figure = 3 images shown""" 26 | 27 | if training_dataset: 28 | print(len(training_dataset)) 29 | 30 | for i in range(len(training_dataset)): 31 | sample = training_dataset[i] 32 | 33 | print(i, sample['images'].shape, sample['labels'].shape) 34 | 35 | ax = plt.subplot(1, 4, i + 1) 36 | plt.tight_layout() 37 | ax.set_title('Sample #{}'.format(i)) 38 | ax.axis('off') 39 | show_images(sample['images'],sample['labels']) 40 | 41 | if i == 3: 42 | plt.show() 43 | break 44 | 45 | class VisdomLinePlotter(object): 46 | 47 | """Plots to Visdom""" 48 | 49 | def __init__(self, env_name='main'): 50 | self.viz = Visdom() 51 | self.env = env_name 52 | self.plots = {} 53 | 54 | def plot(self, var_name, split_name, title_name, x, y): 55 | if var_name not in self.plots: 56 | self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict( 57 | legend=[split_name], 58 | title=title_name, 59 | xlabel='Epochs', 60 | ylabel=var_name 61 | )) 62 | else: 63 | self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update = 'append') 64 | 65 | 66 | def input_images(x, y, i, n_iter, k=1): 67 | """ 68 | 69 | :param x: takes input image 70 | :param y: take input label 71 | :param i: the epoch number 72 | :param n_iter: 73 | :param k: for keeping it in loop 74 | :return: Returns a image and label 75 | """ 76 | if k == 1: 77 | x1 = x 78 | y1 = y 79 | 80 | x2 = x1.to('cpu') 81 | y2 = y1.to('cpu') 82 | x2 = x2.detach().numpy() 83 | y2 = y2.detach().numpy() 84 | 85 | x3 = x2[1, 1, :, :] 86 | y3 = y2[1, 0, :, :] 87 | 88 | fig = plt.figure() 89 | 90 | ax1 = fig.add_subplot(1, 2, 1) 91 | ax1.imshow(x3) 92 | ax1.axis('off') 93 | ax1.set_xticklabels([]) 94 | ax1.set_yticklabels([]) 95 | ax1 = fig.add_subplot(1, 2, 2) 96 | ax1.imshow(y3) 97 | ax1.axis('off') 98 | ax1.set_xticklabels([]) 99 | ax1.set_yticklabels([]) 100 | plt.savefig( 101 | './model/pred/L_' + str(n_iter-1) + '_epoch_' 102 | + str(i)) 103 | 104 | 105 | def plot_kernels(tensor, n_iter, num_cols=5, cmap="gray"): 106 | """Plotting the kernals and layers 107 | Args: 108 | Tensor :Input layer, 109 | n_iter : number of interation, 110 | num_cols : number of columbs required for figure 111 | Output: 112 | Gives the figure of the size decided with output layers activation map 113 | 114 | Default : Last layer will be taken into consideration 115 | """ 116 | if not len(tensor.shape) == 4: 117 | raise Exception("assumes a 4D tensor") 118 | 119 | fig = plt.figure() 120 | i = 0 121 | t = tensor.data.numpy() 122 | b = 0 123 | a = 1 124 | 125 | for t1 in t: 126 | for t2 in t1: 127 | i += 1 128 | 129 | ax1 = fig.add_subplot(5, num_cols, i) 130 | ax1.imshow(t2, cmap=cmap) 131 | ax1.axis('off') 132 | ax1.set_xticklabels([]) 133 | ax1.set_yticklabels([]) 134 | 135 | if i == 1: 136 | a = 1 137 | if a == 10: 138 | break 139 | a += 1 140 | if i % a == 0: 141 | a = 0 142 | b += 1 143 | if b == 20: 144 | break 145 | 146 | plt.savefig( 147 | './model/pred/Kernal_' + str(n_iter - 1) + '_epoch_' 148 | + str(i)) 149 | 150 | 151 | class LayerActivations(): 152 | """Getting the hooks on each layer""" 153 | 154 | features = None 155 | 156 | def __init__(self, layer): 157 | self.hook = layer.register_forward_hook(self.hook_fn) 158 | 159 | def hook_fn(self, module, input, output): 160 | self.features = output.cpu() 161 | 162 | def remove(self): 163 | self.hook.remove() 164 | 165 | 166 | #to get gradient flow 167 | #From Pytorch-forums 168 | def plot_grad_flow(named_parameters,n_iter): 169 | 170 | '''Plots the gradients flowing through different layers in the net during training. 171 | Can be used for checking for possible gradient vanishing / exploding problems. 172 | 173 | Usage: Plug this function in Trainer class after loss.backwards() as 174 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' 175 | ave_grads = [] 176 | max_grads = [] 177 | layers = [] 178 | for n, p in named_parameters: 179 | if (p.requires_grad) and ("bias" not in n): 180 | layers.append(n) 181 | ave_grads.append(p.grad.abs().mean()) 182 | max_grads.append(p.grad.abs().max()) 183 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") 184 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") 185 | plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") 186 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 187 | plt.xlim(left=0, right=len(ave_grads)) 188 | plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions 189 | plt.xlabel("Layers") 190 | plt.ylabel("average gradient") 191 | plt.title("Gradient flow") 192 | plt.grid(True) 193 | plt.legend([Line2D([0], [0], color="c", lw=4), 194 | Line2D([0], [0], color="b", lw=4), 195 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) 196 | #plt.savefig('./model/pred/Grad_Flow_' + str(n_iter - 1)) 197 | -------------------------------------------------------------------------------- /pytorch_run.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import glob 6 | #import SimpleITK as sitk 7 | from torch import optim 8 | import torch.utils.data 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | import torch.nn 13 | import torchvision 14 | import matplotlib.pyplot as plt 15 | import natsort 16 | from torch.utils.data.sampler import SubsetRandomSampler 17 | from Data_Loader import Images_Dataset, Images_Dataset_folder 18 | import torchsummary 19 | #from torch.utils.tensorboard import SummaryWriter 20 | #from tensorboardX import SummaryWriter 21 | 22 | import shutil 23 | import random 24 | from Models import Unet_dict, NestedUNet, U_Net, R2U_Net, AttU_Net, R2AttU_Net 25 | from losses import calc_loss, dice_loss, threshold_predictions_v,threshold_predictions_p 26 | from ploting import plot_kernels, LayerActivations, input_images, plot_grad_flow 27 | from Metrics import dice_coeff, accuracy_score 28 | import time 29 | #from ploting import VisdomLinePlotter 30 | #from visdom import Visdom 31 | 32 | 33 | ####################################################### 34 | #Checking if GPU is used 35 | ####################################################### 36 | 37 | train_on_gpu = torch.cuda.is_available() 38 | 39 | if not train_on_gpu: 40 | print('CUDA is not available. Training on CPU') 41 | else: 42 | print('CUDA is available. Training on GPU') 43 | 44 | device = torch.device("cuda:0" if train_on_gpu else "cpu") 45 | 46 | ####################################################### 47 | #Setting the basic paramters of the model 48 | ####################################################### 49 | 50 | batch_size = 4 51 | print('batch_size = ' + str(batch_size)) 52 | 53 | valid_size = 0.15 54 | 55 | epoch = 15 56 | print('epoch = ' + str(epoch)) 57 | 58 | random_seed = random.randint(1, 100) 59 | print('random_seed = ' + str(random_seed)) 60 | 61 | shuffle = True 62 | valid_loss_min = np.Inf 63 | num_workers = 4 64 | lossT = [] 65 | lossL = [] 66 | lossL.append(np.inf) 67 | lossT.append(np.inf) 68 | epoch_valid = epoch-2 69 | n_iter = 1 70 | i_valid = 0 71 | 72 | pin_memory = False 73 | if train_on_gpu: 74 | pin_memory = True 75 | 76 | #plotter = VisdomLinePlotter(env_name='Tutorial Plots') 77 | 78 | ####################################################### 79 | #Setting up the model 80 | ####################################################### 81 | 82 | model_Inputs = [U_Net, R2U_Net, AttU_Net, R2AttU_Net, NestedUNet] 83 | 84 | 85 | def model_unet(model_input, in_channel=3, out_channel=1): 86 | model_test = model_input(in_channel, out_channel) 87 | return model_test 88 | 89 | #passsing this string so that if it's AttU_Net or R2ATTU_Net it doesn't throw an error at torchSummary 90 | 91 | 92 | model_test = model_unet(model_Inputs[0], 3, 1) 93 | 94 | model_test.to(device) 95 | 96 | ####################################################### 97 | #Getting the Summary of Model 98 | ####################################################### 99 | 100 | torchsummary.summary(model_test, input_size=(3, 128, 128)) 101 | 102 | ####################################################### 103 | #Passing the Dataset of Images and Labels 104 | ####################################################### 105 | 106 | t_data = '/flush1/bat161/segmentation/New_Trails/venv/DATA/new_3C_I_ori/' 107 | l_data = '/flush1/bat161/segmentation/New_Trails/venv/DATA/new_3C_L_ori/' 108 | test_image = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_I_ori/0131_0009.png' 109 | test_label = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_L_ori/0131_0009.png' 110 | test_folderP = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_I_ori/*' 111 | test_folderL = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_L_ori/*' 112 | 113 | Training_Data = Images_Dataset_folder(t_data, 114 | l_data) 115 | 116 | ####################################################### 117 | #Giving a transformation for input data 118 | ####################################################### 119 | 120 | data_transform = torchvision.transforms.Compose([ 121 | # torchvision.transforms.Resize((128,128)), 122 | # torchvision.transforms.CenterCrop(96), 123 | torchvision.transforms.ToTensor(), 124 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 125 | ]) 126 | 127 | ####################################################### 128 | #Trainging Validation Split 129 | ####################################################### 130 | 131 | num_train = len(Training_Data) 132 | indices = list(range(num_train)) 133 | split = int(np.floor(valid_size * num_train)) 134 | 135 | if shuffle: 136 | np.random.seed(random_seed) 137 | np.random.shuffle(indices) 138 | 139 | train_idx, valid_idx = indices[split:], indices[:split] 140 | train_sampler = SubsetRandomSampler(train_idx) 141 | valid_sampler = SubsetRandomSampler(valid_idx) 142 | 143 | train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=train_sampler, 144 | num_workers=num_workers, pin_memory=pin_memory,) 145 | 146 | valid_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=valid_sampler, 147 | num_workers=num_workers, pin_memory=pin_memory,) 148 | 149 | ####################################################### 150 | #Using Adam as Optimizer 151 | ####################################################### 152 | 153 | initial_lr = 0.001 154 | opt = torch.optim.Adam(model_test.parameters(), lr=initial_lr) # try SGD 155 | #opt = optim.SGD(model_test.parameters(), lr = initial_lr, momentum=0.99) 156 | 157 | MAX_STEP = int(1e10) 158 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, MAX_STEP, eta_min=1e-5) 159 | #scheduler = optim.lr_scheduler.CosineAnnealingLr(opt, epoch, 1) 160 | 161 | ####################################################### 162 | #Writing the params to tensorboard 163 | ####################################################### 164 | 165 | #writer1 = SummaryWriter() 166 | #dummy_inp = torch.randn(1, 3, 128, 128) 167 | #model_test.to('cpu') 168 | #writer1.add_graph(model_test, model_test(torch.randn(3, 3, 128, 128, requires_grad=True))) 169 | #model_test.to(device) 170 | 171 | ####################################################### 172 | #Creating a Folder for every data of the program 173 | ####################################################### 174 | 175 | New_folder = './model' 176 | 177 | if os.path.exists(New_folder) and os.path.isdir(New_folder): 178 | shutil.rmtree(New_folder) 179 | 180 | try: 181 | os.mkdir(New_folder) 182 | except OSError: 183 | print("Creation of the main directory '%s' failed " % New_folder) 184 | else: 185 | print("Successfully created the main directory '%s' " % New_folder) 186 | 187 | ####################################################### 188 | #Setting the folder of saving the predictions 189 | ####################################################### 190 | 191 | read_pred = './model/pred' 192 | 193 | ####################################################### 194 | #Checking if prediction folder exixts 195 | ####################################################### 196 | 197 | if os.path.exists(read_pred) and os.path.isdir(read_pred): 198 | shutil.rmtree(read_pred) 199 | 200 | try: 201 | os.mkdir(read_pred) 202 | except OSError: 203 | print("Creation of the prediction directory '%s' failed of dice loss" % read_pred) 204 | else: 205 | print("Successfully created the prediction directory '%s' of dice loss" % read_pred) 206 | 207 | ####################################################### 208 | #checking if the model exists and if true then delete 209 | ####################################################### 210 | 211 | read_model_path = './model/Unet_D_' + str(epoch) + '_' + str(batch_size) 212 | 213 | if os.path.exists(read_model_path) and os.path.isdir(read_model_path): 214 | shutil.rmtree(read_model_path) 215 | print('Model folder there, so deleted for newer one') 216 | 217 | try: 218 | os.mkdir(read_model_path) 219 | except OSError: 220 | print("Creation of the model directory '%s' failed" % read_model_path) 221 | else: 222 | print("Successfully created the model directory '%s' " % read_model_path) 223 | 224 | ####################################################### 225 | #Training loop 226 | ####################################################### 227 | 228 | for i in range(epoch): 229 | 230 | train_loss = 0.0 231 | valid_loss = 0.0 232 | since = time.time() 233 | scheduler.step(i) 234 | lr = scheduler.get_lr() 235 | 236 | ####################################################### 237 | #Training Data 238 | ####################################################### 239 | 240 | model_test.train() 241 | k = 1 242 | 243 | for x, y in train_loader: 244 | x, y = x.to(device), y.to(device) 245 | 246 | #If want to get the input images with their Augmentation - To check the data flowing in net 247 | input_images(x, y, i, n_iter, k) 248 | 249 | # grid_img = torchvision.utils.make_grid(x) 250 | #writer1.add_image('images', grid_img, 0) 251 | 252 | # grid_lab = torchvision.utils.make_grid(y) 253 | 254 | opt.zero_grad() 255 | 256 | y_pred = model_test(x) 257 | lossT = calc_loss(y_pred, y) # Dice_loss Used 258 | 259 | train_loss += lossT.item() * x.size(0) 260 | lossT.backward() 261 | # plot_grad_flow(model_test.named_parameters(), n_iter) 262 | opt.step() 263 | x_size = lossT.item() * x.size(0) 264 | k = 2 265 | 266 | # for name, param in model_test.named_parameters(): 267 | # name = name.replace('.', '/') 268 | # writer1.add_histogram(name, param.data.cpu().numpy(), i + 1) 269 | # writer1.add_histogram(name + '/grad', param.grad.data.cpu().numpy(), i + 1) 270 | 271 | 272 | ####################################################### 273 | #Validation Step 274 | ####################################################### 275 | 276 | model_test.eval() 277 | torch.no_grad() #to increase the validation process uses less memory 278 | 279 | for x1, y1 in valid_loader: 280 | x1, y1 = x1.to(device), y1.to(device) 281 | 282 | y_pred1 = model_test(x1) 283 | lossL = calc_loss(y_pred1, y1) # Dice_loss Used 284 | 285 | valid_loss += lossL.item() * x1.size(0) 286 | x_size1 = lossL.item() * x1.size(0) 287 | 288 | ####################################################### 289 | #Saving the predictions 290 | ####################################################### 291 | 292 | im_tb = Image.open(test_image) 293 | im_label = Image.open(test_label) 294 | s_tb = data_transform(im_tb) 295 | s_label = data_transform(im_label) 296 | s_label = s_label.detach().numpy() 297 | 298 | pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu() 299 | pred_tb = F.sigmoid(pred_tb) 300 | pred_tb = pred_tb.detach().numpy() 301 | 302 | #pred_tb = threshold_predictions_v(pred_tb) 303 | 304 | x1 = plt.imsave( 305 | './model/pred/img_iteration_' + str(n_iter) + '_epoch_' 306 | + str(i) + '.png', pred_tb[0][0]) 307 | 308 | # accuracy = accuracy_score(pred_tb[0][0], s_label) 309 | 310 | ####################################################### 311 | #To write in Tensorboard 312 | ####################################################### 313 | 314 | train_loss = train_loss / len(train_idx) 315 | valid_loss = valid_loss / len(valid_idx) 316 | 317 | if (i+1) % 1 == 0: 318 | print('Epoch: {}/{} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(i + 1, epoch, train_loss, 319 | valid_loss)) 320 | # writer1.add_scalar('Train Loss', train_loss, n_iter) 321 | # writer1.add_scalar('Validation Loss', valid_loss, n_iter) 322 | #writer1.add_image('Pred', pred_tb[0]) #try to get output of shape 3 323 | 324 | 325 | ####################################################### 326 | #Early Stopping 327 | ####################################################### 328 | 329 | if valid_loss <= valid_loss_min and epoch_valid >= i: # and i_valid <= 2: 330 | 331 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model '.format(valid_loss_min, valid_loss)) 332 | torch.save(model_test.state_dict(),'./model/Unet_D_' + 333 | str(epoch) + '_' + str(batch_size) + '/Unet_epoch_' + str(epoch) 334 | + '_batchsize_' + str(batch_size) + '.pth') 335 | # print(accuracy) 336 | if round(valid_loss, 4) == round(valid_loss_min, 4): 337 | print(i_valid) 338 | i_valid = i_valid+1 339 | valid_loss_min = valid_loss 340 | #if i_valid ==3: 341 | # break 342 | 343 | ####################################################### 344 | # Extracting the intermediate layers 345 | ####################################################### 346 | 347 | ##################################### 348 | # for kernals 349 | ##################################### 350 | x1 = torch.nn.ModuleList(model_test.children()) 351 | # x2 = torch.nn.ModuleList(x1[16].children()) 352 | #x3 = torch.nn.ModuleList(x2[0].children()) 353 | 354 | #To get filters in the layers 355 | #plot_kernels(x1.weight.detach().cpu(), 7) 356 | 357 | ##################################### 358 | # for images 359 | ##################################### 360 | x2 = len(x1) 361 | dr = LayerActivations(x1[x2-1]) #Getting the last Conv Layer 362 | 363 | img = Image.open(test_image) 364 | s_tb = data_transform(img) 365 | 366 | pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu() 367 | pred_tb = F.sigmoid(pred_tb) 368 | pred_tb = pred_tb.detach().numpy() 369 | 370 | plot_kernels(dr.features, n_iter, 7, cmap="rainbow") 371 | 372 | time_elapsed = time.time() - since 373 | print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 374 | n_iter += 1 375 | 376 | ####################################################### 377 | #closing the tensorboard writer 378 | ####################################################### 379 | 380 | #writer1.close() 381 | 382 | ####################################################### 383 | #if using dict 384 | ####################################################### 385 | 386 | #model_test.filter_dict 387 | 388 | ####################################################### 389 | #Loading the model 390 | ####################################################### 391 | 392 | test1 =model_test.load_state_dict(torch.load('./model/Unet_D_' + 393 | str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch) 394 | + '_batchsize_' + str(batch_size) + '.pth')) 395 | 396 | 397 | ####################################################### 398 | #checking if cuda is available 399 | ####################################################### 400 | 401 | if torch.cuda.is_available(): 402 | torch.cuda.empty_cache() 403 | 404 | ####################################################### 405 | #Loading the model 406 | ####################################################### 407 | 408 | model_test.load_state_dict(torch.load('./model/Unet_D_' + 409 | str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch) 410 | + '_batchsize_' + str(batch_size) + '.pth')) 411 | 412 | model_test.eval() 413 | 414 | ####################################################### 415 | #opening the test folder and creating a folder for generated images 416 | ####################################################### 417 | 418 | read_test_folder = glob.glob(test_folderP) 419 | x_sort_test = natsort.natsorted(read_test_folder) # To sort 420 | 421 | 422 | read_test_folder112 = './model/gen_images' 423 | 424 | 425 | if os.path.exists(read_test_folder112) and os.path.isdir(read_test_folder112): 426 | shutil.rmtree(read_test_folder112) 427 | 428 | try: 429 | os.mkdir(read_test_folder112) 430 | except OSError: 431 | print("Creation of the testing directory %s failed" % read_test_folder112) 432 | else: 433 | print("Successfully created the testing directory %s " % read_test_folder112) 434 | 435 | 436 | #For Prediction Threshold 437 | 438 | read_test_folder_P_Thres = './model/pred_threshold' 439 | 440 | 441 | if os.path.exists(read_test_folder_P_Thres) and os.path.isdir(read_test_folder_P_Thres): 442 | shutil.rmtree(read_test_folder_P_Thres) 443 | 444 | try: 445 | os.mkdir(read_test_folder_P_Thres) 446 | except OSError: 447 | print("Creation of the testing directory %s failed" % read_test_folder_P_Thres) 448 | else: 449 | print("Successfully created the testing directory %s " % read_test_folder_P_Thres) 450 | 451 | #For Label Threshold 452 | 453 | read_test_folder_L_Thres = './model/label_threshold' 454 | 455 | 456 | if os.path.exists(read_test_folder_L_Thres) and os.path.isdir(read_test_folder_L_Thres): 457 | shutil.rmtree(read_test_folder_L_Thres) 458 | 459 | try: 460 | os.mkdir(read_test_folder_L_Thres) 461 | except OSError: 462 | print("Creation of the testing directory %s failed" % read_test_folder_L_Thres) 463 | else: 464 | print("Successfully created the testing directory %s " % read_test_folder_L_Thres) 465 | 466 | 467 | 468 | 469 | ####################################################### 470 | #saving the images in the files 471 | ####################################################### 472 | 473 | img_test_no = 0 474 | 475 | for i in range(len(read_test_folder)): 476 | im = Image.open(x_sort_test[i]) 477 | 478 | im1 = im 479 | im_n = np.array(im1) 480 | im_n_flat = im_n.reshape(-1, 1) 481 | 482 | for j in range(im_n_flat.shape[0]): 483 | if im_n_flat[j] != 0: 484 | im_n_flat[j] = 255 485 | 486 | s = data_transform(im) 487 | pred = model_test(s.unsqueeze(0).cuda()).cpu() 488 | pred = F.sigmoid(pred) 489 | pred = pred.detach().numpy() 490 | 491 | # pred = threshold_predictions_p(pred) #Value kept 0.01 as max is 1 and noise is very small. 492 | 493 | if i % 24 == 0: 494 | img_test_no = img_test_no + 1 495 | 496 | x1 = plt.imsave('./model/gen_images/im_epoch_' + str(epoch) + 'int_' + str(i) 497 | + '_img_no_' + str(img_test_no) + '.png', pred[0][0]) 498 | 499 | 500 | #################################################### 501 | #Calculating the Dice Score 502 | #################################################### 503 | 504 | data_transform = torchvision.transforms.Compose([ 505 | # torchvision.transforms.Resize((128,128)), 506 | # torchvision.transforms.CenterCrop(96), 507 | torchvision.transforms.Grayscale(), 508 | # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 509 | ]) 510 | 511 | 512 | 513 | read_test_folderP = glob.glob('./model/gen_images/*') 514 | x_sort_testP = natsort.natsorted(read_test_folderP) 515 | 516 | 517 | read_test_folderL = glob.glob(test_folderL) 518 | x_sort_testL = natsort.natsorted(read_test_folderL) # To sort 519 | 520 | 521 | dice_score123 = 0.0 522 | x_count = 0 523 | x_dice = 0 524 | 525 | for i in range(len(read_test_folderP)): 526 | 527 | x = Image.open(x_sort_testP[i]) 528 | s = data_transform(x) 529 | s = np.array(s) 530 | s = threshold_predictions_v(s) 531 | 532 | #save the images 533 | x1 = plt.imsave('./model/pred_threshold/im_epoch_' + str(epoch) + 'int_' + str(i) 534 | + '_img_no_' + str(img_test_no) + '.png', s) 535 | 536 | y = Image.open(x_sort_testL[i]) 537 | s2 = data_transform(y) 538 | s3 = np.array(s2) 539 | # s2 =threshold_predictions_v(s2) 540 | 541 | #save the Images 542 | y1 = plt.imsave('./model/label_threshold/im_epoch_' + str(epoch) + 'int_' + str(i) 543 | + '_img_no_' + str(img_test_no) + '.png', s3) 544 | 545 | total = dice_coeff(s, s3) 546 | print(total) 547 | 548 | if total <= 0.3: 549 | x_count += 1 550 | if total > 0.3: 551 | x_dice = x_dice + total 552 | dice_score123 = dice_score123 + total 553 | 554 | 555 | print('Dice Score : ' + str(dice_score123/len(read_test_folderP))) 556 | #print(x_count) 557 | #print(x_dice) 558 | #print('Dice Score : ' + str(float(x_dice/(len(read_test_folderP)-x_count)))) 559 | 560 | -------------------------------------------------------------------------------- /pytorch_run_old.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import glob 6 | 7 | from torch import optim 8 | import torch.utils.data 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | import torch.nn 13 | import torchvision 14 | import matplotlib.pyplot as plt 15 | import natsort 16 | from torch.utils.data.sampler import SubsetRandomSampler 17 | from Data_Loader import Images_Dataset, Images_Dataset_folder 18 | import torchsummary 19 | #from torch.utils.tensorboard import SummaryWriter 20 | from tensorboardX import SummaryWriter 21 | 22 | import shutil 23 | import random 24 | from Models import Unet_dict, NestedUNet, U_Net, R2U_Net, AttU_Net, R2AttU_Net 25 | from losses import calc_loss, dice_loss, threshold_predictions_v,threshold_predictions_p 26 | from ploting import plot_kernels, LayerActivations, input_images, plot_grad_flow 27 | from Metrics import dice_coeff, accuracy_score 28 | import time 29 | #from ploting import VisdomLinePlotter 30 | #from visdom import Visdom 31 | 32 | 33 | ####################################################### 34 | #to make sure you want to run the program 35 | ####################################################### 36 | 37 | x = input('start the model training: ') 38 | if x == 'yes': 39 | pass 40 | else: 41 | exit() 42 | 43 | ####################################################### 44 | #Checking if GPU is used 45 | ####################################################### 46 | 47 | train_on_gpu = torch.cuda.is_available() 48 | 49 | if not train_on_gpu: 50 | print('CUDA is not available. Training on CPU') 51 | else: 52 | print('CUDA is available. Training on GPU') 53 | 54 | device = torch.device("cuda:0" if train_on_gpu else "cpu") 55 | 56 | ####################################################### 57 | #Setting the basic paramters of the model 58 | ####################################################### 59 | 60 | batch_size = 4 61 | print('batch_size = ' + str(batch_size)) 62 | 63 | valid_size = 0.15 64 | 65 | epoch = 10 66 | print('epoch = ' + str(epoch)) 67 | 68 | random_seed = random.randint(1, 100) 69 | print('random_seed = ' + str(random_seed)) 70 | 71 | shuffle = True 72 | valid_loss_min = np.Inf 73 | num_workers = 4 74 | lossT = [] 75 | lossL = [] 76 | lossL.append(np.inf) 77 | lossT.append(np.inf) 78 | epoch_valid = epoch-2 79 | n_iter = 1 80 | i_valid = 0 81 | 82 | pin_memory = False 83 | if train_on_gpu: 84 | pin_memory = True 85 | 86 | #plotter = VisdomLinePlotter(env_name='Tutorial Plots') 87 | 88 | ####################################################### 89 | #Setting up the model 90 | ####################################################### 91 | 92 | model_Inputs = [U_Net, R2U_Net, AttU_Net, R2AttU_Net, NestedUNet] 93 | 94 | 95 | def model_unet(model_input, in_channel=3, out_channel=1): 96 | model_test = model_input(in_channel, out_channel) 97 | return model_test 98 | 99 | #passsing this string so that if it's AttU_Net or R2ATTU_Net it doesn't throw an error at torchSummary 100 | 101 | 102 | model_test = model_unet(model_Inputs[0], 3, 1) 103 | 104 | model_test.to(device) 105 | 106 | ####################################################### 107 | #Getting the Summary of Model 108 | ####################################################### 109 | 110 | torchsummary.summary(model_test, input_size=(3, 128, 128)) 111 | 112 | ####################################################### 113 | #Passing the Dataset of Images and Labels 114 | ####################################################### 115 | 116 | Training_Data = Images_Dataset_folder('/home/malav/Desktop/Pytorch_Computer/DATA/new_3C_I_ori_same/', 117 | '/home/malav/Desktop/Pytorch_Computer/DATA/new_3C_L_ori_same/') 118 | 119 | ####################################################### 120 | #Giving a transformation for input data 121 | ####################################################### 122 | 123 | data_transform = torchvision.transforms.Compose([ 124 | # torchvision.transforms.Resize((128,128)), 125 | torchvision.transforms.CenterCrop(96), 126 | torchvision.transforms.ToTensor(), 127 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 128 | ]) 129 | 130 | ####################################################### 131 | #Trainging Validation Split 132 | ####################################################### 133 | 134 | num_train = len(Training_Data) 135 | indices = list(range(num_train)) 136 | split = int(np.floor(valid_size * num_train)) 137 | 138 | if shuffle: 139 | np.random.seed(random_seed) 140 | np.random.shuffle(indices) 141 | 142 | train_idx, valid_idx = indices[split:], indices[:split] 143 | train_sampler = SubsetRandomSampler(train_idx) 144 | valid_sampler = SubsetRandomSampler(valid_idx) 145 | 146 | train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=train_sampler, 147 | num_workers=num_workers, pin_memory=pin_memory,) 148 | 149 | valid_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=valid_sampler, 150 | num_workers=num_workers, pin_memory=pin_memory,) 151 | 152 | ####################################################### 153 | #Using Adam as Optimizer 154 | ####################################################### 155 | 156 | initial_lr = 0.001 157 | opt = torch.optim.Adam(model_test.parameters(), lr=initial_lr) 158 | MAX_STEP = int(1e10) 159 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, MAX_STEP, eta_min=1e-5) 160 | #scheduler = optim.lr_scheduler.CosineAnnealingLr(opt, epoch, 1) 161 | 162 | ####################################################### 163 | #Writing the params to tensorboard 164 | ####################################################### 165 | 166 | writer1 = SummaryWriter() 167 | dummy_inp = torch.randn(1, 3, 128, 128) 168 | model_test.to('cpu') 169 | writer1.add_graph(model_test, model_test(torch.randn(3, 3, 128, 128, requires_grad=True))) 170 | model_test.to(device) 171 | 172 | ####################################################### 173 | #Creating a Folder for every data of the program 174 | ####################################################### 175 | 176 | New_folder = './model' 177 | 178 | if os.path.exists(New_folder) and os.path.isdir(New_folder): 179 | shutil.rmtree(New_folder) 180 | 181 | try: 182 | os.mkdir(New_folder) 183 | except OSError: 184 | print("Creation of the main directory '%s' failed " % New_folder) 185 | else: 186 | print("Successfully created the main directory '%s' " % New_folder) 187 | 188 | ####################################################### 189 | #Setting the folder of saving the predictions 190 | ####################################################### 191 | 192 | read_pred = './model/pred' 193 | 194 | ####################################################### 195 | #Checking if prediction folder exixts 196 | ####################################################### 197 | 198 | if os.path.exists(read_pred) and os.path.isdir(read_pred): 199 | shutil.rmtree(read_pred) 200 | 201 | try: 202 | os.mkdir(read_pred) 203 | except OSError: 204 | print("Creation of the prediction directory '%s' failed of dice loss" % read_pred) 205 | else: 206 | print("Successfully created the prediction directory '%s' of dice loss" % read_pred) 207 | 208 | ####################################################### 209 | #checking if the model exists and if true then delete 210 | ####################################################### 211 | 212 | read_model_path = './model/Unet_D_' + str(epoch) + '_' + str(batch_size) 213 | 214 | if os.path.exists(read_model_path) and os.path.isdir(read_model_path): 215 | shutil.rmtree(read_model_path) 216 | print('Model folder there, so deleted for newer one') 217 | 218 | try: 219 | os.mkdir(read_model_path) 220 | except OSError: 221 | print("Creation of the model directory '%s' failed" % read_model_path) 222 | else: 223 | print("Successfully created the model directory '%s' " % read_model_path) 224 | 225 | ####################################################### 226 | #Training loop 227 | ####################################################### 228 | 229 | for i in range(epoch): 230 | 231 | train_loss = 0.0 232 | valid_loss = 0.0 233 | since = time.time() 234 | scheduler.step(i) 235 | lr = scheduler.get_lr() 236 | 237 | ####################################################### 238 | #Training Data 239 | ####################################################### 240 | 241 | model_test.train() 242 | 243 | for x, y in train_loader: 244 | x, y = x.to(device), y.to(device) 245 | 246 | #If want to get the input images with their Augmentation - To check the data flowing in net 247 | input_images(x, y, i, n_iter) 248 | 249 | # grid_img = torchvision.utils.make_grid(x) 250 | #writer1.add_image('images', grid_img, 0) 251 | 252 | # grid_lab = torchvision.utils.make_grid(y) 253 | 254 | opt.zero_grad() 255 | 256 | y_pred = model_test(x) 257 | lossT = calc_loss(y_pred, y) # Dice_loss Used 258 | 259 | train_loss += lossT.item() * x.size(0) 260 | lossT.backward() 261 | # plot_grad_flow(model_test.named_parameters(), n_iter) 262 | opt.step() 263 | x_size = lossT.item() * x.size(0) 264 | k = 2 265 | 266 | # for name, param in model_test.named_parameters(): 267 | # name = name.replace('.', '/') 268 | # writer1.add_histogram(name, param.data.cpu().numpy(), i + 1) 269 | # writer1.add_histogram(name + '/grad', param.grad.data.cpu().numpy(), i + 1) 270 | 271 | 272 | ####################################################### 273 | #Validation Step 274 | ####################################################### 275 | 276 | model_test.eval() 277 | torch.no_grad() #to increase the validation process uses less memory 278 | 279 | for x1, y1 in valid_loader: 280 | x1, y1 = x1.to(device), y1.to(device) 281 | 282 | y_pred1 = model_test(x1) 283 | lossL = calc_loss(y_pred1, y1) # Dice_loss Used 284 | 285 | valid_loss += lossL.item() * x1.size(0) 286 | x_size1 = lossL.item() * x1.size(0) 287 | 288 | ####################################################### 289 | #Saving the predictions 290 | ####################################################### 291 | 292 | im_tb = Image.open('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_I_ori_same/0131_0009.png') 293 | im_label = Image.open('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_L_ori_same/0131_0009.png') 294 | s_tb = data_transform(im_tb) 295 | s_label = data_transform(im_label) 296 | 297 | pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu() 298 | pred_tb = F.sigmoid(pred_tb) 299 | pred_tb = pred_tb.detach().numpy() 300 | 301 | #pred_tb = threshold_predictions_v(pred_tb) 302 | 303 | x1 = plt.imsave( 304 | './model/pred/img_iteration_' + str(n_iter) + '_epoch_' 305 | + str(i) + '.png', pred_tb[0][0]) 306 | 307 | accuracy = accuracy_score(pred_tb[0][0], s_label) 308 | 309 | ####################################################### 310 | #To write in Tensorboard 311 | ####################################################### 312 | 313 | train_loss = train_loss / len(train_idx) 314 | valid_loss = valid_loss / len(valid_idx) 315 | 316 | if (i+1) % 1 == 0: 317 | print('Epoch: {}/{} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(i + 1, epoch, train_loss, 318 | valid_loss)) 319 | writer1.add_scalar('Train Loss', train_loss, n_iter) 320 | writer1.add_scalar('Validation Loss', valid_loss, n_iter) 321 | #writer1.add_image('Pred', pred_tb[0]) #try to get output of shape 3 322 | 323 | 324 | ####################################################### 325 | #Early Stopping 326 | ####################################################### 327 | 328 | if valid_loss <= valid_loss_min and epoch_valid >= i: # and i_valid <= 2: 329 | 330 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model '.format(valid_loss_min, valid_loss)) 331 | torch.save(model_test.state_dict(),'./model/Unet_D_' + 332 | str(epoch) + '_' + str(batch_size) + '/Unet_epoch_' + str(epoch) 333 | + '_batchsize_' + str(batch_size) + '.pth') 334 | print(accuracy) 335 | if round(valid_loss, 4) == round(valid_loss_min, 4): 336 | print(i_valid) 337 | i_valid = i_valid+1 338 | valid_loss_min = valid_loss 339 | #if i_valid ==3: 340 | # break 341 | 342 | ####################################################### 343 | # Extracting the intermediate layers 344 | ####################################################### 345 | 346 | ##################################### 347 | # for kernals 348 | ##################################### 349 | x1 = torch.nn.ModuleList(model_test.children()) 350 | # x2 = torch.nn.ModuleList(x1[16].children()) 351 | # x3 = torch.nn.ModuleList(x2[0].children()) 352 | 353 | #To get filters in the layers 354 | # plot_kernels(x3[3].weight.detach().cpu(), 7) 355 | 356 | ##################################### 357 | # for images 358 | ##################################### 359 | x2 = len(x1) 360 | dr = LayerActivations(x1[x2-1]) #Getting the last Conv Layer 361 | 362 | img = Image.open('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_I_ori_same/0131_0009.png') 363 | s_tb = data_transform(img) 364 | 365 | pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu() 366 | pred_tb = F.sigmoid(pred_tb) 367 | pred_tb = pred_tb.detach().numpy() 368 | 369 | plot_kernels(dr.features, n_iter, 7, cmap="rainbow") 370 | 371 | time_elapsed = time.time() - since 372 | print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 373 | n_iter += 1 374 | 375 | ####################################################### 376 | #closing the tensorboard writer 377 | ####################################################### 378 | 379 | writer1.close() 380 | 381 | ####################################################### 382 | #if using dict 383 | ####################################################### 384 | 385 | #model_test.filter_dict 386 | 387 | ####################################################### 388 | #Loading the model 389 | ####################################################### 390 | 391 | test1 =model_test.load_state_dict(torch.load('./model/Unet_D_' + 392 | str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch) 393 | + '_batchsize_' + str(batch_size) + '.pth')) 394 | 395 | 396 | ####################################################### 397 | #checking if cuda is available 398 | ####################################################### 399 | 400 | if torch.cuda.is_available(): 401 | torch.cuda.empty_cache() 402 | 403 | ####################################################### 404 | #Loading the model 405 | ####################################################### 406 | 407 | model_test.load_state_dict(torch.load('./model/Unet_D_' + 408 | str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch) 409 | + '_batchsize_' + str(batch_size) + '.pth')) 410 | 411 | model_test.eval() 412 | 413 | ####################################################### 414 | #opening the test folder and creating a folder for generated images 415 | ####################################################### 416 | 417 | read_test_folder = glob.glob('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_I_ori_same/*') 418 | x_sort_test = natsort.natsorted(read_test_folder) # To sort 419 | 420 | 421 | read_test_folder112 = './model/gen_images' 422 | 423 | 424 | if os.path.exists(read_test_folder112) and os.path.isdir(read_test_folder112): 425 | shutil.rmtree(read_test_folder112) 426 | 427 | try: 428 | os.mkdir(read_test_folder112) 429 | except OSError: 430 | print("Creation of the testing directory %s failed" % read_test_folder112) 431 | else: 432 | print("Successfully created the testing directory %s " % read_test_folder112) 433 | 434 | 435 | #For Prediction Threshold 436 | 437 | read_test_folder_P_Thres = './model/pred_threshold' 438 | 439 | 440 | if os.path.exists(read_test_folder_P_Thres) and os.path.isdir(read_test_folder_P_Thres): 441 | shutil.rmtree(read_test_folder_P_Thres) 442 | 443 | try: 444 | os.mkdir(read_test_folder_P_Thres) 445 | except OSError: 446 | print("Creation of the testing directory %s failed" % read_test_folder_P_Thres) 447 | else: 448 | print("Successfully created the testing directory %s " % read_test_folder_P_Thres) 449 | 450 | #For Label Threshold 451 | 452 | read_test_folder_L_Thres = './model/label_threshold' 453 | 454 | 455 | if os.path.exists(read_test_folder_L_Thres) and os.path.isdir(read_test_folder_L_Thres): 456 | shutil.rmtree(read_test_folder_L_Thres) 457 | 458 | try: 459 | os.mkdir(read_test_folder_L_Thres) 460 | except OSError: 461 | print("Creation of the testing directory %s failed" % read_test_folder_L_Thres) 462 | else: 463 | print("Successfully created the testing directory %s " % read_test_folder_L_Thres) 464 | 465 | 466 | 467 | ####################################################### 468 | #data transform for test Set (same as before) 469 | ####################################################### 470 | 471 | data_transform = torchvision.transforms.Compose([ 472 | # torchvision.transforms.Resize((128, 128)), 473 | # torchvision.transforms.Grayscale(), 474 | torchvision.transforms.CenterCrop(96), 475 | torchvision.transforms.ToTensor(), 476 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 477 | ]) 478 | 479 | ####################################################### 480 | #saving the images in the files 481 | ####################################################### 482 | 483 | img_test_no = 0 484 | 485 | for i in range(len(read_test_folder)): 486 | im = Image.open(x_sort_test[i]) 487 | 488 | im1 = im 489 | im_n = np.array(im1) 490 | im_n_flat = im_n.reshape(-1,1) 491 | 492 | for j in range(im_n_flat.shape[0]): 493 | if im_n_flat[j] != 0: 494 | im_n_flat[j] = 255 495 | 496 | s = data_transform(im) 497 | pred = model_test(s.unsqueeze(0).cuda()).cpu() 498 | pred = F.sigmoid(pred) 499 | pred = pred.detach().numpy() 500 | 501 | # pred = threshold_predictions_p(pred) #Value kept 0.01 as max is 1 and noise is very small. 502 | 503 | if i % 24 == 0: 504 | img_test_no = img_test_no + 1 505 | 506 | x1 = plt.imsave('./model/gen_images/im_epoch_' + str(epoch) + 'int_' + str(i) 507 | + '_img_no_' + str(img_test_no) + '.png', pred[0][0]) 508 | 509 | #################################################### 510 | #data transform for test Set (same as before) 511 | #################################################### 512 | 513 | data_transform_test = torchvision.transforms.Compose([ 514 | # torchvision.transforms.Resize((128, 128)), 515 | torchvision.transforms.CenterCrop(96), 516 | torchvision.transforms.Grayscale(), 517 | ]) 518 | 519 | #################################################### 520 | #Calculating the Dice Score 521 | #################################################### 522 | 523 | read_test_folderP = glob.glob('./model/gen_images/*') 524 | x_sort_testP = natsort.natsorted(read_test_folderP) 525 | 526 | 527 | read_test_folderL = glob.glob('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_L_ori_same/*') 528 | x_sort_testL = natsort.natsorted(read_test_folderL) # To sort 529 | 530 | 531 | dice_score123 = 0.0 532 | x_count = 0 533 | x_dice = 0 534 | 535 | for i in range(len(read_test_folderP)): 536 | 537 | x = Image.open(x_sort_testP[i]) 538 | s = data_transform_test(x) 539 | s = np.array(s) 540 | s = threshold_predictions_v(s) 541 | 542 | #save the images 543 | x1 = plt.imsave('./model/pred_threshold/im_epoch_' + str(epoch) + 'int_' + str(i) 544 | + '_img_no_' + str(img_test_no) + '.png', s) 545 | 546 | y = Image.open(x_sort_testL[i]) 547 | s2 = data_transform_test(y) 548 | s3 = np.array(s2) 549 | # s2 =threshold_predictions_v(s2) 550 | 551 | #save the Images 552 | y1 = plt.imsave('./model/label_threshold/im_epoch_' + str(epoch) + 'int_' + str(i) 553 | + '_img_no_' + str(img_test_no) + '.png', s3) 554 | 555 | total = dice_coeff(s, s3) 556 | print(total) 557 | 558 | if total <= 0.3: 559 | x_count += 1 560 | if total > 0.3: 561 | x_dice = x_dice + total 562 | dice_score123 = dice_score123 + total 563 | 564 | 565 | print('Dice Score : ' + str(dice_score123/len(read_test_folderP))) 566 | print(x_count) 567 | print(x_dice) 568 | print('Dice Score : ' + str(float(x_dice/(len(read_test_folderP)-x_count)))) 569 | 570 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | python>=3.6 2 | torch>=0.4.0 3 | torchvision 4 | torchsummary 5 | tensorboardx 6 | natsort 7 | numpy 8 | pillow 9 | scipy 10 | scikit-image 11 | sklearn 12 | --------------------------------------------------------------------------------