├── trainer
├── arch.JPG
├── trainer_small.py
├── config.gin
├── config_small.gin
├── README.md
├── plot_tsne.py
├── dataset.py
├── main.py
├── utils.py
├── models.py
├── grad-cam.py
└── trainer.py
/trainer:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/arch.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SahilC/multitask-eye-disease-recognition/HEAD/arch.JPG
--------------------------------------------------------------------------------
/trainer_small.py:
--------------------------------------------------------------------------------
1 | import gin
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.utils.tensorboard import SummaryWriter
6 | import tqdm
7 | import os
8 | from datetime import datetime
9 | from collections import defaultdict
10 | from utils import compute_bleu, compute_topk, accuracy_recall_precision_f1, calculate_confusion_matrix
11 |
12 |
--------------------------------------------------------------------------------
/config.gin:
--------------------------------------------------------------------------------
1 | run.batch_size = 64
2 | run.epochs = 20
3 | run.val_split = 0.15
4 | run.num_workers = 32
5 | run.print_every = 100
6 |
7 | # run.trainval_csv_path = 'merged_combined_30_sept.csv'
8 | run.trainval_csv_path = 'trainset_with_normal_30_sept.csv'
9 | # run.trainval_csv_path = 'verified_oiscapture_trained_labels_with_normal.csv'
10 | run.test_csv_path = 'testset_filtered.csv'
11 | # run.test_csv_path = 'testset.csv'
12 | # run.test_csv_path = 'extra_normal_unclassified.csv'
13 | # run.trainval_csv_path = 'trainset.csv'
14 | # trainval_csv_path = 'trainset_with_normal.csv'
15 | # test_csv_path = 'trainset_filtered.csv'
16 | # trainval_csv_path = 'self-training-set_filtered.csv'
17 | # run.trainval_csv_path = 'self-training_images.csv'
18 | run.tasks = [0, 1, 2]
19 |
20 | run.lr = 1e-3
21 | run.weight_decay = 1e-6
22 | run.momentum = 0.9
23 | run.dataset_dir = '/data2/fundus_images/'
24 | run.model_type = 'resnet50'
25 |
26 | MultiTaskModel.in_feats = 2048 # 1024 -> densenet121, 2048 -> resnet50, 512 -> resnet34
27 |
--------------------------------------------------------------------------------
/config_small.gin:
--------------------------------------------------------------------------------
1 | run.batch_size = 32
2 | run.epochs = 15
3 | run.val_split = 0.15
4 | run.num_workers = 32
5 | run.print_every = 100
6 |
7 | run.trainval_csv_path = 'trainset_with_normal_30_sept.csv'
8 | run.test_csv_path = 'merged_combined_30_sept.csv'
9 | # run.trainval_csv_path = 'trainset_with_normal_30_sept.csv'
10 | # run.trainval_csv_path = 'verified_oiscapture_trained_labels_with_normal.csv'
11 | # run.test_csv_path = 'testset_filtered.csv'
12 | # run.test_csv_path = 'testset.csv'
13 | # run.test_csv_path = 'extra_normal_unclassified.csv'
14 | # run.trainval_csv_path = 'trainset.csv'
15 | # trainval_csv_path = 'trainset_with_normal.csv'
16 | # run.test_csv_path = 'trainset_filtered.csv'
17 | # trainval_csv_path = 'self-training-set_filtered.csv'
18 | # run.trainval_csv_path = 'self-training_images.csv'
19 | run.tasks = [0, 1, 2]
20 |
21 | run.lr = 1e-3
22 | run.weight_decay = 1e-6
23 | run.momentum = 0.9
24 | run.dataset_dir = '/data2/fundus_images/'
25 | run.model_type = 'resnet50'
26 |
27 | # MultiTaskModel.in_feats = 2048 # 1024 -> densenet121, 2048 -> resnet50, 512 -> resnet34
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multitask Eye Disease Recognition
2 | Multitask learning for eye disease recognition.
3 |
4 | Work done in Microsoft AI Research. Published in ACPR'20.
5 |
6 | Recently, deep learning techniques have been widely used for medical image analysis. While there exists some work on deep learning for ophthalmology, there is little work on multi-disease predictions from retinal fundus images. Also, most of the work is based on small datasets. In this work, given a fundus image, we focus on three tasks related to eye disease prediction: (1) predicting one of the four broad disease categories – diabetic retinopathy, age-related macular degeneration, glaucoma, and melanoma, (2) predicting one of the 320 fine disease sub-categories, (3) generating a textual diagnosis. We model these three tasks under a multi-task learning setup using ResNet, a popular deep convolutional neural network architecture. Our experiments on a large dataset of 40658 images across 3502 patients provides ∼86% accuracy for task 1, ∼67% top-5 accuracy for task 2, and ∼32 BLEU for the diagnosis captioning task.
7 |
8 | Link to paper:- https://link.springer.com/chapter/10.1007/978-3-030-41299-9_57
9 |
10 | Architecture Diagram
11 |
12 |
13 |
14 | Run the code with:-
15 | ```
16 | python main.py
17 | ```
18 |
19 | Configuration can be modified in
20 |
21 | ```
22 | config.gin
23 | ```
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/plot_tsne.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | z = np.load('tsne.npy')
5 | # x = z[4:,0]
6 | # y = z[4:, 1]
7 |
8 | plt.axis([-1, 1, -1, 1])
9 | x = []
10 | y = []
11 | key = {4: 'retinal', 5:
12 | 'macroaneurysm', 6: 'cystoid', 7: 'macular', 8: 'edema', 9:
13 | 'nonexudative', 10: 'senile', 11: 'degeneration', 14: 'type', 15: '2', 16: 'diabetes', 17: 'mellitus',19: 'retinopathy', 20: 'choroidal', 21: 'neovascular', 22:
14 | 'membrane', 23: 'hemorrhage', 24: 'exudative', 25: 'age-related', 31: 'non-proliferative', 32: 'diabetic', 33: 'glaucoma', 34: 'suspect',
15 | 35: 'inactive', 39:'atrophic', 40: 'subfoveal', 41: 'involvement', 42: 'mild', 43:
16 | 'nonproliferative', 44: 'associated', 45: 'drusen', 46: 'macula', 48: 'detachment', 60: 'tension', 63: 'chronic', 64: 'angle-closure', 68: 'low', 76:'narrow', 77: 'angle', 78: 'pseudoexfoliation', 79: 'uveitic', 80:
17 | 'recession', 85: 'inflammations', 88: 'high', 89: 'cotton', 90: 'wool', 91: 'spots', 92: 'degenerative', 93:
18 | 'malignant', 94: 'melanoma', 95: 'intermediate', 98: 'ophthalmic', 99: 'manifestations', 100: 'uncontrolled', 105: 'pigment', 106:
19 | 'epithelium', 107: 'hypertrophy', 108: 'underlying', 109: 'condition',
20 | 113: 'choroid', 118: 'presence', 119: 'uvea', 120: 'drug', 121:
21 | 'chemical', 122: 'induced', 125: 'uveal', 126: 'anterior',
22 | 127:'subretinal', 134: 'atrophy', 135: 'iris', 136:
23 | 'oculopathy', 137: 'resolved', 140:'posterior', 141: 'cataract', 142: 'dm', 146: 'juvenile', 147: 'central', 148: 'geographic', 149:'hemorrhagic', 152: 'combined', 153:
24 | 'rhegmatogenous', 154: 'clinically', 155: 'significant', 156:'insulin',
25 | 157: 'involving', 161: 'epitheliopathy', 162: 'quiescent', 165:
26 | 'optic', 166: 'papillopathy', 167: 'exudates', 172: 'detachments', 173:
27 | 'maculae', 175: 'traumatic',
28 | 179: 'syndrome', 181: 'inflammation', 183: 'disorders', 184: 'increased',185: 'pressure', 187: 'closed-angle'
29 | }
30 | nam = []
31 | for i in key.keys():
32 | nam.append(key[i])
33 | x.append(z[i,0])
34 | y.append(z[i, 1])
35 |
36 | x = np.array(x)
37 | y = np.array(y)
38 |
39 | print(x,y)
40 | plt.scatter(x,y)
41 | for i in range(len(key.keys())):
42 | plt.annotate(nam[i], (x[i], y[i]))
43 |
44 | plt.savefig('abc1.png')
45 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | # Generic imports
2 | import pandas as pd
3 | import numpy as np
4 | from PIL import Image
5 | import os
6 | import pdb
7 | import random
8 | import torch
9 | import csv
10 | import nltk
11 | from collections import defaultdict
12 |
13 | # Torch imports
14 | from torchvision import transforms
15 | from torch.utils.data.dataset import Dataset # For custom datasets
16 |
17 | from utils import readLangs, indexFromSentence
18 | class CustomDatasetFromImages(Dataset):
19 | def __init__(self, csv_path, data_dir='/data/sachelar/fundus_images'):
20 | """
21 | Args:
22 | csv_path (string): path to csv file
23 | img_path (string): path to the folder where images are
24 | transform: pytorch transforms for transforms and tensor conversion
25 | """
26 | self.label2idx1 = {'melanoma':0, 'glaucoma':1, 'amd':2, 'diabetic retinopathy':3, 'normal':4}
27 | # self.label2idx1 = {'not applicable':0, 'not classified':1, 'diabetes no retinopathy':2}
28 | # 541 classes
29 | # self.label2idx2 = {j.strip().lower(): (int(i.strip().lower()) -1) for i, j in list(csv.reader(open('labels.txt', 'r'), delimiter='\t'))}
30 | self.label2idx2 = {j.strip().lower(): (int(i.strip().lower()) - 1) for
31 | i, j in list(csv.reader(open('labels2.txt', 'r'), delimiter='\t'))}
32 |
33 | self.to_tensor = transforms.Compose([
34 | transforms.Resize((224, 224)),
35 | transforms.RandomHorizontalFlip(p=0.5),
36 | transforms.RandomVerticalFlip(p=0.5),
37 | transforms.ToTensor(),
38 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
39 | self.data_info = pd.read_csv(csv_path, header=None)
40 | # change to -4?
41 | self.image_arr = np.asarray([os.path.join(data_dir,i.split('/')[-1].replace('%','')) for i in self.data_info.iloc[:,0]])
42 | self.label_arr1 = [self.label2idx1[i.lower()] for i in np.asarray(self.data_info.iloc[:, 1])]
43 | self.label_arr2 = []
44 | self.lang, self.pairs = readLangs(self.data_info.iloc[:, 2], 15)
45 |
46 | for i,z in enumerate(np.asarray(self.data_info.iloc[:, 2])):
47 | self.label_arr2.append(self.label2idx2[z.strip().lower()])
48 | # self.label_arr2 = [self.label2idx2[i] for i in np.asarray(self.data_info.iloc[:, -1])]
49 | # self.operation_arr = np.asarray(self.data_info.iloc[:, 2])
50 | self.data_len = len(self.data_info.index)
51 |
52 | def get_lang(self):
53 | return self.lang
54 |
55 | def __getitem__(self, index):
56 | single_image_name = self.image_arr[index]
57 | img_as_img = Image.open(single_image_name).convert('RGB')
58 | img_as_tensor = self.to_tensor(img_as_img)
59 | single_image_label = self.label_arr1[index]
60 | fine_grained_label = self.label_arr2[index]
61 | text, length = indexFromSentence(self.lang, self.data_info.iloc[index, 2])
62 | text = torch.LongTensor(text).view(-1, 1)
63 | return (single_image_name, img_as_tensor, single_image_label, fine_grained_label, text)
64 |
65 | def __len__(self):
66 | return self.data_len
67 |
68 | class GradedDatasetFromImages(Dataset):
69 | def __init__(self, csv_path, data_dir='/data/sachelar/fundus_images'):
70 | """
71 | Args:
72 | csv_path (string): path to csv file
73 | img_path (string): path to the folder where images are
74 | transform: pytorch transforms for transforms and tensor conversion
75 | """
76 | self.label2idx1 = {'melanoma':0, 'glaucoma':1, 'amd':2, 'diabetic retinopathy':3, 'normal':4}
77 |
78 | self.to_tensor = transforms.Compose([
79 | transforms.Resize((224, 224)),
80 | transforms.RandomHorizontalFlip(p=0.5),
81 | transforms.RandomVerticalFlip(p=0.5),
82 | transforms.ToTensor(),
83 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
84 |
85 | self.data_info = pd.read_csv(csv_path, header=None)
86 | self.image_arr = np.asarray([os.path.join(data_dir, i.replace('%','')) for i in self.data_info.iloc[:,0]])
87 | self.label_arr1 = [self.label2idx1[i.lower()] for i in np.asarray(self.data_info.iloc[:, 1])]
88 | self.data_len = len(self.data_info.index)
89 |
90 | def __getitem__(self, index):
91 | single_image_name = self.image_arr[index]
92 | img_as_img = Image.open(single_image_name).convert('RGB')
93 | img_as_tensor = self.to_tensor(img_as_img)
94 | single_image_label = self.label_arr1[index]
95 | return img_as_tensor, single_image_label
96 |
97 | def __len__(self):
98 | return self.data_len
99 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gin
3 | import numpy as np
4 | import copy
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torchvision.models as models
9 |
10 | from trainer import MultiTaskTrainer
11 | from models import MultiTaskModel
12 | from dataset import CustomDatasetFromImages
13 | from dataset import GradedDatasetFromImages
14 |
15 | from torch.optim.lr_scheduler import ReduceLROnPlateau
16 |
17 | # Hacks for Reproducibility
18 | seed = 3
19 | torch.manual_seed(seed)
20 | torch.backends.cudnn.deterministic = True
21 | torch.backends.cudnn.benchmark = False
22 | np.random.seed(seed)
23 |
24 | # from cnn_model import MnistCNNModel
25 | @gin.configurable
26 | def run(batch_size, epochs, val_split, num_workers, print_every,
27 | trainval_csv_path, test_csv_path, model_type, tasks, lr, weight_decay,
28 | momentum, dataset_dir):
29 |
30 | all_dataset = CustomDatasetFromImages(trainval_csv_path, data_dir = dataset_dir)
31 | # test_dataset = CustomDatasetFromImages(test_csv_path, data_dir = dataset_dir)
32 | val_from_images = GradedDatasetFromImages(test_csv_path, data_dir = dataset_dir)
33 |
34 | dset_len = len(all_dataset)
35 | val_size = int(val_split * dset_len)
36 | test_size = int(0.15 * dset_len)
37 | train_size = dset_len - val_size
38 |
39 |
40 | train_dataset, val_dataset = torch.utils.data.random_split(all_dataset,
41 | [train_size,
42 | val_size])
43 |
44 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
45 | batch_size=2 * batch_size,
46 | pin_memory=False,
47 | drop_last=True,
48 | shuffle=True,
49 | num_workers=num_workers)
50 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
51 | batch_size=batch_size,
52 | pin_memory=False,
53 | drop_last=True,
54 | shuffle=True,
55 | num_workers=num_workers)
56 | test_loader = torch.utils.data.DataLoader(dataset=val_from_images,
57 | batch_size=batch_size,
58 | pin_memory=False,
59 | drop_last=True,
60 | shuffle=True,
61 | num_workers=num_workers)
62 |
63 | lang = train_dataset.dataset.get_lang()
64 |
65 | if model_type == 'densenet121':
66 | model = models.densenet121(pretrained=False)
67 | elif model_type == 'resnet101':
68 | model = models.resnet101(pretrained=False)
69 | elif model_type == 'resnet50':
70 | model = models.resnet50(pretrained=False)
71 | elif model_type == 'resnet34':
72 | model = models.resnet34(pretrained=False)
73 | elif model_type == 'vgg19':
74 | model = models.vgg19(pretrained=False)
75 |
76 | model = MultiTaskModel(model, vocab_size=lang.n_words, model_type = model_type)
77 |
78 | model = nn.DataParallel(model)
79 |
80 | print(model)
81 |
82 | model = model.to('cuda')
83 |
84 | criterion = nn.CrossEntropyLoss()
85 |
86 | optimizer = torch.optim.SGD(model.parameters(),
87 | weight_decay=weight_decay,
88 | momentum=momentum,
89 | lr=lr,
90 | nesterov=True)
91 |
92 | scheduler = ReduceLROnPlateau(optimizer,
93 | factor=0.5,
94 | patience=3,
95 | min_lr=1e-7,
96 | verbose=True)
97 |
98 | trainer = MultiTaskTrainer(model, optimizer, scheduler, criterion, tasks, epochs, lang, print_every = print_every)
99 |
100 | trainer.train(train_loader, val_loader)
101 |
102 | val_loss, total_d_acc, total_acc, bleu, total_f1,total_recall, total_precision, sent_gt, sent_pred, total_topk,per_disease_topk, per_disease_bleu, total_cm = trainer.validate(test_loader)
103 | with open(trainer.output_log, 'a+') as out:
104 | print('Test Loss:{:.8f}\tAcc:{:.8f}\tDAcc:{:.8f}\tBLEU:{:.8f}'.format(val_loss, total_acc, total_d_acc, bleu), file=out)
105 | print('total_topk',total_topk, file=out)
106 | print('per_disease_topk', per_disease_topk, file=out)
107 | print('per_disease_bleu', per_disease_bleu, file=out)
108 | print(total_cm, file=out)
109 | for k in np.random.choice(list(range(len(sent_gt))), size=10, replace=False):
110 | print(sent_gt[k], file=out)
111 | print(sent_pred[k], file=out)
112 | print('---------------------', file=out)
113 | trainer.test(test_loader)
114 |
115 | if __name__ == "__main__":
116 | task_configs =[[0],[1],[2],[0,1], [1,2],[0,2],[0, 1, 2]]
117 | for i, t in enumerate(task_configs):
118 | print("Running", t)
119 | gin.parse_config_file('config.gin')
120 | gin.bind_parameter('run.tasks', t)
121 | run()
122 | gin.clear_config()
123 |
124 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from nltk import word_tokenize
3 | import re
4 | import unicodedata
5 | from nltk.translate.bleu_score import sentence_bleu
6 | import numpy as np
7 | from sklearn.metrics import recall_score, precision_score, f1_score, classification_report, confusion_matrix
8 |
9 | import torch
10 |
11 | from sklearn.preprocessing import LabelEncoder
12 | from sklearn.metrics import recall_score, precision_score, f1_score, classification_report
13 | import torch
14 |
15 |
16 | SOS_token = 0
17 | EOS_token = 1
18 | PAD_token = 2
19 | UNK_token = 3
20 |
21 | class Lang:
22 | def __init__(self, name):
23 | self.name = name
24 | self.word2index = {"UNK":3}
25 | self.word2count = {}
26 | self.index2word = {0: "SOS", 1: "EOS", 2: "PAD", 3:"UNK"}
27 | self.n_words = 4
28 |
29 | def addSentence(self, sentence):
30 | for word in word_tokenize(sentence):
31 | self.addWord(word)
32 |
33 | def addWord(self, word):
34 | if word not in self.word2index:
35 | self.word2index[word] = self.n_words
36 | self.word2count[word] = 1
37 | self.index2word[self.n_words] = word
38 | self.n_words += 1
39 | else:
40 | self.word2count[word] += 1
41 |
42 | def unicodeToAscii(s):
43 | return ''.join(
44 | c for c in unicodedata.normalize('NFD', s)
45 | if unicodedata.category(c) != 'Mn'
46 | )
47 |
48 | # Lowercase, trim, and remove non-letter characters
49 | def normalizeString(s):
50 | s = unicodeToAscii(s.lower().strip())
51 | # s = re.sub(r"([.!?])", r" \1", s)
52 | # s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
53 | return s
54 |
55 | def readLangs(lines, max_length, lang1='eng'):
56 | print("Reading lines...")
57 | input_lang = Lang(lang1)
58 |
59 | pairs = []
60 | for e,l in enumerate(lines):
61 | val = []
62 | words_list = word_tokenize(l)
63 | if len(words_list) >= max_length:
64 | val.append(normalizeString(' '.join(words_list[:max_length - 1])))
65 | else:
66 | val.append(normalizeString(l))
67 | input_lang.addSentence(val[0])
68 | pairs.append(val)
69 |
70 | print("Read %s sentence pairs" % len(pairs))
71 | print("Counted words:")
72 | print(input_lang.name, input_lang.n_words)
73 | print(input_lang.word2count)
74 | return input_lang, pairs
75 |
76 | def indexFromSentence(lang, sentence, unk_ratio = 15, max_length = 15):
77 | out = []
78 | out.append(SOS_token)
79 | # out = [lang.word2index[word] for word in word_tokenize(sentence)]
80 |
81 | # If word freq is small then replace with UNK
82 | for word in word_tokenize(normalizeString(sentence)):
83 | try:
84 | if lang.word2count[word] > unk_ratio:
85 | out.append(lang.word2index[word])
86 | else:
87 | out.append(lang.word2index['UNK'])
88 | except:
89 | pass
90 | # print("error while processing word", word)
91 | # out.append(lang.word2index['UNK'])
92 | out.append(EOS_token)
93 |
94 | if len(out) > max_length:
95 | sentence_length = max_length
96 | else:
97 | sentence_length = len(out)
98 |
99 | # If sentence is small then pad
100 | if len(out) < max_length:
101 | for i in range(max_length - len(out)):
102 | out.append(PAD_token)
103 | elif len(out) > max_length:
104 | out = out[:max_length]
105 | return out, sentence_length
106 |
107 | def variableFromSentence(lang, sentence, max_length = 15):
108 | indexes, sentence_length = indexesFromSentence(lang, sentence, max_length = max_length)
109 | result = torch.LongTensor(indexes).view(-1, 1)
110 | return result, sentence_length
111 |
112 | def variablesFromPair(pair):
113 | input_variable = variableFromSentence(input_lang, pair[0])
114 | return input_variable
115 |
116 | #Metrics
117 | def accuracy_recall_precision_f1(y_pred, y_target):
118 |
119 | """Computes the accuracy, recall, precision and f1 score for given predictions and targets
120 | Args:
121 | y_pred: Logits of the predictions for each class
122 | y_target: Target values
123 | """
124 |
125 | predictions = y_pred.cpu().detach().numpy()
126 | y_target = y_target.cpu().numpy()
127 |
128 | correct = np.sum(predictions == y_target)
129 | accuracy = correct / len(predictions)
130 |
131 | recall = recall_score(y_target, predictions, average=None) #average=None (the scores for each class are returned)
132 | precision = precision_score(y_target, predictions, average=None)
133 | f1 = f1_score(y_target, predictions, average=None)
134 |
135 | return accuracy, recall, precision, f1
136 |
137 | def calculate_confusion_matrix(y_pred, y_target):
138 |
139 | predictions = y_pred.cpu().detach().numpy()
140 | y_target = y_target.cpu().numpy()
141 |
142 | #Confusion matrix
143 | cm = confusion_matrix(y_target, predictions)
144 |
145 | #multi_cm = multilabel_confusion_matrix(y_target, predictions)
146 | #print(multi_cm)
147 | #print(confusion_matrix(y_target, predictions))
148 |
149 | #Classification report
150 | #print(classification_report(y_target, predictions))
151 |
152 | return cm
153 |
154 | def compute_topk(topk_vals, gt, k):
155 | _, preds = topk_vals.topk(k = k, dim = 1)
156 | topk_acc = 0
157 | for i in range(preds.size(1)):
158 | topk_acc += preds[:, i].eq(gt).sum().item()
159 | return (topk_acc / topk_vals.size(0))
160 |
161 | def compute_bleu(lang, text1, preds1, disease, per_disease_bleu):
162 | ind2word = lang.index2word
163 | bleu = 0
164 | sents_gt = []
165 | sents_pred = []
166 | for k in range(len(text1)):
167 | sent1 = []
168 | sent2 = []
169 | weights = (0.25, 0.25, 0.25, 0.25)
170 | for j in range(len(text1[k])):
171 | if text1[k][j] != 0 and text1[k][j] != 1 and text1[k][j] != 2:
172 | sent1.append(ind2word[text1[k][j]])
173 | if preds1[k][j] != 0 and preds1[k][j] != 1 and preds1[k][j] != 2:
174 | sent2.append(ind2word[preds1[k][j]])
175 | if len(sent2) > 0 and len(sent2) < 4 and weights == (0.25, 0.25, 0.25, 0.25):
176 | weights = (1 / len(sent2),) * len(sent2)
177 | c_bleu = sentence_bleu([sent1], sent2, weights = weights)
178 | per_disease_bleu[disease[k].item()].append(c_bleu)
179 | sents_gt.append(sent1)
180 | sents_pred.append(sent2)
181 | bleu += c_bleu
182 | return (bleu/len(text1)), sents_gt, sents_pred
183 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import gin
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import random
7 |
8 | class LanguageModel(nn.Module):
9 | def __init__(self, vocab_size = 193, embed_size = 256, inp_size = 1024, hidden_size = 512,
10 | num_layers = 1, dropout_p = 0.1):
11 | super(LanguageModel, self).__init__()
12 | self.hidden_size = hidden_size
13 | self.embed_size = embed_size
14 | self.dropout = nn.Dropout(dropout_p)
15 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=2)
16 | self.project = nn.Linear(inp_size, hidden_size)
17 | self.gru = nn.GRU(input_size=embed_size, hidden_size=hidden_size,
18 | num_layers=num_layers, batch_first=True)
19 | # self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
20 | self.linear = nn.Linear(hidden_size, vocab_size)
21 | self.max_length = 15
22 | self.teacher_forcing_ratio = 0.5
23 | self.init_weights()
24 |
25 | def init_weights(self):
26 | self.embedding.weight.data.uniform_(-0.1, 0.1)
27 | self.linear.weight.data.uniform_(-0.1, 0.1)
28 | self.linear.bias.data.fill_(0)
29 |
30 | def forward(self, img_feats, text):
31 | text = text.squeeze()
32 | preds = []
33 | self.gru.flatten_parameters()
34 | decoder_input = text[:, 0]
35 | state = self.project(img_feats).unsqueeze(0)
36 | for i in range(1, self.max_length):
37 | use_teacher_forcing = True if (self.training and random.random() < self.teacher_forcing_ratio) else False
38 | embeddings = self.dropout(self.embedding(decoder_input)).squeeze()
39 | feats, state = self.gru(embeddings.unsqueeze(1), state)
40 | pred = self.linear(state).squeeze()
41 |
42 | if use_teacher_forcing:
43 | decoder_input = text[:,i]
44 | else:
45 | output = F.log_softmax(pred, dim=-1)
46 | decoder_input = torch.argmax(output, dim=-1)
47 | preds.append(pred.unsqueeze(1))
48 |
49 | return torch.cat(preds, 1)
50 |
51 | @gin.configurable
52 | class AutoEncoder(nn.Module):
53 | def __init__(self, model_type, model = None):
54 | super(AutoEncoder, self).__init__()
55 | self.model_type = model_type
56 | if model_type == 'self':
57 | self.conv = nn.Sequential(nn.Conv2d(3, 32, (4, 4), 2, 1), # 224 -> 112
58 | nn.ReLU(),
59 | nn.Conv2d(32, 64, (4, 4), 2,1), # 112 -> 56
60 | nn.ReLU(),
61 | nn.Conv2d(64, 128, (4, 4), 2, 1), # 56 -> 28
62 | nn.ReLU(),
63 | nn.Conv2d(128, 128, (4, 4), 2, 1), # 28 -> 14
64 | nn.ReLU(),
65 | nn.Conv2d(128, 64, (4, 4), 2, 1), # 14 -> 7
66 | nn.ReLU(),
67 | nn.Conv2d(64, 5, (7, 7), 1, 0))
68 | else:
69 | self.conv = torch.nn.Sequential(*list(model.children())[:-1])
70 | self.lin = nn.Linear(2048, 256)
71 | self.deconv = nn.Sequential(
72 | nn.ConvTranspose2d(256, 256, 7, 1, 0),
73 | nn.BatchNorm2d(256),
74 | nn.ReLU(),
75 | # state size: (ngf * 8) x 4 x 4
76 | nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
77 | nn.BatchNorm2d(256),
78 | nn.ReLU(),
79 | # state size: ngf x 32 x 32
80 | nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
81 | nn.BatchNorm2d(128),
82 | nn.ReLU(),
83 | # state size: (ngf * 4) x 8 x 8
84 | nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
85 | nn.BatchNorm2d(64),
86 | nn.ReLU(),
87 | # state size: (ngf * 2) x 16 x 16
88 | nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
89 | nn.BatchNorm2d(32),
90 | nn.ReLU(),
91 | # state size: ngf x 32 x 32
92 | nn.ConvTranspose2d(32, 3, 4, 2, 1),
93 | nn.Tanh()
94 | )
95 |
96 | def forward(self, x):
97 | out = self.conv(x).squeeze()
98 | if self.model_type != 'self':
99 | out = self.lin(F.relu(out))
100 | out = self.deconv(out.view(-1, out.size(1), 1, 1))
101 | return out
102 |
103 | @gin.configurable
104 | class AbnormalNet(nn.Module):
105 | def __init__(self, model_type, model = None):
106 | super(AbnormalNet, self).__init__()
107 | self.model_type = model_type
108 | if model_type == 'self':
109 | self.conv = nn.Sequential(nn.Conv2d(3, 32, (4, 4), 2, 1), # 224 -> 112
110 | nn.ReLU(),
111 | nn.Conv2d(32, 64, (4, 4), 2,1), # 112 -> 56
112 | nn.ReLU(),
113 | nn.Conv2d(64, 128, (4, 4), 2, 1), # 56 -> 28
114 | nn.ReLU(),
115 | nn.Conv2d(128, 128, (4, 4), 2, 1), # 28 -> 14
116 | nn.ReLU(),
117 | nn.Conv2d(128, 64, (4, 4), 2, 1), # 14 -> 7
118 | nn.ReLU(),
119 | nn.Conv2d(64, 5, (7, 7), 1, 0))
120 | else:
121 | self.conv = torch.nn.Sequential(*list(model.children())[:-1])
122 | self.lin = nn.Linear(2048, 5)
123 |
124 | def forward(self, x):
125 | out = self.conv(x).squeeze()
126 | if self.model_type != 'self':
127 | out = self.lin(F.relu(out))
128 | return out
129 |
130 | @gin.configurable
131 | class MultiTaskModel(nn.Module):
132 | def __init__(self, model, vocab_size, model_type = 'densenet121', in_feats = gin.REQUIRED):
133 | super(MultiTaskModel, self).__init__()
134 | self.model_type = model_type
135 | if self.model_type == 'densenet121':
136 | self.feature_extract = model.features
137 | else:
138 | self.feature_extract = torch.nn.Sequential(*list(model.children())[:-1])
139 |
140 | self.disease_classifier = nn.Sequential(nn.Linear(in_feats, 512),
141 | nn.ReLU(), nn.Linear(512, 5))
142 | self.fine_disease_classifier = nn.Sequential(nn.Linear(in_feats, 512),
143 | nn.ReLU(), nn.Linear(512, 321))
144 | self.language_classifier = LanguageModel(inp_size = in_feats, vocab_size = vocab_size)
145 |
146 | def forward(self, data, text):
147 | features = self.feature_extract(data).squeeze()
148 | out = F.relu(features)
149 | if self.model_type == 'densenet121':
150 | out = F.adaptive_avg_pool2d(out, (1, 1))
151 | out = torch.flatten(out, 1)
152 |
153 | return (self.disease_classifier(out),
154 | self.fine_disease_classifier(out),
155 | self.language_classifier(out, text)
156 | )
157 |
--------------------------------------------------------------------------------
/grad-cam.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | from torch.autograd import Function
7 | from torchvision import models
8 | from torchvision import utils
9 | import cv2
10 | import sys
11 | import numpy as np
12 | import argparse
13 | from models import AbnormalNet
14 | class FeatureExtractor():
15 | """ Class for extracting activations and
16 | registering gradients from targetted intermediate layers """
17 | def __init__(self, model, target_layers):
18 | self.model = model
19 | self.target_layers = target_layers
20 | self.gradients = []
21 |
22 | def save_gradient(self, grad):
23 | self.gradients.append(grad)
24 |
25 | def __call__(self, x):
26 | outputs = []
27 | self.gradients = []
28 | # print(list(self.model._modules.items()))
29 | for name, module in self.model._modules.items():
30 | x = module(x)
31 | if name in self.target_layers:
32 | x.register_hook(self.save_gradient)
33 | outputs += [x]
34 | return outputs, x
35 |
36 | class ModelOutputs():
37 | """ Class for making a forward pass, and getting:
38 | 1. The network output.
39 | 2. Activations from intermeddiate targetted layers.
40 | 3. Gradients from intermeddiate targetted layers. """
41 | def __init__(self, model, target_layers):
42 | self.model = model
43 | self.model.features = self.model.conv
44 | self.model.classifier = self.model.lin
45 | self.feature_extractor = FeatureExtractor(self.model.features, target_layers)
46 |
47 | def get_gradients(self):
48 | return self.feature_extractor.gradients
49 |
50 | def __call__(self, x):
51 | target_activations, output = self.feature_extractor(x)
52 | # output = output.view(output.size(0), -1)
53 | output = F.relu(output.squeeze())
54 | output = self.model.classifier(output)
55 | return target_activations, F.softmax(output, dim=-1)
56 |
57 | def preprocess_image(img):
58 | means=[0.485, 0.456, 0.406]
59 | stds=[0.229, 0.224, 0.225]
60 |
61 | preprocessed_img = img.copy()[: , :, ::-1]
62 | for i in range(3):
63 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - means[i]
64 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / stds[i]
65 | preprocessed_img = \
66 | np.ascontiguousarray(np.transpose(preprocessed_img, (2, 0, 1)))
67 | preprocessed_img = torch.from_numpy(preprocessed_img)
68 | preprocessed_img.unsqueeze_(0)
69 | input = Variable(preprocessed_img, requires_grad = True)
70 | return input
71 |
72 | def show_cam_on_image(img, mask):
73 | heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
74 | heatmap = np.float32(heatmap) / 255
75 | cam = heatmap + np.float32(img)
76 | cam = cam / np.max(cam)
77 | return cam
78 |
79 | class GradCam:
80 | def __init__(self, model, target_layer_names, use_cuda):
81 | self.model = model
82 | self.model.eval()
83 | self.cuda = use_cuda
84 | if self.cuda:
85 | self.model = model.cuda()
86 |
87 | self.extractor = ModelOutputs(self.model, target_layer_names)
88 |
89 | def forward(self, input):
90 | return self.model(input)
91 |
92 | def __call__(self, input, index = None):
93 | if self.cuda:
94 | features, output = self.extractor(input.cuda())
95 | else:
96 | features, output = self.extractor(input)
97 |
98 | if index == None:
99 | index = np.argmax(output.cpu().data.numpy())
100 |
101 | one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32)
102 | one_hot[0][index] = 1
103 | one_hot = Variable(torch.from_numpy(one_hot), requires_grad = True)
104 | if self.cuda:
105 | one_hot = torch.sum(one_hot.cuda() * output)
106 | else:
107 | one_hot = torch.sum(one_hot * output)
108 |
109 | self.model.features.zero_grad()
110 | self.model.classifier.zero_grad()
111 | one_hot.backward(retain_graph=True)
112 |
113 | grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()
114 |
115 | target = features[-1]
116 | target = target.cpu().data.numpy()[0, :]
117 |
118 | weights = np.mean(grads_val, axis = (2, 3))[0, :]
119 | cam = np.zeros(target.shape[1 : ], dtype = np.float32)
120 |
121 | for i, w in enumerate(weights):
122 | cam += w * target[i, :, :]
123 |
124 | cam = np.maximum(cam, 0)
125 | cam = cv2.resize(cam, (224, 224))
126 | cam = cam - np.min(cam)
127 | cam = cam / np.max(cam)
128 | return cam, index
129 |
130 | class GuidedBackpropReLU(Function):
131 |
132 | def forward(self, input):
133 | positive_mask = (input > 0).type_as(input)
134 | output = torch.addcmul(torch.zeros(input.size()).type_as(input), input, positive_mask)
135 | self.save_for_backward(input, output)
136 | return output
137 |
138 | def backward(self, grad_output):
139 | input, output = self.saved_tensors
140 | grad_input = None
141 |
142 | positive_mask_1 = (input > 0).type_as(grad_output)
143 | positive_mask_2 = (grad_output > 0).type_as(grad_output)
144 | grad_input = torch.addcmul(torch.zeros(input.size()).type_as(input), torch.addcmul(torch.zeros(input.size()).type_as(input), grad_output, positive_mask_1), positive_mask_2)
145 |
146 | return grad_input
147 |
148 | class GuidedBackpropReLUModel:
149 | def __init__(self, model, use_cuda):
150 | self.model = model
151 | self.model.features = model.conv
152 | self.model.eval()
153 | self.cuda = use_cuda
154 | if self.cuda:
155 | self.model = model.cuda()
156 |
157 | # replace ReLU with GuidedBackpropReLU
158 | for idx, module in self.model.features._modules.items():
159 | if module.__class__.__name__ == 'ReLU':
160 | self.model.features._modules[idx] = GuidedBackpropReLU()
161 |
162 | def forward(self, input):
163 | return self.model(input)
164 |
165 | def __call__(self, input, index = None):
166 | if self.cuda:
167 | output = self.forward(input.cuda())
168 | else:
169 | output = self.forward(input)
170 |
171 | if index == None:
172 | index = np.argmax(output.cpu().data.numpy())
173 |
174 | one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32)
175 | one_hot[0][index] = 1
176 | one_hot = Variable(torch.from_numpy(one_hot), requires_grad = True)
177 | if self.cuda:
178 | one_hot = torch.sum(one_hot.cuda() * output)
179 | else:
180 | one_hot = torch.sum(one_hot * output)
181 |
182 | # self.model.features.zero_grad()
183 | # self.model.classifier.zero_grad()
184 | one_hot.backward(retain_graph=True)
185 |
186 | output = input.grad.cpu().data.numpy()
187 | output = output[0,:,:,:]
188 |
189 | return output
190 |
191 | def get_args():
192 | parser = argparse.ArgumentParser()
193 | parser.add_argument('--use-cuda', action='store_true', default=False,
194 | help='Use NVIDIA GPU acceleration')
195 | args = parser.parse_args()
196 | args.use_cuda = args.use_cuda and torch.cuda.is_available()
197 | if args.use_cuda:
198 | print("Using GPU for acceleration")
199 | else:
200 | print("Using CPU for computation")
201 |
202 | return args
203 |
204 | if __name__ == '__main__':
205 | """ python grad_cam.py
206 | 1. Loads an image with opencv.
207 | 2. Preprocesses it for VGG19 and converts to a pytorch variable.
208 | 3. Makes a forward pass to find the category index with the highest score,
209 | and computes intermediate activations.
210 | Makes the visualization. """
211 |
212 | args = get_args()
213 |
214 | # Can work with any model, but it assumes that the model has a
215 | # feature method, and a classifier method,
216 | # as in the VGG models in torchvision.
217 | idx2label = {0:'melanoma', 1:'glaucoma', 2:'amd', 3:'diabetic retinopathy', 4:'normal'}
218 | task_list = [0.15, 0.25, 0.4, 0.55, 0.7, 0.85]
219 | for t in task_list:
220 | for i in os.listdir('Inked/'):
221 | spl = round(1 - t - 0.15, 2)
222 | # file_name = i.replace('Inked','').replace('_LI.jpg','').replace('%','').replace('_','5C', 1)
223 | file_name = 'Batch2079-805C1.2.826.0.1.3680043.2.110.1192826410552750.1.200_0000_000000_1559691433052d.jpg'
224 | model = models.resnet50(pretrained=True)
225 | model = nn.DataParallel(AbnormalNet('resnet50', model))
226 |
227 | # model.load_state_dict(torch.load('/data2/sachelar/kd_models/kd_only-{:.2f}/best_model.pt'.format(spl)))
228 | model.load_state_dict(torch.load('small_models/{:.2f}-resnet/best_model.pt'.format(spl)))
229 |
230 | grad_cam = GradCam(model = model.module, \
231 | target_layer_names = ["7"], use_cuda=args.use_cuda)
232 | try:
233 | img = cv2.imread(os.path.join('/data2/fundus_images',file_name), 1)
234 | img = np.float32(cv2.resize(img, (224, 224))) / 255
235 | input = preprocess_image(img)
236 |
237 | # If None, returns the map for the highest scoring category.
238 | # Otherwise, targets the requested index.
239 | target_index = None
240 |
241 | mask, index = grad_cam(input, target_index)
242 |
243 |
244 | cam = show_cam_on_image(img, mask)
245 | if not os.path.exists(os.path.join("outputs", file_name.replace('.jpg',''))):
246 | os.mkdir(os.path.join("outputs", file_name.replace('.jpg','')))
247 | cv2.imwrite(os.path.join("outputs", file_name.replace('.jpg','') + '/'+"{:.2f}".format(spl) + '_'+idx2label[index] + "_cam.jpg"), np.uint8(255 * cam))
248 |
249 | gb_model = GuidedBackpropReLUModel(model = model.module, use_cuda=args.use_cuda)
250 | gb = gb_model(input, index=target_index)
251 | # utils.save_image(torch.from_numpy(gb),"outputs/{:.2f}".format(spl) +'_'+ file_name.replace('.jpg','')+ '_gb.jpg')
252 |
253 | cam_mask = np.zeros(gb.shape)
254 | for i in range(0, gb.shape[0]):
255 | cam_mask[i, :, :] = mask
256 |
257 | cam_gb = np.multiply(cam_mask, gb)
258 | # utils.save_image(torch.from_numpy(cam_gb),"outputs/{:.2f}".format(spl) +'_'+ file_name.replace('.jpg','') + '_cam_gb.jpg')
259 | except:
260 | print("ERROR", file_name)
261 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import gin
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.utils.tensorboard import SummaryWriter
6 |
7 | import os
8 | from datetime import datetime
9 | from collections import defaultdict
10 | from utils import compute_bleu, compute_topk, accuracy_recall_precision_f1, calculate_confusion_matrix
11 | class BaseTrainer(object):
12 | def __init__(self, model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss = 100):
13 | self.model = model
14 | self.optimizer = optimizer
15 | self.scheduler = scheduler
16 | self.epochs = epochs
17 | self.criterion = criterion
18 | self.print_every = print_every
19 | self.min_val_loss = min_val_loss
20 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21 | # Save experiment configuration
22 | self.save_location_dir = os.path.join('models', str(datetime.now()).replace(' ',''))
23 |
24 | def init_saves(self):
25 | if not os.path.exists(self.save_location_dir):
26 | os.mkdir(self.save_location_dir)
27 | with open(os.path.join(self.save_location_dir,'config.gin'), 'w') as conf:
28 | conf.write(gin.operative_config_str())
29 | self.output_log = os.path.join(self.save_location_dir,'output_log.txt')
30 | self.save_path = os.path.join(self.save_location_dir, 'best_model.pt')
31 | self.summary_writer = SummaryWriter(os.path.join(self.save_location_dir, 'logs'), 300)
32 |
33 | def train(self, train_loader, val_loader):
34 | raise NotImplementedError
35 |
36 | def train_iteration(self, train_loader, val_loader):
37 | raise NotImplementedError
38 |
39 | def validate(self, train_loader, val_loader):
40 | raise NotImplementedError
41 |
42 | def test(self, test_loader):
43 | raise NotImplementedError
44 |
45 |
46 | class MultiTaskTrainer(BaseTrainer):
47 | def __init__(self, model, optimizer, scheduler, criterion, tasks, epochs, lang, print_every = 100, min_val_loss = 100):
48 | super(MultiTaskTrainer, self).__init__(model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss)
49 | self.lang = lang
50 | self.tasks = tasks
51 | self.save_location_dir = os.path.join('models', '_'.join(str(t) for t in self.tasks) +'-'+ str(datetime.now()).replace(' ',''))
52 | self.init_saves()
53 |
54 | def train(self, train_loader, val_loader):
55 | for e in range(self.epochs):
56 | self.model.train()
57 | total_train_loss, total_tl1, total_tl2, total_tl3, total_disease_acc, accuracy, bleu = self.train_iteration(train_loader)
58 | print("Epoch", e)
59 | self.summary_writer.add_scalar('training/total_train_loss', total_train_loss, e)
60 | self.summary_writer.add_scalar('training/total_t1_loss', total_tl1, e)
61 | self.summary_writer.add_scalar('training/total_t2_loss', total_tl2, e)
62 | self.summary_writer.add_scalar('training/total_t3_loss', total_tl3, e)
63 | self.summary_writer.add_scalar('training/t1_acc', total_disease_acc, e)
64 | self.summary_writer.add_scalar('training/t2_acc', accuracy, e)
65 | self.summary_writer.add_scalar('training/t3_bleu', bleu, e)
66 | val_loss, total_d_acc, total_acc, bleu, total_f1, total_recall, total_precision, sent_gt, sent_pred, total_topk, per_disease_topk, per_disease_bleu, total_cm = self.validate(val_loader)
67 | with open(self.output_log, 'a+') as out:
68 | print('Epoch: {}\tVal Loss:{:.8f}\tAcc:{:.8f}\tDAcc:{:.8f}\tBLEU:{:.8f}'.format(e,val_loss, total_acc, total_d_acc, bleu), file=out)
69 | print('total_topk',total_topk, file=out)
70 | print('per_disease_topk', per_disease_topk, file=out)
71 | print('per_disease_bleu', per_disease_bleu, file=out)
72 | print(total_cm, file=out)
73 | for k in np.random.choice(list(range(len(sent_gt))), size=10, replace=False):
74 | print(sent_gt[k], file=out)
75 | print(sent_pred[k], file=out)
76 | print('---------------------', file=out)
77 |
78 | self.summary_writer.add_scalar('validation/val_loss', val_loss, e)
79 | self.summary_writer.add_scalar('validation/t1_acc', total_d_acc, e)
80 | self.summary_writer.add_scalar('validation/t2_acc', total_acc, e)
81 | self.summary_writer.add_scalar('validation/BLEU', bleu, e)
82 | # self.summary_writer.add_scalars('validation/f1_scores', disease_f1, e)
83 | # self.summary_writer.add_scalars('validation/recall_scores',
84 | # disease_recall, e)
85 | # self.summary_writer.add_scalars('validation/precision_scores',
86 | # disease_precision, e)
87 |
88 | self.summary_writer.add_scalar('validation/f1_mean', total_f1, e)
89 | self.summary_writer.add_scalar('validation/recall_mean', total_recall, e)
90 | self.summary_writer.add_scalar('validation/precision_mean', total_precision, e)
91 | self.summary_writer.add_scalars('validation/topk', total_topk, e)
92 | for i in per_disease_topk:
93 | self.summary_writer.add_scalars('validation/topk_'+str(i), per_disease_topk[i], e)
94 |
95 | for i, k in enumerate(np.random.choice(list(range(len(sent_gt))), size=10, replace=False)):
96 | self.summary_writer.add_text('validation/sentence_gt'+str(i),
97 | ' '.join(sent_gt[k]), e)
98 | self.summary_writer.add_text('validation/sentence_pred'+str(i),
99 | ' '.join(sent_pred[k]), e)
100 |
101 | def test(self, test_loader):
102 | results = open('predictions.csv','w')
103 | ind2word = {0:'Melanoma',1:'Glaucoma',2:'AMD',3:'DR',4:'Normal'}
104 | with torch.no_grad():
105 | for i, (name, images, labels, f_labels, text) in enumerate(test_loader):
106 | batch_size = images.size(0)
107 | images = images.to(self.device)
108 | labels = labels.to(self.device)
109 | f_labels = f_labels.to(self.device)
110 | text = text.to(self.device)
111 | diseases, fine_diseases, text_pred = self.model(images, text)
112 | pred = F.log_softmax(diseases, dim= -1).argmax(dim=-1)
113 | for j in range(diseases.size(0)):
114 | results.write(name[j]+','+ind2word[labels[j].item()] +','+ind2word[pred[j].item()] +'\n')
115 |
116 |
117 | def train_iteration(self, train_loader):
118 | train_loss = 0.0
119 | accuracy = 0.0
120 | total_disease_acc = 0.0
121 | bleu = 0.0
122 | total_tl1 = 0
123 | total_tl2 = 0
124 | total_tl3 = 0
125 | total_train_loss = 0.0
126 | loss = torch.tensor(0).to(self.device)
127 | for i, (_, images, labels, f_labels, text) in enumerate(train_loader):
128 | batch_size = images.size(0)
129 | images = images.to(self.device)
130 | labels = labels.to(self.device)
131 | f_labels = f_labels.to(self.device)
132 | text = text.to(self.device)
133 | self.optimizer.zero_grad()
134 | disease, f_disease, text_pred = self.model(images, text)
135 | loss1 = self.criterion(disease, labels)
136 |
137 | loss2 = self.criterion(f_disease, f_labels)
138 |
139 | loss3 = 0.0
140 | for k in range(text_pred.size(1)):
141 | loss3 += self.criterion(text_pred[:, k].squeeze(), text[:, k + 1].squeeze())
142 |
143 | # Only consider tasks defined in the task list
144 | loss = torch.stack((loss1,loss2, loss3))[self.tasks].sum()
145 |
146 | loss.backward()
147 | self.optimizer.step()
148 |
149 | train_loss += loss.item()
150 | total_train_loss += loss.item()
151 | total_tl1 += loss1.item()
152 | total_tl2 += loss2.item()
153 | total_tl3 += loss3.item()
154 |
155 | pred = F.log_softmax(f_disease, dim = -1).argmax(dim=-1)
156 | accuracy += pred.eq(f_labels).sum().item()
157 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1)
158 | total_disease_acc += d_pred.eq(labels).sum().item()
159 | # preds = torch.argmax(F.log_softmax(text_pred,dim=-1), dim=-1)
160 | # text1 = text[:, 1:].squeeze().tolist()
161 | # preds1 = preds.tolist()
162 | # tbleu, _, _ = compute_bleu(self.lang, text1, preds1, labels, per_disease_bleu)
163 | # bleu += tbleu
164 |
165 | if i != 0 and i % self.print_every == 0:
166 | avg_loss = train_loss / self.print_every
167 | accuracy = accuracy / self.print_every
168 | total_disease_acc = total_disease_acc / self.print_every
169 | avg_text_loss = loss3 / self.print_every
170 | bleu = bleu / self.print_every
171 | total_train_loss = total_train_loss / self.print_every
172 | total_tl1 = total_tl1 / self.print_every
173 | total_tl2 = total_tl2 / self.print_every
174 | total_tl3 = total_tl3 / self.print_every
175 |
176 | print('Iter:{}\tTraining Loss:{:.8f}\tAcc:{:.8f}\tDAcc:{:.8f}\tBLEU:{:.8f}\tTextLoss:{:.8f}'.format(i, avg_loss,
177 | accuracy/batch_size,
178 | total_disease_acc / batch_size,
179 | bleu,
180 | loss3.item()))
181 |
182 | train_loss = 0.0
183 | text_loss = 0.0
184 | accuracy = 0.0
185 | total_disease_acc = 0.0
186 | bleu = 0.0
187 | total_tl1 = 0
188 | total_tl2 = 0
189 | total_tl3 = 0
190 | total_train_loss = 0.0
191 | return (total_train_loss, total_tl1, total_tl2, total_tl3, total_disease_acc/batch_size, accuracy/batch_size, bleu)
192 |
193 | def validate(self, val_loader, epoch = 0):
194 | self.model.eval()
195 | val_loss = 0.0
196 | total_acc = 0.0
197 | total_recall = 0.0
198 | total_precision = 0.0
199 | total_f1 = 0.0
200 | total_cm = 0
201 | total_d_acc = 0.0
202 | bleu = 0.0
203 | total_l1 = 0
204 | total_l2 = 0
205 | total_l3 = 0
206 |
207 | k_vals = [1, 2, 3, 4, 5]
208 | total_topk = {k:0.0 for k in k_vals}
209 | per_disease_topk = defaultdict(lambda: {str(k):0.0 for k in k_vals})
210 | per_disease_bleu = defaultdict(list)
211 | with torch.no_grad():
212 | for i, (_, images, labels, f_labels, text) in enumerate(val_loader):
213 | batch_size = images.size(0)
214 | images = images.to(self.device)
215 | labels = labels.to(self.device)
216 | f_labels = f_labels.to(self.device)
217 | text = text.to(self.device)
218 | diseases, fine_diseases, text_pred = self.model(images, text)
219 | loss1 = self.criterion(diseases, labels)
220 | loss2 = self.criterion(fine_diseases, f_labels)
221 | text_loss = 0.0
222 | for k in range(text_pred.size(1)):
223 | text_loss += self.criterion(text_pred[:,k].squeeze(), text[:,k + 1].squeeze())
224 |
225 | val_loss += torch.stack((loss1, loss2, text_loss))[self.tasks].sum()
226 |
227 | preds = F.log_softmax(fine_diseases, dim = -1)
228 | pred = preds.argmax(dim=-1)
229 | d_pred = F.log_softmax(diseases, dim = -1).argmax(dim=-1)
230 |
231 | # Evaluation of P, R, F1, CM, BLEU
232 | total_acc += (pred.eq(f_labels).sum().item() / batch_size)
233 | total_d_acc += (d_pred.eq(labels).sum().item() / batch_size)
234 | acc, recall, precision, f1 = accuracy_recall_precision_f1(d_pred,
235 | labels)
236 | cm = calculate_confusion_matrix(d_pred, labels)
237 | try:
238 | total_cm += (cm / batch_size)
239 | except:
240 | print("Error occured for this CM")
241 | print(cm / batch_size)
242 |
243 | # Top-k evaluation
244 | for k in k_vals:
245 | total_topk[k] += compute_topk(preds, f_labels, k)
246 | for d in [0, 1, 2, 3]:
247 | mask = labels.eq(d)
248 | if mask.sum() > 0:
249 | per_disease_topk[d][str(k)] += compute_topk(preds[mask], f_labels[mask], k)
250 |
251 | total_recall += np.mean(recall)
252 | total_precision += np.mean(precision)
253 | total_f1 += np.mean(f1)
254 | preds = torch.argmax(F.log_softmax(text_pred,dim=-1), dim=-1)
255 | text1 = text[:, 1:].squeeze().tolist()
256 | preds1 = preds.tolist()
257 | t_bleu, sent_gt, sent_pred = compute_bleu(self.lang, text1, preds1, labels, per_disease_bleu)
258 |
259 | # Book-keeping
260 | bleu += t_bleu
261 | total_l1 += loss1.item()
262 | total_l2 += loss2.item()
263 | total_l3 += text_loss.item()
264 | bleu = bleu / (len(val_loader))
265 | val_loss = val_loss / len(val_loader)
266 | total_l1 /= len(val_loader)
267 | total_l2 /= len(val_loader)
268 | total_l3 /= len(val_loader)
269 | total_acc = total_acc / len(val_loader)
270 | total_d_acc = total_d_acc / len(val_loader)
271 | total_f1 = total_f1 / len(val_loader)
272 | total_precision = total_precision / len(val_loader)
273 | total_recall = total_recall / len(val_loader)
274 | total_cm = total_cm / len(val_loader)
275 |
276 | self.scheduler.step(val_loss)
277 | if val_loss <= self.min_val_loss:
278 | torch.save(self.model.state_dict(), self.save_path)
279 | self.min_val_loss = val_loss
280 |
281 | disease_f1 = {}
282 | disease_precision = {}
283 | disease_recall = {}
284 |
285 | #for i in range(len(total_f1)):
286 | # disease_f1[i] = total_f1[i]
287 | # disease_precision[i] = total_precision[i]
288 | # disease_recall[i] = total_recall[i]
289 | for d in per_disease_bleu:
290 | per_disease_bleu[d] = np.mean(per_disease_bleu[d])
291 |
292 | total_topk = {str(k) : total_topk[k] / len(val_loader) for k in k_vals}
293 | for d in [0,1,2,3]:
294 | for k in k_vals:
295 | per_disease_topk[d][str(k)] = per_disease_topk[d][str(k)] / len(val_loader)
296 |
297 | return (val_loss, total_d_acc, total_acc, bleu, total_f1, total_recall,
298 | total_precision, sent_gt, sent_pred, total_topk,
299 | per_disease_topk, per_disease_bleu, total_cm)
300 |
301 |
302 |
303 | class SmallTrainer(BaseTrainer):
304 | def __init__(self, model, optimizer, scheduler, criterion, epochs, print_every = 100, min_val_loss = 100, trainset_split = 0.85):
305 | super(SmallTrainer, self).__init__(model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss)
306 | self.save_location_dir = os.path.join('small_models', str(trainset_split)+'-'+str(datetime.now()))
307 | self.init_saves()
308 |
309 | def train(self, train_loader, val_loader):
310 | for e in range(self.epochs):
311 | self.model.train()
312 | total_train_loss, accuracy = self.train_iteration(train_loader)
313 | print("Epoch", e)
314 | self.summary_writer.add_scalar('training/total_train_loss', total_train_loss, e)
315 | self.summary_writer.add_scalar('training/acc', accuracy, e)
316 | with torch.no_grad():
317 | val_loss, total_d_acc, total_f1, total_recall, total_precision, total_cm = self.validate(val_loader)
318 |
319 | self.summary_writer.add_scalar('validation/val_loss', val_loss, e)
320 | self.summary_writer.add_scalar('validation/t1_acc', total_d_acc, e)
321 |
322 | self.summary_writer.add_scalar('validation/f1_mean', total_f1, e)
323 | self.summary_writer.add_scalar('validation/recall_mean', total_recall, e)
324 | self.summary_writer.add_scalar('validation/precision_mean', total_precision, e)
325 | with open(self.output_log, 'a+') as out:
326 | print('Val Loss',val_loss, 'total_d_acc',total_d_acc, 'F1',
327 | total_f1, 'R', total_recall,'P', total_precision,
328 | file=out)
329 | print(total_cm, file=out)
330 |
331 | def test(self, test_loader):
332 | results = open('self_trained_extra_labels.csv','w')
333 | self.model.eval()
334 | # ind2disease = {0:'Disease',1:'Normal'}
335 | ind2disease = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR', 4:'Normal'}
336 | # ind2disease2 = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR'}
337 | ind2disease2 = {0:'not applicable' , 1: 'not classifed', 2: 'diabetes no retinopathy'}
338 | for i, data in tqdm.tqdm(enumerate(test_loader)):
339 | image_name = data[0]
340 | images = data[1]
341 | labels = data[2]
342 | batch_size = images.size(0)
343 | images = images.to(self.device)
344 | disease = self.model(images)
345 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1)
346 | probs, _ = F.softmax(disease, dim=-1).max(dim=-1)
347 | for j in range(d_pred.size(0)):
348 | results.write(image_name[j]+','+ '{:.8f}'.format(probs[j].item()) +',' + ind2disease2[labels[j].item()] +','+ind2disease[d_pred[j].item()]+'\n')
349 |
350 | def train_iteration(self, train_loader):
351 | train_loss = 0.0
352 | accuracy = 0.0
353 | total_disease_acc = 0.0
354 | total_train_loss = 0.0
355 | for i, (images, labels) in enumerate(train_loader):
356 | batch_size = images.size(0)
357 | images = images.to(self.device)
358 | labels = labels.to(self.device)
359 |
360 | self.optimizer.zero_grad()
361 | disease = self.model(images)
362 | loss = self.criterion(disease, labels)
363 |
364 | loss.backward()
365 | self.optimizer.step()
366 |
367 | train_loss += loss.item()
368 | total_train_loss += loss.item()
369 |
370 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1)
371 | total_disease_acc += d_pred.eq(labels).sum().item()
372 |
373 | if i != 0 and i % self.print_every == 0:
374 | avg_loss = train_loss / self.print_every
375 | total_disease_acc = total_disease_acc / self.print_every
376 | total_train_loss = total_train_loss / self.print_every
377 |
378 | print('Iter:{}\tTraining Loss:{:.8f}\tAcc:{:.8f}'.format(i,
379 | avg_loss, total_disease_acc / batch_size))
380 |
381 | train_loss = 0.0
382 | total_disease_acc = 0.0
383 | return (total_train_loss, total_disease_acc/batch_size)
384 |
385 | def validate(self, val_loader, epoch = 0):
386 | self.model.eval()
387 | val_loss = 0.0
388 | total_acc = 0.0
389 | total_recall = 0.0
390 | total_precision = 0.0
391 | total_f1 = 0.0
392 | total_cm = 0
393 | total_d_acc = 0.0
394 | bleu = 0.0
395 | total_l1 = 0
396 | total_l2 = 0
397 | total_l3 = 0
398 |
399 | k_vals = [1, 2, 3, 4, 5]
400 | total_topk = {k:0.0 for k in k_vals}
401 | per_disease_topk = defaultdict(lambda: {str(k):0.0 for k in k_vals})
402 | losses = []
403 | with torch.no_grad():
404 | for i, (images, labels) in enumerate(val_loader):
405 | batch_size = images.size(0)
406 | images = images.to(self.device)
407 | labels = labels.to(self.device)
408 | diseases = self.model(images)
409 | loss1 = self.criterion(diseases, labels)
410 |
411 | val_loss += loss1.item()
412 |
413 | # Evaluation of P, R, F1, BLEU
414 | d_pred = F.log_softmax(diseases, dim = -1).argmax(dim=-1)
415 | total_d_acc += (d_pred.eq(labels).sum().item() / batch_size)
416 | acc, recall, precision, f1 = accuracy_recall_precision_f1(d_pred,
417 | labels)
418 |
419 | total_recall += np.mean(recall)
420 | total_precision += np.mean(precision)
421 | total_f1 += np.mean(f1)
422 |
423 | cm = calculate_confusion_matrix(d_pred, labels)
424 | try:
425 | total_cm += (cm / batch_size)
426 | except:
427 | print("error occured for this CM")
428 | print(cm / batch_size)
429 | val_loss = val_loss / len(val_loader)
430 | total_d_acc = total_d_acc / len(val_loader)
431 | total_f1 = total_f1 / len(val_loader)
432 | total_precision = total_precision / len(val_loader)
433 | total_recall = total_recall / len(val_loader)
434 | total_cm = total_cm / len(val_loader)
435 |
436 | self.scheduler.step(val_loss)
437 | if val_loss <= self.min_val_loss:
438 | torch.save(self.model.state_dict(), self.save_path)
439 | self.min_val_loss = val_loss
440 |
441 | disease_f1 = {}
442 | disease_precision = {}
443 | disease_recall = {}
444 |
445 | # for i in range(len(total_f1)):
446 | # disease_f1[i] = total_f1[i]
447 | # disease_precision[i] = total_precision[i]
448 | # disease_recall[i] = total_recall[i]
449 |
450 | return (val_loss, total_d_acc, total_f1, total_recall,
451 | total_precision, total_cm)
452 |
453 | class AutoTrainer(BaseTrainer):
454 | def __init__(self, model, optimizer, scheduler, criterion, epochs,
455 | print_every = 100, min_val_loss = 100000, trainset_split = 0.85):
456 | super(AutoTrainer, self).__init__(model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss)
457 | self.save_location_dir = os.path.join('auto_models', str(trainset_split)+'-'+str(datetime.now()))
458 | self.init_saves()
459 |
460 | def train(self, train_loader, val_loader):
461 | for e in range(self.epochs):
462 | self.model.train()
463 | total_train_loss = self.train_iteration(train_loader)
464 | print("Epoch", e)
465 | self.summary_writer.add_scalar('training/total_train_loss', total_train_loss, e)
466 | with torch.no_grad():
467 | val_loss = self.validate(val_loader)
468 |
469 | self.summary_writer.add_scalar('validation/val_loss', val_loss, e)
470 | with open(self.output_log, 'a+') as out:
471 | print('Val Loss',val_loss, file=out)
472 |
473 | def test(self, test_loader):
474 | results = open('self_trained_extra_labels.csv','w')
475 | self.model.eval()
476 | # ind2disease = {0:'Disease',1:'Normal'}
477 | ind2disease = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR', 4:'Normal'}
478 | # ind2disease2 = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR'}
479 | ind2disease2 = {0:'not applicable' , 1: 'not classifed', 2: 'diabetes no retinopathy'}
480 | for i, data in tqdm.tqdm(enumerate(test_loader)):
481 | image_name = data[0]
482 | images = data[1]
483 | labels = data[2]
484 | batch_size = images.size(0)
485 | images = images.to(self.device)
486 | disease = self.model(images)
487 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1)
488 | probs, _ = F.softmax(disease, dim=-1).max(dim=-1)
489 | for j in range(d_pred.size(0)):
490 | results.write(image_name[j]+','+ '{:.8f}'.format(probs[j].item()) +',' + ind2disease2[labels[j].item()] +','+ind2disease[d_pred[j].item()]+'\n')
491 |
492 | def train_iteration(self, train_loader):
493 | train_loss = 0.0
494 | accuracy = 0.0
495 | total_disease_acc = 0.0
496 | total_train_loss = 0.0
497 | for i, data in enumerate(train_loader):
498 | (_, images, labels, f_labels, text) = data
499 | batch_size = images.size(0)
500 | images = images.to(self.device)
501 | labels = labels.to(self.device)
502 |
503 | self.optimizer.zero_grad()
504 | disease = self.model(images)
505 | loss = self.criterion(disease, images).div(images.size(0))
506 |
507 | loss.backward()
508 | self.optimizer.step()
509 |
510 | train_loss += loss.item()
511 | total_train_loss += loss.item()
512 |
513 |
514 | if i != 0 and i % self.print_every == 0:
515 | avg_loss = train_loss / self.print_every
516 |
517 | print('Iter:{}\tTraining Loss:{:.8f}'.format(i,
518 | avg_loss))
519 |
520 | train_loss = 0.0
521 | return total_train_loss
522 |
523 | def validate(self, val_loader, epoch = 0):
524 | self.model.eval()
525 | val_loss = 0.0
526 | total_acc = 0.0
527 | total_recall = 0.0
528 | total_precision = 0.0
529 | total_f1 = 0.0
530 | total_cm = 0
531 | total_d_acc = 0.0
532 | bleu = 0.0
533 | total_l1 = 0
534 | total_l2 = 0
535 | total_l3 = 0
536 |
537 | k_vals = [1, 2, 3, 4, 5]
538 | total_topk = {k:0.0 for k in k_vals}
539 | per_disease_topk = defaultdict(lambda: {str(k):0.0 for k in k_vals})
540 | losses = []
541 | with torch.no_grad():
542 | for i, data in enumerate(val_loader):
543 | (_, images, labels, f_labels, text) = data
544 | batch_size = images.size(0)
545 | images = images.to(self.device)
546 | labels = labels.to(self.device)
547 | diseases = self.model(images)
548 | loss1 = self.criterion(diseases, images).div(images.size(0))
549 |
550 | val_loss += loss1.item()
551 |
552 | val_loss = val_loss / len(val_loader)
553 |
554 | self.scheduler.step(val_loss)
555 | if val_loss <= self.min_val_loss:
556 | torch.save(self.model.state_dict(), self.save_path)
557 | self.min_val_loss = val_loss
558 |
559 | return val_loss
560 |
561 |
562 | class KDTrainer(BaseTrainer):
563 | def __init__(self, kd_model, model, optimizer, scheduler, criterion, epochs, print_every = 100, min_val_loss = 100, trainset_split = 0.85, kd_type = 'full'):
564 | super(KDTrainer, self).__init__(model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss)
565 | self.criterion = self.distillation_loss
566 | self.kd_model = kd_model
567 | self.threshold = 0.9
568 | self.kd_type = kd_type
569 | self.kd_model.eval()
570 | self.save_location_dir = os.path.join('kdmodels', kd_type +'-'+str(trainset_split)+'-'+str(datetime.now()))
571 | self.init_saves()
572 |
573 | def distillation_loss(self, y, labels, teacher_scores, T = 5, alpha = 0.95, reduction_kd='mean', reduction_nll='mean'):
574 | if teacher_scores is not None:
575 | # d_loss = torch.nn.KLDivLoss(reduction=reduction_kd)(F.log_softmax(y/ T, dim= -1), F.softmax(teacher_scores / T, dim=-1)) * T * T
576 | preds = F.softmax(teacher_scores , dim=-1).argmax(dim=-1)
577 | d_loss = F.cross_entropy(y, labels, reduction=reduction_nll)
578 | else:
579 | assert alpha == 0, 'alpha cannot be {} when teacher scores are not provided'.format(alpha)
580 | d_loss = 0.0
581 | nll_loss = F.cross_entropy(y, labels, reduction=reduction_nll)
582 | if self.kd_type == 'full':
583 | tol_loss = alpha * d_loss + (1.0 - alpha) * nll_loss
584 | else:
585 | tol_loss = d_loss
586 | return tol_loss, d_loss, nll_loss
587 |
588 | def unpack_data(self, data):
589 | if self.kd_type != 'full':
590 | (_, images, labels, f_labels, text) = data
591 | else:
592 | images, labels = data
593 | return (images, labels)
594 |
595 | def train(self, train_loader, val_loader):
596 | for e in range(self.epochs):
597 | self.model.train()
598 | total_train_loss, accuracy = self.train_iteration(train_loader)
599 | break
600 | print("Epoch", e)
601 | self.summary_writer.add_scalar('training/total_train_loss', total_train_loss, e)
602 | self.summary_writer.add_scalar('training/acc', accuracy, e)
603 | with torch.no_grad():
604 | val_loss, total_d_acc, total_f1, total_recall, total_precision, total_cm = self.validate(val_loader)
605 |
606 | self.summary_writer.add_scalar('validation/val_loss', val_loss, e)
607 | self.summary_writer.add_scalar('validation/t1_acc', total_d_acc, e)
608 |
609 | self.summary_writer.add_scalar('validation/f1_mean', total_f1, e)
610 | self.summary_writer.add_scalar('validation/recall_mean', total_recall, e)
611 | self.summary_writer.add_scalar('validation/precision_mean', total_precision, e)
612 | with open(self.output_log, 'a+') as out:
613 | print('Val Loss',val_loss, 'total_d_acc',total_d_acc, 'F1',
614 | total_f1, 'R', total_recall,'P', total_precision)
615 | print('Val Loss',val_loss, 'total_d_acc',total_d_acc, 'F1',
616 | total_f1, 'R', total_recall,'P', total_precision,
617 | file=out)
618 | print(total_cm, file=out)
619 |
620 | def test(self, test_loader):
621 | results = open('self_trained_extra_labels.csv','w')
622 | self.model.eval()
623 | # ind2disease = {0:'Disease',1:'Normal'}
624 | ind2disease = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR', 4:'Normal'}
625 | # ind2disease2 = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR'}
626 | ind2disease2 = {0:'not applicable' , 1: 'not classifed', 2: 'diabetes no retinopathy'}
627 | for i, data in tqdm.tqdm(enumerate(test_loader)):
628 | image_name = data[0]
629 | images = data[1]
630 | labels = data[2]
631 | batch_size = images.size(0)
632 | images = images.to(self.device)
633 | disease = self.model(images)
634 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1)
635 | probs, _ = F.softmax(disease, dim=-1).max(dim=-1)
636 | for j in range(d_pred.size(0)):
637 | results.write(image_name[j]+','+ '{:.8f}'.format(probs[j].item()) +',' + ind2disease2[labels[j].item()] +','+ind2disease[d_pred[j].item()]+'\n')
638 |
639 | def train_iteration(self, train_loader):
640 | train_loss = 0.0
641 | accuracy = 0.0
642 | total_disease_acc = 0.0
643 | total_train_loss = 0.0
644 | total_kd_loss = 0.0
645 | total_nll_loss = 0.0
646 | contr = 0
647 | for i, data in enumerate(train_loader):
648 | images, labels = self.unpack_data(data)
649 | batch_size = images.size(0)
650 | images = images.to(self.device)
651 | labels = labels.to(self.device)
652 |
653 | self.optimizer.zero_grad()
654 |
655 | teacher_scores = self.kd_model(images)
656 | val, pred = F.softmax(teacher_scores, dim=-1).max(dim=-1)
657 |
658 | index = val >= self.threshold
659 | if index.any().item():
660 | # print(index, val)
661 | images = images[index]
662 | labels = labels[index]
663 | teacher_scores = teacher_scores[index]
664 | contr += images.size(0)
665 | continue
666 | disease = self.model.module(images)
667 | loss = self.criterion(disease, labels, teacher_scores)
668 |
669 | loss[0].backward()
670 | self.optimizer.step()
671 |
672 | train_loss += loss[0].item()
673 | total_train_loss += loss[0].item()
674 | total_kd_loss += loss[1].item()
675 | total_nll_loss += loss[2].item()
676 |
677 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1)
678 | total_disease_acc += d_pred.eq(labels).sum().item()
679 |
680 | if i != 0 and i % self.print_every == 0:
681 | avg_loss = train_loss / self.print_every
682 | total_disease_acc = total_disease_acc / self.print_every
683 | total_train_loss = total_train_loss / self.print_every
684 |
685 | print('Iter:{}\tTraining Loss:{:.8f}\tKD:{:.8f}\tNLL:{:.8f}\tAcc:{:.8f}'.format(i,
686 | avg_loss, total_kd_loss/ ((i+1)*batch_size),
687 | total_nll_loss / ((i+1) * batch_size ),
688 | total_disease_acc / batch_size))
689 |
690 | train_loss = 0.0
691 | total_disease_acc = 0.0
692 | print("COUNT", contr)
693 | return (total_train_loss, total_disease_acc/batch_size)
694 |
695 | def validate(self, val_loader, epoch = 0):
696 | self.model.eval()
697 | val_loss = 0.0
698 | total_acc = 0.0
699 | total_recall = 0.0
700 | total_precision = 0.0
701 | total_f1 = 0.0
702 | total_cm = 0
703 | total_d_acc = 0.0
704 | bleu = 0.0
705 | total_l1 = 0
706 | total_l2 = 0
707 | total_l3 = 0
708 |
709 | k_vals = [1, 2, 3, 4, 5]
710 | total_topk = {k:0.0 for k in k_vals}
711 | per_disease_topk = defaultdict(lambda: {str(k):0.0 for k in k_vals})
712 | losses = []
713 | with torch.no_grad():
714 | for i, (images, labels) in enumerate(val_loader):
715 | batch_size = images.size(0)
716 | images = images.to(self.device)
717 | labels = labels.to(self.device)
718 | teacher_scores = self.kd_model(images)
719 | val, pred = F.softmax(teacher_scores, dim=-1).max(dim=-1)
720 |
721 | index = val >= self.threshold
722 | if index.any().item():
723 | images = images[index]
724 | labels = labels[index]
725 | teacher_scores = teacher_scores[index]
726 | diseases = self.model.module(images)
727 |
728 | loss1, _, _ = self.criterion(diseases, labels, teacher_scores)
729 |
730 | val_loss += loss1.item()
731 |
732 | # Evaluation of P, R, F1, BLEU
733 | d_pred = F.log_softmax(diseases, dim = -1).argmax(dim=-1)
734 | total_d_acc += (d_pred.eq(labels).sum().item() / batch_size)
735 | acc, recall, precision, f1 = accuracy_recall_precision_f1(d_pred,
736 | labels)
737 |
738 | total_recall += np.mean(recall)
739 | total_precision += np.mean(precision)
740 | total_f1 += np.mean(f1)
741 |
742 | cm = calculate_confusion_matrix(d_pred, labels)
743 | try:
744 | total_cm += (cm / batch_size)
745 | except:
746 | print("error occured for this CM")
747 | print(cm / batch_size)
748 | val_loss = val_loss / len(val_loader)
749 | total_d_acc = total_d_acc / len(val_loader)
750 | total_f1 = total_f1 / len(val_loader)
751 | total_precision = total_precision / len(val_loader)
752 | total_recall = total_recall / len(val_loader)
753 | total_cm = total_cm / len(val_loader)
754 |
755 | self.scheduler.step(val_loss)
756 | if val_loss <= self.min_val_loss:
757 | torch.save(self.model.state_dict(), self.save_path)
758 | self.min_val_loss = val_loss
759 |
760 | disease_f1 = {}
761 | disease_precision = {}
762 | disease_recall = {}
763 |
764 | # for i in range(len(total_f1)):
765 | # disease_f1[i] = total_f1[i]
766 | # disease_precision[i] = total_precision[i]
767 | # disease_recall[i] = total_recall[i]
768 |
769 | return (val_loss, total_d_acc, total_f1, total_recall,
770 | total_precision, total_cm)
771 |
772 |
773 |
--------------------------------------------------------------------------------