├── .style.yapf ├── .gitattributes ├── scripts ├── NASNetMobile.sh ├── MobileNet.sh └── DenseNet169.sh ├── models ├── MobileNet.hdf5 ├── DenseNet169.hdf5 └── NASNetMobile.hdf5 ├── .gitignore ├── LICENSE ├── metrics.py ├── README.md ├── eval.py ├── download_and_convert_mura.py ├── mura.py ├── pytorch ├── dataloader.py └── train.py └── train.py /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | indent_width = 4 4 | column_limit = 120 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.psd filter=lfs diff=lfs merge=lfs -text 2 | *.hdf5 filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /scripts/NASNetMobile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --epochs 120 \ 4 | --resume 'models/86_NASNetMobile_2018-01-04T15:15:37.hdf5' -------------------------------------------------------------------------------- /models/MobileNet.hdf5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:229a1c4d0414337aa21f848fffb6ad9f7acb5dd96443053d613486c264e4d356 3 | size 38935688 4 | -------------------------------------------------------------------------------- /models/DenseNet169.hdf5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a42d7a17f412646634ad128135403b38fdbab533647ce6df28669e5ff24d8fad 3 | size 152693520 4 | -------------------------------------------------------------------------------- /models/NASNetMobile.hdf5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fe5f54a66ce1db8afc138cbdaf44d09a41e5006e7ac06b704bc9d64c95f1ef1c 3 | size 55106408 4 | -------------------------------------------------------------------------------- /scripts/MobileNet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --epochs 5 \ 4 | --b 32 5 | 6 | python train.py \ 7 | --b 32 \ 8 | --resume ./models/MobileNet.hdf5 \ 9 | --lr 1e-4 10 | 11 | python train.py \ 12 | --b 32 \ 13 | --resume ./models/MobileNet.hdf5 \ 14 | --lr 1e-5 15 | 16 | python train.py \ 17 | --b 32 \ 18 | --resume ./models/MobileNet.hdf5 \ 19 | --lr 1e-6 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | */*.pyc 2 | */**/*.pyc 3 | */**/**/*.pyc 4 | */**/**/**/*.pyc 5 | */**/**/**/**/*.pyc 6 | */*.so* 7 | */**/*.so* 8 | */**/*.dylib* 9 | 10 | # IPython notebook checkpoints 11 | .ipynb_checkpoints 12 | 13 | # Editor temporaries 14 | *.swn 15 | *.swo 16 | *.swp 17 | *~ 18 | 19 | # macOS dir files 20 | .DS_Store 21 | 22 | MURA-v1.0/ 23 | runs/ 24 | checkpoint.pth.tar 25 | *.pth.tar 26 | data/ 27 | *.h5 28 | logs/ 29 | .fuse_* -------------------------------------------------------------------------------- /scripts/DenseNet169.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py \ 3 | --epochs 10 \ 4 | --b 8 5 | 6 | python train.py \ 7 | --epochs 15 \ 8 | --b 8 \ 9 | --resume ./models/DenseNet169.hdf5 \ 10 | --lr 1e-4 11 | 12 | python train.py \ 13 | --epochs 30 \ 14 | --b 8 \ 15 | --resume ./models/DenseNet169.hdf5 \ 16 | --lr 1e-5 17 | 18 | python train.py \ 19 | --epochs 60 \ 20 | --b 8 \ 21 | --resume ./models/DenseNet169.hdf5 \ 22 | --lr 1e-6 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Bobby D DeSimone 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import keras 5 | import numpy as np 6 | from sklearn.metrics import * 7 | 8 | 9 | class SKLearnMetrics(keras.callbacks.Callback): 10 | """ SKLearnMetrics computes various classification metrics at the end of a batch. 11 | Unforunately, doesn't work when used with generators....""" 12 | 13 | def on_train_begin(self, logs={}): 14 | self.confusion = [] 15 | self.precision = [] 16 | self.recall = [] 17 | self.f1s = [] 18 | self.kappa = [] 19 | self.auc = [] 20 | 21 | def on_epoch_end(self, epoch, logs={}): 22 | score = np.asarray(self.model.predict(self.validation_data[0])) 23 | predict = np.round(np.asarray(self.model.predict(self.validation_data[0]))) 24 | target = self.validation_data[1] 25 | 26 | self.auc.append(roc_auc_score(target, score)) 27 | self.confusion.append(confusion_matrix(target, predict)) 28 | self.precision.append(precision_score(target, predict)) 29 | self.recall.append(recall_score(target, predict)) 30 | self.f1s.append(f1_score(target, predict)) 31 | self.kappa.append(cohen_kappa_score(target, predict)) 32 | return 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Musculoskeletal Radiographs Abnormality Classifier 2 | 3 | ## Experiments 4 | 5 | | Network | Accuracy (encounter) | Precision (encounter) | Recall (encounter) | F1 (encounter) | Kappa (encounter) | 6 | | ---------------------- | -------------------- | --------------------- | ------------------ | -------------- | ----------------- | 7 | | DenseNet169 (baseline) | .83 (.84) | .82 (.82) | .87 (.90) | .84 (.86) | .65 (.65) | 8 | | MobileNet | .81 (.83) | .80 (.82) | .85 (.89) | .82 (.85) | .62 (.62) | 9 | | NASNetMobile | .82 (.83) | .78 (.80) | .89 (.92) | .83 (.86) | .63 (.63) | 10 | 11 | Also, ResNet50 in pytorch which achieved equivalent results. 12 | 13 | ## The [Mura](https://arxiv.org/abs/1712.06957) Dataset 14 | 15 | ```latex 16 | @misc{1712.06957, 17 | Author = {Pranav Rajpurkar and Jeremy Irvin and Aarti Bagul and Daisy Ding and Tony Duan and Hershel Mehta and Brandon Yang and Kaylie Zhu and Dillon Laird and Robyn L. Ball and Curtis Langlotz and Katie Shpanskaya and Matthew P. Lungren and Andrew Ng}, 18 | Title = {MURA Dataset: Towards Radiologist-Level Abnormality Detection in Musculoskeletal Radiographs}, 19 | Year = {2017}, 20 | Eprint = {arXiv:1712.06957},} 21 | ``` 22 | 23 | | Study | Normal | Abnormal | Total | 24 | | --------: | :-------- | :-------- | ---------: | 25 | | Elbow | 1,203 | 768 | 1,971 | 26 | | Finger | 1,389 | 753 | 2,142 | 27 | | Forearm | 677 | 380 | 1,057 | 28 | | Hand | 1,613 | 602 | 2,215 | 29 | | Humerus | 411 | 367 | 778 | 30 | | Shoulder | 1,479 | 1,594 | 3,073 | 31 | | Wrist | 2,295 | 1,451 | 3,746 | 32 | | **Total** | **9,067** | **5,915** | **14,982** | 33 | 34 | - Each study contains 1-N views (images) 35 | - 40,895 multi-view radiographic images 36 | 37 | ### Their results (DenseNet169) 38 | 39 | | | Radiologists (95% CI) | Model (95% CI) | 40 | | ---------------: | :----------------------- | :----------------------- | 41 | | Elbow | **0.858** (0.707, 0.959) | 0.848 (0.691, 0.955) | 42 | | Finger | 0.781 (0.638, 0.871) | **0.792** (0.588, 0.933) | 43 | | Forearm | **0.899** (0.804, 0.960) | 0.814 (0.633, 0.942) | 44 | | Hand | 0.854 (0.676, 0.958) | **0.858** (0.658, 0.978) | 45 | | Humerus | **0.895** (0.774, 0.976) | 0.862 (0.709, 0.968) | 46 | | Shoulder | **0.925** (0.811, 0.989) | 0.857 (0.667, 0.974) | 47 | | Wrist | 0.958 (0.908, 0.988) | **0.968** (0.889, 1.000) | 48 | | **Aggregate F1** | **0.884** (0.843, 0.918) | 0.859 (0.804, 0.905) | 49 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | from os import environ, getcwd 4 | from os.path import join 5 | 6 | import keras 7 | import numpy as np 8 | import pandas as pd 9 | import sklearn as skl 10 | import tensorflow as tf 11 | from keras.applications import NASNetMobile 12 | from keras.layers import Dense, GlobalAveragePooling2D 13 | from keras.metrics import binary_accuracy, binary_crossentropy 14 | from keras.models import Model 15 | from keras.optimizers import Adam 16 | from keras.preprocessing.image import ImageDataGenerator 17 | 18 | from mura import Mura 19 | 20 | pd.set_option('display.max_rows', 20) 21 | pd.set_option('precision', 4) 22 | np.set_printoptions(precision=4) 23 | 24 | environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Shut up tensorflow! 25 | print("tf : {}".format(tf.__version__)) 26 | print("keras : {}".format(keras.__version__)) 27 | print("numpy : {}".format(np.__version__)) 28 | print("pandas : {}".format(pd.__version__)) 29 | print("sklearn : {}".format(skl.__version__)) 30 | 31 | # Hyper-parameters / Globals 32 | BATCH_SIZE = 512 # tweak to your GPUs capacity 33 | IMG_HEIGHT = 224 # ResNetInceptionv2 & Xception like 299, ResNet50/VGG/Inception 224, NASM 331 34 | IMG_WIDTH = IMG_HEIGHT 35 | CHANNELS = 3 36 | DIMS = (IMG_HEIGHT, IMG_WIDTH, CHANNELS) # blame theano 37 | MODEL_TO_EVAL = './models/NASNetMobile.hdf5' 38 | DATA_DIR = 'MURA-v1.0/' 39 | EVAL_CSV = 'valid.csv' 40 | EVAL_DIR = 'data/val' 41 | 42 | # load up our csv with validation factors 43 | data_dir = join(getcwd(), DATA_DIR) 44 | eval_csv = join(data_dir, EVAL_CSV) 45 | df = pd.read_csv(eval_csv, names=['img', 'label'], header=None) 46 | eval_imgs = df.img.values.tolist() 47 | eval_labels = df.label.values.tolist() 48 | 49 | eval_datagen = ImageDataGenerator(rescale=1. / 255) 50 | eval_generator = eval_datagen.flow_from_directory( 51 | EVAL_DIR, class_mode='binary', shuffle=False, target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE) 52 | n_samples = eval_generator.samples 53 | base_model = NASNetMobile(input_shape=DIMS, weights='imagenet', include_top=False) 54 | x = base_model.output 55 | x = GlobalAveragePooling2D(name='avg_pool')(x) # comment for RESNET 56 | x = Dense(1, activation='sigmoid', name='predictions')(x) 57 | model = Model(inputs=base_model.input, outputs=x) 58 | model.load_weights(MODEL_TO_EVAL) 59 | model.compile(optimizer=Adam(lr=1e-3), loss=binary_crossentropy, metrics=['binary_accuracy']) 60 | score, acc = model.evaluate_generator(eval_generator, n_samples / BATCH_SIZE) 61 | print(model.metrics_names) 62 | print('==> Metrics with eval') 63 | print("loss :{:0.4f} \t Accuracy:{:0.4f}".format(score, acc)) 64 | y_pred = model.predict_generator(eval_generator, n_samples / BATCH_SIZE) 65 | mura = Mura(eval_generator.filenames, y_true=eval_generator.classes, y_pred=y_pred) 66 | print('==> Metrics with predict') 67 | print(mura.metrics()) 68 | print(mura.metrics_by_encounter()) 69 | # print(mura.metrics_by_study_type()) 70 | -------------------------------------------------------------------------------- /download_and_convert_mura.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import re 5 | import os 6 | from os import getcwd 7 | from os.path import exists, isdir, isfile, join 8 | import shutil 9 | import numpy as np 10 | import pandas as pd 11 | 12 | 13 | class ImageString(object): 14 | _patient_re = re.compile(r'patient(\d+)') 15 | _study_re = re.compile(r'study(\d+)') 16 | _image_re = re.compile(r'image(\d+)') 17 | _study_type_re = re.compile(r'XR_(\w+)') 18 | 19 | def __init__(self, img_filename): 20 | self.img_filename = img_filename 21 | self.patient = self._parse_patient() 22 | self.study = self._parse_study() 23 | self.image_num = self._parse_image() 24 | self.study_type = self._parse_study_type() 25 | self.image = self._parse_image() 26 | self.normal = self._parse_normal() 27 | 28 | def flat_file_name(self): 29 | return "{}_{}_patient{}_study{}_image{}.png".format(self.normal, self.study_type, self.patient, self.study, 30 | self.image, self.normal) 31 | 32 | def _parse_patient(self): 33 | return int(self._patient_re.search(self.img_filename).group(1)) 34 | 35 | def _parse_study(self): 36 | return int(self._study_re.search(self.img_filename).group(1)) 37 | 38 | def _parse_image(self): 39 | return int(self._image_re.search(self.img_filename).group(1)) 40 | 41 | def _parse_study_type(self): 42 | return self._study_type_re.search(self.img_filename).group(1) 43 | 44 | def _parse_normal(self): 45 | return "normal" if ("negative" in self.img_filename) else "abnormal" 46 | 47 | 48 | # processed 49 | # data 50 | # ├── train 51 | # │   ├── abnormal 52 | # │   └── normal 53 | # └── val 54 | # ├── abnormal 55 | # └── normal 56 | proc_data_dir = join(getcwd(), 'data') 57 | proc_train_dir = join(proc_data_dir, 'train') 58 | proc_val_dir = join(proc_data_dir, 'val') 59 | 60 | # Data loading code 61 | orig_data_dir = join(getcwd(), 'MURA-v1.0') 62 | train_dir = join(orig_data_dir, 'train') 63 | train_csv = join(orig_data_dir, 'train.csv') 64 | val_dir = join(orig_data_dir, 'valid') 65 | val_csv = join(orig_data_dir, 'valid.csv') 66 | test_dir = join(orig_data_dir, 'test') 67 | assert isdir(orig_data_dir) and isdir(train_dir) and isdir(val_dir) and isdir(test_dir) 68 | assert exists(train_csv) and isfile(train_csv) and exists(val_csv) and isfile(val_csv) 69 | 70 | df = pd.read_csv(train_csv, names=['img', 'label'], header=None) 71 | # imgs = df.img.values.tolist() 72 | # labels = df.label.values.tolist() 73 | # following datasets/folder.py's weird convention here... 74 | samples = [tuple(x) for x in df.values] 75 | for img, label in samples: 76 | assert ("negative" in img) is (label is 0) 77 | enc = ImageString(img) 78 | cat_dir = join(proc_train_dir, enc.normal) 79 | if not os.path.exists(cat_dir): 80 | os.mkdir(cat_dir) 81 | shutil.copy2(enc.img_filename, join(cat_dir, enc.flat_file_name())) 82 | 83 | df = pd.read_csv(val_csv, names=['img', 'label'], header=None) 84 | samples = [tuple(x) for x in df.values] 85 | for img, label in samples: 86 | assert ("negative" in img) is (label is 0) 87 | enc = ImageString(img) 88 | cat_dir = join(proc_val_dir, enc.normal) 89 | if not os.path.exists(cat_dir): 90 | os.mkdir(cat_dir) 91 | shutil.copy2(enc.img_filename, join(cat_dir, enc.flat_file_name())) 92 | -------------------------------------------------------------------------------- /mura.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import re 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.metrics import (accuracy_score, cohen_kappa_score, f1_score, precision_score, recall_score) 8 | 9 | pd.set_option('display.max_rows', 20) 10 | pd.set_option('precision', 4) 11 | np.set_printoptions(precision=4) 12 | 13 | 14 | class Mura(object): 15 | """`MURA `_ Dataset : 16 | Towards Radiologist-Level Abnormality Detection in Musculoskeletal Radiographs. 17 | """ 18 | url = "https://cs.stanford.edu/group/mlgroup/mura-v1.0.zip" 19 | filename = "mura-v1.0.zip" 20 | md5_checksum = '4c36feddb7f5698c8bf291b912c438b1' 21 | _patient_re = re.compile(r'patient(\d+)') 22 | _study_re = re.compile(r'study(\d+)') 23 | _image_re = re.compile(r'image(\d+)') 24 | _study_type_re = re.compile(r'_(\w+)_patient') 25 | 26 | def __init__(self, image_file_names, y_true, y_pred=None): 27 | self.imgs = image_file_names 28 | df_img = pd.Series(np.array(image_file_names), name='img') 29 | self.y_true = y_true 30 | df_true = pd.Series(np.array(y_true), name='y_true') 31 | self.y_pred = y_pred 32 | # number of unique classes 33 | self.patient = [] 34 | self.study = [] 35 | self.study_type = [] 36 | self.image_num = [] 37 | self.encounter = [] 38 | for img in image_file_names: 39 | self.patient.append(self._parse_patient(img)) 40 | self.study.append(self._parse_study(img)) 41 | self.image_num.append(self._parse_image(img)) 42 | self.study_type.append(self._parse_study_type(img)) 43 | self.encounter.append("{}_{}_{}".format( 44 | self._parse_study_type(img), 45 | self._parse_patient(img), 46 | self._parse_study(img), )) 47 | 48 | self.classes = np.unique(self.y_true) 49 | df_patient = pd.Series(np.array(self.patient), name='patient') 50 | df_study = pd.Series(np.array(self.study), name='study') 51 | df_image_num = pd.Series(np.array(self.image_num), name='image_num') 52 | df_study_type = pd.Series(np.array(self.study_type), name='study_type') 53 | df_encounter = pd.Series(np.array(self.encounter), name='encounter') 54 | 55 | self.data = pd.concat( 56 | [ 57 | df_img, 58 | df_encounter, 59 | df_true, 60 | df_patient, 61 | df_patient, 62 | df_study, 63 | df_image_num, 64 | df_study_type, 65 | ], axis=1) 66 | 67 | if self.y_pred is not None: 68 | self.y_pred_probability = self.y_pred.flatten() 69 | self.y_pred = self.y_pred_probability.round().astype(int) 70 | df_y_pred = pd.Series(self.y_pred, name='y_pred') 71 | df_y_pred_probability = pd.Series(self.y_pred_probability, name='y_pred_probs') 72 | self.data = pd.concat((self.data, df_y_pred, df_y_pred_probability), axis=1) 73 | 74 | def __len__(self): 75 | return len(self.imgs) 76 | 77 | def _parse_patient(self, img_filename): 78 | return int(self._patient_re.search(img_filename).group(1)) 79 | 80 | def _parse_study(self, img_filename): 81 | return int(self._study_re.search(img_filename).group(1)) 82 | 83 | def _parse_image(self, img_filename): 84 | return int(self._image_re.search(img_filename).group(1)) 85 | 86 | def _parse_study_type(self, img_filename): 87 | return self._study_type_re.search(img_filename).group(1) 88 | 89 | def metrics(self): 90 | return "per image metrics:\n\taccuracy : {:.2f}\tf1 : {:.2f}\tprecision : {:.2f}\trecall : {:.2f}\tcohen_kappa : {:.2f}".format( 91 | accuracy_score(self.y_true, self.y_pred), 92 | f1_score(self.y_true, self.y_pred), 93 | precision_score(self.y_true, self.y_pred), 94 | recall_score(self.y_true, self.y_pred), 95 | cohen_kappa_score(self.y_true, self.y_pred), ) 96 | 97 | def metrics_by_encounter(self): 98 | y_pred = self.data.groupby(['encounter'])['y_pred_probs'].mean().round() 99 | y_true = self.data.groupby(['encounter'])['y_true'].mean().round() 100 | return "per encounter metrics:\n\taccuracy : {:.2f}\tf1 : {:.2f}\tprecision : {:.2f}\trecall : {:.2f}\tcohen_kappa : {:.2f}".format( 101 | accuracy_score(y_true, y_pred), 102 | f1_score(y_true, y_pred), 103 | precision_score(y_true, y_pred), 104 | recall_score(y_true, y_pred), 105 | cohen_kappa_score(self.y_true, self.y_pred), ) 106 | 107 | # def metrics_by_study_type(self): 108 | # y_pred = self.data.groupby(['study_type', 'encounter'])['y_pred_probs'].mean().round() 109 | # y_true = self.data.groupby(['study_type', 'encounter'])['y_true'].mean().round() 110 | # return "per study_type metrics:\n\taccuracy : {:.2f}\tf1 : {:.2f}\tprecision : {:.2f}\trecall : {:.2f}".format( 111 | # accuracy_score(y_true, y_pred), 112 | # f1_score(y_true, y_pred), 113 | # precision_score(y_true, y_pred), 114 | # recall_score(y_true, y_pred), ) 115 | -------------------------------------------------------------------------------- /pytorch/dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import re 4 | from os import getcwd 5 | from os.path import join 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch.utils.data as data 10 | from PIL import Image 11 | 12 | from torchvision.datasets.utils import check_integrity, download_url 13 | 14 | 15 | class MuraDataset(data.Dataset): 16 | """`MURA `_ Dataset : 17 | Towards Radiologist-Level Abnormality Detection in Musculoskeletal Radiographs. 18 | """ 19 | url = "https://cs.stanford.edu/group/mlgroup/mura-v1.0.zip" 20 | filename = "mura-v1.0.zip" 21 | md5_checksum = '4c36feddb7f5698c8bf291b912c438b1' 22 | _patient_re = re.compile(r'patient(\d+)') 23 | _study_re = re.compile(r'study(\d+)') 24 | _image_re = re.compile(r'image(\d+)') 25 | _study_type_re = re.compile(r'XR_(\w+)') 26 | 27 | def __init__(self, csv_f, transform=None, download=False): 28 | self.df = pd.read_csv(csv_f, names=['img', 'label'], header=None) 29 | self.imgs = self.df.img.values.tolist() 30 | self.labels = self.df.label.values.tolist() 31 | # following datasets/folder.py's weird convention here... 32 | self.samples = [tuple(x) for x in self.df.values] 33 | # number of unique classes 34 | self.classes = np.unique(self.labels) 35 | self.balanced_weights = self.balance_class_weights() 36 | 37 | self.transform = transform 38 | 39 | def __len__(self): 40 | return len(self.imgs) 41 | 42 | def _parse_patient(self, img_filename): 43 | return int(self._patient_re.search(img_filename).group(1)) 44 | 45 | def _parse_study(self, img_filename): 46 | return int(self._study_re.search(img_filename).group(1)) 47 | 48 | def _parse_image(self, img_filename): 49 | return int(self._image_re.search(img_filename).group(1)) 50 | 51 | def _parse_study_type(self, img_filename): 52 | return self._study_type_re.search(img_filename).group(1) 53 | 54 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 55 | """Downloads the `tarball_url` and uncompresses it locally. 56 | Args: 57 | tarball_url: The URL of a tarball file. 58 | dataset_dir: The directory where the temporary files are stored. 59 | """ 60 | filename = tarball_url.split('/')[-1] 61 | filepath = os.path.join(dataset_dir, filename) 62 | 63 | def _progress(count, block_size, total_size): 64 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 65 | float(count * block_size) / float(total_size) * 100.0)) 66 | sys.stdout.flush() 67 | 68 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 69 | print() 70 | statinfo = os.stat(filepath) 71 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 72 | if ".zip" in filename: 73 | print("zipfile:{}".format(filepath)) 74 | with zipfile.ZipFile(filepath, "r") as zip_ref: 75 | zip_ref.extractall(dataset_dir) 76 | else: 77 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 78 | 79 | def balance_class_weights(self): 80 | count = [0] * len(self.classes) 81 | for item in self.samples: 82 | count[item[1]] += 1 83 | weight_per_class = [0.] * len(self.classes) 84 | N = float(sum(count)) 85 | for i in range(len(self.classes)): 86 | weight_per_class[i] = N / float(count[i]) 87 | weight = [0] * len(self.samples) 88 | for idx, val in enumerate(self.samples): 89 | weight[idx] = weight_per_class[val[1]] 90 | return weight 91 | 92 | def __getitem__(self, idx): 93 | img_filename = join(self.imgs[idx]) 94 | patient = self._parse_patient(img_filename) 95 | study = self._parse_study(img_filename) 96 | image_num = self._parse_image(img_filename) 97 | study_type = self._parse_study_type(img_filename) 98 | 99 | # todo(bdd) : inconsistent right now, need param for grayscale / RGB 100 | # todo(bdd) : 'L' -> gray, 'RGB' -> Colors 101 | image = Image.open(img_filename).convert('RGB') 102 | label = self.labels[idx] 103 | 104 | if self.transform is not None: 105 | image = self.transform(image) 106 | 107 | meta_data = { 108 | 'y_true': label, 109 | 'img_filename': img_filename, 110 | 'patient': patient, 111 | 'study': study, 112 | 'study_type': study_type, 113 | 'image_num': image_num, 114 | 'encounter': "{}_{}_{}".format(study_type, patient, study) 115 | } 116 | return image, label, meta_data 117 | 118 | 119 | if __name__ == '__main__': 120 | import torchvision.transforms as transforms 121 | import pprint 122 | 123 | data_dir = join(getcwd(), 'MURA-v1.0') 124 | val_csv = join(data_dir, 'valid.csv') 125 | val_loader = data.DataLoader( 126 | MuraDataset(val_csv, 127 | transforms.Compose([ 128 | transforms.Resize(224), 129 | transforms.CenterCrop(224), 130 | transforms.ToTensor(), 131 | ])), 132 | batch_size=1, 133 | shuffle=False, 134 | num_workers=1, 135 | pin_memory=False) 136 | 137 | for i, (image, label, meta_data) in enumerate(val_loader): 138 | pprint.pprint(meta_data.cpu()) 139 | if i == 40: 140 | break 141 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | from datetime import datetime 5 | from os import environ 6 | 7 | import keras 8 | import numpy as np 9 | import pandas as pd 10 | import tensorflow as tf 11 | from keras.applications import MobileNet 12 | from keras.callbacks import (EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard) 13 | from keras.layers import Dense, GlobalAveragePooling2D 14 | from keras.metrics import binary_accuracy, binary_crossentropy 15 | from keras.models import Model 16 | from keras.optimizers import SGD, Adam 17 | from keras.preprocessing.image import ImageDataGenerator 18 | from sklearn.utils import class_weight 19 | 20 | environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Shut up tensorflow! 21 | print("tf : {}".format(tf.__version__)) 22 | print("keras : {}".format(keras.__version__)) 23 | print("numpy : {}".format(np.__version__)) 24 | print("pandas : {}".format(pd.__version__)) 25 | 26 | parser = argparse.ArgumentParser(description='Hyperparameters') 27 | parser.add_argument('--classes', default=1, type=int) 28 | parser.add_argument('--workers', default=4, type=int) 29 | parser.add_argument('--epochs', default=120, type=int) 30 | parser.add_argument('-b', '--batch-size', default=32, type=int, help='mini-batch size') 31 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float) 32 | parser.add_argument('--lr-wait', default=10, type=int, help='how long to wait on plateu') 33 | parser.add_argument('--decay', default=1e-4, type=float) 34 | parser.add_argument('--momentum', default=0.9, type=float) 35 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') 36 | parser.add_argument('--fullretrain', dest='fullretrain', action='store_true', help='retrain all layers of the model') 37 | parser.add_argument('--seed', default=1337, type=int, help='random seed') 38 | parser.add_argument('--img_channels', default=3, type=int) 39 | parser.add_argument('--img_size', default=224, type=int) 40 | parser.add_argument('--early_stop', default=20, type=int) 41 | 42 | 43 | def train(): 44 | global args 45 | args = parser.parse_args() 46 | img_shape = (args.img_size, args.img_size, args.img_channels) # blame theano 47 | now_iso = datetime.now().strftime('%Y-%m-%dT%H:%M:%S%z') 48 | 49 | # We then scale the variable-sized images to 224x224 50 | # We augment .. by applying random lateral inversions and rotations. 51 | train_datagen = ImageDataGenerator( 52 | rescale=1. / 255, 53 | rotation_range=45, 54 | # width_shift_range=0.2, 55 | # height_shift_range=0.2, 56 | zoom_range=0.2, 57 | horizontal_flip=True) 58 | 59 | train_generator = train_datagen.flow_from_directory( 60 | 'data/train', 61 | shuffle=True, 62 | target_size=(args.img_size, args.img_size), 63 | class_mode='binary', 64 | batch_size=args.batch_size, ) 65 | 66 | val_datagen = ImageDataGenerator(rescale=1. / 255) 67 | val_generator = val_datagen.flow_from_directory( 68 | 'data/val', 69 | shuffle=True, # otherwise we get distorted batch-wise metrics 70 | class_mode='binary', 71 | target_size=(args.img_size, args.img_size), 72 | batch_size=args.batch_size, ) 73 | 74 | classes = len(train_generator.class_indices) 75 | assert classes > 0 76 | assert classes is len(val_generator.class_indices) 77 | n_of_train_samples = train_generator.samples 78 | n_of_val_samples = val_generator.samples 79 | 80 | # Architectures 81 | base_model = MobileNet(input_shape=img_shape, weights='imagenet', include_top=False) 82 | x = base_model.output # Recast classification layer 83 | # x = Flatten()(x) # Uncomment for Resnet based models 84 | x = GlobalAveragePooling2D(name='predictions_avg_pool')(x) # comment for RESNET models 85 | # n_classes; softmax for multi-class, sigmoid for binary 86 | x = Dense(args.classes, activation='sigmoid', name='predictions')(x) 87 | model = Model(inputs=base_model.input, outputs=x) 88 | 89 | # checkpoints 90 | # 91 | checkpoint = ModelCheckpoint(filepath='./models/MobileNet.hdf5', verbose=1, save_best_only=True) 92 | early_stop = EarlyStopping(patience=args.early_stop) 93 | tensorboard = TensorBoard(log_dir='./logs/MobileNet/{}/'.format(now_iso)) 94 | # reduce_lr = ReduceLROnPlateau(factor=0.03, cooldown=0, patience=args.lr_wait, min_lr=0.1e-6) 95 | callbacks = [checkpoint, tensorboard, checkpoint] 96 | 97 | # Calculate class weights 98 | weights = class_weight.compute_class_weight('balanced', np.unique(train_generator.classes), train_generator.classes) 99 | weights = {0: weights[0], 1: weights[1]} 100 | # for layer in base_model.layers: 101 | # layer.set_trainable = False 102 | 103 | # print(model.summary()) 104 | # for i, layer in enumerate(base_model.layers): 105 | # print(i, layer.name) 106 | if args.resume: 107 | model.load_weights(args.resume) 108 | for layer in model.layers: 109 | layer.set_trainable = True 110 | 111 | # if TRAIN_FULL: 112 | # print("=> retrain all layers of network") 113 | # for layer in model.layers: 114 | # set_trainable = True 115 | # else: 116 | # print("=> retraining only bottleneck and fc layers") 117 | # import pdb 118 | # pdb.set_trace() 119 | # set_trainable = False 120 | # for layer in base_model.layers: 121 | # if "block12" in layer.name: # what block do we want to start unfreezing 122 | # set_trainable = True 123 | # if set_trainable: 124 | # layer.trainable = True 125 | # else: 126 | # layer.trainable = False 127 | 128 | # The network is trained end-to-end using Adam with default parameters 129 | model.compile( 130 | optimizer=Adam(lr=args.lr, decay=args.decay), 131 | # optimizer=SGD(lr=args.lr, decay=args.decay,momentum=args.momentum, nesterov=True), 132 | loss=binary_crossentropy, 133 | metrics=[binary_accuracy], ) 134 | 135 | model_out = model.fit_generator( 136 | train_generator, 137 | steps_per_epoch=n_of_train_samples // args.batch_size, 138 | epochs=args.epochs, 139 | validation_data=val_generator, 140 | validation_steps=n_of_val_samples // args.batch_size, 141 | class_weight=weights, 142 | workers=args.workers, 143 | use_multiprocessing=True, 144 | callbacks=callbacks) 145 | 146 | 147 | if __name__ == '__main__': 148 | train() 149 | -------------------------------------------------------------------------------- /pytorch/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import random 5 | import shutil 6 | from os import getcwd 7 | from os.path import exists, isdir, isfile, join 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.optim as optim 16 | import torch.utils.data as data 17 | from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score) 18 | from tensorboardX import SummaryWriter 19 | from torch.autograd import Variable 20 | from torch.optim.lr_scheduler import ReduceLROnPlateau 21 | from tqdm import tqdm 22 | 23 | import torchvision 24 | import torchvision.models as models 25 | import torchvision.transforms as transforms 26 | from dataloader import MuraDataset 27 | 28 | print("torch : {}".format(torch.__version__)) 29 | print("torch vision : {}".format(torchvision.__version__)) 30 | print("numpy : {}".format(np.__version__)) 31 | print("pandas : {}".format(pd.__version__)) 32 | model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__")) 33 | 34 | parser = argparse.ArgumentParser(description='Hyperparameters') 35 | parser.add_argument('--data_dir', default='MURA-v1.0', metavar='DIR', help='path to dataset') 36 | parser.add_argument('--arch', default='densenet121', choices=model_names, help='nn architecture') 37 | parser.add_argument('--classes', default=2, type=int) 38 | parser.add_argument('--workers', default=4, type=int) 39 | parser.add_argument('--epochs', default=90, type=int) 40 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number') 41 | parser.add_argument('-b', '--batch-size', default=512, type=int, help='mini-batch size') 42 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate') 43 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 44 | parser.add_argument('--weight-decay', default=.1, type=float, help='weight decay') 45 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') 46 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') 47 | parser.add_argument('--fullretrain', dest='fullretrain', action='store_true', help='retrain all layers of the model') 48 | parser.add_argument('--seed', default=1337, type=int, help='random seed') 49 | 50 | best_val_loss = 0 51 | 52 | tb_writer = SummaryWriter() 53 | 54 | 55 | def main(): 56 | global args, best_val_loss 57 | args = parser.parse_args() 58 | print("=> setting random seed to '{}'".format(args.seed)) 59 | np.random.seed(args.seed) 60 | torch.manual_seed(args.seed) 61 | torch.cuda.manual_seed(args.seed) 62 | 63 | if args.pretrained: 64 | print("=> using pre-trained model '{}'".format(args.arch)) 65 | model = models.__dict__[args.arch](pretrained=True) 66 | for param in model.parameters(): 67 | param.requires_grad = False 68 | 69 | if 'resnet' in args.arch: 70 | # for param in model.layer4.parameters(): 71 | model.fc = nn.Linear(2048, args.classes) 72 | 73 | if 'dense' in args.arch: 74 | if '121' in args.arch: 75 | # (classifier): Linear(in_features=1024) 76 | model.classifier = nn.Linear(1024, args.classes) 77 | elif '169' in args.arch: 78 | # (classifier): Linear(in_features=1664) 79 | model.classifier = nn.Linear(1664, args.classes) 80 | else: 81 | return 82 | 83 | else: 84 | print("=> creating model '{}'".format(args.arch)) 85 | model = models.__dict__[args.arch]() 86 | 87 | model = torch.nn.DataParallel(model).cuda() 88 | # optionally resume from a checkpoint 89 | if args.resume: 90 | if isfile(args.resume): 91 | print("=> found checkpoint") 92 | checkpoint = torch.load(args.resume) 93 | args.start_epoch = checkpoint['epoch'] 94 | best_val_loss = checkpoint['best_val_loss'] 95 | model.load_state_dict(checkpoint['state_dict']) 96 | 97 | args.epochs = args.epochs + args.start_epoch 98 | print("=> loading checkpoint '{}' with acc of '{}'".format( 99 | args.resume, 100 | checkpoint['best_val_loss'], )) 101 | 102 | else: 103 | print("=> no checkpoint found at '{}'".format(args.resume)) 104 | 105 | cudnn.benchmark = True 106 | 107 | # Data loading code 108 | data_dir = join(getcwd(), args.data_dir) 109 | train_dir = join(data_dir, 'train') 110 | train_csv = join(data_dir, 'train.csv') 111 | val_dir = join(data_dir, 'valid') 112 | val_csv = join(data_dir, 'valid.csv') 113 | test_dir = join(data_dir, 'test') 114 | assert isdir(data_dir) and isdir(train_dir) and isdir(val_dir) and isdir(test_dir) 115 | assert exists(train_csv) and isfile(train_csv) and exists(val_csv) and isfile(val_csv) 116 | 117 | # Before feeding images into the network, we normalize each image to have 118 | # the same mean and standard deviation of images in the ImageNet training set. 119 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 120 | 121 | # We then scale the variable-sized images to 224 × 224. 122 | # We augment by applying random lateral inversions and rotations. 123 | train_transforms = transforms.Compose([ 124 | transforms.Resize(224), 125 | transforms.CenterCrop(224), 126 | # transforms.RandomVerticalFlip(), 127 | # transforms.RandomRotation(30), 128 | transforms.RandomHorizontalFlip(), 129 | transforms.ToTensor(), 130 | normalize, 131 | ]) 132 | 133 | train_data = MuraDataset(train_csv, transform=train_transforms) 134 | weights = train_data.balanced_weights 135 | weights = torch.DoubleTensor(weights) 136 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) 137 | 138 | # num_of_sample = 37110 139 | # weights = 1 / torch.DoubleTensor([24121, 1300]) 140 | # sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_of_sample) 141 | train_loader = data.DataLoader( 142 | train_data, 143 | batch_size=args.batch_size, 144 | # shuffle=True, 145 | num_workers=args.workers, 146 | sampler=sampler, 147 | pin_memory=True) 148 | val_loader = data.DataLoader( 149 | MuraDataset(val_csv, 150 | transforms.Compose([ 151 | transforms.Resize(224), 152 | transforms.CenterCrop(224), 153 | transforms.ToTensor(), 154 | normalize, 155 | ])), 156 | batch_size=args.batch_size, 157 | shuffle=False, 158 | num_workers=args.workers, 159 | pin_memory=True) 160 | 161 | criterion = nn.CrossEntropyLoss().cuda() 162 | # We use an initial learning rate of 0.0001 that is decayed by a factor of 163 | # 10 each time the validation loss plateaus after an epoch, and pick the 164 | # model with the lowest validation loss 165 | if args.fullretrain: 166 | print("=> optimizing all layers") 167 | for param in model.parameters(): 168 | param.requires_grad = True 169 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 170 | else: 171 | print("=> optimizing fc/classifier layers") 172 | optimizer = optim.Adam(model.module.fc.parameters(), args.lr, weight_decay=args.weight_decay) 173 | 174 | scheduler = ReduceLROnPlateau(optimizer, 'max', patience=10, verbose=True) 175 | for epoch in range(args.start_epoch, args.epochs): 176 | # train for one epoch 177 | train(train_loader, model, criterion, optimizer, epoch) 178 | # evaluate on validation set 179 | val_loss = validate(val_loader, model, criterion, epoch) 180 | scheduler.step(val_loss) 181 | # remember best Accuracy and save checkpoint 182 | is_best = val_loss > best_val_loss 183 | best_val_loss = max(val_loss, best_val_loss) 184 | save_checkpoint({ 185 | 'epoch': epoch + 1, 186 | 'arch': args.arch, 187 | 'state_dict': model.state_dict(), 188 | 'best_val_loss': best_val_loss, 189 | }, is_best) 190 | 191 | 192 | def train(train_loader, model, criterion, optimizer, epoch): 193 | losses = AverageMeter() 194 | acc = AverageMeter() 195 | 196 | # ensure model is in train mode 197 | model.train() 198 | pbar = tqdm(train_loader) 199 | for i, (images, target, meta) in enumerate(pbar): 200 | target = target.cuda(async=True) 201 | image_var = Variable(images) 202 | label_var = Variable(target) 203 | 204 | # pass this batch through our model and get y_pred 205 | y_pred = model(image_var) 206 | 207 | # update loss metric 208 | loss = criterion(y_pred, label_var) 209 | losses.update(loss.data[0], images.size(0)) 210 | 211 | # update accuracy metric 212 | prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1)) 213 | acc.update(prec1[0], images.size(0)) 214 | 215 | # compute gradient and do SGD step 216 | optimizer.zero_grad() 217 | loss.backward() 218 | optimizer.step() 219 | 220 | pbar.set_description("EPOCH[{0}][{1}/{2}]".format(epoch, i, len(train_loader))) 221 | pbar.set_postfix( 222 | acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc), 223 | loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses)) 224 | 225 | tb_writer.add_scalar('train/loss', losses.avg, epoch) 226 | tb_writer.add_scalar('train/acc', acc.avg, epoch) 227 | return 228 | 229 | 230 | def validate(val_loader, model, criterion, epoch): 231 | model.eval() 232 | acc = AverageMeter() 233 | losses = AverageMeter() 234 | meta_data = [] 235 | pbar = tqdm(val_loader) 236 | for i, (images, target, meta) in enumerate(pbar): 237 | target = target.cuda(async=True) 238 | image_var = Variable(images, volatile=True) 239 | label_var = Variable(target, volatile=True) 240 | 241 | y_pred = model(image_var) 242 | # udpate loss metric 243 | loss = criterion(y_pred, label_var) 244 | losses.update(loss.data[0], images.size(0)) 245 | 246 | # update accuracy metric on the GPU 247 | prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1)) 248 | acc.update(prec1[0], images.size(0)) 249 | 250 | sm = nn.Softmax() 251 | sm_pred = sm(y_pred).data.cpu().numpy() 252 | # y_norm_probs = sm_pred[:, 0] # p(normal) 253 | y_pred_probs = sm_pred[:, 1] # p(abnormal) 254 | 255 | meta_data.append( 256 | pd.DataFrame({ 257 | 'img_filename': meta['img_filename'], 258 | 'y_true': meta['y_true'].numpy(), 259 | 'y_pred_probs': y_pred_probs, 260 | 'patient': meta['patient'].numpy(), 261 | 'study': meta['study'].numpy(), 262 | 'image_num': meta['image_num'].numpy(), 263 | 'encounter': meta['encounter'], 264 | })) 265 | 266 | pbar.set_description("VALIDATION[{}/{}]".format(i, len(val_loader))) 267 | pbar.set_postfix( 268 | acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc), 269 | loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses)) 270 | df = pd.concat(meta_data) 271 | ab = df.groupby(['encounter'])['y_pred_probs', 'y_true'].mean() 272 | ab['y_pred_round'] = ab.y_pred_probs.round() 273 | ab['y_pred_round'] = pd.to_numeric(ab.y_pred_round, downcast='integer') 274 | 275 | f1_s = f1_score(ab.y_true, ab.y_pred_round) 276 | prec_s = precision_score(ab.y_true, ab.y_pred_round) 277 | rec_s = recall_score(ab.y_true, ab.y_pred_round) 278 | acc_s = accuracy_score(ab.y_true, ab.y_pred_round) 279 | tb_writer.add_scalar('val/f1_score', f1_s, epoch) 280 | tb_writer.add_scalar('val/precision', prec_s, epoch) 281 | tb_writer.add_scalar('val/recall', rec_s, epoch) 282 | tb_writer.add_scalar('val/accuracy', acc_s, epoch) 283 | # return the metric we want to evaluate this model's performance by 284 | return f1_s 285 | 286 | 287 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 288 | torch.save(state, filename) 289 | if is_best: 290 | shutil.copyfile(filename, 'model_best.pth.tar') 291 | 292 | 293 | class AverageMeter(object): 294 | """Computes and stores the average and current value""" 295 | 296 | def __init__(self): 297 | self.reset() 298 | 299 | def reset(self): 300 | self.val = 0 301 | self.avg = 0 302 | self.sum = 0 303 | self.count = 0 304 | 305 | def update(self, val, n=1): 306 | self.val = val 307 | self.sum += val * n 308 | self.count += n 309 | self.avg = self.sum / self.count 310 | 311 | 312 | def accuracy(y_pred, y_actual, topk=(1, )): 313 | """Computes the precision@k for the specified values of k""" 314 | maxk = max(topk) 315 | batch_size = y_actual.size(0) 316 | 317 | _, pred = y_pred.topk(maxk, 1, True, True) 318 | pred = pred.t() 319 | correct = pred.eq(y_actual.view(1, -1).expand_as(pred)) 320 | 321 | res = [] 322 | for k in topk: 323 | correct_k = correct[:k].view(-1).float().sum(0) 324 | res.append(correct_k.mul_(100.0 / batch_size)) 325 | 326 | return res 327 | 328 | 329 | if __name__ == '__main__': 330 | main() 331 | --------------------------------------------------------------------------------