├── README.md
├── eval.py
├── labels.csv
├── main.py
├── model.py
├── preprocessing.py
├── split.py
└── tuning.sh
/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-dicom-classification
2 | PyTorch framework to classify dicom (.dcm) files
3 |
4 |
5 | ### Dependencies
6 | ```
7 | python 3.6.4
8 | pytorch 0.4.0
9 | torchvision 0.2.1
10 | numpy 1.14.1
11 | pydicom 1.0.2
12 | scikit-image 0.13.1
13 | ```
14 |
15 |
16 |
17 | ### Usage
18 | ##### CAUTION: You must define your own labeling function in model.py
19 |
20 |
21 | #### preprocess dataset
22 | ```
23 | python preprocessing.py /path/to/src/dir/ /path/to/dest/dir/
24 | ```
25 |
26 |
27 | #### split dataset for k-fold validation
28 | ```
29 | python split.py /path/to/src/dir/ k
30 | ```
31 |
32 |
33 | #### train dataset
34 | ```
35 | python main.py --architecture resnet152 --output_dim 8192 --num_labels 17 --k 5 --src /path/to/src/dir/
36 | ```
37 |
38 |
39 | #### evaluation
40 | ```
41 | python eval.py --ckpt /path/to/checkpoint/ --data_dir /path/to/src/dir/ --multilabel True --batch_size 64 --labels labels.csv
42 | ```
43 |
44 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn.metrics import confusion_matrix, roc_curve, auc
4 | import argparse
5 | from model import *
6 | from matplotlib import pyplot as plt
7 | import itertools
8 | from itertools import cycle
9 | import pandas as pd
10 | import seaborn as sns
11 | from scipy import interp
12 |
13 | def str2bool(s):
14 | if s == "True":
15 | return True
16 | elif s == "False":
17 | return False
18 | else:
19 | raise NotImplementedError
20 |
21 | def get_output(model, loader, with_prob=True):
22 | y_pred, y_true, = [], []
23 | if with_prob:
24 | y_prob = []
25 | else:
26 | y_prob = None
27 | for inputs, labels in loader:
28 | if torch.cuda.is_available():
29 | inputs = inputs.cuda()
30 | labels = labels.cuda()
31 | outputs = model(inputs)
32 | _, preds = torch.max(outputs, 1)
33 | if with_prob:
34 | probs = torch.nn.functional.softmax(outputs, dim=1)
35 | else:
36 | probs = None
37 | y_pred.append(preds.cpu().numpy())
38 | y_true.append(labels.cpu().numpy())
39 | if with_prob:
40 | y_prob.append(probs.detach().cpu().numpy())
41 | y_pred = np.concatenate(y_pred)
42 | y_true = np.concatenate(y_true)
43 | if with_prob:
44 | y_prob = np.concatenate(y_prob)
45 | return y_pred, y_true, y_prob
46 |
47 | def print_roc_curve(y_test, y_score, n_classes, figsize = (8, 6)):
48 | lw = 2
49 | # Compute ROC curve and ROC area for each class
50 | fpr = dict()
51 | tpr = dict()
52 | roc_auc = dict()
53 | for i in range(n_classes):
54 | fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
55 | roc_auc[i] = auc(fpr[i], tpr[i])
56 | # Compute micro-average ROC curve and ROC area
57 | fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
58 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
59 | # First aggregate all false positive rates
60 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
61 | # Then interpolate all ROC curves at this points
62 | mean_tpr = np.zeros_like(all_fpr)
63 | for i in range(n_classes):
64 | mean_tpr += interp(all_fpr, fpr[i], tpr[i])
65 |
66 | # Finally average it and compute AUC
67 | mean_tpr /= n_classes
68 |
69 | fpr["macro"] = all_fpr
70 | tpr["macro"] = mean_tpr
71 | roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
72 | fig = plt.figure(figsize=figsize)
73 | """
74 | plt.plot(fpr["micro"], tpr["micro"],
75 | label='micro-average ROC curve (area = {0:0.2f})'
76 | ''.format(roc_auc["micro"]),
77 | color='deeppink', linestyle=':', linewidth=4)
78 | """
79 | plt.plot(fpr["macro"], tpr["macro"],
80 | label='macro-average ROC curve (area = {0:0.2f})'
81 | ''.format(roc_auc["macro"]),
82 | color='navy', linestyle=':', linewidth=4)
83 | plt.plot([0, 1], [0, 1], 'k--', lw=lw)
84 | plt.xlim([0.0, 1.0])
85 | plt.ylim([0.0, 1.05])
86 | plt.xlabel('False Positive Rate')
87 | plt.ylabel('True Positive Rate')
88 | #plt.title('Some extension of Receiver operating characteristic to multi-class')
89 | plt.legend(loc="lower right")
90 | return fig
91 |
92 |
93 | def print_confusion_matrix(confusion_matrix, class_names, figsize = (10,7), fontsize=14):
94 | df_cm = pd.DataFrame(
95 | confusion_matrix, index=class_names, columns=class_names,
96 | )
97 | fig = plt.figure(figsize=figsize)
98 | try:
99 | heatmap = sns.heatmap(df_cm, annot=True, fmt="d")
100 | except ValueError:
101 | raise ValueError("Confusion matrix values must be integers.")
102 | heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=fontsize)
103 | heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=fontsize)
104 | plt.ylabel('True label')
105 | plt.xlabel('Predicted label')
106 | return fig
107 |
108 |
109 | def main(args):
110 | # obtain outputs of the model
111 | model = torch.load(args.ckpt)
112 | if args.multilabel:
113 | alloc_label = multi_label
114 | else:
115 | alloc_label = binary_label
116 | test_dataset = EarDataset(binary_dir=args.data_dir,
117 | alloc_label = alloc_label,
118 | transforms=transforms.Compose([Rescale((256, 256)), ToTensor(), Normalize()]))
119 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=8)
120 | y_pred, y_true, y_score = get_output(model, test_loader)
121 | print(y_pred.shape, y_true.shape, y_score.shape)
122 |
123 | # save the confusion matrix
124 | with open(args.labels, 'r+') as f:
125 | labels = f.readlines()
126 | labels = [l.replace('\n', '') for l in labels]
127 | if not args.multilabel:
128 | labels = ['Normal', 'Abnormal']
129 | if not os.path.exists(args.result_dir):
130 | os.mkdir(args.result_dir)
131 | cnf_matrix = confusion_matrix(y_true, y_pred, labels=np.arange(len(labels)))
132 | np.set_printoptions(precision=2)
133 | fig = print_confusion_matrix(cnf_matrix, labels, figsize=(16,14), fontsize=10)
134 | fig.savefig(os.path.join(args.result_dir, args.cfmatrix_name))
135 |
136 | # save the roc curve
137 | y_onehot = np.zeros((y_true.shape[0], len(labels)), dtype=np.uint8)
138 | y_onehot[np.arange(y_true.shape[0]), y_true] = 1
139 | sums = y_onehot.sum(axis=0)
140 | useless_cols = []
141 | for i, c in enumerate(sums):
142 | if c == 0:
143 | print('useless column {}'.format(i))
144 | useless_cols.append(i)
145 | useful_cols = np.array([i for i in range(len(labels)) if i not in useless_cols])
146 | if args.multilabel:
147 | y_onehot = y_onehot[:,useful_cols]
148 | y_score = y_score[:,useful_cols]
149 | fig = print_roc_curve(y_onehot, y_score, useful_cols.shape[0], figsize=(8,6))
150 | fig.savefig(os.path.join(args.result_dir, args.roc_name))
151 |
152 |
153 | if __name__ == "__main__":
154 | parser = argparse.ArgumentParser(description="evaluation")
155 |
156 | parser.add_argument('--ckpt', type=str, help='path to checkpoint')
157 | parser.add_argument('--data_dir', type=str, help='path to the dataset')
158 | parser.add_argument('--result_dir', default='results', type=str, help='path in which we save the result')
159 | parser.add_argument('--cfmatrix_name', default='confusion_matrix', type=str, help='fname of confusion matrix')
160 | parser.add_argument('--roc_name', default='roc_curve', type=str, help='fname of roc curve')
161 | parser.add_argument('--multilabel', default=True, type=str2bool, help='if multilabel, then true, else false')
162 | parser.add_argument('--batch_size', default=16, type=int, help='batch size')
163 | parser.add_argument('--labels', default='labels.csv', type=str, help='fname including labels')
164 |
165 |
166 | args = parser.parse_args()
167 |
168 | main(args)
169 |
--------------------------------------------------------------------------------
/labels.csv:
--------------------------------------------------------------------------------
1 | OME
2 | AdOM
3 | AR
4 | Otit_cerum
5 | PSC
6 | Normal_HPerf
7 | GT
8 | AOM
9 | Normal_Typical
10 | Otit_typical
11 | Normal_Tymp
12 | Cerum
13 | TP
14 | Otit_Typical
15 | Myri
16 | EAC
17 | CC
18 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from model import *
2 | import argparse
3 |
4 | def main():
5 | parser = argparse.ArgumentParser(description="neural network framework for dicom datasets")
6 |
7 | parser.add_argument('--architecture', default='resnet152', type=str, help='a NN architecture supported by torchvision e.g. resnet152')
8 | parser.add_argument('--output_dim', default=8192, type=int, help='the final hidden layer\'s dim')
9 | parser.add_argument('--num_labels', default=2, type=int, help="# of labels")
10 | parser.add_argument('--k', default=5, type=int, help="\'k\'-fold")
11 | parser.add_argument('--src', type=str, help="all directories must be src-0, src-1, ..., src-k")
12 | parser.add_argument('--lr', default=1e-3, type=float, help="learning rate")
13 | parser.add_argument('--beta_1', default=0.9, type=float, help="first beta value")
14 | parser.add_argument('--beta_2', default=0.999, type=float, help="second beta value")
15 | parser.add_argument('--weight_decay', default=.0, type=float, help="weight decay")
16 | parser.add_argument('--nb_epochs', default=25, type=int, help="# of epochs")
17 | parser.add_argument('--batch_size', default=32, type=int, help="batch size")
18 | parser.add_argument('--start_fold', default=0, type=int, help="start fold")
19 | parser.add_argument('--end_fold', default=0, type=int, help="end fold, if it is 0, it will be interpreted as using full k fold")
20 |
21 | parser = parser.parse_args()
22 |
23 | architecture = parser.architecture
24 | output_dim = parser.output_dim
25 | num_labels = parser.num_labels
26 | k = parser.k
27 | src = [ parser.src + "-%d"%(i) for i in range(k) ]
28 | lr = parser.lr
29 | betas = (parser.beta_1, parser.beta_2)
30 | weight_decay = parser.weight_decay
31 | nb_epochs = parser.nb_epochs
32 | batch_size = parser.batch_size
33 | start_fold = parser.start_fold
34 | end_fold = parser.end_fold
35 | if end_fold == 0:
36 | end_fold = k
37 | if num_labels == 2:
38 | train(architecture, output_dim, k, src, binary_label, num_labels, lr, betas, weight_decay, nb_epochs, batch_size, start_fold, end_fold)
39 | else:
40 | train(architecture, output_dim, k, src, multi_label, num_labels, lr, betas, weight_decay, nb_epochs, batch_size, start_fold, end_fold)
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import glob
4 | import torch
5 | import torchvision
6 | import torchvision.models as models
7 | from torch.utils.data import Dataset, DataLoader
8 | import torch.nn as nn
9 | from torchvision import transforms, utils
10 | from skimage.transform import resize
11 | import collections
12 | import time
13 | import copy
14 |
15 | def binary_label(fnames):
16 | """
17 | read file names and return their binary classes
18 | 0: Normal
19 | 1: Abnormal
20 | """
21 | labeled = []
22 | for f in fnames:
23 | if 'Normal' in f:
24 | labeled.append(0)
25 | else:
26 | labeled.append(1)
27 | return np.array(labeled), ["Normal", "Abnormal"]
28 |
29 | def extract_label(fname):
30 | return fname.split('__')[-1].split('.')[-3]
31 |
32 | def multi_label(fnames):
33 | labeled = []
34 | with open('./labels.csv', 'r') as f:
35 | label_table = f.readlines()
36 | label_table = [s.replace('\n', '') for s in label_table]
37 | label_dict = {l:i for i, l in enumerate(label_table)}
38 |
39 | for f in fnames:
40 | labeled.append(label_dict[extract_label(f)])
41 | return np.array(labeled), label_table
42 |
43 | class EarDataset(Dataset):
44 | def __init__(self, binary_dir, alloc_label, transforms=None):
45 | """
46 | binary_dir: directory where binary files (.npy files) exist
47 | allocate_label: a function to allocate labels
48 | transforms: ex. ToTensor
49 | load all file names
50 | allocate their classes
51 | """
52 | if not isinstance(binary_dir,str):
53 | self.fnames = []
54 | for curr_dir in binary_dir:
55 | self.fnames += glob.glob(os.path.join(curr_dir, "*"))
56 | else:
57 | self.fnames = glob.glob(os.path.join(binary_dir, "*"))
58 | self.labels, self.class_names = alloc_label(self.fnames)
59 | assert len(self.fnames) == len(self.labels), "Wrong labels"
60 | self.transforms = transforms
61 |
62 | def __len__(self):
63 | return len(self.fnames)
64 |
65 | def __getitem__(self, idx):
66 | img = np.load(self.fnames[idx]).astype(np.float16)
67 | label = self.labels[idx]
68 | sample = (img, label)
69 | if self.transforms:
70 | try:
71 | sample = self.transforms(sample)
72 | except:
73 | for trs in self.transforms:
74 | sample = trs(sample)
75 | return sample
76 |
77 | class Rescale:
78 | def __init__(self, output_size):
79 | self.output_size = output_size
80 |
81 | def __call__(self, sample):
82 | rescaled = resize(sample[0], self.output_size, mode='constant')
83 | return (rescaled, sample[1])
84 |
85 | class ToTensor:
86 | def __call__(self, sample):
87 | image, label = sample
88 | # swap color axis because
89 | # numpy image: H x W x C
90 | # torch image: C X H X W
91 | image = image.transpose((2, 0, 1))
92 | return (torch.FloatTensor(image), label)
93 | class Normalize:
94 | def __call__(self, sample):
95 | image, label = sample
96 | image[:, 0] = (image[:, 0]-0.485)/0.229
97 | image[:, 1] = (image[:, 1]-0.456)/0.224
98 | image[:, 2] = (image[:, 2]-0.406)/0.225
99 | return (image, label)
100 |
101 |
102 | def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, class_names, device, num_epochs=25):
103 | since = time.time()
104 | best_model_wts = copy.deepcopy(model.state_dict())
105 | best_acc = 0.0
106 |
107 | history = {'train_loss':[], 'train_acc':[], 'val_loss':[], 'val_acc':[]}
108 |
109 | for epoch in range(num_epochs):
110 | print('Epoch {}/{}'.format(epoch, num_epochs - 1))
111 | print('-' * 10)
112 |
113 | # Each epoch has a training and validation phase
114 | for phase in ['train', 'val']:
115 | if phase == 'train':
116 | if scheduler:
117 | scheduler.step()
118 | model.train() # Set model to training mode
119 | else:
120 | model.eval() # Set model to evaluate mode
121 |
122 | running_loss = 0.0
123 | running_corrects = 0
124 |
125 | # Iterate over data.
126 | for inputs, labels in dataloaders[phase]:
127 | inputs = inputs.to(device)
128 | labels = labels.to(device)
129 |
130 | # zero the parameter gradients
131 | optimizer.zero_grad()
132 |
133 | # forward
134 | # track history if only in train
135 | with torch.set_grad_enabled(phase == 'train'):
136 | outputs = model(inputs)
137 | if isinstance(outputs, tuple):
138 | outputs = outputs[0]
139 | _, preds = torch.max(outputs, 1)
140 | loss = criterion(outputs, labels)
141 |
142 | # backward + optimize only if in training phase
143 | if phase == 'train':
144 | loss.backward()
145 | optimizer.step()
146 |
147 | # statistics
148 | running_loss += loss.item() * inputs.size(0)
149 | running_corrects += torch.sum(preds == labels.data)
150 |
151 | epoch_loss = running_loss / dataset_sizes[phase]
152 | epoch_acc = running_corrects.double() / dataset_sizes[phase]
153 |
154 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(
155 | phase, epoch_loss, epoch_acc))
156 | history["%s_loss"%(phase)].append(epoch_loss)
157 | history["%s_acc"%(phase)].append(epoch_acc)
158 | # deep copy the model
159 | if phase == 'val' and epoch_acc > best_acc:
160 | best_acc = epoch_acc
161 | best_model_wts = copy.deepcopy(model.state_dict())
162 | print()
163 |
164 | time_elapsed = time.time() - since
165 | print('Training complete in {:.0f}m {:.0f}s'.format(
166 | time_elapsed // 60, time_elapsed % 60))
167 | print('Best val Acc: {:4f}'.format(best_acc))
168 |
169 | # load best model weights
170 | model.load_state_dict(best_model_wts)
171 | return model, history, best_acc
172 |
173 | def save_history(fname, history):
174 | nb_epochs = len(history['train_loss'])
175 | with open(fname, 'w+') as f:
176 | f.write('epoch train_loss train_acc val_loss val_acc\n')
177 | for i in range(nb_epochs):
178 | f.write('%d %.4f %.4f %.4f %.4f\n'%(i, history['train_loss'][i],
179 | history['train_acc'][i],
180 | history['val_loss'][i],
181 | history['val_acc'][i]))
182 |
183 |
184 |
185 | def train(architecture, output_dim, k, src, alloc_label, num_labels=2, lr=1e-3, betas=(0.9, 0.999), weight_decay=0, nb_epochs=25, batch_size=32, start_fold=0, end_fold=None):
186 | """
187 | k: "k"-fold
188 | src: k src lists
189 | alloc_label: fct to alloc labels
190 | define a dataset and the loader
191 | load a densenet pretrained using ImageNet
192 | train the network
193 | save the model
194 | """
195 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
196 | for curr_fold in range(start_fold, end_fold):
197 | train_src = []
198 | for i in range(k):
199 | if i != curr_fold:
200 | train_src.append(src[i])
201 | test_src = src[curr_fold]
202 | if architecture == "inception_v3":
203 | shape = (299, 299)
204 | output_dim = 2048
205 | else:
206 | shape = (256, 256)
207 | if architecture == 'resnet50' or architecture == 'resnet101' or architecture == 'resnet152':
208 | output_dim = 8192
209 | elif architecture == 'resnet18' or architecture == 'resnet34':
210 | output_dim = 2048
211 | train_dataset = EarDataset(binary_dir=train_src,
212 | alloc_label=alloc_label,
213 | transforms=transforms.Compose([Rescale(shape), ToTensor(), Normalize()]))
214 | test_dataset = EarDataset(binary_dir=test_src,
215 | alloc_label = alloc_label,
216 | transforms=transforms.Compose([Rescale(shape), ToTensor(), Normalize()]))
217 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
218 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
219 | dataloaders = {'train':train_loader, 'val':test_loader}
220 | dataset_sizes = {'train':len(train_dataset), 'val':len(test_dataset)}
221 | model_name = "resnet18"
222 | network = models.resnet18(pretrained=True).to(device)
223 | _global = {"network":network, "models":models, "device":device, "model_name":model_name}
224 | exec("network = models.%s(pretrained=True).to(device)\nmodel_name=\'%s\'"%(architecture, architecture),_global)
225 | network = _global['network']
226 | model_name = _global['model_name']
227 | print(model_name,"is successfully loaded")
228 | #num_ftrs = network.fc.in_features
229 | #network.fc = nn.Linear(num_ftrs, num_labels).cuda()
230 | network.fc = nn.Linear(output_dim, num_labels).to(device)
231 | class_names = train_dataset.class_names
232 | criterion = torch.nn.CrossEntropyLoss().to(device)
233 | optimizer = torch.optim.Adam(network.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
234 | trained_model, curr_history, curr_best = train_model(network, criterion, optimizer, None, dataloaders, dataset_sizes, class_names, device, num_epochs=nb_epochs)
235 | save_history("%s_%.4facc_%dth_fold_lr-%.5f_beta1-%.2f_beta2-%.3f.csv"%(architecture, curr_best, curr_fold, lr, betas[0], betas[1]), curr_history)
236 | torch.save(trained_model, "%s_%.4facc_%dth-fold_lr-%.5f_beta1-%.2f_beta2-%.3f.pt"%(architecture, curr_best, curr_fold, lr, betas[0], betas[1]))
237 |
238 |
--------------------------------------------------------------------------------
/preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import png
3 | import pydicom
4 | import numpy as np
5 | import sys
6 |
7 | def mri_to_png(mri_file, png_file):
8 | """ Function to convert from a DICOM image to png
9 |
10 | @param mri_file: An opened file like object to read te dicom data
11 | @param png_file: An opened file like object to write the png data
12 | """
13 |
14 | # Extracting data from the mri file
15 | plan = pydicom.read_file(mri_file)
16 | shape = plan.pixel_array.shape
17 |
18 | #Convert to float to avoid overflow or underflow losses.
19 | image_2d = plan.pixel_array.astype(float)
20 |
21 | # Rescaling grey scale between 0-255
22 | image_2d_scaled = (np.maximum(image_2d,0) / image_2d.max())
23 |
24 | np.save(png_file, image_2d_scaled.astype(np.float16))
25 |
26 |
27 | def convert_file(mri_file_path, png_file_path):
28 | """ Function to convert an MRI binary file to a
29 | PNG image file.
30 |
31 | @param mri_file_path: Full path to the mri file
32 | @param png_file_path: Fill path to the png file
33 | """
34 |
35 | # Making sure that the mri file exists
36 | if not os.path.exists(mri_file_path):
37 | raise Exception('File "%s" does not exists' % mri_file_path)
38 |
39 | # Making sure the png file does not exist
40 | if os.path.exists(png_file_path):
41 | raise Exception('File "%s" already exists' % png_file_path)
42 |
43 | mri_file = open(mri_file_path, 'rb')
44 | png_file = open(png_file_path, 'wb')
45 |
46 | mri_to_png(mri_file, png_file)
47 |
48 | png_file.close()
49 |
50 |
51 | def convert_folder(mri_folder, png_folder):
52 | """ Convert all MRI files in a folder to png files
53 | in a destination folder
54 | """
55 |
56 | # Create the folder for the pnd directory structure
57 | os.makedirs(png_folder)
58 |
59 | # Recursively traverse all sub-folders in the path
60 | for mri_sub_folder, subdirs, files in os.walk(mri_folder):
61 | for mri_file in os.listdir(mri_sub_folder):
62 | mri_file_path = os.path.join(mri_sub_folder, mri_file)
63 |
64 | # Make sure path is an actual file
65 | if os.path.isfile(mri_file_path):
66 |
67 | # Replicate the original file structure
68 | rel_path = os.path.relpath(mri_sub_folder, mri_folder)
69 | png_folder_path = os.path.join(png_folder, rel_path)
70 | if not os.path.exists(png_folder_path):
71 | os.makedirs(png_folder_path)
72 | png_file_path = os.path.join(png_folder_path, '%s.npy' % mri_file)
73 |
74 | try:
75 | # Convert the actual file
76 | convert_file(mri_file_path, png_file_path)
77 | print('SUCCESS: %s --> %s' % (mri_file_path, png_file_path))
78 | except Exception as e:
79 | print('FAIL: %s --> %s : %s' % (mri_file_path, png_file_path, e))
80 |
81 | convert_folder(sys.argv[1], sys.argv[2])
82 |
--------------------------------------------------------------------------------
/split.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | import shutil
4 | import sys
5 | import os
6 |
7 | def split_and_save(root_dir, k):
8 | fnames = glob.glob(os.path.join(root_dir, "*"))
9 | np.random.shuffle(fnames)
10 | idx = np.array_split(np.arange(len(fnames)), k)
11 | splitted = [[fnames[j] for j in idx[i]] for i in range(k)]
12 | for i, arr in enumerate(splitted):
13 | dest = os.path.join(root_dir, "..", "%s-%d"%(root_dir.split('/')[-1], i))
14 | try:
15 | os.mkdir(dest)
16 | except:
17 | pass
18 | for fname in arr:
19 | shutil.copy(fname, os.path.join(dest, fname.split('/')[-1]))
20 |
21 | split_and_save(sys.argv[1], int(sys.argv[2]))
22 |
--------------------------------------------------------------------------------
/tuning.sh:
--------------------------------------------------------------------------------
1 | for lr in 0.0001 0.0002 0.0005 0.002 0.001 0.005
2 | do
3 | for beta1 in 0.5 0.6 0.7 0.8 0.9
4 | do
5 | for beta2 in 0.999 0.99 0.9 0.995
6 | do
7 | python main.py --src ../ear-binary --lr $lr --beta_1 $beta1 --beta_2 $beta2 --nb_epochs 10 --start_fold 0 --end_fold 1
8 | done
9 | done
10 | done
11 |
--------------------------------------------------------------------------------