├── README.md ├── download.sh ├── environment.yml ├── evaluate.py ├── loader.py ├── model.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # MRNet 2 | 3 | Dataset from Clinical Hospital Centre Rijeka, Croatia, originally appears in: 4 | 5 | I. Štajduhar, M. Mamula, D. Miletić, G. Unal, Semi-automated detection of anterior cruciate ligament injury from MRI, Computer Methods and Programs in Biomedicine, Volume 140, 2017, Pages 151–164. (http://www.riteh.uniri.hr/~istajduh/projects/kneeMRI/data/Stajduhar2017.pdf) 6 | 7 | ## Setup 8 | 9 | `bash download.sh` (caution: downloads ~6.68 GB of data) 10 | 11 | `conda env create -f environment.yml` 12 | 13 | `source activate mrnet` 14 | 15 | ## Train 16 | 17 | `python train.py --rundir [experiment name] --diagnosis 0 --gpu` 18 | 19 | - diagnosis is highest diagnosis allowed for negative label (0 = injury task, 1 = tear task) 20 | - arguments saved at `[experiment-name]/args.json` 21 | - prints training & validation metrics (loss & AUC) after each epoch 22 | - models saved at `[experiment-name]/[val_loss]_[train_loss]_epoch[epoch_num]` 23 | 24 | ## Evaluate 25 | 26 | `python evaluate.py --split [train/valid/test] --diagnosis 0 --model_path [experiment-name]/[val_loss]_[train_loss]_epoch[epoch_num] --gpu` 27 | 28 | - prints loss & AUC 29 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://www.riteh.uniri.hr/~istajduh/projects/kneeMRI/data/metadata.csv 4 | 5 | for i in '01' '02' '03' '04' '05' '06' '07' '08' '09' '10' 6 | 7 | do 8 | wget http://www.riteh.uniri.hr/~istajduh/projects/kneeMRI/data/volumetric_data/vol$i.7z 9 | mkdir vol$i; mv vol$i.7z vol$i 10 | cd vol$i 11 | 7za e vol$i.7z; rm vol$i.7z 12 | cd .. 13 | done 14 | 15 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mrnet 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - blas=1.0=mkl 7 | - ca-certificates=2018.03.07=0 8 | - certifi=2018.8.24=py36_1 9 | - cffi=1.11.5=py36he75722e_1 10 | - freetype=2.9.1=h8a8886c_1 11 | - intel-openmp=2018.0.3=0 12 | - jpeg=9b=h024ee3a_2 13 | - libedit=3.1.20170329=h6b74fdf_2 14 | - libffi=3.2.1=hd88cf55_4 15 | - libgcc-ng=8.2.0=hdf63c60_1 16 | - libgfortran-ng=7.3.0=hdf63c60_0 17 | - libpng=1.6.34=hb9fc6fc_0 18 | - libstdcxx-ng=8.2.0=hdf63c60_1 19 | - libtiff=4.0.9=he85c1e1_2 20 | - mkl=2018.0.3=1 21 | - mkl_fft=1.0.4=py36h4414c95_1 22 | - mkl_random=1.0.1=py36h4414c95_1 23 | - ncurses=6.1=hf484d3e_0 24 | - ninja=1.8.2=py36h6bb024c_1 25 | - numpy=1.15.1=py36h1d66e8a_0 26 | - numpy-base=1.15.1=py36h81de0dd_0 27 | - olefile=0.45.1=py36_0 28 | - openssl=1.0.2p=h14c3975_0 29 | - pillow=5.2.0=py36heded4f4_0 30 | - pip=10.0.1=py36_0 31 | - pycparser=2.18=py36_1 32 | - python=3.6.6=hc3d631a_0 33 | - readline=7.0=h7b6447c_5 34 | - scikit-learn=0.19.1=py36hedc7406_0 35 | - scipy=1.1.0=py36hfa4b5c9_1 36 | - setuptools=40.2.0=py36_0 37 | - six=1.11.0=py36_1 38 | - sqlite=3.24.0=h84994c4_0 39 | - tk=8.6.8=hbc83047_0 40 | - wheel=0.31.1=py36_0 41 | - xz=5.2.4=h14c3975_4 42 | - zlib=1.2.11=ha838bed_2 43 | - pytorch=0.4.1=py36_py35_py27__9.0.176_7.1.2_2 44 | - torchvision=0.2.1=py36_1 45 | - pip: 46 | - torch==0.4.1.post2 47 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import matplotlib.pyplot as plt 3 | import os 4 | import numpy as np 5 | import torch 6 | 7 | from sklearn import metrics 8 | from torch.autograd import Variable 9 | 10 | from loader import load_data 11 | from model import MRNet 12 | 13 | def get_parser(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--model_path', type=str, required=True) 16 | parser.add_argument('--split', type=str, required=True) 17 | parser.add_argument('--diagnosis', type=int, required=True) 18 | parser.add_argument('--gpu', action='store_true') 19 | return parser 20 | 21 | def run_model(model, loader, train=False, optimizer=None): 22 | preds = [] 23 | labels = [] 24 | 25 | if train: 26 | model.train() 27 | else: 28 | model.eval() 29 | 30 | total_loss = 0. 31 | num_batches = 0 32 | 33 | for batch in loader: 34 | if train: 35 | optimizer.zero_grad() 36 | 37 | vol, label = batch 38 | if loader.dataset.use_gpu: 39 | vol = vol.cuda() 40 | label = label.cuda() 41 | vol = Variable(vol) 42 | label = Variable(label) 43 | 44 | logit = model.forward(vol) 45 | 46 | loss = loader.dataset.weighted_loss(logit, label) 47 | total_loss += loss.item() 48 | 49 | pred = torch.sigmoid(logit) 50 | pred_npy = pred.data.cpu().numpy()[0][0] 51 | label_npy = label.data.cpu().numpy()[0][0] 52 | 53 | preds.append(pred_npy) 54 | labels.append(label_npy) 55 | 56 | if train: 57 | loss.backward() 58 | optimizer.step() 59 | num_batches += 1 60 | 61 | avg_loss = total_loss / num_batches 62 | 63 | fpr, tpr, threshold = metrics.roc_curve(labels, preds) 64 | auc = metrics.auc(fpr, tpr) 65 | 66 | return avg_loss, auc, preds, labels 67 | 68 | def evaluate(split, model_path, diagnosis, use_gpu): 69 | train_loader, valid_loader, test_loader = load_data(diagnosis, use_gpu) 70 | 71 | model = MRNet() 72 | state_dict = torch.load(model_path, map_location=(None if use_gpu else 'cpu')) 73 | model.load_state_dict(state_dict) 74 | 75 | if use_gpu: 76 | model = model.cuda() 77 | 78 | if split == 'train': 79 | loader = train_loader 80 | elif split == 'valid': 81 | loader = valid_loader 82 | elif split == 'test': 83 | loader = test_loader 84 | else: 85 | raise ValueError("split must be 'train', 'valid', or 'test'") 86 | 87 | loss, auc, preds, labels = run_model(model, loader) 88 | 89 | print(f'{split} loss: {loss:0.4f}') 90 | print(f'{split} AUC: {auc:0.4f}') 91 | 92 | return preds, labels 93 | 94 | if __name__ == '__main__': 95 | args = get_parser().parse_args() 96 | evaluate(args.split, args.model_path, args.diagnosis, args.gpu) 97 | -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.utils.data as data 7 | 8 | from torch.autograd import Variable 9 | 10 | INPUT_DIM = 224 11 | MAX_PIXEL_VAL = 255 12 | MEAN = 58.09 13 | STDDEV = 49.73 14 | 15 | class Dataset(data.Dataset): 16 | def __init__(self, datadirs, diagnosis, use_gpu): 17 | super().__init__() 18 | self.use_gpu = use_gpu 19 | 20 | label_dict = {} 21 | self.paths = [] 22 | 23 | for i, line in enumerate(open('metadata.csv').readlines()): 24 | if i == 0: 25 | continue 26 | line = line.strip().split(',') 27 | path = line[10] 28 | label = line[2] 29 | label_dict[path] = int(int(label) > diagnosis) 30 | 31 | for dir in datadirs: 32 | for file in os.listdir(dir): 33 | self.paths.append(dir+'/'+file) 34 | 35 | self.labels = [label_dict[path[6:]] for path in self.paths] 36 | 37 | neg_weight = np.mean(self.labels) 38 | self.weights = [neg_weight, 1 - neg_weight] 39 | 40 | def weighted_loss(self, prediction, target): 41 | weights_npy = np.array([self.weights[int(t[0])] for t in target.data]) 42 | weights_tensor = torch.FloatTensor(weights_npy) 43 | if self.use_gpu: 44 | weights_tensor = weights_tensor.cuda() 45 | loss = F.binary_cross_entropy_with_logits(prediction, target, weight=Variable(weights_tensor)) 46 | return loss 47 | 48 | def __getitem__(self, index): 49 | path = self.paths[index] 50 | with open(path, 'rb') as file_handler: # Must use 'rb' as the data is binary 51 | vol = pickle.load(file_handler).astype(np.int32) 52 | 53 | # crop middle 54 | pad = int((vol.shape[2] - INPUT_DIM)/2) 55 | vol = vol[:,pad:-pad,pad:-pad] 56 | 57 | # standardize 58 | vol = (vol - np.min(vol)) / (np.max(vol) - np.min(vol)) * MAX_PIXEL_VAL 59 | 60 | # normalize 61 | vol = (vol - MEAN) / STDDEV 62 | 63 | # convert to RGB 64 | vol = np.stack((vol,)*3, axis=1) 65 | 66 | vol_tensor = torch.FloatTensor(vol) 67 | label_tensor = torch.FloatTensor([self.labels[index]]) 68 | 69 | return vol_tensor, label_tensor 70 | 71 | def __len__(self): 72 | return len(self.paths) 73 | 74 | def load_data(diagnosis, use_gpu=False): 75 | train_dirs = ['vol08','vol04','vol03','vol09','vol06','vol07'] 76 | valid_dirs = ['vol10','vol05'] 77 | test_dirs = ['vol01','vol02'] 78 | 79 | train_dataset = Dataset(train_dirs, diagnosis, use_gpu) 80 | valid_dataset = Dataset(valid_dirs, diagnosis, use_gpu) 81 | test_dataset = Dataset(test_dirs, diagnosis, use_gpu) 82 | 83 | train_loader = data.DataLoader(train_dataset, batch_size=1, num_workers=8, shuffle=True) 84 | valid_loader = data.DataLoader(valid_dataset, batch_size=1, num_workers=8, shuffle=False) 85 | test_loader = data.DataLoader(test_dataset, batch_size=1, num_workers=8, shuffle=False) 86 | 87 | return train_loader, valid_loader, test_loader 88 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchvision import models 5 | 6 | class MRNet(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.model = models.alexnet(pretrained=True) 10 | self.gap = nn.AdaptiveAvgPool2d(1) 11 | self.classifier = nn.Linear(256, 1) 12 | 13 | def forward(self, x): 14 | x = torch.squeeze(x, dim=0) # only batch size 1 supported 15 | x = self.model.features(x) 16 | x = self.gap(x).view(x.size(0), -1) 17 | x = torch.max(x, 0, keepdim=True)[0] 18 | x = self.classifier(x) 19 | return x 20 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import os 5 | import torch 6 | 7 | from datetime import datetime 8 | from pathlib import Path 9 | from sklearn import metrics 10 | 11 | from evaluate import run_model 12 | from loader import load_data 13 | from model import MRNet 14 | 15 | def train(rundir, diagnosis, epochs, learning_rate, use_gpu): 16 | train_loader, valid_loader, test_loader = load_data(diagnosis, use_gpu) 17 | 18 | model = MRNet() 19 | 20 | if use_gpu: 21 | model = model.cuda() 22 | 23 | optimizer = torch.optim.Adam(model.parameters(), learning_rate, weight_decay=.01) 24 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=.3, threshold=1e-4) 25 | 26 | best_val_loss = float('inf') 27 | 28 | start_time = datetime.now() 29 | 30 | for epoch in range(epochs): 31 | change = datetime.now() - start_time 32 | print('starting epoch {}. time passed: {}'.format(epoch+1, str(change))) 33 | 34 | train_loss, train_auc, _, _ = run_model(model, train_loader, train=True, optimizer=optimizer) 35 | print(f'train loss: {train_loss:0.4f}') 36 | print(f'train AUC: {train_auc:0.4f}') 37 | 38 | val_loss, val_auc, _, _ = run_model(model, valid_loader) 39 | print(f'valid loss: {val_loss:0.4f}') 40 | print(f'valid AUC: {val_auc:0.4f}') 41 | 42 | scheduler.step(val_loss) 43 | 44 | if val_loss < best_val_loss: 45 | best_val_loss = val_loss 46 | 47 | file_name = f'val{val_loss:0.4f}_train{train_loss:0.4f}_epoch{epoch+1}' 48 | save_path = Path(rundir) / file_name 49 | torch.save(model.state_dict(), save_path) 50 | 51 | def get_parser(): 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('--rundir', type=str, required=True) 54 | parser.add_argument('--diagnosis', type=int, required=True) 55 | parser.add_argument('--seed', default=42, type=int) 56 | parser.add_argument('--gpu', action='store_true') 57 | parser.add_argument('--learning_rate', default=1e-05, type=float) 58 | parser.add_argument('--weight_decay', default=0.01, type=float) 59 | parser.add_argument('--epochs', default=50, type=int) 60 | parser.add_argument('--max_patience', default=5, type=int) 61 | parser.add_argument('--factor', default=0.3, type=float) 62 | return parser 63 | 64 | if __name__ == '__main__': 65 | args = get_parser().parse_args() 66 | 67 | np.random.seed(args.seed) 68 | torch.manual_seed(args.seed) 69 | if args.gpu: 70 | torch.cuda.manual_seed_all(args.seed) 71 | 72 | os.makedirs(args.rundir, exist_ok=True) 73 | 74 | with open(Path(args.rundir) / 'args.json', 'w') as out: 75 | json.dump(vars(args), out, indent=4) 76 | 77 | train(args.rundir, args.diagnosis, args.epochs, args.learning_rate, args.gpu) 78 | --------------------------------------------------------------------------------