├── README.md ├── eval.py ├── labels.csv ├── main.py ├── model.py ├── preprocessing.py ├── split.py └── tuning.sh /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-dicom-classification 2 | PyTorch framework to classify dicom (.dcm) files 3 |
4 | 5 | ### Dependencies 6 | ``` 7 | python 3.6.4 8 | pytorch 0.4.0 9 | torchvision 0.2.1 10 | numpy 1.14.1 11 | pydicom 1.0.2 12 | scikit-image 0.13.1 13 | ``` 14 |
15 | 16 | 17 | ### Usage 18 | ##### CAUTION: You must define your own labeling function in model.py 19 |
20 | 21 | #### preprocess dataset 22 | ``` 23 | python preprocessing.py /path/to/src/dir/ /path/to/dest/dir/ 24 | ``` 25 |
26 | 27 | #### split dataset for k-fold validation 28 | ``` 29 | python split.py /path/to/src/dir/ k 30 | ``` 31 |
32 | 33 | #### train dataset 34 | ``` 35 | python main.py --architecture resnet152 --output_dim 8192 --num_labels 17 --k 5 --src /path/to/src/dir/ 36 | ``` 37 |
38 | 39 | #### evaluation 40 | ``` 41 | python eval.py --ckpt /path/to/checkpoint/ --data_dir /path/to/src/dir/ --multilabel True --batch_size 64 --labels labels.csv 42 | ``` 43 |
44 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import confusion_matrix, roc_curve, auc 4 | import argparse 5 | from model import * 6 | from matplotlib import pyplot as plt 7 | import itertools 8 | from itertools import cycle 9 | import pandas as pd 10 | import seaborn as sns 11 | from scipy import interp 12 | 13 | def str2bool(s): 14 | if s == "True": 15 | return True 16 | elif s == "False": 17 | return False 18 | else: 19 | raise NotImplementedError 20 | 21 | def get_output(model, loader, with_prob=True): 22 | y_pred, y_true, = [], [] 23 | if with_prob: 24 | y_prob = [] 25 | else: 26 | y_prob = None 27 | for inputs, labels in loader: 28 | if torch.cuda.is_available(): 29 | inputs = inputs.cuda() 30 | labels = labels.cuda() 31 | outputs = model(inputs) 32 | _, preds = torch.max(outputs, 1) 33 | if with_prob: 34 | probs = torch.nn.functional.softmax(outputs, dim=1) 35 | else: 36 | probs = None 37 | y_pred.append(preds.cpu().numpy()) 38 | y_true.append(labels.cpu().numpy()) 39 | if with_prob: 40 | y_prob.append(probs.detach().cpu().numpy()) 41 | y_pred = np.concatenate(y_pred) 42 | y_true = np.concatenate(y_true) 43 | if with_prob: 44 | y_prob = np.concatenate(y_prob) 45 | return y_pred, y_true, y_prob 46 | 47 | def print_roc_curve(y_test, y_score, n_classes, figsize = (8, 6)): 48 | lw = 2 49 | # Compute ROC curve and ROC area for each class 50 | fpr = dict() 51 | tpr = dict() 52 | roc_auc = dict() 53 | for i in range(n_classes): 54 | fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i]) 55 | roc_auc[i] = auc(fpr[i], tpr[i]) 56 | # Compute micro-average ROC curve and ROC area 57 | fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel()) 58 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 59 | # First aggregate all false positive rates 60 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) 61 | # Then interpolate all ROC curves at this points 62 | mean_tpr = np.zeros_like(all_fpr) 63 | for i in range(n_classes): 64 | mean_tpr += interp(all_fpr, fpr[i], tpr[i]) 65 | 66 | # Finally average it and compute AUC 67 | mean_tpr /= n_classes 68 | 69 | fpr["macro"] = all_fpr 70 | tpr["macro"] = mean_tpr 71 | roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) 72 | fig = plt.figure(figsize=figsize) 73 | """ 74 | plt.plot(fpr["micro"], tpr["micro"], 75 | label='micro-average ROC curve (area = {0:0.2f})' 76 | ''.format(roc_auc["micro"]), 77 | color='deeppink', linestyle=':', linewidth=4) 78 | """ 79 | plt.plot(fpr["macro"], tpr["macro"], 80 | label='macro-average ROC curve (area = {0:0.2f})' 81 | ''.format(roc_auc["macro"]), 82 | color='navy', linestyle=':', linewidth=4) 83 | plt.plot([0, 1], [0, 1], 'k--', lw=lw) 84 | plt.xlim([0.0, 1.0]) 85 | plt.ylim([0.0, 1.05]) 86 | plt.xlabel('False Positive Rate') 87 | plt.ylabel('True Positive Rate') 88 | #plt.title('Some extension of Receiver operating characteristic to multi-class') 89 | plt.legend(loc="lower right") 90 | return fig 91 | 92 | 93 | def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14): 94 | df_cm = pd.DataFrame( 95 | confusion_matrix, index=class_names, columns=class_names, 96 | ) 97 | fig = plt.figure(figsize=figsize) 98 | try: 99 | heatmap = sns.heatmap(df_cm, annot=True, fmt="d") 100 | except ValueError: 101 | raise ValueError("Confusion matrix values must be integers.") 102 | heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize) 103 | heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize) 104 | plt.ylabel('True label') 105 | plt.xlabel('Predicted label') 106 | return fig 107 | 108 | 109 | def main(args): 110 | # obtain outputs of the model 111 | model = torch.load(args.ckpt) 112 | if args.multilabel: 113 | alloc_label = multi_label 114 | else: 115 | alloc_label = binary_label 116 | test_dataset = EarDataset(binary_dir=args.data_dir, 117 | alloc_label = alloc_label, 118 | transforms=transforms.Compose([Rescale((256, 256)), ToTensor(), Normalize()])) 119 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) 120 | y_pred, y_true, y_score = get_output(model, test_loader) 121 | print(y_pred.shape, y_true.shape, y_score.shape) 122 | 123 | # save the confusion matrix 124 | with open(args.labels, 'r+') as f: 125 | labels = f.readlines() 126 | labels = [l.replace('\n', '') for l in labels] 127 | if not args.multilabel: 128 | labels = ['Normal', 'Abnormal'] 129 | if not os.path.exists(args.result_dir): 130 | os.mkdir(args.result_dir) 131 | cnf_matrix = confusion_matrix(y_true, y_pred, labels=np.arange(len(labels))) 132 | np.set_printoptions(precision=2) 133 | fig = print_confusion_matrix(cnf_matrix, labels, figsize=(16,14), fontsize=10) 134 | fig.savefig(os.path.join(args.result_dir, args.cfmatrix_name)) 135 | 136 | # save the roc curve 137 | y_onehot = np.zeros((y_true.shape[0], len(labels)), dtype=np.uint8) 138 | y_onehot[np.arange(y_true.shape[0]), y_true] = 1 139 | sums = y_onehot.sum(axis=0) 140 | useless_cols = [] 141 | for i, c in enumerate(sums): 142 | if c == 0: 143 | print('useless column {}'.format(i)) 144 | useless_cols.append(i) 145 | useful_cols = np.array([i for i in range(len(labels)) if i not in useless_cols]) 146 | if args.multilabel: 147 | y_onehot = y_onehot[:,useful_cols] 148 | y_score = y_score[:,useful_cols] 149 | fig = print_roc_curve(y_onehot, y_score, useful_cols.shape[0], figsize=(8,6)) 150 | fig.savefig(os.path.join(args.result_dir, args.roc_name)) 151 | 152 | 153 | if __name__ == "__main__": 154 | parser = argparse.ArgumentParser(description="evaluation") 155 | 156 | parser.add_argument('--ckpt', type=str, help='path to checkpoint') 157 | parser.add_argument('--data_dir', type=str, help='path to the dataset') 158 | parser.add_argument('--result_dir', default='results', type=str, help='path in which we save the result') 159 | parser.add_argument('--cfmatrix_name', default='confusion_matrix', type=str, help='fname of confusion matrix') 160 | parser.add_argument('--roc_name', default='roc_curve', type=str, help='fname of roc curve') 161 | parser.add_argument('--multilabel', default=True, type=str2bool, help='if multilabel, then true, else false') 162 | parser.add_argument('--batch_size', default=16, type=int, help='batch size') 163 | parser.add_argument('--labels', default='labels.csv', type=str, help='fname including labels') 164 | 165 | 166 | args = parser.parse_args() 167 | 168 | main(args) 169 | -------------------------------------------------------------------------------- /labels.csv: -------------------------------------------------------------------------------- 1 | OME 2 | AdOM 3 | AR 4 | Otit_cerum 5 | PSC 6 | Normal_HPerf 7 | GT 8 | AOM 9 | Normal_Typical 10 | Otit_typical 11 | Normal_Tymp 12 | Cerum 13 | TP 14 | Otit_Typical 15 | Myri 16 | EAC 17 | CC 18 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | import argparse 3 | 4 | def main(): 5 | parser = argparse.ArgumentParser(description="neural network framework for dicom datasets") 6 | 7 | parser.add_argument('--architecture', default='resnet152', type=str, help='a NN architecture supported by torchvision e.g. resnet152') 8 | parser.add_argument('--output_dim', default=8192, type=int, help='the final hidden layer\'s dim') 9 | parser.add_argument('--num_labels', default=2, type=int, help="# of labels") 10 | parser.add_argument('--k', default=5, type=int, help="\'k\'-fold") 11 | parser.add_argument('--src', type=str, help="all directories must be src-0, src-1, ..., src-k") 12 | parser.add_argument('--lr', default=1e-3, type=float, help="learning rate") 13 | parser.add_argument('--beta_1', default=0.9, type=float, help="first beta value") 14 | parser.add_argument('--beta_2', default=0.999, type=float, help="second beta value") 15 | parser.add_argument('--weight_decay', default=.0, type=float, help="weight decay") 16 | parser.add_argument('--nb_epochs', default=25, type=int, help="# of epochs") 17 | parser.add_argument('--batch_size', default=32, type=int, help="batch size") 18 | parser.add_argument('--start_fold', default=0, type=int, help="start fold") 19 | parser.add_argument('--end_fold', default=0, type=int, help="end fold, if it is 0, it will be interpreted as using full k fold") 20 | 21 | parser = parser.parse_args() 22 | 23 | architecture = parser.architecture 24 | output_dim = parser.output_dim 25 | num_labels = parser.num_labels 26 | k = parser.k 27 | src = [ parser.src + "-%d"%(i) for i in range(k) ] 28 | lr = parser.lr 29 | betas = (parser.beta_1, parser.beta_2) 30 | weight_decay = parser.weight_decay 31 | nb_epochs = parser.nb_epochs 32 | batch_size = parser.batch_size 33 | start_fold = parser.start_fold 34 | end_fold = parser.end_fold 35 | if end_fold == 0: 36 | end_fold = k 37 | if num_labels == 2: 38 | train(architecture, output_dim, k, src, binary_label, num_labels, lr, betas, weight_decay, nb_epochs, batch_size, start_fold, end_fold) 39 | else: 40 | train(architecture, output_dim, k, src, multi_label, num_labels, lr, betas, weight_decay, nb_epochs, batch_size, start_fold, end_fold) 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import glob 4 | import torch 5 | import torchvision 6 | import torchvision.models as models 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch.nn as nn 9 | from torchvision import transforms, utils 10 | from skimage.transform import resize 11 | import collections 12 | import time 13 | import copy 14 | 15 | def binary_label(fnames): 16 | """ 17 | read file names and return their binary classes 18 | 0: Normal 19 | 1: Abnormal 20 | """ 21 | labeled = [] 22 | for f in fnames: 23 | if 'Normal' in f: 24 | labeled.append(0) 25 | else: 26 | labeled.append(1) 27 | return np.array(labeled), ["Normal", "Abnormal"] 28 | 29 | def extract_label(fname): 30 | return fname.split('__')[-1].split('.')[-3] 31 | 32 | def multi_label(fnames): 33 | labeled = [] 34 | with open('./labels.csv', 'r') as f: 35 | label_table = f.readlines() 36 | label_table = [s.replace('\n', '') for s in label_table] 37 | label_dict = {l:i for i, l in enumerate(label_table)} 38 | 39 | for f in fnames: 40 | labeled.append(label_dict[extract_label(f)]) 41 | return np.array(labeled), label_table 42 | 43 | class EarDataset(Dataset): 44 | def __init__(self, binary_dir, alloc_label, transforms=None): 45 | """ 46 | binary_dir: directory where binary files (.npy files) exist 47 | allocate_label: a function to allocate labels 48 | transforms: ex. ToTensor 49 | load all file names 50 | allocate their classes 51 | """ 52 | if not isinstance(binary_dir,str): 53 | self.fnames = [] 54 | for curr_dir in binary_dir: 55 | self.fnames += glob.glob(os.path.join(curr_dir, "*")) 56 | else: 57 | self.fnames = glob.glob(os.path.join(binary_dir, "*")) 58 | self.labels, self.class_names = alloc_label(self.fnames) 59 | assert len(self.fnames) == len(self.labels), "Wrong labels" 60 | self.transforms = transforms 61 | 62 | def __len__(self): 63 | return len(self.fnames) 64 | 65 | def __getitem__(self, idx): 66 | img = np.load(self.fnames[idx]).astype(np.float16) 67 | label = self.labels[idx] 68 | sample = (img, label) 69 | if self.transforms: 70 | try: 71 | sample = self.transforms(sample) 72 | except: 73 | for trs in self.transforms: 74 | sample = trs(sample) 75 | return sample 76 | 77 | class Rescale: 78 | def __init__(self, output_size): 79 | self.output_size = output_size 80 | 81 | def __call__(self, sample): 82 | rescaled = resize(sample[0], self.output_size, mode='constant') 83 | return (rescaled, sample[1]) 84 | 85 | class ToTensor: 86 | def __call__(self, sample): 87 | image, label = sample 88 | # swap color axis because 89 | # numpy image: H x W x C 90 | # torch image: C X H X W 91 | image = image.transpose((2, 0, 1)) 92 | return (torch.FloatTensor(image), label) 93 | class Normalize: 94 | def __call__(self, sample): 95 | image, label = sample 96 | image[:, 0] = (image[:, 0]-0.485)/0.229 97 | image[:, 1] = (image[:, 1]-0.456)/0.224 98 | image[:, 2] = (image[:, 2]-0.406)/0.225 99 | return (image, label) 100 | 101 | 102 | def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, class_names, device, num_epochs=25): 103 | since = time.time() 104 | best_model_wts = copy.deepcopy(model.state_dict()) 105 | best_acc = 0.0 106 | 107 | history = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[]} 108 | 109 | for epoch in range(num_epochs): 110 | print('Epoch {}/{}'.format(epoch, num_epochs - 1)) 111 | print('-' * 10) 112 | 113 | # Each epoch has a training and validation phase 114 | for phase in ['train', 'val']: 115 | if phase == 'train': 116 | if scheduler: 117 | scheduler.step() 118 | model.train() # Set model to training mode 119 | else: 120 | model.eval() # Set model to evaluate mode 121 | 122 | running_loss = 0.0 123 | running_corrects = 0 124 | 125 | # Iterate over data. 126 | for inputs, labels in dataloaders[phase]: 127 | inputs = inputs.to(device) 128 | labels = labels.to(device) 129 | 130 | # zero the parameter gradients 131 | optimizer.zero_grad() 132 | 133 | # forward 134 | # track history if only in train 135 | with torch.set_grad_enabled(phase == 'train'): 136 | outputs = model(inputs) 137 | if isinstance(outputs, tuple): 138 | outputs = outputs[0] 139 | _, preds = torch.max(outputs, 1) 140 | loss = criterion(outputs, labels) 141 | 142 | # backward + optimize only if in training phase 143 | if phase == 'train': 144 | loss.backward() 145 | optimizer.step() 146 | 147 | # statistics 148 | running_loss += loss.item() * inputs.size(0) 149 | running_corrects += torch.sum(preds == labels.data) 150 | 151 | epoch_loss = running_loss / dataset_sizes[phase] 152 | epoch_acc = running_corrects.double() / dataset_sizes[phase] 153 | 154 | print('{} Loss: {:.4f} Acc: {:.4f}'.format( 155 | phase, epoch_loss, epoch_acc)) 156 | history["%s_loss"%(phase)].append(epoch_loss) 157 | history["%s_acc"%(phase)].append(epoch_acc) 158 | # deep copy the model 159 | if phase == 'val' and epoch_acc > best_acc: 160 | best_acc = epoch_acc 161 | best_model_wts = copy.deepcopy(model.state_dict()) 162 | print() 163 | 164 | time_elapsed = time.time() - since 165 | print('Training complete in {:.0f}m {:.0f}s'.format( 166 | time_elapsed // 60, time_elapsed % 60)) 167 | print('Best val Acc: {:4f}'.format(best_acc)) 168 | 169 | # load best model weights 170 | model.load_state_dict(best_model_wts) 171 | return model, history, best_acc 172 | 173 | def save_history(fname, history): 174 | nb_epochs = len(history['train_loss']) 175 | with open(fname, 'w+') as f: 176 | f.write('epoch train_loss train_acc val_loss val_acc\n') 177 | for i in range(nb_epochs): 178 | f.write('%d %.4f %.4f %.4f %.4f\n'%(i, history['train_loss'][i], 179 | history['train_acc'][i], 180 | history['val_loss'][i], 181 | history['val_acc'][i])) 182 | 183 | 184 | 185 | def train(architecture, output_dim, k, src, alloc_label, num_labels=2, lr=1e-3, betas=(0.9, 0.999), weight_decay=0, nb_epochs=25, batch_size=32, start_fold=0, end_fold=None): 186 | """ 187 | k: "k"-fold 188 | src: k src lists 189 | alloc_label: fct to alloc labels 190 | define a dataset and the loader 191 | load a densenet pretrained using ImageNet 192 | train the network 193 | save the model 194 | """ 195 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 196 | for curr_fold in range(start_fold, end_fold): 197 | train_src = [] 198 | for i in range(k): 199 | if i != curr_fold: 200 | train_src.append(src[i]) 201 | test_src = src[curr_fold] 202 | if architecture == "inception_v3": 203 | shape = (299, 299) 204 | output_dim = 2048 205 | else: 206 | shape = (256, 256) 207 | if architecture == 'resnet50' or architecture == 'resnet101' or architecture == 'resnet152': 208 | output_dim = 8192 209 | elif architecture == 'resnet18' or architecture == 'resnet34': 210 | output_dim = 2048 211 | train_dataset = EarDataset(binary_dir=train_src, 212 | alloc_label=alloc_label, 213 | transforms=transforms.Compose([Rescale(shape), ToTensor(), Normalize()])) 214 | test_dataset = EarDataset(binary_dir=test_src, 215 | alloc_label = alloc_label, 216 | transforms=transforms.Compose([Rescale(shape), ToTensor(), Normalize()])) 217 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) 218 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) 219 | dataloaders = {'train':train_loader, 'val':test_loader} 220 | dataset_sizes = {'train':len(train_dataset), 'val':len(test_dataset)} 221 | model_name = "resnet18" 222 | network = models.resnet18(pretrained=True).to(device) 223 | _global = {"network":network, "models":models, "device":device, "model_name":model_name} 224 | exec("network = models.%s(pretrained=True).to(device)\nmodel_name=\'%s\'"%(architecture, architecture),_global) 225 | network = _global['network'] 226 | model_name = _global['model_name'] 227 | print(model_name,"is successfully loaded") 228 | #num_ftrs = network.fc.in_features 229 | #network.fc = nn.Linear(num_ftrs, num_labels).cuda() 230 | network.fc = nn.Linear(output_dim, num_labels).to(device) 231 | class_names = train_dataset.class_names 232 | criterion = torch.nn.CrossEntropyLoss().to(device) 233 | optimizer = torch.optim.Adam(network.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) 234 | trained_model, curr_history, curr_best = train_model(network, criterion, optimizer, None, dataloaders, dataset_sizes, class_names, device, num_epochs=nb_epochs) 235 | save_history("%s_%.4facc_%dth_fold_lr-%.5f_beta1-%.2f_beta2-%.3f.csv"%(architecture, curr_best, curr_fold, lr, betas[0], betas[1]), curr_history) 236 | torch.save(trained_model, "%s_%.4facc_%dth-fold_lr-%.5f_beta1-%.2f_beta2-%.3f.pt"%(architecture, curr_best, curr_fold, lr, betas[0], betas[1])) 237 | 238 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import png 3 | import pydicom 4 | import numpy as np 5 | import sys 6 | 7 | def mri_to_png(mri_file, png_file): 8 | """ Function to convert from a DICOM image to png 9 | 10 | @param mri_file: An opened file like object to read te dicom data 11 | @param png_file: An opened file like object to write the png data 12 | """ 13 | 14 | # Extracting data from the mri file 15 | plan = pydicom.read_file(mri_file) 16 | shape = plan.pixel_array.shape 17 | 18 | #Convert to float to avoid overflow or underflow losses. 19 | image_2d = plan.pixel_array.astype(float) 20 | 21 | # Rescaling grey scale between 0-255 22 | image_2d_scaled = (np.maximum(image_2d,0) / image_2d.max()) 23 | 24 | np.save(png_file, image_2d_scaled.astype(np.float16)) 25 | 26 | 27 | def convert_file(mri_file_path, png_file_path): 28 | """ Function to convert an MRI binary file to a 29 | PNG image file. 30 | 31 | @param mri_file_path: Full path to the mri file 32 | @param png_file_path: Fill path to the png file 33 | """ 34 | 35 | # Making sure that the mri file exists 36 | if not os.path.exists(mri_file_path): 37 | raise Exception('File "%s" does not exists' % mri_file_path) 38 | 39 | # Making sure the png file does not exist 40 | if os.path.exists(png_file_path): 41 | raise Exception('File "%s" already exists' % png_file_path) 42 | 43 | mri_file = open(mri_file_path, 'rb') 44 | png_file = open(png_file_path, 'wb') 45 | 46 | mri_to_png(mri_file, png_file) 47 | 48 | png_file.close() 49 | 50 | 51 | def convert_folder(mri_folder, png_folder): 52 | """ Convert all MRI files in a folder to png files 53 | in a destination folder 54 | """ 55 | 56 | # Create the folder for the pnd directory structure 57 | os.makedirs(png_folder) 58 | 59 | # Recursively traverse all sub-folders in the path 60 | for mri_sub_folder, subdirs, files in os.walk(mri_folder): 61 | for mri_file in os.listdir(mri_sub_folder): 62 | mri_file_path = os.path.join(mri_sub_folder, mri_file) 63 | 64 | # Make sure path is an actual file 65 | if os.path.isfile(mri_file_path): 66 | 67 | # Replicate the original file structure 68 | rel_path = os.path.relpath(mri_sub_folder, mri_folder) 69 | png_folder_path = os.path.join(png_folder, rel_path) 70 | if not os.path.exists(png_folder_path): 71 | os.makedirs(png_folder_path) 72 | png_file_path = os.path.join(png_folder_path, '%s.npy' % mri_file) 73 | 74 | try: 75 | # Convert the actual file 76 | convert_file(mri_file_path, png_file_path) 77 | print('SUCCESS: %s --> %s' % (mri_file_path, png_file_path)) 78 | except Exception as e: 79 | print('FAIL: %s --> %s : %s' % (mri_file_path, png_file_path, e)) 80 | 81 | convert_folder(sys.argv[1], sys.argv[2]) 82 | -------------------------------------------------------------------------------- /split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import shutil 4 | import sys 5 | import os 6 | 7 | def split_and_save(root_dir, k): 8 | fnames = glob.glob(os.path.join(root_dir, "*")) 9 | np.random.shuffle(fnames) 10 | idx = np.array_split(np.arange(len(fnames)), k) 11 | splitted = [[fnames[j] for j in idx[i]] for i in range(k)] 12 | for i, arr in enumerate(splitted): 13 | dest = os.path.join(root_dir, "..", "%s-%d"%(root_dir.split('/')[-1], i)) 14 | try: 15 | os.mkdir(dest) 16 | except: 17 | pass 18 | for fname in arr: 19 | shutil.copy(fname, os.path.join(dest, fname.split('/')[-1])) 20 | 21 | split_and_save(sys.argv[1], int(sys.argv[2])) 22 | -------------------------------------------------------------------------------- /tuning.sh: -------------------------------------------------------------------------------- 1 | for lr in 0.0001 0.0002 0.0005 0.002 0.001 0.005 2 | do 3 | for beta1 in 0.5 0.6 0.7 0.8 0.9 4 | do 5 | for beta2 in 0.999 0.99 0.9 0.995 6 | do 7 | python main.py --src ../ear-binary --lr $lr --beta_1 $beta1 --beta_2 $beta2 --nb_epochs 10 --start_fold 0 --end_fold 1 8 | done 9 | done 10 | done 11 | --------------------------------------------------------------------------------