├── .DS_Store ├── CLSTM.py ├── LICENSE ├── README.md ├── data.py ├── dataParser.py ├── dice.png ├── hybridunet.png ├── losses.py ├── main.py ├── main_bdclstm.py ├── main_small.py ├── models.py ├── plot_ims.py ├── result_comparison.png ├── smallunet.png └── unet.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/.DS_Store -------------------------------------------------------------------------------- /CLSTM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | # Batch x NumChannels x Height x Width 6 | # UNET --> BatchSize x 1 (3?) x 240 x 240 7 | # BDCLSTM --> BatchSize x 64 x 240 x240 8 | 9 | ''' Class CLSTMCell. 10 | This represents a single node in a CLSTM series. 11 | It produces just one time (spatial) step output. 12 | ''' 13 | 14 | 15 | class CLSTMCell(nn.Module): 16 | 17 | # Constructor 18 | def __init__(self, input_channels, hidden_channels, 19 | kernel_size, bias=True): 20 | super(CLSTMCell, self).__init__() 21 | 22 | assert hidden_channels % 2 == 0 23 | 24 | self.input_channels = input_channels 25 | self.hidden_channels = hidden_channels 26 | self.bias = bias 27 | self.kernel_size = kernel_size 28 | self.num_features = 4 29 | 30 | self.padding = (kernel_size - 1) // 2 31 | self.conv = nn.Conv2d(self.input_channels + self.hidden_channels, 32 | self.num_features * self.hidden_channels, 33 | self.kernel_size, 34 | 1, 35 | self.padding) 36 | 37 | # Forward propogation formulation 38 | def forward(self, x, h, c): 39 | # print('x: ', x.type) 40 | # print('h: ', h.type) 41 | combined = torch.cat((x, h), dim=1) 42 | A = self.conv(combined) 43 | 44 | # NOTE: A? = xz * Wx? + hz-1 * Wh? + b? where * is convolution 45 | (Ai, Af, Ao, Ag) = torch.split(A, 46 | A.size()[1] // self.num_features, 47 | dim=1) 48 | 49 | i = torch.sigmoid(Ai) # input gate 50 | f = torch.sigmoid(Af) # forget gate 51 | o = torch.sigmoid(Ao) # output gate 52 | g = torch.tanh(Ag) 53 | 54 | c = c * f + i * g # cell activation state 55 | h = o * torch.tanh(c) # cell hidden state 56 | 57 | return h, c 58 | 59 | @staticmethod 60 | def init_hidden(batch_size, hidden_c, shape): 61 | try: 62 | return(Variable(torch.zeros(batch_size, 63 | hidden_c, 64 | shape[0], 65 | shape[1])).cuda(), 66 | Variable(torch.zeros(batch_size, 67 | hidden_c, 68 | shape[0], 69 | shape[1])).cuda()) 70 | except: 71 | return(Variable(torch.zeros(batch_size, 72 | hidden_c, 73 | shape[0], 74 | shape[1])), 75 | Variable(torch.zeros(batch_size, 76 | hidden_c, 77 | shape[0], 78 | shape[1]))) 79 | 80 | 81 | ''' Class CLSTM. 82 | This represents a series of CLSTM nodes (one direction) 83 | ''' 84 | 85 | 86 | class CLSTM(nn.Module): 87 | # Constructor 88 | def __init__(self, input_channels=64, hidden_channels=[64], 89 | kernel_size=5, bias=True): 90 | super(CLSTM, self).__init__() 91 | 92 | # store stuff 93 | self.input_channels = [input_channels] + hidden_channels 94 | self.hidden_channels = hidden_channels 95 | self.kernel_size = kernel_size 96 | self.num_layers = len(hidden_channels) 97 | 98 | self.bias = bias 99 | self.all_layers = [] 100 | 101 | # create a node for each layer in the CLSTM 102 | for layer in range(self.num_layers): 103 | name = 'cell{}'.format(layer) 104 | cell = CLSTMCell(self.input_channels[layer], 105 | self.hidden_channels[layer], 106 | self.kernel_size, 107 | self.bias) 108 | setattr(self, name, cell) 109 | self.all_layers.append(cell) 110 | 111 | # Forward propogation 112 | # x --> BatchSize x NumSteps x NumChannels x Height x Width 113 | # BatchSize x 2 x 64 x 240 x 240 114 | def forward(self, x): 115 | bsize, steps, _, height, width = x.size() 116 | internal_state = [] 117 | outputs = [] 118 | for step in range(steps): 119 | input = torch.squeeze(x[:, step, :, :, :], dim=1) 120 | for layer in range(self.num_layers): 121 | # populate hidden states for all layers 122 | if step == 0: 123 | (h, c) = CLSTMCell.init_hidden(bsize, 124 | self.hidden_channels[layer], 125 | (height, width)) 126 | internal_state.append((h, c)) 127 | 128 | # do forward 129 | name = 'cell{}'.format(layer) 130 | (h, c) = internal_state[layer] 131 | 132 | input, c = getattr(self, name)( 133 | input, h, c) # forward propogation call 134 | internal_state[layer] = (input, c) 135 | 136 | outputs.append(input) 137 | 138 | #for i in range(len(outputs)): 139 | # print(outputs[i].size()) 140 | return outputs 141 | 142 | 143 | class BDCLSTM(nn.Module): 144 | # Constructor 145 | def __init__(self, input_channels=64, hidden_channels=[64], 146 | kernel_size=5, bias=True, num_classes=2): 147 | 148 | super(BDCLSTM, self).__init__() 149 | self.forward_net = CLSTM( 150 | input_channels, hidden_channels, kernel_size, bias) 151 | self.reverse_net = CLSTM( 152 | input_channels, hidden_channels, kernel_size, bias) 153 | self.conv = nn.Conv2d( 154 | 2 * hidden_channels[-1], num_classes, kernel_size=1) 155 | self.soft = nn.Softmax2d() 156 | 157 | # Forward propogation 158 | # x --> BatchSize x NumChannels x Height x Width 159 | # BatchSize x 64 x 240 x 240 160 | def forward(self, x1, x2, x3): 161 | x1 = torch.unsqueeze(x1, dim=1) 162 | x2 = torch.unsqueeze(x2, dim=1) 163 | x3 = torch.unsqueeze(x3, dim=1) 164 | 165 | xforward = torch.cat((x1, x2), dim=1) 166 | xreverse = torch.cat((x3, x2), dim=1) 167 | 168 | yforward = self.forward_net(xforward) 169 | yreverse = self.reverse_net(xreverse) 170 | 171 | # assumes y is BatchSize x NumClasses x 240 x 240 172 | # print(yforward[-1].type) 173 | ycat = torch.cat((yforward[-1], yreverse[-1]), dim=1) 174 | # print(ycat.size()) 175 | y = self.conv(ycat) 176 | # print(y.type) 177 | y = self.soft(y) 178 | # print(y.type) 179 | return y 180 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Shreyas Padhy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNet-Zoo 2 | A collection of UNet and hybrid architectures for 2D and 3D Biomedical Image segmentation, implemented in PyTorch. 3 | 4 | This repository contains a collection of architectures used for Biomedical Image Segmentation, implemented on the BraTS Brain Tumor Segmentation Challenge Dataset. The following architectures are implemented 5 | 6 | 1. **UNet** - Standard UNet architecture as described in the Ronneberger et al 2015 paper [[reference]](https://arxiv.org/abs/1505.04597) 7 |

8 | 9 |

10 | 11 | 1. **Small UNet** - 40x smaller version of UNet that achieves similar performance [[Theano Implementation]](https://github.com/codedecde/Luna2016-Lung-Nodule-Detection) 12 |

13 | 14 |

15 | 16 | 1. **UNet with BDCLSTM** - Combining a BDC-LSTM network with UNet to encode spatial correlation for 3D segmentation [[reference]](https://arxiv.org/pdf/1609.01006.pdf) 17 | 18 |

19 | 20 |

21 | 22 | 1. **kUNet** - Combining multiple UNets for increasing heirarchial preservation of information (coming soon) [[reference]](https://arxiv.org/pdf/1701.03056.pdf) 23 | 1. **R-UNet** - UNet with recurrent connections for another way to encode $z$-context (coming soon) 24 | ### To Run 25 | 26 | First, apply for access the BraTS Tumor dataset, and place the scans in a `Data/` folder, divided into `Train` and `Test`. Feel free to modify the BraTS PyTorch dataloaders in `data.py` for your use. 27 | 1. UNet - run `main.py`, type `--help` for information on arguments. 28 | Example: `python main.py --train --cuda --data-folder "./Data/"` 29 | 1. Small UNet - run `main_small.py`, and use `--help` 30 | 1. BDC-LSTM - run `main_bdclstm.py` and use the weights for either your trained UNet or Small-UNet models (`--help` is your savior). 31 | 32 | ### Some Results 33 | 34 | 1. Comparisons of UNet (top) and Small UNet (bottom) 35 |

36 | 37 |

38 | 39 | 40 | 2. DICE Scores for UNet and Small UNet 41 | 42 |

43 | 44 |

45 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from torch.utils.data.dataset import Dataset 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from PIL import Image 7 | import os 8 | import scipy.io as sio 9 | from dataParser import getMaskFileName, getImg 10 | import torchvision.transforms as tr 11 | 12 | 13 | class BraTSDatasetUnet(Dataset): 14 | __file = [] 15 | __im = [] 16 | __mask = [] 17 | im_ht = 0 18 | im_wd = 0 19 | dataset_size = 0 20 | 21 | def __init__(self, dataset_folder, train=True, keywords=["P1", "1", "flair"], im_size=[128, 128], transform=None): 22 | 23 | self.__file = [] 24 | self.__im = [] 25 | self.__mask = [] 26 | self.im_ht = im_size[0] 27 | self.im_wd = im_size[1] 28 | self.transform = transform 29 | 30 | folder = dataset_folder 31 | # # Open and load text file including the whole training data 32 | if train: 33 | folder = dataset_folder + "Train/" 34 | else: 35 | folder = dataset_folder + "Test/" 36 | 37 | for file in os.listdir(folder): 38 | if file.endswith(".png"): 39 | filename = os.path.splitext(file)[0] 40 | filename_fragments = filename.split("_") 41 | samekeywords = list(set(filename_fragments) & set(keywords)) 42 | if len(samekeywords) == len(keywords): 43 | # 1. read file name 44 | self.__file.append(filename) 45 | # 2. read raw image 46 | # TODO: I think we should open image only in getitem, 47 | # otherwise memory explodes 48 | 49 | # rawImage = getImg(folder + file) 50 | self.__im.append(folder + file) 51 | # 3. read mask image 52 | mask_file = getMaskFileName(file) 53 | # maskImage = getImg(folder + mask_file) 54 | self.__mask.append(folder + mask_file) 55 | # self.dataset_size = len(self.__file) 56 | 57 | # print("lengths : ", len(self.__im), len(self.__mask)) 58 | self.dataset_size = len(self.__file) 59 | 60 | if not train: 61 | sio.savemat('filelist2.mat', {'data': self.__im}) 62 | 63 | def __getitem__(self, index): 64 | 65 | img = getImg(self.__im[index]) 66 | mask = getImg(self.__mask[index]) 67 | 68 | img = img.resize((self.im_ht, self.im_wd)) 69 | mask = mask.resize((self.im_ht, self.im_wd)) 70 | # mask.show() 71 | 72 | if self.transform is not None: 73 | # TODO: Not sure why not take full image 74 | img_tr = self.transform(img) 75 | mask_tr = self.transform(mask) 76 | # img_tr = self.transform(img[None, :, :]) 77 | # mask_tr = self.transform(mask[None, :, :]) 78 | 79 | return img_tr, mask_tr 80 | # return img.float(), mask.float() 81 | 82 | def __len__(self): 83 | 84 | return len(self.__im) 85 | 86 | 87 | class BraTSDatasetLSTM(Dataset): 88 | __im = [] 89 | __mask = [] 90 | __im1 = [] 91 | __im3 = [] 92 | im_ht = 0 93 | im_wd = 0 94 | dataset_size = 0 95 | 96 | def __init__(self, dataset_folder, train=True, keywords=["P1", "1", "flair"], im_size=[128, 128], transform=None): 97 | 98 | self.__file = [] 99 | self.__im = [] 100 | self.__mask = [] 101 | self.im_ht = im_size[0] 102 | self.im_wd = im_size[1] 103 | self.transform = transform 104 | 105 | folder = dataset_folder 106 | # # Open and load text file including the whole training data 107 | if train: 108 | folder = dataset_folder + "Train/" 109 | else: 110 | folder = dataset_folder + "Test/" 111 | 112 | # print("files : ", os.listdir(folder)) 113 | # print("Folder : ", folder) 114 | max_file = 0 115 | min_file = 10000000 116 | for file in os.listdir(folder): 117 | if file.endswith(".png"): 118 | m = re.search('(P[0-9]*[_])([0-9]*)', file) 119 | pic_num = int(m.group(2)) 120 | if pic_num > max_file: 121 | max_file = pic_num 122 | if pic_num < min_file: 123 | min_file = pic_num 124 | 125 | # print('min file number: ', min_file) 126 | # print('max file number: ', max_file) 127 | 128 | for file in os.listdir(folder): 129 | if file.endswith(".png"): 130 | filename = os.path.splitext(file)[0] 131 | filename_fragments = filename.split("_") 132 | samekeywords = list(set(filename_fragments) & set(keywords)) 133 | if len(samekeywords) == len(keywords): 134 | # 1. read file name 135 | # 2. read raw image 136 | # TODO: I think we should open image only in getitem, 137 | # otherwise memory explodes 138 | 139 | # rawImage = getImg(folder + file) 140 | 141 | if (filename_fragments[2] != str(min_file)) and (filename_fragments[2] != str(max_file)): 142 | # print("TEST : ", filename_fragments[2]) 143 | self.__im.append(folder + file) 144 | 145 | file1 = filename_fragments[0] + '_' + filename_fragments[1] + '_' + str( 146 | int(filename_fragments[2]) - 1) + '_' + filename_fragments[3] + '.png' 147 | 148 | self.__im1.append(folder + file1) 149 | 150 | file3 = filename_fragments[0] + '_' + filename_fragments[1] + '_' + str( 151 | int(filename_fragments[2]) + 1) + '_' + filename_fragments[3] + '.png' 152 | 153 | self.__im3.append(folder + file3) 154 | # 3. read mask image 155 | mask_file = getMaskFileName(file) 156 | # maskImage = getImg(folder + mask_file) 157 | self.__mask.append(folder + mask_file) 158 | # self.dataset_size = len(self.__file) 159 | 160 | # print("lengths : ", len(self.__im), len(self.__mask)) 161 | self.dataset_size = len(self.__file) 162 | 163 | def __getitem__(self, index): 164 | 165 | img1 = getImg(self.__im1[index]) 166 | img = getImg(self.__im[index]) 167 | img3 = getImg(self.__im3[index]) 168 | mask = getImg(self.__mask[index]) 169 | 170 | # img.show() 171 | # mask.show() 172 | 173 | if self.transform is not None: 174 | # TODO: Not sure why not take full image 175 | img_tr1 = self.transform(img1) 176 | img_tr = self.transform(img) 177 | img_tr3 = self.transform(img3) 178 | mask_tr = self.transform(mask) 179 | # img_tr = self.transform(img[None, :, :]) 180 | # mask_tr = self.transform(mask[None, :, :]) 181 | 182 | return img_tr1, img_tr, img_tr3, mask_tr 183 | # return img.float(), mask.float() 184 | 185 | def __len__(self): 186 | 187 | return len(self.__im) 188 | -------------------------------------------------------------------------------- /dataParser.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from skimage import color 6 | from skimage import io 7 | 8 | 9 | def getMaskFileName(file): 10 | 11 | mask_file = file.replace("flair.png", "seg.png") 12 | mask_file = mask_file.replace("t1.png", "seg.png") 13 | mask_file = mask_file.replace("t2.png", "seg.png") 14 | mask_file = mask_file.replace("t1ce.png", "seg.png") 15 | 16 | return mask_file 17 | 18 | 19 | def getImg(imgpathway): 20 | # image_file = Image.open(imgpathway) # open colour image 21 | # img = image_file.convert('L') 22 | # IMG = np.asarray(img.getdata()) 23 | # img = io.imread(imgpathway, as_grey=True) 24 | img = Image.open(imgpathway) 25 | # img = np.asarray(img) 26 | # img *= 65536.0 / np.max(img) 27 | # IMG.astype(np.uint16) 28 | # plt.imshow(IMG, cmap='gray') 29 | # plt.show() 30 | return img 31 | 32 | 33 | def File2Image(self, index): 34 | file = self.__file[index] 35 | filename_fragments = file.split("_") 36 | if filename_fragments[1] == '0' or filename_fragments[1] == '154': 37 | # Not sure what to do here 38 | return 0, 0 39 | 40 | filename1 = filename_fragments[0] + filename_fragments[1] + '_' + \ 41 | str(int(filename_fragments[2]) - 1) + '_' + filename_fragments[3] 42 | filename3 = filename_fragments[0] + filename_fragments[1] + '_' + \ 43 | str(int(filename_fragments[2]) + 1) + '_' + filename_fragments[3] 44 | 45 | idx1 = self.__file.index(filename1) 46 | idx3 = self.__file.index(filename3) 47 | img1 = self.__im[idx1] 48 | img3 = self.__im[idx3] 49 | 50 | return img1, img3 51 | -------------------------------------------------------------------------------- /dice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/dice.png -------------------------------------------------------------------------------- /hybridunet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/hybridunet.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.functional as f 5 | import numpy as np 6 | 7 | 8 | class DICELossMultiClass(nn.Module): 9 | 10 | def __init__(self): 11 | super(DICELossMultiClass, self).__init__() 12 | 13 | def forward(self, output, mask): 14 | num_classes = output.size(1) 15 | dice_eso = 0 16 | for i in range(num_classes): 17 | probs = torch.squeeze(output[:, i, :, :], 1) 18 | mask = torch.squeeze(mask[:, i, :, :], 1) 19 | 20 | num = probs * mask 21 | num = torch.sum(num, 2) 22 | num = torch.sum(num, 1) 23 | 24 | # print( num ) 25 | 26 | den1 = probs * probs 27 | # print(den1.size()) 28 | den1 = torch.sum(den1, 2) 29 | den1 = torch.sum(den1, 1) 30 | 31 | # print(den1.size()) 32 | 33 | den2 = mask * mask 34 | # print(den2.size()) 35 | den2 = torch.sum(den2, 2) 36 | den2 = torch.sum(den2, 1) 37 | 38 | # print(den2.size()) 39 | eps = 0.0000001 40 | dice = 2 * ((num + eps) / (den1 + den2 + eps)) 41 | # dice_eso = dice[:, 1:] 42 | dice_eso += dice 43 | 44 | loss = 1 - torch.sum(dice_eso) / dice_eso.size(0) 45 | return loss 46 | 47 | 48 | class DICELoss(nn.Module): 49 | 50 | def __init__(self): 51 | super(DICELoss, self).__init__() 52 | 53 | def forward(self, output, mask): 54 | 55 | probs = torch.squeeze(output, 1) 56 | mask = torch.squeeze(mask, 1) 57 | 58 | intersection = probs * mask 59 | intersection = torch.sum(intersection, 2) 60 | intersection = torch.sum(intersection, 1) 61 | 62 | den1 = probs * probs 63 | den1 = torch.sum(den1, 2) 64 | den1 = torch.sum(den1, 1) 65 | 66 | den2 = mask * mask 67 | den2 = torch.sum(den2, 2) 68 | den2 = torch.sum(den2, 1) 69 | 70 | eps = 1e-8 71 | dice = 2 * ((intersection + eps) / (den1 + den2 + eps)) 72 | # dice_eso = dice[:, 1:] 73 | dice_eso = dice 74 | 75 | loss = 1 - torch.sum(dice_eso) / dice_eso.size(0) 76 | return loss 77 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # %% -*- coding: utf-8 -*- 2 | ''' 3 | Author: Shreyas Padhy 4 | Driver file for Standard UNet Implementation 5 | ''' 6 | from __future__ import print_function 7 | 8 | import argparse 9 | import matplotlib 10 | matplotlib.use('Agg') 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | from torch.autograd import Variable 17 | from torch.utils.data import DataLoader 18 | import scipy.io as sio 19 | import torchvision.transforms as tr 20 | 21 | from data import BraTSDatasetUnet, BraTSDatasetLSTM 22 | from losses import DICELossMultiClass 23 | from models import UNet 24 | from tqdm import tqdm 25 | import numpy as np 26 | 27 | # %% import transforms 28 | 29 | # %% Training settings 30 | parser = argparse.ArgumentParser( 31 | description='UNet + BDCLSTM for BraTS Dataset') 32 | parser.add_argument('--batch-size', type=int, default=4, metavar='N', 33 | help='input batch size for training (default: 64)') 34 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 35 | help='input batch size for testing (default: 1000)') 36 | parser.add_argument('--train', action='store_true', default=False, 37 | help='Argument to train model (default: False)') 38 | parser.add_argument('--epochs', type=int, default=1, metavar='N', 39 | help='number of epochs to train (default: 10)') 40 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 41 | help='learning rate (default: 0.01)') 42 | parser.add_argument('--cuda', action='store_true', default=False, 43 | help='enables CUDA training (default: False)') 44 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 45 | help='batches to wait before logging training status') 46 | parser.add_argument('--size', type=int, default=128, metavar='N', 47 | help='imsize') 48 | parser.add_argument('--load', type=str, default=None, metavar='str', 49 | help='weight file to load (default: None)') 50 | parser.add_argument('--data-folder', type=str, default='./Data/', metavar='str', 51 | help='folder that contains data (default: test dataset)') 52 | parser.add_argument('--save', type=str, default='OutMasks', metavar='str', 53 | help='Identifier to save npy arrays with') 54 | parser.add_argument('--modality', type=str, default='flair', metavar='str', 55 | help='Modality to use for training (default: flair)') 56 | parser.add_argument('--optimizer', type=str, default='SGD', metavar='str', 57 | help='Optimizer (default: SGD)') 58 | 59 | args = parser.parse_args() 60 | args.cuda = args.cuda and torch.cuda.is_available() 61 | 62 | DATA_FOLDER = args.data_folder 63 | 64 | # %% Loading in the Dataset 65 | dset_train = BraTSDatasetUnet(DATA_FOLDER, train=True, 66 | keywords=[args.modality], 67 | im_size=[args.size, args.size], 68 | transform=tr.ToTensor()) 69 | 70 | train_loader = DataLoader(dset_train, 71 | batch_size=args.batch_size, 72 | shuffle=True, num_workers=1) 73 | 74 | dset_test = BraTSDatasetUnet(DATA_FOLDER, train=False, 75 | keywords=[args.modality], 76 | im_size=[args.size, args.size], 77 | transform=tr.ToTensor()) 78 | 79 | test_loader = DataLoader(dset_test, 80 | batch_size=args.test_batch_size, 81 | shuffle=False, num_workers=1) 82 | 83 | 84 | print("Training Data : ", len(train_loader.dataset)) 85 | print("Test Data :", len(test_loader.dataset)) 86 | 87 | # %% Loading in the model 88 | model = UNet() 89 | 90 | if args.cuda: 91 | model.cuda() 92 | 93 | if args.optimizer == 'SGD': 94 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 95 | momentum=0.99) 96 | if args.optimizer == 'ADAM': 97 | optimizer = optim.Adam(model.parameters(), lr=args.lr, 98 | betas=(args.beta1, args.beta2)) 99 | 100 | 101 | # Defining Loss Function 102 | criterion = DICELossMultiClass() 103 | 104 | 105 | def train(epoch, loss_lsit): 106 | model.train() 107 | for batch_idx, (image, mask) in enumerate(train_loader): 108 | if args.cuda: 109 | image, mask = image.cuda(), mask.cuda() 110 | 111 | image, mask = Variable(image), Variable(mask) 112 | 113 | optimizer.zero_grad() 114 | 115 | output = model(image) 116 | 117 | loss = criterion(output, mask) 118 | loss_list.append(loss.data[0]) 119 | 120 | loss.backward() 121 | optimizer.step() 122 | 123 | if batch_idx % args.log_interval == 0: 124 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tAverage DICE Loss: {:.6f}'.format( 125 | epoch, batch_idx * len(image), len(train_loader.dataset), 126 | 100. * batch_idx / len(train_loader), loss.data[0])) 127 | 128 | 129 | def test(train_accuracy=False, save_output=False): 130 | test_loss = 0 131 | 132 | if train_accuracy: 133 | loader = train_loader 134 | else: 135 | loader = test_loader 136 | 137 | for batch_idx, (image, mask) in tqdm(enumerate(loader)): 138 | if args.cuda: 139 | image, mask = image.cuda(), mask.cuda() 140 | 141 | image, mask = Variable(image, volatile=True), Variable( 142 | mask, volatile=True) 143 | 144 | output = model(image) 145 | 146 | # test_loss += criterion(output, mask).data[0] 147 | maxes, out = torch.max(output, 1, keepdim=True) 148 | 149 | if save_output and (not train_accuracy): 150 | np.save('./npy-files/out-files/{}-batch-{}-outs.npy'.format(args.save, 151 | batch_idx), 152 | out.data.byte().cpu().numpy()) 153 | np.save('./npy-files/out-files/{}-batch-{}-masks.npy'.format(args.save, 154 | batch_idx), 155 | mask.data.byte().cpu().numpy()) 156 | np.save('./npy-files/out-files/{}-batch-{}-images.npy'.format(args.save, 157 | batch_idx), 158 | image.data.float().cpu().numpy()) 159 | 160 | if save_output and train_accuracy: 161 | np.save('./npy-files/out-files/{}-train-batch-{}-outs.npy'.format(args.save, 162 | batch_idx), 163 | out.data.byte().cpu().numpy()) 164 | np.save('./npy-files/out-files/{}-train-batch-{}-masks.npy'.format(args.save, 165 | batch_idx), 166 | mask.data.byte().cpu().numpy()) 167 | np.save('./npy-files/out-files/{}-train-batch-{}-images.npy'.format(args.save, 168 | batch_idx), 169 | image.data.float().cpu().numpy()) 170 | 171 | test_loss += criterion(output, mask).data[0] 172 | 173 | # Average Dice Coefficient 174 | test_loss /= len(loader) 175 | if train_accuracy: 176 | print('\nTraining Set: Average DICE Coefficient: {:.4f})\n'.format( 177 | test_loss)) 178 | else: 179 | print('\nTest Set: Average DICE Coefficient: {:.4f})\n'.format( 180 | test_loss)) 181 | 182 | 183 | if args.train: 184 | loss_list = [] 185 | for i in tqdm(range(args.epochs)): 186 | train(i, loss_list) 187 | test() 188 | 189 | plt.plot(loss_list) 190 | plt.title("UNet bs={}, ep={}, lr={}".format(args.batch_size, 191 | args.epochs, args.lr)) 192 | plt.xlabel("Number of iterations") 193 | plt.ylabel("Average DICE loss per batch") 194 | plt.savefig("./plots/{}-UNet_Loss_bs={}_ep={}_lr={}.png".format(args.save, 195 | args.batch_size, 196 | args.epochs, 197 | args.lr)) 198 | 199 | np.save('./npy-files/loss-files/{}-UNet_Loss_bs={}_ep={}_lr={}.npy'.format(args.save, 200 | args.batch_size, 201 | args.epochs, 202 | args.lr), 203 | np.asarray(loss_list)) 204 | 205 | torch.save(model.state_dict(), 'unet-final-{}-{}-{}'.format(args.batch_size, 206 | args.epochs, 207 | args.lr)) 208 | elif args.load is not None: 209 | model.load_state_dict(torch.load(args.load)) 210 | test(save_output=True) 211 | test(train_accuracy=True) 212 | -------------------------------------------------------------------------------- /main_bdclstm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.optim as optim 7 | from losses import DICELossMultiClass 8 | 9 | from torch.autograd import Variable 10 | from torch.utils.data import DataLoader 11 | import torchvision.transforms as tr 12 | 13 | from data import BraTSDatasetLSTM 14 | from CLSTM import BDCLSTM 15 | from models import * 16 | 17 | # %% import transforms 18 | 19 | UNET_MODEL_FILE = 'unetsmall-100-10-0.001' 20 | MODALITY = ["flair"] 21 | 22 | # %% Training settings 23 | parser = argparse.ArgumentParser(description='UNet+BDCLSTM for BraTS Dataset') 24 | parser.add_argument('--batch-size', type=int, default=4, metavar='N', 25 | help='input batch size for training (default: 64)') 26 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 27 | help='input batch size for testing (default: 1000)') 28 | parser.add_argument('--train', action='store_true', default=False, 29 | help='Argument to train model (default: False)') 30 | parser.add_argument('--epochs', type=int, default=1, metavar='N', 31 | help='number of epochs to train (default: 10)') 32 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 33 | help='learning rate (default: 0.01)') 34 | parser.add_argument('--mom', type=float, default=0.99, metavar='MOM', 35 | help='SGD momentum (default=0.99)') 36 | parser.add_argument('--cuda', action='store_true', default=False, 37 | help='enables CUDA training (default: False)') 38 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 39 | help='batches to wait before logging training status') 40 | parser.add_argument('--test-dataset', action='store_true', default=False, 41 | help='test on smaller dataset (default: False)') 42 | parser.add_argument('--size', type=int, default=128, metavar='N', 43 | help='imsize') 44 | parser.add_argument('--drop', action='store_true', default=False, 45 | help='enables drop') 46 | parser.add_argument('--data-folder', type=str, default='./Data-Nonzero/', metavar='str', 47 | help='folder that contains data (default: test dataset)') 48 | 49 | 50 | args = parser.parse_args() 51 | args.cuda = args.cuda and torch.cuda.is_available() 52 | if args.cuda: 53 | print("We are on the GPU!") 54 | 55 | DATA_FOLDER = args.data_folder 56 | 57 | # %% Loading in the Dataset 58 | dset_test = BraTSDatasetLSTM( 59 | DATA_FOLDER, keywords=MODALITY, transform=tr.ToTensor()) 60 | test_loader = DataLoader( 61 | dset_test, batch_size=args.test_batch_size, shuffle=False, num_workers=1) 62 | 63 | dset_train = BraTSDatasetLSTM( 64 | DATA_FOLDER, keywords=MODALITY, transform=tr.ToTensor()) 65 | train_loader = DataLoader( 66 | dset_train, batch_size=args.batch_size, shuffle=True, num_workers=1) 67 | 68 | 69 | # %% Loading in the models 70 | unet = UNetSmall() 71 | unet.load_state_dict(torch.load(UNET_MODEL_FILE)) 72 | model = BDCLSTM(input_channels=32, hidden_channels=[32]) 73 | 74 | if args.cuda: 75 | unet.cuda() 76 | model.cuda() 77 | 78 | # Setting Optimizer 79 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom) 80 | criterion = DICELossMultiClass() 81 | 82 | # Define Training Loop 83 | 84 | 85 | def train(epoch): 86 | model.train() 87 | for batch_idx, (image1, image2, image3, mask) in enumerate(train_loader): 88 | if args.cuda: 89 | image1, image2, image3, mask = image1.cuda(), \ 90 | image2.cuda(), \ 91 | image3.cuda(), \ 92 | mask.cuda() 93 | 94 | image1, image2, image3, mask = Variable(image1), \ 95 | Variable(image2), \ 96 | Variable(image3), \ 97 | Variable(mask) 98 | 99 | optimizer.zero_grad() 100 | 101 | map1 = unet(image1, return_features=True) 102 | map2 = unet(image2, return_features=True) 103 | map3 = unet(image3, return_features=True) 104 | 105 | output = model(map1, map2, map3) 106 | loss = criterion(output, mask) 107 | 108 | loss.backward() 109 | optimizer.step() 110 | if batch_idx % args.log_interval == 0: 111 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 112 | epoch, batch_idx * len(image1), len(train_loader.dataset), 113 | 100. * batch_idx / len(train_loader), loss.data[0])) 114 | 115 | 116 | def test(train_accuracy=False): 117 | test_loss = 0 118 | 119 | if train_accuracy == True: 120 | loader = train_loader 121 | else: 122 | loader = test_loader 123 | 124 | for (image1, image2, image3, mask) in loader: 125 | if args.cuda: 126 | image1, image2, image3, mask = image1.cuda(), \ 127 | image2.cuda(), \ 128 | image3.cuda(), \ 129 | mask.cuda() 130 | 131 | image1, image2, image3, mask = Variable(image1, volatile=True), \ 132 | Variable(image2, volatile=True), \ 133 | Variable(image3, volatile=True), \ 134 | Variable(mask, volatile=True) 135 | map1 = unet(image1, return_features=True) 136 | map2 = unet(image2, return_features=True) 137 | map3 = unet(image3, return_features=True) 138 | 139 | # print(image1.type) 140 | # print(map1.type) 141 | 142 | output = model(map1, map2, map3) 143 | test_loss += criterion(output, mask).data[0] 144 | 145 | test_loss /= len(loader) 146 | if train_accuracy: 147 | print( 148 | '\nTraining Set: Average Dice Coefficient: {:.4f}\n'.format(test_loss)) 149 | else: 150 | print( 151 | '\nTest Set: Average Dice Coefficient: {:.4f}\n'.format(test_loss)) 152 | 153 | 154 | if args.train: 155 | for i in range(args.epochs): 156 | train(i) 157 | test() 158 | 159 | torch.save(model.state_dict(), 160 | 'bdclstm-{}-{}-{}'.format(args.batch_size, args.epochs, args.lr)) 161 | else: 162 | model.load_state_dict(torch.load('bdclstm-{}-{}-{}'.format(args.batch_size, 163 | args.epochs, 164 | args.lr))) 165 | test() 166 | test(train_accuracy=True) 167 | -------------------------------------------------------------------------------- /main_small.py: -------------------------------------------------------------------------------- 1 | # %% -*- coding: utf-8 -*- 2 | ''' 3 | Author: Shreyas Padhy 4 | Driver file for Unet and BDC-LSTM Implementation 5 | ''' 6 | 7 | from __future__ import print_function 8 | 9 | import argparse 10 | import matplotlib 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.optim as optim 17 | from torch.autograd import Variable 18 | from torch.utils.data import DataLoader 19 | import torchvision.transforms as tr 20 | 21 | from data import BraTSDatasetUnet, BraTSDatasetLSTM 22 | from losses import DICELoss 23 | from models import UNetSmall 24 | from tqdm import tqdm 25 | import scipy.io as sio 26 | import numpy as np 27 | 28 | # %% import transforms 29 | 30 | # %% Training settings 31 | parser = argparse.ArgumentParser(description='UNet+BDCLSTM for BraTS Dataset') 32 | parser.add_argument('--batch-size', type=int, default=4, metavar='N', 33 | help='input batch size for training (default: 64)') 34 | parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', 35 | help='input batch size for testing (default: 1000)') 36 | parser.add_argument('--train', action='store_true', default=False, 37 | help='Argument to train model (default: False)') 38 | parser.add_argument('--epochs', type=int, default=1, metavar='N', 39 | help='number of epochs to train (default: 10)') 40 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 41 | help='learning rate (default: 0.01)') 42 | parser.add_argument('--cuda', action='store_true', default=False, 43 | help='enables CUDA training (default: False)') 44 | parser.add_argument('--log-interval', type=int, default=1, metavar='N', 45 | help='batches to wait before logging training status') 46 | parser.add_argument('--size', type=int, default=128, metavar='N', 47 | help='imsize') 48 | parser.add_argument('--load', type=str, default=None, metavar='str', 49 | help='weight file to load (default: None)') 50 | parser.add_argument('--data-folder', type=str, default='./Data/', metavar='str', 51 | help='folder that contains data (default: test dataset)') 52 | parser.add_argument('--save', type=str, default='OutMasks', metavar='str', 53 | help='Identifier to save npy arrays with') 54 | parser.add_argument('--modality', type=str, default='flair', metavar='str', 55 | help='Modality to use for training (default: flair)') 56 | parser.add_argument('--optimizer', type=str, default='ADAM', metavar='str', 57 | help='Optimizer (default: SGD)') 58 | parser.add_argument('--clip', action='store_true', default=False, 59 | help='enables gradnorm clip of 1.0 (default: False)') 60 | args = parser.parse_args() 61 | args.cuda = args.cuda and torch.cuda.is_available() 62 | 63 | DATA_FOLDER = args.data_folder 64 | 65 | # %% Loading in the Dataset 66 | dset_train = BraTSDatasetUnet(DATA_FOLDER, train=True, 67 | keywords=[args.modality], 68 | im_size=[args.size, args.size], transform=tr.ToTensor()) 69 | 70 | train_loader = DataLoader(dset_train, 71 | batch_size=args.batch_size, 72 | shuffle=True, num_workers=1) 73 | 74 | dset_test = BraTSDatasetUnet(DATA_FOLDER, train=False, 75 | keywords=[args.modality], 76 | im_size=[args.size, args.size], transform=tr.ToTensor()) 77 | 78 | test_loader = DataLoader(dset_test, 79 | batch_size=args.test_batch_size, 80 | shuffle=False, num_workers=1) 81 | 82 | 83 | print("Training Data : ", len(train_loader.dataset)) 84 | print("Testing Data : ", len(test_loader.dataset)) 85 | 86 | # %% Loading in the model 87 | model = UNetSmall() 88 | 89 | if args.cuda: 90 | model.cuda() 91 | 92 | if args.optimizer == 'SGD': 93 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 94 | momentum=0.99) 95 | if args.optimizer == 'ADAM': 96 | optimizer = optim.Adam(model.parameters(), lr=args.lr, 97 | betas=(0.9, 0.999)) 98 | 99 | 100 | # Defining Loss Function 101 | criterion = DICELoss() 102 | # Define Training Loop 103 | 104 | 105 | def train(epoch, loss_list): 106 | model.train() 107 | for batch_idx, (image, mask) in enumerate(train_loader): 108 | if args.cuda: 109 | image, mask = image.cuda(), mask.cuda() 110 | 111 | image, mask = Variable(image), Variable(mask) 112 | 113 | optimizer.zero_grad() 114 | 115 | output = model(image) 116 | 117 | loss = criterion(output, mask) 118 | loss_list.append(loss.data[0]) 119 | 120 | loss.backward() 121 | optimizer.step() 122 | 123 | if args.clip: 124 | nn.utils.clip_grad_norm(model.parameters(), max_norm=1) 125 | 126 | if batch_idx % args.log_interval == 0: 127 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 128 | epoch, batch_idx * len(image), len(train_loader.dataset), 129 | 100. * batch_idx / len(train_loader), loss.data[0])) 130 | 131 | 132 | def test(train_accuracy=False, save_output=False): 133 | test_loss = 0 134 | correct = 0 135 | 136 | if train_accuracy: 137 | loader = train_loader 138 | else: 139 | loader = test_loader 140 | 141 | for batch_idx, (image, mask) in tqdm(enumerate(loader)): 142 | if args.cuda: 143 | image, mask = image.cuda(), mask.cuda() 144 | 145 | image, mask = Variable(image, volatile=True), Variable( 146 | mask, volatile=True) 147 | 148 | output = model(image) 149 | 150 | test_loss += criterion(output, mask).data[0] 151 | 152 | output.data.round_() 153 | 154 | if save_output and (not train_accuracy): 155 | np.save('./npy-files/out-files/{}-unetsmall-batch-{}-outs.npy'.format(args.save, 156 | batch_idx), 157 | output.data.byte().cpu().numpy()) 158 | np.save('./npy-files/out-files/{}-unetsmall--batch-{}-masks.npy'.format(args.save, 159 | batch_idx), 160 | mask.data.byte().cpu().numpy()) 161 | np.save('./npy-files/out-files/{}-unetsmall--batch-{}-images.npy'.format(args.save, 162 | batch_idx), 163 | image.data.float().cpu().numpy()) 164 | 165 | if save_output and train_accuracy: 166 | np.save('./npy-files/out-files/{}-unetsmall-train-batch-{}-outs.npy'.format(args.save, 167 | batch_idx), 168 | output.data.byte().cpu().numpy()) 169 | np.save('./npy-files/out-files/{}-unetsmall-train-batch-{}-masks.npy'.format(args.save, 170 | batch_idx), 171 | mask.data.byte().cpu().numpy()) 172 | np.save('./npy-files/out-files/{}-unetsmall-train-batch-{}-images.npy'.format(args.save, 173 | batch_idx), 174 | image.data.float().cpu().numpy()) 175 | 176 | # Average Dice Coefficient 177 | test_loss /= len(loader) 178 | if train_accuracy: 179 | print('\nTraining Set: Average DICE Coefficient: {:.4f})\n'.format( 180 | test_loss)) 181 | else: 182 | print('\nTest Set: Average DICE Coefficient: {:.4f})\n'.format( 183 | test_loss)) 184 | 185 | 186 | if args.train: 187 | loss_list = [] 188 | for i in tqdm(range(args.epochs)): 189 | train(i, loss_list) 190 | test(train_accuracy=False, save_output=False) 191 | test(train_accuracy=True, save_output=False) 192 | 193 | plt.plot(loss_list) 194 | plt.title("UNetSmall bs={}, ep={}, lr={}".format(args.batch_size, 195 | args.epochs, args.lr)) 196 | plt.xlabel("Number of iterations") 197 | plt.ylabel("Average DICE loss per batch") 198 | plt.savefig("./plots/{}-UNetSmall_Loss_bs={}_ep={}_lr={}.png".format(args.save, 199 | args.batch_size, 200 | args.epochs, 201 | args.lr)) 202 | 203 | np.save('./npy-files/loss-files/{}-UNetSmall_Loss_bs={}_ep={}_lr={}.npy'.format(args.save, 204 | args.batch_size, 205 | args.epochs, 206 | args.lr), 207 | np.asarray(loss_list)) 208 | 209 | torch.save(model.state_dict(), 'unetsmall-final-{}-{}-{}'.format(args.batch_size, 210 | args.epochs, 211 | args.lr)) 212 | else: 213 | model.load_state_dict(torch.load(args.load)) 214 | test(save_output=True) 215 | test(train_accuracy=True) 216 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class UNet(nn.Module): 7 | def __init__(self, num_channels=1, num_classes=2): 8 | super(UNet, self).__init__() 9 | num_feat = [64, 128, 256, 512, 1024] 10 | 11 | self.down1 = nn.Sequential(Conv3x3(num_channels, num_feat[0])) 12 | 13 | self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 14 | Conv3x3(num_feat[0], num_feat[1])) 15 | 16 | self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 17 | Conv3x3(num_feat[1], num_feat[2])) 18 | 19 | self.down4 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 20 | Conv3x3(num_feat[2], num_feat[3])) 21 | 22 | self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2), 23 | Conv3x3(num_feat[3], num_feat[4])) 24 | 25 | self.up1 = UpConcat(num_feat[4], num_feat[3]) 26 | self.upconv1 = Conv3x3(num_feat[4], num_feat[3]) 27 | 28 | self.up2 = UpConcat(num_feat[3], num_feat[2]) 29 | self.upconv2 = Conv3x3(num_feat[3], num_feat[2]) 30 | 31 | self.up3 = UpConcat(num_feat[2], num_feat[1]) 32 | self.upconv3 = Conv3x3(num_feat[2], num_feat[1]) 33 | 34 | self.up4 = UpConcat(num_feat[1], num_feat[0]) 35 | self.upconv4 = Conv3x3(num_feat[1], num_feat[0]) 36 | 37 | self.final = nn.Sequential(nn.Conv2d(num_feat[0], 38 | num_classes, 39 | kernel_size=1), 40 | nn.Softmax2d()) 41 | 42 | def forward(self, inputs, return_features=False): 43 | # print(inputs.data.size()) 44 | down1_feat = self.down1(inputs) 45 | # print(down1_feat.size()) 46 | down2_feat = self.down2(down1_feat) 47 | # print(down2_feat.size()) 48 | down3_feat = self.down3(down2_feat) 49 | # print(down3_feat.size()) 50 | down4_feat = self.down4(down3_feat) 51 | # print(down4_feat.size()) 52 | bottom_feat = self.bottom(down4_feat) 53 | 54 | # print(bottom_feat.size()) 55 | up1_feat = self.up1(bottom_feat, down4_feat) 56 | # print(up1_feat.size()) 57 | up1_feat = self.upconv1(up1_feat) 58 | # print(up1_feat.size()) 59 | up2_feat = self.up2(up1_feat, down3_feat) 60 | # print(up2_feat.size()) 61 | up2_feat = self.upconv2(up2_feat) 62 | # print(up2_feat.size()) 63 | up3_feat = self.up3(up2_feat, down2_feat) 64 | # print(up3_feat.size()) 65 | up3_feat = self.upconv3(up3_feat) 66 | # print(up3_feat.size()) 67 | up4_feat = self.up4(up3_feat, down1_feat) 68 | # print(up4_feat.size()) 69 | up4_feat = self.upconv4(up4_feat) 70 | # print(up4_feat.size()) 71 | 72 | if return_features: 73 | outputs = up4_feat 74 | else: 75 | outputs = self.final(up4_feat) 76 | 77 | return outputs 78 | 79 | 80 | class UNetSmall(nn.Module): 81 | def __init__(self, num_channels=1, num_classes=2): 82 | super(UNetSmall, self).__init__() 83 | num_feat = [32, 64, 128, 256] 84 | 85 | self.down1 = nn.Sequential(Conv3x3Small(num_channels, num_feat[0])) 86 | 87 | self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 88 | nn.BatchNorm2d(num_feat[0]), 89 | Conv3x3Small(num_feat[0], num_feat[1])) 90 | 91 | self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2), 92 | nn.BatchNorm2d(num_feat[1]), 93 | Conv3x3Small(num_feat[1], num_feat[2])) 94 | 95 | self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2), 96 | nn.BatchNorm2d(num_feat[2]), 97 | Conv3x3Small(num_feat[2], num_feat[3]), 98 | nn.BatchNorm2d(num_feat[3])) 99 | 100 | self.up1 = UpSample(num_feat[3], num_feat[2]) 101 | self.upconv1 = nn.Sequential(Conv3x3Small(num_feat[3] + num_feat[2], num_feat[2]), 102 | nn.BatchNorm2d(num_feat[2])) 103 | 104 | self.up2 = UpSample(num_feat[2], num_feat[1]) 105 | self.upconv2 = nn.Sequential(Conv3x3Small(num_feat[2] + num_feat[1], num_feat[1]), 106 | nn.BatchNorm2d(num_feat[1])) 107 | 108 | self.up3 = UpSample(num_feat[1], num_feat[0]) 109 | self.upconv3 = nn.Sequential(Conv3x3Small(num_feat[1] + num_feat[0], num_feat[0]), 110 | nn.BatchNorm2d(num_feat[0])) 111 | 112 | self.final = nn.Sequential(nn.Conv2d(num_feat[0], 113 | 1, 114 | kernel_size=1), 115 | nn.Sigmoid()) 116 | 117 | def forward(self, inputs, return_features=False): 118 | # print(inputs.data.size()) 119 | down1_feat = self.down1(inputs) 120 | # print(down1_feat.size()) 121 | down2_feat = self.down2(down1_feat) 122 | # print(down2_feat.size()) 123 | down3_feat = self.down3(down2_feat) 124 | # print(down3_feat.size()) 125 | bottom_feat = self.bottom(down3_feat) 126 | 127 | # print(bottom_feat.size()) 128 | up1_feat = self.up1(bottom_feat, down3_feat) 129 | # print(up1_feat.size()) 130 | up1_feat = self.upconv1(up1_feat) 131 | # print(up1_feat.size()) 132 | up2_feat = self.up2(up1_feat, down2_feat) 133 | # print(up2_feat.size()) 134 | up2_feat = self.upconv2(up2_feat) 135 | # print(up2_feat.size()) 136 | up3_feat = self.up3(up2_feat, down1_feat) 137 | # print(up3_feat.size()) 138 | up3_feat = self.upconv3(up3_feat) 139 | # print(up3_feat.size()) 140 | 141 | if return_features: 142 | outputs = up3_feat 143 | else: 144 | outputs = self.final(up3_feat) 145 | 146 | return outputs 147 | 148 | 149 | class Conv3x3(nn.Module): 150 | def __init__(self, in_feat, out_feat): 151 | super(Conv3x3, self).__init__() 152 | 153 | self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat, 154 | kernel_size=3, 155 | stride=1, 156 | padding=1), 157 | nn.BatchNorm2d(out_feat), 158 | nn.ReLU()) 159 | 160 | self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat, 161 | kernel_size=3, 162 | stride=1, 163 | padding=1), 164 | nn.BatchNorm2d(out_feat), 165 | nn.ReLU()) 166 | 167 | def forward(self, inputs): 168 | outputs = self.conv1(inputs) 169 | outputs = self.conv2(outputs) 170 | return outputs 171 | 172 | 173 | class Conv3x3Drop(nn.Module): 174 | def __init__(self, in_feat, out_feat): 175 | super(Conv3x3Drop, self).__init__() 176 | 177 | self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat, 178 | kernel_size=3, 179 | stride=1, 180 | padding=1), 181 | nn.Dropout(p=0.2), 182 | nn.ReLU()) 183 | 184 | self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat, 185 | kernel_size=3, 186 | stride=1, 187 | padding=1), 188 | nn.BatchNorm2d(out_feat), 189 | nn.ReLU()) 190 | 191 | def forward(self, inputs): 192 | outputs = self.conv1(inputs) 193 | outputs = self.conv2(outputs) 194 | return outputs 195 | 196 | 197 | class Conv3x3Small(nn.Module): 198 | def __init__(self, in_feat, out_feat): 199 | super(Conv3x3Small, self).__init__() 200 | 201 | self.conv1 = nn.Sequential(nn.Conv2d(in_feat, out_feat, 202 | kernel_size=3, 203 | stride=1, 204 | padding=1), 205 | nn.ELU(), 206 | nn.Dropout(p=0.2)) 207 | 208 | self.conv2 = nn.Sequential(nn.Conv2d(out_feat, out_feat, 209 | kernel_size=3, 210 | stride=1, 211 | padding=1), 212 | nn.ELU()) 213 | 214 | def forward(self, inputs): 215 | outputs = self.conv1(inputs) 216 | outputs = self.conv2(outputs) 217 | return outputs 218 | 219 | 220 | class UpConcat(nn.Module): 221 | def __init__(self, in_feat, out_feat): 222 | super(UpConcat, self).__init__() 223 | 224 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 225 | 226 | # self.deconv = nn.ConvTranspose2d(in_feat, out_feat, 227 | # kernel_size=3, 228 | # stride=1, 229 | # dilation=1) 230 | 231 | self.deconv = nn.ConvTranspose2d(in_feat, 232 | out_feat, 233 | kernel_size=2, 234 | stride=2) 235 | 236 | def forward(self, inputs, down_outputs): 237 | # TODO: Upsampling required after deconv? 238 | # outputs = self.up(inputs) 239 | outputs = self.deconv(inputs) 240 | out = torch.cat([down_outputs, outputs], 1) 241 | return out 242 | 243 | 244 | class UpSample(nn.Module): 245 | def __init__(self, in_feat, out_feat): 246 | super(UpSample, self).__init__() 247 | 248 | self.up = nn.Upsample(scale_factor=2, mode='nearest') 249 | 250 | self.deconv = nn.ConvTranspose2d(in_feat, 251 | out_feat, 252 | kernel_size=2, 253 | stride=2) 254 | 255 | def forward(self, inputs, down_outputs): 256 | # TODO: Upsampling required after deconv? 257 | outputs = self.up(inputs) 258 | # outputs = self.deconv(inputs) 259 | out = torch.cat([outputs, down_outputs], 1) 260 | return out 261 | -------------------------------------------------------------------------------- /plot_ims.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import scipy.io as sio 4 | # outs = np.load('Numpy-batch-7-outs.npy') 5 | # masks = np.load('Numpy-batch-7-masks.npy') 6 | 7 | final_outs = [] 8 | final_masks = [] 9 | final_images = [] 10 | 11 | loss = np.load('./OutMasks-UNet_Loss_bs=4_ep=1_lr=0.001.npy') 12 | 13 | for i in range(38): 14 | outs = np.load( 15 | './npy-files/out-files/UNetSmall-NonZero-unetsmall-batch-{}-outs.npy'.format(i)) 16 | masks = np.load( 17 | './npy-files/out-files/UNetSmall-NonZero-unetsmall--batch-{}-masks.npy'.format(i)) 18 | images = np.load( 19 | './npy-files/out-files/UNetSmall-NonZero-unetsmall--batch-{}-images.npy'.format(i)) 20 | 21 | final_outs.append(outs) 22 | final_masks.append(masks) 23 | final_images.append(images) 24 | final_outs = np.asarray(final_outs) 25 | final_masks = np.asarray(final_masks) 26 | final_images = np.asarray(final_images) 27 | 28 | print(final_images[0].shape) 29 | for i in range(38): 30 | print(final_images[i].shape) 31 | plt.imshow(np.squeeze(final_images[i][49, :, :]), cmap='gray') 32 | plt.show() 33 | 34 | 35 | # print(final_outs.shape) 36 | # print(final_masks.shape) 37 | 38 | # sio.savemat('./mat-files/final_outputs.mat', {'data': final_outs}) 39 | # sio.savemat('./mat-files/final_masks.mat', {'data': final_masks}) 40 | # sio.savemat('./mat-files/final_images.mat', {'data': final_images}) 41 | # for i in range(len(outs)): 42 | # plt1 = 255 * np.squeeze(outs[i, :, :, :]).astype('uint8') 43 | # plt2 = 255 * np.squeeze(masks[i, :, :, :]).astype('uint8') 44 | # print(plt1, plt2) 45 | # plt.subplot(1, 2, 1) 46 | # plt.imshow(plt1, cmap='gray') 47 | # plt.title("UNet Out") 48 | # plt.subplot(1, 2, 2) 49 | # plt.imshow(plt2, cmap='gray') 50 | # plt.title("Mask") 51 | -------------------------------------------------------------------------------- /result_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/result_comparison.png -------------------------------------------------------------------------------- /smallunet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/smallunet.png -------------------------------------------------------------------------------- /unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shreyaspadhy/UNet-Zoo/294b890d125e70e78cabe9d773a33b78f65d25f1/unet.png --------------------------------------------------------------------------------