├── .gitignore ├── data_utils.py ├── loss.py ├── playground.ipynb ├── prediction.ipynb ├── prediction_dense.ipynb ├── relaynet.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | datasets/ 3 | 4 | # Created by https://www.gitignore.io/api/linux,macos,python,windows,pycharm+all 5 | 6 | ### Linux ### 7 | *~ 8 | 9 | # temporary files which can be created if a process still has a handle open of a deleted file 10 | .fuse_hidden* 11 | 12 | # KDE directory preferences 13 | .directory 14 | 15 | # Linux trash folder which might appear on any partition or disk 16 | .Trash-* 17 | 18 | # .nfs files are created when an open file is removed but is still being accessed 19 | .nfs* 20 | 21 | ### macOS ### 22 | *.DS_Store 23 | .AppleDouble 24 | .LSOverride 25 | 26 | # Icon must end with two \r 27 | Icon 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### PyCharm+all ### 49 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 50 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 51 | 52 | # User-specific stuff: 53 | .idea/**/workspace.xml 54 | .idea/**/tasks.xml 55 | .idea/dictionaries 56 | 57 | # Sensitive or high-churn files: 58 | .idea/**/dataSources/ 59 | .idea/**/dataSources.ids 60 | .idea/**/dataSources.xml 61 | .idea/**/dataSources.local.xml 62 | .idea/**/sqlDataSources.xml 63 | .idea/**/dynamic.xml 64 | .idea/**/uiDesigner.xml 65 | 66 | # Gradle: 67 | .idea/**/gradle.xml 68 | .idea/**/libraries 69 | 70 | # CMake 71 | cmake-build-debug/ 72 | 73 | # Mongo Explorer plugin: 74 | .idea/**/mongoSettings.xml 75 | 76 | ## File-based project format: 77 | *.iws 78 | 79 | ## Plugin-specific files: 80 | 81 | # IntelliJ 82 | /out/ 83 | 84 | # mpeltonen/sbt-idea plugin 85 | .idea_modules/ 86 | 87 | # JIRA plugin 88 | atlassian-ide-plugin.xml 89 | 90 | # Cursive Clojure plugin 91 | .idea/replstate.xml 92 | 93 | # Ruby plugin and RubyMine 94 | /.rakeTasks 95 | 96 | # Crashlytics plugin (for Android Studio and IntelliJ) 97 | com_crashlytics_export_strings.xml 98 | crashlytics.properties 99 | crashlytics-build.properties 100 | fabric.properties 101 | 102 | ### PyCharm+all Patch ### 103 | # Ignores the whole .idea folder and all .iml files 104 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 105 | 106 | .idea/ 107 | 108 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 109 | 110 | *.iml 111 | modules.xml 112 | .idea/misc.xml 113 | *.ipr 114 | 115 | ### Python ### 116 | # Byte-compiled / optimized / DLL files 117 | __pycache__/ 118 | *.py[cod] 119 | *$py.class 120 | 121 | # C extensions 122 | *.so 123 | 124 | # Distribution / packaging 125 | .Python 126 | build/ 127 | develop-eggs/ 128 | dist/ 129 | downloads/ 130 | eggs/ 131 | .eggs/ 132 | lib/ 133 | lib64/ 134 | parts/ 135 | sdist/ 136 | var/ 137 | wheels/ 138 | *.egg-info/ 139 | .installed.cfg 140 | *.egg 141 | 142 | # PyInstaller 143 | # Usually these files are written by a python script from a template 144 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 145 | *.manifest 146 | *.spec 147 | 148 | # Installer logs 149 | pip-log.txt 150 | pip-delete-this-directory.txt 151 | 152 | # Unit test / coverage reports 153 | htmlcov/ 154 | .tox/ 155 | .coverage 156 | .coverage.* 157 | .cache 158 | .pytest_cache/ 159 | nosetests.xml 160 | coverage.xml 161 | *.cover 162 | .hypothesis/ 163 | 164 | # Translations 165 | *.mo 166 | *.pot 167 | 168 | # Flask stuff: 169 | instance/ 170 | .webassets-cache 171 | 172 | # Scrapy stuff: 173 | .scrapy 174 | 175 | # Sphinx documentation 176 | docs/_build/ 177 | 178 | # PyBuilder 179 | target/ 180 | 181 | # Jupyter Notebook 182 | .ipynb_checkpoints 183 | 184 | # pyenv 185 | .python-version 186 | 187 | # celery beat schedule file 188 | celerybeat-schedule.* 189 | 190 | # SageMath parsed files 191 | *.sage.py 192 | 193 | # Environments 194 | .env 195 | .venv 196 | env/ 197 | venv/ 198 | ENV/ 199 | env.bak/ 200 | venv.bak/ 201 | 202 | # Spyder project settings 203 | .spyderproject 204 | .spyproject 205 | 206 | # Rope project settings 207 | .ropeproject 208 | 209 | # mkdocs documentation 210 | /site 211 | 212 | # mypy 213 | .mypy_cache/ 214 | 215 | ### Windows ### 216 | # Windows thumbnail cache files 217 | Thumbs.db 218 | ehthumbs.db 219 | ehthumbs_vista.db 220 | 221 | # Folder config file 222 | Desktop.ini 223 | 224 | # Recycle Bin used on file shares 225 | $RECYCLE.BIN/ 226 | 227 | # Windows Installer files 228 | *.cab 229 | *.msi 230 | *.msm 231 | *.msp 232 | 233 | # Windows shortcuts 234 | *.lnk 235 | 236 | 237 | # End of https://www.gitignore.io/api/linux,macos,python,windows,pycharm+all 238 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | """Data utility functions.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import h5py 7 | from scipy.io import loadmat 8 | 9 | 10 | def salt_and_pepper(img, prob, shape): 11 | rnd = np.random.rand(*shape) 12 | noisy = img[:] 13 | noisy[rnd < prob/2] = 0. 14 | noisy[rnd > 1 - prob/2] = 1. 15 | return noisy 16 | 17 | 18 | class MatDataset(data.Dataset): 19 | def __init__(self, path): 20 | data = loadmat(path) 21 | oct = data['volumedata'] 22 | annotations = data['O1'] 23 | 24 | oct = np.transpose(oct, (2, 1, 0)) 25 | oct = oct[:, 61 + 16:573, :] 26 | 27 | sz = oct.shape 28 | self.oct = oct.reshape([sz[0], 1, sz[1], sz[2]]) 29 | 30 | annotations = np.transpose(annotations, (2, 1, 0)) 31 | self.annotations = annotations[:, 61 + 16:573, :] 32 | 33 | def convert_annotation(self, a): 34 | a = a.astype(np.int) 35 | label = np.zeros((a.shape[0], a.shape[0])) 36 | last = list() 37 | for i in range(a.shape[0]): 38 | last.append(0) 39 | 40 | for c in range(9): 41 | for i in range(a.shape[0]): 42 | if a[i, c] == 0: 43 | continue 44 | label[i, last[i]:a[i, c]] = c 45 | last[i] = a[i, c] 46 | 47 | return label 48 | 49 | def __len__(self): 50 | return len(self.oct) 51 | 52 | def __getitem__(self, item): 53 | img = self.oct[item].astype(np.float32) 54 | annotation = self.annotations[item] 55 | label = self.convert_annotation(annotation) 56 | label += 1 57 | label_bin = np.zeros((9, label.shape[0], label.shape[1]), dtype=np.int32) 58 | i, j = np.mgrid[0:label.shape[0], 0:label.shape[1]] 59 | label_bin[label.astype(np.int), i, j] = 1 60 | 61 | img = torch.from_numpy(img) 62 | label = torch.from_numpy(label) 63 | label_bin = torch.from_numpy(label_bin) 64 | 65 | return img, label, label_bin, 1 # no weight available 66 | 67 | class ImdbData(data.Dataset): 68 | def __init__(self, X, y, yb, w, salt_pepper_noise_prob=0): 69 | self.X = X 70 | self.y = y 71 | self.yb = yb 72 | self.w = w 73 | self.salt_pepper_noise_prob = salt_pepper_noise_prob 74 | 75 | def __getitem__(self, index): 76 | img = self.X[index] 77 | label = self.y[index] 78 | label_bin = self.yb[index] 79 | weight = self.w[index] 80 | if self.salt_pepper_noise_prob > 0: 81 | img = salt_and_pepper(img, self.salt_pepper_noise_prob, img.shape) 82 | 83 | img = torch.from_numpy(img) 84 | label = torch.from_numpy(label) 85 | label[label == 9] = 1 86 | label_bin = torch.from_numpy(label_bin) 87 | label_bin[1] = label_bin[1] + label_bin[9] 88 | label_bin = label_bin[:9] 89 | weight = torch.from_numpy(weight) 90 | weight[1] = weight[1] + weight[9] 91 | weight = weight[:9] 92 | return img, label, label_bin, weight 93 | 94 | def __len__(self): 95 | return len(self.y) 96 | 97 | 98 | def get_imdb_data(): 99 | # TODO: Need to change later 100 | NumClass = 10 101 | 102 | # Load DATA 103 | Data = h5py.File('datasets/Data.h5', 'r') 104 | a_group_key = list(Data.keys())[0] 105 | Data = list(Data[a_group_key]) 106 | Data = np.squeeze(np.asarray(Data)) 107 | Label = h5py.File('datasets/label.h5', 'r') 108 | a_group_key = list(Label.keys())[0] 109 | Label = list(Label[a_group_key]) 110 | Label = np.squeeze(np.asarray(Label)) 111 | set = h5py.File('datasets/set.h5', 'r') 112 | a_group_key = list(set.keys())[0] 113 | set = list(set[a_group_key]) 114 | set = np.squeeze(np.asarray(set)) 115 | sz = Data.shape 116 | Data = Data.reshape([sz[0], 1, sz[1], sz[2]]) 117 | Data = Data[:, :, 61:573, :] 118 | weights = Label[:, 1, 61:573, :] 119 | Label = Label[:, 0, 61:573, :] 120 | sz = Label.shape 121 | Label = Label.reshape([sz[0], 1, sz[1], sz[2]]) 122 | weights = weights.reshape([sz[0], 1, sz[1], sz[2]]) 123 | train_id = set == 1 124 | test_id = set == 3 125 | 126 | Tr_Dat = Data[train_id, :, :, :] 127 | Tr_Label = np.squeeze(Label[train_id, :, :, :]) 128 | Tr_weights = weights[train_id, :, :, :] 129 | Tr_weights = np.tile(Tr_weights, [1, NumClass, 1, 1]) 130 | 131 | Te_Dat = Data[test_id, :, :, :] 132 | Te_Label = np.squeeze(Label[test_id, :, :, :]) 133 | Te_weights = weights[test_id, :, :, :] 134 | Te_weights = np.tile(Te_weights, [1, NumClass, 1, 1]) 135 | 136 | sz = Tr_Dat.shape 137 | sz_test = Te_Dat.shape 138 | y2 = np.ones((sz[0], NumClass, sz[2], sz[3])) 139 | y_test = np.ones((sz_test[0], NumClass, sz_test[2], sz_test[3])) 140 | for i in range(NumClass): 141 | y2[:, i, :, :] = np.squeeze(np.multiply(np.ones(Tr_Label.shape), ((Tr_Label == i)))) 142 | y_test[:, i, :, :] = np.squeeze(np.multiply(np.ones(Te_Label.shape), ((Te_Label == i)))) 143 | 144 | Tr_Label_bin = y2 145 | Te_Label_bin = y_test 146 | 147 | return (ImdbData(Tr_Dat, Tr_Label, Tr_Label_bin, Tr_weights), 148 | ImdbData(Te_Dat, Te_Label, Te_Label_bin, Te_weights)) -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def dice_coeff(output, target_bin, n_classes=9): 6 | batch_size = output.size()[0] 7 | upper = (output * target_bin).view(batch_size, n_classes, -1).sum(dim=-1) 8 | lower = (output ** 2).view(batch_size, n_classes, -1).sum(dim=-1) \ 9 | + (target_bin ** 2).view(batch_size, n_classes, -1).sum(dim=-1) 10 | return 2 * upper / lower 11 | 12 | 13 | class DiceLoss(nn.Module): 14 | def forward(self, output, target_bin): 15 | loss = dice_coeff(output, target_bin) 16 | return torch.mean(1 - loss) 17 | 18 | 19 | class WeightedClassificationLoss(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | self.criterion = nn.NLLLoss(size_average=True, reduce=True) 23 | 24 | def forward(self, output, target, weight): 25 | loss = self.criterion(output, target) 26 | # TODO use weight accordingly 27 | return torch.mean(loss) 28 | 29 | 30 | class TotalLoss(nn.Module): 31 | def __init__(self,classification_weight=1, dice_weight=0.5): 32 | super().__init__() 33 | self.dice_loss = DiceLoss() 34 | self.classification_loss = WeightedClassificationLoss() 35 | self.classification_weight = classification_weight 36 | self.dice_weight = dice_weight 37 | 38 | def forward(self, output, target, weight, target_bin): 39 | return self.classification_weight * self.classification_loss(output, target, weight) + \ 40 | self.dice_weight * self.dice_loss(output, target_bin) 41 | 42 | -------------------------------------------------------------------------------- /relaynet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | def __init__(self, num_input_channels=1, kernel=(3, 7), stride=1, num_output_channels=64, dropout_prob=0.3): 8 | super().__init__() 9 | padding = (np.asarray(kernel) - 1) / 2 10 | padding = tuple(padding.astype(np.int)) 11 | 12 | self.model = nn.Sequential( 13 | nn.Conv2d(in_channels=num_input_channels, out_channels=num_output_channels, 14 | kernel_size=kernel, 15 | padding=padding, 16 | stride=stride), 17 | nn.BatchNorm2d(num_features=num_output_channels), 18 | nn.PReLU() 19 | ) 20 | 21 | if dropout_prob > 0: 22 | self.model.add_module(str(len(self.model)), nn.Dropout2d(dropout_prob)) 23 | 24 | def forward(self, input): 25 | return self.model(input) 26 | 27 | 28 | class DenseBlock(nn.Module): 29 | def __init__(self, num_input_channels=1, kernel=(3, 7), stride=1, num_output_channels=64, dropout_prob=0.3): 30 | super().__init__() 31 | self.dense_modules = nn.ModuleList([ 32 | BasicBlock(num_input_channels, kernel, stride, num_output_channels, dropout_prob), 33 | BasicBlock(num_input_channels + num_output_channels, kernel, stride, num_output_channels, dropout_prob), 34 | BasicBlock(num_input_channels + 2 * num_output_channels, (1, 1), stride, num_output_channels, dropout_prob) 35 | ]) 36 | 37 | def forward(self, input): 38 | outputs = [] 39 | for module in self.dense_modules: 40 | input_cat = torch.cat([input] + outputs, dim=1) 41 | output = module(input_cat) 42 | outputs.append(output) 43 | 44 | return outputs[-1] 45 | 46 | 47 | class EncoderBlock(nn.Module): 48 | def __init__(self, num_input_channels=1, kernel=(3, 7), stride_conv=1, stride_pool=2, num_output_channels=64, 49 | dropout_prob=0.3, basic_block=BasicBlock): 50 | super().__init__() 51 | self.basic = basic_block(num_input_channels, kernel, stride_conv, num_output_channels, dropout_prob) 52 | self.pool = nn.MaxPool2d(kernel, stride_pool, return_indices=True) 53 | 54 | def forward(self, input): 55 | tmp = self.basic(input) 56 | out, indices = self.pool(tmp) 57 | return out, indices, tmp 58 | 59 | 60 | class DecoderBlock(nn.Module): 61 | def __init__(self, num_input_channels=64, kernel=(3, 7), stride_conv=1, stride_pool=2, num_output_channels=64, 62 | dropout_prob=0.3, basic_block=BasicBlock): 63 | super().__init__() 64 | self.basic = basic_block(num_input_channels * 2, kernel, stride_conv, num_output_channels, dropout_prob) 65 | self.unpool = nn.MaxUnpool2d(kernel, stride_pool) 66 | 67 | def forward(self, input, indices, encoder_block): 68 | tmp = self.unpool(input, indices, output_size=encoder_block.size()) 69 | tmp = torch.cat((encoder_block, tmp), dim=1) 70 | return self.basic(tmp) 71 | 72 | 73 | class ClassifierBlock(nn.Module): 74 | def __init__(self, num_input_channels=64, kernel=(1, 1), stride_conv=1, num_classes=10): 75 | super().__init__() 76 | self.classify = nn.Sequential( 77 | nn.Conv2d(num_input_channels, num_classes, kernel, stride_conv), 78 | nn.Softmax2d() 79 | ) 80 | 81 | def forward(self, input): 82 | return self.classify(input) 83 | 84 | 85 | class RelayNet(nn.Module): 86 | def __init__(self, num_input_channels=1, kernel=(3, 3), stride_conv=1, stride_pool=2, num_output_channels=64, 87 | num_encoders=3, num_classes=9, kernel_classify=(1, 1), dropout_prob=0.3, basic_block=BasicBlock): 88 | super().__init__() 89 | self.encoders = nn.ModuleList([EncoderBlock(num_input_channels if i == 0 else num_output_channels, kernel, 90 | stride_conv, stride_pool, num_output_channels, dropout_prob, 91 | basic_block) 92 | for i in range(num_encoders)]) 93 | self.bottleneck = basic_block(num_output_channels, kernel, stride_conv, num_output_channels, dropout_prob) 94 | self.decoders = nn.ModuleList( 95 | [DecoderBlock(num_output_channels, kernel, stride_conv, stride_pool, num_output_channels, dropout_prob, 96 | basic_block) 97 | for _ in range(num_encoders)]) 98 | self.classify = ClassifierBlock(num_output_channels, kernel_classify, stride_conv, num_classes) 99 | 100 | def forward(self, input): 101 | out = input 102 | encodings = list() 103 | for encoder in self.encoders: 104 | out, indices, before_maxpool = encoder(out) 105 | encodings.append((out, indices, before_maxpool)) 106 | 107 | out = self.bottleneck(encodings[-1][0]) 108 | 109 | for i, encoded in enumerate(reversed(encodings)): 110 | decoder = self.decoders[i] 111 | out = decoder(out, encoded[1], encoded[2]) 112 | 113 | return self.classify(out) 114 | 115 | def train(self, mode=True): 116 | super().train(mode) 117 | 118 | # to do MC dropout we would like to keep dropout also during evaluation 119 | for module in self.modules(): 120 | if 'dropout' in module.__class__.__name__.lower(): 121 | module.train(False) 122 | 123 | def predict(self, input, times=10): 124 | self.eval() 125 | results = list() 126 | for _ in range(times): 127 | out = self.forward(input) 128 | results.append(out.data.cpu().numpy()) 129 | 130 | results = np.asarray(results, dtype=np.float) 131 | average = results.mean(axis=0).squeeze() 132 | per_class_entropy = -np.sum(results * np.log(results + 1e-12), axis=0) 133 | overall_entropy = -np.sum(results * np.log(results + 1e-12), axis=(0, 2)) # 1 is batch size 134 | 135 | return average, per_class_entropy / times, overall_entropy / times, results 136 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import multiprocessing 4 | 5 | import torch 6 | import numpy as np 7 | from torch import optim 8 | from torch.autograd import Variable 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | from tqdm import trange 12 | 13 | 14 | from data_utils import get_imdb_data 15 | from loss import TotalLoss, dice_coeff 16 | from relaynet import RelayNet, DenseBlock, BasicBlock 17 | 18 | 19 | def train(epoch, data, net, criterion, optimizer, args): 20 | train_set = DataLoader(data, batch_size=args.batch_size, num_workers=multiprocessing.cpu_count(), shuffle=True) 21 | 22 | progress_bar = tqdm(iter(train_set)) 23 | moving_loss = 0 24 | 25 | net.train() 26 | for img, label, label_bin, weight in progress_bar: 27 | img, label, label_bin, weight = Variable(img), Variable(label), Variable(label_bin), Variable(weight) 28 | label = label.type(torch.LongTensor) 29 | label_bin = label_bin.type(torch.FloatTensor) 30 | 31 | if args.cuda: 32 | img, label, label_bin, weight = img.cuda(), label.cuda(), label_bin.cuda(), weight.cuda() 33 | 34 | output = net(img) 35 | loss = criterion(output, label, weight, label_bin) 36 | net.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | 40 | if moving_loss == 0: 41 | moving_loss = loss.item() 42 | else: 43 | moving_loss = moving_loss * 0.9 + loss.item() * 0.1 44 | 45 | dice_avg = torch.mean(dice_coeff(output, label_bin)) 46 | 47 | progress_bar.set_description( 48 | 'Epoch: {}; Loss: {:.5f}; Avg: {:.5f}; Dice: {:.5f}' 49 | .format(epoch + 1, loss.item(), moving_loss, dice_avg.item())) 50 | 51 | 52 | def valid(data, net, args, mc_samples=1): 53 | valid_set = DataLoader(data, batch_size=args.batch_size // 2, num_workers=multiprocessing.cpu_count(), shuffle=True) 54 | net.eval() 55 | 56 | progress_bar = tqdm(iter(valid_set)) 57 | 58 | dice_avg = list() 59 | entropy_avg = list() 60 | for img, label, label_bin, weight in progress_bar: 61 | img, label, label_bin, weight = Variable(img), Variable(label), Variable(label_bin), Variable(weight) 62 | label_bin = label_bin.type(torch.FloatTensor) 63 | 64 | if args.cuda: 65 | img, label_bin = img.cuda(), label_bin.cuda() 66 | 67 | if mc_samples > 1: 68 | # lol this is insanely inefficient 69 | avg, _, overall_entropy, _ = net.predict(img, times=mc_samples) 70 | entropy_avg.append(np.mean(overall_entropy)) 71 | output = Variable(torch.Tensor(avg)) 72 | if args.cuda: 73 | output = output.cuda() 74 | else: 75 | output = net(img) 76 | 77 | dice_avg.append(torch.mean(dice_coeff(output, label_bin)).item()) 78 | 79 | dice_avg = np.asarray(dice_avg).mean() 80 | entropy_avg = np.asarray(entropy_avg).mean() 81 | 82 | print('Validation dice avg: {}'.format(dice_avg)) 83 | print('Validation entropy avg: {}'.format(entropy_avg)) 84 | 85 | return dice_avg, entropy_avg 86 | 87 | 88 | def parse_args(): 89 | parser = argparse.ArgumentParser(description='Train SqueezeNet with PyTorch.') 90 | parser.add_argument('--batch-size', action='store', type=int, dest='batch_size', default=8) 91 | parser.add_argument('--epochs', action='store', type=int, dest='epochs', default=90) 92 | parser.add_argument('--cuda', action='store', type=bool, dest='cuda', default=True) 93 | parser.add_argument('--validation', action='store_true', dest='validation', default=True) 94 | parser.add_argument('--model-checkpoint-dir', action='store', type=str, default='./models') 95 | parser.add_argument('--use-dense-connections', action='store_true', dest='dense', default=False) 96 | parser.add_argument('--dropout-prob', action='store', type=float, default=0.5) 97 | 98 | return parser.parse_args() 99 | 100 | 101 | def main(): 102 | print('number of cpus used for loading data: {}'.format(multiprocessing.cpu_count())) 103 | args = parse_args() 104 | os.makedirs(args.model_checkpoint_dir, exist_ok=True) 105 | 106 | relay_net = RelayNet(basic_block=DenseBlock if args.dense else BasicBlock, dropout_prob=args.dropout_prob) 107 | print(relay_net) 108 | if args.cuda: 109 | relay_net = relay_net.cuda() 110 | 111 | criterion = TotalLoss() 112 | optimizer = optim.Adam(relay_net.parameters(), lr=0.001, weight_decay=0.0001) 113 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.95) 114 | 115 | train_data, valid_data = get_imdb_data() 116 | 117 | for epoch in trange(args.epochs): 118 | scheduler.step(epoch) 119 | train(epoch, train_data, relay_net, criterion, optimizer, args) 120 | 121 | torch.save(relay_net.state_dict(), os.path.join(args.model_checkpoint_dir, 'model-{}.model'.format(epoch))) 122 | 123 | del criterion, optimizer, scheduler 124 | 125 | if args.validation: 126 | best = (-1, -1) 127 | for epoch in trange(args.epochs): 128 | relay_net.load_state_dict(torch.load(os.path.join(args.model_checkpoint_dir, 'model-{}.model'.format(epoch)))) 129 | if args.cuda: 130 | relay_net = relay_net.cuda() 131 | dice, entropy = valid(valid_data, relay_net, args) 132 | _, best_dice = best 133 | 134 | if dice > best_dice: 135 | best = (epoch, dice) 136 | 137 | print('Best model with epoch {} and dice {}'.format(*best)) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from itertools import product 3 | 4 | import torch 5 | import numpy as np 6 | 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from loss import dice_coeff 12 | 13 | def iter_data_and_predict(data, net, args, mc_samples=10): 14 | valid_set = DataLoader(data, batch_size=args.batch_size, num_workers=multiprocessing.cpu_count(), shuffle=True) 15 | net.eval() 16 | 17 | progress_bar = tqdm(iter(valid_set)) 18 | 19 | for img, label, label_bin, weight in progress_bar: 20 | img, label, label_bin, weight = Variable(img), Variable(label), Variable(label_bin), Variable(weight) 21 | label = label.type(torch.LongTensor) 22 | 23 | if args.cuda: 24 | img, label = img.cuda(), label.cuda() 25 | 26 | avg, per_class_entropy, overall_entropy, samples = net.predict(img, times=mc_samples) 27 | 28 | yield avg, per_class_entropy, overall_entropy, samples, label 29 | 30 | 31 | def error_map_dice(data, net, args, mc_samples=10, entropy_threshold=0.5): 32 | """ 33 | Computes the dice score between prediction error map and the entropy. This is a measure on how well the entropy 34 | describes the actual error the network makes. 35 | :param data: 36 | :param net: 37 | :param args: 38 | :param mc_samples: 39 | :param entropy_threshold: 40 | :return: 41 | """ 42 | 43 | dice_avg = list() 44 | for avg, _, overall_entropy, _, label in iter_data_and_predict(data, net, args, mc_samples): 45 | 46 | overall_entropy = overall_entropy > entropy_threshold 47 | overall_entropy = Variable(torch.Tensor(overall_entropy.astype(np.float32))) 48 | 49 | indices = np.argmax(avg, axis=1) # 1 is class dim 50 | indices = Variable(torch.LongTensor(indices)) 51 | 52 | if args.cuda: 53 | overall_entropy, indices = overall_entropy.cuda(), indices.cuda() 54 | 55 | error_map = label != indices 56 | error_map = error_map.type(torch.cuda.FloatTensor if args.cuda else torch.FloatTensor) 57 | 58 | dice_avg.append(torch.mean(dice_coeff(overall_entropy, error_map, n_classes=1)).item()) 59 | 60 | dice_avg = np.asarray(dice_avg).mean() 61 | 62 | 63 | print('dice avg: {}'.format(dice_avg)) 64 | 65 | return dice_avg 66 | 67 | 68 | def structure_wise_uncertainty_dice(data, net, args, mc_samples=10, n_classes=9): 69 | 70 | dice_avg = list() 71 | 72 | for _, _, _, samples, _ in iter_data_and_predict(data, net, args, mc_samples): 73 | 74 | samples = torch.Tensor(samples) 75 | if args.cuda: 76 | samples = samples.cuda() 77 | 78 | for i, j in product(range(mc_samples), range(mc_samples)): 79 | if i == j: 80 | continue 81 | 82 | dice_score = dice_coeff(samples[i], samples[j, :], n_classes=n_classes) 83 | dice_avg.append(torch.mean(dice_score, dim=0).cpu().numpy()) 84 | 85 | dice_avg = np.asarray(dice_avg).mean(axis=0) 86 | 87 | return dice_avg 88 | 89 | 90 | def structure_wise_uncertainty_cv(data, net, args, mc_samples=10): 91 | """ 92 | Coefficient of variance = mean/std_dev 93 | :param data: 94 | :param net: 95 | :param args: 96 | :param mc_samples: 97 | :return: 98 | """ 99 | cvs = list() 100 | 101 | for avg, _, _, samples, _ in iter_data_and_predict(data, net, args, mc_samples): 102 | 103 | std_dev = samples.std(axis=0) 104 | cv = avg / (std_dev + 1e-6) 105 | 106 | for x in cv: 107 | cvs.append(x) 108 | 109 | cv = np.asarray(cvs).mean(axis=(0, 2, 3)) 110 | 111 | return cv --------------------------------------------------------------------------------