├── .gitignore ├── .ipynb_checkpoints └── Run_ReLayNet-checkpoint.ipynb ├── LICENSE ├── README.md ├── Run_ReLayNet.ipynb ├── build └── lib │ └── relaynet_pytorch │ ├── __init__.py │ ├── data_utils.py │ ├── net_api │ ├── __init__.py │ ├── losses.py │ └── sub_module.py │ ├── relay_net.py │ └── solver.py ├── datasets └── .gitignore ├── dist ├── relaynet-pytorch-1.1.tar.gz └── relaynet_pytorch-1.1-py3-none-any.whl ├── models └── Exp01 │ ├── relaynet_epoch1.model │ ├── relaynet_epoch10.model │ ├── relaynet_epoch11.model │ ├── relaynet_epoch12.model │ ├── relaynet_epoch13.model │ ├── relaynet_epoch14.model │ ├── relaynet_epoch15.model │ ├── relaynet_epoch16.model │ ├── relaynet_epoch17.model │ ├── relaynet_epoch18.model │ ├── relaynet_epoch19.model │ ├── relaynet_epoch2.model │ ├── relaynet_epoch20.model │ ├── relaynet_epoch3.model │ ├── relaynet_epoch4.model │ ├── relaynet_epoch5.model │ ├── relaynet_epoch6.model │ ├── relaynet_epoch7.model │ ├── relaynet_epoch8.model │ └── relaynet_epoch9.model ├── relaynet_pytorch.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt └── top_level.txt ├── relaynet_pytorch ├── .gitignore ├── __init__.py ├── data_utils.py ├── net_api │ ├── __init__.py │ ├── losses.py │ └── sub_module.py ├── relay_net.py └── solver.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | networks/classifiers/__pycache__/ 3 | 4 | *.pyc 5 | 6 | \.idea/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Abhijit Guha Roy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # relaynet_pytorch 2 | 3 | PyTorch Implementation of ReLayNet. There are still some bugs and issues in the code, we are working on fixing them. 4 | 5 | Coded by Abhijit Guha Roy and Shayan Siddiqui (https://github.com/shayansiddiqui) 6 | 7 | If you use this code for any academic purpose, please cite: 8 | 9 | A. Guha Roy, S. Conjeti, S.P.K.Karri, D.Sheet, A.Katouzian, C.Wachinger, and N.Navab, "ReLayNet: retinal layer and fluid segmentation of macular optical coherence tomography using fully convolutional networks," Biomed. Opt. Express 8, 3627-3642 (2017) 10 | Link: https://arxiv.org/abs/1704.02161 11 | 12 | Enjoy!! :) 13 | -------------------------------------------------------------------------------- /build/lib/relaynet_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/build/lib/relaynet_pytorch/__init__.py -------------------------------------------------------------------------------- /build/lib/relaynet_pytorch/data_utils.py: -------------------------------------------------------------------------------- 1 | """Data utility functions.""" 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | import h5py 8 | 9 | 10 | class ImdbData(data.Dataset): 11 | def __init__(self, X, y, w): 12 | self.X = X 13 | self.y = y 14 | self.w = w 15 | 16 | def __getitem__(self, index): 17 | img = self.X[index] 18 | label = self.y[index] 19 | weight = self.w[index] 20 | 21 | img = torch.from_numpy(img) 22 | label = torch.from_numpy(label) 23 | weight = torch.from_numpy(weight) 24 | return img, label, weight 25 | 26 | def __len__(self): 27 | return len(self.y) 28 | 29 | 30 | def get_imdb_data(): 31 | # TODO: Need to change later 32 | NumClass = 9 33 | 34 | # Load DATA 35 | Data = h5py.File('datasets/Data.h5', 'r') 36 | a_group_key = list(Data.keys())[0] 37 | Data = list(Data[a_group_key]) 38 | Data = np.squeeze(np.asarray(Data)) 39 | Label = h5py.File('datasets/label.h5', 'r') 40 | a_group_key = list(Label.keys())[0] 41 | Label = list(Label[a_group_key]) 42 | Label = np.squeeze(np.asarray(Label)) 43 | set = h5py.File('datasets/set.h5', 'r') 44 | a_group_key = list(set.keys())[0] 45 | set = list(set[a_group_key]) 46 | set = np.squeeze(np.asarray(set)) 47 | sz = Data.shape 48 | Data = Data.reshape([sz[0], 1, sz[1], sz[2]]) 49 | Data = Data[:, :, 61:573, :] 50 | weights = Label[:, 1, 61:573, :] 51 | Label = Label[:, 0, 61:573, :] 52 | sz = Label.shape 53 | Label = Label.reshape([sz[0], 1, sz[1], sz[2]]) 54 | weights = weights.reshape([sz[0], 1, sz[1], sz[2]]) 55 | train_id = set == 1 56 | test_id = set == 3 57 | 58 | Tr_Dat = Data[train_id, :, :, :] 59 | Tr_Label = np.squeeze(Label[train_id, :, :, :]) - 1 # Index from [0-(NumClass-1)] 60 | Tr_weights = weights[train_id, :, :, :] 61 | Tr_weights = np.tile(Tr_weights, [1, NumClass, 1, 1]) 62 | 63 | Te_Dat = Data[test_id, :, :, :] 64 | Te_Label = np.squeeze(Label[test_id, :, :, :]) - 1 65 | Te_weights = weights[test_id, :, :, :] 66 | Te_weights = np.tile(Te_weights, [1, NumClass, 1, 1]) 67 | 68 | 69 | 70 | return (ImdbData(Tr_Dat, Tr_Label, Tr_weights), 71 | ImdbData(Te_Dat, Te_Label, Te_weights)) 72 | -------------------------------------------------------------------------------- /build/lib/relaynet_pytorch/net_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/build/lib/relaynet_pytorch/net_api/__init__.py -------------------------------------------------------------------------------- /build/lib/relaynet_pytorch/net_api/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn.modules.loss import _Loss 4 | from torch.autograd import Function, Variable 5 | import torch.nn as nn 6 | import torch 7 | import numpy as np 8 | from torch.nn.modules.loss import _Loss 9 | from torch.autograd import Function, Variable 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class DiceCoeff(nn.Module): 15 | """Dice coeff for individual examples""" 16 | 17 | def __init__(self): 18 | super(DiceCoeff, self).__init__() 19 | 20 | def forward(self, input, target): 21 | inter = torch.dot(input, target) + 0.0001 22 | union = torch.sum(input ** 2) + torch.sum(target ** 2) + 0.0001 23 | 24 | t = 2 * inter.float() / union.float() 25 | return t 26 | 27 | 28 | def dice_coeff(input, target): 29 | """Dice coeff for batches""" 30 | if input.is_cuda: 31 | s = Variable(torch.FloatTensor(1).cuda().zero_()) 32 | else: 33 | s = Variable(torch.FloatTensor(1).zero_()) 34 | 35 | for i, c in enumerate(zip(input, target)): 36 | s = s + DiceCoeff().forward(c[0], c[1]) 37 | 38 | return s / (i + 1) 39 | 40 | 41 | class DiceLoss(_Loss): 42 | def forward(self, output, target, weights=None, ignore_index=None): 43 | """ 44 | output : NxCxHxW Variable 45 | target : NxHxW LongTensor 46 | weights : C FloatTensor 47 | ignore_index : int index to ignore from loss 48 | """ 49 | eps = 0.0001 50 | 51 | output = output.exp() 52 | encoded_target = output.detach() * 0 53 | if ignore_index is not None: 54 | mask = target == ignore_index 55 | target = target.clone() 56 | target[mask] = 0 57 | encoded_target.scatter_(1, target.unsqueeze(1), 1) 58 | mask = mask.unsqueeze(1).expand_as(encoded_target) 59 | encoded_target[mask] = 0 60 | else: 61 | encoded_target.scatter_(1, target.unsqueeze(1), 1) 62 | 63 | if weights is None: 64 | weights = 1 65 | 66 | intersection = output * encoded_target 67 | numerator = 2 * intersection.sum(0).sum(1).sum(1) 68 | denominator = output + encoded_target 69 | 70 | if ignore_index is not None: 71 | denominator[mask] = 0 72 | denominator = denominator.sum(0).sum(1).sum(1) + eps 73 | loss_per_channel = weights * (1 - (numerator / denominator)) 74 | 75 | return loss_per_channel.sum() / output.size(1) 76 | 77 | 78 | class CrossEntropyLoss2d(nn.Module): 79 | def __init__(self, weight=None, size_average=True): 80 | super(CrossEntropyLoss2d, self).__init__() 81 | self.nll_loss = nn.CrossEntropyLoss(weight, size_average) 82 | 83 | def forward(self, inputs, targets): 84 | return self.nll_loss(inputs, targets) 85 | 86 | 87 | class CombinedLoss(nn.Module): 88 | def __init__(self): 89 | super(CombinedLoss, self).__init__() 90 | self.cross_entropy_loss = CrossEntropyLoss2d() 91 | self.dice_loss = DiceLoss() 92 | 93 | def forward(self, input, target, weight): 94 | # TODO: why? 95 | target = target.type(torch.LongTensor).cuda() 96 | input_soft = F.softmax(input,dim=1) 97 | y2 = torch.mean(self.dice_loss(input_soft, target)) 98 | y1 = torch.mean(torch.mul(self.cross_entropy_loss.forward(input, target), weight)) 99 | y = y1 + y2 100 | return y 101 | 102 | -------------------------------------------------------------------------------- /build/lib/relaynet_pytorch/net_api/sub_module.py: -------------------------------------------------------------------------------- 1 | # List of APIs 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | ''' 8 | param ={ 9 | 'num_channels':1, 10 | 'num_filters':64, 11 | 'kernel_h':7, 12 | 'kernel_w':3, 13 | 'stride_conv':1, 14 | 'pool':2, 15 | 'stride_pool':2, 16 | 'num_classes':10 17 | } 18 | 19 | ''' 20 | 21 | def __init__(self, params): 22 | super(BasicBlock, self).__init__() 23 | 24 | padding_h = int((params['kernel_h'] - 1) / 2) 25 | padding_w = int((params['kernel_w'] - 1) / 2) 26 | 27 | self.conv = nn.Conv2d(in_channels=params['num_channels'], out_channels=params['num_filters'], 28 | kernel_size=(params['kernel_h'], params['kernel_w']), 29 | padding=(padding_h, padding_w), 30 | stride=params['stride_conv']) 31 | self.batchnorm = nn.BatchNorm2d(num_features=params['num_filters']) 32 | self.prelu = nn.PReLU() 33 | 34 | def forward(self, input): 35 | out_conv = self.conv(input) 36 | out_bn = self.batchnorm(out_conv) 37 | out_prelu = self.prelu(out_bn) 38 | return out_prelu 39 | 40 | 41 | class EncoderBlock(BasicBlock): 42 | def __init__(self, params): 43 | super(EncoderBlock, self).__init__(params) 44 | self.maxpool = nn.MaxPool2d(kernel_size=params['pool'], stride=params['stride_pool'], return_indices=True) 45 | 46 | def forward(self, input): 47 | out_block = super(EncoderBlock, self).forward(input) 48 | out_encoder, indices = self.maxpool(out_block) 49 | return out_encoder, out_block, indices 50 | 51 | 52 | class DecoderBlock(BasicBlock): 53 | def __init__(self, params): 54 | super(DecoderBlock, self).__init__(params) 55 | self.unpool = nn.MaxUnpool2d(kernel_size=params['pool'], stride=params['stride_pool']) 56 | 57 | def forward(self, input, out_block, indices): 58 | unpool = self.unpool(input, indices) 59 | concat = torch.cat((out_block, unpool), dim=1) 60 | out_block = super(DecoderBlock, self).forward(concat) 61 | 62 | return out_block 63 | 64 | 65 | class ClassifierBlock(nn.Module): 66 | def __init__(self, params): 67 | super(ClassifierBlock, self).__init__() 68 | self.conv = nn.Conv2d(params['num_channels'], params['num_class'], params['kernel_c'], params['stride_conv']) 69 | self.softmax = nn.Softmax2d() 70 | 71 | def forward(self, input): 72 | out_conv = self.conv(input) 73 | #out_logit = self.softmax(out_conv) 74 | return out_conv 75 | -------------------------------------------------------------------------------- /build/lib/relaynet_pytorch/relay_net.py: -------------------------------------------------------------------------------- 1 | """ClassificationCNN""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | from relaynet_pytorch.net_api import sub_module as sm 6 | 7 | 8 | class ReLayNet(nn.Module): 9 | """ 10 | A PyTorch implementation of ReLayNet 11 | Coded by Shayan and Abhijit 12 | 13 | param ={ 14 | 'num_channels':1, 15 | 'num_filters':64, 16 | 'num_channels':64, 17 | 'kernel_h':7, 18 | 'kernel_w':3, 19 | 'stride_conv':1, 20 | 'pool':2, 21 | 'stride_pool':2, 22 | 'num_classes':10 23 | } 24 | 25 | """ 26 | 27 | def __init__(self, params): 28 | super(ReLayNet, self).__init__() 29 | 30 | self.encode1 = sm.EncoderBlock(params) 31 | params['num_channels'] = 64 32 | self.encode2 = sm.EncoderBlock(params) 33 | # params['num_channels'] = 64 # This can be used to change the numchannels for each block 34 | self.encode3 = sm.EncoderBlock(params) 35 | self.bottleneck = sm.BasicBlock(params) 36 | params['num_channels'] = 128 37 | self.decode1 = sm.DecoderBlock(params) 38 | self.decode2 = sm.DecoderBlock(params) 39 | self.decode3 = sm.DecoderBlock(params) 40 | params['num_channels'] = 64 41 | self.classifier = sm.ClassifierBlock(params) 42 | 43 | def forward(self, input): 44 | e1, out1, ind1 = self.encode1.forward(input) 45 | e2, out2, ind2 = self.encode2.forward(e1) 46 | e3, out3, ind3 = self.encode3.forward(e2) 47 | bn = self.bottleneck.forward(e3) 48 | 49 | d3 = self.decode1.forward(bn, out3, ind3) 50 | d2 = self.decode2.forward(d3, out2, ind2) 51 | d1 = self.decode3.forward(d2, out1, ind1) 52 | prob = self.classifier.forward(d1) 53 | 54 | return prob 55 | 56 | @property 57 | def is_cuda(self): 58 | """ 59 | Check if model parameters are allocated on the GPU. 60 | """ 61 | return next(self.parameters()).is_cuda 62 | 63 | def save(self, path): 64 | """ 65 | Save model with its parameters to the given path. Conventionally the 66 | path should end with "*.model". 67 | 68 | Inputs: 69 | - path: path string 70 | """ 71 | print('Saving model... %s' % path) 72 | torch.save(self, path) 73 | -------------------------------------------------------------------------------- /build/lib/relaynet_pytorch/solver.py: -------------------------------------------------------------------------------- 1 | from random import shuffle 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch 5 | from torch.autograd import Variable 6 | from relaynet_pytorch.net_api.losses import CombinedLoss 7 | from torch.optim import lr_scheduler 8 | import os 9 | 10 | 11 | def per_class_dice(y_pred, y_true, num_class): 12 | avg_dice = 0 13 | y_pred = y_pred.data.cpu().numpy() 14 | y_true = y_true.data.cpu().numpy() 15 | for i in range(num_class): 16 | GT = y_true == (i + 1) 17 | Pred = y_pred == (i + 1) 18 | inter = np.sum(np.matmul(GT, Pred)) + 0.0001 19 | union = np.sum(GT) + np.sum(Pred) + 0.0001 20 | t = 2 * inter / union 21 | avg_dice = avg_dice + (t / num_class) 22 | return avg_dice 23 | 24 | 25 | def create_exp_directory(exp_dir_name): 26 | if not os.path.exists('models/' + exp_dir_name): 27 | os.makedirs('models/' + exp_dir_name) 28 | 29 | 30 | class Solver(object): 31 | # global optimiser parameters 32 | default_optim_args = {"lr": 1e-2, 33 | "betas": (0.9, 0.999), 34 | "eps": 1e-8, 35 | "weight_decay": 0.0001} 36 | gamma = 0.5 37 | step_size = 5 38 | NumClass = 9 39 | 40 | def __init__(self, optim=torch.optim.Adam, optim_args={}, 41 | loss_func=CombinedLoss()): 42 | optim_args_merged = self.default_optim_args.copy() 43 | optim_args_merged.update(optim_args) 44 | self.optim_args = optim_args_merged 45 | self.optim = optim 46 | self.loss_func = loss_func 47 | 48 | self._reset_histories() 49 | 50 | def _reset_histories(self): 51 | """ 52 | Resets train and val histories for the accuracy and the loss. 53 | """ 54 | self.train_loss_history = [] 55 | self.train_acc_history = [] 56 | self.val_acc_history = [] 57 | self.val_loss_history = [] 58 | 59 | def train(self, model, train_loader, val_loader, num_epochs=10, log_nth=5, exp_dir_name='exp_default'): 60 | """ 61 | Train a given model with the provided data. 62 | 63 | Inputs: 64 | - model: model object initialized from a torch.nn.Module 65 | - train_loader: train data in torch.utils.data.DataLoader 66 | - val_loader: val data in torch.utils.data.DataLoader 67 | - num_epochs: total number of training epochs 68 | - log_nth: log training accuracy and loss every nth iteration 69 | """ 70 | optim = self.optim(model.parameters(), **self.optim_args) 71 | scheduler = lr_scheduler.StepLR(optim, step_size=self.step_size, 72 | gamma=self.gamma) # decay LR by a factor of 0.5 every 5 epochs 73 | 74 | self._reset_histories() 75 | iter_per_epoch = 1 76 | # iter_per_epoch = len(train_loader) 77 | 78 | if torch.cuda.is_available(): 79 | model.cuda() 80 | 81 | print('START TRAIN.') 82 | curr_iter = 0 83 | 84 | create_exp_directory(exp_dir_name) 85 | 86 | for epoch in range(num_epochs): 87 | scheduler.step() 88 | for i_batch, sample_batched in enumerate(train_loader): 89 | X = Variable(sample_batched[0]) 90 | y = Variable(sample_batched[1]) 91 | w = Variable(sample_batched[2]) 92 | 93 | if model.is_cuda: 94 | X, y, w = X.cuda(), y.cuda(), w.cuda() 95 | 96 | for iter in range(iter_per_epoch): 97 | curr_iter += iter 98 | optim.zero_grad() 99 | output = model(X) 100 | loss = self.loss_func(output, y, w) 101 | loss.backward() 102 | optim.step() 103 | if iter % log_nth == 0: 104 | self.train_loss_history.append(loss.data[0]) 105 | #print('[Iteration : ' + str(iter) + '/' + str(iter_per_epoch * num_epochs) + '] : ' + str( 106 | # loss.data[0])) 107 | 108 | 109 | #_, batch_output = torch.max(F.softmax(model(X),dim=1), dim=1) 110 | #avg_dice = per_class_dice(batch_output, y, self.NumClass) 111 | #print('Per class average dice score is ' + str(avg_dice)) 112 | # self.train_acc_history.append(train_accuracy) 113 | # 114 | # val_output = torch.max(model(Variable(torch.from_numpy(val_loader.dataset.X))), dim= 1) 115 | # val_accuracy = self.accuracy(val_output[1], Variable(torch.from_numpy(val_loader.dataset.y))) 116 | # self.val_acc_history.append(val_accuracy) 117 | print('[Epoch : ' + str(epoch) + '/' + str(num_epochs) + '] : ' + str(loss.data[0])) 118 | model.save('models/' + exp_dir_name + '/relaynet_epoch' + str(epoch + 1) + '.model') 119 | print('FINISH.') 120 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | *.p 2 | /segmentation_data 3 | /segmentation_data_test 4 | seg.py -------------------------------------------------------------------------------- /dist/relaynet-pytorch-1.1.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/dist/relaynet-pytorch-1.1.tar.gz -------------------------------------------------------------------------------- /dist/relaynet_pytorch-1.1-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/dist/relaynet_pytorch-1.1-py3-none-any.whl -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch1.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch1.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch10.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch10.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch11.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch11.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch12.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch12.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch13.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch13.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch14.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch14.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch15.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch15.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch16.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch16.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch17.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch17.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch18.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch18.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch19.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch19.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch2.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch2.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch20.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch20.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch3.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch3.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch4.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch4.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch5.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch5.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch6.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch6.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch7.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch7.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch8.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch8.model -------------------------------------------------------------------------------- /models/Exp01/relaynet_epoch9.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/models/Exp01/relaynet_epoch9.model -------------------------------------------------------------------------------- /relaynet_pytorch.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: relaynet-pytorch 3 | Version: 1.1 4 | Summary: Retinal Layer and Fluid Segmentation of Macular Optical Coherence Tomography using Fully Convolutional Networks. 5 | Home-page: https://github.com/abhi4ssj/relaynet_pytorch 6 | Author: Abhijit Guha Roy 7 | Author-email: abhi4ssj@gmail.com 8 | License: UNKNOWN 9 | Description-Content-Type: UNKNOWN 10 | Description: UNKNOWN 11 | Platform: UNKNOWN 12 | -------------------------------------------------------------------------------- /relaynet_pytorch.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | README.md 2 | setup.py 3 | relaynet_pytorch/__init__.py 4 | relaynet_pytorch/data_utils.py 5 | relaynet_pytorch/relay_net.py 6 | relaynet_pytorch/solver.py 7 | relaynet_pytorch.egg-info/PKG-INFO 8 | relaynet_pytorch.egg-info/SOURCES.txt 9 | relaynet_pytorch.egg-info/dependency_links.txt 10 | relaynet_pytorch.egg-info/top_level.txt 11 | relaynet_pytorch/net_api/__init__.py 12 | relaynet_pytorch/net_api/losses.py 13 | relaynet_pytorch/net_api/sub_module.py -------------------------------------------------------------------------------- /relaynet_pytorch.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /relaynet_pytorch.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | relaynet_pytorch 2 | -------------------------------------------------------------------------------- /relaynet_pytorch/.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | im2col_cython.c 3 | im2col_cython.so 4 | -------------------------------------------------------------------------------- /relaynet_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/relaynet_pytorch/__init__.py -------------------------------------------------------------------------------- /relaynet_pytorch/data_utils.py: -------------------------------------------------------------------------------- 1 | """Data utility functions.""" 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | import h5py 8 | 9 | 10 | class ImdbData(data.Dataset): 11 | def __init__(self, X, y, w): 12 | self.X = X 13 | self.y = y 14 | self.w = w 15 | 16 | def __getitem__(self, index): 17 | img = self.X[index] 18 | label = self.y[index] 19 | weight = self.w[index] 20 | 21 | img = torch.from_numpy(img) 22 | label = torch.from_numpy(label) 23 | weight = torch.from_numpy(weight) 24 | return img, label, weight 25 | 26 | def __len__(self): 27 | return len(self.y) 28 | 29 | 30 | def get_imdb_data(): 31 | # TODO: Need to change later 32 | NumClass = 9 33 | 34 | # Load DATA 35 | Data = h5py.File('datasets/Data.h5', 'r') 36 | a_group_key = list(Data.keys())[0] 37 | Data = list(Data[a_group_key]) 38 | Data = np.squeeze(np.asarray(Data)) 39 | Label = h5py.File('datasets/label.h5', 'r') 40 | a_group_key = list(Label.keys())[0] 41 | Label = list(Label[a_group_key]) 42 | Label = np.squeeze(np.asarray(Label)) 43 | set = h5py.File('datasets/set.h5', 'r') 44 | a_group_key = list(set.keys())[0] 45 | set = list(set[a_group_key]) 46 | set = np.squeeze(np.asarray(set)) 47 | sz = Data.shape 48 | Data = Data.reshape([sz[0], 1, sz[1], sz[2]]) 49 | Data = Data[:, :, 61:573, :] 50 | weights = Label[:, 1, 61:573, :] 51 | Label = Label[:, 0, 61:573, :] 52 | sz = Label.shape 53 | Label = Label.reshape([sz[0], 1, sz[1], sz[2]]) 54 | weights = weights.reshape([sz[0], 1, sz[1], sz[2]]) 55 | train_id = set == 1 56 | test_id = set == 3 57 | 58 | Tr_Dat = Data[train_id, :, :, :] 59 | Tr_Label = np.squeeze(Label[train_id, :, :, :]) - 1 # Index from [0-(NumClass-1)] 60 | Tr_weights = weights[train_id, :, :, :] 61 | Tr_weights = np.tile(Tr_weights, [1, NumClass, 1, 1]) 62 | 63 | Te_Dat = Data[test_id, :, :, :] 64 | Te_Label = np.squeeze(Label[test_id, :, :, :]) - 1 65 | Te_weights = weights[test_id, :, :, :] 66 | Te_weights = np.tile(Te_weights, [1, NumClass, 1, 1]) 67 | 68 | 69 | 70 | return (ImdbData(Tr_Dat, Tr_Label, Tr_weights), 71 | ImdbData(Te_Dat, Te_Label, Te_weights)) 72 | -------------------------------------------------------------------------------- /relaynet_pytorch/net_api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ai-med/relaynet_pytorch/40ae1aa56e426da14ddca37e06c2f31966febea5/relaynet_pytorch/net_api/__init__.py -------------------------------------------------------------------------------- /relaynet_pytorch/net_api/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn.modules.loss import _Loss 4 | from torch.autograd import Function, Variable 5 | import torch.nn as nn 6 | import torch 7 | import numpy as np 8 | from torch.nn.modules.loss import _Loss 9 | from torch.autograd import Function, Variable 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class DiceCoeff(nn.Module): 15 | """Dice coeff for individual examples""" 16 | 17 | def __init__(self): 18 | super(DiceCoeff, self).__init__() 19 | 20 | def forward(self, input, target): 21 | inter = torch.dot(input, target) + 0.0001 22 | union = torch.sum(input ** 2) + torch.sum(target ** 2) + 0.0001 23 | 24 | t = 2 * inter.float() / union.float() 25 | return t 26 | 27 | 28 | def dice_coeff(input, target): 29 | """Dice coeff for batches""" 30 | if input.is_cuda: 31 | s = Variable(torch.FloatTensor(1).cuda().zero_()) 32 | else: 33 | s = Variable(torch.FloatTensor(1).zero_()) 34 | 35 | for i, c in enumerate(zip(input, target)): 36 | s = s + DiceCoeff().forward(c[0], c[1]) 37 | 38 | return s / (i + 1) 39 | 40 | 41 | class DiceLoss(_Loss): 42 | def forward(self, output, target, weights=None, ignore_index=None): 43 | """ 44 | output : NxCxHxW Variable 45 | target : NxHxW LongTensor 46 | weights : C FloatTensor 47 | ignore_index : int index to ignore from loss 48 | """ 49 | eps = 0.0001 50 | 51 | output = output.exp() 52 | encoded_target = output.detach() * 0 53 | if ignore_index is not None: 54 | mask = target == ignore_index 55 | target = target.clone() 56 | target[mask] = 0 57 | encoded_target.scatter_(1, target.unsqueeze(1), 1) 58 | mask = mask.unsqueeze(1).expand_as(encoded_target) 59 | encoded_target[mask] = 0 60 | else: 61 | encoded_target.scatter_(1, target.unsqueeze(1), 1) 62 | 63 | if weights is None: 64 | weights = 1 65 | 66 | intersection = output * encoded_target 67 | numerator = 2 * intersection.sum(0).sum(1).sum(1) 68 | denominator = output + encoded_target 69 | 70 | if ignore_index is not None: 71 | denominator[mask] = 0 72 | denominator = denominator.sum(0).sum(1).sum(1) + eps 73 | loss_per_channel = weights * (1 - (numerator / denominator)) 74 | 75 | return loss_per_channel.sum() / output.size(1) 76 | 77 | 78 | class CrossEntropyLoss2d(nn.Module): 79 | def __init__(self, weight=None, size_average=True): 80 | super(CrossEntropyLoss2d, self).__init__() 81 | self.nll_loss = nn.CrossEntropyLoss(weight, size_average) 82 | 83 | def forward(self, inputs, targets): 84 | return self.nll_loss(inputs, targets) 85 | 86 | 87 | class CombinedLoss(nn.Module): 88 | def __init__(self): 89 | super(CombinedLoss, self).__init__() 90 | self.cross_entropy_loss = CrossEntropyLoss2d() 91 | self.dice_loss = DiceLoss() 92 | 93 | def forward(self, input, target, weight): 94 | # TODO: why? 95 | target = target.type(torch.LongTensor).cuda() 96 | input_soft = F.softmax(input,dim=1) 97 | y2 = torch.mean(self.dice_loss(input_soft, target)) 98 | y1 = torch.mean(torch.mul(self.cross_entropy_loss.forward(input, target), weight)) 99 | y = y1 + y2 100 | return y 101 | 102 | -------------------------------------------------------------------------------- /relaynet_pytorch/net_api/sub_module.py: -------------------------------------------------------------------------------- 1 | # List of APIs 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | ''' 8 | param ={ 9 | 'num_channels':1, 10 | 'num_filters':64, 11 | 'kernel_h':7, 12 | 'kernel_w':3, 13 | 'stride_conv':1, 14 | 'pool':2, 15 | 'stride_pool':2, 16 | 'num_classes':10 17 | } 18 | 19 | ''' 20 | 21 | def __init__(self, params): 22 | super(BasicBlock, self).__init__() 23 | 24 | padding_h = int((params['kernel_h'] - 1) / 2) 25 | padding_w = int((params['kernel_w'] - 1) / 2) 26 | 27 | self.conv = nn.Conv2d(in_channels=params['num_channels'], out_channels=params['num_filters'], 28 | kernel_size=(params['kernel_h'], params['kernel_w']), 29 | padding=(padding_h, padding_w), 30 | stride=params['stride_conv']) 31 | self.batchnorm = nn.BatchNorm2d(num_features=params['num_filters']) 32 | self.prelu = nn.PReLU() 33 | 34 | def forward(self, input): 35 | out_conv = self.conv(input) 36 | out_bn = self.batchnorm(out_conv) 37 | out_prelu = self.prelu(out_bn) 38 | return out_prelu 39 | 40 | 41 | class EncoderBlock(BasicBlock): 42 | def __init__(self, params): 43 | super(EncoderBlock, self).__init__(params) 44 | self.maxpool = nn.MaxPool2d(kernel_size=params['pool'], stride=params['stride_pool'], return_indices=True) 45 | 46 | def forward(self, input): 47 | out_block = super(EncoderBlock, self).forward(input) 48 | out_encoder, indices = self.maxpool(out_block) 49 | return out_encoder, out_block, indices 50 | 51 | 52 | class DecoderBlock(BasicBlock): 53 | def __init__(self, params): 54 | super(DecoderBlock, self).__init__(params) 55 | self.unpool = nn.MaxUnpool2d(kernel_size=params['pool'], stride=params['stride_pool']) 56 | 57 | def forward(self, input, out_block, indices): 58 | unpool = self.unpool(input, indices) 59 | concat = torch.cat((out_block, unpool), dim=1) 60 | out_block = super(DecoderBlock, self).forward(concat) 61 | 62 | return out_block 63 | 64 | 65 | class ClassifierBlock(nn.Module): 66 | def __init__(self, params): 67 | super(ClassifierBlock, self).__init__() 68 | self.conv = nn.Conv2d(params['num_channels'], params['num_class'], params['kernel_c'], params['stride_conv']) 69 | self.softmax = nn.Softmax2d() 70 | 71 | def forward(self, input): 72 | out_conv = self.conv(input) 73 | #out_logit = self.softmax(out_conv) 74 | return out_conv 75 | -------------------------------------------------------------------------------- /relaynet_pytorch/relay_net.py: -------------------------------------------------------------------------------- 1 | """ClassificationCNN""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | from relaynet_pytorch.net_api import sub_module as sm 6 | 7 | 8 | class ReLayNet(nn.Module): 9 | """ 10 | A PyTorch implementation of ReLayNet 11 | Coded by Shayan and Abhijit 12 | 13 | param ={ 14 | 'num_channels':1, 15 | 'num_filters':64, 16 | 'num_channels':64, 17 | 'kernel_h':7, 18 | 'kernel_w':3, 19 | 'stride_conv':1, 20 | 'pool':2, 21 | 'stride_pool':2, 22 | 'num_classes':10 23 | } 24 | 25 | """ 26 | 27 | def __init__(self, params): 28 | super(ReLayNet, self).__init__() 29 | 30 | self.encode1 = sm.EncoderBlock(params) 31 | params['num_channels'] = 64 32 | self.encode2 = sm.EncoderBlock(params) 33 | # params['num_channels'] = 64 # This can be used to change the numchannels for each block 34 | self.encode3 = sm.EncoderBlock(params) 35 | self.bottleneck = sm.BasicBlock(params) 36 | params['num_channels'] = 128 37 | self.decode1 = sm.DecoderBlock(params) 38 | self.decode2 = sm.DecoderBlock(params) 39 | self.decode3 = sm.DecoderBlock(params) 40 | params['num_channels'] = 64 41 | self.classifier = sm.ClassifierBlock(params) 42 | 43 | def forward(self, input): 44 | e1, out1, ind1 = self.encode1.forward(input) 45 | e2, out2, ind2 = self.encode2.forward(e1) 46 | e3, out3, ind3 = self.encode3.forward(e2) 47 | bn = self.bottleneck.forward(e3) 48 | 49 | d3 = self.decode1.forward(bn, out3, ind3) 50 | d2 = self.decode2.forward(d3, out2, ind2) 51 | d1 = self.decode3.forward(d2, out1, ind1) 52 | prob = self.classifier.forward(d1) 53 | 54 | return prob 55 | 56 | @property 57 | def is_cuda(self): 58 | """ 59 | Check if model parameters are allocated on the GPU. 60 | """ 61 | return next(self.parameters()).is_cuda 62 | 63 | def save(self, path): 64 | """ 65 | Save model with its parameters to the given path. Conventionally the 66 | path should end with "*.model". 67 | 68 | Inputs: 69 | - path: path string 70 | """ 71 | print('Saving model... %s' % path) 72 | torch.save(self, path) 73 | -------------------------------------------------------------------------------- /relaynet_pytorch/solver.py: -------------------------------------------------------------------------------- 1 | from random import shuffle 2 | import numpy as np 3 | import torch.nn.functional as F 4 | import torch 5 | from torch.autograd import Variable 6 | from relaynet_pytorch.net_api.losses import CombinedLoss 7 | from torch.optim import lr_scheduler 8 | import os 9 | 10 | 11 | def per_class_dice(y_pred, y_true, num_class): 12 | avg_dice = 0 13 | y_pred = y_pred.data.cpu().numpy() 14 | y_true = y_true.data.cpu().numpy() 15 | for i in range(num_class): 16 | GT = y_true == (i + 1) 17 | Pred = y_pred == (i + 1) 18 | inter = np.sum(np.matmul(GT, Pred)) + 0.0001 19 | union = np.sum(GT) + np.sum(Pred) + 0.0001 20 | t = 2 * inter / union 21 | avg_dice = avg_dice + (t / num_class) 22 | return avg_dice 23 | 24 | 25 | def create_exp_directory(exp_dir_name): 26 | if not os.path.exists('models/' + exp_dir_name): 27 | os.makedirs('models/' + exp_dir_name) 28 | 29 | 30 | class Solver(object): 31 | # global optimiser parameters 32 | default_optim_args = {"lr": 1e-2, 33 | "betas": (0.9, 0.999), 34 | "eps": 1e-8, 35 | "weight_decay": 0.0001} 36 | gamma = 0.5 37 | step_size = 5 38 | NumClass = 9 39 | 40 | def __init__(self, optim=torch.optim.Adam, optim_args={}, 41 | loss_func=CombinedLoss()): 42 | optim_args_merged = self.default_optim_args.copy() 43 | optim_args_merged.update(optim_args) 44 | self.optim_args = optim_args_merged 45 | self.optim = optim 46 | self.loss_func = loss_func 47 | 48 | self._reset_histories() 49 | 50 | def _reset_histories(self): 51 | """ 52 | Resets train and val histories for the accuracy and the loss. 53 | """ 54 | self.train_loss_history = [] 55 | self.train_acc_history = [] 56 | self.val_acc_history = [] 57 | self.val_loss_history = [] 58 | 59 | def train(self, model, train_loader, val_loader, num_epochs=10, log_nth=5, exp_dir_name='exp_default'): 60 | """ 61 | Train a given model with the provided data. 62 | 63 | Inputs: 64 | - model: model object initialized from a torch.nn.Module 65 | - train_loader: train data in torch.utils.data.DataLoader 66 | - val_loader: val data in torch.utils.data.DataLoader 67 | - num_epochs: total number of training epochs 68 | - log_nth: log training accuracy and loss every nth iteration 69 | """ 70 | optim = self.optim(model.parameters(), **self.optim_args) 71 | scheduler = lr_scheduler.StepLR(optim, step_size=self.step_size, 72 | gamma=self.gamma) # decay LR by a factor of 0.5 every 5 epochs 73 | 74 | self._reset_histories() 75 | iter_per_epoch = 1 76 | # iter_per_epoch = len(train_loader) 77 | 78 | if torch.cuda.is_available(): 79 | model.cuda() 80 | 81 | print('START TRAIN.') 82 | curr_iter = 0 83 | 84 | create_exp_directory(exp_dir_name) 85 | 86 | for epoch in range(num_epochs): 87 | scheduler.step() 88 | for i_batch, sample_batched in enumerate(train_loader): 89 | X = Variable(sample_batched[0]) 90 | y = Variable(sample_batched[1]) 91 | w = Variable(sample_batched[2]) 92 | 93 | if model.is_cuda: 94 | X, y, w = X.cuda(), y.cuda(), w.cuda() 95 | 96 | for iter in range(iter_per_epoch): 97 | curr_iter += iter 98 | optim.zero_grad() 99 | output = model(X) 100 | loss = self.loss_func(output, y, w) 101 | loss.backward() 102 | optim.step() 103 | if iter % log_nth == 0: 104 | self.train_loss_history.append(loss.data[0]) 105 | print('[Iteration : ' + str(iter) + '/' + str(iter_per_epoch * num_epochs) + '] : ' + str( 106 | loss.data[0])) 107 | 108 | 109 | #_, batch_output = torch.max(F.softmax(model(X),dim=1), dim=1) 110 | #avg_dice = per_class_dice(batch_output, y, self.NumClass) 111 | #print('Per class average dice score is ' + str(avg_dice)) 112 | # self.train_acc_history.append(train_accuracy) 113 | # 114 | # val_output = torch.max(model(Variable(torch.from_numpy(val_loader.dataset.X))), dim= 1) 115 | # val_accuracy = self.accuracy(val_output[1], Variable(torch.from_numpy(val_loader.dataset.y))) 116 | # self.val_acc_history.append(val_accuracy) 117 | print('[Epoch : ' + str(epoch) + '/' + str(num_epochs) + '] : ' + str(loss.data[0])) 118 | model.save('models/' + exp_dir_name + '/relaynet_epoch' + str(epoch + 1) + '.model') 119 | print('FINISH.') 120 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup(name="relaynet-pytorch", 4 | version="1.1", 5 | url="https://github.com/abhi4ssj/relaynet_pytorch", 6 | author="Abhijit Guha Roy", 7 | author_email="abhi4ssj@gmail.com", 8 | description="Retinal Layer and Fluid Segmentation of Macular Optical Coherence Tomography using Fully Convolutional Networks.", 9 | packages=setuptools.find_packages(), 10 | install_requires=[]) 11 | --------------------------------------------------------------------------------