├── .style.yapf
├── .gitattributes
├── scripts
├── NASNetMobile.sh
├── MobileNet.sh
└── DenseNet169.sh
├── models
├── MobileNet.hdf5
├── DenseNet169.hdf5
└── NASNetMobile.hdf5
├── .gitignore
├── LICENSE
├── metrics.py
├── README.md
├── eval.py
├── download_and_convert_mura.py
├── mura.py
├── pytorch
├── dataloader.py
└── train.py
└── train.py
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | based_on_style = pep8
3 | indent_width = 4
4 | column_limit = 120
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.psd filter=lfs diff=lfs merge=lfs -text
2 | *.hdf5 filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/scripts/NASNetMobile.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python train.py \
3 | --epochs 120 \
4 | --resume 'models/86_NASNetMobile_2018-01-04T15:15:37.hdf5'
--------------------------------------------------------------------------------
/models/MobileNet.hdf5:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:229a1c4d0414337aa21f848fffb6ad9f7acb5dd96443053d613486c264e4d356
3 | size 38935688
4 |
--------------------------------------------------------------------------------
/models/DenseNet169.hdf5:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:a42d7a17f412646634ad128135403b38fdbab533647ce6df28669e5ff24d8fad
3 | size 152693520
4 |
--------------------------------------------------------------------------------
/models/NASNetMobile.hdf5:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:fe5f54a66ce1db8afc138cbdaf44d09a41e5006e7ac06b704bc9d64c95f1ef1c
3 | size 55106408
4 |
--------------------------------------------------------------------------------
/scripts/MobileNet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python train.py \
3 | --epochs 5 \
4 | --b 32
5 |
6 | python train.py \
7 | --b 32 \
8 | --resume ./models/MobileNet.hdf5 \
9 | --lr 1e-4
10 |
11 | python train.py \
12 | --b 32 \
13 | --resume ./models/MobileNet.hdf5 \
14 | --lr 1e-5
15 |
16 | python train.py \
17 | --b 32 \
18 | --resume ./models/MobileNet.hdf5 \
19 | --lr 1e-6
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | */*.pyc
2 | */**/*.pyc
3 | */**/**/*.pyc
4 | */**/**/**/*.pyc
5 | */**/**/**/**/*.pyc
6 | */*.so*
7 | */**/*.so*
8 | */**/*.dylib*
9 |
10 | # IPython notebook checkpoints
11 | .ipynb_checkpoints
12 |
13 | # Editor temporaries
14 | *.swn
15 | *.swo
16 | *.swp
17 | *~
18 |
19 | # macOS dir files
20 | .DS_Store
21 |
22 | MURA-v1.0/
23 | runs/
24 | checkpoint.pth.tar
25 | *.pth.tar
26 | data/
27 | *.h5
28 | logs/
29 | .fuse_*
--------------------------------------------------------------------------------
/scripts/DenseNet169.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | python train.py \
3 | --epochs 10 \
4 | --b 8
5 |
6 | python train.py \
7 | --epochs 15 \
8 | --b 8 \
9 | --resume ./models/DenseNet169.hdf5 \
10 | --lr 1e-4
11 |
12 | python train.py \
13 | --epochs 30 \
14 | --b 8 \
15 | --resume ./models/DenseNet169.hdf5 \
16 | --lr 1e-5
17 |
18 | python train.py \
19 | --epochs 60 \
20 | --b 8 \
21 | --resume ./models/DenseNet169.hdf5 \
22 | --lr 1e-6
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Bobby D DeSimone
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import absolute_import, division, print_function
3 |
4 | import keras
5 | import numpy as np
6 | from sklearn.metrics import *
7 |
8 |
9 | class SKLearnMetrics(keras.callbacks.Callback):
10 | """ SKLearnMetrics computes various classification metrics at the end of a batch.
11 | Unforunately, doesn't work when used with generators...."""
12 |
13 | def on_train_begin(self, logs={}):
14 | self.confusion = []
15 | self.precision = []
16 | self.recall = []
17 | self.f1s = []
18 | self.kappa = []
19 | self.auc = []
20 |
21 | def on_epoch_end(self, epoch, logs={}):
22 | score = np.asarray(self.model.predict(self.validation_data[0]))
23 | predict = np.round(np.asarray(self.model.predict(self.validation_data[0])))
24 | target = self.validation_data[1]
25 |
26 | self.auc.append(roc_auc_score(target, score))
27 | self.confusion.append(confusion_matrix(target, predict))
28 | self.precision.append(precision_score(target, predict))
29 | self.recall.append(recall_score(target, predict))
30 | self.f1s.append(f1_score(target, predict))
31 | self.kappa.append(cohen_kappa_score(target, predict))
32 | return
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Musculoskeletal Radiographs Abnormality Classifier
2 |
3 | ## Experiments
4 |
5 | | Network | Accuracy (encounter) | Precision (encounter) | Recall (encounter) | F1 (encounter) | Kappa (encounter) |
6 | | ---------------------- | -------------------- | --------------------- | ------------------ | -------------- | ----------------- |
7 | | DenseNet169 (baseline) | .83 (.84) | .82 (.82) | .87 (.90) | .84 (.86) | .65 (.65) |
8 | | MobileNet | .81 (.83) | .80 (.82) | .85 (.89) | .82 (.85) | .62 (.62) |
9 | | NASNetMobile | .82 (.83) | .78 (.80) | .89 (.92) | .83 (.86) | .63 (.63) |
10 |
11 | Also, ResNet50 in pytorch which achieved equivalent results.
12 |
13 | ## The [Mura](https://arxiv.org/abs/1712.06957) Dataset
14 |
15 | ```latex
16 | @misc{1712.06957,
17 | Author = {Pranav Rajpurkar and Jeremy Irvin and Aarti Bagul and Daisy Ding and Tony Duan and Hershel Mehta and Brandon Yang and Kaylie Zhu and Dillon Laird and Robyn L. Ball and Curtis Langlotz and Katie Shpanskaya and Matthew P. Lungren and Andrew Ng},
18 | Title = {MURA Dataset: Towards Radiologist-Level Abnormality Detection in Musculoskeletal Radiographs},
19 | Year = {2017},
20 | Eprint = {arXiv:1712.06957},}
21 | ```
22 |
23 | | Study | Normal | Abnormal | Total |
24 | | --------: | :-------- | :-------- | ---------: |
25 | | Elbow | 1,203 | 768 | 1,971 |
26 | | Finger | 1,389 | 753 | 2,142 |
27 | | Forearm | 677 | 380 | 1,057 |
28 | | Hand | 1,613 | 602 | 2,215 |
29 | | Humerus | 411 | 367 | 778 |
30 | | Shoulder | 1,479 | 1,594 | 3,073 |
31 | | Wrist | 2,295 | 1,451 | 3,746 |
32 | | **Total** | **9,067** | **5,915** | **14,982** |
33 |
34 | - Each study contains 1-N views (images)
35 | - 40,895 multi-view radiographic images
36 |
37 | ### Their results (DenseNet169)
38 |
39 | | | Radiologists (95% CI) | Model (95% CI) |
40 | | ---------------: | :----------------------- | :----------------------- |
41 | | Elbow | **0.858** (0.707, 0.959) | 0.848 (0.691, 0.955) |
42 | | Finger | 0.781 (0.638, 0.871) | **0.792** (0.588, 0.933) |
43 | | Forearm | **0.899** (0.804, 0.960) | 0.814 (0.633, 0.942) |
44 | | Hand | 0.854 (0.676, 0.958) | **0.858** (0.658, 0.978) |
45 | | Humerus | **0.895** (0.774, 0.976) | 0.862 (0.709, 0.968) |
46 | | Shoulder | **0.925** (0.811, 0.989) | 0.857 (0.667, 0.974) |
47 | | Wrist | 0.958 (0.908, 0.988) | **0.968** (0.889, 1.000) |
48 | | **Aggregate F1** | **0.884** (0.843, 0.918) | 0.859 (0.804, 0.905) |
49 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | from os import environ, getcwd
4 | from os.path import join
5 |
6 | import keras
7 | import numpy as np
8 | import pandas as pd
9 | import sklearn as skl
10 | import tensorflow as tf
11 | from keras.applications import NASNetMobile
12 | from keras.layers import Dense, GlobalAveragePooling2D
13 | from keras.metrics import binary_accuracy, binary_crossentropy
14 | from keras.models import Model
15 | from keras.optimizers import Adam
16 | from keras.preprocessing.image import ImageDataGenerator
17 |
18 | from mura import Mura
19 |
20 | pd.set_option('display.max_rows', 20)
21 | pd.set_option('precision', 4)
22 | np.set_printoptions(precision=4)
23 |
24 | environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Shut up tensorflow!
25 | print("tf : {}".format(tf.__version__))
26 | print("keras : {}".format(keras.__version__))
27 | print("numpy : {}".format(np.__version__))
28 | print("pandas : {}".format(pd.__version__))
29 | print("sklearn : {}".format(skl.__version__))
30 |
31 | # Hyper-parameters / Globals
32 | BATCH_SIZE = 512 # tweak to your GPUs capacity
33 | IMG_HEIGHT = 224 # ResNetInceptionv2 & Xception like 299, ResNet50/VGG/Inception 224, NASM 331
34 | IMG_WIDTH = IMG_HEIGHT
35 | CHANNELS = 3
36 | DIMS = (IMG_HEIGHT, IMG_WIDTH, CHANNELS) # blame theano
37 | MODEL_TO_EVAL = './models/NASNetMobile.hdf5'
38 | DATA_DIR = 'MURA-v1.0/'
39 | EVAL_CSV = 'valid.csv'
40 | EVAL_DIR = 'data/val'
41 |
42 | # load up our csv with validation factors
43 | data_dir = join(getcwd(), DATA_DIR)
44 | eval_csv = join(data_dir, EVAL_CSV)
45 | df = pd.read_csv(eval_csv, names=['img', 'label'], header=None)
46 | eval_imgs = df.img.values.tolist()
47 | eval_labels = df.label.values.tolist()
48 |
49 | eval_datagen = ImageDataGenerator(rescale=1. / 255)
50 | eval_generator = eval_datagen.flow_from_directory(
51 | EVAL_DIR, class_mode='binary', shuffle=False, target_size=(IMG_HEIGHT, IMG_WIDTH), batch_size=BATCH_SIZE)
52 | n_samples = eval_generator.samples
53 | base_model = NASNetMobile(input_shape=DIMS, weights='imagenet', include_top=False)
54 | x = base_model.output
55 | x = GlobalAveragePooling2D(name='avg_pool')(x) # comment for RESNET
56 | x = Dense(1, activation='sigmoid', name='predictions')(x)
57 | model = Model(inputs=base_model.input, outputs=x)
58 | model.load_weights(MODEL_TO_EVAL)
59 | model.compile(optimizer=Adam(lr=1e-3), loss=binary_crossentropy, metrics=['binary_accuracy'])
60 | score, acc = model.evaluate_generator(eval_generator, n_samples / BATCH_SIZE)
61 | print(model.metrics_names)
62 | print('==> Metrics with eval')
63 | print("loss :{:0.4f} \t Accuracy:{:0.4f}".format(score, acc))
64 | y_pred = model.predict_generator(eval_generator, n_samples / BATCH_SIZE)
65 | mura = Mura(eval_generator.filenames, y_true=eval_generator.classes, y_pred=y_pred)
66 | print('==> Metrics with predict')
67 | print(mura.metrics())
68 | print(mura.metrics_by_encounter())
69 | # print(mura.metrics_by_study_type())
70 |
--------------------------------------------------------------------------------
/download_and_convert_mura.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import absolute_import, division, print_function
3 |
4 | import re
5 | import os
6 | from os import getcwd
7 | from os.path import exists, isdir, isfile, join
8 | import shutil
9 | import numpy as np
10 | import pandas as pd
11 |
12 |
13 | class ImageString(object):
14 | _patient_re = re.compile(r'patient(\d+)')
15 | _study_re = re.compile(r'study(\d+)')
16 | _image_re = re.compile(r'image(\d+)')
17 | _study_type_re = re.compile(r'XR_(\w+)')
18 |
19 | def __init__(self, img_filename):
20 | self.img_filename = img_filename
21 | self.patient = self._parse_patient()
22 | self.study = self._parse_study()
23 | self.image_num = self._parse_image()
24 | self.study_type = self._parse_study_type()
25 | self.image = self._parse_image()
26 | self.normal = self._parse_normal()
27 |
28 | def flat_file_name(self):
29 | return "{}_{}_patient{}_study{}_image{}.png".format(self.normal, self.study_type, self.patient, self.study,
30 | self.image, self.normal)
31 |
32 | def _parse_patient(self):
33 | return int(self._patient_re.search(self.img_filename).group(1))
34 |
35 | def _parse_study(self):
36 | return int(self._study_re.search(self.img_filename).group(1))
37 |
38 | def _parse_image(self):
39 | return int(self._image_re.search(self.img_filename).group(1))
40 |
41 | def _parse_study_type(self):
42 | return self._study_type_re.search(self.img_filename).group(1)
43 |
44 | def _parse_normal(self):
45 | return "normal" if ("negative" in self.img_filename) else "abnormal"
46 |
47 |
48 | # processed
49 | # data
50 | # ├── train
51 | # │ ├── abnormal
52 | # │ └── normal
53 | # └── val
54 | # ├── abnormal
55 | # └── normal
56 | proc_data_dir = join(getcwd(), 'data')
57 | proc_train_dir = join(proc_data_dir, 'train')
58 | proc_val_dir = join(proc_data_dir, 'val')
59 |
60 | # Data loading code
61 | orig_data_dir = join(getcwd(), 'MURA-v1.0')
62 | train_dir = join(orig_data_dir, 'train')
63 | train_csv = join(orig_data_dir, 'train.csv')
64 | val_dir = join(orig_data_dir, 'valid')
65 | val_csv = join(orig_data_dir, 'valid.csv')
66 | test_dir = join(orig_data_dir, 'test')
67 | assert isdir(orig_data_dir) and isdir(train_dir) and isdir(val_dir) and isdir(test_dir)
68 | assert exists(train_csv) and isfile(train_csv) and exists(val_csv) and isfile(val_csv)
69 |
70 | df = pd.read_csv(train_csv, names=['img', 'label'], header=None)
71 | # imgs = df.img.values.tolist()
72 | # labels = df.label.values.tolist()
73 | # following datasets/folder.py's weird convention here...
74 | samples = [tuple(x) for x in df.values]
75 | for img, label in samples:
76 | assert ("negative" in img) is (label is 0)
77 | enc = ImageString(img)
78 | cat_dir = join(proc_train_dir, enc.normal)
79 | if not os.path.exists(cat_dir):
80 | os.mkdir(cat_dir)
81 | shutil.copy2(enc.img_filename, join(cat_dir, enc.flat_file_name()))
82 |
83 | df = pd.read_csv(val_csv, names=['img', 'label'], header=None)
84 | samples = [tuple(x) for x in df.values]
85 | for img, label in samples:
86 | assert ("negative" in img) is (label is 0)
87 | enc = ImageString(img)
88 | cat_dir = join(proc_val_dir, enc.normal)
89 | if not os.path.exists(cat_dir):
90 | os.mkdir(cat_dir)
91 | shutil.copy2(enc.img_filename, join(cat_dir, enc.flat_file_name()))
92 |
--------------------------------------------------------------------------------
/mura.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import re
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from sklearn.metrics import (accuracy_score, cohen_kappa_score, f1_score, precision_score, recall_score)
8 |
9 | pd.set_option('display.max_rows', 20)
10 | pd.set_option('precision', 4)
11 | np.set_printoptions(precision=4)
12 |
13 |
14 | class Mura(object):
15 | """`MURA `_ Dataset :
16 | Towards Radiologist-Level Abnormality Detection in Musculoskeletal Radiographs.
17 | """
18 | url = "https://cs.stanford.edu/group/mlgroup/mura-v1.0.zip"
19 | filename = "mura-v1.0.zip"
20 | md5_checksum = '4c36feddb7f5698c8bf291b912c438b1'
21 | _patient_re = re.compile(r'patient(\d+)')
22 | _study_re = re.compile(r'study(\d+)')
23 | _image_re = re.compile(r'image(\d+)')
24 | _study_type_re = re.compile(r'_(\w+)_patient')
25 |
26 | def __init__(self, image_file_names, y_true, y_pred=None):
27 | self.imgs = image_file_names
28 | df_img = pd.Series(np.array(image_file_names), name='img')
29 | self.y_true = y_true
30 | df_true = pd.Series(np.array(y_true), name='y_true')
31 | self.y_pred = y_pred
32 | # number of unique classes
33 | self.patient = []
34 | self.study = []
35 | self.study_type = []
36 | self.image_num = []
37 | self.encounter = []
38 | for img in image_file_names:
39 | self.patient.append(self._parse_patient(img))
40 | self.study.append(self._parse_study(img))
41 | self.image_num.append(self._parse_image(img))
42 | self.study_type.append(self._parse_study_type(img))
43 | self.encounter.append("{}_{}_{}".format(
44 | self._parse_study_type(img),
45 | self._parse_patient(img),
46 | self._parse_study(img), ))
47 |
48 | self.classes = np.unique(self.y_true)
49 | df_patient = pd.Series(np.array(self.patient), name='patient')
50 | df_study = pd.Series(np.array(self.study), name='study')
51 | df_image_num = pd.Series(np.array(self.image_num), name='image_num')
52 | df_study_type = pd.Series(np.array(self.study_type), name='study_type')
53 | df_encounter = pd.Series(np.array(self.encounter), name='encounter')
54 |
55 | self.data = pd.concat(
56 | [
57 | df_img,
58 | df_encounter,
59 | df_true,
60 | df_patient,
61 | df_patient,
62 | df_study,
63 | df_image_num,
64 | df_study_type,
65 | ], axis=1)
66 |
67 | if self.y_pred is not None:
68 | self.y_pred_probability = self.y_pred.flatten()
69 | self.y_pred = self.y_pred_probability.round().astype(int)
70 | df_y_pred = pd.Series(self.y_pred, name='y_pred')
71 | df_y_pred_probability = pd.Series(self.y_pred_probability, name='y_pred_probs')
72 | self.data = pd.concat((self.data, df_y_pred, df_y_pred_probability), axis=1)
73 |
74 | def __len__(self):
75 | return len(self.imgs)
76 |
77 | def _parse_patient(self, img_filename):
78 | return int(self._patient_re.search(img_filename).group(1))
79 |
80 | def _parse_study(self, img_filename):
81 | return int(self._study_re.search(img_filename).group(1))
82 |
83 | def _parse_image(self, img_filename):
84 | return int(self._image_re.search(img_filename).group(1))
85 |
86 | def _parse_study_type(self, img_filename):
87 | return self._study_type_re.search(img_filename).group(1)
88 |
89 | def metrics(self):
90 | return "per image metrics:\n\taccuracy : {:.2f}\tf1 : {:.2f}\tprecision : {:.2f}\trecall : {:.2f}\tcohen_kappa : {:.2f}".format(
91 | accuracy_score(self.y_true, self.y_pred),
92 | f1_score(self.y_true, self.y_pred),
93 | precision_score(self.y_true, self.y_pred),
94 | recall_score(self.y_true, self.y_pred),
95 | cohen_kappa_score(self.y_true, self.y_pred), )
96 |
97 | def metrics_by_encounter(self):
98 | y_pred = self.data.groupby(['encounter'])['y_pred_probs'].mean().round()
99 | y_true = self.data.groupby(['encounter'])['y_true'].mean().round()
100 | return "per encounter metrics:\n\taccuracy : {:.2f}\tf1 : {:.2f}\tprecision : {:.2f}\trecall : {:.2f}\tcohen_kappa : {:.2f}".format(
101 | accuracy_score(y_true, y_pred),
102 | f1_score(y_true, y_pred),
103 | precision_score(y_true, y_pred),
104 | recall_score(y_true, y_pred),
105 | cohen_kappa_score(self.y_true, self.y_pred), )
106 |
107 | # def metrics_by_study_type(self):
108 | # y_pred = self.data.groupby(['study_type', 'encounter'])['y_pred_probs'].mean().round()
109 | # y_true = self.data.groupby(['study_type', 'encounter'])['y_true'].mean().round()
110 | # return "per study_type metrics:\n\taccuracy : {:.2f}\tf1 : {:.2f}\tprecision : {:.2f}\trecall : {:.2f}".format(
111 | # accuracy_score(y_true, y_pred),
112 | # f1_score(y_true, y_pred),
113 | # precision_score(y_true, y_pred),
114 | # recall_score(y_true, y_pred), )
115 |
--------------------------------------------------------------------------------
/pytorch/dataloader.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import re
4 | from os import getcwd
5 | from os.path import join
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import torch.utils.data as data
10 | from PIL import Image
11 |
12 | from torchvision.datasets.utils import check_integrity, download_url
13 |
14 |
15 | class MuraDataset(data.Dataset):
16 | """`MURA `_ Dataset :
17 | Towards Radiologist-Level Abnormality Detection in Musculoskeletal Radiographs.
18 | """
19 | url = "https://cs.stanford.edu/group/mlgroup/mura-v1.0.zip"
20 | filename = "mura-v1.0.zip"
21 | md5_checksum = '4c36feddb7f5698c8bf291b912c438b1'
22 | _patient_re = re.compile(r'patient(\d+)')
23 | _study_re = re.compile(r'study(\d+)')
24 | _image_re = re.compile(r'image(\d+)')
25 | _study_type_re = re.compile(r'XR_(\w+)')
26 |
27 | def __init__(self, csv_f, transform=None, download=False):
28 | self.df = pd.read_csv(csv_f, names=['img', 'label'], header=None)
29 | self.imgs = self.df.img.values.tolist()
30 | self.labels = self.df.label.values.tolist()
31 | # following datasets/folder.py's weird convention here...
32 | self.samples = [tuple(x) for x in self.df.values]
33 | # number of unique classes
34 | self.classes = np.unique(self.labels)
35 | self.balanced_weights = self.balance_class_weights()
36 |
37 | self.transform = transform
38 |
39 | def __len__(self):
40 | return len(self.imgs)
41 |
42 | def _parse_patient(self, img_filename):
43 | return int(self._patient_re.search(img_filename).group(1))
44 |
45 | def _parse_study(self, img_filename):
46 | return int(self._study_re.search(img_filename).group(1))
47 |
48 | def _parse_image(self, img_filename):
49 | return int(self._image_re.search(img_filename).group(1))
50 |
51 | def _parse_study_type(self, img_filename):
52 | return self._study_type_re.search(img_filename).group(1)
53 |
54 | def download_and_uncompress_tarball(tarball_url, dataset_dir):
55 | """Downloads the `tarball_url` and uncompresses it locally.
56 | Args:
57 | tarball_url: The URL of a tarball file.
58 | dataset_dir: The directory where the temporary files are stored.
59 | """
60 | filename = tarball_url.split('/')[-1]
61 | filepath = os.path.join(dataset_dir, filename)
62 |
63 | def _progress(count, block_size, total_size):
64 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
65 | float(count * block_size) / float(total_size) * 100.0))
66 | sys.stdout.flush()
67 |
68 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
69 | print()
70 | statinfo = os.stat(filepath)
71 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
72 | if ".zip" in filename:
73 | print("zipfile:{}".format(filepath))
74 | with zipfile.ZipFile(filepath, "r") as zip_ref:
75 | zip_ref.extractall(dataset_dir)
76 | else:
77 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
78 |
79 | def balance_class_weights(self):
80 | count = [0] * len(self.classes)
81 | for item in self.samples:
82 | count[item[1]] += 1
83 | weight_per_class = [0.] * len(self.classes)
84 | N = float(sum(count))
85 | for i in range(len(self.classes)):
86 | weight_per_class[i] = N / float(count[i])
87 | weight = [0] * len(self.samples)
88 | for idx, val in enumerate(self.samples):
89 | weight[idx] = weight_per_class[val[1]]
90 | return weight
91 |
92 | def __getitem__(self, idx):
93 | img_filename = join(self.imgs[idx])
94 | patient = self._parse_patient(img_filename)
95 | study = self._parse_study(img_filename)
96 | image_num = self._parse_image(img_filename)
97 | study_type = self._parse_study_type(img_filename)
98 |
99 | # todo(bdd) : inconsistent right now, need param for grayscale / RGB
100 | # todo(bdd) : 'L' -> gray, 'RGB' -> Colors
101 | image = Image.open(img_filename).convert('RGB')
102 | label = self.labels[idx]
103 |
104 | if self.transform is not None:
105 | image = self.transform(image)
106 |
107 | meta_data = {
108 | 'y_true': label,
109 | 'img_filename': img_filename,
110 | 'patient': patient,
111 | 'study': study,
112 | 'study_type': study_type,
113 | 'image_num': image_num,
114 | 'encounter': "{}_{}_{}".format(study_type, patient, study)
115 | }
116 | return image, label, meta_data
117 |
118 |
119 | if __name__ == '__main__':
120 | import torchvision.transforms as transforms
121 | import pprint
122 |
123 | data_dir = join(getcwd(), 'MURA-v1.0')
124 | val_csv = join(data_dir, 'valid.csv')
125 | val_loader = data.DataLoader(
126 | MuraDataset(val_csv,
127 | transforms.Compose([
128 | transforms.Resize(224),
129 | transforms.CenterCrop(224),
130 | transforms.ToTensor(),
131 | ])),
132 | batch_size=1,
133 | shuffle=False,
134 | num_workers=1,
135 | pin_memory=False)
136 |
137 | for i, (image, label, meta_data) in enumerate(val_loader):
138 | pprint.pprint(meta_data.cpu())
139 | if i == 40:
140 | break
141 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import argparse
4 | from datetime import datetime
5 | from os import environ
6 |
7 | import keras
8 | import numpy as np
9 | import pandas as pd
10 | import tensorflow as tf
11 | from keras.applications import MobileNet
12 | from keras.callbacks import (EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard)
13 | from keras.layers import Dense, GlobalAveragePooling2D
14 | from keras.metrics import binary_accuracy, binary_crossentropy
15 | from keras.models import Model
16 | from keras.optimizers import SGD, Adam
17 | from keras.preprocessing.image import ImageDataGenerator
18 | from sklearn.utils import class_weight
19 |
20 | environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Shut up tensorflow!
21 | print("tf : {}".format(tf.__version__))
22 | print("keras : {}".format(keras.__version__))
23 | print("numpy : {}".format(np.__version__))
24 | print("pandas : {}".format(pd.__version__))
25 |
26 | parser = argparse.ArgumentParser(description='Hyperparameters')
27 | parser.add_argument('--classes', default=1, type=int)
28 | parser.add_argument('--workers', default=4, type=int)
29 | parser.add_argument('--epochs', default=120, type=int)
30 | parser.add_argument('-b', '--batch-size', default=32, type=int, help='mini-batch size')
31 | parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float)
32 | parser.add_argument('--lr-wait', default=10, type=int, help='how long to wait on plateu')
33 | parser.add_argument('--decay', default=1e-4, type=float)
34 | parser.add_argument('--momentum', default=0.9, type=float)
35 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
36 | parser.add_argument('--fullretrain', dest='fullretrain', action='store_true', help='retrain all layers of the model')
37 | parser.add_argument('--seed', default=1337, type=int, help='random seed')
38 | parser.add_argument('--img_channels', default=3, type=int)
39 | parser.add_argument('--img_size', default=224, type=int)
40 | parser.add_argument('--early_stop', default=20, type=int)
41 |
42 |
43 | def train():
44 | global args
45 | args = parser.parse_args()
46 | img_shape = (args.img_size, args.img_size, args.img_channels) # blame theano
47 | now_iso = datetime.now().strftime('%Y-%m-%dT%H:%M:%S%z')
48 |
49 | # We then scale the variable-sized images to 224x224
50 | # We augment .. by applying random lateral inversions and rotations.
51 | train_datagen = ImageDataGenerator(
52 | rescale=1. / 255,
53 | rotation_range=45,
54 | # width_shift_range=0.2,
55 | # height_shift_range=0.2,
56 | zoom_range=0.2,
57 | horizontal_flip=True)
58 |
59 | train_generator = train_datagen.flow_from_directory(
60 | 'data/train',
61 | shuffle=True,
62 | target_size=(args.img_size, args.img_size),
63 | class_mode='binary',
64 | batch_size=args.batch_size, )
65 |
66 | val_datagen = ImageDataGenerator(rescale=1. / 255)
67 | val_generator = val_datagen.flow_from_directory(
68 | 'data/val',
69 | shuffle=True, # otherwise we get distorted batch-wise metrics
70 | class_mode='binary',
71 | target_size=(args.img_size, args.img_size),
72 | batch_size=args.batch_size, )
73 |
74 | classes = len(train_generator.class_indices)
75 | assert classes > 0
76 | assert classes is len(val_generator.class_indices)
77 | n_of_train_samples = train_generator.samples
78 | n_of_val_samples = val_generator.samples
79 |
80 | # Architectures
81 | base_model = MobileNet(input_shape=img_shape, weights='imagenet', include_top=False)
82 | x = base_model.output # Recast classification layer
83 | # x = Flatten()(x) # Uncomment for Resnet based models
84 | x = GlobalAveragePooling2D(name='predictions_avg_pool')(x) # comment for RESNET models
85 | # n_classes; softmax for multi-class, sigmoid for binary
86 | x = Dense(args.classes, activation='sigmoid', name='predictions')(x)
87 | model = Model(inputs=base_model.input, outputs=x)
88 |
89 | # checkpoints
90 | #
91 | checkpoint = ModelCheckpoint(filepath='./models/MobileNet.hdf5', verbose=1, save_best_only=True)
92 | early_stop = EarlyStopping(patience=args.early_stop)
93 | tensorboard = TensorBoard(log_dir='./logs/MobileNet/{}/'.format(now_iso))
94 | # reduce_lr = ReduceLROnPlateau(factor=0.03, cooldown=0, patience=args.lr_wait, min_lr=0.1e-6)
95 | callbacks = [checkpoint, tensorboard, checkpoint]
96 |
97 | # Calculate class weights
98 | weights = class_weight.compute_class_weight('balanced', np.unique(train_generator.classes), train_generator.classes)
99 | weights = {0: weights[0], 1: weights[1]}
100 | # for layer in base_model.layers:
101 | # layer.set_trainable = False
102 |
103 | # print(model.summary())
104 | # for i, layer in enumerate(base_model.layers):
105 | # print(i, layer.name)
106 | if args.resume:
107 | model.load_weights(args.resume)
108 | for layer in model.layers:
109 | layer.set_trainable = True
110 |
111 | # if TRAIN_FULL:
112 | # print("=> retrain all layers of network")
113 | # for layer in model.layers:
114 | # set_trainable = True
115 | # else:
116 | # print("=> retraining only bottleneck and fc layers")
117 | # import pdb
118 | # pdb.set_trace()
119 | # set_trainable = False
120 | # for layer in base_model.layers:
121 | # if "block12" in layer.name: # what block do we want to start unfreezing
122 | # set_trainable = True
123 | # if set_trainable:
124 | # layer.trainable = True
125 | # else:
126 | # layer.trainable = False
127 |
128 | # The network is trained end-to-end using Adam with default parameters
129 | model.compile(
130 | optimizer=Adam(lr=args.lr, decay=args.decay),
131 | # optimizer=SGD(lr=args.lr, decay=args.decay,momentum=args.momentum, nesterov=True),
132 | loss=binary_crossentropy,
133 | metrics=[binary_accuracy], )
134 |
135 | model_out = model.fit_generator(
136 | train_generator,
137 | steps_per_epoch=n_of_train_samples // args.batch_size,
138 | epochs=args.epochs,
139 | validation_data=val_generator,
140 | validation_steps=n_of_val_samples // args.batch_size,
141 | class_weight=weights,
142 | workers=args.workers,
143 | use_multiprocessing=True,
144 | callbacks=callbacks)
145 |
146 |
147 | if __name__ == '__main__':
148 | train()
149 |
--------------------------------------------------------------------------------
/pytorch/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import, division, print_function
2 |
3 | import argparse
4 | import random
5 | import shutil
6 | from os import getcwd
7 | from os.path import exists, isdir, isfile, join
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import torch
12 | import torch.backends.cudnn as cudnn
13 | import torch.nn as nn
14 | import torch.nn.parallel
15 | import torch.optim as optim
16 | import torch.utils.data as data
17 | from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score)
18 | from tensorboardX import SummaryWriter
19 | from torch.autograd import Variable
20 | from torch.optim.lr_scheduler import ReduceLROnPlateau
21 | from tqdm import tqdm
22 |
23 | import torchvision
24 | import torchvision.models as models
25 | import torchvision.transforms as transforms
26 | from dataloader import MuraDataset
27 |
28 | print("torch : {}".format(torch.__version__))
29 | print("torch vision : {}".format(torchvision.__version__))
30 | print("numpy : {}".format(np.__version__))
31 | print("pandas : {}".format(pd.__version__))
32 | model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__"))
33 |
34 | parser = argparse.ArgumentParser(description='Hyperparameters')
35 | parser.add_argument('--data_dir', default='MURA-v1.0', metavar='DIR', help='path to dataset')
36 | parser.add_argument('--arch', default='densenet121', choices=model_names, help='nn architecture')
37 | parser.add_argument('--classes', default=2, type=int)
38 | parser.add_argument('--workers', default=4, type=int)
39 | parser.add_argument('--epochs', default=90, type=int)
40 | parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number')
41 | parser.add_argument('-b', '--batch-size', default=512, type=int, help='mini-batch size')
42 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
43 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
44 | parser.add_argument('--weight-decay', default=.1, type=float, help='weight decay')
45 | parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
46 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
47 | parser.add_argument('--fullretrain', dest='fullretrain', action='store_true', help='retrain all layers of the model')
48 | parser.add_argument('--seed', default=1337, type=int, help='random seed')
49 |
50 | best_val_loss = 0
51 |
52 | tb_writer = SummaryWriter()
53 |
54 |
55 | def main():
56 | global args, best_val_loss
57 | args = parser.parse_args()
58 | print("=> setting random seed to '{}'".format(args.seed))
59 | np.random.seed(args.seed)
60 | torch.manual_seed(args.seed)
61 | torch.cuda.manual_seed(args.seed)
62 |
63 | if args.pretrained:
64 | print("=> using pre-trained model '{}'".format(args.arch))
65 | model = models.__dict__[args.arch](pretrained=True)
66 | for param in model.parameters():
67 | param.requires_grad = False
68 |
69 | if 'resnet' in args.arch:
70 | # for param in model.layer4.parameters():
71 | model.fc = nn.Linear(2048, args.classes)
72 |
73 | if 'dense' in args.arch:
74 | if '121' in args.arch:
75 | # (classifier): Linear(in_features=1024)
76 | model.classifier = nn.Linear(1024, args.classes)
77 | elif '169' in args.arch:
78 | # (classifier): Linear(in_features=1664)
79 | model.classifier = nn.Linear(1664, args.classes)
80 | else:
81 | return
82 |
83 | else:
84 | print("=> creating model '{}'".format(args.arch))
85 | model = models.__dict__[args.arch]()
86 |
87 | model = torch.nn.DataParallel(model).cuda()
88 | # optionally resume from a checkpoint
89 | if args.resume:
90 | if isfile(args.resume):
91 | print("=> found checkpoint")
92 | checkpoint = torch.load(args.resume)
93 | args.start_epoch = checkpoint['epoch']
94 | best_val_loss = checkpoint['best_val_loss']
95 | model.load_state_dict(checkpoint['state_dict'])
96 |
97 | args.epochs = args.epochs + args.start_epoch
98 | print("=> loading checkpoint '{}' with acc of '{}'".format(
99 | args.resume,
100 | checkpoint['best_val_loss'], ))
101 |
102 | else:
103 | print("=> no checkpoint found at '{}'".format(args.resume))
104 |
105 | cudnn.benchmark = True
106 |
107 | # Data loading code
108 | data_dir = join(getcwd(), args.data_dir)
109 | train_dir = join(data_dir, 'train')
110 | train_csv = join(data_dir, 'train.csv')
111 | val_dir = join(data_dir, 'valid')
112 | val_csv = join(data_dir, 'valid.csv')
113 | test_dir = join(data_dir, 'test')
114 | assert isdir(data_dir) and isdir(train_dir) and isdir(val_dir) and isdir(test_dir)
115 | assert exists(train_csv) and isfile(train_csv) and exists(val_csv) and isfile(val_csv)
116 |
117 | # Before feeding images into the network, we normalize each image to have
118 | # the same mean and standard deviation of images in the ImageNet training set.
119 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
120 |
121 | # We then scale the variable-sized images to 224 × 224.
122 | # We augment by applying random lateral inversions and rotations.
123 | train_transforms = transforms.Compose([
124 | transforms.Resize(224),
125 | transforms.CenterCrop(224),
126 | # transforms.RandomVerticalFlip(),
127 | # transforms.RandomRotation(30),
128 | transforms.RandomHorizontalFlip(),
129 | transforms.ToTensor(),
130 | normalize,
131 | ])
132 |
133 | train_data = MuraDataset(train_csv, transform=train_transforms)
134 | weights = train_data.balanced_weights
135 | weights = torch.DoubleTensor(weights)
136 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
137 |
138 | # num_of_sample = 37110
139 | # weights = 1 / torch.DoubleTensor([24121, 1300])
140 | # sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_of_sample)
141 | train_loader = data.DataLoader(
142 | train_data,
143 | batch_size=args.batch_size,
144 | # shuffle=True,
145 | num_workers=args.workers,
146 | sampler=sampler,
147 | pin_memory=True)
148 | val_loader = data.DataLoader(
149 | MuraDataset(val_csv,
150 | transforms.Compose([
151 | transforms.Resize(224),
152 | transforms.CenterCrop(224),
153 | transforms.ToTensor(),
154 | normalize,
155 | ])),
156 | batch_size=args.batch_size,
157 | shuffle=False,
158 | num_workers=args.workers,
159 | pin_memory=True)
160 |
161 | criterion = nn.CrossEntropyLoss().cuda()
162 | # We use an initial learning rate of 0.0001 that is decayed by a factor of
163 | # 10 each time the validation loss plateaus after an epoch, and pick the
164 | # model with the lowest validation loss
165 | if args.fullretrain:
166 | print("=> optimizing all layers")
167 | for param in model.parameters():
168 | param.requires_grad = True
169 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
170 | else:
171 | print("=> optimizing fc/classifier layers")
172 | optimizer = optim.Adam(model.module.fc.parameters(), args.lr, weight_decay=args.weight_decay)
173 |
174 | scheduler = ReduceLROnPlateau(optimizer, 'max', patience=10, verbose=True)
175 | for epoch in range(args.start_epoch, args.epochs):
176 | # train for one epoch
177 | train(train_loader, model, criterion, optimizer, epoch)
178 | # evaluate on validation set
179 | val_loss = validate(val_loader, model, criterion, epoch)
180 | scheduler.step(val_loss)
181 | # remember best Accuracy and save checkpoint
182 | is_best = val_loss > best_val_loss
183 | best_val_loss = max(val_loss, best_val_loss)
184 | save_checkpoint({
185 | 'epoch': epoch + 1,
186 | 'arch': args.arch,
187 | 'state_dict': model.state_dict(),
188 | 'best_val_loss': best_val_loss,
189 | }, is_best)
190 |
191 |
192 | def train(train_loader, model, criterion, optimizer, epoch):
193 | losses = AverageMeter()
194 | acc = AverageMeter()
195 |
196 | # ensure model is in train mode
197 | model.train()
198 | pbar = tqdm(train_loader)
199 | for i, (images, target, meta) in enumerate(pbar):
200 | target = target.cuda(async=True)
201 | image_var = Variable(images)
202 | label_var = Variable(target)
203 |
204 | # pass this batch through our model and get y_pred
205 | y_pred = model(image_var)
206 |
207 | # update loss metric
208 | loss = criterion(y_pred, label_var)
209 | losses.update(loss.data[0], images.size(0))
210 |
211 | # update accuracy metric
212 | prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1))
213 | acc.update(prec1[0], images.size(0))
214 |
215 | # compute gradient and do SGD step
216 | optimizer.zero_grad()
217 | loss.backward()
218 | optimizer.step()
219 |
220 | pbar.set_description("EPOCH[{0}][{1}/{2}]".format(epoch, i, len(train_loader)))
221 | pbar.set_postfix(
222 | acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc),
223 | loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses))
224 |
225 | tb_writer.add_scalar('train/loss', losses.avg, epoch)
226 | tb_writer.add_scalar('train/acc', acc.avg, epoch)
227 | return
228 |
229 |
230 | def validate(val_loader, model, criterion, epoch):
231 | model.eval()
232 | acc = AverageMeter()
233 | losses = AverageMeter()
234 | meta_data = []
235 | pbar = tqdm(val_loader)
236 | for i, (images, target, meta) in enumerate(pbar):
237 | target = target.cuda(async=True)
238 | image_var = Variable(images, volatile=True)
239 | label_var = Variable(target, volatile=True)
240 |
241 | y_pred = model(image_var)
242 | # udpate loss metric
243 | loss = criterion(y_pred, label_var)
244 | losses.update(loss.data[0], images.size(0))
245 |
246 | # update accuracy metric on the GPU
247 | prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1))
248 | acc.update(prec1[0], images.size(0))
249 |
250 | sm = nn.Softmax()
251 | sm_pred = sm(y_pred).data.cpu().numpy()
252 | # y_norm_probs = sm_pred[:, 0] # p(normal)
253 | y_pred_probs = sm_pred[:, 1] # p(abnormal)
254 |
255 | meta_data.append(
256 | pd.DataFrame({
257 | 'img_filename': meta['img_filename'],
258 | 'y_true': meta['y_true'].numpy(),
259 | 'y_pred_probs': y_pred_probs,
260 | 'patient': meta['patient'].numpy(),
261 | 'study': meta['study'].numpy(),
262 | 'image_num': meta['image_num'].numpy(),
263 | 'encounter': meta['encounter'],
264 | }))
265 |
266 | pbar.set_description("VALIDATION[{}/{}]".format(i, len(val_loader)))
267 | pbar.set_postfix(
268 | acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc),
269 | loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses))
270 | df = pd.concat(meta_data)
271 | ab = df.groupby(['encounter'])['y_pred_probs', 'y_true'].mean()
272 | ab['y_pred_round'] = ab.y_pred_probs.round()
273 | ab['y_pred_round'] = pd.to_numeric(ab.y_pred_round, downcast='integer')
274 |
275 | f1_s = f1_score(ab.y_true, ab.y_pred_round)
276 | prec_s = precision_score(ab.y_true, ab.y_pred_round)
277 | rec_s = recall_score(ab.y_true, ab.y_pred_round)
278 | acc_s = accuracy_score(ab.y_true, ab.y_pred_round)
279 | tb_writer.add_scalar('val/f1_score', f1_s, epoch)
280 | tb_writer.add_scalar('val/precision', prec_s, epoch)
281 | tb_writer.add_scalar('val/recall', rec_s, epoch)
282 | tb_writer.add_scalar('val/accuracy', acc_s, epoch)
283 | # return the metric we want to evaluate this model's performance by
284 | return f1_s
285 |
286 |
287 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
288 | torch.save(state, filename)
289 | if is_best:
290 | shutil.copyfile(filename, 'model_best.pth.tar')
291 |
292 |
293 | class AverageMeter(object):
294 | """Computes and stores the average and current value"""
295 |
296 | def __init__(self):
297 | self.reset()
298 |
299 | def reset(self):
300 | self.val = 0
301 | self.avg = 0
302 | self.sum = 0
303 | self.count = 0
304 |
305 | def update(self, val, n=1):
306 | self.val = val
307 | self.sum += val * n
308 | self.count += n
309 | self.avg = self.sum / self.count
310 |
311 |
312 | def accuracy(y_pred, y_actual, topk=(1, )):
313 | """Computes the precision@k for the specified values of k"""
314 | maxk = max(topk)
315 | batch_size = y_actual.size(0)
316 |
317 | _, pred = y_pred.topk(maxk, 1, True, True)
318 | pred = pred.t()
319 | correct = pred.eq(y_actual.view(1, -1).expand_as(pred))
320 |
321 | res = []
322 | for k in topk:
323 | correct_k = correct[:k].view(-1).float().sum(0)
324 | res.append(correct_k.mul_(100.0 / batch_size))
325 |
326 | return res
327 |
328 |
329 | if __name__ == '__main__':
330 | main()
331 |
--------------------------------------------------------------------------------