├── .DS_Store
├── CLSTM.py
├── LICENSE
├── README.md
├── data.py
├── dataParser.py
├── dice.png
├── hybridunet.png
├── losses.py
├── main.py
├── main_bdclstm.py
├── main_small.py
├── models.py
├── plot_ims.py
├── result_comparison.png
├── smallunet.png
└── unet.png
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/.DS_Store
--------------------------------------------------------------------------------
/CLSTM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 | # Batch x NumChannels x Height x Width
6 | # UNET --> BatchSize x 1 (3?) x 240 x 240
7 | # BDCLSTM --> BatchSize x 64 x 240 x240
8 |
9 | ''' Class CLSTMCell.
10 | This represents a single node in a CLSTM series.
11 | It produces just one time (spatial) step output.
12 | '''
13 |
14 |
15 | class CLSTMCell(nn.Module):
16 |
17 | # Constructor
18 | def __init__(self, input_channels, hidden_channels,
19 | kernel_size, bias=True):
20 | super(CLSTMCell, self).__init__()
21 |
22 | assert hidden_channels % 2 == 0
23 |
24 | self.input_channels = input_channels
25 | self.hidden_channels = hidden_channels
26 | self.bias = bias
27 | self.kernel_size = kernel_size
28 | self.num_features = 4
29 |
30 | self.padding = (kernel_size - 1) // 2
31 | self.conv = nn.Conv2d(self.input_channels + self.hidden_channels,
32 | self.num_features * self.hidden_channels,
33 | self.kernel_size,
34 | 1,
35 | self.padding)
36 |
37 | # Forward propogation formulation
38 | def forward(self, x, h, c):
39 | # print('x: ', x.type)
40 | # print('h: ', h.type)
41 | combined = torch.cat((x, h), dim=1)
42 | A = self.conv(combined)
43 |
44 | # NOTE: A? = xz * Wx? + hz-1 * Wh? + b? where * is convolution
45 | (Ai, Af, Ao, Ag) = torch.split(A,
46 | A.size()[1] // self.num_features,
47 | dim=1)
48 |
49 | i = torch.sigmoid(Ai) # input gate
50 | f = torch.sigmoid(Af) # forget gate
51 | o = torch.sigmoid(Ao) # output gate
52 | g = torch.tanh(Ag)
53 |
54 | c = c * f + i * g # cell activation state
55 | h = o * torch.tanh(c) # cell hidden state
56 |
57 | return h, c
58 |
59 | @staticmethod
60 | def init_hidden(batch_size, hidden_c, shape):
61 | try:
62 | return(Variable(torch.zeros(batch_size,
63 | hidden_c,
64 | shape[0],
65 | shape[1])).cuda(),
66 | Variable(torch.zeros(batch_size,
67 | hidden_c,
68 | shape[0],
69 | shape[1])).cuda())
70 | except:
71 | return(Variable(torch.zeros(batch_size,
72 | hidden_c,
73 | shape[0],
74 | shape[1])),
75 | Variable(torch.zeros(batch_size,
76 | hidden_c,
77 | shape[0],
78 | shape[1])))
79 |
80 |
81 | ''' Class CLSTM.
82 | This represents a series of CLSTM nodes (one direction)
83 | '''
84 |
85 |
86 | class CLSTM(nn.Module):
87 | # Constructor
88 | def __init__(self, input_channels=64, hidden_channels=[64],
89 | kernel_size=5, bias=True):
90 | super(CLSTM, self).__init__()
91 |
92 | # store stuff
93 | self.input_channels = [input_channels] + hidden_channels
94 | self.hidden_channels = hidden_channels
95 | self.kernel_size = kernel_size
96 | self.num_layers = len(hidden_channels)
97 |
98 | self.bias = bias
99 | self.all_layers = []
100 |
101 | # create a node for each layer in the CLSTM
102 | for layer in range(self.num_layers):
103 | name = 'cell{}'.format(layer)
104 | cell = CLSTMCell(self.input_channels[layer],
105 | self.hidden_channels[layer],
106 | self.kernel_size,
107 | self.bias)
108 | setattr(self, name, cell)
109 | self.all_layers.append(cell)
110 |
111 | # Forward propogation
112 | # x --> BatchSize x NumSteps x NumChannels x Height x Width
113 | # BatchSize x 2 x 64 x 240 x 240
114 | def forward(self, x):
115 | bsize, steps, _, height, width = x.size()
116 | internal_state = []
117 | outputs = []
118 | for step in range(steps):
119 | input = torch.squeeze(x[:, step, :, :, :], dim=1)
120 | for layer in range(self.num_layers):
121 | # populate hidden states for all layers
122 | if step == 0:
123 | (h, c) = CLSTMCell.init_hidden(bsize,
124 | self.hidden_channels[layer],
125 | (height, width))
126 | internal_state.append((h, c))
127 |
128 | # do forward
129 | name = 'cell{}'.format(layer)
130 | (h, c) = internal_state[layer]
131 |
132 | input, c = getattr(self, name)(
133 | input, h, c) # forward propogation call
134 | internal_state[layer] = (input, c)
135 |
136 | outputs.append(input)
137 |
138 | #for i in range(len(outputs)):
139 | # print(outputs[i].size())
140 | return outputs
141 |
142 |
143 | class BDCLSTM(nn.Module):
144 | # Constructor
145 | def __init__(self, input_channels=64, hidden_channels=[64],
146 | kernel_size=5, bias=True, num_classes=2):
147 |
148 | super(BDCLSTM, self).__init__()
149 | self.forward_net = CLSTM(
150 | input_channels, hidden_channels, kernel_size, bias)
151 | self.reverse_net = CLSTM(
152 | input_channels, hidden_channels, kernel_size, bias)
153 | self.conv = nn.Conv2d(
154 | 2 * hidden_channels[-1], num_classes, kernel_size=1)
155 | self.soft = nn.Softmax2d()
156 |
157 | # Forward propogation
158 | # x --> BatchSize x NumChannels x Height x Width
159 | # BatchSize x 64 x 240 x 240
160 | def forward(self, x1, x2, x3):
161 | x1 = torch.unsqueeze(x1, dim=1)
162 | x2 = torch.unsqueeze(x2, dim=1)
163 | x3 = torch.unsqueeze(x3, dim=1)
164 |
165 | xforward = torch.cat((x1, x2), dim=1)
166 | xreverse = torch.cat((x3, x2), dim=1)
167 |
168 | yforward = self.forward_net(xforward)
169 | yreverse = self.reverse_net(xreverse)
170 |
171 | # assumes y is BatchSize x NumClasses x 240 x 240
172 | # print(yforward[-1].type)
173 | ycat = torch.cat((yforward[-1], yreverse[-1]), dim=1)
174 | # print(ycat.size())
175 | y = self.conv(ycat)
176 | # print(y.type)
177 | y = self.soft(y)
178 | # print(y.type)
179 | return y
180 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Shreyas Padhy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UNet-Zoo
2 | A collection of UNet and hybrid architectures for 2D and 3D Biomedical Image segmentation, implemented in PyTorch.
3 |
4 | This repository contains a collection of architectures used for Biomedical Image Segmentation, implemented on the BraTS Brain Tumor Segmentation Challenge Dataset. The following architectures are implemented
5 |
6 | 1. **UNet** - Standard UNet architecture as described in the Ronneberger et al 2015 paper [[reference]](https://arxiv.org/abs/1505.04597)
7 |
8 |
9 |
10 |
11 | 1. **Small UNet** - 40x smaller version of UNet that achieves similar performance [[Theano Implementation]](https://github.com/codedecde/Luna2016-Lung-Nodule-Detection)
12 |
13 |
14 |
15 |
16 | 1. **UNet with BDCLSTM** - Combining a BDC-LSTM network with UNet to encode spatial correlation for 3D segmentation [[reference]](https://arxiv.org/pdf/1609.01006.pdf)
17 |
18 |
19 |
20 |
21 |
22 | 1. **kUNet** - Combining multiple UNets for increasing heirarchial preservation of information (coming soon) [[reference]](https://arxiv.org/pdf/1701.03056.pdf)
23 | 1. **R-UNet** - UNet with recurrent connections for another way to encode $z$-context (coming soon)
24 | ### To Run
25 |
26 | First, apply for access the BraTS Tumor dataset, and place the scans in a `Data/` folder, divided into `Train` and `Test`. Feel free to modify the BraTS PyTorch dataloaders in `data.py` for your use.
27 | 1. UNet - run `main.py`, type `--help` for information on arguments.
28 | Example: `python main.py --train --cuda --data-folder "./Data/"`
29 | 1. Small UNet - run `main_small.py`, and use `--help`
30 | 1. BDC-LSTM - run `main_bdclstm.py` and use the weights for either your trained UNet or Small-UNet models (`--help` is your savior).
31 |
32 | ### Some Results
33 |
34 | 1. Comparisons of UNet (top) and Small UNet (bottom)
35 |
36 |
37 |
38 |
39 |
40 | 2. DICE Scores for UNet and Small UNet
41 |
42 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | from torch.utils.data.dataset import Dataset
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from PIL import Image
7 | import os
8 | import scipy.io as sio
9 | from dataParser import getMaskFileName, getImg
10 | import torchvision.transforms as tr
11 |
12 |
13 | class BraTSDatasetUnet(Dataset):
14 | __file = []
15 | __im = []
16 | __mask = []
17 | im_ht = 0
18 | im_wd = 0
19 | dataset_size = 0
20 |
21 | def __init__(self, dataset_folder, train=True, keywords=["P1", "1", "flair"], im_size=[128, 128], transform=None):
22 |
23 | self.__file = []
24 | self.__im = []
25 | self.__mask = []
26 | self.im_ht = im_size[0]
27 | self.im_wd = im_size[1]
28 | self.transform = transform
29 |
30 | folder = dataset_folder
31 | # # Open and load text file including the whole training data
32 | if train:
33 | folder = dataset_folder + "Train/"
34 | else:
35 | folder = dataset_folder + "Test/"
36 |
37 | for file in os.listdir(folder):
38 | if file.endswith(".png"):
39 | filename = os.path.splitext(file)[0]
40 | filename_fragments = filename.split("_")
41 | samekeywords = list(set(filename_fragments) & set(keywords))
42 | if len(samekeywords) == len(keywords):
43 | # 1. read file name
44 | self.__file.append(filename)
45 | # 2. read raw image
46 | # TODO: I think we should open image only in getitem,
47 | # otherwise memory explodes
48 |
49 | # rawImage = getImg(folder + file)
50 | self.__im.append(folder + file)
51 | # 3. read mask image
52 | mask_file = getMaskFileName(file)
53 | # maskImage = getImg(folder + mask_file)
54 | self.__mask.append(folder + mask_file)
55 | # self.dataset_size = len(self.__file)
56 |
57 | # print("lengths : ", len(self.__im), len(self.__mask))
58 | self.dataset_size = len(self.__file)
59 |
60 | if not train:
61 | sio.savemat('filelist2.mat', {'data': self.__im})
62 |
63 | def __getitem__(self, index):
64 |
65 | img = getImg(self.__im[index])
66 | mask = getImg(self.__mask[index])
67 |
68 | img = img.resize((self.im_ht, self.im_wd))
69 | mask = mask.resize((self.im_ht, self.im_wd))
70 | # mask.show()
71 |
72 | if self.transform is not None:
73 | # TODO: Not sure why not take full image
74 | img_tr = self.transform(img)
75 | mask_tr = self.transform(mask)
76 | # img_tr = self.transform(img[None, :, :])
77 | # mask_tr = self.transform(mask[None, :, :])
78 |
79 | return img_tr, mask_tr
80 | # return img.float(), mask.float()
81 |
82 | def __len__(self):
83 |
84 | return len(self.__im)
85 |
86 |
87 | class BraTSDatasetLSTM(Dataset):
88 | __im = []
89 | __mask = []
90 | __im1 = []
91 | __im3 = []
92 | im_ht = 0
93 | im_wd = 0
94 | dataset_size = 0
95 |
96 | def __init__(self, dataset_folder, train=True, keywords=["P1", "1", "flair"], im_size=[128, 128], transform=None):
97 |
98 | self.__file = []
99 | self.__im = []
100 | self.__mask = []
101 | self.im_ht = im_size[0]
102 | self.im_wd = im_size[1]
103 | self.transform = transform
104 |
105 | folder = dataset_folder
106 | # # Open and load text file including the whole training data
107 | if train:
108 | folder = dataset_folder + "Train/"
109 | else:
110 | folder = dataset_folder + "Test/"
111 |
112 | # print("files : ", os.listdir(folder))
113 | # print("Folder : ", folder)
114 | max_file = 0
115 | min_file = 10000000
116 | for file in os.listdir(folder):
117 | if file.endswith(".png"):
118 | m = re.search('(P[0-9]*[_])([0-9]*)', file)
119 | pic_num = int(m.group(2))
120 | if pic_num > max_file:
121 | max_file = pic_num
122 | if pic_num < min_file:
123 | min_file = pic_num
124 |
125 | # print('min file number: ', min_file)
126 | # print('max file number: ', max_file)
127 |
128 | for file in os.listdir(folder):
129 | if file.endswith(".png"):
130 | filename = os.path.splitext(file)[0]
131 | filename_fragments = filename.split("_")
132 | samekeywords = list(set(filename_fragments) & set(keywords))
133 | if len(samekeywords) == len(keywords):
134 | # 1. read file name
135 | # 2. read raw image
136 | # TODO: I think we should open image only in getitem,
137 | # otherwise memory explodes
138 |
139 | # rawImage = getImg(folder + file)
140 |
141 | if (filename_fragments[2] != str(min_file)) and (filename_fragments[2] != str(max_file)):
142 | # print("TEST : ", filename_fragments[2])
143 | self.__im.append(folder + file)
144 |
145 | file1 = filename_fragments[0] + '_' + filename_fragments[1] + '_' + str(
146 | int(filename_fragments[2]) - 1) + '_' + filename_fragments[3] + '.png'
147 |
148 | self.__im1.append(folder + file1)
149 |
150 | file3 = filename_fragments[0] + '_' + filename_fragments[1] + '_' + str(
151 | int(filename_fragments[2]) + 1) + '_' + filename_fragments[3] + '.png'
152 |
153 | self.__im3.append(folder + file3)
154 | # 3. read mask image
155 | mask_file = getMaskFileName(file)
156 | # maskImage = getImg(folder + mask_file)
157 | self.__mask.append(folder + mask_file)
158 | # self.dataset_size = len(self.__file)
159 |
160 | # print("lengths : ", len(self.__im), len(self.__mask))
161 | self.dataset_size = len(self.__file)
162 |
163 | def __getitem__(self, index):
164 |
165 | img1 = getImg(self.__im1[index])
166 | img = getImg(self.__im[index])
167 | img3 = getImg(self.__im3[index])
168 | mask = getImg(self.__mask[index])
169 |
170 | # img.show()
171 | # mask.show()
172 |
173 | if self.transform is not None:
174 | # TODO: Not sure why not take full image
175 | img_tr1 = self.transform(img1)
176 | img_tr = self.transform(img)
177 | img_tr3 = self.transform(img3)
178 | mask_tr = self.transform(mask)
179 | # img_tr = self.transform(img[None, :, :])
180 | # mask_tr = self.transform(mask[None, :, :])
181 |
182 | return img_tr1, img_tr, img_tr3, mask_tr
183 | # return img.float(), mask.float()
184 |
185 | def __len__(self):
186 |
187 | return len(self.__im)
188 |
--------------------------------------------------------------------------------
/dataParser.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from PIL import Image
4 | import numpy as np
5 | from skimage import color
6 | from skimage import io
7 |
8 |
9 | def getMaskFileName(file):
10 |
11 | mask_file = file.replace("flair.png", "seg.png")
12 | mask_file = mask_file.replace("t1.png", "seg.png")
13 | mask_file = mask_file.replace("t2.png", "seg.png")
14 | mask_file = mask_file.replace("t1ce.png", "seg.png")
15 |
16 | return mask_file
17 |
18 |
19 | def getImg(imgpathway):
20 | # image_file = Image.open(imgpathway) # open colour image
21 | # img = image_file.convert('L')
22 | # IMG = np.asarray(img.getdata())
23 | # img = io.imread(imgpathway, as_grey=True)
24 | img = Image.open(imgpathway)
25 | # img = np.asarray(img)
26 | # img *= 65536.0 / np.max(img)
27 | # IMG.astype(np.uint16)
28 | # plt.imshow(IMG, cmap='gray')
29 | # plt.show()
30 | return img
31 |
32 |
33 | def File2Image(self, index):
34 | file = self.__file[index]
35 | filename_fragments = file.split("_")
36 | if filename_fragments[1] == '0' or filename_fragments[1] == '154':
37 | # Not sure what to do here
38 | return 0, 0
39 |
40 | filename1 = filename_fragments[0] + filename_fragments[1] + '_' + \
41 | str(int(filename_fragments[2]) - 1) + '_' + filename_fragments[3]
42 | filename3 = filename_fragments[0] + filename_fragments[1] + '_' + \
43 | str(int(filename_fragments[2]) + 1) + '_' + filename_fragments[3]
44 |
45 | idx1 = self.__file.index(filename1)
46 | idx3 = self.__file.index(filename3)
47 | img1 = self.__im[idx1]
48 | img3 = self.__im[idx3]
49 |
50 | return img1, img3
51 |
--------------------------------------------------------------------------------
/dice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/dice.png
--------------------------------------------------------------------------------
/hybridunet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/hybridunet.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import torch.nn as nn
4 | import torch.functional as f
5 | import numpy as np
6 |
7 |
8 | class DICELossMultiClass(nn.Module):
9 |
10 | def __init__(self):
11 | super(DICELossMultiClass, self).__init__()
12 |
13 | def forward(self, output, mask):
14 | num_classes = output.size(1)
15 | dice_eso = 0
16 | for i in range(num_classes):
17 | probs = torch.squeeze(output[:, i, :, :], 1)
18 | mask = torch.squeeze(mask[:, i, :, :], 1)
19 |
20 | num = probs * mask
21 | num = torch.sum(num, 2)
22 | num = torch.sum(num, 1)
23 |
24 | # print( num )
25 |
26 | den1 = probs * probs
27 | # print(den1.size())
28 | den1 = torch.sum(den1, 2)
29 | den1 = torch.sum(den1, 1)
30 |
31 | # print(den1.size())
32 |
33 | den2 = mask * mask
34 | # print(den2.size())
35 | den2 = torch.sum(den2, 2)
36 | den2 = torch.sum(den2, 1)
37 |
38 | # print(den2.size())
39 | eps = 0.0000001
40 | dice = 2 * ((num + eps) / (den1 + den2 + eps))
41 | # dice_eso = dice[:, 1:]
42 | dice_eso += dice
43 |
44 | loss = 1 - torch.sum(dice_eso) / dice_eso.size(0)
45 | return loss
46 |
47 |
48 | class DICELoss(nn.Module):
49 |
50 | def __init__(self):
51 | super(DICELoss, self).__init__()
52 |
53 | def forward(self, output, mask):
54 |
55 | probs = torch.squeeze(output, 1)
56 | mask = torch.squeeze(mask, 1)
57 |
58 | intersection = probs * mask
59 | intersection = torch.sum(intersection, 2)
60 | intersection = torch.sum(intersection, 1)
61 |
62 | den1 = probs * probs
63 | den1 = torch.sum(den1, 2)
64 | den1 = torch.sum(den1, 1)
65 |
66 | den2 = mask * mask
67 | den2 = torch.sum(den2, 2)
68 | den2 = torch.sum(den2, 1)
69 |
70 | eps = 1e-8
71 | dice = 2 * ((intersection + eps) / (den1 + den2 + eps))
72 | # dice_eso = dice[:, 1:]
73 | dice_eso = dice
74 |
75 | loss = 1 - torch.sum(dice_eso) / dice_eso.size(0)
76 | return loss
77 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # %% -*- coding: utf-8 -*-
2 | '''
3 | Author: Shreyas Padhy
4 | Driver file for Standard UNet Implementation
5 | '''
6 | from __future__ import print_function
7 |
8 | import argparse
9 | import matplotlib
10 | matplotlib.use('Agg')
11 | import matplotlib.pyplot as plt
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | from torch.autograd import Variable
17 | from torch.utils.data import DataLoader
18 | import scipy.io as sio
19 | import torchvision.transforms as tr
20 |
21 | from data import BraTSDatasetUnet, BraTSDatasetLSTM
22 | from losses import DICELossMultiClass
23 | from models import UNet
24 | from tqdm import tqdm
25 | import numpy as np
26 |
27 | # %% import transforms
28 |
29 | # %% Training settings
30 | parser = argparse.ArgumentParser(
31 | description='UNet + BDCLSTM for BraTS Dataset')
32 | parser.add_argument('--batch-size', type=int, default=4, metavar='N',
33 | help='input batch size for training (default: 64)')
34 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
35 | help='input batch size for testing (default: 1000)')
36 | parser.add_argument('--train', action='store_true', default=False,
37 | help='Argument to train model (default: False)')
38 | parser.add_argument('--epochs', type=int, default=1, metavar='N',
39 | help='number of epochs to train (default: 10)')
40 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
41 | help='learning rate (default: 0.01)')
42 | parser.add_argument('--cuda', action='store_true', default=False,
43 | help='enables CUDA training (default: False)')
44 | parser.add_argument('--log-interval', type=int, default=1, metavar='N',
45 | help='batches to wait before logging training status')
46 | parser.add_argument('--size', type=int, default=128, metavar='N',
47 | help='imsize')
48 | parser.add_argument('--load', type=str, default=None, metavar='str',
49 | help='weight file to load (default: None)')
50 | parser.add_argument('--data-folder', type=str, default='./Data/', metavar='str',
51 | help='folder that contains data (default: test dataset)')
52 | parser.add_argument('--save', type=str, default='OutMasks', metavar='str',
53 | help='Identifier to save npy arrays with')
54 | parser.add_argument('--modality', type=str, default='flair', metavar='str',
55 | help='Modality to use for training (default: flair)')
56 | parser.add_argument('--optimizer', type=str, default='SGD', metavar='str',
57 | help='Optimizer (default: SGD)')
58 |
59 | args = parser.parse_args()
60 | args.cuda = args.cuda and torch.cuda.is_available()
61 |
62 | DATA_FOLDER = args.data_folder
63 |
64 | # %% Loading in the Dataset
65 | dset_train = BraTSDatasetUnet(DATA_FOLDER, train=True,
66 | keywords=[args.modality],
67 | im_size=[args.size, args.size],
68 | transform=tr.ToTensor())
69 |
70 | train_loader = DataLoader(dset_train,
71 | batch_size=args.batch_size,
72 | shuffle=True, num_workers=1)
73 |
74 | dset_test = BraTSDatasetUnet(DATA_FOLDER, train=False,
75 | keywords=[args.modality],
76 | im_size=[args.size, args.size],
77 | transform=tr.ToTensor())
78 |
79 | test_loader = DataLoader(dset_test,
80 | batch_size=args.test_batch_size,
81 | shuffle=False, num_workers=1)
82 |
83 |
84 | print("Training Data : ", len(train_loader.dataset))
85 | print("Test Data :", len(test_loader.dataset))
86 |
87 | # %% Loading in the model
88 | model = UNet()
89 |
90 | if args.cuda:
91 | model.cuda()
92 |
93 | if args.optimizer == 'SGD':
94 | optimizer = optim.SGD(model.parameters(), lr=args.lr,
95 | momentum=0.99)
96 | if args.optimizer == 'ADAM':
97 | optimizer = optim.Adam(model.parameters(), lr=args.lr,
98 | betas=(args.beta1, args.beta2))
99 |
100 |
101 | # Defining Loss Function
102 | criterion = DICELossMultiClass()
103 |
104 |
105 | def train(epoch, loss_lsit):
106 | model.train()
107 | for batch_idx, (image, mask) in enumerate(train_loader):
108 | if args.cuda:
109 | image, mask = image.cuda(), mask.cuda()
110 |
111 | image, mask = Variable(image), Variable(mask)
112 |
113 | optimizer.zero_grad()
114 |
115 | output = model(image)
116 |
117 | loss = criterion(output, mask)
118 | loss_list.append(loss.data[0])
119 |
120 | loss.backward()
121 | optimizer.step()
122 |
123 | if batch_idx % args.log_interval == 0:
124 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage DICE Loss: {:.6f}'.format(
125 | epoch, batch_idx * len(image), len(train_loader.dataset),
126 | 100. * batch_idx / len(train_loader), loss.data[0]))
127 |
128 |
129 | def test(train_accuracy=False, save_output=False):
130 | test_loss = 0
131 |
132 | if train_accuracy:
133 | loader = train_loader
134 | else:
135 | loader = test_loader
136 |
137 | for batch_idx, (image, mask) in tqdm(enumerate(loader)):
138 | if args.cuda:
139 | image, mask = image.cuda(), mask.cuda()
140 |
141 | image, mask = Variable(image, volatile=True), Variable(
142 | mask, volatile=True)
143 |
144 | output = model(image)
145 |
146 | # test_loss += criterion(output, mask).data[0]
147 | maxes, out = torch.max(output, 1, keepdim=True)
148 |
149 | if save_output and (not train_accuracy):
150 | np.save('./npy-files/out-files/{}-batch-{}-outs.npy'.format(args.save,
151 | batch_idx),
152 | out.data.byte().cpu().numpy())
153 | np.save('./npy-files/out-files/{}-batch-{}-masks.npy'.format(args.save,
154 | batch_idx),
155 | mask.data.byte().cpu().numpy())
156 | np.save('./npy-files/out-files/{}-batch-{}-images.npy'.format(args.save,
157 | batch_idx),
158 | image.data.float().cpu().numpy())
159 |
160 | if save_output and train_accuracy:
161 | np.save('./npy-files/out-files/{}-train-batch-{}-outs.npy'.format(args.save,
162 | batch_idx),
163 | out.data.byte().cpu().numpy())
164 | np.save('./npy-files/out-files/{}-train-batch-{}-masks.npy'.format(args.save,
165 | batch_idx),
166 | mask.data.byte().cpu().numpy())
167 | np.save('./npy-files/out-files/{}-train-batch-{}-images.npy'.format(args.save,
168 | batch_idx),
169 | image.data.float().cpu().numpy())
170 |
171 | test_loss += criterion(output, mask).data[0]
172 |
173 | # Average Dice Coefficient
174 | test_loss /= len(loader)
175 | if train_accuracy:
176 | print('\nTraining Set: Average DICE Coefficient: {:.4f})\n'.format(
177 | test_loss))
178 | else:
179 | print('\nTest Set: Average DICE Coefficient: {:.4f})\n'.format(
180 | test_loss))
181 |
182 |
183 | if args.train:
184 | loss_list = []
185 | for i in tqdm(range(args.epochs)):
186 | train(i, loss_list)
187 | test()
188 |
189 | plt.plot(loss_list)
190 | plt.title("UNet bs={}, ep={}, lr={}".format(args.batch_size,
191 | args.epochs, args.lr))
192 | plt.xlabel("Number of iterations")
193 | plt.ylabel("Average DICE loss per batch")
194 | plt.savefig("./plots/{}-UNet_Loss_bs={}_ep={}_lr={}.png".format(args.save,
195 | args.batch_size,
196 | args.epochs,
197 | args.lr))
198 |
199 | np.save('./npy-files/loss-files/{}-UNet_Loss_bs={}_ep={}_lr={}.npy'.format(args.save,
200 | args.batch_size,
201 | args.epochs,
202 | args.lr),
203 | np.asarray(loss_list))
204 |
205 | torch.save(model.state_dict(), 'unet-final-{}-{}-{}'.format(args.batch_size,
206 | args.epochs,
207 | args.lr))
208 | elif args.load is not None:
209 | model.load_state_dict(torch.load(args.load))
210 | test(save_output=True)
211 | test(train_accuracy=True)
212 |
--------------------------------------------------------------------------------
/main_bdclstm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | import matplotlib.pyplot as plt
5 | import torch
6 | import torch.optim as optim
7 | from losses import DICELossMultiClass
8 |
9 | from torch.autograd import Variable
10 | from torch.utils.data import DataLoader
11 | import torchvision.transforms as tr
12 |
13 | from data import BraTSDatasetLSTM
14 | from CLSTM import BDCLSTM
15 | from models import *
16 |
17 | # %% import transforms
18 |
19 | UNET_MODEL_FILE = 'unetsmall-100-10-0.001'
20 | MODALITY = ["flair"]
21 |
22 | # %% Training settings
23 | parser = argparse.ArgumentParser(description='UNet+BDCLSTM for BraTS Dataset')
24 | parser.add_argument('--batch-size', type=int, default=4, metavar='N',
25 | help='input batch size for training (default: 64)')
26 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
27 | help='input batch size for testing (default: 1000)')
28 | parser.add_argument('--train', action='store_true', default=False,
29 | help='Argument to train model (default: False)')
30 | parser.add_argument('--epochs', type=int, default=1, metavar='N',
31 | help='number of epochs to train (default: 10)')
32 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
33 | help='learning rate (default: 0.01)')
34 | parser.add_argument('--mom', type=float, default=0.99, metavar='MOM',
35 | help='SGD momentum (default=0.99)')
36 | parser.add_argument('--cuda', action='store_true', default=False,
37 | help='enables CUDA training (default: False)')
38 | parser.add_argument('--log-interval', type=int, default=1, metavar='N',
39 | help='batches to wait before logging training status')
40 | parser.add_argument('--test-dataset', action='store_true', default=False,
41 | help='test on smaller dataset (default: False)')
42 | parser.add_argument('--size', type=int, default=128, metavar='N',
43 | help='imsize')
44 | parser.add_argument('--drop', action='store_true', default=False,
45 | help='enables drop')
46 | parser.add_argument('--data-folder', type=str, default='./Data-Nonzero/', metavar='str',
47 | help='folder that contains data (default: test dataset)')
48 |
49 |
50 | args = parser.parse_args()
51 | args.cuda = args.cuda and torch.cuda.is_available()
52 | if args.cuda:
53 | print("We are on the GPU!")
54 |
55 | DATA_FOLDER = args.data_folder
56 |
57 | # %% Loading in the Dataset
58 | dset_test = BraTSDatasetLSTM(
59 | DATA_FOLDER, keywords=MODALITY, transform=tr.ToTensor())
60 | test_loader = DataLoader(
61 | dset_test, batch_size=args.test_batch_size, shuffle=False, num_workers=1)
62 |
63 | dset_train = BraTSDatasetLSTM(
64 | DATA_FOLDER, keywords=MODALITY, transform=tr.ToTensor())
65 | train_loader = DataLoader(
66 | dset_train, batch_size=args.batch_size, shuffle=True, num_workers=1)
67 |
68 |
69 | # %% Loading in the models
70 | unet = UNetSmall()
71 | unet.load_state_dict(torch.load(UNET_MODEL_FILE))
72 | model = BDCLSTM(input_channels=32, hidden_channels=[32])
73 |
74 | if args.cuda:
75 | unet.cuda()
76 | model.cuda()
77 |
78 | # Setting Optimizer
79 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom)
80 | criterion = DICELossMultiClass()
81 |
82 | # Define Training Loop
83 |
84 |
85 | def train(epoch):
86 | model.train()
87 | for batch_idx, (image1, image2, image3, mask) in enumerate(train_loader):
88 | if args.cuda:
89 | image1, image2, image3, mask = image1.cuda(), \
90 | image2.cuda(), \
91 | image3.cuda(), \
92 | mask.cuda()
93 |
94 | image1, image2, image3, mask = Variable(image1), \
95 | Variable(image2), \
96 | Variable(image3), \
97 | Variable(mask)
98 |
99 | optimizer.zero_grad()
100 |
101 | map1 = unet(image1, return_features=True)
102 | map2 = unet(image2, return_features=True)
103 | map3 = unet(image3, return_features=True)
104 |
105 | output = model(map1, map2, map3)
106 | loss = criterion(output, mask)
107 |
108 | loss.backward()
109 | optimizer.step()
110 | if batch_idx % args.log_interval == 0:
111 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
112 | epoch, batch_idx * len(image1), len(train_loader.dataset),
113 | 100. * batch_idx / len(train_loader), loss.data[0]))
114 |
115 |
116 | def test(train_accuracy=False):
117 | test_loss = 0
118 |
119 | if train_accuracy == True:
120 | loader = train_loader
121 | else:
122 | loader = test_loader
123 |
124 | for (image1, image2, image3, mask) in loader:
125 | if args.cuda:
126 | image1, image2, image3, mask = image1.cuda(), \
127 | image2.cuda(), \
128 | image3.cuda(), \
129 | mask.cuda()
130 |
131 | image1, image2, image3, mask = Variable(image1, volatile=True), \
132 | Variable(image2, volatile=True), \
133 | Variable(image3, volatile=True), \
134 | Variable(mask, volatile=True)
135 | map1 = unet(image1, return_features=True)
136 | map2 = unet(image2, return_features=True)
137 | map3 = unet(image3, return_features=True)
138 |
139 | # print(image1.type)
140 | # print(map1.type)
141 |
142 | output = model(map1, map2, map3)
143 | test_loss += criterion(output, mask).data[0]
144 |
145 | test_loss /= len(loader)
146 | if train_accuracy:
147 | print(
148 | '\nTraining Set: Average Dice Coefficient: {:.4f}\n'.format(test_loss))
149 | else:
150 | print(
151 | '\nTest Set: Average Dice Coefficient: {:.4f}\n'.format(test_loss))
152 |
153 |
154 | if args.train:
155 | for i in range(args.epochs):
156 | train(i)
157 | test()
158 |
159 | torch.save(model.state_dict(),
160 | 'bdclstm-{}-{}-{}'.format(args.batch_size, args.epochs, args.lr))
161 | else:
162 | model.load_state_dict(torch.load('bdclstm-{}-{}-{}'.format(args.batch_size,
163 | args.epochs,
164 | args.lr)))
165 | test()
166 | test(train_accuracy=True)
167 |
--------------------------------------------------------------------------------
/main_small.py:
--------------------------------------------------------------------------------
1 | # %% -*- coding: utf-8 -*-
2 | '''
3 | Author: Shreyas Padhy
4 | Driver file for Unet and BDC-LSTM Implementation
5 | '''
6 |
7 | from __future__ import print_function
8 |
9 | import argparse
10 | import matplotlib
11 | matplotlib.use('Agg')
12 | import matplotlib.pyplot as plt
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.optim as optim
17 | from torch.autograd import Variable
18 | from torch.utils.data import DataLoader
19 | import torchvision.transforms as tr
20 |
21 | from data import BraTSDatasetUnet, BraTSDatasetLSTM
22 | from losses import DICELoss
23 | from models import UNetSmall
24 | from tqdm import tqdm
25 | import scipy.io as sio
26 | import numpy as np
27 |
28 | # %% import transforms
29 |
30 | # %% Training settings
31 | parser = argparse.ArgumentParser(description='UNet+BDCLSTM for BraTS Dataset')
32 | parser.add_argument('--batch-size', type=int, default=4, metavar='N',
33 | help='input batch size for training (default: 64)')
34 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
35 | help='input batch size for testing (default: 1000)')
36 | parser.add_argument('--train', action='store_true', default=False,
37 | help='Argument to train model (default: False)')
38 | parser.add_argument('--epochs', type=int, default=1, metavar='N',
39 | help='number of epochs to train (default: 10)')
40 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
41 | help='learning rate (default: 0.01)')
42 | parser.add_argument('--cuda', action='store_true', default=False,
43 | help='enables CUDA training (default: False)')
44 | parser.add_argument('--log-interval', type=int, default=1, metavar='N',
45 | help='batches to wait before logging training status')
46 | parser.add_argument('--size', type=int, default=128, metavar='N',
47 | help='imsize')
48 | parser.add_argument('--load', type=str, default=None, metavar='str',
49 | help='weight file to load (default: None)')
50 | parser.add_argument('--data-folder', type=str, default='./Data/', metavar='str',
51 | help='folder that contains data (default: test dataset)')
52 | parser.add_argument('--save', type=str, default='OutMasks', metavar='str',
53 | help='Identifier to save npy arrays with')
54 | parser.add_argument('--modality', type=str, default='flair', metavar='str',
55 | help='Modality to use for training (default: flair)')
56 | parser.add_argument('--optimizer', type=str, default='ADAM', metavar='str',
57 | help='Optimizer (default: SGD)')
58 | parser.add_argument('--clip', action='store_true', default=False,
59 | help='enables gradnorm clip of 1.0 (default: False)')
60 | args = parser.parse_args()
61 | args.cuda = args.cuda and torch.cuda.is_available()
62 |
63 | DATA_FOLDER = args.data_folder
64 |
65 | # %% Loading in the Dataset
66 | dset_train = BraTSDatasetUnet(DATA_FOLDER, train=True,
67 | keywords=[args.modality],
68 | im_size=[args.size, args.size], transform=tr.ToTensor())
69 |
70 | train_loader = DataLoader(dset_train,
71 | batch_size=args.batch_size,
72 | shuffle=True, num_workers=1)
73 |
74 | dset_test = BraTSDatasetUnet(DATA_FOLDER, train=False,
75 | keywords=[args.modality],
76 | im_size=[args.size, args.size], transform=tr.ToTensor())
77 |
78 | test_loader = DataLoader(dset_test,
79 | batch_size=args.test_batch_size,
80 | shuffle=False, num_workers=1)
81 |
82 |
83 | print("Training Data : ", len(train_loader.dataset))
84 | print("Testing Data : ", len(test_loader.dataset))
85 |
86 | # %% Loading in the model
87 | model = UNetSmall()
88 |
89 | if args.cuda:
90 | model.cuda()
91 |
92 | if args.optimizer == 'SGD':
93 | optimizer = optim.SGD(model.parameters(), lr=args.lr,
94 | momentum=0.99)
95 | if args.optimizer == 'ADAM':
96 | optimizer = optim.Adam(model.parameters(), lr=args.lr,
97 | betas=(0.9, 0.999))
98 |
99 |
100 | # Defining Loss Function
101 | criterion = DICELoss()
102 | # Define Training Loop
103 |
104 |
105 | def train(epoch, loss_list):
106 | model.train()
107 | for batch_idx, (image, mask) in enumerate(train_loader):
108 | if args.cuda:
109 | image, mask = image.cuda(), mask.cuda()
110 |
111 | image, mask = Variable(image), Variable(mask)
112 |
113 | optimizer.zero_grad()
114 |
115 | output = model(image)
116 |
117 | loss = criterion(output, mask)
118 | loss_list.append(loss.data[0])
119 |
120 | loss.backward()
121 | optimizer.step()
122 |
123 | if args.clip:
124 | nn.utils.clip_grad_norm(model.parameters(), max_norm=1)
125 |
126 | if batch_idx % args.log_interval == 0:
127 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
128 | epoch, batch_idx * len(image), len(train_loader.dataset),
129 | 100. * batch_idx / len(train_loader), loss.data[0]))
130 |
131 |
132 | def test(train_accuracy=False, save_output=False):
133 | test_loss = 0
134 | correct = 0
135 |
136 | if train_accuracy:
137 | loader = train_loader
138 | else:
139 | loader = test_loader
140 |
141 | for batch_idx, (image, mask) in tqdm(enumerate(loader)):
142 | if args.cuda:
143 | image, mask = image.cuda(), mask.cuda()
144 |
145 | image, mask = Variable(image, volatile=True), Variable(
146 | mask, volatile=True)
147 |
148 | output = model(image)
149 |
150 | test_loss += criterion(output, mask).data[0]
151 |
152 | output.data.round_()
153 |
154 | if save_output and (not train_accuracy):
155 | np.save('./npy-files/out-files/{}-unetsmall-batch-{}-outs.npy'.format(args.save,
156 | batch_idx),
157 | output.data.byte().cpu().numpy())
158 | np.save('./npy-files/out-files/{}-unetsmall--batch-{}-masks.npy'.format(args.save,
159 | batch_idx),
160 | mask.data.byte().cpu().numpy())
161 | np.save('./npy-files/out-files/{}-unetsmall--batch-{}-images.npy'.format(args.save,
162 | batch_idx),
163 | image.data.float().cpu().numpy())
164 |
165 | if save_output and train_accuracy:
166 | np.save('./npy-files/out-files/{}-unetsmall-train-batch-{}-outs.npy'.format(args.save,
167 | batch_idx),
168 | output.data.byte().cpu().numpy())
169 | np.save('./npy-files/out-files/{}-unetsmall-train-batch-{}-masks.npy'.format(args.save,
170 | batch_idx),
171 | mask.data.byte().cpu().numpy())
172 | np.save('./npy-files/out-files/{}-unetsmall-train-batch-{}-images.npy'.format(args.save,
173 | batch_idx),
174 | image.data.float().cpu().numpy())
175 |
176 | # Average Dice Coefficient
177 | test_loss /= len(loader)
178 | if train_accuracy:
179 | print('\nTraining Set: Average DICE Coefficient: {:.4f})\n'.format(
180 | test_loss))
181 | else:
182 | print('\nTest Set: Average DICE Coefficient: {:.4f})\n'.format(
183 | test_loss))
184 |
185 |
186 | if args.train:
187 | loss_list = []
188 | for i in tqdm(range(args.epochs)):
189 | train(i, loss_list)
190 | test(train_accuracy=False, save_output=False)
191 | test(train_accuracy=True, save_output=False)
192 |
193 | plt.plot(loss_list)
194 | plt.title("UNetSmall bs={}, ep={}, lr={}".format(args.batch_size,
195 | args.epochs, args.lr))
196 | plt.xlabel("Number of iterations")
197 | plt.ylabel("Average DICE loss per batch")
198 | plt.savefig("./plots/{}-UNetSmall_Loss_bs={}_ep={}_lr={}.png".format(args.save,
199 | args.batch_size,
200 | args.epochs,
201 | args.lr))
202 |
203 | np.save('./npy-files/loss-files/{}-UNetSmall_Loss_bs={}_ep={}_lr={}.npy'.format(args.save,
204 | args.batch_size,
205 | args.epochs,
206 | args.lr),
207 | np.asarray(loss_list))
208 |
209 | torch.save(model.state_dict(), 'unetsmall-final-{}-{}-{}'.format(args.batch_size,
210 | args.epochs,
211 | args.lr))
212 | else:
213 | model.load_state_dict(torch.load(args.load))
214 | test(save_output=True)
215 | test(train_accuracy=True)
216 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch.nn as nn
3 | import torch
4 |
5 |
6 | class UNet(nn.Module):
7 | def __init__(self, num_channels=1, num_classes=2):
8 | super(UNet, self).__init__()
9 | num_feat = [64, 128, 256, 512, 1024]
10 |
11 | self.down1 = nn.Sequential(Conv3x3(num_channels, num_feat[0]))
12 |
13 | self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
14 | Conv3x3(num_feat[0], num_feat[1]))
15 |
16 | self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
17 | Conv3x3(num_feat[1], num_feat[2]))
18 |
19 | self.down4 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
20 | Conv3x3(num_feat[2], num_feat[3]))
21 |
22 | self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2),
23 | Conv3x3(num_feat[3], num_feat[4]))
24 |
25 | self.up1 = UpConcat(num_feat[4], num_feat[3])
26 | self.upconv1 = Conv3x3(num_feat[4], num_feat[3])
27 |
28 | self.up2 = UpConcat(num_feat[3], num_feat[2])
29 | self.upconv2 = Conv3x3(num_feat[3], num_feat[2])
30 |
31 | self.up3 = UpConcat(num_feat[2], num_feat[1])
32 | self.upconv3 = Conv3x3(num_feat[2], num_feat[1])
33 |
34 | self.up4 = UpConcat(num_feat[1], num_feat[0])
35 | self.upconv4 = Conv3x3(num_feat[1], num_feat[0])
36 |
37 | self.final = nn.Sequential(nn.Conv2d(num_feat[0],
38 | num_classes,
39 | kernel_size=1),
40 | nn.Softmax2d())
41 |
42 | def forward(self, inputs, return_features=False):
43 | # print(inputs.data.size())
44 | down1_feat = self.down1(inputs)
45 | # print(down1_feat.size())
46 | down2_feat = self.down2(down1_feat)
47 | # print(down2_feat.size())
48 | down3_feat = self.down3(down2_feat)
49 | # print(down3_feat.size())
50 | down4_feat = self.down4(down3_feat)
51 | # print(down4_feat.size())
52 | bottom_feat = self.bottom(down4_feat)
53 |
54 | # print(bottom_feat.size())
55 | up1_feat = self.up1(bottom_feat, down4_feat)
56 | # print(up1_feat.size())
57 | up1_feat = self.upconv1(up1_feat)
58 | # print(up1_feat.size())
59 | up2_feat = self.up2(up1_feat, down3_feat)
60 | # print(up2_feat.size())
61 | up2_feat = self.upconv2(up2_feat)
62 | # print(up2_feat.size())
63 | up3_feat = self.up3(up2_feat, down2_feat)
64 | # print(up3_feat.size())
65 | up3_feat = self.upconv3(up3_feat)
66 | # print(up3_feat.size())
67 | up4_feat = self.up4(up3_feat, down1_feat)
68 | # print(up4_feat.size())
69 | up4_feat = self.upconv4(up4_feat)
70 | # print(up4_feat.size())
71 |
72 | if return_features:
73 | outputs = up4_feat
74 | else:
75 | outputs = self.final(up4_feat)
76 |
77 | return outputs
78 |
79 |
80 | class UNetSmall(nn.Module):
81 | def __init__(self, num_channels=1, num_classes=2):
82 | super(UNetSmall, self).__init__()
83 | num_feat = [32, 64, 128, 256]
84 |
85 | self.down1 = nn.Sequential(Conv3x3Small(num_channels, num_feat[0]))
86 |
87 | self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
88 | nn.BatchNorm2d(num_feat[0]),
89 | Conv3x3Small(num_feat[0], num_feat[1]))
90 |
91 | self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
92 | nn.BatchNorm2d(num_feat[1]),
93 | Conv3x3Small(num_feat[1], num_feat[2]))
94 |
95 | self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2),
96 | nn.BatchNorm2d(num_feat[2]),
97 | Conv3x3Small(num_feat[2], num_feat[3]),
98 | nn.BatchNorm2d(num_feat[3]))
99 |
100 | self.up1 = UpSample(num_feat[3], num_feat[2])
101 | self.upconv1 = nn.Sequential(Conv3x3Small(num_feat[3] + num_feat[2], num_feat[2]),
102 | nn.BatchNorm2d(num_feat[2]))
103 |
104 | self.up2 = UpSample(num_feat[2], num_feat[1])
105 | self.upconv2 = nn.Sequential(Conv3x3Small(num_feat[2] + num_feat[1], num_feat[1]),
106 | nn.BatchNorm2d(num_feat[1]))
107 |
108 | self.up3 = UpSample(num_feat[1], num_feat[0])
109 | self.upconv3 = nn.Sequential(Conv3x3Small(num_feat[1] + num_feat[0], num_feat[0]),
110 | nn.BatchNorm2d(num_feat[0]))
111 |
112 | self.final = nn.Sequential(nn.Conv2d(num_feat[0],
113 | 1,
114 | kernel_size=1),
115 | nn.Sigmoid())
116 |
117 | def forward(self, inputs, return_features=False):
118 | # print(inputs.data.size())
119 | down1_feat = self.down1(inputs)
120 | # print(down1_feat.size())
121 | down2_feat = self.down2(down1_feat)
122 | # print(down2_feat.size())
123 | down3_feat = self.down3(down2_feat)
124 | # print(down3_feat.size())
125 | bottom_feat = self.bottom(down3_feat)
126 |
127 | # print(bottom_feat.size())
128 | up1_feat = self.up1(bottom_feat, down3_feat)
129 | # print(up1_feat.size())
130 | up1_feat = self.upconv1(up1_feat)
131 | # print(up1_feat.size())
132 | up2_feat = self.up2(up1_feat, down2_feat)
133 | # print(up2_feat.size())
134 | up2_feat = self.upconv2(up2_feat)
135 | # print(up2_feat.size())
136 | up3_feat = self.up3(up2_feat, down1_feat)
137 | # print(up3_feat.size())
138 | up3_feat = self.upconv3(up3_feat)
139 | # print(up3_feat.size())
140 |
141 | if return_features:
142 | outputs = up3_feat
143 | else:
144 | outputs = self.final(up3_feat)
145 |
146 | return outputs
147 |
148 |
149 | class Conv3x3(nn.Module):
150 | def __init__(self, in_feat, out_feat):
151 | super(Conv3x3, self).__init__()
152 |
153 | self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
154 | kernel_size=3,
155 | stride=1,
156 | padding=1),
157 | nn.BatchNorm2d(out_feat),
158 | nn.ReLU())
159 |
160 | self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat,
161 | kernel_size=3,
162 | stride=1,
163 | padding=1),
164 | nn.BatchNorm2d(out_feat),
165 | nn.ReLU())
166 |
167 | def forward(self, inputs):
168 | outputs = self.conv1(inputs)
169 | outputs = self.conv2(outputs)
170 | return outputs
171 |
172 |
173 | class Conv3x3Drop(nn.Module):
174 | def __init__(self, in_feat, out_feat):
175 | super(Conv3x3Drop, self).__init__()
176 |
177 | self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
178 | kernel_size=3,
179 | stride=1,
180 | padding=1),
181 | nn.Dropout(p=0.2),
182 | nn.ReLU())
183 |
184 | self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat,
185 | kernel_size=3,
186 | stride=1,
187 | padding=1),
188 | nn.BatchNorm2d(out_feat),
189 | nn.ReLU())
190 |
191 | def forward(self, inputs):
192 | outputs = self.conv1(inputs)
193 | outputs = self.conv2(outputs)
194 | return outputs
195 |
196 |
197 | class Conv3x3Small(nn.Module):
198 | def __init__(self, in_feat, out_feat):
199 | super(Conv3x3Small, self).__init__()
200 |
201 | self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat,
202 | kernel_size=3,
203 | stride=1,
204 | padding=1),
205 | nn.ELU(),
206 | nn.Dropout(p=0.2))
207 |
208 | self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat,
209 | kernel_size=3,
210 | stride=1,
211 | padding=1),
212 | nn.ELU())
213 |
214 | def forward(self, inputs):
215 | outputs = self.conv1(inputs)
216 | outputs = self.conv2(outputs)
217 | return outputs
218 |
219 |
220 | class UpConcat(nn.Module):
221 | def __init__(self, in_feat, out_feat):
222 | super(UpConcat, self).__init__()
223 |
224 | self.up = nn.UpsamplingBilinear2d(scale_factor=2)
225 |
226 | # self.deconv = nn.ConvTranspose2d(in_feat, out_feat,
227 | # kernel_size=3,
228 | # stride=1,
229 | # dilation=1)
230 |
231 | self.deconv = nn.ConvTranspose2d(in_feat,
232 | out_feat,
233 | kernel_size=2,
234 | stride=2)
235 |
236 | def forward(self, inputs, down_outputs):
237 | # TODO: Upsampling required after deconv?
238 | # outputs = self.up(inputs)
239 | outputs = self.deconv(inputs)
240 | out = torch.cat([down_outputs, outputs], 1)
241 | return out
242 |
243 |
244 | class UpSample(nn.Module):
245 | def __init__(self, in_feat, out_feat):
246 | super(UpSample, self).__init__()
247 |
248 | self.up = nn.Upsample(scale_factor=2, mode='nearest')
249 |
250 | self.deconv = nn.ConvTranspose2d(in_feat,
251 | out_feat,
252 | kernel_size=2,
253 | stride=2)
254 |
255 | def forward(self, inputs, down_outputs):
256 | # TODO: Upsampling required after deconv?
257 | outputs = self.up(inputs)
258 | # outputs = self.deconv(inputs)
259 | out = torch.cat([outputs, down_outputs], 1)
260 | return out
261 |
--------------------------------------------------------------------------------
/plot_ims.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import scipy.io as sio
4 | # outs = np.load('Numpy-batch-7-outs.npy')
5 | # masks = np.load('Numpy-batch-7-masks.npy')
6 |
7 | final_outs = []
8 | final_masks = []
9 | final_images = []
10 |
11 | loss = np.load('./OutMasks-UNet_Loss_bs=4_ep=1_lr=0.001.npy')
12 |
13 | for i in range(38):
14 | outs = np.load(
15 | './npy-files/out-files/UNetSmall-NonZero-unetsmall-batch-{}-outs.npy'.format(i))
16 | masks = np.load(
17 | './npy-files/out-files/UNetSmall-NonZero-unetsmall--batch-{}-masks.npy'.format(i))
18 | images = np.load(
19 | './npy-files/out-files/UNetSmall-NonZero-unetsmall--batch-{}-images.npy'.format(i))
20 |
21 | final_outs.append(outs)
22 | final_masks.append(masks)
23 | final_images.append(images)
24 | final_outs = np.asarray(final_outs)
25 | final_masks = np.asarray(final_masks)
26 | final_images = np.asarray(final_images)
27 |
28 | print(final_images[0].shape)
29 | for i in range(38):
30 | print(final_images[i].shape)
31 | plt.imshow(np.squeeze(final_images[i][49, :, :]), cmap='gray')
32 | plt.show()
33 |
34 |
35 | # print(final_outs.shape)
36 | # print(final_masks.shape)
37 |
38 | # sio.savemat('./mat-files/final_outputs.mat', {'data': final_outs})
39 | # sio.savemat('./mat-files/final_masks.mat', {'data': final_masks})
40 | # sio.savemat('./mat-files/final_images.mat', {'data': final_images})
41 | # for i in range(len(outs)):
42 | # plt1 = 255 * np.squeeze(outs[i, :, :, :]).astype('uint8')
43 | # plt2 = 255 * np.squeeze(masks[i, :, :, :]).astype('uint8')
44 | # print(plt1, plt2)
45 | # plt.subplot(1, 2, 1)
46 | # plt.imshow(plt1, cmap='gray')
47 | # plt.title("UNet Out")
48 | # plt.subplot(1, 2, 2)
49 | # plt.imshow(plt2, cmap='gray')
50 | # plt.title("Mask")
51 |
--------------------------------------------------------------------------------
/result_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/result_comparison.png
--------------------------------------------------------------------------------
/smallunet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/smallunet.png
--------------------------------------------------------------------------------
/unet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/unet.png
--------------------------------------------------------------------------------