├── Data ├── data.py ├── hzhu_gen.py └── hzhu_data_raw.py ├── LICENSE ├── Module ├── hzhu_data.py ├── hzhu_net.py ├── RUN.py ├── hzhu_metrics_class.py ├── hzhu_gen.py ├── hzhu_metrics_saliency.py ├── hzhu_MTL_UNet.py └── hzhu_learn.py └── README.md /Data/data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | from hzhu_data_raw import * 8 | import os 9 | from hzhu_gen import * 10 | import argparse 11 | 12 | 13 | # In[2]: 14 | 15 | 16 | if __name__ == '__main__': 17 | QH = QuickHelper() 18 | 19 | print('Running on my laptop') 20 | gaze_path = 'D:/Gaze Dataset/Eye gaze data for chest X-rays/extracted' 21 | cxr_path = 'D:/Gaze Dataset/MIMIC-CXR & GAZE (master)/RAW/CXR' 22 | save_path = os.getcwd() 23 | fraction = 0.01 24 | 25 | print('Data preparation completed') 26 | print(QH) 27 | 28 | downsample = 5 29 | blur = 500 30 | path_str = 'data' 31 | 32 | local_save_path = save_path+'/'+path_str 33 | create_folder(local_save_path) 34 | 35 | DATA = MasterDataHandle(gaze_path=gaze_path, cxr_path=cxr_path, blur=blur) 36 | DATA.save_all(root_path=local_save_path, downsample=downsample, fraction=fraction, seed=0) 37 | 38 | print('%s generation completed'%path_str) 39 | print(QH) 40 | 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Hongzhi Zhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Module/hzhu_data.py: -------------------------------------------------------------------------------- 1 | from hzhu_gen import * 2 | 3 | import torch, pickle, copy 4 | from torch.utils.data import Dataset, DataLoader 5 | from torch.nn import functional as F 6 | import pandas as pd 7 | import numpy as np 8 | 9 | from matplotlib import pyplot as plt 10 | import os 11 | 12 | class DataHandle(Dataset): 13 | 14 | def __init__(self, path): 15 | self.path = path 16 | self.init() 17 | 18 | def init(self): 19 | self.file_list = ls_file(path=self.path) 20 | 21 | def __len__(self): 22 | return len(self.file_list) 23 | 24 | def __getitem__(self, i): 25 | local_path = self.path+'/'+self.file_list[i] 26 | return torch.load(local_path) 27 | 28 | def plot(self, i): 29 | data = self[i] 30 | plt.figure(figsize=(8,10)) 31 | for i, key in enumerate(data): 32 | plt.subplot(2,3,1+i) 33 | shape = data[key].shape 34 | if len(shape)==3: 35 | plt.imshow(data[key][0,:,:]) 36 | elif len(shape)==2: 37 | plt.imshow(data[key]) 38 | plt.title('%s\n%s\n%.3f\n%.3f'%(data[key].dtype, data[key].shape, data[key].min(), data[key].max())) 39 | 40 | class DataMaster: 41 | 42 | def __init__(self, path, batch_size): 43 | self.path = path 44 | self.batch_size = batch_size 45 | 46 | name_list = ['Train','Test','Valid'] 47 | num_workers = torch.get_num_threads()-1 if torch.get_num_threads()<=9 else 8 48 | self.handle = {item:DataHandle(self.path+'/'+item) for item in name_list} 49 | self.dataLoader = {item:DataLoader( 50 | self.handle[item], batch_size=self.batch_size, shuffle=True, 51 | num_workers=num_workers, pin_memory=True, prefetch_factor=4) for item in name_list} 52 | 53 | def __call__(self, key): 54 | return self.dataLoader[key] 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MT-UNet 2 | 3 | This repo contains python codes used for **"Multi-task UNet: Jointly Boosting Saliency Prediction and Disease Classification on Chest X-ray Images"**, Hongzhi Zhu et al., submitted to Medical Imaging with Deep learning (MIDL) 2022. 4 | 5 | - Folder _Data_ contain files used for data pre-processing (before training). 6 | - Folder _Module_ contain files used for network training and evaluation. 7 | 8 | The following packages are used with Python 3.7: 9 | - PyTorch 1.8.0 10 | - torchvision 0.9.0 11 | - tensorboard 2.3.0 12 | - pandas 1.1.0 13 | - matplotlib 3.3.1 14 | - numpy 1.19.2 15 | - scipy 1.5.2 16 | - sklearn 0.23.2 17 | - opencv 4.0.1 18 | - pydicom 2.2.0 19 | 20 | ## _Data_ folder 21 | 22 | To run the code, simply execute data.py through “python data.py” in command line or with other Python IDEs. After execution, new folders containing per-processed and split datasets ready for training will be created in the execution directory. The raw dataset is partitioned into training (70%), validation (10%) and testing (20%) subsets. Seeding (value 0) is used for reproducibility. 23 | 24 | ## _Module_ folder 25 | 26 | To run the code, simply execute RUN.py through “python RUN.py” in command line or with other Python IDEs. After execution, a new folder _run_ will be created (if not already exists) in the execution directory. For each execution of RUN.py, a new folder with a random name will be created inside folder _run_ for the temporary storage of network parameters as well as recoding training details and evaluation results. The following lists the file details inside the folder: 27 | - params_MTL_UNet_preset_.txt, records parameters and data related to network training. 28 | - params_NetLearn_.txt, records hyper-parameters for the network. 29 | - QuickHelper_summary.txt, records other data during execution. 30 | - NET_XXXX/classification_report.json, records classification evaluation metrics, and XXXXX are random characters generated during run time. 31 | - NET_XXXX/classification_results.csv, records raw classification output for each test image. 32 | - NET_XXXX/prediction_report.json, records classification and saliency prediction performance matrices. 33 | - NET_XXXX/NET.pt, stores the parameters for the best performing network during or after training. 34 | - NET_XXXX/training_process.png, visualizes the change of learning rate, losses, and validation metrics during the training process. 35 | During execution, RUN.py will also print training details and evaluation results to the console. 36 | -------------------------------------------------------------------------------- /Module/hzhu_net.py: -------------------------------------------------------------------------------- 1 | import torch, json 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | class Module(nn.Module): 8 | 9 | def visualize(self, x): 10 | # tensorboard --logdir=runs 11 | writer = SummaryWriter('runs/%s'%self.__class__.__name__) 12 | writer.add_graph(self, x) 13 | writer.close() 14 | print('tensorboard --logdir=runs') 15 | 16 | def save_params(self, path, name): 17 | content = {} 18 | self.total_param = self.get_total_param() 19 | for attr in dir(self): 20 | if attr[0]=='_': 21 | continue 22 | attr_instance = getattr(self, attr) 23 | if callable(attr_instance): 24 | continue 25 | if isinstance(attr_instance, (float, int, bool, list, str)): 26 | content[attr] = attr_instance 27 | 28 | with open(path+'/'+'params_%s_'%(self.__class__.__name__)+name+'.txt', 'w') as file: 29 | try: 30 | json.dump(content, file, indent=4) 31 | except: 32 | print('Exception occured at hzhu_resnet::%s.save_params(..): content cannot be dumped!'%self.__class__.__name__) 33 | file.write(str(content)) 34 | 35 | def get_total_param(self): 36 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 37 | 38 | def conv_bn_acti_drop( 39 | in_channels, 40 | out_channels, 41 | kernel_size, 42 | activation=nn.ReLU, 43 | normalize=nn.BatchNorm2d, 44 | dropout_rate=0.0, 45 | sequential=None, 46 | name='', 47 | stride=1, 48 | padding=0, 49 | dilation=1, 50 | groups=1, 51 | bias=True, 52 | padding_mode='zeros'): 53 | 54 | if sequential is None: 55 | r = nn.Sequential() 56 | else: 57 | r = sequential 58 | 59 | if len(name)==0: 60 | connector = '' 61 | else: 62 | connector = '_' 63 | 64 | r.add_module( 65 | name+connector+'conv2d', 66 | nn.Conv2d( 67 | in_channels=in_channels, 68 | out_channels=out_channels, 69 | kernel_size=kernel_size, 70 | stride=stride, 71 | padding=padding, 72 | dilation=dilation, 73 | groups=groups, 74 | bias=bias, 75 | padding_mode=padding_mode)) 76 | 77 | if normalize is not None: 78 | norm_layer = normalize(out_channels) 79 | r.add_module( 80 | name+connector+norm_layer.__class__.__name__, 81 | norm_layer) 82 | 83 | if activation is not None: 84 | acti = activation() 85 | r.add_module( 86 | name+connector+acti.__class__.__name__, 87 | acti) 88 | 89 | if dropout_rate>0.0: 90 | r.add_module( 91 | name+connector+'dropout', 92 | nn.Dropout(p=dropout_rate)) 93 | 94 | return r -------------------------------------------------------------------------------- /Module/RUN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # In[1]: 5 | 6 | 7 | import sys 8 | 9 | from hzhu_metrics_class import * 10 | from hzhu_metrics_saliency import * 11 | from hzhu_data import * 12 | from hzhu_learn import * 13 | from hzhu_MTL_UNet import * 14 | from hzhu_gen import * 15 | 16 | import torch 17 | from torch import nn as nn 18 | import torch.optim as optim 19 | import torchvision 20 | 21 | import matplotlib.pyplot as plt 22 | import matplotlib 23 | import numpy as np 24 | import os 25 | import copy 26 | 27 | matplotlib.use('Agg') 28 | plt.rcParams['axes.facecolor'] = 'white' 29 | 30 | import argparse 31 | 32 | 33 | # In[2]: 34 | 35 | 36 | if __name__ == '__main__': 37 | print('torch.get_num_threads()=%d'%torch.get_num_threads()) 38 | 39 | lr, patience_reduce_lr = 1e-4, 40 40 | optimizer_dict = {'optimizer':optim.Adam, 'param':{}, 'name':'Adam'} 41 | lr_factor = 0.1 42 | lr_min = 1.0e-8 43 | epoch_max = 1024 44 | duration_max = 23.5*60*60 #seconds 10.5hour 45 | patience_early_stop = patience_reduce_lr*2+3 46 | batch_size = 6 47 | 48 | lg_sigma_image = None 49 | lg_sigma_class = 0.0 50 | 51 | down = 5 52 | blur = 500 53 | 54 | classification_loss = nn.CrossEntropyLoss() 55 | saliency_pred_loss = nn.KLDivLoss(reduction='batchmean') 56 | 57 | Metrics = {'class':MetricsHandle_Class, 'saliency':MetricsHandle_Saliency} 58 | Model = MTL_UNet_preset 59 | 60 | name = 'NET' 61 | folder_string = 'run' 62 | qH = QuickHelper(path=os.getcwd()+'/'+folder_string) 63 | print('New Folder name: %s'%qH.ID) 64 | print(folder_string) 65 | 66 | data_timer = QuickTimer() 67 | path = 'D:/Gaze Dataset/MIMIC-CXR & GAZE (master)/Data_Attention/test_down5_blur500' 68 | batch_size = 2 69 | epoch_max = 5 70 | 71 | dataAll = DataMaster(path=path, batch_size=batch_size) 72 | print('Data Preparing time: %fsec'%data_timer()) 73 | 74 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 75 | Net = Model( 76 | device=device, 77 | out_dict={'class':3, 'image':1}, 78 | loss_dict={'class':lg_sigma_class, 'image':lg_sigma_image}) 79 | Net.save_params(name='', path=qH()) 80 | 81 | netLearn = NetLearn( 82 | net=Net, 83 | dataAll=dataAll, 84 | criterion={'class':classification_loss, 'saliency':saliency_pred_loss}, 85 | optimizer_dict=optimizer_dict, 86 | lr=lr, 87 | lr_min=lr_min, 88 | lr_factor=lr_factor, 89 | epoch_max=epoch_max, 90 | duration_max=duration_max, 91 | patience_reduce_lr=patience_reduce_lr, 92 | patience_early_stop=patience_early_stop, 93 | device=device, 94 | metrics=Metrics, 95 | name=name, 96 | path=qH()) 97 | 98 | netLearn.train() 99 | 100 | print(netLearn.evaluate()) 101 | #netLearn.remove_saved() 102 | netLearn.remove_saved_optim() 103 | netLearn.remove_saved_sched() 104 | netLearn.save_params(name='', path=qH()) 105 | qH.summary() 106 | 107 | -------------------------------------------------------------------------------- /Module/hzhu_metrics_class.py: -------------------------------------------------------------------------------- 1 | import sys, json 2 | 3 | import torch 4 | from torch import nn as nn 5 | 6 | import copy 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | import matplotlib 12 | matplotlib.use('Agg') 13 | from matplotlib.backends.backend_pdf import PdfPages 14 | 15 | import sklearn.metrics as M 16 | 17 | class MetricsHandle_Class: 18 | 19 | def __init__(self): 20 | self.data = [] 21 | 22 | def __len__(self): 23 | return len(self.data) 24 | 25 | def add_data(self, Y, Y_pred): 26 | 27 | Y = Y.detach().clone().cpu() 28 | Y_pred = nn.Softmax(dim=1)(Y_pred.detach().clone().cpu()) 29 | 30 | batch_num = Y_pred.shape[0] 31 | 32 | for i in range(batch_num): 33 | local_Y = Y[i:i+1] 34 | local_Y_pred = Y_pred[i,:] 35 | 36 | local_data = {} 37 | local_data['Y'] = local_Y 38 | local_data['Y_pred'] = local_Y_pred 39 | 40 | self.data.append(local_data) 41 | 42 | def __getitem__(self, i): 43 | return self.data[i] 44 | 45 | def compute_classification_report(self): 46 | Y_true = torch.cat([self[i]['Y'] for i in range(len(self))], dim=0).float().numpy() 47 | Y_true_onehot = one_hot_encoding(Y_true, class_num=3) 48 | Y_score = torch.stack([self[i]['Y_pred'] for i in range(len(self))], dim=0).numpy() 49 | Y_pred = np.argmax(Y_score, axis=1) 50 | 51 | self.classification_report = M.classification_report(Y_true, Y_pred, output_dict=True) 52 | 53 | for i in range(3): 54 | class_name = 'class_%d'%(i) 55 | self.classification_report[class_name+'_fpr'], self.classification_report[class_name+'_tpr'], _ = \ 56 | M.roc_curve(y_score=Y_score[:,i], y_true=Y_true_onehot[:,i]) 57 | self.classification_report[class_name+'_ROC_AUC'] = \ 58 | M.auc(self.classification_report[class_name+'_fpr'], self.classification_report[class_name+'_tpr']) 59 | 60 | self.classification_report['micro_fpr'], self.classification_report['micro_tpr'], _ = \ 61 | M.roc_curve(y_score=Y_score.ravel(), y_true=Y_true_onehot.ravel()) 62 | self.classification_report['micro_ROC_AUC'] = \ 63 | M.auc(self.classification_report['micro_fpr'], self.classification_report['micro_tpr']) 64 | 65 | for item in self.classification_report: 66 | if isinstance(self.classification_report[item], np.ndarray): 67 | self.classification_report[item] = self.classification_report[item].tolist() 68 | 69 | def get_evaluation(self): 70 | if not hasattr(self, 'classification_report'): 71 | self.compute_classification_report() 72 | return self.classification_report 73 | 74 | def get_key_evaluation(self): 75 | Y_true = torch.cat([self[i]['Y'] for i in range(len(self))], dim=0).float().numpy() 76 | Y_score = torch.stack([self[i]['Y_pred'] for i in range(len(self))], dim=0).numpy() 77 | Y_pred = np.argmax(Y_score, axis=1) 78 | return M.accuracy_score(Y_true, Y_pred) 79 | 80 | def save_outputs(self, name, path): 81 | r = [] 82 | for i in range(len(self)): 83 | item = self[i] 84 | r.append({'Y':item['Y'].tolist(), 'Y_pred':item['Y_pred'].tolist()}) 85 | 86 | r = pd.DataFrame(r) 87 | r.to_csv(path+'/'+name+'.csv') 88 | 89 | def save_classification_report(self, name, path): 90 | if not hasattr(self, 'classification_report'): 91 | self.compute_classification_report() 92 | with open(path+'/'+name+'.json', 'w') as f: 93 | json.dump(self.classification_report, f, ensure_ascii=False, indent=4) 94 | 95 | def one_hot_encoding(x, class_num=None): 96 | if class_num is None: 97 | class_num = np.max(x)+1 98 | r = [] 99 | for item in x: 100 | tmp = np.zeros((1,class_num), dtype=np.int32) 101 | tmp[0,int(item)] = 1.0 102 | r.append(tmp) 103 | return np.concatenate(r, axis=0) -------------------------------------------------------------------------------- /Data/hzhu_gen.py: -------------------------------------------------------------------------------- 1 | import sys, copy, os, datetime, re, uuid, time, json, random, string 2 | 3 | global global_random_seed 4 | global global_ID 5 | global_random_seed = int(time.time()*1e7) 6 | global_ID = uuid.uuid4() 7 | random.seed(global_random_seed) 8 | 9 | def ls_file(path=os.getcwd()): 10 | assert isinstance(path, str), 'path type error @hzhu_gen::ls_file(path)' 11 | files = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))] 12 | return files 13 | 14 | def ls_dir(path=os.getcwd()): 15 | assert isinstance(path, str), 'path type error @hzhu_gen::ls_dir(path)' 16 | files = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))] 17 | return files 18 | 19 | def ls_all(path=os.getcwd()): 20 | assert isinstance(path, str), 'path type error @hzhu_gen::ls_all(path)' 21 | return os.listdir(path) 22 | 23 | def str_criteria(word, criteria): 24 | assert isinstance(word, str), 'word type error @hzhu_gen::str_criteria(word, criteria)' 25 | if isinstance(criteria, str): 26 | if criteria in word: return True 27 | else: return False 28 | if isinstance(criteria, list): 29 | for item in criteria: 30 | J = str_criteria(word, item) 31 | if J==True: return True 32 | return False 33 | if isinstance(criteria, tuple): 34 | for item in criteria: 35 | J = str_criteria(word, item) 36 | if J==False: return False 37 | return True 38 | assert False, 'criteria type error @hzhu_gen::str_criteria(word, criteria)' 39 | 40 | def ls_name(path=os.getcwd(), name=None): 41 | ls = ls_file(path) 42 | if name is None: 43 | return ls 44 | return [f for f in ls if str_criteria(f,name)] 45 | 46 | def disp(data): 47 | print(to_str(data)) 48 | 49 | def to_str(data): 50 | try: 51 | return json.dumps(data, indent=4) 52 | except: 53 | return str(data) 54 | 55 | def random_str(length=5): 56 | letters = string.ascii_letters+string.digits 57 | return ''.join(random.choice(letters) for i in range(length)) 58 | 59 | def extract_number(s): 60 | return [float(number) for number in re.findall(r"[-+]?\d*\.?\d+|[-+]?\d+", s)] 61 | 62 | def create_folder(path): 63 | if not os.path.exists(path): 64 | os.makedirs(path) 65 | return True 66 | else: 67 | return False 68 | 69 | def read_file(name, path=os.getcwd()): 70 | file_name = path+'/'+name 71 | with open(file_name, 'r') as file: 72 | return file.read() 73 | 74 | def read_file_by_line(name, path=os.getcwd()): 75 | file_name = path+'/'+name 76 | r = [] 77 | with open(file_name, 'r') as file: 78 | for line in file: 79 | r.append(line) 80 | return r 81 | 82 | class QuickTimer: 83 | 84 | def __init__(self): 85 | self.start_time = time.perf_counter() 86 | def __call__(self): 87 | return time.perf_counter()-self.start_time 88 | def start(self): 89 | self.__init__() 90 | 91 | class QuickHelper: 92 | 93 | def __init__(self, path=None, name=None, ID_length=5): 94 | if path is None: self.path = os.getcwd() 95 | else: self.path = path 96 | if name is None: self.name = '' 97 | else: self.name = name 98 | 99 | self.ID_length = ID_length 100 | self.init() 101 | 102 | def init(self): 103 | counter = 0 104 | while True: 105 | self.ID = random_str(self.ID_length) 106 | J = create_folder(self.path+'/'+self.name+'_'+self.ID) 107 | if J: 108 | self.dir = self.path+'/'+self.name+'_'+self.ID 109 | break 110 | else: 111 | print('Folder already exists! Creating new ID.') 112 | counter += 1 113 | if counter>=10: 114 | self.ID_length += 1 115 | counter = 0 116 | self.timer = QuickTimer() 117 | 118 | def time_elapsed(self): 119 | return self.timer() 120 | 121 | def __call__(self): 122 | return self.dir 123 | 124 | def __str__(self): 125 | return '- QuickHelper:\n - ID = %s\n - path = %s\n - elapsed time = %f(sec)\n'%(self.ID, self.dir, self.time_elapsed()) 126 | 127 | def summary(self): 128 | content = self.__str__() 129 | print(content) 130 | with open(self.dir+'/QuickHelper_summary.txt','w') as file: 131 | file.write(content) -------------------------------------------------------------------------------- /Module/hzhu_gen.py: -------------------------------------------------------------------------------- 1 | import sys, copy, os, datetime, re, uuid, time, json, random, string 2 | 3 | global global_random_seed 4 | global global_ID 5 | global_random_seed = int(time.time()*1e7) 6 | global_ID = uuid.uuid4() 7 | random.seed(global_random_seed) 8 | 9 | def ls_file(path=os.getcwd()): 10 | assert isinstance(path, str), 'path type error @hzhu_gen::ls_file(path)' 11 | files = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))] 12 | return files 13 | 14 | def ls_dir(path=os.getcwd()): 15 | assert isinstance(path, str), 'path type error @hzhu_gen::ls_dir(path)' 16 | files = [f for f in os.listdir(path) if os.path.isdir(os.path.join(path, f))] 17 | return files 18 | 19 | def ls_all(path=os.getcwd()): 20 | assert isinstance(path, str), 'path type error @hzhu_gen::ls_all(path)' 21 | return os.listdir(path) 22 | 23 | def str_criteria(word, criteria): 24 | assert isinstance(word, str), 'word type error @hzhu_gen::str_criteria(word, criteria)' 25 | if isinstance(criteria, str): 26 | if criteria in word: return True 27 | else: return False 28 | if isinstance(criteria, list): 29 | for item in criteria: 30 | J = str_criteria(word, item) 31 | if J==True: return True 32 | return False 33 | if isinstance(criteria, tuple): 34 | for item in criteria: 35 | J = str_criteria(word, item) 36 | if J==False: return False 37 | return True 38 | assert False, 'criteria type error @hzhu_gen::str_criteria(word, criteria)' 39 | 40 | def ls_name(path=os.getcwd(), name=None): 41 | ls = ls_file(path) 42 | if name is None: 43 | return ls 44 | return [f for f in ls if str_criteria(f,name)] 45 | 46 | def disp(data): 47 | print(to_str(data)) 48 | 49 | def to_str(data): 50 | try: 51 | return json.dumps(data, indent=4) 52 | except: 53 | return str(data) 54 | 55 | def random_str(length=5): 56 | letters = string.ascii_letters+string.digits 57 | return ''.join(random.choice(letters) for i in range(length)) 58 | 59 | def extract_number(s): 60 | return [float(number) for number in re.findall(r"[-+]?\d*\.?\d+|[-+]?\d+", s)] 61 | 62 | def create_folder(path): 63 | if not os.path.exists(path): 64 | os.makedirs(path) 65 | return True 66 | else: 67 | return False 68 | 69 | def read_file(name, path=os.getcwd()): 70 | file_name = path+'/'+name 71 | with open(file_name, 'r') as file: 72 | return file.read() 73 | 74 | def read_file_by_line(name, path=os.getcwd()): 75 | file_name = path+'/'+name 76 | r = [] 77 | with open(file_name, 'r') as file: 78 | for line in file: 79 | r.append(line) 80 | return r 81 | 82 | class QuickTimer: 83 | 84 | def __init__(self): 85 | self.start_time = time.perf_counter() 86 | def __call__(self): 87 | return time.perf_counter()-self.start_time 88 | def start(self): 89 | self.__init__() 90 | 91 | class QuickHelper: 92 | 93 | def __init__(self, path=None, name=None, ID_length=5): 94 | if path is None: self.path = os.getcwd() 95 | else: self.path = path 96 | if name is None: self.name = '' 97 | else: self.name = name 98 | 99 | self.ID_length = ID_length 100 | self.init() 101 | 102 | def init(self): 103 | counter = 0 104 | while True: 105 | self.ID = random_str(self.ID_length) 106 | J = create_folder(self.path+'/'+self.name+'_'+self.ID) 107 | if J: 108 | self.dir = self.path+'/'+self.name+'_'+self.ID 109 | break 110 | else: 111 | print('Folder already exists! Creating new ID.') 112 | counter += 1 113 | if counter>=10: 114 | self.ID_length += 1 115 | counter = 0 116 | self.timer = QuickTimer() 117 | 118 | def time_elapsed(self): 119 | return self.timer() 120 | 121 | def __call__(self): 122 | return self.dir 123 | 124 | def __str__(self): 125 | return '- QuickHelper:\n - ID = %s\n - path = %s\n - elapsed time = %f(sec)\n'%(self.ID, self.dir, self.time_elapsed()) 126 | 127 | def summary(self): 128 | content = self.__str__() 129 | print(content) 130 | with open(self.dir+'/QuickHelper_summary.txt','w') as file: 131 | file.write(content) -------------------------------------------------------------------------------- /Module/hzhu_metrics_saliency.py: -------------------------------------------------------------------------------- 1 | import sys, json 2 | 3 | import torch 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | import copy 8 | import os 9 | import numpy as np 10 | import pandas as pd 11 | import matplotlib.pyplot as plt 12 | import matplotlib 13 | matplotlib.use('Agg') 14 | from matplotlib.backends.backend_pdf import PdfPages 15 | 16 | import sklearn.metrics as M 17 | import scipy 18 | 19 | class MetricsHandle_Saliency: 20 | 21 | def __init__(self): 22 | self.data = [] 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def add_data(self, Y, Y_pred): 28 | 29 | #X = X.detach().clone().cpu() 30 | Y = Y.detach().clone().cpu() 31 | Y_pred = Y_pred.detach().clone().cpu() 32 | 33 | batch_num = Y_pred.shape[0] 34 | 35 | for i in range(batch_num): 36 | local_Y = Y[i,:,:,:] 37 | local_Y_pred = Y_pred[i,:,:,:] 38 | 39 | local_data = {} 40 | local_data['Y'] = local_Y 41 | local_data['Y_pred_log'] = local_Y_pred 42 | local_data['Y_pred'] = torch.exp(local_data['Y_pred_log']) 43 | 44 | check_sum = local_data['Y_pred'].sum() 45 | if check_sum>1.0+1e-3 or check_sum<1.0-1e-3: 46 | print('Y_pred check sum failed with %e'%check_sum) 47 | local_data['Y_pred'] /= check_sum 48 | 49 | check_sum = local_data['Y'].sum() 50 | if check_sum>1.0+1e-3 or check_sum<1.0-1e-3: 51 | print('Y check sum failed with %e'%check_sum) 52 | local_data['Y'] /= check_sum 53 | 54 | self.data.append(local_data) 55 | 56 | def __getitem__(self, i): 57 | return self.data[i] 58 | 59 | def compute_prediction_report(self): 60 | self.KL_loss_list = [] 61 | self.CC_list = [] 62 | 63 | self.EMD_list = [] 64 | self.histogram_similarity_list = [] 65 | 66 | with torch.no_grad(): 67 | for item in self.data: 68 | 69 | KL_loss = F.kl_div(item['Y_pred_log'], item['Y'], reduction='batchmean') 70 | self.KL_loss_list.append(KL_loss) 71 | 72 | CC, p = scipy.stats.pearsonr(item['Y_pred'].flatten().numpy(), item['Y'].flatten().numpy()) 73 | self.CC_list.append(CC) 74 | 75 | HI = torch.minimum(item['Y_pred'], item['Y']).sum() 76 | if HI>1.0+1e-5: 77 | print('Invalid HI encountered', HI) 78 | else: 79 | self.histogram_similarity_list.append(HI) 80 | 81 | self.KL_loss_list = np.array(self.KL_loss_list) 82 | self.histogram_similarity_list = np.array(self.histogram_similarity_list) 83 | self.CC_list = np.array(self.CC_list) 84 | 85 | self.prediction_report = { 86 | 'KL_mean': float(self.KL_loss_list.mean()), 87 | 'KL_median': float(np.median(self.KL_loss_list)), 88 | 'KL_std': float(self.KL_loss_list.std()), 89 | 90 | 'CC_mean': float(self.CC_list.mean()), 91 | 'CC_median': float(np.median(self.CC_list)), 92 | 'CC_std': float(self.CC_list.std()), 93 | 94 | 'HS_mean': float(self.histogram_similarity_list.mean()), 95 | 'HS_median': float(np.median(self.histogram_similarity_list)), 96 | 'HS_std': float(self.histogram_similarity_list.std())} 97 | 98 | def get_evaluation(self): 99 | if not hasattr(self, 'prediction_report'): 100 | self.compute_prediction_report() 101 | return self.prediction_report 102 | 103 | def get_key_evaluation(self): 104 | self.metrics_list = [] 105 | 106 | with torch.no_grad(): 107 | for item in self.data: 108 | KL_loss = F.kl_div(item['Y_pred_log'], item['Y'], reduction='batchmean') 109 | self.metrics_list.append(KL_loss) 110 | 111 | self.metrics_list = np.array(self.metrics_list) 112 | return float(self.metrics_list.mean()) 113 | 114 | def save_prediction_report(self, name, path): 115 | if not hasattr(self, 'prediction_report'): 116 | self.compute_prediction_report() 117 | with open(path+'/'+name+'.json', 'w') as f: 118 | json.dump(self.prediction_report, f, ensure_ascii=False, indent=4) -------------------------------------------------------------------------------- /Data/hzhu_data_raw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 as cv 3 | import matplotlib.pyplot as plt 4 | import pandas as pd 5 | import torch, json, pickle, copy 6 | 7 | from torch.nn import functional as F 8 | 9 | from hzhu_gen import * 10 | from pydicom import dcmread 11 | from scipy.sparse import coo_matrix 12 | 13 | from datetime import datetime 14 | 15 | class GazeDataHandle: 16 | 17 | def __init__(self, root_path): 18 | self.root_path = root_path 19 | self.master_sheet = pd.read_csv(self.root_path+'/'+'master_sheet.csv') 20 | self.eye_gaze = pd.read_csv(self.root_path+'/'+'eye_gaze.csv') 21 | 22 | self.selected_df = copy.deepcopy(self.master_sheet) 23 | self.selected_DICOM_ID = self.selected_df['dicom_id'].tolist() 24 | 25 | self.process() 26 | 27 | def process(self): 28 | self.groups = self.eye_gaze.groupby(['DICOM_ID']) 29 | self.data_gaze = {item[0]:item[1] for item in self.groups if len(item[0])==44 and item[0] in self.selected_DICOM_ID} 30 | 31 | def __getitem__(self, i): 32 | ID = self.selected_DICOM_ID[i] 33 | info = self.selected_df[self.selected_df['dicom_id']==ID] 34 | gaze = self.data_gaze[ID] 35 | 36 | return {'info':info, 'gaze_raw':gaze} 37 | 38 | def __len__(self): 39 | return len(self.selected_DICOM_ID) 40 | 41 | class CXRDataHandle: 42 | 43 | def __init__(self, root_path): 44 | self.root_path = root_path 45 | 46 | def __getitem__(self, info): 47 | path = info['path'].iloc[0][10:] 48 | image = dcmread(self.root_path+'/'+path).pixel_array.astype(np.float32) 49 | return image 50 | 51 | class SegDataHandle: 52 | 53 | def __init__(self, root_path): 54 | self.root_path = root_path 55 | 56 | def __getitem__(self, info): 57 | ID = info['dicom_id'].iloc[0] 58 | left_lung = cv.imread(self.root_path+'/audio_segmentation_transcripts/'+ID+'/left_lung.png', cv.IMREAD_UNCHANGED).astype(np.bool_) 59 | mediastanum = cv.imread(self.root_path+'/audio_segmentation_transcripts/'+ID+'/mediastanum.png', cv.IMREAD_UNCHANGED).astype(np.bool_) 60 | right_lung = cv.imread(self.root_path+'/audio_segmentation_transcripts/'+ID+'/right_lung.png', cv.IMREAD_UNCHANGED).astype(np.bool_) 61 | 62 | return {'lung':(left_lung+right_lung).astype(np.bool_), 'heart':mediastanum} 63 | 64 | class MasterDataHandle: 65 | 66 | def __init__(self, gaze_path, cxr_path, blur): 67 | self.blur = blur 68 | self.gaze_path = gaze_path 69 | self.cxr_path = cxr_path 70 | self.gazeData = GazeDataHandle(self.gaze_path) 71 | self.segData = SegDataHandle(self.gaze_path) 72 | self.cxrData = CXRDataHandle(self.cxr_path) 73 | 74 | def __len__(self): 75 | return len(self.gazeData) 76 | 77 | def __getitem__(self, i): 78 | r = self.gazeData[i] 79 | r['cxr'] = self.cxrData[r['info']] 80 | seg = self.segData[r['info']] 81 | r['lung'] = seg['lung'] 82 | r['heart'] = seg['heart'] 83 | r['gaze'] = get_gaze_heatmap(r, self.blur) 84 | 85 | if r['info']['Normal'].iloc[0]==1 and r['info']['CHF'].iloc[0]==0 and r['info']['pneumonia'].iloc[0]==0: 86 | r['Y'] = 0 87 | elif r['info']['Normal'].iloc[0]==0 and r['info']['CHF'].iloc[0]==1 and r['info']['pneumonia'].iloc[0]==0: 88 | r['Y'] = 1 89 | elif r['info']['Normal'].iloc[0]==0 and r['info']['CHF'].iloc[0]==0 and r['info']['pneumonia'].iloc[0]==1: 90 | r['Y'] = 2 91 | else: 92 | assert False 93 | 94 | return r 95 | 96 | def plot(self, i, blur): 97 | data = self[i] 98 | self.process(data, downsample=5) 99 | 100 | plt.figure(figsize=(10,8)) 101 | plt.subplot(2,3,1) 102 | plt.imshow(data['cxr']) 103 | plt.title(data['Y']) 104 | 105 | plt.subplot(2,3,2) 106 | plt.imshow(data['gaze']) 107 | plt.title(data['gaze'].shape) 108 | 109 | plt.subplot(2,3,3) 110 | plt.imshow(data['gaze']/torch.max(data['gaze'])+data['cxr']/torch.max(data['cxr'])) 111 | 112 | plt.subplot(2,3,4) 113 | plt.imshow(data['heart']) 114 | 115 | plt.subplot(2,3,5) 116 | plt.imshow(data['lung']) 117 | 118 | 119 | 120 | def save(self, data, path, downsample): 121 | local_name = '%s'%data['info']['dicom_id'].iloc[0] 122 | 123 | item = ['gaze', 'cxr', 'Y', 'heart', 'lung'] 124 | r = {} 125 | for key in item: 126 | r[key] = data[key] 127 | 128 | torch.save(r, path+'/'+local_name+'.pt') 129 | 130 | def process(self, data, downsample): 131 | 132 | data['gaze'] = torch.tensor(data['gaze'], requires_grad=False, dtype=torch.float32) 133 | data['cxr'] = torch.tensor(data['cxr'], requires_grad=False, dtype=torch.float32) 134 | data['lung'] = torch.tensor(data['lung'], requires_grad=False, dtype=torch.bool) 135 | data['heart'] = torch.tensor(data['heart'], requires_grad=False, dtype=torch.bool) 136 | data['Y'] = torch.tensor(data['Y'], requires_grad=False, dtype=torch.int8) 137 | 138 | shape = data['cxr'].shape 139 | if shape[0]<=shape[1]: 140 | torch.transpose(data['cxr'], 0, 1) 141 | torch.transpose(data['gaze'], 0, 1) 142 | torch.transpose(data['lung'], 0, 1) 143 | torch.transpose(data['heart'], 0, 1) 144 | 145 | shape = data['cxr'].shape 146 | H = (int(3056/downsample/32)+1)*32*downsample 147 | W = (int(2544/downsample/32)+1)*32*downsample 148 | if shape[0]<=H or shape[1]<=W: 149 | padding_left = int((W-shape[1])/2) 150 | padding_right = W-shape[1]-padding_left 151 | padding_top = int((H-shape[0])/2) 152 | padding_bottom = H-shape[0]-padding_top 153 | data['cxr'] = F.pad(data['cxr'], (padding_left, padding_right, padding_top, padding_bottom)) 154 | data['gaze'] = F.pad(data['gaze'], (padding_left, padding_right, padding_top, padding_bottom)) 155 | data['lung'] = F.pad(data['lung'], (padding_left, padding_right, padding_top, padding_bottom)) 156 | data['heart'] = F.pad(data['heart'], (padding_left, padding_right, padding_top, padding_bottom)) 157 | 158 | data['gaze'] = data['gaze'][0:H+1:downsample,0:W+1:downsample] 159 | data['gaze'] -= data['gaze'].min() 160 | data['gaze'] /= data['gaze'].max() 161 | 162 | data['cxr'] = data['cxr'][0:H+1:downsample,0:W+1:downsample] 163 | data['cxr'] /= data['cxr'].max() 164 | 165 | data['lung'] = data['lung'][0:H+1:downsample,0:W+1:downsample] 166 | data['heart'] = data['heart'][0:H+1:downsample,0:W+1:downsample] 167 | 168 | def save_all(self, root_path, downsample, fraction, seed): 169 | folders = ['Test','Valid','Train'] 170 | counter_full = {item:{i:0 for i in range(3)} for item in folders} 171 | counter = {i:0 for i in range(3)} 172 | folder_path = [root_path+'/'+item for item in folders] 173 | 174 | for item in folder_path: 175 | create_folder(item) 176 | 177 | N = int(len(self)*fraction) 178 | index_use, index_n = torch.utils.data.random_split( 179 | range(len(self)), [N, len(self)-N], generator=torch.Generator().manual_seed(seed)) 180 | 181 | for idx, i in enumerate(index_use): 182 | 183 | data = self[i] 184 | self.process(data=data, downsample=downsample) 185 | Y = int(data['Y']) 186 | counter[Y] += 1 187 | select = counter[Y]%10 188 | if select<=1: 189 | self.save(data=data, path=folder_path[0], downsample=downsample) 190 | counter_full[folders[0]][Y] += 1 191 | elif select==2: 192 | self.save(data=data, path=folder_path[1], downsample=downsample) 193 | counter_full[folders[1]][Y] += 1 194 | else: 195 | self.save(data=data, path=folder_path[2], downsample=downsample) 196 | counter_full[folders[2]][Y] += 1 197 | 198 | if idx%50==0: 199 | print(idx, '%f%%'%((idx+1)/len(index_use)*100)) 200 | print("-- Current Time =", datetime.now()) 201 | 202 | disp(counter_full) 203 | 204 | 205 | def get_gaze_dot(data): 206 | X = data['gaze_raw']['X_ORIGINAL'].to_numpy(dtype=np.int32, copy=True) 207 | Y = data['gaze_raw']['Y_ORIGINAL'].to_numpy(dtype=np.int32, copy=True) 208 | shape = data['cxr'].shape 209 | JX = np.logical_and(X>=0, X=0, Y=127: 218 | print('overflow @ get_gaze_dot(data)') 219 | r = r.astype(np.int8) 220 | 221 | return r 222 | 223 | def get_gaze_heatmap(data, blur): 224 | blur = blur+1 if blur%2==0 else blur 225 | r = get_gaze_dot(data).astype(np.float32) 226 | r = cv.GaussianBlur(r, (blur, blur), 0, 0) 227 | r = r/np.max(r) 228 | return r -------------------------------------------------------------------------------- /Module/hzhu_MTL_UNet.py: -------------------------------------------------------------------------------- 1 | from hzhu_net import * 2 | 3 | import torch, os, copy 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | def conv_block(in_ch, out_ch): 8 | conv = nn.Sequential( 9 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 10 | nn.BatchNorm2d(out_ch), 11 | nn.ReLU(inplace=True), 12 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 13 | nn.BatchNorm2d(out_ch), 14 | nn.ReLU(inplace=True)) 15 | return conv 16 | 17 | def up_conv(in_ch, out_ch): 18 | up = nn.Sequential( 19 | nn.Upsample(scale_factor=2), 20 | nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), 21 | nn.BatchNorm2d(out_ch), 22 | nn.ReLU(inplace=True) 23 | ) 24 | 25 | return up 26 | 27 | 28 | def forward(self, g, x): 29 | g1 = self.W_g(g) 30 | x1 = self.W_x(x) 31 | psi = self.relu(g1 + x1) 32 | psi = self.psi(psi) 33 | out = x * psi 34 | return out, psi 35 | 36 | def classification_head(in_features, mid_features, out_features, dropout_rate): 37 | if mid_features is not None: 38 | r = nn.Sequential() 39 | r.add_module('linear_1', nn.Linear(in_features=in_features, out_features=mid_features)) 40 | if dropout_rate is not None: 41 | if dropout_rate>0.0: 42 | r.add_module('dropout', nn.Dropout(p=dropout_rate)) 43 | r.add_module('relu_1', nn.ReLU()) 44 | r.add_module('linear_2', nn.Linear(in_features=mid_features, out_features=out_features)) 45 | return r 46 | else: 47 | return nn.Linear(in_features=in_features, out_features=out_features) 48 | 49 | class UNet_Chunk(Module): 50 | def __init__(self, in_channels, filter_list): 51 | super().__init__() 52 | 53 | self.in_channels = in_channels 54 | self.filter_list = filter_list 55 | 56 | self.Maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 57 | self.Maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 58 | self.Maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2) 59 | self.Maxpool4 = nn.MaxPool2d(kernel_size=2, stride=2) 60 | 61 | self.Conv1 = conv_block(self.in_channels, self.filter_list[0]) 62 | self.Conv2 = conv_block(self.filter_list[0], self.filter_list[1]) 63 | self.Conv3 = conv_block(self.filter_list[1], self.filter_list[2]) 64 | self.Conv4 = conv_block(self.filter_list[2], self.filter_list[3]) 65 | self.Conv5 = conv_block(self.filter_list[3], self.filter_list[4]) 66 | 67 | self.Up5 = up_conv(self.filter_list[4], self.filter_list[3]) 68 | self.Up_conv5 = conv_block(self.filter_list[4], self.filter_list[3]) 69 | 70 | self.Up4 = up_conv(self.filter_list[3], self.filter_list[2]) 71 | self.Up_conv4 = conv_block(self.filter_list[3], self.filter_list[2]) 72 | 73 | self.Up3 = up_conv(self.filter_list[2], self.filter_list[1]) 74 | self.Up_conv3 = conv_block(self.filter_list[2], self.filter_list[1]) 75 | 76 | self.Up2 = up_conv(self.filter_list[1], self.filter_list[0]) 77 | self.Up_conv2 = conv_block(self.filter_list[1], self.filter_list[0]) 78 | 79 | def forward(self, x): 80 | 81 | e1 = self.Conv1(x) 82 | 83 | e2 = self.Maxpool1(e1) 84 | e2 = self.Conv2(e2) 85 | 86 | e3 = self.Maxpool2(e2) 87 | e3 = self.Conv3(e3) 88 | 89 | e4 = self.Maxpool3(e3) 90 | e4 = self.Conv4(e4) 91 | 92 | e5 = self.Maxpool4(e4) 93 | e5 = self.Conv5(e5) 94 | 95 | d5 = self.Up5(e5) 96 | 97 | d5 = torch.cat((e4, d5), dim=1) 98 | d5 = self.Up_conv5(d5) 99 | 100 | d4 = self.Up4(d5) 101 | d4 = torch.cat((e3, d4), dim=1) 102 | d4 = self.Up_conv4(d4) 103 | 104 | d3 = self.Up3(d4) 105 | d3 = torch.cat((e2, d3), dim=1) 106 | d3 = self.Up_conv3(d3) 107 | 108 | d2 = self.Up2(d3) 109 | d2 = torch.cat((e1, d2), dim=1) 110 | d2 = self.Up_conv2(d2) 111 | 112 | return e5, d2 113 | 114 | class MTL_UNet(UNet_Chunk): 115 | 116 | def __init__(self, in_channels, filter_list, out_dict): 117 | super().__init__(in_channels, filter_list) 118 | self.out_dict = out_dict 119 | self.init() 120 | 121 | def init(self): 122 | self.dummy_tensor = nn.Parameter(torch.tensor(0), requires_grad=False) 123 | 124 | if self.out_dict is None: 125 | self.out_conv = nn.Conv2d(self.filter_list[0], 1, kernel_size=1, stride=1, padding=0) 126 | else: 127 | if 'class' in self.out_dict: 128 | if self.out_dict['class']>0: 129 | self.out_classification = classification_head( 130 | in_features=self.filter_list[0]+self.filter_list[-1], 131 | mid_features=self.filter_list[0], out_features=self.out_dict['class'], 132 | dropout_rate=0.25) 133 | if 'image' in self.out_dict: 134 | if self.out_dict['image']>0: 135 | self.out_conv_image = conv_bn_acti_drop( 136 | in_channels=self.filter_list[0], 137 | out_channels=self.filter_list[0], 138 | kernel_size=3, 139 | activation=nn.ReLU, 140 | normalize=nn.BatchNorm2d, 141 | padding=1, 142 | dropout_rate=0.0, 143 | sequential=None) 144 | self.out_conv_image.add_module( 145 | 'conv_last', nn.Conv2d(self.filter_list[0], self.out_dict['image'], kernel_size=1, stride=1, padding=0)) 146 | 147 | def forward(self, x): 148 | e5, d2 = super().forward(x) 149 | 150 | if self.out_dict is None: 151 | y = self.out_conv(d2) 152 | return self.dummy_tensor, y 153 | else: 154 | r = [] 155 | if 'class' in self.out_dict: 156 | if self.out_dict['class']>0: 157 | average_pool_e5 = e5.mean(dim=(-2,-1)) 158 | average_pool_d2 = d2.mean(dim=(-2,-1)) 159 | average_pool = torch.cat((average_pool_e5, average_pool_d2), dim=1) 160 | y_class = self.out_classification(average_pool) 161 | r.append(y_class) 162 | else: 163 | r.append(self.dummy_tensor) 164 | else: 165 | r.append(self.dummy_tensor) 166 | 167 | if 'image' in self.out_dict: 168 | if self.out_dict['image']>0: 169 | y_image = self.out_conv_image(d2) 170 | r.append(y_image) 171 | else: 172 | r.append(self.dummy_tensor) 173 | else: 174 | r.append(self.dummy_tensor) 175 | 176 | return tuple(r) 177 | 178 | class MTL_UNet_preset(MTL_UNet): 179 | 180 | def __init__(self, device, out_dict, loss_dict): 181 | self.device = device 182 | base = 64 if os.getcwd()[0] == '/' else 2 183 | super().__init__(in_channels=1, filter_list=[base*(2**i) for i in range(5)], out_dict=out_dict) 184 | 185 | self.loss_dict = loss_dict 186 | self.mt_param_init() 187 | 188 | self.to(self.device) 189 | 190 | def mt_param_init(self): 191 | if 'class' in self.out_dict: 192 | if self.out_dict['class']>0: 193 | if self.loss_dict['class'] is not None: 194 | self.lg_sigma_class = nn.Parameter(torch.tensor(self.loss_dict['class'], device=self.device, dtype=torch.float32)) 195 | else: 196 | self.lg_sigma_class = torch.tensor(0.0, device=self.device, dtype=torch.float32) 197 | 198 | if 'image' in self.out_dict: 199 | if self.out_dict['image']>0: 200 | if not isinstance(self.loss_dict['image'], (list, tuple)): 201 | self.loss_dict['image'] = [self.loss_dict['image'],] 202 | for item in self.loss_dict['image']: 203 | if item is not None: 204 | self.lg_sigma_image = nn.Parameter(torch.tensor(item, device=self.device, dtype=torch.float32)) 205 | else: 206 | self.lg_sigma_image = torch.tensor(0.0, device=self.device, dtype=torch.float32) 207 | 208 | def compute_loss_class(self, y_pred, y_true, loss_function): 209 | 210 | sigma = torch.exp(self.lg_sigma_class) 211 | loss_raw = loss_function(y_pred, y_true) 212 | loss_weighted = loss_raw/sigma/sigma+torch.log(sigma+1.0) 213 | 214 | return sigma, loss_raw, loss_weighted 215 | 216 | def compute_loss_image(self, y_pred, y_true, loss_function, idx): 217 | 218 | sigma = torch.exp(self.lg_sigma_image) 219 | loss_raw = loss_function(y_pred, y_true) 220 | loss_weighted = loss_raw/sigma/sigma/2.0+torch.log(sigma+1.0) 221 | 222 | return sigma, loss_raw, loss_weighted 223 | 224 | def compute_loss(self, y_class_pred, y_image_pred, y_class_true, y_image_true, loss_class, loss_image_list): 225 | 226 | class_sigma, class_loss_raw, class_loss_weighted = self.compute_loss_class( 227 | y_pred=y_class_pred, y_true=y_class_true, loss_function=loss_class) 228 | 229 | image_sigma, image_loss_raw, image_loss_weighted = self.compute_loss_image( 230 | y_pred=y_image_pred, y_true=y_image_true, loss_function=loss_image_list[0], idx=0) 231 | 232 | loss_sum = class_loss_weighted+image_loss_weighted 233 | 234 | r = {'loss_sum':loss_sum, 235 | 'class_loss_raw':class_loss_raw, 236 | 'image_loss_raw':image_loss_raw} 237 | 238 | return r 239 | 240 | def get_status(self): 241 | r = [] 242 | r.append(torch.exp(self.lg_sigma_class).detach().clone().cpu()) 243 | r.append(torch.exp(self.lg_sigma_image).detach().clone().cpu()) 244 | return r 245 | 246 | def get_status_str(self): 247 | stats = self.get_status() 248 | r = '' 249 | for item in stats: 250 | r += '%.2e '%item 251 | 252 | return r 253 | -------------------------------------------------------------------------------- /Module/hzhu_learn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import numpy as np 4 | import PIL 5 | import matplotlib.pyplot as plt 6 | import matplotlib 7 | import json 8 | matplotlib.use('Agg') 9 | 10 | import torch 11 | from torch.utils.data import Dataset, DataLoader 12 | from torchvision import transforms, utils 13 | from torch import optim 14 | 15 | from hzhu_gen import * 16 | from hzhu_data import * 17 | from hzhu_metrics_class import * 18 | 19 | class NetLearn: 20 | 21 | def __init__( 22 | self, 23 | net, 24 | dataAll, 25 | criterion, 26 | optimizer_dict, 27 | lr, 28 | lr_min, 29 | lr_factor, 30 | epoch_max, 31 | duration_max, 32 | patience_reduce_lr, 33 | patience_early_stop, 34 | device, 35 | metrics, 36 | name, 37 | path): 38 | 39 | self.quickTimer = QuickTimer() 40 | self.net = net 41 | self.dataAll = dataAll 42 | 43 | self.optimizer_dict = optimizer_dict 44 | self.lr = lr 45 | self.lr_min = lr_min 46 | self.lr_factor = lr_factor 47 | self.duration_max = duration_max 48 | self.epoch_max = epoch_max 49 | self.criterion = criterion 50 | 51 | self.device = device 52 | self.patience_reduce_lr = patience_reduce_lr 53 | self.patience_early_stop = patience_early_stop 54 | 55 | self.train_loss_list = [] 56 | self.valid_loss_list = [] 57 | self.test_loss_list = [] 58 | self.metrics_list = [] 59 | self.lr_list = [] 60 | 61 | self.name = name 62 | self.path = path 63 | self.ID = self.name+'_'+random_str() 64 | self.epoch = 0 65 | 66 | self.metrics = metrics 67 | 68 | self.set_optimizer() 69 | self.set_scheduler() 70 | 71 | self.model_name = 'NET.pt' 72 | self.optim_name = 'OPT.pt' 73 | self.sched_name = 'SCH.pt' 74 | 75 | self.create_save_path() 76 | 77 | print('ID:', self.ID) 78 | 79 | def set_optimizer(self): 80 | self.optimizer = self.optimizer_dict['optimizer'](self.net.parameters(), lr=self.lr, **self.optimizer_dict['param']) 81 | 82 | def set_scheduler(self): 83 | self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( 84 | self.optimizer, 85 | mode='min', 86 | factor=self.lr_factor, 87 | patience=self.patience_reduce_lr, 88 | eps=0, 89 | verbose=False) 90 | 91 | def train_iterate(self, dataLoader): 92 | self.epoch += 1 93 | self.net.train() 94 | loss_list = [] 95 | 96 | for data in dataLoader: 97 | X = data['cxr'].to(self.device).unsqueeze(1) 98 | 99 | Y_class = data['Y'].to(self.device).long() 100 | Y_saliency = data['gaze'].to(self.device).unsqueeze(1) 101 | Y_saliency = Y_saliency/Y_saliency.sum(dim=(-2,-1), keepdim=True) 102 | 103 | self.optimizer.zero_grad() 104 | Y_class_pred, Y_saliency_pred = self.net(X) 105 | Y_saliency_pred_shape = Y_saliency_pred.shape 106 | Y_saliency_pred = F.log_softmax(Y_saliency_pred.flatten(start_dim=-2, end_dim=-1), dim=-1).reshape(Y_saliency_pred_shape) 107 | 108 | net_list = self.net.compute_loss( 109 | y_class_pred=Y_class_pred, 110 | y_image_pred=Y_saliency_pred, 111 | y_class_true=Y_class, 112 | y_image_true=Y_saliency, 113 | loss_class=self.criterion['class'], 114 | loss_image_list=[self.criterion['saliency'],]) 115 | 116 | loss = net_list['loss_sum'] 117 | loss.backward() 118 | 119 | self.optimizer.step() 120 | loss_list.append(loss.detach().clone().cpu()) 121 | 122 | del data, X, Y_class, Y_saliency, Y_class_pred, Y_saliency_pred, net_list, loss 123 | 124 | return loss_list 125 | 126 | def eval_iterate(self, dataLoader): 127 | self.net.eval() 128 | loss_list = [] 129 | 130 | metrics_class = self.metrics['class']() 131 | metrics_saliency = self.metrics['saliency']() 132 | 133 | with torch.no_grad(): 134 | for data in dataLoader: 135 | X = data['cxr'].to(self.device).unsqueeze(1) 136 | 137 | Y_class = data['Y'].to(self.device).long() 138 | Y_saliency = data['gaze'].to(self.device).unsqueeze(1) 139 | Y_saliency = Y_saliency/Y_saliency.sum(dim=(-2,-1), keepdim=True) 140 | 141 | Y_class_pred, Y_saliency_pred = self.net(X) 142 | Y_saliency_pred_shape = Y_saliency_pred.shape 143 | Y_saliency_pred = F.log_softmax(Y_saliency_pred.flatten(start_dim=-2, end_dim=-1), dim=-1).reshape(Y_saliency_pred_shape) 144 | 145 | net_list = self.net.compute_loss( 146 | y_class_pred=Y_class_pred, 147 | y_image_pred=Y_saliency_pred, 148 | y_class_true=Y_class, 149 | y_image_true=Y_saliency, 150 | loss_class=self.criterion['class'], 151 | loss_image_list=[self.criterion['saliency'],]) 152 | 153 | metrics_class.add_data(Y_class, Y_class_pred) 154 | metrics_saliency.add_data(Y=Y_saliency, Y_pred=Y_saliency_pred) 155 | 156 | for item in net_list: 157 | tmp = {} 158 | tmp[item] = net_list[item].detach().clone().cpu() 159 | loss_list.append(tmp) 160 | 161 | del data, X, Y_class, Y_saliency, Y_class_pred, Y_saliency_pred, net_list 162 | 163 | return {'metrics_class':metrics_class, 'metrics_saliency':metrics_saliency, 'loss':pd.DataFrame(loss_list)} 164 | 165 | 166 | def save_net(self, path): 167 | torch.save(self.net.state_dict(), path+'/'+self.model_name) 168 | torch.save(self.optimizer.state_dict(), path+'/'+self.optim_name) 169 | torch.save(self.scheduler.state_dict(), path+'/'+self.sched_name) 170 | 171 | def load_net(self, path): 172 | self.net.load_state_dict(torch.load(path+'/'+self.model_name)) 173 | self.net.eval() 174 | self.optimizer.load_state_dict(torch.load(path+'/'+self.optim_name)) 175 | for pg in self.optimizer.param_groups: 176 | if len(self.lr_list)>0: 177 | pg['lr'] = self.lr_list[-1] 178 | self.scheduler.load_state_dict(torch.load(path+'/'+self.sched_name)) 179 | 180 | def create_save_path(self): 181 | self.save_path = self.path+'/'+self.ID 182 | 183 | if not os.path.exists(self.save_path): 184 | os.makedirs(self.save_path) 185 | else: 186 | print('Training folder already exists!') 187 | 188 | def train(self): 189 | self.valid_loss_min = -np.Inf 190 | self.epoch_best = 0 191 | while self.epoch=self.patience_early_stop: 248 | print('- Early stopping: max non-improving epoch reached at %d'%(self.epoch-self.epoch_best)) 249 | return True 250 | 251 | if self.quickTimer()>=self.duration_max: 252 | print('- Early stopping: Max duration reached %f>=%f (sec)'%(self.quickTimer(), self.duration_max)) 253 | return True 254 | 255 | def save_training_process(self): 256 | plt.figure() 257 | plt.subplot(3,1,1) 258 | plt.plot(self.train_loss_list) 259 | plt.plot(self.valid_loss_list) 260 | plt.title('loss') 261 | 262 | plt.subplot(3,1,2) 263 | plt.plot(self.lr_list) 264 | plt.title('learning rate') 265 | 266 | plt.subplot(3,1,3) 267 | plt.plot(self.metrics_list) 268 | plt.title('metrics') 269 | plt.savefig(self.save_path+'/training_process.png') 270 | 271 | plt.close() 272 | 273 | def remove_saved_net(self): 274 | if not hasattr(self, 'model_name'): 275 | print("The net file does not exist") 276 | return 277 | 278 | if not hasattr(self, 'save_path'): 279 | print("The net file does not exist") 280 | return 281 | 282 | if os.path.exists(self.save_path+'/'+self.model_name): 283 | os.remove(self.save_path+'/'+self.model_name) 284 | print("Saved network file deleted successfully") 285 | else: 286 | print("The net file does not exist") 287 | 288 | def remove_saved_optim(self): 289 | if not hasattr(self, 'optim_name'): 290 | print("The optim file does not exist") 291 | return 292 | 293 | if not hasattr(self, 'save_path'): 294 | print("The optim file does not exist") 295 | return 296 | 297 | if os.path.exists(self.save_path+'/'+self.optim_name): 298 | os.remove(self.save_path+'/'+self.optim_name) 299 | print("Saved optim file deleted successfully") 300 | else: 301 | print("The optim file does not exist") 302 | 303 | def remove_saved_sched(self): 304 | if not hasattr(self, 'sched_name'): 305 | print("The sched file does not exist") 306 | return 307 | 308 | if not hasattr(self, 'save_path'): 309 | print("The sched file does not exist") 310 | return 311 | 312 | if os.path.exists(self.save_path+'/'+self.sched_name): 313 | os.remove(self.save_path+'/'+self.sched_name) 314 | print("Saved sched file deleted successfully") 315 | else: 316 | print("The sched file does not exist") 317 | 318 | def remove_saved(self): 319 | self.remove_saved_net() 320 | self.remove_saved_optim() 321 | self.remove_saved_sched() 322 | 323 | def save_params(self, name, path): 324 | 325 | attr_list = [attr for attr in dir(self) if isinstance(getattr(self, attr), (list, tuple, dict, int, float, bool)) and not attr.startswith("_")] 326 | content = {} 327 | for attr in attr_list: 328 | content[attr] = getattr(self, attr) 329 | 330 | with open(path+'/'+'params_%s_'%(self.__class__.__name__)+name+'.txt', 'w') as file: 331 | try: 332 | json.dump(content, file, indent=4) 333 | except: 334 | print('Exception occured at hzhu_learn::NetLearn.save_params(..): content cannot be dumped!') 335 | file.write(str(content)) 336 | 337 | def evaluate(self): 338 | 339 | eval_test = self.eval_iterate(self.dataAll('Test')) 340 | eval_test['metrics_class'].compute_classification_report() 341 | eval_test['metrics_class'].save_classification_report('classification_report', self.save_path) 342 | eval_test['metrics_class'].save_outputs('classification_results', self.save_path) 343 | 344 | r = {key:eval_test['metrics_class'].classification_report[key]\ 345 | for key in eval_test['metrics_class'].classification_report if 'ROC_AUC' in key} 346 | r['accuracy'] = eval_test['metrics_class'].classification_report['accuracy'] 347 | 348 | eval_test['metrics_saliency'].compute_prediction_report() 349 | eval_test['metrics_saliency'].save_prediction_report('prediction_report', self.save_path) 350 | #eval_test['metrics'].save_outputs('prediction_results', self.save_path) 351 | 352 | r = {**eval_test['metrics_saliency'].prediction_report, **r} 353 | 354 | return json.dumps(r, indent=4) 355 | 356 | def index_expand(idx, image, n): 357 | a, b = idx 358 | r = [] 359 | for i in range(a-n,a+n+1): 360 | for j in range(b-n,b+n+1): 361 | if i>=0 and i=0 and j