├── trainer ├── arch.JPG ├── trainer_small.py ├── config.gin ├── config_small.gin ├── README.md ├── plot_tsne.py ├── dataset.py ├── main.py ├── utils.py ├── models.py ├── grad-cam.py └── trainer.py /trainer: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arch.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SahilC/multitask-eye-disease-recognition/HEAD/arch.JPG -------------------------------------------------------------------------------- /trainer_small.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.tensorboard import SummaryWriter 6 | import tqdm 7 | import os 8 | from datetime import datetime 9 | from collections import defaultdict 10 | from utils import compute_bleu, compute_topk, accuracy_recall_precision_f1, calculate_confusion_matrix 11 | 12 | -------------------------------------------------------------------------------- /config.gin: -------------------------------------------------------------------------------- 1 | run.batch_size = 64 2 | run.epochs = 20 3 | run.val_split = 0.15 4 | run.num_workers = 32 5 | run.print_every = 100 6 | 7 | # run.trainval_csv_path = 'merged_combined_30_sept.csv' 8 | run.trainval_csv_path = 'trainset_with_normal_30_sept.csv' 9 | # run.trainval_csv_path = 'verified_oiscapture_trained_labels_with_normal.csv' 10 | run.test_csv_path = 'testset_filtered.csv' 11 | # run.test_csv_path = 'testset.csv' 12 | # run.test_csv_path = 'extra_normal_unclassified.csv' 13 | # run.trainval_csv_path = 'trainset.csv' 14 | # trainval_csv_path = 'trainset_with_normal.csv' 15 | # test_csv_path = 'trainset_filtered.csv' 16 | # trainval_csv_path = 'self-training-set_filtered.csv' 17 | # run.trainval_csv_path = 'self-training_images.csv' 18 | run.tasks = [0, 1, 2] 19 | 20 | run.lr = 1e-3 21 | run.weight_decay = 1e-6 22 | run.momentum = 0.9 23 | run.dataset_dir = '/data2/fundus_images/' 24 | run.model_type = 'resnet50' 25 | 26 | MultiTaskModel.in_feats = 2048 # 1024 -> densenet121, 2048 -> resnet50, 512 -> resnet34 27 | -------------------------------------------------------------------------------- /config_small.gin: -------------------------------------------------------------------------------- 1 | run.batch_size = 32 2 | run.epochs = 15 3 | run.val_split = 0.15 4 | run.num_workers = 32 5 | run.print_every = 100 6 | 7 | run.trainval_csv_path = 'trainset_with_normal_30_sept.csv' 8 | run.test_csv_path = 'merged_combined_30_sept.csv' 9 | # run.trainval_csv_path = 'trainset_with_normal_30_sept.csv' 10 | # run.trainval_csv_path = 'verified_oiscapture_trained_labels_with_normal.csv' 11 | # run.test_csv_path = 'testset_filtered.csv' 12 | # run.test_csv_path = 'testset.csv' 13 | # run.test_csv_path = 'extra_normal_unclassified.csv' 14 | # run.trainval_csv_path = 'trainset.csv' 15 | # trainval_csv_path = 'trainset_with_normal.csv' 16 | # run.test_csv_path = 'trainset_filtered.csv' 17 | # trainval_csv_path = 'self-training-set_filtered.csv' 18 | # run.trainval_csv_path = 'self-training_images.csv' 19 | run.tasks = [0, 1, 2] 20 | 21 | run.lr = 1e-3 22 | run.weight_decay = 1e-6 23 | run.momentum = 0.9 24 | run.dataset_dir = '/data2/fundus_images/' 25 | run.model_type = 'resnet50' 26 | 27 | # MultiTaskModel.in_feats = 2048 # 1024 -> densenet121, 2048 -> resnet50, 512 -> resnet34 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multitask Eye Disease Recognition 2 | Multitask learning for eye disease recognition. 3 | 4 | Work done in Microsoft AI Research. Published in ACPR'20. 5 | 6 | Recently, deep learning techniques have been widely used for medical image analysis. While there exists some work on deep learning for ophthalmology, there is little work on multi-disease predictions from retinal fundus images. Also, most of the work is based on small datasets. In this work, given a fundus image, we focus on three tasks related to eye disease prediction: (1) predicting one of the four broad disease categories – diabetic retinopathy, age-related macular degeneration, glaucoma, and melanoma, (2) predicting one of the 320 fine disease sub-categories, (3) generating a textual diagnosis. We model these three tasks under a multi-task learning setup using ResNet, a popular deep convolutional neural network architecture. Our experiments on a large dataset of 40658 images across 3502 patients provides ∼86% accuracy for task 1, ∼67% top-5 accuracy for task 2, and ∼32 BLEU for the diagnosis captioning task. 7 | 8 | Link to paper:- https://link.springer.com/chapter/10.1007/978-3-030-41299-9_57 9 | 10 | Architecture Diagram 11 | 12 | 13 | 14 | Run the code with:- 15 | ``` 16 | python main.py 17 | ``` 18 | 19 | Configuration can be modified in 20 | 21 | ``` 22 | config.gin 23 | ``` 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /plot_tsne.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | z = np.load('tsne.npy') 5 | # x = z[4:,0] 6 | # y = z[4:, 1] 7 | 8 | plt.axis([-1, 1, -1, 1]) 9 | x = [] 10 | y = [] 11 | key = {4: 'retinal', 5: 12 | 'macroaneurysm', 6: 'cystoid', 7: 'macular', 8: 'edema', 9: 13 | 'nonexudative', 10: 'senile', 11: 'degeneration', 14: 'type', 15: '2', 16: 'diabetes', 17: 'mellitus',19: 'retinopathy', 20: 'choroidal', 21: 'neovascular', 22: 14 | 'membrane', 23: 'hemorrhage', 24: 'exudative', 25: 'age-related', 31: 'non-proliferative', 32: 'diabetic', 33: 'glaucoma', 34: 'suspect', 15 | 35: 'inactive', 39:'atrophic', 40: 'subfoveal', 41: 'involvement', 42: 'mild', 43: 16 | 'nonproliferative', 44: 'associated', 45: 'drusen', 46: 'macula', 48: 'detachment', 60: 'tension', 63: 'chronic', 64: 'angle-closure', 68: 'low', 76:'narrow', 77: 'angle', 78: 'pseudoexfoliation', 79: 'uveitic', 80: 17 | 'recession', 85: 'inflammations', 88: 'high', 89: 'cotton', 90: 'wool', 91: 'spots', 92: 'degenerative', 93: 18 | 'malignant', 94: 'melanoma', 95: 'intermediate', 98: 'ophthalmic', 99: 'manifestations', 100: 'uncontrolled', 105: 'pigment', 106: 19 | 'epithelium', 107: 'hypertrophy', 108: 'underlying', 109: 'condition', 20 | 113: 'choroid', 118: 'presence', 119: 'uvea', 120: 'drug', 121: 21 | 'chemical', 122: 'induced', 125: 'uveal', 126: 'anterior', 22 | 127:'subretinal', 134: 'atrophy', 135: 'iris', 136: 23 | 'oculopathy', 137: 'resolved', 140:'posterior', 141: 'cataract', 142: 'dm', 146: 'juvenile', 147: 'central', 148: 'geographic', 149:'hemorrhagic', 152: 'combined', 153: 24 | 'rhegmatogenous', 154: 'clinically', 155: 'significant', 156:'insulin', 25 | 157: 'involving', 161: 'epitheliopathy', 162: 'quiescent', 165: 26 | 'optic', 166: 'papillopathy', 167: 'exudates', 172: 'detachments', 173: 27 | 'maculae', 175: 'traumatic', 28 | 179: 'syndrome', 181: 'inflammation', 183: 'disorders', 184: 'increased',185: 'pressure', 187: 'closed-angle' 29 | } 30 | nam = [] 31 | for i in key.keys(): 32 | nam.append(key[i]) 33 | x.append(z[i,0]) 34 | y.append(z[i, 1]) 35 | 36 | x = np.array(x) 37 | y = np.array(y) 38 | 39 | print(x,y) 40 | plt.scatter(x,y) 41 | for i in range(len(key.keys())): 42 | plt.annotate(nam[i], (x[i], y[i])) 43 | 44 | plt.savefig('abc1.png') 45 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # Generic imports 2 | import pandas as pd 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | import pdb 7 | import random 8 | import torch 9 | import csv 10 | import nltk 11 | from collections import defaultdict 12 | 13 | # Torch imports 14 | from torchvision import transforms 15 | from torch.utils.data.dataset import Dataset # For custom datasets 16 | 17 | from utils import readLangs, indexFromSentence 18 | class CustomDatasetFromImages(Dataset): 19 | def __init__(self, csv_path, data_dir='/data/sachelar/fundus_images'): 20 | """ 21 | Args: 22 | csv_path (string): path to csv file 23 | img_path (string): path to the folder where images are 24 | transform: pytorch transforms for transforms and tensor conversion 25 | """ 26 | self.label2idx1 = {'melanoma':0, 'glaucoma':1, 'amd':2, 'diabetic retinopathy':3, 'normal':4} 27 | # self.label2idx1 = {'not applicable':0, 'not classified':1, 'diabetes no retinopathy':2} 28 | # 541 classes 29 | # self.label2idx2 = {j.strip().lower(): (int(i.strip().lower()) -1) for i, j in list(csv.reader(open('labels.txt', 'r'), delimiter='\t'))} 30 | self.label2idx2 = {j.strip().lower(): (int(i.strip().lower()) - 1) for 31 | i, j in list(csv.reader(open('labels2.txt', 'r'), delimiter='\t'))} 32 | 33 | self.to_tensor = transforms.Compose([ 34 | transforms.Resize((224, 224)), 35 | transforms.RandomHorizontalFlip(p=0.5), 36 | transforms.RandomVerticalFlip(p=0.5), 37 | transforms.ToTensor(), 38 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 39 | self.data_info = pd.read_csv(csv_path, header=None) 40 | # change to -4? 41 | self.image_arr = np.asarray([os.path.join(data_dir,i.split('/')[-1].replace('%','')) for i in self.data_info.iloc[:,0]]) 42 | self.label_arr1 = [self.label2idx1[i.lower()] for i in np.asarray(self.data_info.iloc[:, 1])] 43 | self.label_arr2 = [] 44 | self.lang, self.pairs = readLangs(self.data_info.iloc[:, 2], 15) 45 | 46 | for i,z in enumerate(np.asarray(self.data_info.iloc[:, 2])): 47 | self.label_arr2.append(self.label2idx2[z.strip().lower()]) 48 | # self.label_arr2 = [self.label2idx2[i] for i in np.asarray(self.data_info.iloc[:, -1])] 49 | # self.operation_arr = np.asarray(self.data_info.iloc[:, 2]) 50 | self.data_len = len(self.data_info.index) 51 | 52 | def get_lang(self): 53 | return self.lang 54 | 55 | def __getitem__(self, index): 56 | single_image_name = self.image_arr[index] 57 | img_as_img = Image.open(single_image_name).convert('RGB') 58 | img_as_tensor = self.to_tensor(img_as_img) 59 | single_image_label = self.label_arr1[index] 60 | fine_grained_label = self.label_arr2[index] 61 | text, length = indexFromSentence(self.lang, self.data_info.iloc[index, 2]) 62 | text = torch.LongTensor(text).view(-1, 1) 63 | return (single_image_name, img_as_tensor, single_image_label, fine_grained_label, text) 64 | 65 | def __len__(self): 66 | return self.data_len 67 | 68 | class GradedDatasetFromImages(Dataset): 69 | def __init__(self, csv_path, data_dir='/data/sachelar/fundus_images'): 70 | """ 71 | Args: 72 | csv_path (string): path to csv file 73 | img_path (string): path to the folder where images are 74 | transform: pytorch transforms for transforms and tensor conversion 75 | """ 76 | self.label2idx1 = {'melanoma':0, 'glaucoma':1, 'amd':2, 'diabetic retinopathy':3, 'normal':4} 77 | 78 | self.to_tensor = transforms.Compose([ 79 | transforms.Resize((224, 224)), 80 | transforms.RandomHorizontalFlip(p=0.5), 81 | transforms.RandomVerticalFlip(p=0.5), 82 | transforms.ToTensor(), 83 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 84 | 85 | self.data_info = pd.read_csv(csv_path, header=None) 86 | self.image_arr = np.asarray([os.path.join(data_dir, i.replace('%','')) for i in self.data_info.iloc[:,0]]) 87 | self.label_arr1 = [self.label2idx1[i.lower()] for i in np.asarray(self.data_info.iloc[:, 1])] 88 | self.data_len = len(self.data_info.index) 89 | 90 | def __getitem__(self, index): 91 | single_image_name = self.image_arr[index] 92 | img_as_img = Image.open(single_image_name).convert('RGB') 93 | img_as_tensor = self.to_tensor(img_as_img) 94 | single_image_label = self.label_arr1[index] 95 | return img_as_tensor, single_image_label 96 | 97 | def __len__(self): 98 | return self.data_len 99 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gin 3 | import numpy as np 4 | import copy 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchvision.models as models 9 | 10 | from trainer import MultiTaskTrainer 11 | from models import MultiTaskModel 12 | from dataset import CustomDatasetFromImages 13 | from dataset import GradedDatasetFromImages 14 | 15 | from torch.optim.lr_scheduler import ReduceLROnPlateau 16 | 17 | # Hacks for Reproducibility 18 | seed = 3 19 | torch.manual_seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | np.random.seed(seed) 23 | 24 | # from cnn_model import MnistCNNModel 25 | @gin.configurable 26 | def run(batch_size, epochs, val_split, num_workers, print_every, 27 | trainval_csv_path, test_csv_path, model_type, tasks, lr, weight_decay, 28 | momentum, dataset_dir): 29 | 30 | all_dataset = CustomDatasetFromImages(trainval_csv_path, data_dir = dataset_dir) 31 | # test_dataset = CustomDatasetFromImages(test_csv_path, data_dir = dataset_dir) 32 | val_from_images = GradedDatasetFromImages(test_csv_path, data_dir = dataset_dir) 33 | 34 | dset_len = len(all_dataset) 35 | val_size = int(val_split * dset_len) 36 | test_size = int(0.15 * dset_len) 37 | train_size = dset_len - val_size 38 | 39 | 40 | train_dataset, val_dataset = torch.utils.data.random_split(all_dataset, 41 | [train_size, 42 | val_size]) 43 | 44 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 45 | batch_size=2 * batch_size, 46 | pin_memory=False, 47 | drop_last=True, 48 | shuffle=True, 49 | num_workers=num_workers) 50 | val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 51 | batch_size=batch_size, 52 | pin_memory=False, 53 | drop_last=True, 54 | shuffle=True, 55 | num_workers=num_workers) 56 | test_loader = torch.utils.data.DataLoader(dataset=val_from_images, 57 | batch_size=batch_size, 58 | pin_memory=False, 59 | drop_last=True, 60 | shuffle=True, 61 | num_workers=num_workers) 62 | 63 | lang = train_dataset.dataset.get_lang() 64 | 65 | if model_type == 'densenet121': 66 | model = models.densenet121(pretrained=False) 67 | elif model_type == 'resnet101': 68 | model = models.resnet101(pretrained=False) 69 | elif model_type == 'resnet50': 70 | model = models.resnet50(pretrained=False) 71 | elif model_type == 'resnet34': 72 | model = models.resnet34(pretrained=False) 73 | elif model_type == 'vgg19': 74 | model = models.vgg19(pretrained=False) 75 | 76 | model = MultiTaskModel(model, vocab_size=lang.n_words, model_type = model_type) 77 | 78 | model = nn.DataParallel(model) 79 | 80 | print(model) 81 | 82 | model = model.to('cuda') 83 | 84 | criterion = nn.CrossEntropyLoss() 85 | 86 | optimizer = torch.optim.SGD(model.parameters(), 87 | weight_decay=weight_decay, 88 | momentum=momentum, 89 | lr=lr, 90 | nesterov=True) 91 | 92 | scheduler = ReduceLROnPlateau(optimizer, 93 | factor=0.5, 94 | patience=3, 95 | min_lr=1e-7, 96 | verbose=True) 97 | 98 | trainer = MultiTaskTrainer(model, optimizer, scheduler, criterion, tasks, epochs, lang, print_every = print_every) 99 | 100 | trainer.train(train_loader, val_loader) 101 | 102 | val_loss, total_d_acc, total_acc, bleu, total_f1,total_recall, total_precision, sent_gt, sent_pred, total_topk,per_disease_topk, per_disease_bleu, total_cm = trainer.validate(test_loader) 103 | with open(trainer.output_log, 'a+') as out: 104 | print('Test Loss:{:.8f}\tAcc:{:.8f}\tDAcc:{:.8f}\tBLEU:{:.8f}'.format(val_loss, total_acc, total_d_acc, bleu), file=out) 105 | print('total_topk',total_topk, file=out) 106 | print('per_disease_topk', per_disease_topk, file=out) 107 | print('per_disease_bleu', per_disease_bleu, file=out) 108 | print(total_cm, file=out) 109 | for k in np.random.choice(list(range(len(sent_gt))), size=10, replace=False): 110 | print(sent_gt[k], file=out) 111 | print(sent_pred[k], file=out) 112 | print('---------------------', file=out) 113 | trainer.test(test_loader) 114 | 115 | if __name__ == "__main__": 116 | task_configs =[[0],[1],[2],[0,1], [1,2],[0,2],[0, 1, 2]] 117 | for i, t in enumerate(task_configs): 118 | print("Running", t) 119 | gin.parse_config_file('config.gin') 120 | gin.bind_parameter('run.tasks', t) 121 | run() 122 | gin.clear_config() 123 | 124 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nltk import word_tokenize 3 | import re 4 | import unicodedata 5 | from nltk.translate.bleu_score import sentence_bleu 6 | import numpy as np 7 | from sklearn.metrics import recall_score, precision_score, f1_score, classification_report, confusion_matrix 8 | 9 | import torch 10 | 11 | from sklearn.preprocessing import LabelEncoder 12 | from sklearn.metrics import recall_score, precision_score, f1_score, classification_report 13 | import torch 14 | 15 | 16 | SOS_token = 0 17 | EOS_token = 1 18 | PAD_token = 2 19 | UNK_token = 3 20 | 21 | class Lang: 22 | def __init__(self, name): 23 | self.name = name 24 | self.word2index = {"UNK":3} 25 | self.word2count = {} 26 | self.index2word = {0: "SOS", 1: "EOS", 2: "PAD", 3:"UNK"} 27 | self.n_words = 4 28 | 29 | def addSentence(self, sentence): 30 | for word in word_tokenize(sentence): 31 | self.addWord(word) 32 | 33 | def addWord(self, word): 34 | if word not in self.word2index: 35 | self.word2index[word] = self.n_words 36 | self.word2count[word] = 1 37 | self.index2word[self.n_words] = word 38 | self.n_words += 1 39 | else: 40 | self.word2count[word] += 1 41 | 42 | def unicodeToAscii(s): 43 | return ''.join( 44 | c for c in unicodedata.normalize('NFD', s) 45 | if unicodedata.category(c) != 'Mn' 46 | ) 47 | 48 | # Lowercase, trim, and remove non-letter characters 49 | def normalizeString(s): 50 | s = unicodeToAscii(s.lower().strip()) 51 | # s = re.sub(r"([.!?])", r" \1", s) 52 | # s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) 53 | return s 54 | 55 | def readLangs(lines, max_length, lang1='eng'): 56 | print("Reading lines...") 57 | input_lang = Lang(lang1) 58 | 59 | pairs = [] 60 | for e,l in enumerate(lines): 61 | val = [] 62 | words_list = word_tokenize(l) 63 | if len(words_list) >= max_length: 64 | val.append(normalizeString(' '.join(words_list[:max_length - 1]))) 65 | else: 66 | val.append(normalizeString(l)) 67 | input_lang.addSentence(val[0]) 68 | pairs.append(val) 69 | 70 | print("Read %s sentence pairs" % len(pairs)) 71 | print("Counted words:") 72 | print(input_lang.name, input_lang.n_words) 73 | print(input_lang.word2count) 74 | return input_lang, pairs 75 | 76 | def indexFromSentence(lang, sentence, unk_ratio = 15, max_length = 15): 77 | out = [] 78 | out.append(SOS_token) 79 | # out = [lang.word2index[word] for word in word_tokenize(sentence)] 80 | 81 | # If word freq is small then replace with UNK 82 | for word in word_tokenize(normalizeString(sentence)): 83 | try: 84 | if lang.word2count[word] > unk_ratio: 85 | out.append(lang.word2index[word]) 86 | else: 87 | out.append(lang.word2index['UNK']) 88 | except: 89 | pass 90 | # print("error while processing word", word) 91 | # out.append(lang.word2index['UNK']) 92 | out.append(EOS_token) 93 | 94 | if len(out) > max_length: 95 | sentence_length = max_length 96 | else: 97 | sentence_length = len(out) 98 | 99 | # If sentence is small then pad 100 | if len(out) < max_length: 101 | for i in range(max_length - len(out)): 102 | out.append(PAD_token) 103 | elif len(out) > max_length: 104 | out = out[:max_length] 105 | return out, sentence_length 106 | 107 | def variableFromSentence(lang, sentence, max_length = 15): 108 | indexes, sentence_length = indexesFromSentence(lang, sentence, max_length = max_length) 109 | result = torch.LongTensor(indexes).view(-1, 1) 110 | return result, sentence_length 111 | 112 | def variablesFromPair(pair): 113 | input_variable = variableFromSentence(input_lang, pair[0]) 114 | return input_variable 115 | 116 | #Metrics 117 | def accuracy_recall_precision_f1(y_pred, y_target): 118 | 119 | """Computes the accuracy, recall, precision and f1 score for given predictions and targets 120 | Args: 121 | y_pred: Logits of the predictions for each class 122 | y_target: Target values 123 | """ 124 | 125 | predictions = y_pred.cpu().detach().numpy() 126 | y_target = y_target.cpu().numpy() 127 | 128 | correct = np.sum(predictions == y_target) 129 | accuracy = correct / len(predictions) 130 | 131 | recall = recall_score(y_target, predictions, average=None) #average=None (the scores for each class are returned) 132 | precision = precision_score(y_target, predictions, average=None) 133 | f1 = f1_score(y_target, predictions, average=None) 134 | 135 | return accuracy, recall, precision, f1 136 | 137 | def calculate_confusion_matrix(y_pred, y_target): 138 | 139 | predictions = y_pred.cpu().detach().numpy() 140 | y_target = y_target.cpu().numpy() 141 | 142 | #Confusion matrix 143 | cm = confusion_matrix(y_target, predictions) 144 | 145 | #multi_cm = multilabel_confusion_matrix(y_target, predictions) 146 | #print(multi_cm) 147 | #print(confusion_matrix(y_target, predictions)) 148 | 149 | #Classification report 150 | #print(classification_report(y_target, predictions)) 151 | 152 | return cm 153 | 154 | def compute_topk(topk_vals, gt, k): 155 | _, preds = topk_vals.topk(k = k, dim = 1) 156 | topk_acc = 0 157 | for i in range(preds.size(1)): 158 | topk_acc += preds[:, i].eq(gt).sum().item() 159 | return (topk_acc / topk_vals.size(0)) 160 | 161 | def compute_bleu(lang, text1, preds1, disease, per_disease_bleu): 162 | ind2word = lang.index2word 163 | bleu = 0 164 | sents_gt = [] 165 | sents_pred = [] 166 | for k in range(len(text1)): 167 | sent1 = [] 168 | sent2 = [] 169 | weights = (0.25, 0.25, 0.25, 0.25) 170 | for j in range(len(text1[k])): 171 | if text1[k][j] != 0 and text1[k][j] != 1 and text1[k][j] != 2: 172 | sent1.append(ind2word[text1[k][j]]) 173 | if preds1[k][j] != 0 and preds1[k][j] != 1 and preds1[k][j] != 2: 174 | sent2.append(ind2word[preds1[k][j]]) 175 | if len(sent2) > 0 and len(sent2) < 4 and weights == (0.25, 0.25, 0.25, 0.25): 176 | weights = (1 / len(sent2),) * len(sent2) 177 | c_bleu = sentence_bleu([sent1], sent2, weights = weights) 178 | per_disease_bleu[disease[k].item()].append(c_bleu) 179 | sents_gt.append(sent1) 180 | sents_pred.append(sent2) 181 | bleu += c_bleu 182 | return (bleu/len(text1)), sents_gt, sents_pred 183 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import random 7 | 8 | class LanguageModel(nn.Module): 9 | def __init__(self, vocab_size = 193, embed_size = 256, inp_size = 1024, hidden_size = 512, 10 | num_layers = 1, dropout_p = 0.1): 11 | super(LanguageModel, self).__init__() 12 | self.hidden_size = hidden_size 13 | self.embed_size = embed_size 14 | self.dropout = nn.Dropout(dropout_p) 15 | self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=2) 16 | self.project = nn.Linear(inp_size, hidden_size) 17 | self.gru = nn.GRU(input_size=embed_size, hidden_size=hidden_size, 18 | num_layers=num_layers, batch_first=True) 19 | # self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True) 20 | self.linear = nn.Linear(hidden_size, vocab_size) 21 | self.max_length = 15 22 | self.teacher_forcing_ratio = 0.5 23 | self.init_weights() 24 | 25 | def init_weights(self): 26 | self.embedding.weight.data.uniform_(-0.1, 0.1) 27 | self.linear.weight.data.uniform_(-0.1, 0.1) 28 | self.linear.bias.data.fill_(0) 29 | 30 | def forward(self, img_feats, text): 31 | text = text.squeeze() 32 | preds = [] 33 | self.gru.flatten_parameters() 34 | decoder_input = text[:, 0] 35 | state = self.project(img_feats).unsqueeze(0) 36 | for i in range(1, self.max_length): 37 | use_teacher_forcing = True if (self.training and random.random() < self.teacher_forcing_ratio) else False 38 | embeddings = self.dropout(self.embedding(decoder_input)).squeeze() 39 | feats, state = self.gru(embeddings.unsqueeze(1), state) 40 | pred = self.linear(state).squeeze() 41 | 42 | if use_teacher_forcing: 43 | decoder_input = text[:,i] 44 | else: 45 | output = F.log_softmax(pred, dim=-1) 46 | decoder_input = torch.argmax(output, dim=-1) 47 | preds.append(pred.unsqueeze(1)) 48 | 49 | return torch.cat(preds, 1) 50 | 51 | @gin.configurable 52 | class AutoEncoder(nn.Module): 53 | def __init__(self, model_type, model = None): 54 | super(AutoEncoder, self).__init__() 55 | self.model_type = model_type 56 | if model_type == 'self': 57 | self.conv = nn.Sequential(nn.Conv2d(3, 32, (4, 4), 2, 1), # 224 -> 112 58 | nn.ReLU(), 59 | nn.Conv2d(32, 64, (4, 4), 2,1), # 112 -> 56 60 | nn.ReLU(), 61 | nn.Conv2d(64, 128, (4, 4), 2, 1), # 56 -> 28 62 | nn.ReLU(), 63 | nn.Conv2d(128, 128, (4, 4), 2, 1), # 28 -> 14 64 | nn.ReLU(), 65 | nn.Conv2d(128, 64, (4, 4), 2, 1), # 14 -> 7 66 | nn.ReLU(), 67 | nn.Conv2d(64, 5, (7, 7), 1, 0)) 68 | else: 69 | self.conv = torch.nn.Sequential(*list(model.children())[:-1]) 70 | self.lin = nn.Linear(2048, 256) 71 | self.deconv = nn.Sequential( 72 | nn.ConvTranspose2d(256, 256, 7, 1, 0), 73 | nn.BatchNorm2d(256), 74 | nn.ReLU(), 75 | # state size: (ngf * 8) x 4 x 4 76 | nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False), 77 | nn.BatchNorm2d(256), 78 | nn.ReLU(), 79 | # state size: ngf x 32 x 32 80 | nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False), 81 | nn.BatchNorm2d(128), 82 | nn.ReLU(), 83 | # state size: (ngf * 4) x 8 x 8 84 | nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), 85 | nn.BatchNorm2d(64), 86 | nn.ReLU(), 87 | # state size: (ngf * 2) x 16 x 16 88 | nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False), 89 | nn.BatchNorm2d(32), 90 | nn.ReLU(), 91 | # state size: ngf x 32 x 32 92 | nn.ConvTranspose2d(32, 3, 4, 2, 1), 93 | nn.Tanh() 94 | ) 95 | 96 | def forward(self, x): 97 | out = self.conv(x).squeeze() 98 | if self.model_type != 'self': 99 | out = self.lin(F.relu(out)) 100 | out = self.deconv(out.view(-1, out.size(1), 1, 1)) 101 | return out 102 | 103 | @gin.configurable 104 | class AbnormalNet(nn.Module): 105 | def __init__(self, model_type, model = None): 106 | super(AbnormalNet, self).__init__() 107 | self.model_type = model_type 108 | if model_type == 'self': 109 | self.conv = nn.Sequential(nn.Conv2d(3, 32, (4, 4), 2, 1), # 224 -> 112 110 | nn.ReLU(), 111 | nn.Conv2d(32, 64, (4, 4), 2,1), # 112 -> 56 112 | nn.ReLU(), 113 | nn.Conv2d(64, 128, (4, 4), 2, 1), # 56 -> 28 114 | nn.ReLU(), 115 | nn.Conv2d(128, 128, (4, 4), 2, 1), # 28 -> 14 116 | nn.ReLU(), 117 | nn.Conv2d(128, 64, (4, 4), 2, 1), # 14 -> 7 118 | nn.ReLU(), 119 | nn.Conv2d(64, 5, (7, 7), 1, 0)) 120 | else: 121 | self.conv = torch.nn.Sequential(*list(model.children())[:-1]) 122 | self.lin = nn.Linear(2048, 5) 123 | 124 | def forward(self, x): 125 | out = self.conv(x).squeeze() 126 | if self.model_type != 'self': 127 | out = self.lin(F.relu(out)) 128 | return out 129 | 130 | @gin.configurable 131 | class MultiTaskModel(nn.Module): 132 | def __init__(self, model, vocab_size, model_type = 'densenet121', in_feats = gin.REQUIRED): 133 | super(MultiTaskModel, self).__init__() 134 | self.model_type = model_type 135 | if self.model_type == 'densenet121': 136 | self.feature_extract = model.features 137 | else: 138 | self.feature_extract = torch.nn.Sequential(*list(model.children())[:-1]) 139 | 140 | self.disease_classifier = nn.Sequential(nn.Linear(in_feats, 512), 141 | nn.ReLU(), nn.Linear(512, 5)) 142 | self.fine_disease_classifier = nn.Sequential(nn.Linear(in_feats, 512), 143 | nn.ReLU(), nn.Linear(512, 321)) 144 | self.language_classifier = LanguageModel(inp_size = in_feats, vocab_size = vocab_size) 145 | 146 | def forward(self, data, text): 147 | features = self.feature_extract(data).squeeze() 148 | out = F.relu(features) 149 | if self.model_type == 'densenet121': 150 | out = F.adaptive_avg_pool2d(out, (1, 1)) 151 | out = torch.flatten(out, 1) 152 | 153 | return (self.disease_classifier(out), 154 | self.fine_disease_classifier(out), 155 | self.language_classifier(out, text) 156 | ) 157 | -------------------------------------------------------------------------------- /grad-cam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from torch.autograd import Function 7 | from torchvision import models 8 | from torchvision import utils 9 | import cv2 10 | import sys 11 | import numpy as np 12 | import argparse 13 | from models import AbnormalNet 14 | class FeatureExtractor(): 15 | """ Class for extracting activations and 16 | registering gradients from targetted intermediate layers """ 17 | def __init__(self, model, target_layers): 18 | self.model = model 19 | self.target_layers = target_layers 20 | self.gradients = [] 21 | 22 | def save_gradient(self, grad): 23 | self.gradients.append(grad) 24 | 25 | def __call__(self, x): 26 | outputs = [] 27 | self.gradients = [] 28 | # print(list(self.model._modules.items())) 29 | for name, module in self.model._modules.items(): 30 | x = module(x) 31 | if name in self.target_layers: 32 | x.register_hook(self.save_gradient) 33 | outputs += [x] 34 | return outputs, x 35 | 36 | class ModelOutputs(): 37 | """ Class for making a forward pass, and getting: 38 | 1. The network output. 39 | 2. Activations from intermeddiate targetted layers. 40 | 3. Gradients from intermeddiate targetted layers. """ 41 | def __init__(self, model, target_layers): 42 | self.model = model 43 | self.model.features = self.model.conv 44 | self.model.classifier = self.model.lin 45 | self.feature_extractor = FeatureExtractor(self.model.features, target_layers) 46 | 47 | def get_gradients(self): 48 | return self.feature_extractor.gradients 49 | 50 | def __call__(self, x): 51 | target_activations, output = self.feature_extractor(x) 52 | # output = output.view(output.size(0), -1) 53 | output = F.relu(output.squeeze()) 54 | output = self.model.classifier(output) 55 | return target_activations, F.softmax(output, dim=-1) 56 | 57 | def preprocess_image(img): 58 | means=[0.485, 0.456, 0.406] 59 | stds=[0.229, 0.224, 0.225] 60 | 61 | preprocessed_img = img.copy()[: , :, ::-1] 62 | for i in range(3): 63 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] - means[i] 64 | preprocessed_img[:, :, i] = preprocessed_img[:, :, i] / stds[i] 65 | preprocessed_img = \ 66 | np.ascontiguousarray(np.transpose(preprocessed_img, (2, 0, 1))) 67 | preprocessed_img = torch.from_numpy(preprocessed_img) 68 | preprocessed_img.unsqueeze_(0) 69 | input = Variable(preprocessed_img, requires_grad = True) 70 | return input 71 | 72 | def show_cam_on_image(img, mask): 73 | heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET) 74 | heatmap = np.float32(heatmap) / 255 75 | cam = heatmap + np.float32(img) 76 | cam = cam / np.max(cam) 77 | return cam 78 | 79 | class GradCam: 80 | def __init__(self, model, target_layer_names, use_cuda): 81 | self.model = model 82 | self.model.eval() 83 | self.cuda = use_cuda 84 | if self.cuda: 85 | self.model = model.cuda() 86 | 87 | self.extractor = ModelOutputs(self.model, target_layer_names) 88 | 89 | def forward(self, input): 90 | return self.model(input) 91 | 92 | def __call__(self, input, index = None): 93 | if self.cuda: 94 | features, output = self.extractor(input.cuda()) 95 | else: 96 | features, output = self.extractor(input) 97 | 98 | if index == None: 99 | index = np.argmax(output.cpu().data.numpy()) 100 | 101 | one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32) 102 | one_hot[0][index] = 1 103 | one_hot = Variable(torch.from_numpy(one_hot), requires_grad = True) 104 | if self.cuda: 105 | one_hot = torch.sum(one_hot.cuda() * output) 106 | else: 107 | one_hot = torch.sum(one_hot * output) 108 | 109 | self.model.features.zero_grad() 110 | self.model.classifier.zero_grad() 111 | one_hot.backward(retain_graph=True) 112 | 113 | grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy() 114 | 115 | target = features[-1] 116 | target = target.cpu().data.numpy()[0, :] 117 | 118 | weights = np.mean(grads_val, axis = (2, 3))[0, :] 119 | cam = np.zeros(target.shape[1 : ], dtype = np.float32) 120 | 121 | for i, w in enumerate(weights): 122 | cam += w * target[i, :, :] 123 | 124 | cam = np.maximum(cam, 0) 125 | cam = cv2.resize(cam, (224, 224)) 126 | cam = cam - np.min(cam) 127 | cam = cam / np.max(cam) 128 | return cam, index 129 | 130 | class GuidedBackpropReLU(Function): 131 | 132 | def forward(self, input): 133 | positive_mask = (input > 0).type_as(input) 134 | output = torch.addcmul(torch.zeros(input.size()).type_as(input), input, positive_mask) 135 | self.save_for_backward(input, output) 136 | return output 137 | 138 | def backward(self, grad_output): 139 | input, output = self.saved_tensors 140 | grad_input = None 141 | 142 | positive_mask_1 = (input > 0).type_as(grad_output) 143 | positive_mask_2 = (grad_output > 0).type_as(grad_output) 144 | grad_input = torch.addcmul(torch.zeros(input.size()).type_as(input), torch.addcmul(torch.zeros(input.size()).type_as(input), grad_output, positive_mask_1), positive_mask_2) 145 | 146 | return grad_input 147 | 148 | class GuidedBackpropReLUModel: 149 | def __init__(self, model, use_cuda): 150 | self.model = model 151 | self.model.features = model.conv 152 | self.model.eval() 153 | self.cuda = use_cuda 154 | if self.cuda: 155 | self.model = model.cuda() 156 | 157 | # replace ReLU with GuidedBackpropReLU 158 | for idx, module in self.model.features._modules.items(): 159 | if module.__class__.__name__ == 'ReLU': 160 | self.model.features._modules[idx] = GuidedBackpropReLU() 161 | 162 | def forward(self, input): 163 | return self.model(input) 164 | 165 | def __call__(self, input, index = None): 166 | if self.cuda: 167 | output = self.forward(input.cuda()) 168 | else: 169 | output = self.forward(input) 170 | 171 | if index == None: 172 | index = np.argmax(output.cpu().data.numpy()) 173 | 174 | one_hot = np.zeros((1, output.size()[-1]), dtype = np.float32) 175 | one_hot[0][index] = 1 176 | one_hot = Variable(torch.from_numpy(one_hot), requires_grad = True) 177 | if self.cuda: 178 | one_hot = torch.sum(one_hot.cuda() * output) 179 | else: 180 | one_hot = torch.sum(one_hot * output) 181 | 182 | # self.model.features.zero_grad() 183 | # self.model.classifier.zero_grad() 184 | one_hot.backward(retain_graph=True) 185 | 186 | output = input.grad.cpu().data.numpy() 187 | output = output[0,:,:,:] 188 | 189 | return output 190 | 191 | def get_args(): 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument('--use-cuda', action='store_true', default=False, 194 | help='Use NVIDIA GPU acceleration') 195 | args = parser.parse_args() 196 | args.use_cuda = args.use_cuda and torch.cuda.is_available() 197 | if args.use_cuda: 198 | print("Using GPU for acceleration") 199 | else: 200 | print("Using CPU for computation") 201 | 202 | return args 203 | 204 | if __name__ == '__main__': 205 | """ python grad_cam.py 206 | 1. Loads an image with opencv. 207 | 2. Preprocesses it for VGG19 and converts to a pytorch variable. 208 | 3. Makes a forward pass to find the category index with the highest score, 209 | and computes intermediate activations. 210 | Makes the visualization. """ 211 | 212 | args = get_args() 213 | 214 | # Can work with any model, but it assumes that the model has a 215 | # feature method, and a classifier method, 216 | # as in the VGG models in torchvision. 217 | idx2label = {0:'melanoma', 1:'glaucoma', 2:'amd', 3:'diabetic retinopathy', 4:'normal'} 218 | task_list = [0.15, 0.25, 0.4, 0.55, 0.7, 0.85] 219 | for t in task_list: 220 | for i in os.listdir('Inked/'): 221 | spl = round(1 - t - 0.15, 2) 222 | # file_name = i.replace('Inked','').replace('_LI.jpg','').replace('%','').replace('_','5C', 1) 223 | file_name = 'Batch2079-805C1.2.826.0.1.3680043.2.110.1192826410552750.1.200_0000_000000_1559691433052d.jpg' 224 | model = models.resnet50(pretrained=True) 225 | model = nn.DataParallel(AbnormalNet('resnet50', model)) 226 | 227 | # model.load_state_dict(torch.load('/data2/sachelar/kd_models/kd_only-{:.2f}/best_model.pt'.format(spl))) 228 | model.load_state_dict(torch.load('small_models/{:.2f}-resnet/best_model.pt'.format(spl))) 229 | 230 | grad_cam = GradCam(model = model.module, \ 231 | target_layer_names = ["7"], use_cuda=args.use_cuda) 232 | try: 233 | img = cv2.imread(os.path.join('/data2/fundus_images',file_name), 1) 234 | img = np.float32(cv2.resize(img, (224, 224))) / 255 235 | input = preprocess_image(img) 236 | 237 | # If None, returns the map for the highest scoring category. 238 | # Otherwise, targets the requested index. 239 | target_index = None 240 | 241 | mask, index = grad_cam(input, target_index) 242 | 243 | 244 | cam = show_cam_on_image(img, mask) 245 | if not os.path.exists(os.path.join("outputs", file_name.replace('.jpg',''))): 246 | os.mkdir(os.path.join("outputs", file_name.replace('.jpg',''))) 247 | cv2.imwrite(os.path.join("outputs", file_name.replace('.jpg','') + '/'+"{:.2f}".format(spl) + '_'+idx2label[index] + "_cam.jpg"), np.uint8(255 * cam)) 248 | 249 | gb_model = GuidedBackpropReLUModel(model = model.module, use_cuda=args.use_cuda) 250 | gb = gb_model(input, index=target_index) 251 | # utils.save_image(torch.from_numpy(gb),"outputs/{:.2f}".format(spl) +'_'+ file_name.replace('.jpg','')+ '_gb.jpg') 252 | 253 | cam_mask = np.zeros(gb.shape) 254 | for i in range(0, gb.shape[0]): 255 | cam_mask[i, :, :] = mask 256 | 257 | cam_gb = np.multiply(cam_mask, gb) 258 | # utils.save_image(torch.from_numpy(cam_gb),"outputs/{:.2f}".format(spl) +'_'+ file_name.replace('.jpg','') + '_cam_gb.jpg') 259 | except: 260 | print("ERROR", file_name) 261 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import gin 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | import os 8 | from datetime import datetime 9 | from collections import defaultdict 10 | from utils import compute_bleu, compute_topk, accuracy_recall_precision_f1, calculate_confusion_matrix 11 | class BaseTrainer(object): 12 | def __init__(self, model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss = 100): 13 | self.model = model 14 | self.optimizer = optimizer 15 | self.scheduler = scheduler 16 | self.epochs = epochs 17 | self.criterion = criterion 18 | self.print_every = print_every 19 | self.min_val_loss = min_val_loss 20 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | # Save experiment configuration 22 | self.save_location_dir = os.path.join('models', str(datetime.now()).replace(' ','')) 23 | 24 | def init_saves(self): 25 | if not os.path.exists(self.save_location_dir): 26 | os.mkdir(self.save_location_dir) 27 | with open(os.path.join(self.save_location_dir,'config.gin'), 'w') as conf: 28 | conf.write(gin.operative_config_str()) 29 | self.output_log = os.path.join(self.save_location_dir,'output_log.txt') 30 | self.save_path = os.path.join(self.save_location_dir, 'best_model.pt') 31 | self.summary_writer = SummaryWriter(os.path.join(self.save_location_dir, 'logs'), 300) 32 | 33 | def train(self, train_loader, val_loader): 34 | raise NotImplementedError 35 | 36 | def train_iteration(self, train_loader, val_loader): 37 | raise NotImplementedError 38 | 39 | def validate(self, train_loader, val_loader): 40 | raise NotImplementedError 41 | 42 | def test(self, test_loader): 43 | raise NotImplementedError 44 | 45 | 46 | class MultiTaskTrainer(BaseTrainer): 47 | def __init__(self, model, optimizer, scheduler, criterion, tasks, epochs, lang, print_every = 100, min_val_loss = 100): 48 | super(MultiTaskTrainer, self).__init__(model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss) 49 | self.lang = lang 50 | self.tasks = tasks 51 | self.save_location_dir = os.path.join('models', '_'.join(str(t) for t in self.tasks) +'-'+ str(datetime.now()).replace(' ','')) 52 | self.init_saves() 53 | 54 | def train(self, train_loader, val_loader): 55 | for e in range(self.epochs): 56 | self.model.train() 57 | total_train_loss, total_tl1, total_tl2, total_tl3, total_disease_acc, accuracy, bleu = self.train_iteration(train_loader) 58 | print("Epoch", e) 59 | self.summary_writer.add_scalar('training/total_train_loss', total_train_loss, e) 60 | self.summary_writer.add_scalar('training/total_t1_loss', total_tl1, e) 61 | self.summary_writer.add_scalar('training/total_t2_loss', total_tl2, e) 62 | self.summary_writer.add_scalar('training/total_t3_loss', total_tl3, e) 63 | self.summary_writer.add_scalar('training/t1_acc', total_disease_acc, e) 64 | self.summary_writer.add_scalar('training/t2_acc', accuracy, e) 65 | self.summary_writer.add_scalar('training/t3_bleu', bleu, e) 66 | val_loss, total_d_acc, total_acc, bleu, total_f1, total_recall, total_precision, sent_gt, sent_pred, total_topk, per_disease_topk, per_disease_bleu, total_cm = self.validate(val_loader) 67 | with open(self.output_log, 'a+') as out: 68 | print('Epoch: {}\tVal Loss:{:.8f}\tAcc:{:.8f}\tDAcc:{:.8f}\tBLEU:{:.8f}'.format(e,val_loss, total_acc, total_d_acc, bleu), file=out) 69 | print('total_topk',total_topk, file=out) 70 | print('per_disease_topk', per_disease_topk, file=out) 71 | print('per_disease_bleu', per_disease_bleu, file=out) 72 | print(total_cm, file=out) 73 | for k in np.random.choice(list(range(len(sent_gt))), size=10, replace=False): 74 | print(sent_gt[k], file=out) 75 | print(sent_pred[k], file=out) 76 | print('---------------------', file=out) 77 | 78 | self.summary_writer.add_scalar('validation/val_loss', val_loss, e) 79 | self.summary_writer.add_scalar('validation/t1_acc', total_d_acc, e) 80 | self.summary_writer.add_scalar('validation/t2_acc', total_acc, e) 81 | self.summary_writer.add_scalar('validation/BLEU', bleu, e) 82 | # self.summary_writer.add_scalars('validation/f1_scores', disease_f1, e) 83 | # self.summary_writer.add_scalars('validation/recall_scores', 84 | # disease_recall, e) 85 | # self.summary_writer.add_scalars('validation/precision_scores', 86 | # disease_precision, e) 87 | 88 | self.summary_writer.add_scalar('validation/f1_mean', total_f1, e) 89 | self.summary_writer.add_scalar('validation/recall_mean', total_recall, e) 90 | self.summary_writer.add_scalar('validation/precision_mean', total_precision, e) 91 | self.summary_writer.add_scalars('validation/topk', total_topk, e) 92 | for i in per_disease_topk: 93 | self.summary_writer.add_scalars('validation/topk_'+str(i), per_disease_topk[i], e) 94 | 95 | for i, k in enumerate(np.random.choice(list(range(len(sent_gt))), size=10, replace=False)): 96 | self.summary_writer.add_text('validation/sentence_gt'+str(i), 97 | ' '.join(sent_gt[k]), e) 98 | self.summary_writer.add_text('validation/sentence_pred'+str(i), 99 | ' '.join(sent_pred[k]), e) 100 | 101 | def test(self, test_loader): 102 | results = open('predictions.csv','w') 103 | ind2word = {0:'Melanoma',1:'Glaucoma',2:'AMD',3:'DR',4:'Normal'} 104 | with torch.no_grad(): 105 | for i, (name, images, labels, f_labels, text) in enumerate(test_loader): 106 | batch_size = images.size(0) 107 | images = images.to(self.device) 108 | labels = labels.to(self.device) 109 | f_labels = f_labels.to(self.device) 110 | text = text.to(self.device) 111 | diseases, fine_diseases, text_pred = self.model(images, text) 112 | pred = F.log_softmax(diseases, dim= -1).argmax(dim=-1) 113 | for j in range(diseases.size(0)): 114 | results.write(name[j]+','+ind2word[labels[j].item()] +','+ind2word[pred[j].item()] +'\n') 115 | 116 | 117 | def train_iteration(self, train_loader): 118 | train_loss = 0.0 119 | accuracy = 0.0 120 | total_disease_acc = 0.0 121 | bleu = 0.0 122 | total_tl1 = 0 123 | total_tl2 = 0 124 | total_tl3 = 0 125 | total_train_loss = 0.0 126 | loss = torch.tensor(0).to(self.device) 127 | for i, (_, images, labels, f_labels, text) in enumerate(train_loader): 128 | batch_size = images.size(0) 129 | images = images.to(self.device) 130 | labels = labels.to(self.device) 131 | f_labels = f_labels.to(self.device) 132 | text = text.to(self.device) 133 | self.optimizer.zero_grad() 134 | disease, f_disease, text_pred = self.model(images, text) 135 | loss1 = self.criterion(disease, labels) 136 | 137 | loss2 = self.criterion(f_disease, f_labels) 138 | 139 | loss3 = 0.0 140 | for k in range(text_pred.size(1)): 141 | loss3 += self.criterion(text_pred[:, k].squeeze(), text[:, k + 1].squeeze()) 142 | 143 | # Only consider tasks defined in the task list 144 | loss = torch.stack((loss1,loss2, loss3))[self.tasks].sum() 145 | 146 | loss.backward() 147 | self.optimizer.step() 148 | 149 | train_loss += loss.item() 150 | total_train_loss += loss.item() 151 | total_tl1 += loss1.item() 152 | total_tl2 += loss2.item() 153 | total_tl3 += loss3.item() 154 | 155 | pred = F.log_softmax(f_disease, dim = -1).argmax(dim=-1) 156 | accuracy += pred.eq(f_labels).sum().item() 157 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1) 158 | total_disease_acc += d_pred.eq(labels).sum().item() 159 | # preds = torch.argmax(F.log_softmax(text_pred,dim=-1), dim=-1) 160 | # text1 = text[:, 1:].squeeze().tolist() 161 | # preds1 = preds.tolist() 162 | # tbleu, _, _ = compute_bleu(self.lang, text1, preds1, labels, per_disease_bleu) 163 | # bleu += tbleu 164 | 165 | if i != 0 and i % self.print_every == 0: 166 | avg_loss = train_loss / self.print_every 167 | accuracy = accuracy / self.print_every 168 | total_disease_acc = total_disease_acc / self.print_every 169 | avg_text_loss = loss3 / self.print_every 170 | bleu = bleu / self.print_every 171 | total_train_loss = total_train_loss / self.print_every 172 | total_tl1 = total_tl1 / self.print_every 173 | total_tl2 = total_tl2 / self.print_every 174 | total_tl3 = total_tl3 / self.print_every 175 | 176 | print('Iter:{}\tTraining Loss:{:.8f}\tAcc:{:.8f}\tDAcc:{:.8f}\tBLEU:{:.8f}\tTextLoss:{:.8f}'.format(i, avg_loss, 177 | accuracy/batch_size, 178 | total_disease_acc / batch_size, 179 | bleu, 180 | loss3.item())) 181 | 182 | train_loss = 0.0 183 | text_loss = 0.0 184 | accuracy = 0.0 185 | total_disease_acc = 0.0 186 | bleu = 0.0 187 | total_tl1 = 0 188 | total_tl2 = 0 189 | total_tl3 = 0 190 | total_train_loss = 0.0 191 | return (total_train_loss, total_tl1, total_tl2, total_tl3, total_disease_acc/batch_size, accuracy/batch_size, bleu) 192 | 193 | def validate(self, val_loader, epoch = 0): 194 | self.model.eval() 195 | val_loss = 0.0 196 | total_acc = 0.0 197 | total_recall = 0.0 198 | total_precision = 0.0 199 | total_f1 = 0.0 200 | total_cm = 0 201 | total_d_acc = 0.0 202 | bleu = 0.0 203 | total_l1 = 0 204 | total_l2 = 0 205 | total_l3 = 0 206 | 207 | k_vals = [1, 2, 3, 4, 5] 208 | total_topk = {k:0.0 for k in k_vals} 209 | per_disease_topk = defaultdict(lambda: {str(k):0.0 for k in k_vals}) 210 | per_disease_bleu = defaultdict(list) 211 | with torch.no_grad(): 212 | for i, (_, images, labels, f_labels, text) in enumerate(val_loader): 213 | batch_size = images.size(0) 214 | images = images.to(self.device) 215 | labels = labels.to(self.device) 216 | f_labels = f_labels.to(self.device) 217 | text = text.to(self.device) 218 | diseases, fine_diseases, text_pred = self.model(images, text) 219 | loss1 = self.criterion(diseases, labels) 220 | loss2 = self.criterion(fine_diseases, f_labels) 221 | text_loss = 0.0 222 | for k in range(text_pred.size(1)): 223 | text_loss += self.criterion(text_pred[:,k].squeeze(), text[:,k + 1].squeeze()) 224 | 225 | val_loss += torch.stack((loss1, loss2, text_loss))[self.tasks].sum() 226 | 227 | preds = F.log_softmax(fine_diseases, dim = -1) 228 | pred = preds.argmax(dim=-1) 229 | d_pred = F.log_softmax(diseases, dim = -1).argmax(dim=-1) 230 | 231 | # Evaluation of P, R, F1, CM, BLEU 232 | total_acc += (pred.eq(f_labels).sum().item() / batch_size) 233 | total_d_acc += (d_pred.eq(labels).sum().item() / batch_size) 234 | acc, recall, precision, f1 = accuracy_recall_precision_f1(d_pred, 235 | labels) 236 | cm = calculate_confusion_matrix(d_pred, labels) 237 | try: 238 | total_cm += (cm / batch_size) 239 | except: 240 | print("Error occured for this CM") 241 | print(cm / batch_size) 242 | 243 | # Top-k evaluation 244 | for k in k_vals: 245 | total_topk[k] += compute_topk(preds, f_labels, k) 246 | for d in [0, 1, 2, 3]: 247 | mask = labels.eq(d) 248 | if mask.sum() > 0: 249 | per_disease_topk[d][str(k)] += compute_topk(preds[mask], f_labels[mask], k) 250 | 251 | total_recall += np.mean(recall) 252 | total_precision += np.mean(precision) 253 | total_f1 += np.mean(f1) 254 | preds = torch.argmax(F.log_softmax(text_pred,dim=-1), dim=-1) 255 | text1 = text[:, 1:].squeeze().tolist() 256 | preds1 = preds.tolist() 257 | t_bleu, sent_gt, sent_pred = compute_bleu(self.lang, text1, preds1, labels, per_disease_bleu) 258 | 259 | # Book-keeping 260 | bleu += t_bleu 261 | total_l1 += loss1.item() 262 | total_l2 += loss2.item() 263 | total_l3 += text_loss.item() 264 | bleu = bleu / (len(val_loader)) 265 | val_loss = val_loss / len(val_loader) 266 | total_l1 /= len(val_loader) 267 | total_l2 /= len(val_loader) 268 | total_l3 /= len(val_loader) 269 | total_acc = total_acc / len(val_loader) 270 | total_d_acc = total_d_acc / len(val_loader) 271 | total_f1 = total_f1 / len(val_loader) 272 | total_precision = total_precision / len(val_loader) 273 | total_recall = total_recall / len(val_loader) 274 | total_cm = total_cm / len(val_loader) 275 | 276 | self.scheduler.step(val_loss) 277 | if val_loss <= self.min_val_loss: 278 | torch.save(self.model.state_dict(), self.save_path) 279 | self.min_val_loss = val_loss 280 | 281 | disease_f1 = {} 282 | disease_precision = {} 283 | disease_recall = {} 284 | 285 | #for i in range(len(total_f1)): 286 | # disease_f1[i] = total_f1[i] 287 | # disease_precision[i] = total_precision[i] 288 | # disease_recall[i] = total_recall[i] 289 | for d in per_disease_bleu: 290 | per_disease_bleu[d] = np.mean(per_disease_bleu[d]) 291 | 292 | total_topk = {str(k) : total_topk[k] / len(val_loader) for k in k_vals} 293 | for d in [0,1,2,3]: 294 | for k in k_vals: 295 | per_disease_topk[d][str(k)] = per_disease_topk[d][str(k)] / len(val_loader) 296 | 297 | return (val_loss, total_d_acc, total_acc, bleu, total_f1, total_recall, 298 | total_precision, sent_gt, sent_pred, total_topk, 299 | per_disease_topk, per_disease_bleu, total_cm) 300 | 301 | 302 | 303 | class SmallTrainer(BaseTrainer): 304 | def __init__(self, model, optimizer, scheduler, criterion, epochs, print_every = 100, min_val_loss = 100, trainset_split = 0.85): 305 | super(SmallTrainer, self).__init__(model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss) 306 | self.save_location_dir = os.path.join('small_models', str(trainset_split)+'-'+str(datetime.now())) 307 | self.init_saves() 308 | 309 | def train(self, train_loader, val_loader): 310 | for e in range(self.epochs): 311 | self.model.train() 312 | total_train_loss, accuracy = self.train_iteration(train_loader) 313 | print("Epoch", e) 314 | self.summary_writer.add_scalar('training/total_train_loss', total_train_loss, e) 315 | self.summary_writer.add_scalar('training/acc', accuracy, e) 316 | with torch.no_grad(): 317 | val_loss, total_d_acc, total_f1, total_recall, total_precision, total_cm = self.validate(val_loader) 318 | 319 | self.summary_writer.add_scalar('validation/val_loss', val_loss, e) 320 | self.summary_writer.add_scalar('validation/t1_acc', total_d_acc, e) 321 | 322 | self.summary_writer.add_scalar('validation/f1_mean', total_f1, e) 323 | self.summary_writer.add_scalar('validation/recall_mean', total_recall, e) 324 | self.summary_writer.add_scalar('validation/precision_mean', total_precision, e) 325 | with open(self.output_log, 'a+') as out: 326 | print('Val Loss',val_loss, 'total_d_acc',total_d_acc, 'F1', 327 | total_f1, 'R', total_recall,'P', total_precision, 328 | file=out) 329 | print(total_cm, file=out) 330 | 331 | def test(self, test_loader): 332 | results = open('self_trained_extra_labels.csv','w') 333 | self.model.eval() 334 | # ind2disease = {0:'Disease',1:'Normal'} 335 | ind2disease = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR', 4:'Normal'} 336 | # ind2disease2 = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR'} 337 | ind2disease2 = {0:'not applicable' , 1: 'not classifed', 2: 'diabetes no retinopathy'} 338 | for i, data in tqdm.tqdm(enumerate(test_loader)): 339 | image_name = data[0] 340 | images = data[1] 341 | labels = data[2] 342 | batch_size = images.size(0) 343 | images = images.to(self.device) 344 | disease = self.model(images) 345 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1) 346 | probs, _ = F.softmax(disease, dim=-1).max(dim=-1) 347 | for j in range(d_pred.size(0)): 348 | results.write(image_name[j]+','+ '{:.8f}'.format(probs[j].item()) +',' + ind2disease2[labels[j].item()] +','+ind2disease[d_pred[j].item()]+'\n') 349 | 350 | def train_iteration(self, train_loader): 351 | train_loss = 0.0 352 | accuracy = 0.0 353 | total_disease_acc = 0.0 354 | total_train_loss = 0.0 355 | for i, (images, labels) in enumerate(train_loader): 356 | batch_size = images.size(0) 357 | images = images.to(self.device) 358 | labels = labels.to(self.device) 359 | 360 | self.optimizer.zero_grad() 361 | disease = self.model(images) 362 | loss = self.criterion(disease, labels) 363 | 364 | loss.backward() 365 | self.optimizer.step() 366 | 367 | train_loss += loss.item() 368 | total_train_loss += loss.item() 369 | 370 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1) 371 | total_disease_acc += d_pred.eq(labels).sum().item() 372 | 373 | if i != 0 and i % self.print_every == 0: 374 | avg_loss = train_loss / self.print_every 375 | total_disease_acc = total_disease_acc / self.print_every 376 | total_train_loss = total_train_loss / self.print_every 377 | 378 | print('Iter:{}\tTraining Loss:{:.8f}\tAcc:{:.8f}'.format(i, 379 | avg_loss, total_disease_acc / batch_size)) 380 | 381 | train_loss = 0.0 382 | total_disease_acc = 0.0 383 | return (total_train_loss, total_disease_acc/batch_size) 384 | 385 | def validate(self, val_loader, epoch = 0): 386 | self.model.eval() 387 | val_loss = 0.0 388 | total_acc = 0.0 389 | total_recall = 0.0 390 | total_precision = 0.0 391 | total_f1 = 0.0 392 | total_cm = 0 393 | total_d_acc = 0.0 394 | bleu = 0.0 395 | total_l1 = 0 396 | total_l2 = 0 397 | total_l3 = 0 398 | 399 | k_vals = [1, 2, 3, 4, 5] 400 | total_topk = {k:0.0 for k in k_vals} 401 | per_disease_topk = defaultdict(lambda: {str(k):0.0 for k in k_vals}) 402 | losses = [] 403 | with torch.no_grad(): 404 | for i, (images, labels) in enumerate(val_loader): 405 | batch_size = images.size(0) 406 | images = images.to(self.device) 407 | labels = labels.to(self.device) 408 | diseases = self.model(images) 409 | loss1 = self.criterion(diseases, labels) 410 | 411 | val_loss += loss1.item() 412 | 413 | # Evaluation of P, R, F1, BLEU 414 | d_pred = F.log_softmax(diseases, dim = -1).argmax(dim=-1) 415 | total_d_acc += (d_pred.eq(labels).sum().item() / batch_size) 416 | acc, recall, precision, f1 = accuracy_recall_precision_f1(d_pred, 417 | labels) 418 | 419 | total_recall += np.mean(recall) 420 | total_precision += np.mean(precision) 421 | total_f1 += np.mean(f1) 422 | 423 | cm = calculate_confusion_matrix(d_pred, labels) 424 | try: 425 | total_cm += (cm / batch_size) 426 | except: 427 | print("error occured for this CM") 428 | print(cm / batch_size) 429 | val_loss = val_loss / len(val_loader) 430 | total_d_acc = total_d_acc / len(val_loader) 431 | total_f1 = total_f1 / len(val_loader) 432 | total_precision = total_precision / len(val_loader) 433 | total_recall = total_recall / len(val_loader) 434 | total_cm = total_cm / len(val_loader) 435 | 436 | self.scheduler.step(val_loss) 437 | if val_loss <= self.min_val_loss: 438 | torch.save(self.model.state_dict(), self.save_path) 439 | self.min_val_loss = val_loss 440 | 441 | disease_f1 = {} 442 | disease_precision = {} 443 | disease_recall = {} 444 | 445 | # for i in range(len(total_f1)): 446 | # disease_f1[i] = total_f1[i] 447 | # disease_precision[i] = total_precision[i] 448 | # disease_recall[i] = total_recall[i] 449 | 450 | return (val_loss, total_d_acc, total_f1, total_recall, 451 | total_precision, total_cm) 452 | 453 | class AutoTrainer(BaseTrainer): 454 | def __init__(self, model, optimizer, scheduler, criterion, epochs, 455 | print_every = 100, min_val_loss = 100000, trainset_split = 0.85): 456 | super(AutoTrainer, self).__init__(model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss) 457 | self.save_location_dir = os.path.join('auto_models', str(trainset_split)+'-'+str(datetime.now())) 458 | self.init_saves() 459 | 460 | def train(self, train_loader, val_loader): 461 | for e in range(self.epochs): 462 | self.model.train() 463 | total_train_loss = self.train_iteration(train_loader) 464 | print("Epoch", e) 465 | self.summary_writer.add_scalar('training/total_train_loss', total_train_loss, e) 466 | with torch.no_grad(): 467 | val_loss = self.validate(val_loader) 468 | 469 | self.summary_writer.add_scalar('validation/val_loss', val_loss, e) 470 | with open(self.output_log, 'a+') as out: 471 | print('Val Loss',val_loss, file=out) 472 | 473 | def test(self, test_loader): 474 | results = open('self_trained_extra_labels.csv','w') 475 | self.model.eval() 476 | # ind2disease = {0:'Disease',1:'Normal'} 477 | ind2disease = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR', 4:'Normal'} 478 | # ind2disease2 = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR'} 479 | ind2disease2 = {0:'not applicable' , 1: 'not classifed', 2: 'diabetes no retinopathy'} 480 | for i, data in tqdm.tqdm(enumerate(test_loader)): 481 | image_name = data[0] 482 | images = data[1] 483 | labels = data[2] 484 | batch_size = images.size(0) 485 | images = images.to(self.device) 486 | disease = self.model(images) 487 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1) 488 | probs, _ = F.softmax(disease, dim=-1).max(dim=-1) 489 | for j in range(d_pred.size(0)): 490 | results.write(image_name[j]+','+ '{:.8f}'.format(probs[j].item()) +',' + ind2disease2[labels[j].item()] +','+ind2disease[d_pred[j].item()]+'\n') 491 | 492 | def train_iteration(self, train_loader): 493 | train_loss = 0.0 494 | accuracy = 0.0 495 | total_disease_acc = 0.0 496 | total_train_loss = 0.0 497 | for i, data in enumerate(train_loader): 498 | (_, images, labels, f_labels, text) = data 499 | batch_size = images.size(0) 500 | images = images.to(self.device) 501 | labels = labels.to(self.device) 502 | 503 | self.optimizer.zero_grad() 504 | disease = self.model(images) 505 | loss = self.criterion(disease, images).div(images.size(0)) 506 | 507 | loss.backward() 508 | self.optimizer.step() 509 | 510 | train_loss += loss.item() 511 | total_train_loss += loss.item() 512 | 513 | 514 | if i != 0 and i % self.print_every == 0: 515 | avg_loss = train_loss / self.print_every 516 | 517 | print('Iter:{}\tTraining Loss:{:.8f}'.format(i, 518 | avg_loss)) 519 | 520 | train_loss = 0.0 521 | return total_train_loss 522 | 523 | def validate(self, val_loader, epoch = 0): 524 | self.model.eval() 525 | val_loss = 0.0 526 | total_acc = 0.0 527 | total_recall = 0.0 528 | total_precision = 0.0 529 | total_f1 = 0.0 530 | total_cm = 0 531 | total_d_acc = 0.0 532 | bleu = 0.0 533 | total_l1 = 0 534 | total_l2 = 0 535 | total_l3 = 0 536 | 537 | k_vals = [1, 2, 3, 4, 5] 538 | total_topk = {k:0.0 for k in k_vals} 539 | per_disease_topk = defaultdict(lambda: {str(k):0.0 for k in k_vals}) 540 | losses = [] 541 | with torch.no_grad(): 542 | for i, data in enumerate(val_loader): 543 | (_, images, labels, f_labels, text) = data 544 | batch_size = images.size(0) 545 | images = images.to(self.device) 546 | labels = labels.to(self.device) 547 | diseases = self.model(images) 548 | loss1 = self.criterion(diseases, images).div(images.size(0)) 549 | 550 | val_loss += loss1.item() 551 | 552 | val_loss = val_loss / len(val_loader) 553 | 554 | self.scheduler.step(val_loss) 555 | if val_loss <= self.min_val_loss: 556 | torch.save(self.model.state_dict(), self.save_path) 557 | self.min_val_loss = val_loss 558 | 559 | return val_loss 560 | 561 | 562 | class KDTrainer(BaseTrainer): 563 | def __init__(self, kd_model, model, optimizer, scheduler, criterion, epochs, print_every = 100, min_val_loss = 100, trainset_split = 0.85, kd_type = 'full'): 564 | super(KDTrainer, self).__init__(model, optimizer, scheduler, criterion, epochs, print_every, min_val_loss) 565 | self.criterion = self.distillation_loss 566 | self.kd_model = kd_model 567 | self.threshold = 0.9 568 | self.kd_type = kd_type 569 | self.kd_model.eval() 570 | self.save_location_dir = os.path.join('kdmodels', kd_type +'-'+str(trainset_split)+'-'+str(datetime.now())) 571 | self.init_saves() 572 | 573 | def distillation_loss(self, y, labels, teacher_scores, T = 5, alpha = 0.95, reduction_kd='mean', reduction_nll='mean'): 574 | if teacher_scores is not None: 575 | # d_loss = torch.nn.KLDivLoss(reduction=reduction_kd)(F.log_softmax(y/ T, dim= -1), F.softmax(teacher_scores / T, dim=-1)) * T * T 576 | preds = F.softmax(teacher_scores , dim=-1).argmax(dim=-1) 577 | d_loss = F.cross_entropy(y, labels, reduction=reduction_nll) 578 | else: 579 | assert alpha == 0, 'alpha cannot be {} when teacher scores are not provided'.format(alpha) 580 | d_loss = 0.0 581 | nll_loss = F.cross_entropy(y, labels, reduction=reduction_nll) 582 | if self.kd_type == 'full': 583 | tol_loss = alpha * d_loss + (1.0 - alpha) * nll_loss 584 | else: 585 | tol_loss = d_loss 586 | return tol_loss, d_loss, nll_loss 587 | 588 | def unpack_data(self, data): 589 | if self.kd_type != 'full': 590 | (_, images, labels, f_labels, text) = data 591 | else: 592 | images, labels = data 593 | return (images, labels) 594 | 595 | def train(self, train_loader, val_loader): 596 | for e in range(self.epochs): 597 | self.model.train() 598 | total_train_loss, accuracy = self.train_iteration(train_loader) 599 | break 600 | print("Epoch", e) 601 | self.summary_writer.add_scalar('training/total_train_loss', total_train_loss, e) 602 | self.summary_writer.add_scalar('training/acc', accuracy, e) 603 | with torch.no_grad(): 604 | val_loss, total_d_acc, total_f1, total_recall, total_precision, total_cm = self.validate(val_loader) 605 | 606 | self.summary_writer.add_scalar('validation/val_loss', val_loss, e) 607 | self.summary_writer.add_scalar('validation/t1_acc', total_d_acc, e) 608 | 609 | self.summary_writer.add_scalar('validation/f1_mean', total_f1, e) 610 | self.summary_writer.add_scalar('validation/recall_mean', total_recall, e) 611 | self.summary_writer.add_scalar('validation/precision_mean', total_precision, e) 612 | with open(self.output_log, 'a+') as out: 613 | print('Val Loss',val_loss, 'total_d_acc',total_d_acc, 'F1', 614 | total_f1, 'R', total_recall,'P', total_precision) 615 | print('Val Loss',val_loss, 'total_d_acc',total_d_acc, 'F1', 616 | total_f1, 'R', total_recall,'P', total_precision, 617 | file=out) 618 | print(total_cm, file=out) 619 | 620 | def test(self, test_loader): 621 | results = open('self_trained_extra_labels.csv','w') 622 | self.model.eval() 623 | # ind2disease = {0:'Disease',1:'Normal'} 624 | ind2disease = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR', 4:'Normal'} 625 | # ind2disease2 = {0:'Melanoma' , 1: 'Glaucoma', 2: 'AMD', 3:'DR'} 626 | ind2disease2 = {0:'not applicable' , 1: 'not classifed', 2: 'diabetes no retinopathy'} 627 | for i, data in tqdm.tqdm(enumerate(test_loader)): 628 | image_name = data[0] 629 | images = data[1] 630 | labels = data[2] 631 | batch_size = images.size(0) 632 | images = images.to(self.device) 633 | disease = self.model(images) 634 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1) 635 | probs, _ = F.softmax(disease, dim=-1).max(dim=-1) 636 | for j in range(d_pred.size(0)): 637 | results.write(image_name[j]+','+ '{:.8f}'.format(probs[j].item()) +',' + ind2disease2[labels[j].item()] +','+ind2disease[d_pred[j].item()]+'\n') 638 | 639 | def train_iteration(self, train_loader): 640 | train_loss = 0.0 641 | accuracy = 0.0 642 | total_disease_acc = 0.0 643 | total_train_loss = 0.0 644 | total_kd_loss = 0.0 645 | total_nll_loss = 0.0 646 | contr = 0 647 | for i, data in enumerate(train_loader): 648 | images, labels = self.unpack_data(data) 649 | batch_size = images.size(0) 650 | images = images.to(self.device) 651 | labels = labels.to(self.device) 652 | 653 | self.optimizer.zero_grad() 654 | 655 | teacher_scores = self.kd_model(images) 656 | val, pred = F.softmax(teacher_scores, dim=-1).max(dim=-1) 657 | 658 | index = val >= self.threshold 659 | if index.any().item(): 660 | # print(index, val) 661 | images = images[index] 662 | labels = labels[index] 663 | teacher_scores = teacher_scores[index] 664 | contr += images.size(0) 665 | continue 666 | disease = self.model.module(images) 667 | loss = self.criterion(disease, labels, teacher_scores) 668 | 669 | loss[0].backward() 670 | self.optimizer.step() 671 | 672 | train_loss += loss[0].item() 673 | total_train_loss += loss[0].item() 674 | total_kd_loss += loss[1].item() 675 | total_nll_loss += loss[2].item() 676 | 677 | d_pred = F.log_softmax(disease, dim= -1).argmax(dim=-1) 678 | total_disease_acc += d_pred.eq(labels).sum().item() 679 | 680 | if i != 0 and i % self.print_every == 0: 681 | avg_loss = train_loss / self.print_every 682 | total_disease_acc = total_disease_acc / self.print_every 683 | total_train_loss = total_train_loss / self.print_every 684 | 685 | print('Iter:{}\tTraining Loss:{:.8f}\tKD:{:.8f}\tNLL:{:.8f}\tAcc:{:.8f}'.format(i, 686 | avg_loss, total_kd_loss/ ((i+1)*batch_size), 687 | total_nll_loss / ((i+1) * batch_size ), 688 | total_disease_acc / batch_size)) 689 | 690 | train_loss = 0.0 691 | total_disease_acc = 0.0 692 | print("COUNT", contr) 693 | return (total_train_loss, total_disease_acc/batch_size) 694 | 695 | def validate(self, val_loader, epoch = 0): 696 | self.model.eval() 697 | val_loss = 0.0 698 | total_acc = 0.0 699 | total_recall = 0.0 700 | total_precision = 0.0 701 | total_f1 = 0.0 702 | total_cm = 0 703 | total_d_acc = 0.0 704 | bleu = 0.0 705 | total_l1 = 0 706 | total_l2 = 0 707 | total_l3 = 0 708 | 709 | k_vals = [1, 2, 3, 4, 5] 710 | total_topk = {k:0.0 for k in k_vals} 711 | per_disease_topk = defaultdict(lambda: {str(k):0.0 for k in k_vals}) 712 | losses = [] 713 | with torch.no_grad(): 714 | for i, (images, labels) in enumerate(val_loader): 715 | batch_size = images.size(0) 716 | images = images.to(self.device) 717 | labels = labels.to(self.device) 718 | teacher_scores = self.kd_model(images) 719 | val, pred = F.softmax(teacher_scores, dim=-1).max(dim=-1) 720 | 721 | index = val >= self.threshold 722 | if index.any().item(): 723 | images = images[index] 724 | labels = labels[index] 725 | teacher_scores = teacher_scores[index] 726 | diseases = self.model.module(images) 727 | 728 | loss1, _, _ = self.criterion(diseases, labels, teacher_scores) 729 | 730 | val_loss += loss1.item() 731 | 732 | # Evaluation of P, R, F1, BLEU 733 | d_pred = F.log_softmax(diseases, dim = -1).argmax(dim=-1) 734 | total_d_acc += (d_pred.eq(labels).sum().item() / batch_size) 735 | acc, recall, precision, f1 = accuracy_recall_precision_f1(d_pred, 736 | labels) 737 | 738 | total_recall += np.mean(recall) 739 | total_precision += np.mean(precision) 740 | total_f1 += np.mean(f1) 741 | 742 | cm = calculate_confusion_matrix(d_pred, labels) 743 | try: 744 | total_cm += (cm / batch_size) 745 | except: 746 | print("error occured for this CM") 747 | print(cm / batch_size) 748 | val_loss = val_loss / len(val_loader) 749 | total_d_acc = total_d_acc / len(val_loader) 750 | total_f1 = total_f1 / len(val_loader) 751 | total_precision = total_precision / len(val_loader) 752 | total_recall = total_recall / len(val_loader) 753 | total_cm = total_cm / len(val_loader) 754 | 755 | self.scheduler.step(val_loss) 756 | if val_loss <= self.min_val_loss: 757 | torch.save(self.model.state_dict(), self.save_path) 758 | self.min_val_loss = val_loss 759 | 760 | disease_f1 = {} 761 | disease_precision = {} 762 | disease_recall = {} 763 | 764 | # for i in range(len(total_f1)): 765 | # disease_f1[i] = total_f1[i] 766 | # disease_precision[i] = total_precision[i] 767 | # disease_recall[i] = total_recall[i] 768 | 769 | return (val_loss, total_d_acc, total_f1, total_recall, 770 | total_precision, total_cm) 771 | 772 | 773 | --------------------------------------------------------------------------------