├── 2d_from_3d.py
├── CODE_OF_CONDUCT.md
├── Data_Loader.py
├── LICENSE
├── Metrics.py
├── Models.py
├── README.md
├── dice.png
├── images
├── att-r2u.png
├── att-unet.png
├── filt1.png
├── in1.png
├── in2.png
├── l2.png
├── nested.jpg
├── r2unet.png
├── tensorb.png
└── unet1.png
├── losses.py
├── ploting.py
├── pytorch_run.py
├── pytorch_run_old.py
└── requirements.txt
/2d_from_3d.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import scipy.misc
3 |
4 | import SimpleITK as sitk #reading MR images
5 |
6 | import glob
7 |
8 |
9 | readfolderT = glob.glob('/home/bat161/Desktop/Thesis/EADC_HHP/*_MNI.nii.gz')
10 | readfolderL = glob.glob('/home/bat161/Desktop/Thesis/EADC_HHP/*_HHP_EADC.nii.gz')
11 |
12 |
13 | TrainingImagesList = []
14 | TrainingLabelsList = []
15 |
16 |
17 | for i in range(len(readfolderT)):
18 | y_folder = readfolderT[i]
19 | yread = sitk.ReadImage(y_folder)
20 | yimage = sitk.GetArrayFromImage(yread)
21 | x = yimage[:184,:232,112:136]
22 | x = scipy.rot90(x)
23 | x = scipy.rot90(x)
24 | for j in range(x.shape[2]):
25 | TrainingImagesList.append((x[:184,:224,j]))
26 |
27 | for i in range(len(readfolderL)):
28 | y_folder = readfolderL[i]
29 | yread = sitk.ReadImage(y_folder)
30 | yimage = sitk.GetArrayFromImage(yread)
31 | x = yimage[:184,:232,112:136]
32 | x = scipy.rot90(x)
33 | x = scipy.rot90(x)
34 | for j in range(x.shape[2]):
35 | TrainingLabelsList.append((x[:184,:224,j]))
36 |
37 | for i in range(len(TrainingImagesList)):
38 |
39 | xchangeL = TrainingImagesList[i]
40 | xchangeL = cv2.resize(xchangeL,(128,128))
41 | scipy.misc.imsave('/home/bat161/Desktop/Thesis/Image/png_1C_images/'+str(i)+'.png',xchangeL)
42 |
43 | for i in range(len(TrainingLabelsList)):
44 |
45 | xchangeL = TrainingLabelsList[i]
46 | xchangeL = cv2.resize(xchangeL,(128,128))
47 | scipy.misc.imsave('/home/bat161/Desktop/Thesis/Image/png_1C_labels/'+str(i)+'.png',xchangeL)
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at malav.b93@gmail.com. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/Data_Loader.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | from PIL import Image
4 | import torch
5 | import torch.utils.data
6 | import torchvision
7 | from skimage import io
8 | from torch.utils.data import Dataset
9 | import random
10 | import numpy as np
11 |
12 |
13 | class Images_Dataset(Dataset):
14 | """Class for getting data as a Dict
15 | Args:
16 | images_dir = path of input images
17 | labels_dir = path of labeled images
18 | transformI = Input Images transformation (default: None)
19 | transformM = Input Labels transformation (default: None)
20 | Output:
21 | sample : Dict of images and labels"""
22 |
23 | def __init__(self, images_dir, labels_dir, transformI = None, transformM = None):
24 |
25 | self.labels_dir = labels_dir
26 | self.images_dir = images_dir
27 | self.transformI = transformI
28 | self.transformM = transformM
29 |
30 | def __len__(self):
31 | return len(self.images_dir)
32 |
33 | def __getitem__(self, idx):
34 |
35 | for i in range(len(self.images_dir)):
36 | image = io.imread(self.images_dir[i])
37 | label = io.imread(self.labels_dir[i])
38 | if self.transformI:
39 | image = self.transformI(image)
40 | if self.transformM:
41 | label = self.transformM(label)
42 | sample = {'images': image, 'labels': label}
43 |
44 | return sample
45 |
46 |
47 | class Images_Dataset_folder(torch.utils.data.Dataset):
48 | """Class for getting individual transformations and data
49 | Args:
50 | images_dir = path of input images
51 | labels_dir = path of labeled images
52 | transformI = Input Images transformation (default: None)
53 | transformM = Input Labels transformation (default: None)
54 | Output:
55 | tx = Transformed images
56 | lx = Transformed labels"""
57 |
58 | def __init__(self, images_dir, labels_dir,transformI = None, transformM = None):
59 | self.images = sorted(os.listdir(images_dir))
60 | self.labels = sorted(os.listdir(labels_dir))
61 | self.images_dir = images_dir
62 | self.labels_dir = labels_dir
63 | self.transformI = transformI
64 | self.transformM = transformM
65 |
66 | if self.transformI:
67 | self.tx = self.transformI
68 | else:
69 | self.tx = torchvision.transforms.Compose([
70 | # torchvision.transforms.Resize((128,128)),
71 | torchvision.transforms.CenterCrop(96),
72 | torchvision.transforms.RandomRotation((-10,10)),
73 | # torchvision.transforms.RandomHorizontalFlip(),
74 | torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
75 | torchvision.transforms.ToTensor(),
76 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
77 | ])
78 |
79 | if self.transformM:
80 | self.lx = self.transformM
81 | else:
82 | self.lx = torchvision.transforms.Compose([
83 | # torchvision.transforms.Resize((128,128)),
84 | torchvision.transforms.CenterCrop(96),
85 | torchvision.transforms.RandomRotation((-10,10)),
86 | torchvision.transforms.Grayscale(),
87 | torchvision.transforms.ToTensor(),
88 | #torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0))
89 | ])
90 |
91 | def __len__(self):
92 |
93 | return len(self.images)
94 |
95 | def __getitem__(self, i):
96 | i1 = Image.open(self.images_dir + self.images[i])
97 | l1 = Image.open(self.labels_dir + self.labels[i])
98 |
99 | seed=np.random.randint(0,2**32) # make a seed with numpy generator
100 |
101 | # apply this seed to img tranfsorms
102 | random.seed(seed)
103 | torch.manual_seed(seed)
104 | img = self.tx(i1)
105 |
106 | # apply this seed to target/label tranfsorms
107 | random.seed(seed)
108 | torch.manual_seed(seed)
109 | label = self.lx(l1)
110 |
111 |
112 |
113 | return img, label
114 |
115 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Malav Bateriwala
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import spatial
3 |
4 |
5 | def dice_coeff(im1, im2, empty_score=1.0):
6 | """Calculates the dice coefficient for the images"""
7 |
8 | im1 = np.asarray(im1).astype(np.bool)
9 | im2 = np.asarray(im2).astype(np.bool)
10 |
11 | if im1.shape != im2.shape:
12 | raise ValueError("Shape mismatch: im1 and im2 must have the same shape.")
13 |
14 | im1 = im1 > 0.5
15 | im2 = im2 > 0.5
16 |
17 | im_sum = im1.sum() + im2.sum()
18 | if im_sum == 0:
19 | return empty_score
20 |
21 | # Compute Dice coefficient
22 | intersection = np.logical_and(im1, im2)
23 | #print(im_sum)
24 |
25 | return 2. * intersection.sum() / im_sum
26 |
27 |
28 | def numeric_score(prediction, groundtruth):
29 | """Computes scores:
30 | FP = False Positives
31 | FN = False Negatives
32 | TP = True Positives
33 | TN = True Negatives
34 | return: FP, FN, TP, TN"""
35 |
36 | FP = np.float(np.sum((prediction == 1) & (groundtruth == 0)))
37 | FN = np.float(np.sum((prediction == 0) & (groundtruth == 1)))
38 | TP = np.float(np.sum((prediction == 1) & (groundtruth == 1)))
39 | TN = np.float(np.sum((prediction == 0) & (groundtruth == 0)))
40 |
41 | return FP, FN, TP, TN
42 |
43 |
44 | def accuracy_score(prediction, groundtruth):
45 | """Getting the accuracy of the model"""
46 |
47 | FP, FN, TP, TN = numeric_score(prediction, groundtruth)
48 | N = FP + FN + TP + TN
49 | accuracy = np.divide(TP + TN, N)
50 | return accuracy * 100.0
--------------------------------------------------------------------------------
/Models.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.data
5 | import torch
6 |
7 |
8 | class conv_block(nn.Module):
9 | """
10 | Convolution Block
11 | """
12 | def __init__(self, in_ch, out_ch):
13 | super(conv_block, self).__init__()
14 |
15 | self.conv = nn.Sequential(
16 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
17 | nn.BatchNorm2d(out_ch),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
20 | nn.BatchNorm2d(out_ch),
21 | nn.ReLU(inplace=True))
22 |
23 | def forward(self, x):
24 |
25 | x = self.conv(x)
26 | return x
27 |
28 |
29 | class up_conv(nn.Module):
30 | """
31 | Up Convolution Block
32 | """
33 | def __init__(self, in_ch, out_ch):
34 | super(up_conv, self).__init__()
35 | self.up = nn.Sequential(
36 | nn.Upsample(scale_factor=2),
37 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
38 | nn.BatchNorm2d(out_ch),
39 | nn.ReLU(inplace=True)
40 | )
41 |
42 | def forward(self, x):
43 | x = self.up(x)
44 | return x
45 |
46 |
47 | class U_Net(nn.Module):
48 | """
49 | UNet - Basic Implementation
50 | Paper : https://arxiv.org/abs/1505.04597
51 | """
52 | def __init__(self, in_ch=3, out_ch=1):
53 | super(U_Net, self).__init__()
54 |
55 | n1 = 64
56 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
57 |
58 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
59 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
60 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
61 | self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
62 |
63 | self.Conv1 = conv_block(in_ch, filters[0])
64 | self.Conv2 = conv_block(filters[0], filters[1])
65 | self.Conv3 = conv_block(filters[1], filters[2])
66 | self.Conv4 = conv_block(filters[2], filters[3])
67 | self.Conv5 = conv_block(filters[3], filters[4])
68 |
69 | self.Up5 = up_conv(filters[4], filters[3])
70 | self.Up_conv5 = conv_block(filters[4], filters[3])
71 |
72 | self.Up4 = up_conv(filters[3], filters[2])
73 | self.Up_conv4 = conv_block(filters[3], filters[2])
74 |
75 | self.Up3 = up_conv(filters[2], filters[1])
76 | self.Up_conv3 = conv_block(filters[2], filters[1])
77 |
78 | self.Up2 = up_conv(filters[1], filters[0])
79 | self.Up_conv2 = conv_block(filters[1], filters[0])
80 |
81 | self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
82 |
83 | # self.active = torch.nn.Sigmoid()
84 |
85 | def forward(self, x):
86 |
87 | e1 = self.Conv1(x)
88 |
89 | e2 = self.Maxpool1(e1)
90 | e2 = self.Conv2(e2)
91 |
92 | e3 = self.Maxpool2(e2)
93 | e3 = self.Conv3(e3)
94 |
95 | e4 = self.Maxpool3(e3)
96 | e4 = self.Conv4(e4)
97 |
98 | e5 = self.Maxpool4(e4)
99 | e5 = self.Conv5(e5)
100 |
101 | d5 = self.Up5(e5)
102 | d5 = torch.cat((e4, d5), dim=1)
103 |
104 | d5 = self.Up_conv5(d5)
105 |
106 | d4 = self.Up4(d5)
107 | d4 = torch.cat((e3, d4), dim=1)
108 | d4 = self.Up_conv4(d4)
109 |
110 | d3 = self.Up3(d4)
111 | d3 = torch.cat((e2, d3), dim=1)
112 | d3 = self.Up_conv3(d3)
113 |
114 | d2 = self.Up2(d3)
115 | d2 = torch.cat((e1, d2), dim=1)
116 | d2 = self.Up_conv2(d2)
117 |
118 | out = self.Conv(d2)
119 |
120 | #d1 = self.active(out)
121 |
122 | return out
123 |
124 |
125 | class Recurrent_block(nn.Module):
126 | """
127 | Recurrent Block for R2Unet_CNN
128 | """
129 | def __init__(self, out_ch, t=2):
130 | super(Recurrent_block, self).__init__()
131 |
132 | self.t = t
133 | self.out_ch = out_ch
134 | self.conv = nn.Sequential(
135 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
136 | nn.BatchNorm2d(out_ch),
137 | nn.ReLU(inplace=True)
138 | )
139 |
140 | def forward(self, x):
141 | for i in range(self.t):
142 | if i == 0:
143 | x = self.conv(x)
144 | out = self.conv(x + x)
145 | return out
146 |
147 |
148 | class RRCNN_block(nn.Module):
149 | """
150 | Recurrent Residual Convolutional Neural Network Block
151 | """
152 | def __init__(self, in_ch, out_ch, t=2):
153 | super(RRCNN_block, self).__init__()
154 |
155 | self.RCNN = nn.Sequential(
156 | Recurrent_block(out_ch, t=t),
157 | Recurrent_block(out_ch, t=t)
158 | )
159 | self.Conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0)
160 |
161 | def forward(self, x):
162 | x1 = self.Conv(x)
163 | x2 = self.RCNN(x1)
164 | out = x1 + x2
165 | return out
166 |
167 |
168 | class R2U_Net(nn.Module):
169 | """
170 | R2U-Unet implementation
171 | Paper: https://arxiv.org/abs/1802.06955
172 | """
173 | def __init__(self, img_ch=3, output_ch=1, t=2):
174 | super(R2U_Net, self).__init__()
175 |
176 | n1 = 64
177 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
178 |
179 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
180 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
181 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
182 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
183 |
184 | self.Upsample = nn.Upsample(scale_factor=2)
185 |
186 | self.RRCNN1 = RRCNN_block(img_ch, filters[0], t=t)
187 |
188 | self.RRCNN2 = RRCNN_block(filters[0], filters[1], t=t)
189 |
190 | self.RRCNN3 = RRCNN_block(filters[1], filters[2], t=t)
191 |
192 | self.RRCNN4 = RRCNN_block(filters[2], filters[3], t=t)
193 |
194 | self.RRCNN5 = RRCNN_block(filters[3], filters[4], t=t)
195 |
196 | self.Up5 = up_conv(filters[4], filters[3])
197 | self.Up_RRCNN5 = RRCNN_block(filters[4], filters[3], t=t)
198 |
199 | self.Up4 = up_conv(filters[3], filters[2])
200 | self.Up_RRCNN4 = RRCNN_block(filters[3], filters[2], t=t)
201 |
202 | self.Up3 = up_conv(filters[2], filters[1])
203 | self.Up_RRCNN3 = RRCNN_block(filters[2], filters[1], t=t)
204 |
205 | self.Up2 = up_conv(filters[1], filters[0])
206 | self.Up_RRCNN2 = RRCNN_block(filters[1], filters[0], t=t)
207 |
208 | self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0)
209 |
210 | # self.active = torch.nn.Sigmoid()
211 |
212 |
213 | def forward(self, x):
214 |
215 | e1 = self.RRCNN1(x)
216 |
217 | e2 = self.Maxpool(e1)
218 | e2 = self.RRCNN2(e2)
219 |
220 | e3 = self.Maxpool1(e2)
221 | e3 = self.RRCNN3(e3)
222 |
223 | e4 = self.Maxpool2(e3)
224 | e4 = self.RRCNN4(e4)
225 |
226 | e5 = self.Maxpool3(e4)
227 | e5 = self.RRCNN5(e5)
228 |
229 | d5 = self.Up5(e5)
230 | d5 = torch.cat((e4, d5), dim=1)
231 | d5 = self.Up_RRCNN5(d5)
232 |
233 | d4 = self.Up4(d5)
234 | d4 = torch.cat((e3, d4), dim=1)
235 | d4 = self.Up_RRCNN4(d4)
236 |
237 | d3 = self.Up3(d4)
238 | d3 = torch.cat((e2, d3), dim=1)
239 | d3 = self.Up_RRCNN3(d3)
240 |
241 | d2 = self.Up2(d3)
242 | d2 = torch.cat((e1, d2), dim=1)
243 | d2 = self.Up_RRCNN2(d2)
244 |
245 | out = self.Conv(d2)
246 |
247 | # out = self.active(out)
248 |
249 | return out
250 |
251 |
252 | class Attention_block(nn.Module):
253 | """
254 | Attention Block
255 | """
256 |
257 | def __init__(self, F_g, F_l, F_int):
258 | super(Attention_block, self).__init__()
259 |
260 | self.W_g = nn.Sequential(
261 | nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
262 | nn.BatchNorm2d(F_int)
263 | )
264 |
265 | self.W_x = nn.Sequential(
266 | nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
267 | nn.BatchNorm2d(F_int)
268 | )
269 |
270 | self.psi = nn.Sequential(
271 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
272 | nn.BatchNorm2d(1),
273 | nn.Sigmoid()
274 | )
275 |
276 | self.relu = nn.ReLU(inplace=True)
277 |
278 | def forward(self, g, x):
279 | g1 = self.W_g(g)
280 | x1 = self.W_x(x)
281 | psi = self.relu(g1 + x1)
282 | psi = self.psi(psi)
283 | out = x * psi
284 | return out
285 |
286 |
287 | class AttU_Net(nn.Module):
288 | """
289 | Attention Unet implementation
290 | Paper: https://arxiv.org/abs/1804.03999
291 | """
292 | def __init__(self, img_ch=3, output_ch=1):
293 | super(AttU_Net, self).__init__()
294 |
295 | n1 = 64
296 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
297 |
298 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
299 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
300 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
301 | self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
302 |
303 | self.Conv1 = conv_block(img_ch, filters[0])
304 | self.Conv2 = conv_block(filters[0], filters[1])
305 | self.Conv3 = conv_block(filters[1], filters[2])
306 | self.Conv4 = conv_block(filters[2], filters[3])
307 | self.Conv5 = conv_block(filters[3], filters[4])
308 |
309 | self.Up5 = up_conv(filters[4], filters[3])
310 | self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2])
311 | self.Up_conv5 = conv_block(filters[4], filters[3])
312 |
313 | self.Up4 = up_conv(filters[3], filters[2])
314 | self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1])
315 | self.Up_conv4 = conv_block(filters[3], filters[2])
316 |
317 | self.Up3 = up_conv(filters[2], filters[1])
318 | self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0])
319 | self.Up_conv3 = conv_block(filters[2], filters[1])
320 |
321 | self.Up2 = up_conv(filters[1], filters[0])
322 | self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32)
323 | self.Up_conv2 = conv_block(filters[1], filters[0])
324 |
325 | self.Conv = nn.Conv2d(filters[0], output_ch, kernel_size=1, stride=1, padding=0)
326 |
327 | #self.active = torch.nn.Sigmoid()
328 |
329 |
330 | def forward(self, x):
331 |
332 | e1 = self.Conv1(x)
333 |
334 | e2 = self.Maxpool1(e1)
335 | e2 = self.Conv2(e2)
336 |
337 | e3 = self.Maxpool2(e2)
338 | e3 = self.Conv3(e3)
339 |
340 | e4 = self.Maxpool3(e3)
341 | e4 = self.Conv4(e4)
342 |
343 | e5 = self.Maxpool4(e4)
344 | e5 = self.Conv5(e5)
345 |
346 | #print(x5.shape)
347 | d5 = self.Up5(e5)
348 | #print(d5.shape)
349 | x4 = self.Att5(g=d5, x=e4)
350 | d5 = torch.cat((x4, d5), dim=1)
351 | d5 = self.Up_conv5(d5)
352 |
353 | d4 = self.Up4(d5)
354 | x3 = self.Att4(g=d4, x=e3)
355 | d4 = torch.cat((x3, d4), dim=1)
356 | d4 = self.Up_conv4(d4)
357 |
358 | d3 = self.Up3(d4)
359 | x2 = self.Att3(g=d3, x=e2)
360 | d3 = torch.cat((x2, d3), dim=1)
361 | d3 = self.Up_conv3(d3)
362 |
363 | d2 = self.Up2(d3)
364 | x1 = self.Att2(g=d2, x=e1)
365 | d2 = torch.cat((x1, d2), dim=1)
366 | d2 = self.Up_conv2(d2)
367 |
368 | out = self.Conv(d2)
369 |
370 | # out = self.active(out)
371 |
372 | return out
373 |
374 |
375 | class R2AttU_Net(nn.Module):
376 | """
377 | Residual Recuurent Block with attention Unet
378 | Implementation : https://github.com/LeeJunHyun/Image_Segmentation
379 | """
380 | def __init__(self, in_ch=3, out_ch=1, t=2):
381 | super(R2AttU_Net, self).__init__()
382 |
383 | n1 = 64
384 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
385 |
386 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
387 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
388 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
389 | self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2)
390 |
391 | self.RRCNN1 = RRCNN_block(in_ch, filters[0], t=t)
392 | self.RRCNN2 = RRCNN_block(filters[0], filters[1], t=t)
393 | self.RRCNN3 = RRCNN_block(filters[1], filters[2], t=t)
394 | self.RRCNN4 = RRCNN_block(filters[2], filters[3], t=t)
395 | self.RRCNN5 = RRCNN_block(filters[3], filters[4], t=t)
396 |
397 | self.Up5 = up_conv(filters[4], filters[3])
398 | self.Att5 = Attention_block(F_g=filters[3], F_l=filters[3], F_int=filters[2])
399 | self.Up_RRCNN5 = RRCNN_block(filters[4], filters[3], t=t)
400 |
401 | self.Up4 = up_conv(filters[3], filters[2])
402 | self.Att4 = Attention_block(F_g=filters[2], F_l=filters[2], F_int=filters[1])
403 | self.Up_RRCNN4 = RRCNN_block(filters[3], filters[2], t=t)
404 |
405 | self.Up3 = up_conv(filters[2], filters[1])
406 | self.Att3 = Attention_block(F_g=filters[1], F_l=filters[1], F_int=filters[0])
407 | self.Up_RRCNN3 = RRCNN_block(filters[2], filters[1], t=t)
408 |
409 | self.Up2 = up_conv(filters[1], filters[0])
410 | self.Att2 = Attention_block(F_g=filters[0], F_l=filters[0], F_int=32)
411 | self.Up_RRCNN2 = RRCNN_block(filters[1], filters[0], t=t)
412 |
413 | self.Conv = nn.Conv2d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
414 |
415 | # self.active = torch.nn.Sigmoid()
416 |
417 |
418 | def forward(self, x):
419 |
420 | e1 = self.RRCNN1(x)
421 |
422 | e2 = self.Maxpool1(e1)
423 | e2 = self.RRCNN2(e2)
424 |
425 | e3 = self.Maxpool2(e2)
426 | e3 = self.RRCNN3(e3)
427 |
428 | e4 = self.Maxpool3(e3)
429 | e4 = self.RRCNN4(e4)
430 |
431 | e5 = self.Maxpool4(e4)
432 | e5 = self.RRCNN5(e5)
433 |
434 | d5 = self.Up5(e5)
435 | e4 = self.Att5(g=d5, x=e4)
436 | d5 = torch.cat((e4, d5), dim=1)
437 | d5 = self.Up_RRCNN5(d5)
438 |
439 | d4 = self.Up4(d5)
440 | e3 = self.Att4(g=d4, x=e3)
441 | d4 = torch.cat((e3, d4), dim=1)
442 | d4 = self.Up_RRCNN4(d4)
443 |
444 | d3 = self.Up3(d4)
445 | e2 = self.Att3(g=d3, x=e2)
446 | d3 = torch.cat((e2, d3), dim=1)
447 | d3 = self.Up_RRCNN3(d3)
448 |
449 | d2 = self.Up2(d3)
450 | e1 = self.Att2(g=d2, x=e1)
451 | d2 = torch.cat((e1, d2), dim=1)
452 | d2 = self.Up_RRCNN2(d2)
453 |
454 | out = self.Conv(d2)
455 |
456 | # out = self.active(out)
457 |
458 | return out
459 |
460 | #For nested 3 channels are required
461 |
462 | class conv_block_nested(nn.Module):
463 |
464 | def __init__(self, in_ch, mid_ch, out_ch):
465 | super(conv_block_nested, self).__init__()
466 | self.activation = nn.ReLU(inplace=True)
467 | self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
468 | self.bn1 = nn.BatchNorm2d(mid_ch)
469 | self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
470 | self.bn2 = nn.BatchNorm2d(out_ch)
471 |
472 | def forward(self, x):
473 | x = self.conv1(x)
474 | x = self.bn1(x)
475 | x = self.activation(x)
476 |
477 | x = self.conv2(x)
478 | x = self.bn2(x)
479 | output = self.activation(x)
480 |
481 | return output
482 |
483 | #Nested Unet
484 |
485 | class NestedUNet(nn.Module):
486 | """
487 | Implementation of this paper:
488 | https://arxiv.org/pdf/1807.10165.pdf
489 | """
490 | def __init__(self, in_ch=3, out_ch=1):
491 | super(NestedUNet, self).__init__()
492 |
493 | n1 = 64
494 | filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
495 |
496 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
497 | self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
498 |
499 | self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
500 | self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
501 | self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
502 | self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
503 | self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])
504 |
505 | self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0])
506 | self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
507 | self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
508 | self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3])
509 |
510 | self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1], filters[0], filters[0])
511 | self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2], filters[1], filters[1])
512 | self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3], filters[2], filters[2])
513 |
514 | self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1], filters[0], filters[0])
515 | self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2], filters[1], filters[1])
516 |
517 | self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1], filters[0], filters[0])
518 |
519 | self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)
520 |
521 |
522 | def forward(self, x):
523 |
524 | x0_0 = self.conv0_0(x)
525 | x1_0 = self.conv1_0(self.pool(x0_0))
526 | x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))
527 |
528 | x2_0 = self.conv2_0(self.pool(x1_0))
529 | x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
530 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))
531 |
532 | x3_0 = self.conv3_0(self.pool(x2_0))
533 | x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
534 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
535 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))
536 |
537 | x4_0 = self.conv4_0(self.pool(x3_0))
538 | x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
539 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
540 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
541 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))
542 |
543 | output = self.final(x0_4)
544 | return output
545 |
546 | #Dictioary Unet
547 | #if required for getting the filters and model parameters for each step
548 |
549 | class ConvolutionBlock(nn.Module):
550 | """Convolution block"""
551 |
552 | def __init__(self, in_filters, out_filters, kernel_size=3, batchnorm=True, last_active=F.relu):
553 | super(ConvolutionBlock, self).__init__()
554 |
555 | self.bn = batchnorm
556 | self.last_active = last_active
557 | self.c1 = nn.Conv2d(in_filters, out_filters, kernel_size, padding=1)
558 | self.b1 = nn.BatchNorm2d(out_filters)
559 | self.c2 = nn.Conv2d(out_filters, out_filters, kernel_size, padding=1)
560 | self.b2 = nn.BatchNorm2d(out_filters)
561 |
562 | def forward(self, x):
563 | x = self.c1(x)
564 | if self.bn:
565 | x = self.b1(x)
566 | x = F.relu(x)
567 | x = self.c2(x)
568 | if self.bn:
569 | x = self.b2(x)
570 | x = self.last_active(x)
571 | return x
572 |
573 |
574 | class ContractiveBlock(nn.Module):
575 | """Deconvuling Block"""
576 |
577 | def __init__(self, in_filters, out_filters, conv_kern=3, pool_kern=2, dropout=0.5, batchnorm=True):
578 | super(ContractiveBlock, self).__init__()
579 | self.c1 = ConvolutionBlock(in_filters=in_filters, out_filters=out_filters, kernel_size=conv_kern,
580 | batchnorm=batchnorm)
581 | self.p1 = nn.MaxPool2d(kernel_size=pool_kern, ceil_mode=True)
582 | self.d1 = nn.Dropout2d(dropout)
583 |
584 | def forward(self, x):
585 | c = self.c1(x)
586 | return c, self.d1(self.p1(c))
587 |
588 |
589 | class ExpansiveBlock(nn.Module):
590 | """Upconvole Block"""
591 |
592 | def __init__(self, in_filters1, in_filters2, out_filters, tr_kern=3, conv_kern=3, stride=2, dropout=0.5):
593 | super(ExpansiveBlock, self).__init__()
594 | self.t1 = nn.ConvTranspose2d(in_filters1, out_filters, tr_kern, stride=2, padding=1, output_padding=1)
595 | self.d1 = nn.Dropout(dropout)
596 | self.c1 = ConvolutionBlock(out_filters + in_filters2, out_filters, conv_kern)
597 |
598 | def forward(self, x, contractive_x):
599 | x_ups = self.t1(x)
600 | x_concat = torch.cat([x_ups, contractive_x], 1)
601 | x_fin = self.c1(self.d1(x_concat))
602 | return x_fin
603 |
604 |
605 | class Unet_dict(nn.Module):
606 | """Unet which operates with filters dictionary values"""
607 |
608 | def __init__(self, n_labels, n_filters=32, p_dropout=0.5, batchnorm=True):
609 | super(Unet_dict, self).__init__()
610 | filters_dict = {}
611 | filt_pair = [3, n_filters]
612 |
613 | for i in range(4):
614 | self.add_module('contractive_' + str(i), ContractiveBlock(filt_pair[0], filt_pair[1], batchnorm=batchnorm))
615 | filters_dict['contractive_' + str(i)] = (filt_pair[0], filt_pair[1])
616 | filt_pair[0] = filt_pair[1]
617 | filt_pair[1] = filt_pair[1] * 2
618 |
619 | self.bottleneck = ConvolutionBlock(filt_pair[0], filt_pair[1], batchnorm=batchnorm)
620 | filters_dict['bottleneck'] = (filt_pair[0], filt_pair[1])
621 |
622 | for i in reversed(range(4)):
623 | self.add_module('expansive_' + str(i),
624 | ExpansiveBlock(filt_pair[1], filters_dict['contractive_' + str(i)][1], filt_pair[0]))
625 | filters_dict['expansive_' + str(i)] = (filt_pair[1], filt_pair[0])
626 | filt_pair[1] = filt_pair[0]
627 | filt_pair[0] = filt_pair[0] // 2
628 |
629 | self.output = nn.Conv2d(filt_pair[1], n_labels, kernel_size=1)
630 | filters_dict['output'] = (filt_pair[1], n_labels)
631 | self.filters_dict = filters_dict
632 |
633 | # final_forward
634 | def forward(self, x):
635 | c00, c0 = self.contractive_0(x)
636 | c11, c1 = self.contractive_1(c0)
637 | c22, c2 = self.contractive_2(c1)
638 | c33, c3 = self.contractive_3(c2)
639 | bottle = self.bottleneck(c3)
640 | u3 = F.relu(self.expansive_3(bottle, c33))
641 | u2 = F.relu(self.expansive_2(u3, c22))
642 | u1 = F.relu(self.expansive_1(u2, c11))
643 | u0 = F.relu(self.expansive_0(u1, c00))
644 | return F.softmax(self.output(u0), dim=1)
645 |
646 | #Need to check why this Unet is not workin properly
647 | #
648 | # class Convolution2(nn.Module):
649 | # """Convolution Block using 2 Conv2D
650 | # Args:
651 | # in_channels = Input Channels
652 | # out_channels = Output Channels
653 | # kernal_size = 3
654 | # activation = Relu
655 | # batchnorm = True
656 | #
657 | # Output:
658 | # Sequential Relu output """
659 | #
660 | # def __init__(self, in_channels, out_channels, kernal_size=3, activation='Relu', batchnorm=True):
661 | # super(Convolution2, self).__init__()
662 | #
663 | # self.in_channels = in_channels
664 | # self.out_channels = out_channels
665 | # self.kernal_size = kernal_size
666 | # self.batchnorm1 = batchnorm
667 | #
668 | # self.batchnorm2 = batchnorm
669 | # self.activation = activation
670 | #
671 | # self.conv1 = nn.Conv2d(self.in_channels, self.out_channels, self.kernal_size, padding=1, bias=True)
672 | # self.conv2 = nn.Conv2d(self.out_channels, self.out_channels, self.kernal_size, padding=1, bias=True)
673 | #
674 | # self.b1 = nn.BatchNorm2d(out_channels)
675 | # self.b2 = nn.BatchNorm2d(out_channels)
676 | #
677 | # if self.activation == 'LRelu':
678 | # self.a1 = nn.LeakyReLU(inplace=True)
679 | # if self.activation == 'Relu':
680 | # self.a1 = nn.ReLU(inplace=True)
681 | #
682 | # if self.activation == 'LRelu':
683 | # self.a2 = nn.LeakyReLU(inplace=True)
684 | # if self.activation == 'Relu':
685 | # self.a2 = nn.ReLU(inplace=True)
686 | #
687 | # def forward(self, x):
688 | # x1 = self.conv1(x)
689 | #
690 | # if self.batchnorm1:
691 | # x1 = self.b1(x1)
692 | #
693 | # x1 = self.a1(x1)
694 | #
695 | # x1 = self.conv2(x1)
696 | #
697 | # if self.batchnorm2:
698 | # x1 = self.b1(x1)
699 | #
700 | # x = self.a2(x1)
701 | #
702 | # return x
703 | #
704 | #
705 | # class UNet(nn.Module):
706 | # """Implementation of U-Net: Convolutional Networks for Biomedical Image Segmentation (Ronneberger et al., 2015)
707 | # https://arxiv.org/abs/1505.04597
708 | # Args:
709 | # n_class = no. of classes"""
710 | #
711 | # def __init__(self, n_class, dropout=0.4):
712 | # super(UNet, self).__init__()
713 | #
714 | # in_ch = 3
715 | # n1 = 64
716 | # n2 = n1*2
717 | # n3 = n2*2
718 | # n4 = n3*2
719 | # n5 = n4*2
720 | #
721 | # self.dconv_down1 = Convolution2(in_ch, n1)
722 | # self.dconv_down2 = Convolution2(n1, n2)
723 | # self.dconv_down3 = Convolution2(n2, n3)
724 | # self.dconv_down4 = Convolution2(n3, n4)
725 | # self.dconv_down5 = Convolution2(n4, n5)
726 | #
727 | # self.maxpool1 = nn.MaxPool2d(2)
728 | # self.maxpool2 = nn.MaxPool2d(2)
729 | # self.maxpool3 = nn.MaxPool2d(2)
730 | # self.maxpool4 = nn.MaxPool2d(2)
731 | #
732 | # self.upsample1 = nn.Upsample(scale_factor=2)#, mode='bilinear', align_corners=True)
733 | # self.upsample2 = nn.Upsample(scale_factor=2)#, mode='bilinear', align_corners=True)
734 | # self.upsample3 = nn.Upsample(scale_factor=2)#, mode='bilinear', align_corners=True)
735 | # self.upsample4 = nn.Upsample(scale_factor=2)#, mode='bilinear', align_corners=True)
736 | #
737 | # self.dropout1 = nn.Dropout(dropout)
738 | # self.dropout2 = nn.Dropout(dropout)
739 | # self.dropout3 = nn.Dropout(dropout)
740 | # self.dropout4 = nn.Dropout(dropout)
741 | # self.dropout5 = nn.Dropout(dropout)
742 | # self.dropout6 = nn.Dropout(dropout)
743 | # self.dropout7 = nn.Dropout(dropout)
744 | # self.dropout8 = nn.Dropout(dropout)
745 | #
746 | # self.dconv_up4 = Convolution2(n4 + n5, n4)
747 | # self.dconv_up3 = Convolution2(n3 + n4, n3)
748 | # self.dconv_up2 = Convolution2(n2 + n3, n2)
749 | # self.dconv_up1 = Convolution2(n1 + n2, n1)
750 | #
751 | # self.conv_last = nn.Conv2d(n1, n_class, kernel_size=1, stride=1, padding=0)
752 | # # self.active = torch.nn.Sigmoid()
753 | #
754 | #
755 | #
756 | # def forward(self, x):
757 | # conv1 = self.dconv_down1(x)
758 | # x = self.maxpool1(conv1)
759 | # # x = self.dropout1(x)
760 | #
761 | # conv2 = self.dconv_down2(x)
762 | # x = self.maxpool2(conv2)
763 | # # x = self.dropout2(x)
764 | #
765 | # conv3 = self.dconv_down3(x)
766 | # x = self.maxpool3(conv3)
767 | # # x = self.dropout3(x)
768 | #
769 | # conv4 = self.dconv_down4(x)
770 | # x = self.maxpool4(conv4)
771 | # #x = self.dropout4(x)
772 | #
773 | # x = self.dconv_down5(x)
774 | #
775 | # x = self.upsample4(x)
776 | # x = torch.cat((x, conv4), dim=1)
777 | # #x = self.dropout5(x)
778 | #
779 | # x = self.dconv_up4(x)
780 | # x = self.upsample3(x)
781 | # x = torch.cat((x, conv3), dim=1)
782 | # # x = self.dropout6(x)
783 | #
784 | # x = self.dconv_up3(x)
785 | # x = self.upsample2(x)
786 | # x = torch.cat((x, conv2), dim=1)
787 | # #x = self.dropout7(x)
788 | #
789 | # x = self.dconv_up2(x)
790 | # x = self.upsample1(x)
791 | # x = torch.cat((x, conv1), dim=1)
792 | # #x = self.dropout8(x)
793 | #
794 | # x = self.dconv_up1(x)
795 | #
796 | # x = self.conv_last(x)
797 | # # out = self.active(x)
798 | #
799 | # return x
800 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unet-Segmentation-Pytorch-Nest-of-Unets
2 |
3 | [](https://www.python.org/)
4 |
5 | [](http://hits.dwyl.io/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets)
6 | [](https://opensource.org/licenses/MIT)
7 | [](https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/graphs/commit-activity)
8 | [](https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/issues)
9 | [](https://paperswithcode.com/sota/semantic-segmentation-on-cityscapes-val?p=unet-a-nested-u-net-architecture-for-medical)
10 |
11 | Implementation of different kinds of Unet Models for Image Segmentation
12 |
13 | 1) **UNet** - U-Net: Convolutional Networks for Biomedical Image Segmentation
14 | https://arxiv.org/abs/1505.04597
15 |
16 | 2) **RCNN-UNet** - Recurrent Residual Convolutional Neural Network based on U-Net (R2U-Net) for Medical Image Segmentation
17 | https://arxiv.org/abs/1802.06955
18 |
19 | 3) **Attention Unet** - Attention U-Net: Learning Where to Look for the Pancreas
20 | https://arxiv.org/abs/1804.03999
21 |
22 | 4) **RCNN-Attention Unet** - Attention R2U-Net : Just integration of two recent advanced works (R2U-Net + Attention U-Net)
23 |
24 |
25 | 5) **Nested UNet** - UNet++: A Nested U-Net Architecture for Medical Image Segmentation
26 | https://arxiv.org/abs/1807.10165
27 |
28 | With Layer Visualization
29 |
30 | ## 1. Getting Started
31 |
32 | Clone the repo:
33 |
34 | ```bash
35 | git clone https://github.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets.git
36 | ```
37 |
38 | ## 2. Requirements
39 |
40 | ```
41 | python>=3.6
42 | torch>=0.4.0
43 | torchvision
44 | torchsummary
45 | tensorboardx
46 | natsort
47 | numpy
48 | pillow
49 | scipy
50 | scikit-image
51 | sklearn
52 | ```
53 | Install all dependent libraries:
54 | ```bash
55 | pip install -r requirements.txt
56 | ```
57 | ## 3. Run the file
58 |
59 | Add all your folders to this line 106-113
60 | ```
61 | t_data = '' # Input data
62 | l_data = '' #Input Label
63 | test_image = '' #Image to be predicted while training
64 | test_label = '' #Label of the prediction Image
65 | test_folderP = '' #Test folder Image
66 | test_folderL = '' #Test folder Label for calculating the Dice score
67 | ```
68 |
69 | ## 4. Types of Unet
70 |
71 | **Unet**
72 | 
73 |
74 | **RCNN Unet**
75 | 
76 |
77 |
78 | **Attention Unet**
79 | 
80 |
81 |
82 | **Attention-RCNN Unet**
83 | 
84 |
85 |
86 | **Nested Unet**
87 |
88 | 
89 |
90 | ## 5. Visualization
91 |
92 | To plot the loss , Visdom would be required. The code is already written, just uncomment the required part.
93 | Gradient flow can be used too. Taken from (https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/10)
94 |
95 | A model folder is created and all the data is stored inside that.
96 | Last layer will be saved in the model folder. If any particular layer is required , mention it in the line 361.
97 |
98 | **Layer Visulization**
99 |
100 | 
101 |
102 | **Filter Visulization**
103 |
104 | 
105 |
106 | **TensorboardX**
107 | Still have to tweak some parameters to get visualization. Have messed up this trying to make pytorch 1.1.0 working with tensorboard directly (and then came to know Currently it doesn't support anything apart from linear graphs)
108 |
109 |
110 | **Input Image Visulization for checking**
111 |
112 | **a) Original Image**
113 |
114 |
115 |
116 | **b) CenterCrop Image**
117 |
118 |
119 |
120 | ## 6. Results
121 |
122 | **Dice Score for hippocampus segmentation**
123 | ADNI-LONI Dataset
124 |
125 |
126 |
127 | ## 7. Citation
128 |
129 | If you find it usefull for your work.
130 | ```
131 | @article{DBLP:journals/corr/abs-1906-07160,
132 | author = {Malav Bateriwala and
133 | Pierrick Bourgeat},
134 | title = {Enforcing temporal consistency in Deep Learning segmentation of brain
135 | {MR} images},
136 | journal = {CoRR},
137 | volume = {abs/1906.07160},
138 | year = {2019},
139 | url = {http://arxiv.org/abs/1906.07160},
140 | archivePrefix = {arXiv},
141 | eprint = {1906.07160},
142 | timestamp = {Mon, 24 Jun 2019 17:28:45 +0200},
143 | biburl = {https://dblp.org/rec/bib/journals/corr/abs-1906-07160},
144 | bibsource = {dblp computer science bibliography, https://dblp.org}
145 | }
146 | ```
147 |
148 | ## 8. Blog about different Unets
149 | ```
150 | In progress
151 | ```
152 |
153 |
154 |
--------------------------------------------------------------------------------
/dice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/dice.png
--------------------------------------------------------------------------------
/images/att-r2u.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/att-r2u.png
--------------------------------------------------------------------------------
/images/att-unet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/att-unet.png
--------------------------------------------------------------------------------
/images/filt1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/filt1.png
--------------------------------------------------------------------------------
/images/in1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/in1.png
--------------------------------------------------------------------------------
/images/in2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/in2.png
--------------------------------------------------------------------------------
/images/l2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/l2.png
--------------------------------------------------------------------------------
/images/nested.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/nested.jpg
--------------------------------------------------------------------------------
/images/r2unet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/r2unet.png
--------------------------------------------------------------------------------
/images/tensorb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/tensorb.png
--------------------------------------------------------------------------------
/images/unet1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets/f63262de13dc7a31f426e3ac6cd22eafdbd60131/images/unet1.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import torch.nn.functional as F
3 |
4 |
5 | def dice_loss(prediction, target):
6 | """Calculating the dice loss
7 | Args:
8 | prediction = predicted image
9 | target = Targeted image
10 | Output:
11 | dice_loss"""
12 |
13 | smooth = 1.0
14 |
15 | i_flat = prediction.view(-1)
16 | t_flat = target.view(-1)
17 |
18 | intersection = (i_flat * t_flat).sum()
19 |
20 | return 1 - ((2. * intersection + smooth) / (i_flat.sum() + t_flat.sum() + smooth))
21 |
22 |
23 | def calc_loss(prediction, target, bce_weight=0.5):
24 | """Calculating the loss and metrics
25 | Args:
26 | prediction = predicted image
27 | target = Targeted image
28 | metrics = Metrics printed
29 | bce_weight = 0.5 (default)
30 | Output:
31 | loss : dice loss of the epoch """
32 | bce = F.binary_cross_entropy_with_logits(prediction, target)
33 | prediction = F.sigmoid(prediction)
34 | dice = dice_loss(prediction, target)
35 |
36 | loss = bce * bce_weight + dice * (1 - bce_weight)
37 |
38 | return loss
39 |
40 |
41 | def threshold_predictions_v(predictions, thr=150):
42 | thresholded_preds = predictions[:]
43 | # hist = cv2.calcHist([predictions], [0], None, [2], [0, 2])
44 | # plt.plot(hist)
45 | # plt.xlim([0, 2])
46 | # plt.show()
47 | low_values_indices = thresholded_preds < thr
48 | thresholded_preds[low_values_indices] = 0
49 | low_values_indices = thresholded_preds >= thr
50 | thresholded_preds[low_values_indices] = 255
51 | return thresholded_preds
52 |
53 |
54 | def threshold_predictions_p(predictions, thr=0.01):
55 | thresholded_preds = predictions[:]
56 | #hist = cv2.calcHist([predictions], [0], None, [256], [0, 256])
57 | low_values_indices = thresholded_preds < thr
58 | thresholded_preds[low_values_indices] = 0
59 | low_values_indices = thresholded_preds >= thr
60 | thresholded_preds[low_values_indices] = 1
61 | return thresholded_preds
--------------------------------------------------------------------------------
/ploting.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from matplotlib.lines import Line2D
3 | import numpy as np
4 | from visdom import Visdom
5 |
6 |
7 | def show_images(images, labels):
8 | """Show image with label
9 | Args:
10 | images = input images
11 | labels = input labels
12 | Output:
13 | plt = concatenated image and label """
14 |
15 | plt.imshow(images.permute(1, 2, 0))
16 | plt.imshow(labels, alpha=0.7, cmap='gray')
17 | plt.figure()
18 |
19 |
20 | def show_training_dataset(training_dataset):
21 | """Showing the images in training set for dict images and labels
22 | Args:
23 | training_dataset = dictionary of images and labels
24 | Output:
25 | figure = 3 images shown"""
26 |
27 | if training_dataset:
28 | print(len(training_dataset))
29 |
30 | for i in range(len(training_dataset)):
31 | sample = training_dataset[i]
32 |
33 | print(i, sample['images'].shape, sample['labels'].shape)
34 |
35 | ax = plt.subplot(1, 4, i + 1)
36 | plt.tight_layout()
37 | ax.set_title('Sample #{}'.format(i))
38 | ax.axis('off')
39 | show_images(sample['images'],sample['labels'])
40 |
41 | if i == 3:
42 | plt.show()
43 | break
44 |
45 | class VisdomLinePlotter(object):
46 |
47 | """Plots to Visdom"""
48 |
49 | def __init__(self, env_name='main'):
50 | self.viz = Visdom()
51 | self.env = env_name
52 | self.plots = {}
53 |
54 | def plot(self, var_name, split_name, title_name, x, y):
55 | if var_name not in self.plots:
56 | self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict(
57 | legend=[split_name],
58 | title=title_name,
59 | xlabel='Epochs',
60 | ylabel=var_name
61 | ))
62 | else:
63 | self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update = 'append')
64 |
65 |
66 | def input_images(x, y, i, n_iter, k=1):
67 | """
68 |
69 | :param x: takes input image
70 | :param y: take input label
71 | :param i: the epoch number
72 | :param n_iter:
73 | :param k: for keeping it in loop
74 | :return: Returns a image and label
75 | """
76 | if k == 1:
77 | x1 = x
78 | y1 = y
79 |
80 | x2 = x1.to('cpu')
81 | y2 = y1.to('cpu')
82 | x2 = x2.detach().numpy()
83 | y2 = y2.detach().numpy()
84 |
85 | x3 = x2[1, 1, :, :]
86 | y3 = y2[1, 0, :, :]
87 |
88 | fig = plt.figure()
89 |
90 | ax1 = fig.add_subplot(1, 2, 1)
91 | ax1.imshow(x3)
92 | ax1.axis('off')
93 | ax1.set_xticklabels([])
94 | ax1.set_yticklabels([])
95 | ax1 = fig.add_subplot(1, 2, 2)
96 | ax1.imshow(y3)
97 | ax1.axis('off')
98 | ax1.set_xticklabels([])
99 | ax1.set_yticklabels([])
100 | plt.savefig(
101 | './model/pred/L_' + str(n_iter-1) + '_epoch_'
102 | + str(i))
103 |
104 |
105 | def plot_kernels(tensor, n_iter, num_cols=5, cmap="gray"):
106 | """Plotting the kernals and layers
107 | Args:
108 | Tensor :Input layer,
109 | n_iter : number of interation,
110 | num_cols : number of columbs required for figure
111 | Output:
112 | Gives the figure of the size decided with output layers activation map
113 |
114 | Default : Last layer will be taken into consideration
115 | """
116 | if not len(tensor.shape) == 4:
117 | raise Exception("assumes a 4D tensor")
118 |
119 | fig = plt.figure()
120 | i = 0
121 | t = tensor.data.numpy()
122 | b = 0
123 | a = 1
124 |
125 | for t1 in t:
126 | for t2 in t1:
127 | i += 1
128 |
129 | ax1 = fig.add_subplot(5, num_cols, i)
130 | ax1.imshow(t2, cmap=cmap)
131 | ax1.axis('off')
132 | ax1.set_xticklabels([])
133 | ax1.set_yticklabels([])
134 |
135 | if i == 1:
136 | a = 1
137 | if a == 10:
138 | break
139 | a += 1
140 | if i % a == 0:
141 | a = 0
142 | b += 1
143 | if b == 20:
144 | break
145 |
146 | plt.savefig(
147 | './model/pred/Kernal_' + str(n_iter - 1) + '_epoch_'
148 | + str(i))
149 |
150 |
151 | class LayerActivations():
152 | """Getting the hooks on each layer"""
153 |
154 | features = None
155 |
156 | def __init__(self, layer):
157 | self.hook = layer.register_forward_hook(self.hook_fn)
158 |
159 | def hook_fn(self, module, input, output):
160 | self.features = output.cpu()
161 |
162 | def remove(self):
163 | self.hook.remove()
164 |
165 |
166 | #to get gradient flow
167 | #From Pytorch-forums
168 | def plot_grad_flow(named_parameters,n_iter):
169 |
170 | '''Plots the gradients flowing through different layers in the net during training.
171 | Can be used for checking for possible gradient vanishing / exploding problems.
172 |
173 | Usage: Plug this function in Trainer class after loss.backwards() as
174 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow'''
175 | ave_grads = []
176 | max_grads = []
177 | layers = []
178 | for n, p in named_parameters:
179 | if (p.requires_grad) and ("bias" not in n):
180 | layers.append(n)
181 | ave_grads.append(p.grad.abs().mean())
182 | max_grads.append(p.grad.abs().max())
183 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c")
184 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b")
185 | plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k")
186 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
187 | plt.xlim(left=0, right=len(ave_grads))
188 | plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions
189 | plt.xlabel("Layers")
190 | plt.ylabel("average gradient")
191 | plt.title("Gradient flow")
192 | plt.grid(True)
193 | plt.legend([Line2D([0], [0], color="c", lw=4),
194 | Line2D([0], [0], color="b", lw=4),
195 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient'])
196 | #plt.savefig('./model/pred/Grad_Flow_' + str(n_iter - 1))
197 |
--------------------------------------------------------------------------------
/pytorch_run.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import glob
6 | #import SimpleITK as sitk
7 | from torch import optim
8 | import torch.utils.data
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | import torch.nn
13 | import torchvision
14 | import matplotlib.pyplot as plt
15 | import natsort
16 | from torch.utils.data.sampler import SubsetRandomSampler
17 | from Data_Loader import Images_Dataset, Images_Dataset_folder
18 | import torchsummary
19 | #from torch.utils.tensorboard import SummaryWriter
20 | #from tensorboardX import SummaryWriter
21 |
22 | import shutil
23 | import random
24 | from Models import Unet_dict, NestedUNet, U_Net, R2U_Net, AttU_Net, R2AttU_Net
25 | from losses import calc_loss, dice_loss, threshold_predictions_v,threshold_predictions_p
26 | from ploting import plot_kernels, LayerActivations, input_images, plot_grad_flow
27 | from Metrics import dice_coeff, accuracy_score
28 | import time
29 | #from ploting import VisdomLinePlotter
30 | #from visdom import Visdom
31 |
32 |
33 | #######################################################
34 | #Checking if GPU is used
35 | #######################################################
36 |
37 | train_on_gpu = torch.cuda.is_available()
38 |
39 | if not train_on_gpu:
40 | print('CUDA is not available. Training on CPU')
41 | else:
42 | print('CUDA is available. Training on GPU')
43 |
44 | device = torch.device("cuda:0" if train_on_gpu else "cpu")
45 |
46 | #######################################################
47 | #Setting the basic paramters of the model
48 | #######################################################
49 |
50 | batch_size = 4
51 | print('batch_size = ' + str(batch_size))
52 |
53 | valid_size = 0.15
54 |
55 | epoch = 15
56 | print('epoch = ' + str(epoch))
57 |
58 | random_seed = random.randint(1, 100)
59 | print('random_seed = ' + str(random_seed))
60 |
61 | shuffle = True
62 | valid_loss_min = np.Inf
63 | num_workers = 4
64 | lossT = []
65 | lossL = []
66 | lossL.append(np.inf)
67 | lossT.append(np.inf)
68 | epoch_valid = epoch-2
69 | n_iter = 1
70 | i_valid = 0
71 |
72 | pin_memory = False
73 | if train_on_gpu:
74 | pin_memory = True
75 |
76 | #plotter = VisdomLinePlotter(env_name='Tutorial Plots')
77 |
78 | #######################################################
79 | #Setting up the model
80 | #######################################################
81 |
82 | model_Inputs = [U_Net, R2U_Net, AttU_Net, R2AttU_Net, NestedUNet]
83 |
84 |
85 | def model_unet(model_input, in_channel=3, out_channel=1):
86 | model_test = model_input(in_channel, out_channel)
87 | return model_test
88 |
89 | #passsing this string so that if it's AttU_Net or R2ATTU_Net it doesn't throw an error at torchSummary
90 |
91 |
92 | model_test = model_unet(model_Inputs[0], 3, 1)
93 |
94 | model_test.to(device)
95 |
96 | #######################################################
97 | #Getting the Summary of Model
98 | #######################################################
99 |
100 | torchsummary.summary(model_test, input_size=(3, 128, 128))
101 |
102 | #######################################################
103 | #Passing the Dataset of Images and Labels
104 | #######################################################
105 |
106 | t_data = '/flush1/bat161/segmentation/New_Trails/venv/DATA/new_3C_I_ori/'
107 | l_data = '/flush1/bat161/segmentation/New_Trails/venv/DATA/new_3C_L_ori/'
108 | test_image = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_I_ori/0131_0009.png'
109 | test_label = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_L_ori/0131_0009.png'
110 | test_folderP = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_I_ori/*'
111 | test_folderL = '/flush1/bat161/segmentation/New_Trails/venv/DATA/test_new_3C_L_ori/*'
112 |
113 | Training_Data = Images_Dataset_folder(t_data,
114 | l_data)
115 |
116 | #######################################################
117 | #Giving a transformation for input data
118 | #######################################################
119 |
120 | data_transform = torchvision.transforms.Compose([
121 | # torchvision.transforms.Resize((128,128)),
122 | # torchvision.transforms.CenterCrop(96),
123 | torchvision.transforms.ToTensor(),
124 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
125 | ])
126 |
127 | #######################################################
128 | #Trainging Validation Split
129 | #######################################################
130 |
131 | num_train = len(Training_Data)
132 | indices = list(range(num_train))
133 | split = int(np.floor(valid_size * num_train))
134 |
135 | if shuffle:
136 | np.random.seed(random_seed)
137 | np.random.shuffle(indices)
138 |
139 | train_idx, valid_idx = indices[split:], indices[:split]
140 | train_sampler = SubsetRandomSampler(train_idx)
141 | valid_sampler = SubsetRandomSampler(valid_idx)
142 |
143 | train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=train_sampler,
144 | num_workers=num_workers, pin_memory=pin_memory,)
145 |
146 | valid_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=valid_sampler,
147 | num_workers=num_workers, pin_memory=pin_memory,)
148 |
149 | #######################################################
150 | #Using Adam as Optimizer
151 | #######################################################
152 |
153 | initial_lr = 0.001
154 | opt = torch.optim.Adam(model_test.parameters(), lr=initial_lr) # try SGD
155 | #opt = optim.SGD(model_test.parameters(), lr = initial_lr, momentum=0.99)
156 |
157 | MAX_STEP = int(1e10)
158 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, MAX_STEP, eta_min=1e-5)
159 | #scheduler = optim.lr_scheduler.CosineAnnealingLr(opt, epoch, 1)
160 |
161 | #######################################################
162 | #Writing the params to tensorboard
163 | #######################################################
164 |
165 | #writer1 = SummaryWriter()
166 | #dummy_inp = torch.randn(1, 3, 128, 128)
167 | #model_test.to('cpu')
168 | #writer1.add_graph(model_test, model_test(torch.randn(3, 3, 128, 128, requires_grad=True)))
169 | #model_test.to(device)
170 |
171 | #######################################################
172 | #Creating a Folder for every data of the program
173 | #######################################################
174 |
175 | New_folder = './model'
176 |
177 | if os.path.exists(New_folder) and os.path.isdir(New_folder):
178 | shutil.rmtree(New_folder)
179 |
180 | try:
181 | os.mkdir(New_folder)
182 | except OSError:
183 | print("Creation of the main directory '%s' failed " % New_folder)
184 | else:
185 | print("Successfully created the main directory '%s' " % New_folder)
186 |
187 | #######################################################
188 | #Setting the folder of saving the predictions
189 | #######################################################
190 |
191 | read_pred = './model/pred'
192 |
193 | #######################################################
194 | #Checking if prediction folder exixts
195 | #######################################################
196 |
197 | if os.path.exists(read_pred) and os.path.isdir(read_pred):
198 | shutil.rmtree(read_pred)
199 |
200 | try:
201 | os.mkdir(read_pred)
202 | except OSError:
203 | print("Creation of the prediction directory '%s' failed of dice loss" % read_pred)
204 | else:
205 | print("Successfully created the prediction directory '%s' of dice loss" % read_pred)
206 |
207 | #######################################################
208 | #checking if the model exists and if true then delete
209 | #######################################################
210 |
211 | read_model_path = './model/Unet_D_' + str(epoch) + '_' + str(batch_size)
212 |
213 | if os.path.exists(read_model_path) and os.path.isdir(read_model_path):
214 | shutil.rmtree(read_model_path)
215 | print('Model folder there, so deleted for newer one')
216 |
217 | try:
218 | os.mkdir(read_model_path)
219 | except OSError:
220 | print("Creation of the model directory '%s' failed" % read_model_path)
221 | else:
222 | print("Successfully created the model directory '%s' " % read_model_path)
223 |
224 | #######################################################
225 | #Training loop
226 | #######################################################
227 |
228 | for i in range(epoch):
229 |
230 | train_loss = 0.0
231 | valid_loss = 0.0
232 | since = time.time()
233 | scheduler.step(i)
234 | lr = scheduler.get_lr()
235 |
236 | #######################################################
237 | #Training Data
238 | #######################################################
239 |
240 | model_test.train()
241 | k = 1
242 |
243 | for x, y in train_loader:
244 | x, y = x.to(device), y.to(device)
245 |
246 | #If want to get the input images with their Augmentation - To check the data flowing in net
247 | input_images(x, y, i, n_iter, k)
248 |
249 | # grid_img = torchvision.utils.make_grid(x)
250 | #writer1.add_image('images', grid_img, 0)
251 |
252 | # grid_lab = torchvision.utils.make_grid(y)
253 |
254 | opt.zero_grad()
255 |
256 | y_pred = model_test(x)
257 | lossT = calc_loss(y_pred, y) # Dice_loss Used
258 |
259 | train_loss += lossT.item() * x.size(0)
260 | lossT.backward()
261 | # plot_grad_flow(model_test.named_parameters(), n_iter)
262 | opt.step()
263 | x_size = lossT.item() * x.size(0)
264 | k = 2
265 |
266 | # for name, param in model_test.named_parameters():
267 | # name = name.replace('.', '/')
268 | # writer1.add_histogram(name, param.data.cpu().numpy(), i + 1)
269 | # writer1.add_histogram(name + '/grad', param.grad.data.cpu().numpy(), i + 1)
270 |
271 |
272 | #######################################################
273 | #Validation Step
274 | #######################################################
275 |
276 | model_test.eval()
277 | torch.no_grad() #to increase the validation process uses less memory
278 |
279 | for x1, y1 in valid_loader:
280 | x1, y1 = x1.to(device), y1.to(device)
281 |
282 | y_pred1 = model_test(x1)
283 | lossL = calc_loss(y_pred1, y1) # Dice_loss Used
284 |
285 | valid_loss += lossL.item() * x1.size(0)
286 | x_size1 = lossL.item() * x1.size(0)
287 |
288 | #######################################################
289 | #Saving the predictions
290 | #######################################################
291 |
292 | im_tb = Image.open(test_image)
293 | im_label = Image.open(test_label)
294 | s_tb = data_transform(im_tb)
295 | s_label = data_transform(im_label)
296 | s_label = s_label.detach().numpy()
297 |
298 | pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu()
299 | pred_tb = F.sigmoid(pred_tb)
300 | pred_tb = pred_tb.detach().numpy()
301 |
302 | #pred_tb = threshold_predictions_v(pred_tb)
303 |
304 | x1 = plt.imsave(
305 | './model/pred/img_iteration_' + str(n_iter) + '_epoch_'
306 | + str(i) + '.png', pred_tb[0][0])
307 |
308 | # accuracy = accuracy_score(pred_tb[0][0], s_label)
309 |
310 | #######################################################
311 | #To write in Tensorboard
312 | #######################################################
313 |
314 | train_loss = train_loss / len(train_idx)
315 | valid_loss = valid_loss / len(valid_idx)
316 |
317 | if (i+1) % 1 == 0:
318 | print('Epoch: {}/{} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(i + 1, epoch, train_loss,
319 | valid_loss))
320 | # writer1.add_scalar('Train Loss', train_loss, n_iter)
321 | # writer1.add_scalar('Validation Loss', valid_loss, n_iter)
322 | #writer1.add_image('Pred', pred_tb[0]) #try to get output of shape 3
323 |
324 |
325 | #######################################################
326 | #Early Stopping
327 | #######################################################
328 |
329 | if valid_loss <= valid_loss_min and epoch_valid >= i: # and i_valid <= 2:
330 |
331 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model '.format(valid_loss_min, valid_loss))
332 | torch.save(model_test.state_dict(),'./model/Unet_D_' +
333 | str(epoch) + '_' + str(batch_size) + '/Unet_epoch_' + str(epoch)
334 | + '_batchsize_' + str(batch_size) + '.pth')
335 | # print(accuracy)
336 | if round(valid_loss, 4) == round(valid_loss_min, 4):
337 | print(i_valid)
338 | i_valid = i_valid+1
339 | valid_loss_min = valid_loss
340 | #if i_valid ==3:
341 | # break
342 |
343 | #######################################################
344 | # Extracting the intermediate layers
345 | #######################################################
346 |
347 | #####################################
348 | # for kernals
349 | #####################################
350 | x1 = torch.nn.ModuleList(model_test.children())
351 | # x2 = torch.nn.ModuleList(x1[16].children())
352 | #x3 = torch.nn.ModuleList(x2[0].children())
353 |
354 | #To get filters in the layers
355 | #plot_kernels(x1.weight.detach().cpu(), 7)
356 |
357 | #####################################
358 | # for images
359 | #####################################
360 | x2 = len(x1)
361 | dr = LayerActivations(x1[x2-1]) #Getting the last Conv Layer
362 |
363 | img = Image.open(test_image)
364 | s_tb = data_transform(img)
365 |
366 | pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu()
367 | pred_tb = F.sigmoid(pred_tb)
368 | pred_tb = pred_tb.detach().numpy()
369 |
370 | plot_kernels(dr.features, n_iter, 7, cmap="rainbow")
371 |
372 | time_elapsed = time.time() - since
373 | print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
374 | n_iter += 1
375 |
376 | #######################################################
377 | #closing the tensorboard writer
378 | #######################################################
379 |
380 | #writer1.close()
381 |
382 | #######################################################
383 | #if using dict
384 | #######################################################
385 |
386 | #model_test.filter_dict
387 |
388 | #######################################################
389 | #Loading the model
390 | #######################################################
391 |
392 | test1 =model_test.load_state_dict(torch.load('./model/Unet_D_' +
393 | str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch)
394 | + '_batchsize_' + str(batch_size) + '.pth'))
395 |
396 |
397 | #######################################################
398 | #checking if cuda is available
399 | #######################################################
400 |
401 | if torch.cuda.is_available():
402 | torch.cuda.empty_cache()
403 |
404 | #######################################################
405 | #Loading the model
406 | #######################################################
407 |
408 | model_test.load_state_dict(torch.load('./model/Unet_D_' +
409 | str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch)
410 | + '_batchsize_' + str(batch_size) + '.pth'))
411 |
412 | model_test.eval()
413 |
414 | #######################################################
415 | #opening the test folder and creating a folder for generated images
416 | #######################################################
417 |
418 | read_test_folder = glob.glob(test_folderP)
419 | x_sort_test = natsort.natsorted(read_test_folder) # To sort
420 |
421 |
422 | read_test_folder112 = './model/gen_images'
423 |
424 |
425 | if os.path.exists(read_test_folder112) and os.path.isdir(read_test_folder112):
426 | shutil.rmtree(read_test_folder112)
427 |
428 | try:
429 | os.mkdir(read_test_folder112)
430 | except OSError:
431 | print("Creation of the testing directory %s failed" % read_test_folder112)
432 | else:
433 | print("Successfully created the testing directory %s " % read_test_folder112)
434 |
435 |
436 | #For Prediction Threshold
437 |
438 | read_test_folder_P_Thres = './model/pred_threshold'
439 |
440 |
441 | if os.path.exists(read_test_folder_P_Thres) and os.path.isdir(read_test_folder_P_Thres):
442 | shutil.rmtree(read_test_folder_P_Thres)
443 |
444 | try:
445 | os.mkdir(read_test_folder_P_Thres)
446 | except OSError:
447 | print("Creation of the testing directory %s failed" % read_test_folder_P_Thres)
448 | else:
449 | print("Successfully created the testing directory %s " % read_test_folder_P_Thres)
450 |
451 | #For Label Threshold
452 |
453 | read_test_folder_L_Thres = './model/label_threshold'
454 |
455 |
456 | if os.path.exists(read_test_folder_L_Thres) and os.path.isdir(read_test_folder_L_Thres):
457 | shutil.rmtree(read_test_folder_L_Thres)
458 |
459 | try:
460 | os.mkdir(read_test_folder_L_Thres)
461 | except OSError:
462 | print("Creation of the testing directory %s failed" % read_test_folder_L_Thres)
463 | else:
464 | print("Successfully created the testing directory %s " % read_test_folder_L_Thres)
465 |
466 |
467 |
468 |
469 | #######################################################
470 | #saving the images in the files
471 | #######################################################
472 |
473 | img_test_no = 0
474 |
475 | for i in range(len(read_test_folder)):
476 | im = Image.open(x_sort_test[i])
477 |
478 | im1 = im
479 | im_n = np.array(im1)
480 | im_n_flat = im_n.reshape(-1, 1)
481 |
482 | for j in range(im_n_flat.shape[0]):
483 | if im_n_flat[j] != 0:
484 | im_n_flat[j] = 255
485 |
486 | s = data_transform(im)
487 | pred = model_test(s.unsqueeze(0).cuda()).cpu()
488 | pred = F.sigmoid(pred)
489 | pred = pred.detach().numpy()
490 |
491 | # pred = threshold_predictions_p(pred) #Value kept 0.01 as max is 1 and noise is very small.
492 |
493 | if i % 24 == 0:
494 | img_test_no = img_test_no + 1
495 |
496 | x1 = plt.imsave('./model/gen_images/im_epoch_' + str(epoch) + 'int_' + str(i)
497 | + '_img_no_' + str(img_test_no) + '.png', pred[0][0])
498 |
499 |
500 | ####################################################
501 | #Calculating the Dice Score
502 | ####################################################
503 |
504 | data_transform = torchvision.transforms.Compose([
505 | # torchvision.transforms.Resize((128,128)),
506 | # torchvision.transforms.CenterCrop(96),
507 | torchvision.transforms.Grayscale(),
508 | # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
509 | ])
510 |
511 |
512 |
513 | read_test_folderP = glob.glob('./model/gen_images/*')
514 | x_sort_testP = natsort.natsorted(read_test_folderP)
515 |
516 |
517 | read_test_folderL = glob.glob(test_folderL)
518 | x_sort_testL = natsort.natsorted(read_test_folderL) # To sort
519 |
520 |
521 | dice_score123 = 0.0
522 | x_count = 0
523 | x_dice = 0
524 |
525 | for i in range(len(read_test_folderP)):
526 |
527 | x = Image.open(x_sort_testP[i])
528 | s = data_transform(x)
529 | s = np.array(s)
530 | s = threshold_predictions_v(s)
531 |
532 | #save the images
533 | x1 = plt.imsave('./model/pred_threshold/im_epoch_' + str(epoch) + 'int_' + str(i)
534 | + '_img_no_' + str(img_test_no) + '.png', s)
535 |
536 | y = Image.open(x_sort_testL[i])
537 | s2 = data_transform(y)
538 | s3 = np.array(s2)
539 | # s2 =threshold_predictions_v(s2)
540 |
541 | #save the Images
542 | y1 = plt.imsave('./model/label_threshold/im_epoch_' + str(epoch) + 'int_' + str(i)
543 | + '_img_no_' + str(img_test_no) + '.png', s3)
544 |
545 | total = dice_coeff(s, s3)
546 | print(total)
547 |
548 | if total <= 0.3:
549 | x_count += 1
550 | if total > 0.3:
551 | x_dice = x_dice + total
552 | dice_score123 = dice_score123 + total
553 |
554 |
555 | print('Dice Score : ' + str(dice_score123/len(read_test_folderP)))
556 | #print(x_count)
557 | #print(x_dice)
558 | #print('Dice Score : ' + str(float(x_dice/(len(read_test_folderP)-x_count))))
559 |
560 |
--------------------------------------------------------------------------------
/pytorch_run_old.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import glob
6 |
7 | from torch import optim
8 | import torch.utils.data
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | import torch.nn
13 | import torchvision
14 | import matplotlib.pyplot as plt
15 | import natsort
16 | from torch.utils.data.sampler import SubsetRandomSampler
17 | from Data_Loader import Images_Dataset, Images_Dataset_folder
18 | import torchsummary
19 | #from torch.utils.tensorboard import SummaryWriter
20 | from tensorboardX import SummaryWriter
21 |
22 | import shutil
23 | import random
24 | from Models import Unet_dict, NestedUNet, U_Net, R2U_Net, AttU_Net, R2AttU_Net
25 | from losses import calc_loss, dice_loss, threshold_predictions_v,threshold_predictions_p
26 | from ploting import plot_kernels, LayerActivations, input_images, plot_grad_flow
27 | from Metrics import dice_coeff, accuracy_score
28 | import time
29 | #from ploting import VisdomLinePlotter
30 | #from visdom import Visdom
31 |
32 |
33 | #######################################################
34 | #to make sure you want to run the program
35 | #######################################################
36 |
37 | x = input('start the model training: ')
38 | if x == 'yes':
39 | pass
40 | else:
41 | exit()
42 |
43 | #######################################################
44 | #Checking if GPU is used
45 | #######################################################
46 |
47 | train_on_gpu = torch.cuda.is_available()
48 |
49 | if not train_on_gpu:
50 | print('CUDA is not available. Training on CPU')
51 | else:
52 | print('CUDA is available. Training on GPU')
53 |
54 | device = torch.device("cuda:0" if train_on_gpu else "cpu")
55 |
56 | #######################################################
57 | #Setting the basic paramters of the model
58 | #######################################################
59 |
60 | batch_size = 4
61 | print('batch_size = ' + str(batch_size))
62 |
63 | valid_size = 0.15
64 |
65 | epoch = 10
66 | print('epoch = ' + str(epoch))
67 |
68 | random_seed = random.randint(1, 100)
69 | print('random_seed = ' + str(random_seed))
70 |
71 | shuffle = True
72 | valid_loss_min = np.Inf
73 | num_workers = 4
74 | lossT = []
75 | lossL = []
76 | lossL.append(np.inf)
77 | lossT.append(np.inf)
78 | epoch_valid = epoch-2
79 | n_iter = 1
80 | i_valid = 0
81 |
82 | pin_memory = False
83 | if train_on_gpu:
84 | pin_memory = True
85 |
86 | #plotter = VisdomLinePlotter(env_name='Tutorial Plots')
87 |
88 | #######################################################
89 | #Setting up the model
90 | #######################################################
91 |
92 | model_Inputs = [U_Net, R2U_Net, AttU_Net, R2AttU_Net, NestedUNet]
93 |
94 |
95 | def model_unet(model_input, in_channel=3, out_channel=1):
96 | model_test = model_input(in_channel, out_channel)
97 | return model_test
98 |
99 | #passsing this string so that if it's AttU_Net or R2ATTU_Net it doesn't throw an error at torchSummary
100 |
101 |
102 | model_test = model_unet(model_Inputs[0], 3, 1)
103 |
104 | model_test.to(device)
105 |
106 | #######################################################
107 | #Getting the Summary of Model
108 | #######################################################
109 |
110 | torchsummary.summary(model_test, input_size=(3, 128, 128))
111 |
112 | #######################################################
113 | #Passing the Dataset of Images and Labels
114 | #######################################################
115 |
116 | Training_Data = Images_Dataset_folder('/home/malav/Desktop/Pytorch_Computer/DATA/new_3C_I_ori_same/',
117 | '/home/malav/Desktop/Pytorch_Computer/DATA/new_3C_L_ori_same/')
118 |
119 | #######################################################
120 | #Giving a transformation for input data
121 | #######################################################
122 |
123 | data_transform = torchvision.transforms.Compose([
124 | # torchvision.transforms.Resize((128,128)),
125 | torchvision.transforms.CenterCrop(96),
126 | torchvision.transforms.ToTensor(),
127 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
128 | ])
129 |
130 | #######################################################
131 | #Trainging Validation Split
132 | #######################################################
133 |
134 | num_train = len(Training_Data)
135 | indices = list(range(num_train))
136 | split = int(np.floor(valid_size * num_train))
137 |
138 | if shuffle:
139 | np.random.seed(random_seed)
140 | np.random.shuffle(indices)
141 |
142 | train_idx, valid_idx = indices[split:], indices[:split]
143 | train_sampler = SubsetRandomSampler(train_idx)
144 | valid_sampler = SubsetRandomSampler(valid_idx)
145 |
146 | train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=train_sampler,
147 | num_workers=num_workers, pin_memory=pin_memory,)
148 |
149 | valid_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=valid_sampler,
150 | num_workers=num_workers, pin_memory=pin_memory,)
151 |
152 | #######################################################
153 | #Using Adam as Optimizer
154 | #######################################################
155 |
156 | initial_lr = 0.001
157 | opt = torch.optim.Adam(model_test.parameters(), lr=initial_lr)
158 | MAX_STEP = int(1e10)
159 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, MAX_STEP, eta_min=1e-5)
160 | #scheduler = optim.lr_scheduler.CosineAnnealingLr(opt, epoch, 1)
161 |
162 | #######################################################
163 | #Writing the params to tensorboard
164 | #######################################################
165 |
166 | writer1 = SummaryWriter()
167 | dummy_inp = torch.randn(1, 3, 128, 128)
168 | model_test.to('cpu')
169 | writer1.add_graph(model_test, model_test(torch.randn(3, 3, 128, 128, requires_grad=True)))
170 | model_test.to(device)
171 |
172 | #######################################################
173 | #Creating a Folder for every data of the program
174 | #######################################################
175 |
176 | New_folder = './model'
177 |
178 | if os.path.exists(New_folder) and os.path.isdir(New_folder):
179 | shutil.rmtree(New_folder)
180 |
181 | try:
182 | os.mkdir(New_folder)
183 | except OSError:
184 | print("Creation of the main directory '%s' failed " % New_folder)
185 | else:
186 | print("Successfully created the main directory '%s' " % New_folder)
187 |
188 | #######################################################
189 | #Setting the folder of saving the predictions
190 | #######################################################
191 |
192 | read_pred = './model/pred'
193 |
194 | #######################################################
195 | #Checking if prediction folder exixts
196 | #######################################################
197 |
198 | if os.path.exists(read_pred) and os.path.isdir(read_pred):
199 | shutil.rmtree(read_pred)
200 |
201 | try:
202 | os.mkdir(read_pred)
203 | except OSError:
204 | print("Creation of the prediction directory '%s' failed of dice loss" % read_pred)
205 | else:
206 | print("Successfully created the prediction directory '%s' of dice loss" % read_pred)
207 |
208 | #######################################################
209 | #checking if the model exists and if true then delete
210 | #######################################################
211 |
212 | read_model_path = './model/Unet_D_' + str(epoch) + '_' + str(batch_size)
213 |
214 | if os.path.exists(read_model_path) and os.path.isdir(read_model_path):
215 | shutil.rmtree(read_model_path)
216 | print('Model folder there, so deleted for newer one')
217 |
218 | try:
219 | os.mkdir(read_model_path)
220 | except OSError:
221 | print("Creation of the model directory '%s' failed" % read_model_path)
222 | else:
223 | print("Successfully created the model directory '%s' " % read_model_path)
224 |
225 | #######################################################
226 | #Training loop
227 | #######################################################
228 |
229 | for i in range(epoch):
230 |
231 | train_loss = 0.0
232 | valid_loss = 0.0
233 | since = time.time()
234 | scheduler.step(i)
235 | lr = scheduler.get_lr()
236 |
237 | #######################################################
238 | #Training Data
239 | #######################################################
240 |
241 | model_test.train()
242 |
243 | for x, y in train_loader:
244 | x, y = x.to(device), y.to(device)
245 |
246 | #If want to get the input images with their Augmentation - To check the data flowing in net
247 | input_images(x, y, i, n_iter)
248 |
249 | # grid_img = torchvision.utils.make_grid(x)
250 | #writer1.add_image('images', grid_img, 0)
251 |
252 | # grid_lab = torchvision.utils.make_grid(y)
253 |
254 | opt.zero_grad()
255 |
256 | y_pred = model_test(x)
257 | lossT = calc_loss(y_pred, y) # Dice_loss Used
258 |
259 | train_loss += lossT.item() * x.size(0)
260 | lossT.backward()
261 | # plot_grad_flow(model_test.named_parameters(), n_iter)
262 | opt.step()
263 | x_size = lossT.item() * x.size(0)
264 | k = 2
265 |
266 | # for name, param in model_test.named_parameters():
267 | # name = name.replace('.', '/')
268 | # writer1.add_histogram(name, param.data.cpu().numpy(), i + 1)
269 | # writer1.add_histogram(name + '/grad', param.grad.data.cpu().numpy(), i + 1)
270 |
271 |
272 | #######################################################
273 | #Validation Step
274 | #######################################################
275 |
276 | model_test.eval()
277 | torch.no_grad() #to increase the validation process uses less memory
278 |
279 | for x1, y1 in valid_loader:
280 | x1, y1 = x1.to(device), y1.to(device)
281 |
282 | y_pred1 = model_test(x1)
283 | lossL = calc_loss(y_pred1, y1) # Dice_loss Used
284 |
285 | valid_loss += lossL.item() * x1.size(0)
286 | x_size1 = lossL.item() * x1.size(0)
287 |
288 | #######################################################
289 | #Saving the predictions
290 | #######################################################
291 |
292 | im_tb = Image.open('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_I_ori_same/0131_0009.png')
293 | im_label = Image.open('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_L_ori_same/0131_0009.png')
294 | s_tb = data_transform(im_tb)
295 | s_label = data_transform(im_label)
296 |
297 | pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu()
298 | pred_tb = F.sigmoid(pred_tb)
299 | pred_tb = pred_tb.detach().numpy()
300 |
301 | #pred_tb = threshold_predictions_v(pred_tb)
302 |
303 | x1 = plt.imsave(
304 | './model/pred/img_iteration_' + str(n_iter) + '_epoch_'
305 | + str(i) + '.png', pred_tb[0][0])
306 |
307 | accuracy = accuracy_score(pred_tb[0][0], s_label)
308 |
309 | #######################################################
310 | #To write in Tensorboard
311 | #######################################################
312 |
313 | train_loss = train_loss / len(train_idx)
314 | valid_loss = valid_loss / len(valid_idx)
315 |
316 | if (i+1) % 1 == 0:
317 | print('Epoch: {}/{} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(i + 1, epoch, train_loss,
318 | valid_loss))
319 | writer1.add_scalar('Train Loss', train_loss, n_iter)
320 | writer1.add_scalar('Validation Loss', valid_loss, n_iter)
321 | #writer1.add_image('Pred', pred_tb[0]) #try to get output of shape 3
322 |
323 |
324 | #######################################################
325 | #Early Stopping
326 | #######################################################
327 |
328 | if valid_loss <= valid_loss_min and epoch_valid >= i: # and i_valid <= 2:
329 |
330 | print('Validation loss decreased ({:.6f} --> {:.6f}). Saving model '.format(valid_loss_min, valid_loss))
331 | torch.save(model_test.state_dict(),'./model/Unet_D_' +
332 | str(epoch) + '_' + str(batch_size) + '/Unet_epoch_' + str(epoch)
333 | + '_batchsize_' + str(batch_size) + '.pth')
334 | print(accuracy)
335 | if round(valid_loss, 4) == round(valid_loss_min, 4):
336 | print(i_valid)
337 | i_valid = i_valid+1
338 | valid_loss_min = valid_loss
339 | #if i_valid ==3:
340 | # break
341 |
342 | #######################################################
343 | # Extracting the intermediate layers
344 | #######################################################
345 |
346 | #####################################
347 | # for kernals
348 | #####################################
349 | x1 = torch.nn.ModuleList(model_test.children())
350 | # x2 = torch.nn.ModuleList(x1[16].children())
351 | # x3 = torch.nn.ModuleList(x2[0].children())
352 |
353 | #To get filters in the layers
354 | # plot_kernels(x3[3].weight.detach().cpu(), 7)
355 |
356 | #####################################
357 | # for images
358 | #####################################
359 | x2 = len(x1)
360 | dr = LayerActivations(x1[x2-1]) #Getting the last Conv Layer
361 |
362 | img = Image.open('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_I_ori_same/0131_0009.png')
363 | s_tb = data_transform(img)
364 |
365 | pred_tb = model_test(s_tb.unsqueeze(0).to(device)).cpu()
366 | pred_tb = F.sigmoid(pred_tb)
367 | pred_tb = pred_tb.detach().numpy()
368 |
369 | plot_kernels(dr.features, n_iter, 7, cmap="rainbow")
370 |
371 | time_elapsed = time.time() - since
372 | print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
373 | n_iter += 1
374 |
375 | #######################################################
376 | #closing the tensorboard writer
377 | #######################################################
378 |
379 | writer1.close()
380 |
381 | #######################################################
382 | #if using dict
383 | #######################################################
384 |
385 | #model_test.filter_dict
386 |
387 | #######################################################
388 | #Loading the model
389 | #######################################################
390 |
391 | test1 =model_test.load_state_dict(torch.load('./model/Unet_D_' +
392 | str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch)
393 | + '_batchsize_' + str(batch_size) + '.pth'))
394 |
395 |
396 | #######################################################
397 | #checking if cuda is available
398 | #######################################################
399 |
400 | if torch.cuda.is_available():
401 | torch.cuda.empty_cache()
402 |
403 | #######################################################
404 | #Loading the model
405 | #######################################################
406 |
407 | model_test.load_state_dict(torch.load('./model/Unet_D_' +
408 | str(epoch) + '_' + str(batch_size)+ '/Unet_epoch_' + str(epoch)
409 | + '_batchsize_' + str(batch_size) + '.pth'))
410 |
411 | model_test.eval()
412 |
413 | #######################################################
414 | #opening the test folder and creating a folder for generated images
415 | #######################################################
416 |
417 | read_test_folder = glob.glob('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_I_ori_same/*')
418 | x_sort_test = natsort.natsorted(read_test_folder) # To sort
419 |
420 |
421 | read_test_folder112 = './model/gen_images'
422 |
423 |
424 | if os.path.exists(read_test_folder112) and os.path.isdir(read_test_folder112):
425 | shutil.rmtree(read_test_folder112)
426 |
427 | try:
428 | os.mkdir(read_test_folder112)
429 | except OSError:
430 | print("Creation of the testing directory %s failed" % read_test_folder112)
431 | else:
432 | print("Successfully created the testing directory %s " % read_test_folder112)
433 |
434 |
435 | #For Prediction Threshold
436 |
437 | read_test_folder_P_Thres = './model/pred_threshold'
438 |
439 |
440 | if os.path.exists(read_test_folder_P_Thres) and os.path.isdir(read_test_folder_P_Thres):
441 | shutil.rmtree(read_test_folder_P_Thres)
442 |
443 | try:
444 | os.mkdir(read_test_folder_P_Thres)
445 | except OSError:
446 | print("Creation of the testing directory %s failed" % read_test_folder_P_Thres)
447 | else:
448 | print("Successfully created the testing directory %s " % read_test_folder_P_Thres)
449 |
450 | #For Label Threshold
451 |
452 | read_test_folder_L_Thres = './model/label_threshold'
453 |
454 |
455 | if os.path.exists(read_test_folder_L_Thres) and os.path.isdir(read_test_folder_L_Thres):
456 | shutil.rmtree(read_test_folder_L_Thres)
457 |
458 | try:
459 | os.mkdir(read_test_folder_L_Thres)
460 | except OSError:
461 | print("Creation of the testing directory %s failed" % read_test_folder_L_Thres)
462 | else:
463 | print("Successfully created the testing directory %s " % read_test_folder_L_Thres)
464 |
465 |
466 |
467 | #######################################################
468 | #data transform for test Set (same as before)
469 | #######################################################
470 |
471 | data_transform = torchvision.transforms.Compose([
472 | # torchvision.transforms.Resize((128, 128)),
473 | # torchvision.transforms.Grayscale(),
474 | torchvision.transforms.CenterCrop(96),
475 | torchvision.transforms.ToTensor(),
476 | torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
477 | ])
478 |
479 | #######################################################
480 | #saving the images in the files
481 | #######################################################
482 |
483 | img_test_no = 0
484 |
485 | for i in range(len(read_test_folder)):
486 | im = Image.open(x_sort_test[i])
487 |
488 | im1 = im
489 | im_n = np.array(im1)
490 | im_n_flat = im_n.reshape(-1,1)
491 |
492 | for j in range(im_n_flat.shape[0]):
493 | if im_n_flat[j] != 0:
494 | im_n_flat[j] = 255
495 |
496 | s = data_transform(im)
497 | pred = model_test(s.unsqueeze(0).cuda()).cpu()
498 | pred = F.sigmoid(pred)
499 | pred = pred.detach().numpy()
500 |
501 | # pred = threshold_predictions_p(pred) #Value kept 0.01 as max is 1 and noise is very small.
502 |
503 | if i % 24 == 0:
504 | img_test_no = img_test_no + 1
505 |
506 | x1 = plt.imsave('./model/gen_images/im_epoch_' + str(epoch) + 'int_' + str(i)
507 | + '_img_no_' + str(img_test_no) + '.png', pred[0][0])
508 |
509 | ####################################################
510 | #data transform for test Set (same as before)
511 | ####################################################
512 |
513 | data_transform_test = torchvision.transforms.Compose([
514 | # torchvision.transforms.Resize((128, 128)),
515 | torchvision.transforms.CenterCrop(96),
516 | torchvision.transforms.Grayscale(),
517 | ])
518 |
519 | ####################################################
520 | #Calculating the Dice Score
521 | ####################################################
522 |
523 | read_test_folderP = glob.glob('./model/gen_images/*')
524 | x_sort_testP = natsort.natsorted(read_test_folderP)
525 |
526 |
527 | read_test_folderL = glob.glob('/home/malav/Desktop/Pytorch_Computer/DATA/test_new_3C_L_ori_same/*')
528 | x_sort_testL = natsort.natsorted(read_test_folderL) # To sort
529 |
530 |
531 | dice_score123 = 0.0
532 | x_count = 0
533 | x_dice = 0
534 |
535 | for i in range(len(read_test_folderP)):
536 |
537 | x = Image.open(x_sort_testP[i])
538 | s = data_transform_test(x)
539 | s = np.array(s)
540 | s = threshold_predictions_v(s)
541 |
542 | #save the images
543 | x1 = plt.imsave('./model/pred_threshold/im_epoch_' + str(epoch) + 'int_' + str(i)
544 | + '_img_no_' + str(img_test_no) + '.png', s)
545 |
546 | y = Image.open(x_sort_testL[i])
547 | s2 = data_transform_test(y)
548 | s3 = np.array(s2)
549 | # s2 =threshold_predictions_v(s2)
550 |
551 | #save the Images
552 | y1 = plt.imsave('./model/label_threshold/im_epoch_' + str(epoch) + 'int_' + str(i)
553 | + '_img_no_' + str(img_test_no) + '.png', s3)
554 |
555 | total = dice_coeff(s, s3)
556 | print(total)
557 |
558 | if total <= 0.3:
559 | x_count += 1
560 | if total > 0.3:
561 | x_dice = x_dice + total
562 | dice_score123 = dice_score123 + total
563 |
564 |
565 | print('Dice Score : ' + str(dice_score123/len(read_test_folderP)))
566 | print(x_count)
567 | print(x_dice)
568 | print('Dice Score : ' + str(float(x_dice/(len(read_test_folderP)-x_count))))
569 |
570 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | python>=3.6
2 | torch>=0.4.0
3 | torchvision
4 | torchsummary
5 | tensorboardx
6 | natsort
7 | numpy
8 | pillow
9 | scipy
10 | scikit-image
11 | sklearn
12 |
--------------------------------------------------------------------------------