├── .DS_Store ├── 4-Data.xlsx ├── config ├── .DS_Store ├── __init__.py └── opts.py ├── data_loader ├── .DS_Store ├── __init__.py ├── dataset.py └── utils.py ├── main.py ├── model ├── .DS_Store ├── bl_model.py ├── feature_extractors │ ├── .DS_Store │ ├── __init__.py │ └── mnasnet.py ├── model.py ├── nn_layers │ ├── .DS_Store │ ├── __init__.py │ ├── attn_layers.py │ ├── eesp.py │ ├── espnet_utils.py │ ├── ffn.py │ ├── multi_head_attn.py │ └── transformer.py └── yx_model.py ├── train_and_eval.py └── utils ├── .DS_Store ├── __init__.py ├── build_backbone.py ├── build_criterion.py ├── build_dataloader.py ├── build_model.py ├── build_optimizer.py ├── criterions ├── __pycache__ │ ├── blyx_loss.cpython-37.pyc │ ├── blyx_loss.cpython-39.pyc │ ├── focal_loss.cpython-37.pyc │ ├── focal_loss.cpython-39.pyc │ ├── smoothing_loss.cpython-37.pyc │ ├── smoothing_loss.cpython-39.pyc │ ├── survival_loss.cpython-37.pyc │ └── survival_loss.cpython-39.pyc ├── blyx_loss.py ├── focal_loss.py ├── smoothing_loss.py └── survival_loss.py ├── lr_scheduler.py ├── metric_utils.py ├── print_utils.py ├── roc_utils.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/.DS_Store -------------------------------------------------------------------------------- /4-Data.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/4-Data.xlsx -------------------------------------------------------------------------------- /config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/config/.DS_Store -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/config/__init__.py -------------------------------------------------------------------------------- /config/opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | import time 5 | import os 6 | from utils.utils import setup_seed 7 | from utils.build_criterion import get_criterion_opts 8 | # from utils.build_backbone import get_backbone_opts 9 | from utils.build_optimizer import get_optimizer_opts 10 | from utils.build_model import get_model_opts 11 | from utils.build_dataloader import get_dataset_opts 12 | from utils.lr_scheduler import get_scheduler_opts 13 | 14 | def general_opts(parser): 15 | group = parser.add_argument_group('General Options') 16 | 17 | group.add_argument('--log-interval', type=int, default=5, help='After how many iterations, we should print logs') 18 | group.add_argument('--epochs', type=int, default=100, help='Number of training epochs') 19 | group.add_argument('--seed', type=int, default=1, help='Random seed') 20 | group.add_argument('--config-file', type=str, default='', help='Config file if exists') 21 | group.add_argument('--msc-eval', action='store_true', default=False, help='Multi-scale evaluation') 22 | group.add_argument('--save-dir', type=str, default="../Results", help="Path to save results") 23 | group.add_argument('--attnmap-weight-dir', type=str, default="None") 24 | group.add_argument('--feat-dir', type=str, default="None") 25 | group.add_argument('--finetune', action='store_true', default=False) 26 | 27 | return parser 28 | 29 | def visualization_opts(parser): 30 | group = parser.add_argument_group('Visualization options') 31 | group.add_argument('--im-or-file', type=str, required=True, help='Name of the image or list of images in file to be visualized') 32 | group.add_argument('--is-type-file', action='store_true', default=False, help='Is it a file? ') 33 | group.add_argument('--img-extn-vis', type=str, required=True, help='Image extension without dot (example is png)') 34 | group.add_argument('--vis-res-dir', type=str, default='results_vis', help='Results after visualization') 35 | group.add_argument('--no-pt-files', action='store_true', default=False, help='Do not save data using torch.save') 36 | return parser 37 | 38 | def get_opts(parser): 39 | '''General options''' 40 | parser = general_opts(parser) 41 | parser = get_optimizer_opts(parser) 42 | parser = get_criterion_opts(parser) 43 | parser = get_model_opts(parser) 44 | parser = get_dataset_opts(parser) 45 | parser = get_scheduler_opts(parser) 46 | # parser = get_backbone_opts(parser) 47 | return parser 48 | 49 | def get_config(is_visualization=False): 50 | parser = argparse.ArgumentParser(description='M3') 51 | parser = get_opts(parser) 52 | if is_visualization: 53 | parser = visualization_opts(parser) 54 | args = parser.parse_args() 55 | setup_seed(args.seed) 56 | #torch.set_num_threads(args.data_workers) 57 | timestr = time.strftime("%Y%m%d-%H%M%S") 58 | name = os.path.basename(args.train_file).split(".")[0].split("_")[-1] 59 | if args.dataset == 'yingxiang': 60 | args.save_dir = '{}/{}/{}_{}_{}/sch_{}/{}/'.format(args.save_dir, 61 | args.dataset, 62 | args.model, 63 | args.yx_model, 64 | args.yx_cnn_name, 65 | name, 66 | timestr) 67 | elif args.dataset == 'bingli': 68 | args.save_dir = '{}/{}/{}_{}_{}/sch_{}/{}/'.format(args.save_dir, 69 | args.dataset, 70 | args.model, 71 | args.bl_model, 72 | args.bl_cnn_name, 73 | name, 74 | timestr) 75 | else: 76 | args.save_dir = '{}/{}/{}_{}_{}_{}_{}_{}/{}/{}/'.format(args.save_dir, 77 | args.dataset, 78 | args.model, 79 | args.blyx_model, 80 | args.bl_model, 81 | args.bl_cnn_name, 82 | args.yx_model, 83 | args.yx_cnn_name, 84 | name, 85 | timestr) 86 | os.makedirs(args.save_dir, exist_ok=True) 87 | return args, parser -------------------------------------------------------------------------------- /data_loader/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/data_loader/.DS_Store -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | from data_loader.dataset import * -------------------------------------------------------------------------------- /data_loader/dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | from PIL import Image 4 | import numpy as np 5 | from data_loader.utils import * 6 | from utils.print_utils import * 7 | import pandas as pd 8 | from copy import deepcopy 9 | import pickle 10 | 11 | BLACK_LST = [] 12 | YX_BLACK_LST = [] 13 | BL_BLACK_LST = [] 14 | yx_black_lst_v1 = [] 15 | bl_black_lst_v1 = [] 16 | 17 | class BLYXDataset(Dataset): 18 | def __init__(self, opts, split, split_file, printer, ignore=True, cohort=None): 19 | super().__init__() 20 | 21 | self.split = split 22 | self.opts = opts 23 | self.is_training = ("train" == split) 24 | 25 | self.bl_img_dir = opts.bl_img_dir 26 | self.bl_num_bags = opts.bl_num_bags 27 | self.bl_bag_size = opts.bl_bag_size 28 | self.bl_img_extn = opts.bl_img_extn 29 | self.bl_word_size = opts.bl_word_size 30 | 31 | self.yx_img_dir = opts.yx_img_dir 32 | self.yx_img_extn = opts.yx_img_extn 33 | self.yx_num_lesions = opts.yx_num_lesions 34 | self.yx_lesion_size = opts.yx_lesion_size 35 | 36 | pd_data = pd.read_excel(split_file) 37 | self.bl_num, self.yx_num, self.blyx_num = 0, 0, 0 38 | self.id_lst, self.name_lst, self.bl_pid_lst, self.yx_pid_lst, self.liaoxiao_lst = [], [], [], [], [] 39 | self.OS_lst, self.OSCensor_lst, self.PFS_lst, self.PFSCensor_lst = [], [], [], [] 40 | self.fangan_lst, self.label_lst, self.time_lst = [], [], [] 41 | self.sex_lst, self.age_lst, self.buwei_lst, self.xianshu_lst = [], [], [], [] 42 | self.ln_dis_lst = [] 43 | self.zl_ln_lst, self.zl_ln_pos_lst, self.zl_multi_lst, self.zl_per_lst = [], [], [], [] 44 | self.fenhua_lst, self.fenxing_lst, self.tils_lst, self.her2_lst, self.tumor_lst = [], [], [], [], [] 45 | self.yx_flag_lst, self.bl_flag_lst = [], [] 46 | for (id, name, bl_pid, yx_pid, start_time, liaoxiao, PFS, PFSCensor, OS, OSCensor, fangan, 47 | sex, age, buwei, xianshu, 48 | zl_ln, zl_ln_pos, zl_multi, zl_per, 49 | fenhua, fenxing, tils, her2, tumor, ln_dis, 50 | yx_flag, bl_flag 51 | ) in zip( 52 | pd_data["住院号"], pd_data["姓名"], pd_data["病理勾画编号"], pd_data["影像勾画编号"], pd_data["开始抗HER2治疗日期"], pd_data["最佳疗效"], 53 | pd_data["PFS"], pd_data["PFSCensor"], pd_data["OS"], pd_data["OSCensor"], pd_data["联合免疫(是=1,否=0)"], 54 | pd_data["性别"], pd_data["年龄"], pd_data["肿瘤部位"], pd_data["治疗线数"], 55 | pd_data["转移淋巴结(0 没有转移淋巴结 1存在转移淋巴结,但不融合 2存在融合的转移淋巴结)"], 56 | pd_data["转移淋巴结位置(0没有转移淋巴结 1 存在局域转移淋巴结 2 存在M分期的腹腔或腹膜后淋巴结转移 3存在纵隔或锁骨上淋巴结转移 4存在其他少见远隔部位淋巴结转移(如腋窝、颈旁、腹股沟等等区域)"], 57 | pd_data["肝或肺多发转移,多发为≥3个病灶(0没有多发肝转移和多发肺转移 1仅存在多发肝转移 2仅存在多发肺转移 3两者均有)"], 58 | pd_data["腹膜转移(0 无腹膜转移 1存在腹膜转移)"], 59 | pd_data["分化程度"], pd_data["LAUREN分型"], pd_data["肿瘤相关淋巴细胞TILs"], pd_data["HER2表达异质性"], pd_data["肿瘤占比"], pd_data["转移部位整理"], 60 | pd_data["影像采样时间"], pd_data["病理采样时间"]): 61 | 62 | if id in BLACK_LST: 63 | continue 64 | 65 | if isinstance(yx_pid, str): 66 | yx_pid = int(yx_pid.replace(",", "")) 67 | label = convert_label(self.opts.label_type, OS, OSCensor, liaoxiao, PFS=PFS, PFSCensor=PFSCensor) 68 | if label < -1 or (yx_pid in YX_BLACK_LST) or (bl_pid in BL_BLACK_LST): 69 | continue 70 | if str(bl_pid) != str(np.nan) and bl_flag == -1: 71 | continue 72 | if str(yx_pid) != str(np.nan) and yx_flag == -1: 73 | continue 74 | 75 | if cohort is not None: 76 | assert cohort in ["her2", "ci", "all"], (cohort) 77 | if cohort == "ci" and fangan == 0: 78 | continue 79 | elif cohort == "her2" and fangan == 1: 80 | continue 81 | elif cohort == "all": 82 | pass 83 | 84 | if opts.model == "bingli" and (str(bl_pid) == str(np.nan) or bl_flag == -1): 85 | continue 86 | elif opts.model == "yingxiang" and (str(yx_pid) == str(np.nan) or yx_flag == -1): 87 | continue 88 | 89 | self.bl_num += int(str(bl_pid) != str(np.nan)) 90 | self.yx_num += int(str(yx_pid) != str(np.nan)) 91 | self.blyx_num += int(str(bl_pid) != str(np.nan) and (str(yx_pid) != str(np.nan))) 92 | 93 | self.id_lst.append(id) 94 | self.name_lst.append(name) 95 | self.bl_pid_lst.append(str(bl_pid).replace(".0", "") if bl_pid != str(np.nan) else [str(np.nan),]) 96 | self.yx_pid_lst.append(str(int(yx_pid)) if not np.isnan(yx_pid) else str(np.nan)) 97 | self.liaoxiao_lst.append("NA" if str(liaoxiao)==str(np.nan) else liaoxiao) 98 | 99 | self.OS_lst.append(OS) 100 | self.OSCensor_lst.append(OSCensor) 101 | self.PFS_lst.append(PFS) 102 | self.PFSCensor_lst.append(PFSCensor) 103 | 104 | self.fangan_lst.append(fangan) 105 | self.label_lst.append(label) 106 | 107 | self.sex_lst.append(sex) 108 | self.age_lst.append(age) 109 | self.buwei_lst.append(buwei) 110 | self.xianshu_lst.append(xianshu) 111 | self.time_lst.append(str(start_time).split("/")[0]) 112 | 113 | self.zl_ln_lst.append(zl_ln) 114 | self.zl_ln_pos_lst.append(zl_ln_pos) 115 | self.zl_multi_lst.append(zl_multi) 116 | self.zl_per_lst.append(zl_per) 117 | 118 | self.fenhua_lst.append(fenhua) 119 | self.fenxing_lst.append(fenxing) 120 | self.tils_lst.append(tils) 121 | self.her2_lst.append(her2) 122 | self.tumor_lst.append(tumor) 123 | self.ln_dis_lst.append(ln_dis) 124 | 125 | self.bl_flag_lst.append(bl_flag) 126 | self.yx_flag_lst.append(yx_flag) 127 | 128 | self.diag_labels = deepcopy(self.label_lst) 129 | self.n_classes = len(np.unique(self.diag_labels)) 130 | self.printer = printer 131 | 132 | print_info_message('Samples in {}: {}\t(bl={}\tyx={}\tblyx={} ({:.2f}%))'.format( 133 | split_file, self.__len__(), self.bl_num, self.yx_num, self.blyx_num, 100.0*self.blyx_num/self.__len__()), self.printer) 134 | print_info_message('-- {} ({:.2f}%) Non-response | {} ({:.2f}%) Response | {} ({:.2f}%) Others'.format( 135 | sum(np.asarray(self.label_lst)==0), 100.0*sum(np.asarray(self.label_lst)==0)/self.__len__(), 136 | sum(np.asarray(self.label_lst)==1), 100.0*sum(np.asarray(self.label_lst)==1)/self.__len__(), 137 | sum(np.asarray(self.label_lst)==-1), 100.0*sum(np.asarray(self.label_lst)==-1)/self.__len__(), 138 | ), self.printer) 139 | 140 | def __len__(self): 141 | return len(self.bl_pid_lst) 142 | 143 | def _generate_mask_bags_label(self, mask_bags): 144 | # mask_bags: (N_B, B_H, B_W) 145 | mask_bags_label = [] 146 | mask_bags = mask_bags.reshape((mask_bags.shape[0], -1)) 147 | for nb in range(mask_bags.shape[0]): 148 | mask_bags_label.append(np.argmax(np.bincount(mask_bags[nb]))) 149 | mask_bags_label = torch.LongTensor(mask_bags_label) # (N_B,) 150 | return mask_bags_label 151 | 152 | def _generate_mask_words_label(self, mask_words): 153 | # mask_words: (N_B, N_W, W_H, W_W) 154 | mask_words_label = [] 155 | mask_words = mask_words.reshape((mask_words.shape[0], mask_words.shape[1], -1)) 156 | for nb in range(mask_words.shape[0]): 157 | mask_words_label_tmp = [] 158 | for nw in range(mask_words.shape[1]): 159 | mask_words_label_tmp.append(np.argmax(np.bincount(mask_words[nb, nw]))) 160 | mask_words_label.append(mask_words_label_tmp) 161 | mask_words_label = torch.LongTensor(mask_words_label) # (N_B, N_W) 162 | return mask_words_label 163 | 164 | def _load_bl_data(self, index): 165 | num_words_per_bag = (self.opts.bl_bag_size // self.opts.bl_word_size) ** 2 166 | if self.bl_pid_lst[index] != str(np.nan): 167 | bl_pid = sorted(self.bl_pid_lst[index].strip().split("+"))[-1] # 如果有两个切片,取最后一个 168 | # 大切块 bags: (N_B, B_H, B_W, C) | masks: (N_B, B_H, B_W) 169 | bags, masks, keys = load_all_bags_with_masks(os.path.join(self.bl_img_dir, bl_pid), self.bl_bag_size, 170 | self.bl_bag_size, self.bl_img_extn, self.bl_num_bags, split=self.split) 171 | mask_bags_label = self._generate_mask_bags_label(masks) 172 | # 小切块 words: (N_B, N_W, C, W_H, W_W) 173 | feat_words = bags_to_words(bags, self.bl_word_size, self.bl_word_size, bl_pid).float() 174 | # 小切块对应的mask: (N_B, N_W, W_H, W_W) 175 | mask_words = masks_to_words(masks, self.bl_word_size, self.bl_word_size, bl_pid).long() 176 | mask_words_label = self._generate_mask_words_label(mask_words) 177 | 178 | radiomics_file = os.path.join(self.opts.bl_rad_dir, f"radiomics_{bl_pid}_norm.csv") 179 | radiomics_data = pd.read_csv(radiomics_file).values 180 | radiomics_dict = {} 181 | for line in radiomics_data: 182 | radiomics_dict[str(line[0])] = torch.FloatTensor(np.nan_to_num(np.asarray(line[1:], dtype=np.float32), 0.0)).float() # (736,) 183 | radiomics_feat = torch.stack([radiomics_dict[key] for key in keys], dim=0) # (N_B, 736) 184 | flag = 1 185 | else: 186 | feat_words = torch.zeros(max(1, self.opts.bl_num_bags), num_words_per_bag, 3, self.opts.bl_word_size, self.opts.bl_word_size).float() 187 | radiomics_feat = torch.zeros(max(1, self.opts.bl_num_bags), 736).float() 188 | mask_bags_label = torch.full((max(1, self.opts.bl_num_bags),), -1).long() 189 | mask_words_label = torch.full((max(1, self.opts.bl_num_bags), num_words_per_bag), -1).long() 190 | flag = 0 191 | return feat_words, radiomics_feat, mask_bags_label, mask_words_label, flag 192 | 193 | def _load_yx_data(self, index): 194 | if self.yx_pid_lst[index] != str(np.nan): 195 | yx_pid = self.yx_pid_lst[index] 196 | lesions, keys, lesions_label = load_lesions(self.yx_img_dir, yx_pid, 197 | self.yx_lesion_size, self.yx_lesion_size, 198 | self.yx_img_extn, is_training=self.is_training, num_lesions=self.opts.yx_num_lesions, split=self.split) 199 | 200 | radiomics_file = os.path.join(self.opts.yx_rad_dir, f"radiomics_{yx_pid}_norm.csv") 201 | radiomics_data = pd.read_csv(radiomics_file).values 202 | radiomics_dict = {} 203 | for line in radiomics_data: 204 | radiomics_dict[str(line[0])] = torch.FloatTensor(np.nan_to_num(np.asarray(line[1:], dtype=np.float32), 0.0)).float() # (736,) 205 | radiomics_feat = torch.stack([radiomics_dict[key] for key in keys], dim=0) # (N_B, 736) 206 | flag = 1 207 | else: 208 | lesions = torch.zeros(max(1, self.opts.yx_num_lesions), 3, self.opts.yx_lesion_size, self.opts.yx_lesion_size).float() 209 | radiomics_feat = torch.zeros(max(1, self.opts.yx_num_lesions), 736).float() 210 | lesions_label = torch.full((max(1, self.opts.yx_num_lesions),), -1).long() 211 | flag = 0 212 | return lesions, radiomics_feat, lesions_label, flag 213 | 214 | def __getitem__(self, index): 215 | feat_words, bl_radiomics_feat, mask_bags_label, mask_words_label, bl_flag = self._load_bl_data(index) 216 | lesions, yx_radiomics_feat, lesions_label, yx_flag = self._load_yx_data(index) 217 | 218 | #print(self.yx_pid_lst[index], self.bl_pid_lst[index]) 219 | assert bl_flag or yx_flag, (self.id_lst[index], self.yx_pid_lst[index], self.bl_pid_lst[index]) 220 | 221 | return { 222 | "id": self.id_lst[index], 223 | "name": self.name_lst[index], 224 | 225 | "feat_words": feat_words, 226 | "mask_bags_label": mask_bags_label, 227 | "mask_words_label": mask_words_label, 228 | "bl_radiomics_feat": bl_radiomics_feat, 229 | "bl_pid": self.bl_pid_lst[index], 230 | "bl_flag": bl_flag, 231 | 232 | "lesions": lesions, 233 | "lesions_label": lesions_label, 234 | "yx_radiomics_feat": yx_radiomics_feat, 235 | "yx_pid": self.yx_pid_lst[index], 236 | "yx_flag": yx_flag, 237 | 238 | "liaoxiao": self.liaoxiao_lst[index], 239 | "os": self.OS_lst[index], 240 | "os_censor": self.OSCensor_lst[index], 241 | "pfs": self.PFS_lst[index], 242 | "pfs_censor": self.PFSCensor_lst[index], 243 | 244 | "fangan": self.fangan_lst[index], 245 | "label": self.label_lst[index], 246 | 247 | 248 | "clinical_sex": ccd_sex(self.sex_lst[index]), # (2,) 249 | "clinical_age": ccd_age(self.age_lst[index]), # (2,) 250 | "clinical_buwei": ccd_buwei(self.buwei_lst[index]), # (2,) 251 | "clinical_xianshu": ccd_xianshu(self.xianshu_lst[index]), # (2,) 252 | "clinical_time": ccd_time(self.time_lst[index]), # (3,) 253 | "clinical_fenxing": ccd_fenxing(self.fenxing_lst[index]), # (3,) 254 | "clinical_fenhua": ccd_fenhua(self.fenhua_lst[index]), # (4,) 255 | "clinical_ln_dis": ccd_ln_dis(self.ln_dis_lst[index]), # (13,) 256 | "clinical_yx_flag": ccd_yx_flag(self.yx_flag_lst[index]), # (3,) 257 | "clinical_bl_flag": ccd_bl_flag(self.bl_flag_lst[index]), # (3,) 258 | 259 | 260 | "clinical_bl_tils": ccd_bl_tils(self.tils_lst[index]), # (10,) 261 | "clinical_bl_her2": ccd_bl_her2(self.her2_lst[index]), # (4,) 262 | "clinical_bl_tumor": ccd_bl_tumor(self.tumor_lst[index]), # (10,) 263 | 264 | "clinical_yx_stomach": ccd_yx_stomach(os.path.join(self.yx_img_dir, self.yx_pid_lst[index])), # (2,) 265 | #"clinical_yx_ln_dis": ccd_yx_ln_dis(os.path.join(self.yx_img_dir, self.yx_pid_lst[index])), # (9,) 266 | "clinical_yx_ln_num": ccd_yx_ln_num(os.path.join(self.yx_img_dir, self.yx_pid_lst[index])), # (7,) 267 | "clinical_yx_zl_ln": ccd_yx_zl_ln(self.zl_ln_lst[index]), # (3,) 268 | "clinical_yx_zl_ln_pos": ccd_yx_zl_ln_pos(self.zl_ln_pos_lst[index]), # (5,) 269 | "clinical_yx_zl_multi": ccd_yx_zl_multi(self.zl_multi_lst[index]), # (4,) 270 | "clinical_yx_zl_per": ccd_yx_zl_per(self.zl_per_lst[index]), # (2,) 271 | } -------------------------------------------------------------------------------- /data_loader/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import cv2 4 | import torch 5 | import gc 6 | from glob import glob 7 | import os 8 | from PIL import Image, ImageFile, ImageEnhance 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | Image.MAX_IMAGE_PIXELS = 1000000000000 12 | 13 | MEAN = [0.485, 0.456, 0.406] 14 | STD = [0.229, 0.224, 0.225] 15 | yx_MEAN = [0.485, 0.456, 0.406] 16 | yx_STD = [0.229, 0.224, 0.225] 17 | WHITE_BALANCE_LST = ["NF-10", "NF-26", "NF-20", "NF-19", "NF-17", "NF-1", "NF-27", "NF-25", "NF-6", "NF-8", "NF-5", "NF-9", "NF-29", "NF-4", "NF-28"] 18 | 19 | def map_s(s): 20 | if '一线' in s: 21 | return torch.Tensor([0.0, 1.0]).float() 22 | else: 23 | return torch.Tensor([1.0, 0.0]).float() 24 | 25 | def map_c(c): 26 | if c <= 0: 27 | return torch.Tensor([1.0, 0.0]).float() 28 | else: 29 | return torch.Tensor([0.0, 1.0]).float() 30 | 31 | def normalize_words_np(words_np): 32 | # (N_B, N_W*N_H, W_H, W_W, C) 33 | words_np = words_np.astype(float) 34 | words_np /= 255.0 35 | words_np -= MEAN 36 | words_np /= STD 37 | # (N_B, N_W*N_H, W_H, W_W, C) -> (N_B, N_W*N_H, C, W_H, W_W) 38 | words_np = words_np.transpose(0, 1, 4, 2, 3) 39 | return words_np 40 | 41 | lesion_to_label = { 42 | 'LN': 0, 43 | 'stomach': 1, 44 | 'Liver': 2, 45 | 'Peritoneum': 3, 46 | 'Other': 4, 47 | 'Spleen': 4, 48 | 'Bone': 4, 49 | 'Soft': 4, 50 | 'Aden': 4, 51 | } 52 | 53 | def load_lesions(data_dir, pid, lesion_height, lesion_width, img_extn, is_training, num_lesions=4, split="train"): 54 | files = glob(os.path.join(data_dir, "0", pid, f"*.{img_extn}")) 55 | files = [file for file in files if not file.endswith(f"_mask.{img_extn}")] 56 | lesions = [] 57 | keys = [] 58 | lesions_label = [] 59 | if len(files) == 0: 60 | print(data_dir, pid) 61 | 62 | if num_lesions > 0: 63 | files = np.random.choice(files, num_lesions, replace=True) 64 | 65 | for file in files: 66 | key = os.path.basename(file).split('.')[0] 67 | lesion = np.stack([ 68 | np.asarray(Image.open(file.replace("/0/", "/-1/")).convert("L")), 69 | np.asarray(Image.open(file).convert("L")), 70 | np.asarray(Image.open(file.replace("/0/", "/1/")).convert("L")), 71 | ], axis=-1) 72 | if lesion_height != lesion.shape[0] or lesion_width != lesion.shape[1]: 73 | lesion = cv2.resize(lesion, (lesion_width, lesion_height)) 74 | if is_training: 75 | lesion = random_transform_np(lesion, max_rotation=30, pad_value=0) 76 | lesions.append(lesion) 77 | keys.append(os.path.basename(file).split(".")[0]) 78 | lesions_label.append(lesion_to_label[os.path.basename(file).split('_')[-1].split('.')[0].replace('1', '').replace('2', '')]) 79 | if len(lesions) == 0: 80 | print(data_dir, pid, "+++++++") # debug 81 | return torch.Tensor(1) 82 | lesions = np.stack(lesions, axis=0) 83 | lesions = lesions.astype(float) 84 | lesions /= 255.0 85 | lesions -= yx_MEAN 86 | lesions /= yx_STD 87 | lesions = torch.Tensor(lesions.transpose(0, 3, 1, 2)).float() # (N_l, 3, 224, 224) 88 | lesions_label = torch.LongTensor(lesions_label) 89 | return lesions, keys, lesions_label 90 | 91 | def random_transform_np(img_np, max_rotation=10, pad_value=255): 92 | h, w = img_np.shape[:2] 93 | # flip the bag 94 | if random.random() < 0.5: 95 | flip_code = random.choice([0, 1]) # 0 for horizontal and 1 for vertical 96 | img_np = cv2.flip(img_np, flip_code) 97 | 98 | # rotate the image 99 | if random.random() < 0.5: 100 | angle = random.choice(np.arange(-max_rotation, max_rotation + 1).tolist()) 101 | # note that these functions take argument as (w, h) 102 | rot_mat = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1) 103 | img_np = cv2.warpAffine(img_np, rot_mat, (w, h), 104 | borderValue=(pad_value, pad_value, pad_value)) # 255 because that correpond to background in WSIs 105 | 106 | # random crop and scale 107 | if random.random() < 0.5: 108 | x = random.randint(0, w - w // 4) 109 | y = random.randint(0, h - h // 4) 110 | img_np = img_np[y:, x:] 111 | img_np = cv2.resize(img_np, (w, h)) 112 | 113 | return img_np 114 | 115 | def load_all_bags_with_masks(data_dir, bag_height, bag_width, img_extn, num_bags, split="train"): 116 | files = glob(os.path.join(data_dir, f"*-img.{img_extn}")) 117 | bags = [] 118 | masks = [] 119 | keys = [] 120 | if num_bags > 0: 121 | files = np.random.choice(files, num_bags) 122 | 123 | for file in files: 124 | pid = file.split("/")[-2] 125 | if pid in WHITE_BALANCE_LST: 126 | #print(pid) 127 | file = file.replace("GuoZhongBingLiData", "GuoZhongBingLiDataWB") 128 | # (B_H, B_W, C) 129 | bag = np.asarray(Image.open(file).convert("RGB")) 130 | if bag_height != bag.shape[0] or bag_width != bag.shape[1]: 131 | bag = cv2.resize(bag, (bag_width, bag_height)) 132 | bags.append(bag) 133 | 134 | mask_file = file.replace("GuoZhongBingLiDataWB", "GuoZhongBingLiData").replace('-img', '-mask') 135 | # mask_file = file.replace('-img', '-mask') 136 | mask = np.asarray(Image.open(mask_file).convert('L')) 137 | 138 | mask[mask==1] = 1 139 | mask[mask==2] = 1 140 | mask[mask==3] = 1 141 | mask[mask==4] = 2 142 | mask[mask==5] = 3 143 | 144 | if bag_height != mask.shape[0] or bag_width != mask.shape[1]: 145 | mask = cv2.resize(mask, (bag_width, bag_height)) 146 | masks.append(mask) 147 | 148 | keys.append(os.path.basename(file).replace("-img.png", "").replace("-mask.png", "")) 149 | if len(bags) == 0: 150 | return torch.Tensor(1) 151 | bags = np.stack(bags, axis=0) # (N_B, B_H, B_W, C) 152 | masks = np.stack(masks, axis=0) # (N_B, B_H, B_W) 153 | keys = np.asarray(keys) # (N_B) 154 | return bags, masks, keys 155 | 156 | def bags_to_words(bags, word_height, word_width, pid): 157 | # bags: (N_B, B_H, B_W, C) 158 | num_bags, bag_height, bag_width, channel = bags.shape 159 | # (N_B, B_H, B_W, C) -> (N_B, B_H, N_W, W_W, C) 160 | words = np.reshape(bags, (num_bags, bag_height, -1, word_width, channel)) 161 | # (N_B, B_H, N_W, W_W, C) -> (N_B, N_W, B_H, W_W, C) 162 | words = words.transpose(0, 2, 1, 3, 4) 163 | # (N_B, N_W, B_H, W_W, C) -> (N_B, N_W*N_H, W_H, W_W, C) 164 | words = np.reshape(words, (num_bags, -1, word_height, word_width, channel)) 165 | 166 | words = normalize_words_np(words) 167 | words_torch = torch.from_numpy(words).float() 168 | del words 169 | gc.collect() 170 | 171 | return words_torch 172 | 173 | def masks_to_words(masks, word_height, word_width, pid): 174 | # masks: (N_B, B_H, B_W) 175 | num_bags, bag_height, bag_width = masks.shape 176 | # (N_B, B_H, B_W) -> (N_B, B_H, N_W, W_W) 177 | words = np.reshape(masks, (num_bags, bag_height, -1, word_width)) 178 | # (N_B, B_H, N_W, W_W) -> (N_B, N_W, B_H, W_W) 179 | words = words.transpose(0, 2, 1, 3) 180 | # (N_B, N_W, B_H, W_W) -> (N_B, N_W*N_H, W_H, W_W) 181 | words = np.reshape(words, (num_bags, -1, word_height, word_width)) 182 | 183 | words_torch = torch.from_numpy(words).long() 184 | del words 185 | 186 | return words_torch 187 | 188 | 189 | # 影像临床特征(七个特征) 190 | ln_to_id = { 191 | "stomach": 0, 192 | "LN": 1, 193 | "Liver": 2, 194 | "Aden": 3, 195 | "Soft": 4, 196 | "Peritoneum": 5, 197 | "Other": 6, 198 | "Spleen": 7, 199 | "Bone": 8 200 | } 201 | 202 | def ccd_yx_stomach(pid_dir): 203 | # 影像:有无原发灶 204 | if os.path.isdir(pid_dir): 205 | stomach_files = [file for file in glob(os.path.join(pid_dir, "v_stomach*.jpg")) if not file.endswith("_mask.jpg")] 206 | if len(stomach_files): 207 | return torch.FloatTensor([0.0, 1.0]) # 有原发灶 208 | else: 209 | return torch.FloatTensor([1.0, 0.0]) # 无原发灶 210 | else: 211 | return torch.FloatTensor([0.0, 0.0]) 212 | 213 | def ccd_yx_ln_num(pid_dir): 214 | # 影像:转移灶数量 215 | ln_num = [0.0] * 7 216 | if os.path.isdir(pid_dir): 217 | stomach_files = [file for file in glob(os.path.join(pid_dir, "v_stomach*.jpg")) if not file.endswith("_mask.jpg")] 218 | files = [file for file in glob(os.path.join(pid_dir, "v_*.jpg")) if not file.endswith("_mask.jpg")] 219 | ln_num[len(files)-len(stomach_files)] = 1.0 220 | return torch.FloatTensor(ln_num) 221 | else: 222 | return torch.FloatTensor(ln_num) 223 | 224 | def ccd_yx_zl_ln(x): 225 | # 影像:转移淋巴结(0 没有转移淋巴结 1存在转移淋巴结,但不融合 2存在融合的转移淋巴结) 226 | zl_ln = [0.0] * 3 227 | if str(x) != str(np.nan): 228 | zl_ln[int(x)] = 1.0 229 | return torch.FloatTensor(zl_ln) 230 | 231 | def ccd_yx_zl_ln_pos(x): 232 | # 影像:转移淋巴结位置(0没有转移淋巴结 1 存在局域转移淋巴结 2 存在M分期的腹腔或腹膜后淋巴结转移 3存在纵隔或锁骨上淋巴结转移 4存在其他少见远隔部位淋巴结转移(如腋窝、颈旁、腹股沟等等区域) 233 | zl_ln_pos = [0.0] * 5 234 | if str(x) != str(np.nan): 235 | for i in str(int(x)): 236 | zl_ln_pos[int(i)] = 1.0 237 | return torch.FloatTensor(zl_ln_pos) 238 | 239 | def ccd_yx_zl_multi(x): 240 | # 影像:肝或肺多发转移,多发为≥3个病灶(0没有多发肝转移和多发肺转移 1仅存在多发肝转移 2仅存在多发肺转移 3两者均有) 241 | zl_multi = [0.0] * 4 242 | if str(x) != str(np.nan): 243 | zl_multi[int(x)] = 1.0 244 | return torch.FloatTensor(zl_multi) 245 | 246 | def ccd_yx_zl_per(x): 247 | # 影像:腹膜转移(0 无腹膜转移 1存在腹膜转移) 248 | zl_per = [0.0] * 2 249 | if str(x) != str(np.nan): 250 | zl_per[int(x)] = 1.0 251 | return torch.FloatTensor(zl_per) 252 | 253 | def ccd_bl_tils(x): 254 | # 病理:肿瘤相关淋巴细胞TILs 255 | tils = [0.0] * 10 256 | if str(x) != str(np.nan): 257 | tils[int(x*100)//10] = 1.0 258 | return torch.FloatTensor(tils) 259 | 260 | def ccd_bl_her2(x): 261 | # 病理:HER2表达异质性 262 | her2 = [0.0] * 4 263 | if str(x) != str(np.nan): 264 | if x.startswith("无"): 265 | her2[int(x[3:4])] = 1.0 266 | elif x.startswith("异质性"): 267 | for t in x[4:-1].split(","): 268 | try: 269 | her2[int(t[:1])] = 1.0*int(t[2:-1])/100.0 270 | except: 271 | pass 272 | else: 273 | raise NotImplementedError 274 | return torch.FloatTensor(her2) 275 | 276 | def ccd_bl_tumor(x): 277 | tumor = [0.0] * 10 278 | if str(x) != str(np.nan): 279 | tumor[int(x*100)//10] = 1.0 280 | return torch.FloatTensor(tumor) 281 | 282 | # 患者层面临床特征(四个特征) 283 | 284 | def ccd_sex(x): 285 | # 患者:性别 286 | x = str(x).strip() 287 | if x in ["男"]: 288 | return torch.FloatTensor([0.0, 1.0]) 289 | elif x in ["女"]: 290 | return torch.FloatTensor([1.0, 0.0]) 291 | else: 292 | raise NotImplementedError 293 | 294 | def ccd_age(x, split_age=60): 295 | # 患者:年龄 296 | if x <= split_age: 297 | return torch.FloatTensor([1.0, 0.0]) 298 | elif x > split_age: 299 | return torch.FloatTensor([0.0, 1.0]) 300 | else: 301 | raise NotImplementedError 302 | 303 | def ccd_buwei(x): 304 | # 患者:肿瘤部位 305 | x = str(x).strip() 306 | if x in ["GEJ"]: 307 | return torch.FloatTensor([1.0, 0.0]) 308 | elif x in ["non-GJE"]: 309 | return torch.FloatTensor([0.0, 1.0]) 310 | else: 311 | raise NotImplementedError 312 | 313 | def ccd_xianshu(x): 314 | # 患者:治疗线数 315 | x = str(x).strip() 316 | if "一线" in x: 317 | return torch.FloatTensor([1.0, 0.0]) 318 | else: 319 | return torch.FloatTensor([0.0, 1.0]) 320 | 321 | def ccd_time(x): 322 | # 患者:开始治疗时间 323 | if int(x) >= 2007 and int(x) < 2012: 324 | return torch.FloatTensor([1.0, 0.0, 0.0]) 325 | elif int(x) >= 2012 and int(x) < 2017: 326 | return torch.FloatTensor([0.0, 1.0, 0.0]) 327 | elif int(x) >= 2017 and int(x) <= 2023: 328 | return torch.FloatTensor([0.0, 0.0, 1.0]) 329 | else: 330 | raise NotImplementedError 331 | 332 | def ccd_fenxing(x): 333 | # 病理:LAUREN分型 334 | if x == "肠型": 335 | return torch.FloatTensor([1.0, 0.0, 0.0]) 336 | elif x == "弥漫型": 337 | return torch.FloatTensor([0.0, 1.0, 0.0]) 338 | elif x == "混合型": 339 | return torch.FloatTensor([0.0, 0.0, 1.0]) 340 | elif str(x) == str(np.nan): 341 | return torch.FloatTensor([0.0, 0.0, 0.0]) 342 | else: 343 | raise NotImplementedError 344 | 345 | def ccd_fenhua(x): 346 | # 病理:分化程度 347 | if x in ["低分化"]: 348 | return torch.FloatTensor([1.0, 0.0, 0.0, 0.0]) 349 | elif x in ["中分化"]: 350 | return torch.FloatTensor([0.0, 1.0, 0.0, 0.0]) 351 | elif x in ["高分化"]: 352 | return torch.FloatTensor([0.0, 0.0, 1.0, 0.0]) 353 | elif x in ["弥漫型"]: 354 | return torch.FloatTensor([0.0, 0.0, 0.0, 1.0]) 355 | elif str(x) == str(np.nan): 356 | return torch.FloatTensor([0.0, 0.0, 0.0, 0.0]) 357 | else: 358 | return torch.FloatTensor([0.0, 0.0, 0.0, 0.0]) 359 | raise NotImplementedError 360 | 361 | def ccd_ln_dis(x): 362 | # 影像:病灶分布 363 | x = str(x) 364 | ln_dis = [0.0] * 13 365 | for i in range(len(x)): 366 | ln_dis[i] = float(int(x[i])) 367 | return torch.FloatTensor(ln_dis) 368 | 369 | def ccd_yx_flag(x): 370 | yx_flag = [0.0] * 3 371 | if x >= 0: yx_flag[x] = 1.0 372 | return torch.FloatTensor(yx_flag) 373 | 374 | def ccd_bl_flag(x): 375 | bl_flag = [0.0] * 3 376 | if x >= 0: bl_flag[x] = 1.0 377 | return torch.FloatTensor(bl_flag) 378 | 379 | def convert_label(label_type, OS, OSCensor, liaoxiao, PFS=None, PFSCensor=None): 380 | if label_type == "ORR": 381 | if liaoxiao in ["CR", "PR"]: 382 | return 1 383 | elif liaoxiao in ["SD", "PD"]: 384 | return 0 385 | else: 386 | return -1 387 | elif label_type == "ORR_OS180": 388 | if liaoxiao in ["CR", "PR"]: 389 | return 1 390 | elif liaoxiao in ["PD",]: 391 | return 0 392 | else: 393 | cutoff = 180 # about six months 394 | if OS > cutoff: 395 | return 1 396 | elif OSCensor == 1 and OS <= cutoff: 397 | return 0 398 | else: 399 | return -1 400 | elif label_type == "ORR_PFS240": 401 | if liaoxiao in ["CR", "PR"]: 402 | return 1 403 | elif liaoxiao in ["PD",]: 404 | return 0 405 | else: 406 | cutoff = 240 # about eight months 407 | if PFS > cutoff: 408 | return 1 409 | elif PFSCensor == 1 and PFS <= cutoff: 410 | return 0 411 | else: 412 | return -1 413 | elif label_type == "ORR_PFS300": 414 | if liaoxiao in ["CR", "PR"]: 415 | return 1 416 | elif liaoxiao in ["PD",]: 417 | return 0 418 | else: 419 | cutoff = 300 # about ten months 420 | if PFS > cutoff: 421 | return 1 422 | elif PFSCensor == 1 and PFS <= cutoff: 423 | return 0 424 | else: 425 | return -1 426 | elif label_type == "PFS240": 427 | cutoff = 240 # about eight months 428 | if PFS > cutoff: 429 | return 1 430 | elif PFSCensor == 1 and PFS <= cutoff: 431 | return 0 432 | else: 433 | return -1 434 | elif label_type == "PFS300": 435 | cutoff = 300 # about ten months 436 | if PFS > cutoff: 437 | return 1 438 | elif PFSCensor == 1 and PFS <= cutoff: 439 | return 0 440 | else: 441 | return -1 442 | else: 443 | raise NotImplementedError -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from config.opts import get_config 3 | from train_and_eval import Trainer 4 | from utils.utils import * 5 | import os 6 | from utils.print_utils import * 7 | import json 8 | 9 | if __name__ == '__main__': 10 | opts, parser = get_config() 11 | torch.set_default_dtype(torch.float32) 12 | 13 | logger = build_logging(os.path.join(opts.save_dir, "log.log")) 14 | printer = logger.info 15 | 16 | print_log_message('Arguments', printer) 17 | printer(json.dumps(vars(opts), indent=4, sort_keys=True)) 18 | 19 | trainer = Trainer(opts=opts, printer=printer) 20 | trainer.run() -------------------------------------------------------------------------------- /model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/model/.DS_Store -------------------------------------------------------------------------------- /model/bl_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from model.nn_layers.ffn import FFN 5 | from model.nn_layers.attn_layers import * 6 | from typing import NamedTuple, Optional 7 | from torch import Tensor 8 | import math 9 | from utils.print_utils import * 10 | import os 11 | import numpy as np 12 | import torch.nn.functional as F 13 | import random 14 | from torchvision import models 15 | from model.feature_extractors.mnasnet import MNASNet 16 | from model.nn_layers.transformer import * 17 | from tqdm import tqdm 18 | import os 19 | 20 | class BLModel(nn.Module): 21 | def __init__(self, opts, *args, **kwargs): 22 | super().__init__() 23 | self.opts = opts 24 | 25 | if opts.bl_cnn_name == "mnasnet": 26 | s = 1.0 27 | weight = 'checkpoints/mnasnet_s_1.0_imagenet_224x224.pth' 28 | backbone = MNASNet(alpha=s) 29 | pretrained_dict = torch.load(weight, map_location=torch.device('cpu')) 30 | backbone.load_state_dict(pretrained_dict) 31 | del backbone.classifier 32 | self.cnn = backbone 33 | else: 34 | backbone = eval(f"models.{opts.bl_cnn_name}")(pretrained=opts.bl_cnn_pretrained) 35 | # if opts.bl_cnn_name == "mnasnet1_0": 36 | if "mnasnet" in opts.bl_cnn_name: 37 | self.cnn = nn.Sequential(*[*list(backbone.children())[:-1], nn.AdaptiveAvgPool2d(1)]) 38 | else: 39 | self.cnn = nn.Sequential(*list(backbone.children())[:-1]) 40 | 41 | self.attn_layer = nn.Conv2d(opts.bl_cnn_features+3, 1, kernel_size=1, padding=0, stride=1) 42 | 43 | self.project_words = nn.Linear(opts.bl_cnn_features, opts.bl_out_features) 44 | 45 | self.attn_over_words = nn.MultiheadAttention(embed_dim=opts.bl_out_features, 46 | num_heads=1, dropout=opts.bl_dropout, batch_first=True) 47 | self.attn_over_bags = nn.MultiheadAttention(embed_dim=opts.bl_out_features, 48 | num_heads=1, dropout=opts.bl_dropout, batch_first=True) 49 | self.attn_dropout = nn.Dropout(p=opts.bl_attn_dropout) 50 | 51 | self.bags_classifier = nn.Linear(opts.bl_out_features, 4) 52 | self.words_classifier = nn.Linear(opts.bl_out_features, 4) 53 | 54 | self.ffn_attn_w2b_lst = FFN(input_dim=opts.bl_out_features, scale=2, p=opts.bl_dropout) 55 | self.ffn_b2i = FFN(input_dim=opts.bl_out_features, scale=2, p=opts.bl_dropout) 56 | 57 | self.attn_fn = nn.Softmax(dim=-1) 58 | 59 | self.words_weight_fn = nn.Linear(opts.bl_out_features, 1, bias=False) 60 | self.bags_weight_fn = nn.Linear(opts.bl_out_features, 1, bias=False) 61 | 62 | # 临床特征融入 63 | # self.fenxing_fc = nn.Linear(3, opts.bl_out_features) 64 | # self.fenhua_fc = nn.Linear(4, opts.bl_out_features) 65 | self.tils_fc = nn.Linear(10, opts.bl_out_features) 66 | self.her2_fc = nn.Linear(4, opts.bl_out_features) 67 | self.tumor_fc = nn.Linear(10, opts.bl_out_features) 68 | self.clinical_image_attn = nn.MultiheadAttention(embed_dim=opts.bl_out_features, num_heads=1, dropout=opts.bl_dropout, batch_first=True) 69 | 70 | # 组学特征融入 71 | self.radiomics_fc = nn.Linear(736, opts.bl_out_features) 72 | self.radiomics_image_attn = nn.MultiheadAttention(embed_dim=opts.bl_out_features, num_heads=1, dropout=opts.bl_dropout, batch_first=True) 73 | 74 | #self.bl_classifier = nn.Linear(opts.bl_out_features, opts.n_classes) 75 | 76 | def energy_function(self, x, weight_fn, need_attn=False): 77 | # x: (B, N, C) 78 | x = weight_fn(x).squeeze(dim=-1) # (B, N) 79 | energy: Tensor[Optional] = None 80 | if need_attn: 81 | energy = x 82 | x = self.attn_fn(x) 83 | x = self.attn_dropout(x) 84 | x = x.unsqueeze(dim=-1) # (B, N, 1) 85 | return x, energy 86 | 87 | def parallel_radiomics_clinical(self, image_from_bags, radiomics_feat, clinical_feat): 88 | radiomics_image_feat, radiomics_attnmap = self.radiomics_image_attn(key=radiomics_feat, 89 | query=image_from_bags.unsqueeze(dim=1), value=radiomics_feat) 90 | clinical_image_feat, clinical_attnmap = self.clinical_image_attn(key=clinical_feat, 91 | query=image_from_bags.unsqueeze(dim=1), value=clinical_feat) 92 | image_from_bags = image_from_bags \ 93 | + clinical_image_feat.squeeze(dim=1) \ 94 | + radiomics_image_feat.squeeze(dim=1) 95 | self.info_dict["clinical_weight"] = clinical_attnmap[0, 0].detach().cpu().numpy() 96 | return image_from_bags 97 | 98 | def series_radiomics_clinical(self, image_from_bags, radiomics_feat, clinical_feat): 99 | radiomics_image_feat, _ = self.radiomics_image_attn(key=radiomics_feat, 100 | query=image_from_bags.unsqueeze(dim=1), value=radiomics_feat) 101 | image_from_bags = image_from_bags + radiomics_image_feat.squeeze(dim=1) 102 | clinical_image_feat, _ = self.clinical_image_attn(key=clinical_feat, 103 | query=image_from_bags.unsqueeze(dim=1), value=clinical_feat) 104 | image_from_bags = image_from_bags + clinical_image_feat.squeeze(dim=1) 105 | return image_from_bags 106 | 107 | def series_clinical_radiomics(self, image_from_bags, radiomics_feat, clinical_feat): 108 | clinical_image_feat, _ = self.clinical_image_attn(key=clinical_feat, 109 | query=image_from_bags.unsqueeze(dim=1), value=clinical_feat) 110 | image_from_bags = image_from_bags + clinical_image_feat.squeeze(dim=1) 111 | radiomics_image_feat, _ = self.radiomics_image_attn(key=radiomics_feat, 112 | query=image_from_bags.unsqueeze(dim=1), value=radiomics_feat) 113 | image_from_bags = image_from_bags + radiomics_image_feat.squeeze(dim=1) 114 | return image_from_bags 115 | 116 | def forward(self, batch, *args, **kwargs): 117 | # if not self.training: 118 | # return self.incremental_inference(batch) 119 | words = batch['feat_words'] 120 | bl_flag = batch["clinical_bl_flag"] # (B, 3,) 121 | 122 | # STEP1: Project CNN encoded words 123 | # (B, N_b, N_w, C, H_w, W_w) --> (B, N_b, N_w, d) 124 | B, N_b, N_w, C, H, W = words.shape 125 | words_cnn = self.cnn(words.view(B*N_b*N_w, C, H, W)) 126 | bl_flag = bl_flag.view(B, 1, 1, 3, 1, 1).repeat(1, N_b, N_w, 1, *words_cnn.shape[-2:]) 127 | bl_flag = bl_flag.view(B*N_b*N_w, 3, *words_cnn.shape[-2:]) 128 | attn_words_cnn = torch.cat([words_cnn, bl_flag], dim=1) 129 | attn_words_cnn = torch.sigmoid(self.attn_layer(attn_words_cnn)) 130 | words_cnn = words_cnn * attn_words_cnn 131 | words_cnn = words_cnn.view(B, N_b, N_w, -1) 132 | words_cnn = self.project_words(words_cnn) 133 | 134 | self.info_dict = { 135 | "id": batch["id"][0], 136 | "bags_label": batch["mask_bags_label"][0].detach().cpu().numpy(), 137 | "words_label": batch["mask_words_label"][0].detach().cpu().numpy(), 138 | } 139 | 140 | # STEP2: Words to Bags (Attn words | CNN words) 141 | words_cnn = words_cnn.view(B*N_b, N_w, -1) 142 | 143 | # import os 144 | # feat_dir = "../Analysis/BLFeat" 145 | # os.makedirs(feat_dir, exist_ok=True) 146 | # np.save(os.path.join(feat_dir, self.info_dict["id"]+".npz"), words_cnn.view(B, N_b, N_w, -1).detach().cpu().numpy()) 147 | 148 | words_attn, words_attnmap = self.attn_over_words(key=words_cnn, query=words_cnn, value=words_cnn) 149 | words_attn = words_attn.view(B, N_b, N_w, -1) 150 | self.info_dict["words_attnmap"] = words_attnmap[0].detach().cpu().numpy() 151 | words_attn_energy, words_attn_energy_unnorm = self.energy_function(words_attn, self.words_weight_fn) 152 | # (B, N_B, N_W, C) * (B, N_B, N_W, 1) --> (B, N_B, C) 153 | self.info_dict["words_weight"] = words_attn_energy[0, ..., 0].detach().cpu().numpy() 154 | bags_from_attn_words = torch.sum(words_attn * words_attn_energy, dim=-2) 155 | bags_from_attn_words = self.ffn_attn_w2b_lst(bags_from_attn_words) 156 | 157 | mask_words_pred = self.words_classifier(words_attn) # (B, N_B, N_W, 6) 158 | 159 | # STEP3: Bags to Image 160 | bags_attn, bags_attnmap = self.attn_over_bags(key=bags_from_attn_words, query=bags_from_attn_words, value=bags_from_attn_words) 161 | self.info_dict["bags_attnmap"] = bags_attnmap[0].detach().cpu().numpy() 162 | bags_energy, bags_energy_unnorm = self.energy_function(bags_attn, self.bags_weight_fn) 163 | self.info_dict["bags_weight"] = bags_energy[0, ..., 0].detach().cpu().numpy() 164 | image_from_bags = torch.sum(bags_attn * bags_energy, dim=-2) 165 | image_from_bags = self.ffn_b2i(image_from_bags) 166 | 167 | mask_bags_pred = self.bags_classifier(bags_attn) # (B, N_B, 6) 168 | 169 | clinical_feat = torch.stack([ 170 | self.tils_fc(batch["clinical_bl_tils"]), 171 | self.her2_fc(batch["clinical_bl_her2"]), 172 | self.tumor_fc(batch["clinical_bl_tumor"]), 173 | ], dim=1) # (B, 5, C) 174 | 175 | radiomics_feat = self.radiomics_fc(batch["bl_radiomics_feat"]) # (B, M, C) 176 | 177 | if self.opts.feat_fusion_mode == "parallel": 178 | image_from_bags = self.parallel_radiomics_clinical(image_from_bags, radiomics_feat, clinical_feat) 179 | elif self.opts.feat_fusion_mode == "series_rc": 180 | image_from_bags = self.series_radiomics_clinical(image_from_bags, radiomics_feat, clinical_feat) 181 | elif self.opts.feat_fusion_mode == "series_cr": 182 | image_from_bags = self.series_clinical_radiomics(image_from_bags, radiomics_feat, clinical_feat) 183 | else: 184 | raise NotImplementedError 185 | 186 | if self.opts.attnmap_weight_dir not in [None, "None"]: 187 | npz_dir = os.path.join(self.opts.attnmap_weight_dir, "BL") 188 | os.makedirs(npz_dir, exist_ok=True) 189 | if batch["bl_flag"][0].item(): 190 | np.savez(os.path.join(npz_dir, self.info_dict["id"]+".npz"), 191 | bags_attnmap=self.info_dict["bags_attnmap"], 192 | bags_weight=self.info_dict["bags_weight"], 193 | words_attnmap=self.info_dict["words_attnmap"], 194 | words_weight=self.info_dict["words_weight"], 195 | clinical_weight=self.info_dict["clinical_weight"], 196 | bags_label=self.info_dict["bags_label"], 197 | words_label=self.info_dict["words_label"]) 198 | 199 | return { 200 | "feat": image_from_bags, 201 | "mask_words_pred": mask_words_pred, 202 | "mask_bags_pred": mask_bags_pred, 203 | #"bl_pred": self.bl_classifier(image_from_bags) 204 | } 205 | 206 | def incremental_inference(self, batch, max_bags_gpu0=64, *args, **kwargs): 207 | words = batch['feat_words'] 208 | 209 | # STEP1: Project CNN encoded words 210 | # (B, N_b, N_w, F) --> (B, N_b, N_w, d) 211 | B, N_b, N_w, C, H, W = words.shape 212 | words = words.view(B*N_b*N_w, C, H, W) 213 | words_cnn = [] 214 | N = words.shape[0] 215 | indexes = np.arange(0, N, max_bags_gpu0) 216 | for i in range(len(indexes)): 217 | start = indexes[i] 218 | if i < len(indexes) - 1: 219 | end = indexes[i+1] 220 | words_batch = words[start:end] 221 | else: 222 | words_batch = words[start:] 223 | words_cnn.append(self.cnn(words_batch)) 224 | words_cnn = torch.cat(words_cnn, dim=0) 225 | words_cnn = words_cnn.view(B, N_b, N_w, -1) 226 | words_cnn = self.project_words(words_cnn) 227 | 228 | self.info_dict = { 229 | "id": batch["id"][0], 230 | "bags_label": batch["mask_bags_label"][0].detach().cpu().numpy(), 231 | "words_label": batch["mask_words_label"][0].detach().cpu().numpy(), 232 | } 233 | 234 | # STEP2: Words to Bags (Attn words | CNN words) 235 | words_cnn = words_cnn.view(B*N_b, N_w, -1) 236 | words_attn = [] 237 | words_attnmap = [] 238 | N = words_cnn.shape[0] 239 | indexes = np.arange(0, N, max_bags_gpu0) 240 | for i in range(len(indexes)): 241 | start = indexes[i] 242 | if i < len(indexes) - 1: 243 | end = indexes[i+1] 244 | words_cnn_batch = words_cnn[start:end] 245 | else: 246 | words_cnn_batch = words_cnn[start:] 247 | 248 | words_attn_batch, words_attnmap_batch = self.attn_over_words(key=words_cnn_batch, query=words_cnn_batch, value=words_cnn_batch) 249 | words_attn.append(words_attn_batch) 250 | words_attnmap.append(words_attnmap_batch) 251 | words_attn = torch.cat(words_attn, dim=0).view(B, N_b, N_w, -1) 252 | words_attnmap = torch.cat(words_attnmap, dim=0).view(B, N_b, N_w, N_w) 253 | self.info_dict["words_attnmap"] = words_attnmap[0].detach().cpu().numpy() 254 | words_attn_energy, words_attn_energy_unnorm = self.energy_function(words_attn, self.words_weight_fn) 255 | # (B, N_B, N_W, C) * (B, N_B, N_W, 1) --> (B, N_B, C) 256 | self.info_dict["words_weight"] = words_attn_energy[0, ..., 0].detach().cpu().numpy() 257 | bags_from_attn_words = torch.sum(words_attn * words_attn_energy, dim=-2) 258 | bags_from_attn_words = self.ffn_attn_w2b_lst(bags_from_attn_words) 259 | 260 | mask_words_pred = self.words_classifier(words_attn) # (B, N_B, N_W, 6) 261 | 262 | # STEP3: Bags to Image 263 | bags_attn, bags_attnmap = self.attn_over_bags(key=bags_from_attn_words, query=bags_from_attn_words, value=bags_from_attn_words) 264 | self.info_dict["bags_attnmap"] = bags_attnmap[0].detach().cpu().numpy() 265 | bags_energy, bags_energy_unnorm = self.energy_function(bags_attn, self.bags_weight_fn) 266 | self.info_dict["bags_weight"] = bags_energy[0, ..., 0].detach().cpu().numpy() 267 | image_from_bags = torch.sum(bags_attn * bags_energy, dim=-2) 268 | image_from_bags = self.ffn_b2i(image_from_bags) 269 | 270 | mask_bags_pred = self.bags_classifier(bags_attn) # (B, N_B, 6) 271 | 272 | clinical_feat = torch.stack([ 273 | self.tils_fc(batch["clinical_bl_tils"]), 274 | self.her2_fc(batch["clinical_bl_her2"]), 275 | self.tumor_fc(batch["clinical_bl_tumor"]), 276 | ], dim=1) # (B, 5, C) 277 | 278 | radiomics_feat = self.radiomics_fc(batch["bl_radiomics_feat"]) # (B, M, C) 279 | 280 | if self.opts.feat_fusion_mode == "parallel": 281 | image_from_bags = self.parallel_radiomics_clinical(image_from_bags, radiomics_feat, clinical_feat) 282 | elif self.opts.feat_fusion_mode == "series_rc": 283 | image_from_bags = self.series_radiomics_clinical(image_from_bags, radiomics_feat, clinical_feat) 284 | elif self.opts.feat_fusion_mode == "series_cr": 285 | image_from_bags = self.series_clinical_radiomics(image_from_bags, radiomics_feat, clinical_feat) 286 | else: 287 | raise NotImplementedError 288 | 289 | 290 | npz_dir = "../Analysis/Attnmap_weight/BL" 291 | os.makedirs(npz_dir, exist_ok=True) 292 | np.savez(os.path.join(npz_dir, self.info_dict["id"]+".npz"), 293 | bags_attnmap=self.info_dict["bags_attnmap"], 294 | bags_weight=self.info_dict["bags_weight"], 295 | bags_label=self.info_dict["bags_label"], 296 | words_attnmap=self.info_dict["words_attnmap"], 297 | words_weight=self.info_dict["words_weight"], 298 | words_label=self.info_dict["words_label"], 299 | clinical_weight=self.info_dict["clinical_weight"]) 300 | 301 | return { 302 | "feat": image_from_bags, 303 | "mask_words_pred": mask_words_pred, 304 | "mask_bags_pred": mask_bags_pred, 305 | #"bl_pred": self.bl_classifier(image_from_bags) 306 | } 307 | 308 | 309 | class BLModel_OnlyHer2(nn.Module): 310 | def __init__(self, opts, *args, **kwargs): 311 | super().__init__() 312 | self.opts = opts 313 | self.her2_fc = nn.Linear(4, opts.bl_out_features) 314 | 315 | def forward(self, batch, *args, **kwargs): 316 | return { 317 | "feat": self.her2_fc(batch["clinical_bl_her2"]), 318 | "mask_words_pred": None, 319 | "mask_bags_pred": None, 320 | } 321 | 322 | 323 | class BLModelOnly(BLModel): 324 | def __init__(self, opts, *args, **kwargs): 325 | super().__init__(opts, *args, **kwargs) 326 | self.classifier = nn.Linear(opts.bl_out_features, opts.n_classes) 327 | 328 | def forward(self, batch, *args, **kwargs): 329 | bl_results = super().forward(batch, *args, **kwargs) 330 | bl_pred = self.classifier(bl_results["feat"]) 331 | return { 332 | "pred": bl_pred, 333 | "feat": bl_results["feat"], 334 | "mask_words_pred": bl_results["mask_words_pred"], 335 | "mask_bags_pred": bl_results["mask_bags_pred"], 336 | } 337 | 338 | 339 | class BLModelNoneClinicalRadiomics(BLModel): 340 | def forward(self, batch, *args, **kwargs): 341 | # if not self.training: 342 | # return self.incremental_inference(batch) 343 | words = batch['feat_words'] 344 | bl_flag = batch["clinical_bl_flag"] # (B, 3,) 345 | 346 | # STEP1: Project CNN encoded words 347 | # (B, N_b, N_w, C, H_w, W_w) --> (B, N_b, N_w, d) 348 | B, N_b, N_w, C, H, W = words.shape 349 | words_cnn = self.cnn(words.view(B*N_b*N_w, C, H, W)) 350 | bl_flag = bl_flag.view(B, 1, 1, 3, 1, 1).repeat(1, N_b, N_w, 1, *words_cnn.shape[-2:]) 351 | bl_flag = bl_flag.view(B*N_b*N_w, 3, *words_cnn.shape[-2:]) 352 | attn_words_cnn = torch.cat([words_cnn, bl_flag], dim=1) 353 | attn_words_cnn = torch.sigmoid(self.attn_layer(attn_words_cnn)) 354 | words_cnn = words_cnn * attn_words_cnn 355 | words_cnn = words_cnn.view(B, N_b, N_w, -1) 356 | words_cnn = self.project_words(words_cnn) 357 | 358 | # STEP2: Words to Bags (Attn words | CNN words) 359 | words_cnn = words_cnn.view(B*N_b, N_w, -1) 360 | 361 | words_attn, words_attnmap = self.attn_over_words(key=words_cnn, query=words_cnn, value=words_cnn) 362 | words_attn = words_attn.view(B, N_b, N_w, -1) 363 | words_attn_energy, words_attn_energy_unnorm = self.energy_function(words_attn, self.words_weight_fn) 364 | # (B, N_B, N_W, C) * (B, N_B, N_W, 1) --> (B, N_B, C) 365 | bags_from_attn_words = torch.sum(words_attn * words_attn_energy, dim=-2) 366 | bags_from_attn_words = self.ffn_attn_w2b_lst(bags_from_attn_words) 367 | 368 | mask_words_pred = self.words_classifier(words_attn) # (B, N_B, N_W, 6) 369 | 370 | # STEP3: Bags to Image 371 | bags_attn, bags_attnmap = self.attn_over_bags(key=bags_from_attn_words, query=bags_from_attn_words, value=bags_from_attn_words) 372 | bags_energy, bags_energy_unnorm = self.energy_function(bags_attn, self.bags_weight_fn) 373 | image_from_bags = torch.sum(bags_attn * bags_energy, dim=-2) 374 | image_from_bags = self.ffn_b2i(image_from_bags) 375 | 376 | mask_bags_pred = self.bags_classifier(bags_attn) # (B, N_B, 6) 377 | 378 | return { 379 | "feat": image_from_bags, 380 | "mask_words_pred": mask_words_pred, 381 | "mask_bags_pred": mask_bags_pred, 382 | #"bl_pred": self.bl_classifier(image_from_bags) 383 | } -------------------------------------------------------------------------------- /model/feature_extractors/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/model/feature_extractors/.DS_Store -------------------------------------------------------------------------------- /model/feature_extractors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/model/feature_extractors/__init__.py -------------------------------------------------------------------------------- /model/feature_extractors/mnasnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | ''' 8 | The code is taken from official Pytorch repo 9 | 10 | https://github.com/pytorch/vision/blob/master/torchvision/models/mnasnet.py 11 | ''' 12 | 13 | 14 | # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is 15 | # 1.0 - tensorflow. 16 | _BN_MOMENTUM = 1 - 0.9997 17 | 18 | 19 | class _InvertedResidual(nn.Module): 20 | 21 | def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, 22 | bn_momentum=0.1): 23 | super(_InvertedResidual, self).__init__() 24 | assert stride in [1, 2] 25 | assert kernel_size in [3, 5] 26 | mid_ch = in_ch * expansion_factor 27 | self.apply_residual = (in_ch == out_ch and stride == 1) 28 | self.layers = nn.Sequential( 29 | # Pointwise 30 | nn.Conv2d(in_ch, mid_ch, 1, bias=False), 31 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 32 | nn.ReLU(inplace=True), 33 | # Depthwise 34 | nn.Conv2d(mid_ch, mid_ch, kernel_size, padding=kernel_size // 2, 35 | stride=stride, groups=mid_ch, bias=False), 36 | nn.BatchNorm2d(mid_ch, momentum=bn_momentum), 37 | nn.ReLU(inplace=True), 38 | # Linear pointwise. Note that there's no activation. 39 | nn.Conv2d(mid_ch, out_ch, 1, bias=False), 40 | nn.BatchNorm2d(out_ch, momentum=bn_momentum)) 41 | 42 | def forward(self, input): 43 | if self.apply_residual: 44 | return self.layers(input) + input 45 | else: 46 | return self.layers(input) 47 | 48 | 49 | def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, 50 | bn_momentum): 51 | """ Creates a stack of inverted residuals. """ 52 | assert repeats >= 1 53 | # First one has no skip, because feature map size changes. 54 | first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, 55 | bn_momentum=bn_momentum) 56 | remaining = [] 57 | for _ in range(1, repeats): 58 | remaining.append( 59 | _InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, 60 | bn_momentum=bn_momentum)) 61 | return nn.Sequential(first, *remaining) 62 | 63 | 64 | def _round_to_multiple_of(val, divisor, round_up_bias=0.9): 65 | """ Asymmetric rounding to make `val` divisible by `divisor`. With default 66 | bias, will round up, unless the number is no more than 10% greater than the 67 | smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88. """ 68 | assert 0.0 < round_up_bias < 1.0 69 | new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) 70 | return new_val if new_val >= round_up_bias * val else new_val + divisor 71 | 72 | 73 | def _get_depths(alpha): 74 | """ Scales tensor depths as in reference MobileNet code, prefers rouding up 75 | rather than down. """ 76 | depths = [32, 16, 24, 40, 80, 96, 192, 320] 77 | return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] 78 | 79 | 80 | class MNASNet(torch.nn.Module): 81 | """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This 82 | implements the B1 variant of the model. 83 | >>> model = MNASNet(1000, 1.0) 84 | >>> x = torch.rand(1, 3, 224, 224) 85 | >>> y = model(x) 86 | >>> y.dim() 87 | 1 88 | >>> y.nelement() 89 | 1000 90 | """ 91 | # Version 2 adds depth scaling in the initial stages of the network. 92 | _version = 2 93 | 94 | def __init__(self, alpha, num_classes=1000, dropout=0.2): 95 | super(MNASNet, self).__init__() 96 | assert alpha > 0.0 97 | self.alpha = alpha 98 | self.num_classes = num_classes 99 | depths = _get_depths(alpha) 100 | layers = [ 101 | # First layer: regular conv. 102 | nn.Conv2d(3, depths[0], 3, padding=1, stride=2, bias=False), 103 | nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), 104 | nn.ReLU(inplace=True), 105 | # Depthwise separable, no skip. 106 | nn.Conv2d(depths[0], depths[0], 3, padding=1, stride=1, 107 | groups=depths[0], bias=False), 108 | nn.BatchNorm2d(depths[0], momentum=_BN_MOMENTUM), 109 | nn.ReLU(inplace=True), 110 | nn.Conv2d(depths[0], depths[1], 1, padding=0, stride=1, bias=False), 111 | nn.BatchNorm2d(depths[1], momentum=_BN_MOMENTUM), 112 | # MNASNet blocks: stacks of inverted residuals. 113 | _stack(depths[1], depths[2], 3, 2, 3, 3, _BN_MOMENTUM), 114 | _stack(depths[2], depths[3], 5, 2, 3, 3, _BN_MOMENTUM), 115 | _stack(depths[3], depths[4], 5, 2, 6, 3, _BN_MOMENTUM), 116 | _stack(depths[4], depths[5], 3, 1, 6, 2, _BN_MOMENTUM), 117 | _stack(depths[5], depths[6], 5, 2, 6, 4, _BN_MOMENTUM), 118 | _stack(depths[6], depths[7], 3, 1, 6, 1, _BN_MOMENTUM), 119 | # Final mapping to classifier input. 120 | nn.Conv2d(depths[7], 1280, 1, padding=0, stride=1, bias=False), 121 | nn.BatchNorm2d(1280, momentum=_BN_MOMENTUM), 122 | nn.ReLU(inplace=True), 123 | ] 124 | self.layers = nn.Sequential(*layers) 125 | self.classifier = nn.Sequential(nn.Dropout(p=dropout, inplace=True), 126 | nn.Linear(1280, num_classes)) 127 | self._initialize_weights() 128 | self.last_channel = 1280 129 | 130 | def forward(self, x): 131 | x = self.layers(x) 132 | # Equivalent to global avgpool and removing H and W dimensions. 133 | x = x.mean([2, 3]) 134 | return x 135 | # return self.classifier(x) 136 | 137 | def _initialize_weights(self): 138 | for m in self.modules(): 139 | if isinstance(m, nn.Conv2d): 140 | nn.init.kaiming_normal_(m.weight, mode="fan_out", 141 | nonlinearity="relu") 142 | if m.bias is not None: 143 | nn.init.zeros_(m.bias) 144 | elif isinstance(m, nn.BatchNorm2d): 145 | nn.init.ones_(m.weight) 146 | nn.init.zeros_(m.bias) 147 | elif isinstance(m, nn.Linear): 148 | nn.init.kaiming_uniform_(m.weight, mode="fan_out", 149 | nonlinearity="sigmoid") 150 | nn.init.zeros_(m.bias) 151 | 152 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 153 | missing_keys, unexpected_keys, error_msgs): 154 | version = local_metadata.get("version", None) 155 | assert version in [1, 2] 156 | 157 | if version == 1 and not self.alpha == 1.0: 158 | # In the initial version of the model (v1), stem was fixed-size. 159 | # All other layer configurations were the same. This will patch 160 | # the model so that it's identical to v1. Model with alpha 1.0 is 161 | # unaffected. 162 | depths = _get_depths(self.alpha) 163 | v1_stem = [ 164 | nn.Conv2d(3, 32, 3, padding=1, stride=2, bias=False), 165 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 166 | nn.ReLU(inplace=True), 167 | nn.Conv2d(32, 32, 3, padding=1, stride=1, groups=32, 168 | bias=False), 169 | nn.BatchNorm2d(32, momentum=_BN_MOMENTUM), 170 | nn.ReLU(inplace=True), 171 | nn.Conv2d(32, 16, 1, padding=0, stride=1, bias=False), 172 | nn.BatchNorm2d(16, momentum=_BN_MOMENTUM), 173 | _stack(16, depths[2], 3, 2, 3, 3, _BN_MOMENTUM), 174 | ] 175 | for idx, layer in enumerate(v1_stem): 176 | self.layers[idx] = layer 177 | 178 | # The model is now identical to v1, and must be saved as such. 179 | self._version = 1 180 | warnings.warn( 181 | "A new version of MNASNet model has been implemented. " 182 | "Your checkpoint was saved using the previous version. " 183 | "This checkpoint will load and work as before, but " 184 | "you may want to upgrade by training a newer model or " 185 | "transfer learning from an updated ImageNet checkpoint.", 186 | UserWarning) 187 | 188 | super(MNASNet, self)._load_from_state_dict( 189 | state_dict, prefix, local_metadata, strict, missing_keys, 190 | unexpected_keys, error_msgs) 191 | 192 | 193 | if __name__ == '__main__': 194 | model = MNASNet(alpha=1) 195 | print(model) -------------------------------------------------------------------------------- /model/nn_layers/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/model/nn_layers/.DS_Store -------------------------------------------------------------------------------- /model/nn_layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/model/nn_layers/__init__.py -------------------------------------------------------------------------------- /model/nn_layers/attn_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from model.nn_layers.ffn import FFN 4 | from model.nn_layers.multi_head_attn import MultiHeadAttn 5 | 6 | 7 | ''' 8 | Adapted from OpenNMT-Py 9 | ''' 10 | 11 | class SelfAttention(nn.Module): 12 | ''' 13 | This class implements the transformer block with multi-head attention and Feed forward network 14 | ''' 15 | def __init__(self, in_dim, num_heads=8, p=0.1, *args, **kwargs): 16 | super(SelfAttention, self).__init__() 17 | self.self_attn = MultiHeadAttn(input_dim=in_dim, out_dim=in_dim, num_heads=num_heads) 18 | 19 | self.ffn = FFN(in_dim, scale=4, p=p, expansion=True) 20 | 21 | self.layer_norm_1 = nn.LayerNorm(in_dim, eps=1e-6) 22 | self.drop = nn.Dropout(p=p) 23 | 24 | def forward(self, x, need_attn=False): 25 | ''' 26 | :param x: Input (bags or words) 27 | :param need_attn: Need attention weights or not 28 | :return: returns the self attention output and attention weights (optional) 29 | ''' 30 | x_norm = self.layer_norm_1(x) 31 | 32 | context, attn = self.self_attn(x_norm, x_norm, x_norm, need_attn=need_attn) 33 | 34 | out = self.drop(context) + x 35 | return self.ffn(out), attn 36 | 37 | class MultiWaySelfAttention(nn.Module): 38 | ''' 39 | This class implements the transformer block with multi-head attention and Feed forward network 40 | ''' 41 | def __init__(self, in_dim, num_heads=8, p=0.1, num_way=4, *args, **kwargs): 42 | super(MultiWaySelfAttention, self).__init__() 43 | self.self_attn = MultiHeadAttn(input_dim=in_dim, out_dim=in_dim, num_heads=num_heads) 44 | 45 | self.way_cls_fn = nn.Linear(in_dim, num_way) 46 | self.ffn_lst = nn.ModuleList([FFN(in_dim, scale=4, p=p, expansion=True) for _ in range(num_way)]) 47 | 48 | self.layer_norm_1 = nn.LayerNorm(in_dim, eps=1e-6) 49 | self.drop = nn.Dropout(p=p) 50 | 51 | self.num_way = num_way 52 | 53 | def forward(self, x, w, need_attn=False): 54 | ''' 55 | :param x: Input (bags or words) 56 | :param need_attn: Need attention weights or not 57 | :return: returns the self attention output and attention weights (optional) 58 | ''' 59 | # x: (B, N, C) 60 | # w: (B, N) 61 | x_norm = self.layer_norm_1(x) 62 | 63 | context, attn = self.self_attn(x_norm, x_norm, x_norm, need_attn=need_attn) 64 | 65 | out = self.drop(context) + x 66 | 67 | ffn_out = None 68 | for ci in range(self.num_way): 69 | ci_ffn_out = self.ffn_lst[ci](out) 70 | ci_mask = (w == ci).float().unsqueeze(dim=-1) 71 | if ffn_out is None: 72 | ffn_out = ci_ffn_out 73 | else: 74 | ffn_out = ffn_out * (1.0 - ci_mask) + ci_ffn_out * ci_mask 75 | 76 | # if len(x.shape) == 3: # (B, N, C) 77 | # ffn_out = [] 78 | # for b in range(out.shape[0]): 79 | # b_ffn_out = [] 80 | # for n in range(out.shape[1]): 81 | # b_ffn_out.append(self.ffn_lst[int(w[b, n].item())](out[b:b+1,n:n+1])) 82 | # b_ffn_out = torch.cat(b_ffn_out, dim=1) 83 | # ffn_out.append(b_ffn_out) 84 | # ffn_out = torch.cat(ffn_out, dim=0) 85 | # elif len(x.shape) == 4: # (B, N1, N2, C) 86 | # ffn_out = [] 87 | # for b in range(out.shape[0]): 88 | # b_ffn_out = [] 89 | # for n1 in range(out.shape[1]): 90 | # b_n1_ffn_out = [] 91 | # for n2 in range(out.shape[2]): 92 | # b_n1_ffn_out.append(self.ffn_lst[int(w[b, n1, n2].item())](out[b:b+1,n1:n1+1,n2:n2+1])) 93 | # b_n1_ffn_out = torch.cat(b_n1_ffn_out, dim=2) 94 | # b_ffn_out.append(b_n1_ffn_out) 95 | # b_ffn_out = torch.cat(b_ffn_out, dim=1) 96 | # ffn_out.append(b_ffn_out) 97 | # ffn_out = torch.cat(ffn_out, dim=0) 98 | return ffn_out, self.way_cls_fn(out), attn 99 | 100 | 101 | class ContextualAttention(torch.nn.Module): 102 | ''' 103 | This class implements the contextual attention. 104 | For example, we used this class to compute bag-to-bag attention where 105 | one set of bag is directly from CNN, while the other set of bag is obtained after self-attention 106 | ''' 107 | def __init__(self, in_dim, num_heads=8, p=0.1, *args, **kwargs): 108 | super(ContextualAttention, self).__init__() 109 | self.self_attn = MultiHeadAttn(input_dim=in_dim, out_dim=in_dim, num_heads=num_heads) 110 | 111 | self.context_norm = nn.LayerNorm(in_dim) 112 | self.context_attn = MultiHeadAttn(input_dim=in_dim, out_dim=in_dim, num_heads=num_heads) 113 | self.ffn = FFN(in_dim, scale=4, p=p, expansion=True) 114 | 115 | self.input_norm = nn.LayerNorm(in_dim, eps=1e-6) 116 | self.query_norm = nn.LayerNorm(in_dim, eps=1e-6) 117 | self.drop = nn.Dropout(p=p) 118 | 119 | def forward(self, input, context, need_attn=False): 120 | ''' 121 | :param input: Tensor of shape (B x N_b x N_w x CNN_DIM) or (B x N_b x CNN_DIM) 122 | :param context: Tensor of shape (B x N_b x N_w x hist_dim) or (B x N_b x hist_dim) 123 | :return: 124 | ''' 125 | 126 | # Self attention on Input features 127 | input_norm = self.input_norm(input) 128 | query, _ = self.self_attn(input_norm, input_norm, input_norm, need_attn=need_attn) 129 | query = self.drop(query) + input 130 | query_norm = self.query_norm(query) 131 | 132 | # Contextual attention 133 | context_norm = self.context_norm(context) 134 | mid, contextual_attn = self.context_attn(context_norm, context_norm, query_norm, need_attn= need_attn) 135 | output = self.ffn(self.drop(mid) + input) 136 | 137 | return output, contextual_attn 138 | 139 | class MultiWayContextualAttention(torch.nn.Module): 140 | ''' 141 | This class implements the contextual attention. 142 | For example, we used this class to compute bag-to-bag attention where 143 | one set of bag is directly from CNN, while the other set of bag is obtained after self-attention 144 | ''' 145 | def __init__(self, in_dim, num_heads=8, p=0.1, num_way=6, *args, **kwargs): 146 | super(MultiWayContextualAttention, self).__init__() 147 | self.self_attn = MultiHeadAttn(input_dim=in_dim, out_dim=in_dim, num_heads=num_heads) 148 | 149 | self.context_norm = nn.LayerNorm(in_dim) 150 | self.context_attn = MultiHeadAttn(input_dim=in_dim, out_dim=in_dim, num_heads=num_heads) 151 | 152 | self.way_cls_fn = nn.Linear(in_dim, num_way) 153 | self.ffn_lst = nn.ModuleList([FFN(in_dim, scale=4, p=p, expansion=True) for _ in range(num_way)]) 154 | 155 | self.input_norm = nn.LayerNorm(in_dim, eps=1e-6) 156 | self.query_norm = nn.LayerNorm(in_dim, eps=1e-6) 157 | self.drop = nn.Dropout(p=p) 158 | 159 | self.num_way = num_way 160 | 161 | def forward(self, input, context, w, need_attn=False): 162 | ''' 163 | :param input: Tensor of shape (B x N_b x N_w x CNN_DIM) or (B x N_b x CNN_DIM) 164 | :param context: Tensor of shape (B x N_b x N_w x hist_dim) or (B x N_b x hist_dim) 165 | :return: 166 | ''' 167 | 168 | # Self attention on Input features 169 | input_norm = self.input_norm(input) 170 | query, _ = self.self_attn(input_norm, input_norm, input_norm, need_attn=need_attn) 171 | query = self.drop(query) + input 172 | query_norm = self.query_norm(query) 173 | 174 | # Contextual attention 175 | context_norm = self.context_norm(context) 176 | mid, contextual_attn = self.context_attn(context_norm, context_norm, query_norm, need_attn= need_attn) 177 | 178 | #output = self.ffn(self.drop(mid) + input) 179 | tmp = self.drop(mid) + input 180 | 181 | output = None 182 | for ci in range(self.num_way): 183 | ci_output = self.ffn_lst[ci](tmp) # (B, N, C) 184 | ci_mask = (w == ci).float().unsqueeze(dim=-1) # (B, N, 1) 185 | if output is None: 186 | output = ci_output 187 | else: 188 | output = output * (1.0 - ci_mask) + ci_output * ci_mask 189 | 190 | # output = [] 191 | # for b in range(tmp.shape[0]): 192 | # b_output = [] 193 | # for n in range(tmp.shape[1]): 194 | # b_output.append(self.ffn_lst[int(w[b, n].item())](tmp[b:b+1,n:n+1])) 195 | # b_output = torch.cat(b_output, dim=1) 196 | # output.append(b_output) 197 | # output = torch.cat(output, dim=0) 198 | 199 | return output, self.way_cls_fn(tmp), contextual_attn 200 | -------------------------------------------------------------------------------- /model/nn_layers/eesp.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from model.nn_layers.espnet_utils import * 3 | import torch 4 | from model.feature_extractors import espnetv2_config as config 5 | 6 | ''' 7 | Adapted from https://github.com/sacmehta/EdgeNets 8 | 9 | This file implements ESPNetv2 blocks 10 | ''' 11 | 12 | config_inp_reinf = config.config_inp_reinf 13 | 14 | class EESP(nn.Module): 15 | ''' 16 | This class defines the EESP block, which is based on the following principle 17 | REDUCE ---> SPLIT ---> TRANSFORM --> MERGE 18 | ''' 19 | 20 | def __init__(self, nIn, nOut, stride=1, k=4, r_lim=7, down_method='esp'): #down_method --> ['avg' or 'esp'] 21 | ''' 22 | :param nIn: number of input channels 23 | :param nOut: number of output channels 24 | :param stride: factor by which we should skip (useful for down-sampling). If 2, then down-samples the feature map by 2 25 | :param k: # of parallel branches 26 | :param r_lim: A maximum value of receptive field allowed for EESP block 27 | :param down_method: Downsample or not (equivalent to say stride is 2 or not) 28 | ''' 29 | super().__init__() 30 | self.stride = stride 31 | n = int(nOut / k) 32 | n1 = nOut - (k - 1) * n 33 | assert down_method in ['avg', 'esp'], 'One of these is suppported (avg or esp)' 34 | assert n == n1, "n(={}) and n1(={}) should be equal for Depth-wise Convolution ".format(n, n1) 35 | self.proj_1x1 = CBR(nIn, n, 1, stride=1, groups=k) 36 | 37 | # (For convenience) Mapping between dilation rate and receptive field for a 3x3 kernel 38 | map_receptive_ksize = {3: 1, 5: 2, 7: 3, 9: 4, 11: 5, 13: 6, 15: 7, 17: 8} 39 | self.k_sizes = list() 40 | for i in range(k): 41 | ksize = int(3 + 2 * i) 42 | # After reaching the receptive field limit, fall back to the base kernel size of 3 with a dilation rate of 1 43 | ksize = ksize if ksize <= r_lim else 3 44 | self.k_sizes.append(ksize) 45 | # sort (in ascending order) these kernel sizes based on their receptive field 46 | # This enables us to ignore the kernels (3x3 in our case) with the same effective receptive field in hierarchical 47 | # feature fusion because kernels with 3x3 receptive fields does not have gridding artifact. 48 | self.k_sizes.sort() 49 | self.spp_dw = nn.ModuleList() 50 | for i in range(k): 51 | d_rate = map_receptive_ksize[self.k_sizes[i]] 52 | self.spp_dw.append(CDilated(n, n, kSize=3, stride=stride, groups=n, d=d_rate)) 53 | # Performing a group convolution with K groups is the same as performing K point-wise convolutions 54 | self.conv_1x1_exp = CB(nOut, nOut, 1, 1, groups=k) 55 | self.br_after_cat = BR(nOut) 56 | self.module_act = nn.PReLU(nOut) 57 | self.downAvg = True if down_method == 'avg' else False 58 | 59 | def forward(self, input): 60 | ''' 61 | :param input: input feature map 62 | :return: transformed feature map 63 | ''' 64 | 65 | # Reduce --> project high-dimensional feature maps to low-dimensional space 66 | output1 = self.proj_1x1(input) 67 | output = [self.spp_dw[0](output1)] 68 | # compute the output for each branch and hierarchically fuse them 69 | # i.e. Split --> Transform --> HFF 70 | for k in range(1, len(self.spp_dw)): 71 | out_k = self.spp_dw[k](output1) 72 | # HFF 73 | out_k = out_k + output[k - 1] 74 | output.append(out_k) 75 | # Merge 76 | expanded = self.conv_1x1_exp( # learn linear combinations using group point-wise convolutions 77 | self.br_after_cat( # apply batch normalization followed by activation function (PRelu in this case) 78 | torch.cat(output, 1) # concatenate the output of different branches 79 | ) 80 | ) 81 | del output 82 | # if down-sampling, then return the concatenated vector 83 | # because Downsampling function will combine it with avg. pooled feature map and then threshold it 84 | if self.stride == 2 and self.downAvg: 85 | return expanded 86 | 87 | # if dimensions of input and concatenated vector are the same, add them (RESIDUAL LINK) 88 | if expanded.size() == input.size(): 89 | expanded = expanded + input 90 | 91 | # Threshold the feature map using activation function (PReLU in this case) 92 | return self.module_act(expanded) 93 | 94 | 95 | class DownSampler(nn.Module): 96 | ''' 97 | Down-sampling fucntion that has three parallel branches: (1) avg pooling, 98 | (2) EESP block with stride of 2 and (3) efficient long-range connection with the input. 99 | The output feature maps of branches from (1) and (2) are concatenated and then additively fused with (3) to produce 100 | the final output. 101 | ''' 102 | 103 | def __init__(self, nin, nout, k=4, r_lim=9, reinf=True): 104 | ''' 105 | :param nin: number of input channels 106 | :param nout: number of output channels 107 | :param k: # of parallel branches 108 | :param r_lim: A maximum value of receptive field allowed for EESP block 109 | :param reinf: Use long range shortcut connection with the input or not. 110 | ''' 111 | super().__init__() 112 | nout_new = nout - nin 113 | self.eesp = EESP(nin, nout_new, stride=2, k=k, r_lim=r_lim, down_method='avg') 114 | self.avg = nn.AvgPool2d(kernel_size=3, padding=1, stride=2) 115 | if reinf: 116 | self.inp_reinf = nn.Sequential( 117 | CBR(config_inp_reinf, config_inp_reinf, 3, 1), 118 | CB(config_inp_reinf, nout, 1, 1) 119 | ) 120 | self.act = nn.PReLU(nout) 121 | 122 | def forward(self, input, input2=None): 123 | ''' 124 | :param input: input feature map 125 | :return: feature map down-sampled by a factor of 2 126 | ''' 127 | avg_out = self.avg(input) 128 | eesp_out = self.eesp(input) 129 | output = torch.cat([avg_out, eesp_out], 1) 130 | 131 | if input2 is not None: 132 | #assuming the input is a square image 133 | # Shortcut connection with the input image 134 | w1 = avg_out.size(2) 135 | while True: 136 | input2 = F.avg_pool2d(input2, kernel_size=3, padding=1, stride=2) 137 | w2 = input2.size(2) 138 | if w2 == w1: 139 | break 140 | output = output + self.inp_reinf(input2) 141 | 142 | return self.act(output) -------------------------------------------------------------------------------- /model/nn_layers/espnet_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | ''' 4 | Adapted from https://github.com/sacmehta/EdgeNets 5 | ''' 6 | 7 | class CBR(nn.Module): 8 | ''' 9 | This class defines the convolution layer with batch normalization and PReLU activation 10 | ''' 11 | 12 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 13 | ''' 14 | 15 | :param nIn: number of input channels 16 | :param nOut: number of output channels 17 | :param kSize: kernel size 18 | :param stride: stride rate for down-sampling. Default is 1 19 | ''' 20 | super().__init__() 21 | padding = int((kSize - 1) / 2) 22 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, groups=groups) 23 | self.bn = nn.BatchNorm2d(nOut) 24 | self.act = nn.PReLU(nOut) 25 | 26 | def forward(self, input): 27 | ''' 28 | :param input: input feature map 29 | :return: transformed feature map 30 | ''' 31 | output = self.conv(input) 32 | # output = self.conv1(output) 33 | output = self.bn(output) 34 | output = self.act(output) 35 | return output 36 | 37 | 38 | class BR(nn.Module): 39 | ''' 40 | This class groups the batch normalization and PReLU activation 41 | ''' 42 | 43 | def __init__(self, nOut): 44 | ''' 45 | :param nOut: output feature maps 46 | ''' 47 | super().__init__() 48 | self.bn = nn.BatchNorm2d(nOut) 49 | self.act = nn.PReLU(nOut) 50 | 51 | def forward(self, input): 52 | ''' 53 | :param input: input feature map 54 | :return: normalized and thresholded feature map 55 | ''' 56 | output = self.bn(input) 57 | output = self.act(output) 58 | return output 59 | 60 | 61 | class CB(nn.Module): 62 | ''' 63 | This class groups the convolution and batch normalization 64 | ''' 65 | 66 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 67 | ''' 68 | :param nIn: number of input channels 69 | :param nOut: number of output channels 70 | :param kSize: kernel size 71 | :param stride: optinal stide for down-sampling 72 | ''' 73 | super().__init__() 74 | padding = int((kSize - 1) / 2) 75 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, 76 | groups=groups) 77 | self.bn = nn.BatchNorm2d(nOut) 78 | 79 | def forward(self, input): 80 | ''' 81 | 82 | :param input: input feature map 83 | :return: transformed feature map 84 | ''' 85 | output = self.conv(input) 86 | output = self.bn(output) 87 | return output 88 | 89 | 90 | class C(nn.Module): 91 | ''' 92 | This class is for a convolutional layer. 93 | ''' 94 | 95 | def __init__(self, nIn, nOut, kSize, stride=1, groups=1): 96 | ''' 97 | 98 | :param nIn: number of input channels 99 | :param nOut: number of output channels 100 | :param kSize: kernel size 101 | :param stride: optional stride rate for down-sampling 102 | ''' 103 | super().__init__() 104 | padding = int((kSize - 1) / 2) 105 | self.conv = nn.Conv2d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, 106 | groups=groups) 107 | 108 | def forward(self, input): 109 | ''' 110 | :param input: input feature map 111 | :return: transformed feature map 112 | ''' 113 | output = self.conv(input) 114 | return output 115 | 116 | 117 | class CDilated(nn.Module): 118 | ''' 119 | This class defines the dilated convolution. 120 | ''' 121 | 122 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1): 123 | ''' 124 | :param nIn: number of input channels 125 | :param nOut: number of output channels 126 | :param kSize: kernel size 127 | :param stride: optional stride rate for down-sampling 128 | :param d: optional dilation rate 129 | ''' 130 | super().__init__() 131 | padding = int((kSize - 1) / 2) * d 132 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False, 133 | dilation=d, groups=groups) 134 | 135 | def forward(self, input): 136 | ''' 137 | :param input: input feature map 138 | :return: transformed feature map 139 | ''' 140 | output = self.conv(input) 141 | return output 142 | 143 | class CDilatedB(nn.Module): 144 | ''' 145 | This class defines the dilated convolution with batch normalization. 146 | ''' 147 | 148 | def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1): 149 | ''' 150 | :param nIn: number of input channels 151 | :param nOut: number of output channels 152 | :param kSize: kernel size 153 | :param stride: optional stride rate for down-sampling 154 | :param d: optional dilation rate 155 | ''' 156 | super().__init__() 157 | padding = int((kSize - 1) / 2) * d 158 | self.conv = nn.Conv2d(nIn, nOut,kSize, stride=stride, padding=padding, bias=False, 159 | dilation=d, groups=groups) 160 | self.bn = nn.BatchNorm2d(nOut) 161 | 162 | def forward(self, input): 163 | ''' 164 | :param input: input feature map 165 | :return: transformed feature map 166 | ''' 167 | return self.bn(self.conv(input)) 168 | -------------------------------------------------------------------------------- /model/nn_layers/ffn.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This file implements the Feed forward network 3 | Adapted from OpenNMT-Py 4 | ''' 5 | 6 | from torch import nn 7 | 8 | class FFN(nn.Module): 9 | def __init__(self, input_dim, scale, output_dim=None, p=0.1, expansion=False): 10 | super(FFN, self).__init__() 11 | output_dim = input_dim if output_dim is None else output_dim 12 | 13 | proj_features = input_dim * scale if expansion else input_dim // scale 14 | self.w_1 = nn.Linear(input_dim, proj_features) 15 | self.w_2 = nn.Linear(proj_features, output_dim) 16 | 17 | self.layer_norm = nn.LayerNorm(input_dim, eps=1e-6) 18 | self.dropout_1 = nn.Dropout(p) 19 | self.relu = nn.ReLU() 20 | self.dropout_2 = nn.Dropout(p) 21 | self.residual = True if input_dim == output_dim else False 22 | 23 | def forward(self, x): 24 | """Layer definition. 25 | Args: 26 | x: ``(batch_size, input_len, model_dim)`` 27 | Returns: 28 | (FloatTensor): Output ``(batch_size, input_len, model_dim)``. 29 | """ 30 | 31 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 32 | output = self.dropout_2(self.w_2(inter)) 33 | return output + x if self.residual else output -------------------------------------------------------------------------------- /model/nn_layers/multi_head_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import numpy as np 5 | from torch import Tensor 6 | from typing import Optional 7 | 8 | ''' 9 | This file implements the Multi-head attention 10 | Adapted from OpenNMT-Py 11 | ''' 12 | 13 | 14 | class MultiHeadAttn(torch.nn.Module): 15 | def __init__(self, input_dim, out_dim, num_heads=8, dropout=0.1, *args, **kwargs): 16 | super(MultiHeadAttn, self).__init__() 17 | assert input_dim % num_heads == 0 18 | 19 | self.num_heads = num_heads 20 | self.dim_per_head = input_dim // num_heads 21 | 22 | self.linear_keys = nn.Linear(input_dim, num_heads * self.dim_per_head) 23 | self.linear_values = nn.Linear(input_dim, num_heads * self.dim_per_head) 24 | self.linear_query = nn.Linear(input_dim, num_heads * self.dim_per_head) 25 | 26 | self.softmax = nn.Softmax(dim=-1) 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | self.final_linear = nn.Linear(input_dim, out_dim) 30 | 31 | self.scaling_factor = math.sqrt(self.dim_per_head) 32 | 33 | def forward(self, key, value, query, need_attn=False): 34 | ''' 35 | :param key: A tensor of shape [B x N_b x N_w x d] or [B x N_b x d] 36 | :param value: A tensor of shape [B x N_b x N_w x d] or [B x N_b x d] 37 | :param query: A tensor of shape [B x N_b x N_w x d] or [B x N_b x d] 38 | :param need_attn: Need attention weights or not 39 | :return: Tuple containing output and mean attention scores across all heads (optional) 40 | Output size is [B x N_b x N_w x d'] or [B x N_b x d'] 41 | Attention score size is [B x N_b*N_w x N_b*N_w] or [B x N_b x N_b] 42 | ''' 43 | dim_size = key.size() 44 | reshape=False 45 | if key.dim() == 4: 46 | # [B x N_b x N_w x d] --> [B x N_b*N_w x d] 47 | key = key.view(dim_size[0], -1, dim_size[3]) 48 | value = key.view(dim_size[0], -1, dim_size[3]) 49 | query = key.view(dim_size[0], -1, dim_size[3]) 50 | reshape = True 51 | 52 | batch_size = key.size(0) 53 | 54 | dim_per_head = self.dim_per_head 55 | head_count = self.num_heads 56 | 57 | key = self.linear_keys(key) 58 | value = self.linear_values(value) 59 | query = self.linear_query(query) 60 | 61 | query = query / self.scaling_factor 62 | 63 | # [B x N_b*N_w x d] --> [B x N_b*N_w x h x d_h] --> [B x h x N_b*N_w x d_h] 64 | key = ( 65 | key.contiguous() 66 | .view(batch_size, -1, head_count, dim_per_head) 67 | .transpose(1, 2) 68 | ) 69 | 70 | value = ( 71 | value.contiguous() 72 | .view(batch_size, -1, head_count, dim_per_head) 73 | .transpose(1, 2) 74 | ) 75 | 76 | query = ( 77 | query.contiguous() 78 | .view(batch_size, -1, head_count, dim_per_head) 79 | .transpose(1, 2) 80 | ) 81 | 82 | # compute attention scores 83 | # [B x h x N_b*N_w x d_h] x [B x h x d_h x N_b*N_w] --> [B x h x N_b*N_w x N_b*N_w] 84 | scores = torch.matmul(query, key.transpose(2, 3)).float() 85 | 86 | attn = self.softmax(scores).to(query.dtype) 87 | drop_attn = self.dropout(attn) 88 | 89 | # [B x h x N_b*N_w x N_b*N_w] x [B x h x N_b*N_w x d_h] --> [B x h x N_b*N_w x d_h] 90 | context = torch.matmul(drop_attn, value) 91 | 92 | # [B x h x N_b*N_w x d_h] --> [B x N_b*N_w x h x d_h] --> [B x N_b*N_w x h*d_h] 93 | context = ( 94 | context.transpose(1, 2) 95 | .contiguous().view(batch_size, -1, head_count * dim_per_head) 96 | ) 97 | 98 | output = self.final_linear(context) 99 | 100 | attn_scores: Tensor[Optional] = None 101 | if need_attn: 102 | attn_scores = torch.mean(scores, dim=1) 103 | 104 | if reshape: 105 | # [B x N_b*N_w x d] --> [B x N_b x N_w x d] 106 | output = output.contiguous().view(dim_size[0], dim_size[1], dim_size[2], -1).contiguous() 107 | 108 | return output, attn_scores -------------------------------------------------------------------------------- /model/nn_layers/transformer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torchvision.models as models 4 | import os 5 | from torch.nn import functional as F 6 | import numpy as np 7 | from torch.nn import init 8 | 9 | def split_last(x, shape): 10 | "split the last dimension to given shape" 11 | shape = list(shape) 12 | assert shape.count(-1) <= 1 13 | if -1 in shape: 14 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 15 | return x.view(*x.size()[:-1], *shape) 16 | 17 | def merge_last(x, n_dims): 18 | "merge the last n_dims to a dimension" 19 | s = x.size() 20 | assert n_dims > 1 and n_dims < len(s) 21 | return x.view(*s[:-n_dims], -1) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, out_dim, dropout): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.LayerNorm(dim), 28 | nn.Linear(dim, out_dim), 29 | nn.GELU(), 30 | nn.Dropout(dropout), 31 | nn.Linear(out_dim, out_dim), 32 | nn.Dropout(dropout), 33 | ) 34 | 35 | def forward(self, x): 36 | return self.net(x) 37 | 38 | class MultiHeadedSelfAttention(nn.Module): 39 | """Multi-Headed Dot Product Attention""" 40 | def __init__(self, feat_dim, dim, num_heads, dropout): 41 | super().__init__() 42 | self.proj_q = nn.Linear(feat_dim, dim) 43 | self.proj_k = nn.Linear(feat_dim, dim) 44 | self.proj_v = nn.Linear(feat_dim, dim) 45 | self.drop = nn.Dropout(dropout) 46 | self.n_heads = num_heads 47 | self.scores = None # for visualization 48 | 49 | def forward(self, x, mask): 50 | """ 51 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 52 | mask : (B(batch_size) x S(seq_len)) 53 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 54 | """ 55 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 56 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 57 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) 58 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 59 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 60 | #scores = torch.cosine_similarity(q[:, :, :, None, :], q[:, :, None, :, :], dim=-1) 61 | if mask is not None: 62 | mask = mask[:, None, None, :].float() 63 | scores -= 10000.0 * (1.0 - mask) 64 | attn_map = F.softmax(scores, dim=-1) 65 | scores = self.drop(attn_map) 66 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 67 | h = (scores @ v).transpose(1, 2).contiguous() 68 | # -merge-> (B, S, D) 69 | h = merge_last(h, 2) 70 | self.scores = attn_map 71 | return h 72 | 73 | class Block(nn.Module): 74 | """Transformer Block""" 75 | def __init__(self, dim, num_heads, dropout): 76 | super().__init__() 77 | self.attn = MultiHeadedSelfAttention(dim, dim, num_heads, dropout) 78 | self.attn_norm = nn.LayerNorm(dim) 79 | self.mlp = nn.Linear(dim, dim) 80 | self.mlp_norm = nn.LayerNorm(dim) 81 | self.drop = nn.Dropout(dropout) 82 | 83 | def forward(self, feat_x, mask): 84 | attn_x = self.attn(feat_x, mask) 85 | attn_x = self.attn_norm(attn_x) 86 | attn_x = attn_x + feat_x 87 | mlp_x = self.mlp(attn_x) 88 | mlp_x = self.mlp_norm(mlp_x) 89 | mlp_x = self.drop(mlp_x) 90 | out_x = mlp_x + attn_x 91 | return out_x 92 | 93 | class SimpleTransformer(nn.Module): 94 | def __init__(self, 95 | in_dim, 96 | num_head, 97 | dropout, 98 | num_attn, 99 | merge_token=False): 100 | super().__init__() 101 | self.merge_token = merge_token 102 | if self.merge_token: 103 | self.token = nn.Parameter(torch.zeros(1, 1, in_dim).float()) 104 | self.pe_token = nn.Parameter(torch.zeros(1, 1, in_dim).float()) 105 | else: 106 | self.weight_fc = nn.Linear(in_dim, 1, bias=True) 107 | self.weight = None 108 | 109 | self.attn_layer_lst = nn.ModuleList([ 110 | Block(in_dim, num_head, dropout) for _ in range(num_attn) 111 | ]) 112 | 113 | self.init_weights() 114 | 115 | @torch.no_grad() 116 | def init_weights(self): 117 | def _init(m): 118 | if isinstance(m, nn.Linear): 119 | nn.init.xavier_uniform_(m.weight) 120 | if hasattr(m, 'bias') and m.bias is not None: 121 | nn.init.normal_(m.bias, std=1e-6) # nn.init.constant(m.bias, 0) 122 | self.apply(_init) 123 | if self.merge_token: 124 | nn.init.constant_(self.token, 0) 125 | nn.init.constant_(self.pe_token, 0) 126 | 127 | def forward(self, x, mask, pe): 128 | # x: (B, T, C) 129 | # mask: (B, T) 130 | if self.merge_token: 131 | x = torch.cat([self.token.expand(x.shape[0], 1, -1).to(x.device), x], dim=1) 132 | mask = torch.cat([torch.ones(mask.shape[0], 1).float().to(mask.device), mask], dim=1) 133 | if pe is not None: 134 | pe = torch.cat([self.pe_token.expand(pe.shape[0], 1, -1).to(pe.device), pe], dim=1) 135 | for attn_layer in self.attn_layer_lst: 136 | if pe is not None: 137 | x = x + pe 138 | x = attn_layer(x, mask) 139 | if self.merge_token: 140 | return x[:, 0] 141 | else: 142 | return x -------------------------------------------------------------------------------- /model/yx_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from model.nn_layers.ffn import FFN 5 | from model.nn_layers.attn_layers import * 6 | from typing import NamedTuple, Optional 7 | from torch import Tensor 8 | from torchvision import models 9 | import torch.nn.functional as F 10 | from copy import deepcopy 11 | from model.feature_extractors.mnasnet import MNASNet 12 | from model.nn_layers.transformer import * 13 | 14 | class YXModel(nn.Module): 15 | def __init__(self, opts, *args, **kwargs): 16 | super().__init__() 17 | self.opts = opts 18 | 19 | if opts.yx_cnn_name == "mnasnet": 20 | s = 1.0 21 | weight = 'checkpoints/mnasnet_s_1.0_imagenet_224x224.pth' 22 | backbone = MNASNet(alpha=s) 23 | pretrained_dict = torch.load(weight, map_location=torch.device('cpu')) 24 | backbone.load_state_dict(pretrained_dict) 25 | del backbone.classifier 26 | self.cnn = backbone 27 | else: 28 | backbone = eval(f"models.{opts.yx_cnn_name}")(pretrained=opts.yx_cnn_pretrained) 29 | # if opts.yx_cnn_name == "mnasnet1_0": 30 | if "mnasnet" in opts.yx_cnn_name: 31 | self.cnn = nn.Sequential(*[*list(backbone.children())[:-1], nn.AdaptiveAvgPool2d(1)]) 32 | else: 33 | self.cnn = nn.Sequential(*list(backbone.children())[:-1]) 34 | 35 | self.attn_layer = nn.Conv2d(opts.yx_cnn_features+3, 1, kernel_size=1, padding=0, stride=1) 36 | 37 | self.cnn_project = nn.Linear(opts.yx_cnn_features, opts.yx_out_features) 38 | 39 | self.attn_over_lesions = nn.MultiheadAttention(embed_dim=opts.yx_out_features, 40 | num_heads=1, dropout=opts.yx_dropout, batch_first=True) 41 | self.ffn_attn_l2p = FFN(input_dim=opts.yx_out_features, scale=2, p=opts.yx_dropout) 42 | self.lesions_weight_fn = nn.Linear(opts.yx_out_features, 1, bias=False) 43 | self.attn_dropout = nn.Dropout(p=opts.yx_attn_dropout) 44 | self.attn_fn = nn.Softmax(dim=-1) 45 | 46 | self.lesions_classifier = nn.Linear(opts.yx_out_features, 5) 47 | 48 | # 临床特征融入 49 | self.stomach_fc = nn.Linear(2, opts.yx_out_features) 50 | # self.ln_dis_fc = nn.Linear(9, opts.yx_out_features) 51 | self.ln_num_fc = nn.Linear(7, opts.yx_out_features) 52 | self.zl_ln_fc = nn.Linear(3, opts.yx_out_features) 53 | self.zl_ln_pos_fc = nn.Linear(5, opts.yx_out_features) 54 | self.zl_multi_fc = nn.Linear(4, opts.yx_out_features) 55 | self.zl_per_fc = nn.Linear(2, opts.yx_out_features) 56 | self.clinical_image_attn = nn.MultiheadAttention(embed_dim=opts.yx_out_features, num_heads=1, dropout=opts.yx_dropout, batch_first=True) 57 | 58 | # 组学特征融入 59 | self.radiomics_fc = nn.Linear(736, opts.yx_out_features) 60 | self.radiomics_image_attn = nn.MultiheadAttention(embed_dim=opts.yx_out_features, num_heads=1, dropout=opts.yx_dropout, batch_first=True) 61 | 62 | #self.yx_classifier = nn.Linear(opts.yx_out_features, opts.n_classes) 63 | 64 | def energy_function(self, x, weight_fn, need_attn=False): 65 | # x: (B, N, C) 66 | x = weight_fn(x).squeeze(dim=-1) # (B, N) 67 | energy: Tensor[Optional] = None 68 | if need_attn: 69 | energy = x 70 | x = self.attn_fn(x) 71 | x = self.attn_dropout(x) 72 | x = x.unsqueeze(dim=-1) # (B, N, 1) 73 | return x, energy 74 | 75 | def parallel_radiomics_clinical(self, patient_from_lesions, radiomics_feat, clinical_feat): 76 | radiomics_image_feat, radiomics_attnmap = self.radiomics_image_attn(key=radiomics_feat, 77 | query=patient_from_lesions.unsqueeze(dim=1), value=radiomics_feat) 78 | clinical_image_feat, clinical_attnmap = self.clinical_image_attn(key=clinical_feat, 79 | query=patient_from_lesions.unsqueeze(dim=1), value=clinical_feat) 80 | patient_from_lesions = patient_from_lesions \ 81 | + clinical_image_feat.squeeze(dim=1) \ 82 | + radiomics_image_feat.squeeze(dim=1) 83 | 84 | self.info_dict["clinical_weight"] = clinical_attnmap[0, 0].detach().cpu().numpy() 85 | return patient_from_lesions 86 | 87 | def series_radiomics_clinical(self, patient_from_lesions, radiomics_feat, clinical_feat): 88 | radiomics_image_feat, _ = self.radiomics_image_attn(key=radiomics_feat, 89 | query=patient_from_lesions.unsqueeze(dim=1), value=radiomics_feat) 90 | patient_from_lesions = patient_from_lesions + radiomics_image_feat.squeeze(dim=1) 91 | clinical_image_feat, _ = self.clinical_image_attn(key=clinical_feat, 92 | query=patient_from_lesions.unsqueeze(dim=1), value=clinical_feat) 93 | patient_from_lesions = patient_from_lesions + clinical_image_feat.squeeze(dim=1) 94 | return patient_from_lesions 95 | 96 | def series_clinical_radiomics(self, patient_from_lesions, radiomics_feat, clinical_feat): 97 | clinical_image_feat, _ = self.clinical_image_attn(key=clinical_feat, 98 | query=patient_from_lesions.unsqueeze(dim=1), value=clinical_feat) 99 | patient_from_lesions = patient_from_lesions + clinical_image_feat.squeeze(dim=1) 100 | radiomics_image_feat, _ = self.radiomics_image_attn(key=radiomics_feat, 101 | query=patient_from_lesions.unsqueeze(dim=1), value=radiomics_feat) 102 | patient_from_lesions = patient_from_lesions + radiomics_image_feat.squeeze(dim=1) 103 | return patient_from_lesions 104 | 105 | def forward(self, batch, *args, **kwargs): 106 | lesions = batch["lesions"] # (B, N_l, 3, H, W) 107 | yx_flag = batch["clinical_yx_flag"] # (B, 3) 108 | 109 | B, N_l, C, H, W = lesions.shape 110 | # (B, N_l, 3, H, W) --> (B, N_l, C) 111 | lesions_cnn = self.cnn(lesions.view(B*N_l, C, H, W)) 112 | yx_flag = yx_flag.view(B, 1, 3, 1, 1).repeat(1, N_l, 1, *lesions_cnn.shape[-2:]) 113 | yx_flag = yx_flag.view(B*N_l, 3, *lesions_cnn.shape[-2:]) 114 | attn_lesions_cnn = torch.cat([lesions_cnn, yx_flag], dim=1) 115 | attn_lesions_cnn = torch.sigmoid(self.attn_layer(attn_lesions_cnn)) 116 | lesions_cnn = lesions_cnn * attn_lesions_cnn 117 | lesions_cnn = lesions_cnn.view(B, N_l, -1) 118 | lesions_cnn = self.cnn_project(lesions_cnn) 119 | 120 | self.info_dict = { 121 | "id": batch["id"][0], 122 | "lesions_label": batch["lesions_label"][0].detach().cpu().numpy(), 123 | } 124 | 125 | lesions_attn, lesions_attnmap = self.attn_over_lesions(key=lesions_cnn, query=lesions_cnn, value=lesions_cnn) 126 | self.info_dict["lesions_attnmap"] = lesions_attnmap[0].detach().cpu().numpy() 127 | lesions_attn_energy, lesions_attn_energy_unnorm = self.energy_function(lesions_attn, self.lesions_weight_fn) 128 | # (B, N_l, C) x (B, N_l, 1) --> (B, C) 129 | self.info_dict["lesions_weight"] = lesions_attn_energy[0, ..., 0].detach().cpu().numpy() 130 | patient_from_lesions = torch.sum(lesions_attn * lesions_attn_energy, dim=1) 131 | patient_from_lesions = self.ffn_attn_l2p(patient_from_lesions) 132 | 133 | lesions_pred = self.lesions_classifier(lesions_attn) 134 | 135 | self.lesions_attn_energy_unnorm = lesions_attn_energy_unnorm 136 | 137 | clinical_feat = torch.stack([ 138 | self.stomach_fc(batch["clinical_yx_stomach"]), 139 | # self.ln_dis_fc(batch["clinical_yx_ln_dis"]), 140 | self.ln_num_fc(batch["clinical_yx_ln_num"]), 141 | self.zl_ln_fc(batch["clinical_yx_zl_ln"]), 142 | self.zl_ln_pos_fc(batch["clinical_yx_zl_ln_pos"]), 143 | self.zl_multi_fc(batch["clinical_yx_zl_multi"]), 144 | self.zl_per_fc(batch["clinical_yx_zl_per"]), 145 | ], dim=1) # (B, 7, C) 146 | 147 | radiomics_feat = self.radiomics_fc(batch["yx_radiomics_feat"]) # (B, N_l, 736) --> (B, N_l, C) 148 | 149 | if self.opts.feat_fusion_mode == "parallel": 150 | patient_from_lesions = self.parallel_radiomics_clinical(patient_from_lesions, radiomics_feat, clinical_feat) 151 | elif self.opts.feat_fusion_mode == "series_rc": 152 | patient_from_lesions = self.series_radiomics_clinical(patient_from_lesions, radiomics_feat, clinical_feat) 153 | elif self.opts.feat_fusion_mode == "series_cr": 154 | patient_from_lesions = self.series_clinical_radiomics(patient_from_lesions, radiomics_feat, clinical_feat) 155 | else: 156 | raise NotImplementedError 157 | 158 | if self.opts.attnmap_weight_dir not in [None, "None"]: 159 | npz_dir = os.path.join(self.opts.attnmap_weight_dir, "YX") 160 | os.makedirs(npz_dir, exist_ok=True) 161 | if batch["yx_flag"][0].item(): 162 | np.savez(os.path.join(npz_dir, self.info_dict["id"]+".npz"), 163 | lesions_label=self.info_dict["lesions_label"], 164 | lesions_attnmap=self.info_dict["lesions_attnmap"], 165 | lesions_weight=self.info_dict["lesions_weight"], 166 | clinical_weight=self.info_dict["clinical_weight"]) 167 | 168 | return { 169 | "feat": patient_from_lesions, 170 | "lesions_pred": lesions_pred, 171 | #"yx_pred": self.yx_classifier(patient_from_lesions) 172 | } 173 | 174 | class YXModelOnly(YXModel): 175 | def __init__(self, opts, *args, **kwargs): 176 | super().__init__(opts, *args, **kwargs) 177 | self.classifier = nn.Linear(opts.yx_out_features, opts.n_classes) 178 | 179 | def forward(self, batch, *args, **kwargs): 180 | yx_results = super().forward(batch, *args, **kwargs) 181 | yx_pred = self.classifier(yx_results["feat"]) 182 | return { 183 | "pred": yx_pred, 184 | "feat": yx_results["feat"], 185 | "lesions_pred": yx_results["lesions_pred"], 186 | } 187 | 188 | class YXModelNoneClinicalRadiomics(YXModel): 189 | def forward(self, batch, *args, **kwargs): 190 | lesions = batch["lesions"] # (B, N_l, 3, H, W) 191 | yx_flag = batch["clinical_yx_flag"] # (B, 3) 192 | 193 | B, N_l, C, H, W = lesions.shape 194 | # (B, N_l, 3, H, W) --> (B, N_l, C) 195 | lesions_cnn = self.cnn(lesions.view(B*N_l, C, H, W)) 196 | yx_flag = yx_flag.view(B, 1, 3, 1, 1).repeat(1, N_l, 1, *lesions_cnn.shape[-2:]) 197 | yx_flag = yx_flag.view(B*N_l, 3, *lesions_cnn.shape[-2:]) 198 | attn_lesions_cnn = torch.cat([lesions_cnn, yx_flag], dim=1) 199 | attn_lesions_cnn = torch.sigmoid(self.attn_layer(attn_lesions_cnn)) 200 | lesions_cnn = lesions_cnn * attn_lesions_cnn 201 | lesions_cnn = lesions_cnn.view(B, N_l, -1) 202 | lesions_cnn = self.cnn_project(lesions_cnn) 203 | 204 | lesions_attn, lesions_attnmap = self.attn_over_lesions(key=lesions_cnn, query=lesions_cnn, value=lesions_cnn) 205 | lesions_attn_energy, lesions_attn_energy_unnorm = self.energy_function(lesions_attn, self.lesions_weight_fn) 206 | # (B, N_l, C) x (B, N_l, 1) --> (B, C) 207 | patient_from_lesions = torch.sum(lesions_attn * lesions_attn_energy, dim=1) 208 | patient_from_lesions = self.ffn_attn_l2p(patient_from_lesions) 209 | 210 | lesions_pred = self.lesions_classifier(lesions_attn) 211 | 212 | self.lesions_attn_energy_unnorm = lesions_attn_energy_unnorm 213 | 214 | return { 215 | "feat": patient_from_lesions, 216 | "lesions_pred": lesions_pred, 217 | #"yx_pred": self.yx_classifier(patient_from_lesions) 218 | } -------------------------------------------------------------------------------- /train_and_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.print_utils import * 4 | import os 5 | from utils.lr_scheduler import get_lr_scheduler 6 | from utils.metric_utils import * 7 | import gc 8 | from utils.utils import * 9 | from utils.build_dataloader import build_data_loader 10 | from utils.build_model import build_model 11 | from utils.build_optimizer import build_optimizer, update_optimizer, read_lr_from_optimzier 12 | from utils.build_criterion import build_criterion 13 | from utils.build_backbone import BaseFeatureExtractor 14 | import numpy as np 15 | import math 16 | import json 17 | from sklearn.metrics import roc_auc_score 18 | from lifelines.utils import concordance_index 19 | import pandas as pd 20 | from copy import deepcopy 21 | import torch 22 | # torch.autograd.set_detect_anomaly(True) 23 | 24 | def _concordance_index(T, P, E): 25 | P = [p for i, p in enumerate(P) if not np.isnan(T[i])] 26 | E = [e for i, e in enumerate(E) if not np.isnan(T[i])] 27 | T = [t for i, t in enumerate(T) if not np.isnan(T[i])] 28 | return concordance_index(T, P, E) 29 | 30 | def my_concordance_index(data): 31 | try: 32 | os_cindex = _concordance_index(data['OS'].values, data['pred'].values, data['OSCensor'].values) 33 | pfs_cindex = _concordance_index(data['PFS'].values, data['pred'].values, data['PFSCensor'].values) 34 | except: 35 | os_cindex = -1 36 | pfs_cindex = -1 37 | return os_cindex, pfs_cindex 38 | 39 | def my_roc_auc_score(data): 40 | label = data['label'].values 41 | pred = data['pred'].values 42 | ind = np.where(label != -1) 43 | try: 44 | auc = roc_auc_score(label[ind], pred[ind]) 45 | except: 46 | auc = -1 47 | return auc 48 | 49 | def compute_metric(output, target, is_show=False): 50 | with torch.no_grad(): 51 | y_pred = output.detach().cpu().numpy() 52 | y_true = target.detach().cpu().numpy() 53 | ind = np.where(y_true != -1) 54 | y_pred = y_pred[ind] 55 | y_true = y_true[ind] 56 | 57 | if is_show: 58 | y_pred_true = [] 59 | for i in range(y_pred.shape[0]): 60 | y_pred_true.append((y_pred[i], y_true[i])) 61 | y_pred_true = sorted(y_pred_true, key=lambda x: x[0]) 62 | for yp, yt in y_pred_true: 63 | print(yp, yt) 64 | try: 65 | return roc_auc_score(y_true, y_pred) 66 | except: 67 | return -1.0 68 | 69 | class Trainer(object): 70 | '''This class implemetns the training and validation functionality for training ML model for medical imaging''' 71 | 72 | def __init__(self, opts, printer): 73 | super().__init__() 74 | self.opts = opts 75 | self.best_auc = 0 76 | self.start_epoch = 1 77 | self.printer = printer 78 | self.global_setter() 79 | 80 | def global_setter(self): 81 | # self.setup_logger() 82 | self.setup_device() 83 | self.setup_dataloader() 84 | self.setup_model_optimizer_lossfn() 85 | self.setup_lr_scheduler() 86 | 87 | def setup_device(self): 88 | num_gpus = torch.cuda.device_count() 89 | self.num_gpus = num_gpus 90 | if num_gpus > 0: 91 | print_log_message('Using {} GPUs'.format(num_gpus), self.printer) 92 | else: 93 | print_log_message('Using CPU', self.printer) 94 | 95 | self.device = torch.device("cuda:0" if num_gpus > 0 else "cpu") 96 | self.use_multi_gpu = True if num_gpus > 1 else False 97 | 98 | if torch.backends.cudnn.is_available(): 99 | import torch.backends.cudnn as cudnn 100 | cudnn.benchmark = True 101 | cudnn.deterministic = True 102 | 103 | def setup_lr_scheduler(self): 104 | # fetch learning rate scheduler 105 | #self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.98, verbose=True) 106 | self.lr_scheduler = get_lr_scheduler(self.opts, printer=self.printer) 107 | 108 | def setup_dataloader(self): 109 | train_loader, val_loader, diag_classes, class_weights = build_data_loader(opts=self.opts, printer=self.printer) 110 | 111 | self.train_loader = train_loader 112 | self.val_loader = val_loader 113 | self.diag_classes = diag_classes 114 | self.class_weights = torch.from_numpy(class_weights) 115 | 116 | def setup_model_optimizer_lossfn(self): 117 | # Build Model 118 | mi_model = build_model(opts=self.opts, printer=self.printer) 119 | 120 | mi_model = mi_model.to(device=self.device) 121 | if self.use_multi_gpu: 122 | mi_model = torch.nn.DataParallel(mi_model) 123 | self.mi_model = mi_model 124 | 125 | # Build Loss function 126 | criterion = build_criterion(opts=self.opts, class_weights=self.class_weights.float(), printer=self.printer) 127 | self.criterion = criterion.to(device=self.device) 128 | 129 | # Build optimizer 130 | self.optimizer = build_optimizer(model=self.mi_model, opts=self.opts, printer=self.printer) 131 | 132 | def training(self, epoch, epochs, lr, *args, **kwargs): 133 | train_stats = Statistics(printer=self.printer) 134 | 135 | self.mi_model.train() 136 | self.optimizer.zero_grad() 137 | 138 | num_samples = len(self.train_loader) 139 | epoch_start_time = time.time() 140 | 141 | pred_diag_labels_lst, true_diag_labels_lst, loss_lst = [], [], [] 142 | #P_risk_lst, T_risk_lst, E_risk_lst = [], [], [] 143 | for batch_id, batch in enumerate(self.train_loader): 144 | for key in batch: 145 | if not isinstance(batch[key][0], str): 146 | batch[key] = batch[key].float().to(device=self.device) 147 | # if epoch == 6: 148 | # print_log_message(f"{key} {batch[key].min().item()} {batch[key].max().item()}", printer=self.printer) 149 | 150 | true_diag_labels = batch['label'] 151 | results = self.mi_model(batch, opts=self.opts) 152 | pred_diag_labels = results['pred'] 153 | #print_log_message(f"{results['pred'].min().item()},{results['pred'].max().item()}", printer=self.printer) 154 | batch['epoch'] = epoch 155 | batch['epochs'] = epochs 156 | 157 | loss = self.criterion(batch, results) 158 | if loss is not None: 159 | #with torch.autograd.detect_anomaly(): 160 | (loss / self.opts.log_interval).backward() 161 | torch.nn.utils.clip_grad_norm_(self.mi_model.parameters(), max_norm=20, norm_type=2) 162 | loss_lst.append(loss.item()) 163 | pred_diag_labels_lst.append(torch.softmax(pred_diag_labels, dim=1)[:, 1]) 164 | true_diag_labels_lst.append(true_diag_labels) 165 | 166 | if (batch_id+1) % self.opts.log_interval == 0 or (batch_id+1) == len(self.train_loader): 167 | self.optimizer.step() 168 | self.optimizer.zero_grad() 169 | auc = compute_metric(torch.cat(pred_diag_labels_lst, dim=0), torch.cat(true_diag_labels_lst, dim=0)) 170 | train_stats.update(loss=np.mean(loss_lst), auc=auc) 171 | train_stats.output(epoch=epoch, batch=batch_id+1, n_batches=num_samples, start=epoch_start_time, lr=lr) 172 | 173 | return train_stats.avg_auc(), train_stats.avg_loss() 174 | 175 | def validation(self, epoch, lr, *args, **kwargs): 176 | val_stats = Statistics(printer=self.printer) 177 | self.mi_model.eval() 178 | num_samples = len(self.val_loader) 179 | 180 | black_lst = ["feat_words", "diss_words", "diss_bags", "lesions", "diss_lesions"] 181 | pred_save_dir = os.path.join(self.opts.save_dir, str(self.opts.seed), "pred") 182 | os.makedirs(pred_save_dir, exist_ok=True) 183 | pred_diag_labels_lst, true_diag_labels_lst, loss_lst = [], [], [] 184 | info_lst = {} 185 | with torch.no_grad(): 186 | epoch_start_time = time.time() 187 | for batch_id, batch in enumerate(self.val_loader): 188 | for key in batch: 189 | if not isinstance(batch[key][0], str): 190 | batch[key] = batch[key].float().to(device=self.device) 191 | 192 | true_diag_labels = batch['label'] 193 | results = self.mi_model(batch, opts=self.opts) 194 | pred_diag_labels = results['pred'] 195 | 196 | loss = self.criterion(batch, results) 197 | if loss is not None: 198 | loss_lst.append(loss.item()) 199 | 200 | pred_diag_labels_lst.append(torch.softmax(pred_diag_labels, dim=1)[:, 1]) 201 | true_diag_labels_lst.append(true_diag_labels) 202 | for key, value in batch.items(): 203 | if key in black_lst: 204 | continue 205 | if key not in info_lst: 206 | info_lst[key] = [] 207 | if isinstance(value[0], str): 208 | info_lst[key].append(value[0]) 209 | else: 210 | info_lst[key].append(value.detach().cpu().numpy()[0]) 211 | 212 | #print(batch_id, batch["id"], pred_diag_labels_lst[-1][0].item(), true_diag_labels[0].item()) 213 | 214 | torch.cuda.empty_cache() 215 | gc.collect() 216 | 217 | f = open(os.path.join(os.path.join(pred_save_dir, f'{epoch:03d}.csv')), 'w') 218 | f.write("id,name,blid,yxid,liaoxiao,xianshu,lianhe,PFS,PFSCensor,OS,OSCensor,label,pred\n") 219 | pred_diag_labels_lst_ = torch.cat(pred_diag_labels_lst, dim=0).detach().cpu().numpy() 220 | true_diag_labels_lst_ = torch.cat(true_diag_labels_lst, dim=0).detach().cpu().numpy() 221 | for key in info_lst: 222 | info_lst[key] = np.asarray(info_lst[key]) 223 | for i in range(len(pred_diag_labels_lst)): 224 | xianshu = -1 225 | lianhe = -1 226 | if "id" in info_lst: 227 | f.write(f'{info_lst["id"][i]},{info_lst["name"][i]},{info_lst["bl_pid"][i]},{info_lst["yx_pid"][i]},{info_lst["liaoxiao"][i]},{xianshu},{lianhe},{info_lst["pfs"][i]},{info_lst["pfs_censor"][i]},{info_lst["os"][i]},{info_lst["os_censor"][i]},{true_diag_labels_lst_[i]},{pred_diag_labels_lst_[i]}\n') 228 | elif "yxid" in info_lst: 229 | f.write(f'None,{info_lst["name"][i]},None,{info_lst["yx_pid"][i]},{info_lst["liaoxiao"][i]},{xianshu},{lianhe},{info_lst["pfs"][i]},{info_lst["pfs_censor"][i]},{info_lst["os"][i]},{info_lst["os_censor"][i]},{true_diag_labels_lst_[i]},{pred_diag_labels_lst_[i]}\n') 230 | else: 231 | f.write(f'None,{info_lst["name"][i]},{info_lst["bl_pid"][i]},None,{info_lst["liaoxiao"][i]},{xianshu},{lianhe},{info_lst["pfs"][i]},{info_lst["pfs_censor"][i]},{info_lst["os"][i]},{info_lst["os_censor"][i]},{true_diag_labels_lst_[i]},{pred_diag_labels_lst_[i]}\n') 232 | f.close() 233 | 234 | tmp_data = pd.read_csv(os.path.join(os.path.join(pred_save_dir, f'{epoch:03d}.csv'))) 235 | os_cindex, pfs_cindex = my_concordance_index(tmp_data) 236 | auc = my_roc_auc_score(tmp_data) 237 | #auc = roc_auc_score(tmp_data['label'].values, tmp_data['pred'].values) 238 | avg_loss = np.mean(loss_lst) 239 | 240 | print_log_message('* Validation Stats', printer=self.printer) 241 | print_log_message('* Loss: {:.3f}, AUC: {:3.3f}, C-index(OS): {:.3f}, C-index(PFS): {:.3f}'.format( 242 | avg_loss, auc, os_cindex, pfs_cindex), printer=self.printer) 243 | print_log_message('Minv: {:.3f}, Maxv: {:.3f}'.format(torch.cat(pred_diag_labels_lst, dim=0).detach().cpu().numpy().min(), 244 | torch.cat(pred_diag_labels_lst, dim=0).detach().cpu().numpy().max()), printer=self.printer,) 245 | 246 | return auc, avg_loss 247 | 248 | def run(self, *args, **kwargs): 249 | kwargs['need_attn'] = False 250 | 251 | # if self.opts.warm_up: 252 | # self.warm_up(args=args, kwargs=kwargs) 253 | 254 | eval_stats_dict = dict() 255 | res_dict = { 256 | "TrainingLoss": [], 257 | "TrainingAUC": [], 258 | "ValidationLoss": [], 259 | "ValidationAUC": [], 260 | } 261 | 262 | self.validation(epoch=-1, lr=self.opts.lr, args=args, kwargs=kwargs) 263 | for epoch in range(self.start_epoch, self.opts.epochs+1): 264 | epoch_lr = self.lr_scheduler.step(epoch) 265 | 266 | self.optimizer = update_optimizer(optimizer=self.optimizer, lr_value=epoch_lr) 267 | 268 | # Uncomment this line if you want to check the optimizer's LR is updated correctly 269 | # assert read_lr_from_optimzier(self.optimizer) == epoch_lr 270 | 271 | train_auc, train_loss = self.training(epoch=epoch, lr=epoch_lr, epochs=self.opts.epochs, args=args, kwargs=kwargs) 272 | val_auc, val_loss = self.validation(epoch=epoch, lr=epoch_lr, args=args, kwargs=kwargs) 273 | eval_stats_dict[epoch] = val_auc 274 | gc.collect() 275 | 276 | # remember best accuracy and save checkpoint for best model 277 | is_best = val_auc >= self.best_auc 278 | self.best_auc = max(val_auc, self.best_auc) 279 | 280 | model_state = self.mi_model.module.state_dict() if isinstance(self.mi_model, torch.nn.DataParallel) \ 281 | else self.mi_model.state_dict() 282 | 283 | optimizer_state = self.optimizer.state_dict() 284 | 285 | save_checkpoint(epoch=epoch, 286 | model_state=model_state, 287 | optimizer_state=optimizer_state, 288 | best_perf=self.best_auc, 289 | save_dir=self.opts.save_dir, 290 | is_best=is_best, 291 | keep_best_k_models=self.opts.keep_best_k_models, 292 | printer=self.printer, 293 | metric=val_auc, 294 | ) 295 | 296 | # if epoch % 10 == 0: 297 | # save_checkpoint(epoch=epoch, 298 | # model_state=model_state, 299 | # optimizer_state=optimizer_state, 300 | # best_perf=self.best_auc, 301 | # save_dir=self.opts.save_dir, 302 | # is_best=is_best, 303 | # keep_best_k_models=self.opts.keep_best_k_models, 304 | # printer=self.printer, 305 | # ) 306 | 307 | res_dict["TrainingLoss"].append(train_loss) 308 | res_dict["TrainingAUC"].append(train_auc) 309 | res_dict["ValidationLoss"].append(val_loss) 310 | res_dict["ValidationAUC"].append(val_auc) 311 | plot_results(res_dict, os.path.join(self.opts.save_dir, "plot.jpg")) 312 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/__init__.py -------------------------------------------------------------------------------- /utils/build_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from utils.print_utils import * 4 | import os 5 | 6 | class BaseFeatureExtractor(torch.nn.Module): 7 | ''' 8 | This class calls different base feature extractors 9 | ''' 10 | def __init__(self, opts, printer=print): 11 | ''' 12 | :param opts: Argument list 13 | ''' 14 | super(BaseFeatureExtractor, self).__init__() 15 | 16 | print(opts.base_extractor, 'resnet' in opts.base_extractor) 17 | 18 | self.printer = printer 19 | if opts.base_extractor == 'espnetv2': 20 | from model.feature_extractors.espnetv2 import EESPNet 21 | self.base_model = EESPNet(opts) 22 | self.initialize_base_model(opts.weights) 23 | output_feature_sz = self.base_model.classifier.in_features 24 | del self.base_model.classifier 25 | elif opts.base_extractor == 'mobilenetv2': 26 | from model.feature_extractors.mobilenetv2 import MobileNetV2 27 | self.base_model = MobileNetV2() 28 | self.initialize_base_model(opts.weights) 29 | output_feature_sz = self.base_model.last_channel 30 | del self.base_model.classifier 31 | elif opts.base_extractor == 'mnasnet': 32 | from model.feature_extractors.mnasnet import MNASNet 33 | assert opts.s == 1.0, 'We are currently supporting models with scale = 1.0. If you are interested in ' \ 34 | 'exploring more models, download those from PyTorch repo and use it after uncommenting ' \ 35 | 'this assertion. ' 36 | self.base_model = MNASNet(alpha=opts.s) 37 | self.initialize_base_model(opts.weights) 38 | output_feature_sz = self.base_model.last_channel 39 | del self.base_model.classifier 40 | elif 'resnet' in opts.base_extractor: 41 | from torchvision.models import resnet 42 | self.base_model = eval(f'resnet.{opts.base_extractor}')(pretrained=True) 43 | output_feature_sz = self.base_model.fc.in_features 44 | del self.base_model.fc 45 | else: 46 | print_error_message('{} model not yet supported'.format(opts.base_extractor), self.printer) 47 | 48 | self.output_feature_sz = output_feature_sz 49 | 50 | def initialize_base_model(self, wts_loc): 51 | ''' 52 | This function initializes the base model 53 | 54 | :param wts_loc: Location of the weights file 55 | ''' 56 | # initialize CNN model 57 | if not os.path.isfile(wts_loc): 58 | print_error_message('No file exists here: {}'.format(wts_loc), self.printer) 59 | 60 | print_log_message('Loading Imagenet trained weights', self.printer) 61 | pretrained_dict = torch.load(wts_loc, map_location=torch.device('cpu')) 62 | self.base_model.load_state_dict(pretrained_dict) 63 | print_log_message('Loading over', self.printer) 64 | 65 | def forward(self, words): 66 | ''' 67 | :param words: Word tensor of shape (N_w x C x w x h) 68 | :return: Features vector for words (N_w x F) 69 | ''' 70 | assert words.dim() == 4, 'Input should be 4 dimensional tensor (B x 3 X H x W)' 71 | words = self.base_model(words) 72 | return words 73 | 74 | 75 | def get_backbone_opts(parser): 76 | '''Base feature extractor CNN Model details''' 77 | group = parser.add_argument_group('CNN Model Details') 78 | group.add_argument('--base-extractor', default='espnetv2', help='Which CNN model? Default is espnetv2') 79 | group.add_argument('--s', type=float, default=2.0, 80 | help='Factor by which channels will be scaled. Default is 2.0 for espnetv2') 81 | group.add_argument('--weights', type=str, default='model/model_zoo/espnetv2/espnetv2_s_2.0_imagenet_224x224.pth', 82 | help='Location of imagenet pretrained weights') 83 | group.add_argument('--num_classes', type=int, default=1000, help='Number of classes in the base feature extractor.' 84 | ' Default is 1000 for the ImageNet pretrained model') 85 | group.add_argument('--channels', type=int, default=3, 86 | help='Number of input image channesl. Default is 3 for RGB image') 87 | 88 | return parser 89 | -------------------------------------------------------------------------------- /utils/build_criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.print_utils import * 4 | from utils.criterions.smoothing_loss import * 5 | from utils.criterions.blyx_loss import * 6 | 7 | def build_criterion(opts, class_weights, printer=print): 8 | criterion = None 9 | if opts.loss_fn == 'ce': 10 | if opts.label_smoothing: 11 | criterion = CrossEntropyWithLabelSmoothing(ls_eps=opts.label_smoothing_eps) 12 | print_log_message('Using label smoothing value of : \n\t{}'.format(opts.label_smoothing_eps), printer) 13 | else: 14 | if opts.loss_weight: 15 | criterion = nn.CrossEntropyLoss(weight=class_weights) 16 | class_wts_str = '\n\t'.join(['{} --> {:.3f}'.format(cl_id, class_weights[cl_id]) for cl_id in range(class_weights.size(0))]) 17 | print_log_message('Using class-weights: \n\t{}'.format(class_wts_str), printer) 18 | else: 19 | criterion = nn.CrossEntropyLoss() 20 | elif opts.loss_fn == 'bce': 21 | criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights) 22 | class_wts_str = '\n\t'.join( 23 | ['{} --> {:.3f}'.format(cl_id, class_weights[cl_id]) for cl_id in range(class_weights.size(0))]) 24 | print_log_message('Using class-weights: \n\t{}'.format(class_wts_str), printer) 25 | elif "BLYX" in opts.loss_fn: 26 | criterion = eval(opts.loss_fn)() 27 | elif "YX" in opts.loss_fn: 28 | criterion = eval(opts.loss_fn)() 29 | elif "BL" in opts.loss_fn: 30 | criterion = eval(opts.loss_fn)() 31 | else: 32 | print_error_message('{} critiria not yet supported') 33 | 34 | if criterion is None: 35 | print_error_message('Criteria function cannot be None. Please check', printer) 36 | 37 | return criterion 38 | 39 | def get_criterion_opts(parser): 40 | group = parser.add_argument_group("Criterion options") 41 | group.add_argument("--loss-fn", default="ce", help="Loss function") 42 | group.add_argument("--loss-weight", action="store_true", default=False, help="Weighted loss or not") 43 | group.add_argument("--label-smoothing", action="store_true", default=False, help="Smooth labels or not") 44 | group.add_argument("--label-smoothing-eps", default=0.1, type=float, help="Epsilon for label smoothing") 45 | return parser -------------------------------------------------------------------------------- /utils/build_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from utils.print_utils import * 4 | from torch.utils.data import DataLoader 5 | from data_loader import * 6 | 7 | def worker_init_fn(worked_id): 8 | worker_seed = torch.initial_seed() % 2 ** 32 9 | np.random.seed(worker_seed) 10 | random.seed(worker_seed) 11 | 12 | def build_data_loader(opts, printer=print): 13 | train_loader, val_loader = None, None 14 | diag_classes = 0 15 | if opts.dataset == "her2": 16 | train_ds = BLYXDataset(opts, split="train", split_file=opts.train_file, printer=print, cohort=opts.train_cohort) 17 | val_ds = BLYXDataset(opts, split="val", split_file=opts.val_file, printer=print, cohort="her2") 18 | 19 | diag_classes = train_ds.n_classes 20 | #assert diag_classes == opts.n_classes, (diag_classes, opts.n_classes) 21 | diag_labels = train_ds.diag_labels 22 | 23 | train_dl = DataLoader(train_ds, batch_size=opts.batch_size, shuffle=True, 24 | pin_memory=True, num_workers=opts.data_workers, 25 | worker_init_fn=worker_init_fn) 26 | val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, 27 | pin_memory=True, num_workers=opts.data_workers, 28 | worker_init_fn=worker_init_fn) 29 | elif opts.dataset == "her2_test": 30 | train_ds = BLYXDatasetV1(opts, split="train", split_file=opts.train_file, printer=print, cohort=opts.train_cohort) 31 | val_ds = BLYXDatasetV1(opts, split="val", split_file=opts.val_file, printer=print, cohort="her2") 32 | 33 | diag_classes = train_ds.n_classes 34 | #assert diag_classes == opts.n_classes, (diag_classes, opts.n_classes) 35 | diag_labels = train_ds.diag_labels 36 | 37 | train_dl = DataLoader(train_ds, batch_size=opts.batch_size, shuffle=True, 38 | pin_memory=True, num_workers=opts.data_workers, 39 | worker_init_fn=worker_init_fn) 40 | val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, 41 | pin_memory=True, num_workers=opts.data_workers, 42 | worker_init_fn=worker_init_fn) 43 | elif opts.dataset == "ci": 44 | train_ds = BLYXDataset(opts, split="train", split_file=opts.train_file, printer=print, cohort=opts.train_cohort) 45 | val_ds = BLYXDataset(opts, split="val", split_file=opts.val_file, printer=print, cohort="ci") 46 | 47 | diag_classes = train_ds.n_classes 48 | #assert diag_classes == opts.n_classes, (diag_classes, opts.n_classes) 49 | diag_labels = train_ds.diag_labels 50 | 51 | train_dl = DataLoader(train_ds, batch_size=opts.batch_size, shuffle=True, 52 | pin_memory=True, num_workers=opts.data_workers, 53 | worker_init_fn=worker_init_fn) 54 | val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, 55 | pin_memory=True, num_workers=opts.data_workers, 56 | worker_init_fn=worker_init_fn) 57 | elif opts.dataset == "her2_yx": 58 | train_ds = YXDataset(opts, split="train", split_file=opts.train_file, printer=print, cohort=opts.train_cohort) 59 | val_ds = YXDataset(opts, split="val", split_file=opts.val_file, printer=print, cohort="her2") 60 | 61 | diag_classes = train_ds.n_classes 62 | #assert diag_classes == opts.n_classes, (diag_classes, opts.n_classes) 63 | diag_labels = train_ds.diag_labels 64 | 65 | train_dl = DataLoader(train_ds, batch_size=opts.batch_size, shuffle=True, 66 | pin_memory=True, num_workers=opts.data_workers, 67 | worker_init_fn=worker_init_fn) 68 | val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, 69 | pin_memory=True, num_workers=opts.data_workers, 70 | worker_init_fn=worker_init_fn) 71 | elif opts.dataset == "ci_yx": 72 | train_ds = YXDataset(opts, split="train", split_file=opts.train_file, printer=print, cohort=opts.train_cohort) 73 | val_ds = YXDataset(opts, split="val", split_file=opts.val_file, printer=print, cohort="ci") 74 | 75 | diag_classes = train_ds.n_classes 76 | #assert diag_classes == opts.n_classes, (diag_classes, opts.n_classes) 77 | diag_labels = train_ds.diag_labels 78 | 79 | train_dl = DataLoader(train_ds, batch_size=opts.batch_size, shuffle=True, 80 | pin_memory=True, num_workers=opts.data_workers, 81 | worker_init_fn=worker_init_fn) 82 | val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, 83 | pin_memory=True, num_workers=opts.data_workers, 84 | worker_init_fn=worker_init_fn) 85 | elif opts.dataset == "her2_bl": 86 | train_ds = BLDataset(opts, split="train", split_file=opts.train_file, printer=print, cohort=opts.train_cohort) 87 | val_ds = BLDataset(opts, split="val", split_file=opts.val_file, printer=print, cohort="her2") 88 | 89 | diag_classes = train_ds.n_classes 90 | #assert diag_classes == opts.n_classes, (diag_classes, opts.n_classes) 91 | diag_labels = train_ds.diag_labels 92 | 93 | train_dl = DataLoader(train_ds, batch_size=opts.batch_size, shuffle=True, 94 | pin_memory=True, num_workers=opts.data_workers, 95 | worker_init_fn=worker_init_fn) 96 | val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, 97 | pin_memory=True, num_workers=opts.data_workers, 98 | worker_init_fn=worker_init_fn) 99 | elif opts.dataset == "ci_bl": 100 | train_ds = BLDataset(opts, split="train", split_file=opts.train_file, printer=print, cohort=opts.train_cohort) 101 | val_ds = BLDataset(opts, split="val", split_file=opts.val_file, printer=print, cohort="ci") 102 | 103 | diag_classes = train_ds.n_classes 104 | #assert diag_classes == opts.n_classes, (diag_classes, opts.n_classes) 105 | diag_labels = train_ds.diag_labels 106 | 107 | train_dl = DataLoader(train_ds, batch_size=opts.batch_size, shuffle=True, 108 | pin_memory=True, num_workers=opts.data_workers, 109 | worker_init_fn=worker_init_fn) 110 | val_dl = DataLoader(val_ds, batch_size=1, shuffle=False, 111 | pin_memory=True, num_workers=opts.data_workers, 112 | worker_init_fn=worker_init_fn) 113 | else: 114 | print_error_message('{} dataset not supported yet'.format(opts.dataset), printer) 115 | 116 | # compute class-weights for balancing dataset 117 | if opts.class_weights: 118 | class_weights = np.histogram(diag_labels, bins=diag_classes)[0] 119 | class_weights = np.array(class_weights) / sum(class_weights) 120 | for i in range(diag_classes): 121 | class_weights[i] = round(np.log(1.0 / class_weights[i]), 5) 122 | else: 123 | class_weights = np.ones(diag_classes, dtype=np.float) 124 | 125 | return train_dl, val_dl, diag_classes, class_weights 126 | 127 | def get_dataset_opts(parser): 128 | group = parser.add_argument_group('Dataset general details') 129 | group.add_argument('--dataset', type=str, default='bingli', help='Dataset name') 130 | group.add_argument('--label-type', type=str, default='response') 131 | group.add_argument('--train-cohort', type=str, default=None) 132 | group.add_argument('--train-file', type=str, default=None) 133 | group.add_argument('--val-file', type=str, default=None) 134 | group.add_argument('--bl-img-dir', type=str, default='/public/share/chenzifan/Journal22-GuoZhong/GuoZhongBingLiData/', help='Dataset location') 135 | group.add_argument('--bl-rad-dir', type=str, default='../Data/BL_radiomics', help='Dataset location') 136 | group.add_argument('--bl-img-extn', type=str, default='jpg', help='Extension of WSIs. Default is tiff') 137 | group.add_argument('--bl-num-bags', type=int, default=10, help='Number of bags for running') 138 | group.add_argument('--bl-bag-size', type=int, default=2048, help='Bag size.') 139 | group.add_argument('--bl-word-size', type=int, default=256, help='Word size.') 140 | group.add_argument('--yx-img-dir', type=str, default='/public/share/chenzifan/Journal22-GuoZhong/YingXiangCropData/', help='Dataset location') 141 | group.add_argument('--yx-rad-dir', type=str, default='../Data/YX_radiomics', help='Dataset location') 142 | group.add_argument('--yx-img-extn', type=str, default='jpg', help='Extension of WSIs. Default is tiff') 143 | group.add_argument('--yx-num-lesions', type=int, default=4, help='Number of lesions for running') 144 | group.add_argument('--yx-lesion-size', type=int, default=224, help='Lesion size.') 145 | group.add_argument('--split-file', type=str, default='../Data/SplitData.xlsx', 146 | help='Text file with training image ids and labels') 147 | group.add_argument('--batch-size', type=int, default=1, help='Batch size') 148 | group.add_argument('--data-workers', type=int, default=1, help='Number of workers for data loading') 149 | group.add_argument('--class-weights', action='store_true', default=False, 150 | help='Compute normalized to address class-imbalance') 151 | return parser -------------------------------------------------------------------------------- /utils/build_model.py: -------------------------------------------------------------------------------- 1 | from model.bl_model import * 2 | from model.yx_model import * 3 | from model.model import * 4 | from utils.print_utils import * 5 | 6 | def _load_from_checkpoint(opts, model, checkpoint, printer): 7 | if os.path.isfile(checkpoint): 8 | state_dict = torch.load(checkpoint, map_location='cpu') 9 | model_state_dict = model.state_dict() 10 | suc = 0 11 | freezn_keys = [] 12 | for key, value in state_dict.items(): 13 | if opts.finetune: 14 | if "classifier" in key: 15 | continue 16 | if key in model_state_dict and model_state_dict[key].shape == value.shape: 17 | model_state_dict[key] = value 18 | freezn_keys.append(key) 19 | suc += 1 20 | res = model.load_state_dict(model_state_dict) 21 | print_info_message('Load from {} ({}) {}/{}'.format(checkpoint, res, suc, len(model_state_dict)), printer) 22 | return model 23 | 24 | def build_model(opts, printer=print): 25 | model = None 26 | if opts.model == 'bingli': 27 | model = eval(opts.bl_model)(opts) 28 | if os.path.isfile(opts.bl_pretrained): 29 | printer(f'Loaded pretrained weight from {opts.bl_pretrained}') 30 | state_dict = torch.load(opts.bl_pretrained) 31 | model_state_dict = model.state_dict() 32 | suc = 0 33 | freezn_keys = [] 34 | for key, value in state_dict.items(): 35 | if key in model_state_dict: 36 | model_state_dict[key] = value 37 | freezn_keys.append(key) 38 | suc += 1 39 | model.load_state_dict(model_state_dict) 40 | # for key, param in model.named_parameters(): 41 | # if key in freezn_keys: 42 | # param.requires_grad = False 43 | printer(f'Loaded {suc}/{len(list(state_dict.keys()))} keys') 44 | model = _load_from_checkpoint(opts, model, opts.bl_checkpoint, printer) 45 | elif opts.model == 'yingxiang': 46 | model = eval(opts.yx_model)(opts) 47 | if os.path.isfile(opts.yx_pretrained): 48 | printer(f'Loaded pretrained weight from {opts.yx_pretrained}') 49 | state_dict = torch.load(opts.yx_pretrained) 50 | model_state_dict = model.state_dict() 51 | suc = 0 52 | freezn_keys = [] 53 | for key, value in state_dict.items(): 54 | if key in model_state_dict: 55 | model_state_dict[key] = value 56 | freezn_keys.append(key) 57 | suc += 1 58 | model.load_state_dict(model_state_dict) 59 | # for key, param in model.named_parameters(): 60 | # if key in freezn_keys: 61 | # param.requires_grad = False 62 | printer(f'Loaded {suc}/{len(list(state_dict.keys()))} keys') 63 | model = _load_from_checkpoint(opts, model, opts.yx_checkpoint, printer) 64 | # for key, param in model.cnn.named_parameters(): 65 | # param.requires_grad = False 66 | elif opts.model == 'bingliyingxiang': 67 | base_model = opts.model 68 | opts.model = 'bingli' 69 | bl_model = build_model(opts, printer=printer) 70 | 71 | opts.model = 'yingxiang' 72 | yx_model = build_model(opts, printer=printer) 73 | 74 | opts.model = base_model 75 | model = eval(opts.blyx_model)(opts=opts, bl_model=bl_model, yx_model=yx_model) 76 | model = _load_from_checkpoint(opts, model, opts.blyx_checkpoint, printer) 77 | else: 78 | print_error_message('Model for this dataset ({}) not yet supported'.format('self.opts.dataset'), printer) 79 | 80 | # sanity check to ensure that everything is fine 81 | if model is None: 82 | print_error_message('Model cannot be None. Please check', printer) 83 | 84 | return model 85 | 86 | def get_model_opts(parser): 87 | group = parser.add_argument_group('Medical Imaging Model Details') 88 | 89 | group.add_argument('--model', default="bingli", type=str, help='Name of model') 90 | group.add_argument('--feat-fusion-mode', default="parallel", type=str) 91 | 92 | group.add_argument('--yx-model', default="YXModelv1", type=str, help='Name of YingXiangModel') 93 | group.add_argument('--yx-pretrained', default='', type=str) 94 | group.add_argument('--yx-cnn-name', default="resnet18", type=str, help='Name of backbone') 95 | group.add_argument('--yx-cnn-pretrained', action='store_true', default=False) 96 | group.add_argument('--yx-cnn-features', type=int, default=512, 97 | help='Number of cnn features extracted by the backbone') 98 | group.add_argument('--yx-out-features', type=int, default=128, 99 | help='Number of output features after merging bags and words') 100 | group.add_argument('--yx-attn-heads', default=2, type=int, help='Number of attention heads') 101 | group.add_argument('--yx-dropout', default=0.4, type=float, help='Dropout value') 102 | group.add_argument('--yx-attn-dropout', default=0.2, type=float, help='Dropout value for attention') 103 | group.add_argument('--yx-attn-fn', type=str, default='softmax', choices=['tanh', 'sigmoid', 'softmax'], 104 | help='Proability to drop bag and word attention weights') 105 | group.add_argument('--yx-num-way', type=int, default=4) 106 | 107 | group.add_argument('--bl-model', default="BLModelV1", type=str, help='Name of BingLiModel') 108 | group.add_argument('--bl-pretrained', default='', type=str) 109 | group.add_argument('--bl-cnn-name', default="resnet18", type=str, help='Name of backbone') 110 | group.add_argument('--bl-cnn-s', type=float, default=2.0, 111 | help='Factor by which channels will be scaled. Default is 2.0 for espnetv2') 112 | group.add_argument('--bl-cnn-pretrained', action='store_true', default=False) 113 | group.add_argument('--bl-cnn-weight', default=False, type=str, help='Pretrained model') 114 | group.add_argument('--bl-cnn-features', type=int, default=512, 115 | help='Number of cnn features extracted by the backbone') 116 | group.add_argument('--bl-out-features', type=int, default=128, 117 | help='Number of output features after merging bags and words') 118 | group.add_argument('--bl-attn-heads', default=2, type=int, help='Number of attention heads') 119 | group.add_argument('--bl-dropout', default=0.4, type=float, help='Dropout value') 120 | group.add_argument('--bl-max-bsz-cnn-gpu0', type=int, default=100, help='Max. batch size on GPU0') 121 | group.add_argument('--bl-attn-dropout', type=float, default=0.2, help='Proability to drop bag and word attention weights') 122 | group.add_argument('--bl-attn-fn', type=str, default='softmax', choices=['tanh', 'sigmoid', 'softmax'], 123 | help='Proability to drop bag and word attention weights') 124 | group.add_argument('--keep-best-k-models', type=int, default=-1) 125 | group.add_argument('--bl-num-way', type=int, default=6) 126 | 127 | group.add_argument('--n-classes', default=2, type=int, help='Number of classes') 128 | group.add_argument('--blyx-model', default="BLYXModelv1", type=str, help='Name of BingLiYingXiangModel') 129 | group.add_argument('--blyx-out-features', type=int, default=128, 130 | help='Number of output features after merging bags and words') 131 | group.add_argument('--blyx-dropout', default=0.4, type=float, help='Dropout value') 132 | 133 | 134 | group.add_argument('--resume', action='store_true', default=False) 135 | group.add_argument('--blyx-checkpoint', default="", type=str, help='Checkpoint for resuming') 136 | group.add_argument('--bl-checkpoint', default="", type=str, help='Checkpoint for resuming') 137 | group.add_argument('--yx-checkpoint', default="", type=str, help='Checkpoint for resuming') 138 | return parser -------------------------------------------------------------------------------- /utils/build_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from utils.print_utils import * 4 | import os 5 | 6 | def build_optimizer(opts, model, printer=print): 7 | optimizer = None 8 | 9 | if opts.finetune: 10 | params = [ 11 | {"params": [p for n, p in model.named_parameters() if "classifier." in n], "lr": opts.lr * 10.0}, 12 | {"params": [p for n, p in model.named_parameters() if "classifier." not in n], "lr": opts.lr}, 13 | ] 14 | # params = [ 15 | # {"params": [p for n, p in model.named_parameters() if not(("bl_model." in n) or ("yx_model." in n))], "lr": opts.lr * 100.0}, 16 | # {"params": [p for n, p in model.named_parameters() if "bl_model." in n], "lr": opts.lr}, 17 | # {"params": [p for n, p in model.named_parameters() if "yx_model." in n], "lr": opts.lr}, 18 | # ] 19 | else: 20 | params = [p for p in model.parameters() if p.requires_grad] 21 | 22 | if opts.optim == 'sgd': 23 | print_info_message('Using SGD optimizer', printer) 24 | optimizer = optim.SGD(params, lr=opts.lr, momentum=opts.momentum, weight_decay=opts.weight_decay) 25 | elif opts.optim == 'adam': 26 | print_info_message('Using ADAM optimizer', printer) 27 | beta1 = 0.9 if opts.adam_beta1 is None else opts.adam_beta1 28 | beta2 = 0.999 if opts.adam_beta2 is None else opts.adam_beta2 29 | optimizer = optim.Adam( 30 | params, 31 | lr=opts.lr, 32 | betas=(beta1, beta2), 33 | weight_decay=opts.weight_decay, 34 | eps=1e-9) 35 | elif opts.optim == "adamw": 36 | print_info_message('Using ADAMW optimizer', printer) 37 | beta1 = 0.9 if opts.adam_beta1 is None else opts.adam_beta1 38 | beta2 = 0.999 if opts.adam_beta2 is None else opts.adam_beta2 39 | optimizer = optim.AdamW( 40 | params, 41 | lr=opts.lr, 42 | betas=(beta1, beta2), 43 | weight_decay=opts.weight_decay, 44 | eps=1e-9) 45 | else: 46 | print_error_message('{} optimizer not yet supported'.format(opts.optim), printer) 47 | 48 | # sanity check to ensure that everything is fine 49 | if optimizer is None: 50 | print_error_message('Optimizer cannot be None. Please check', printer) 51 | 52 | return optimizer 53 | 54 | def update_optimizer(optimizer, lr_value): 55 | optimizer.param_groups[0]['lr'] = lr_value 56 | return optimizer 57 | 58 | def read_lr_from_optimzier(optimizer): 59 | return optimizer.param_groups[0]['lr'] 60 | 61 | def get_optimizer_opts(parser): 62 | 'Loss function details' 63 | group = parser.add_argument_group('Optimizer options') 64 | group.add_argument('--optim', default='sgd', type=str, help='Optimizer') 65 | group.add_argument('--momentum', default=0.8, type=float, help='Momentum for SGD') 66 | group.add_argument('--adam-beta1', default=0.9, type=float, help='Beta1 for ADAM') 67 | group.add_argument('--adam-beta2', default=0.999, type=float, help='Beta2 for ADAM') 68 | group.add_argument('--lr', default=0.0005, type=float, help='Initial learning rate for the optimizer') 69 | group.add_argument('--weight-decay', default=4e-6, type=float, help='Weight decay') 70 | 71 | group = parser.add_argument_group('Optimizer accumulation options') 72 | group.add_argument('--accum-count', type=int, default=1, help='After how many iterations shall we update the weights') 73 | 74 | return parser 75 | 76 | -------------------------------------------------------------------------------- /utils/criterions/__pycache__/blyx_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/criterions/__pycache__/blyx_loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/criterions/__pycache__/blyx_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/criterions/__pycache__/blyx_loss.cpython-39.pyc -------------------------------------------------------------------------------- /utils/criterions/__pycache__/focal_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/criterions/__pycache__/focal_loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/criterions/__pycache__/focal_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/criterions/__pycache__/focal_loss.cpython-39.pyc -------------------------------------------------------------------------------- /utils/criterions/__pycache__/smoothing_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/criterions/__pycache__/smoothing_loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/criterions/__pycache__/smoothing_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/criterions/__pycache__/smoothing_loss.cpython-39.pyc -------------------------------------------------------------------------------- /utils/criterions/__pycache__/survival_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/criterions/__pycache__/survival_loss.cpython-37.pyc -------------------------------------------------------------------------------- /utils/criterions/__pycache__/survival_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czifan/MuMo/3a4d33fbba57c1421af442fd116936e294caca05/utils/criterions/__pycache__/survival_loss.cpython-39.pyc -------------------------------------------------------------------------------- /utils/criterions/blyx_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from utils.criterions.focal_loss import * 6 | from utils.criterions.survival_loss import * 7 | 8 | def sigmoid_rampup(current, rampup_length): 9 | if rampup_length == 0: 10 | return 1.0 11 | else: 12 | current = np.clip(current, 0.0, rampup_length) 13 | phase = 1.0 - current / rampup_length 14 | return float(np.exp(-5.0 * phase * phase)) 15 | 16 | def get_current_consistency_weight(epoch, consistency=1., consistency_rampup=40): 17 | return consistency * sigmoid_rampup(epoch, consistency_rampup) 18 | 19 | class BLYXLoss(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | self.ce_criterion = FocalLoss(num_class=2) 23 | self.surv_criterion = DeepSurvLoss() 24 | self.align_criterion = FocalLoss(num_class=2) 25 | 26 | self.lesion_criterion = FocalLoss(num_class=5) 27 | self.bag_criterion = FocalLoss(num_class=4) 28 | self.word_criterion = FocalLoss(num_class=4) 29 | 30 | def forward(self, batch, results): 31 | ind = torch.where(batch['label'] != -1) 32 | loss = self.ce_criterion(results['pred'][ind], batch['label'][ind].long()) 33 | 34 | loss1 = loss.item() 35 | 36 | if results['pred'].shape[0] == 1: 37 | return loss 38 | 39 | P_risk = torch.softmax(results['pred'], dim=1)[:, 0] 40 | T = batch['os'] 41 | E = batch['os_censor'] 42 | # T = torch.round(batch['pfs'] / 30.5).float() 43 | # E = batch['pfs_censor'] 44 | loss2 = self.surv_criterion(P_risk, T, E) 45 | if not np.isnan(loss2.item()): 46 | loss += loss2 47 | 48 | blyx_flag = results["blyx_flag"].squeeze(dim=1) # (B,) 49 | bl_com_feat = results["bl_com_feat"][torch.where(blyx_flag)] # (B', C) 50 | yx_com_feat = results["yx_com_feat"][torch.where(blyx_flag)] # (B', C) 51 | cos_pred = torch.cosine_similarity(bl_com_feat.unsqueeze(dim=1), yx_com_feat.unsqueeze(dim=0), dim=-1).view(-1) # (B'*B') 52 | cos_pred = torch.stack([1.0-cos_pred, cos_pred], dim=1) # (B'*B', 2) 53 | cos_label = torch.eye(bl_com_feat.shape[0]).long().view(-1).to(cos_pred.device) # (B'*B') 54 | loss3 = self.align_criterion(cos_pred, cos_label) 55 | if not np.isnan(loss3.item()): 56 | loss += loss3 57 | 58 | cls_weight = 0.1 * (1.0 - get_current_consistency_weight(batch["epoch"], consistency=1.0, consistency_rampup=batch["epochs"])) 59 | 60 | if results["lesions_pred"] is not None: 61 | ind = torch.where(batch["lesions_label"].view(-1) != -1) 62 | lesions_pred = results["lesions_pred"].view(-1, results["lesions_pred"].shape[-1])[ind] 63 | lesions_label = batch["lesions_label"].view(-1)[ind] 64 | loss4 = self.lesion_criterion(lesions_pred, lesions_label) 65 | loss += loss4 * cls_weight 66 | 67 | if results["mask_bags_pred"] is not None: 68 | ind = torch.where(batch["mask_bags_label"].view(-1) != -1) 69 | mask_bags_pred = results["mask_bags_pred"].view(-1, results["mask_bags_pred"].shape[-1])[ind] 70 | mask_bags_label = batch["mask_bags_label"].view(-1)[ind] 71 | loss5 = self.bag_criterion(mask_bags_pred, mask_bags_label) 72 | loss += loss5 * cls_weight 73 | 74 | if results["mask_words_pred"] is not None: 75 | ind = torch.where(batch["mask_words_label"].view(-1) != -1) 76 | mask_words_pred = results["mask_words_pred"].view(-1, results["mask_words_pred"].shape[-1])[ind] 77 | mask_words_label = batch["mask_words_label"].view(-1)[ind] 78 | loss6 = self.word_criterion(mask_words_pred, mask_words_label) 79 | loss += loss6 * cls_weight 80 | 81 | return loss 82 | 83 | 84 | class BLYXLossFN(nn.Module): 85 | def __init__(self): 86 | super().__init__() 87 | self.ce_criterion = FocalLoss(num_class=2) 88 | self.surv_criterion = DeepSurvLoss() 89 | 90 | def forward(self, batch, results): 91 | ind = torch.where(batch['label'] != -1) 92 | loss = self.ce_criterion(results['pred'][ind], batch['label'][ind].long()) 93 | 94 | if results['pred'].shape[0] == 1: 95 | return loss 96 | 97 | P_risk = torch.softmax(results['pred'], dim=1)[:, 0] 98 | T = batch["pfs"] 99 | E = batch['pfs_censor'] 100 | loss2 = self.surv_criterion(P_risk, T, E) 101 | if not np.isnan(loss2.item()): 102 | loss += loss2 103 | 104 | return loss 105 | 106 | class BLLoss(nn.Module): 107 | def __init__(self): 108 | super().__init__() 109 | self.ce_criterion = FocalLoss(num_class=2) 110 | self.surv_criterion = DeepSurvLoss() 111 | self.bag_criterion = FocalLoss(num_class=4) 112 | self.word_criterion = FocalLoss(num_class=4) 113 | 114 | def forward(self, batch, results): 115 | ind = torch.where(batch['label'] != -1) 116 | loss = self.ce_criterion(results['pred'][ind], batch['label'][ind].long()) 117 | 118 | loss1 = loss.item() 119 | 120 | if results['pred'].shape[0] == 1: 121 | return loss 122 | 123 | P_risk = torch.softmax(results['pred'], dim=1)[:, 0] 124 | T = batch['pfs'] 125 | E = batch['pfs_censor'] 126 | # T = torch.round(batch['pfs'] / 30.5).float() 127 | # E = batch['pfs_censor'] 128 | loss2 = self.surv_criterion(P_risk, T, E) 129 | if not np.isnan(loss2.item()): 130 | loss += loss2 131 | 132 | cls_weight = 0.1 * (1.0 - get_current_consistency_weight(batch["epoch"], consistency=1.0, consistency_rampup=batch["epochs"])) 133 | 134 | if results["mask_bags_pred"] is not None: 135 | ind = torch.where(batch["mask_bags_label"].view(-1) != -1) 136 | mask_bags_pred = results["mask_bags_pred"].view(-1, results["mask_bags_pred"].shape[-1])[ind] 137 | mask_bags_label = batch["mask_bags_label"].view(-1)[ind] 138 | loss5 = self.bag_criterion(mask_bags_pred, mask_bags_label) 139 | loss += loss5 * cls_weight 140 | 141 | if results["mask_words_pred"] is not None: 142 | ind = torch.where(batch["mask_words_label"].view(-1) != -1) 143 | mask_words_pred = results["mask_words_pred"].view(-1, results["mask_words_pred"].shape[-1])[ind] 144 | mask_words_label = batch["mask_words_label"].view(-1)[ind] 145 | loss6 = self.word_criterion(mask_words_pred, mask_words_label) 146 | loss += loss6 * cls_weight 147 | 148 | return loss 149 | 150 | class YXLoss(nn.Module): 151 | def __init__(self): 152 | super().__init__() 153 | self.ce_criterion = FocalLoss(num_class=2) 154 | self.surv_criterion = DeepSurvLoss() 155 | 156 | self.lesion_criterion = FocalLoss(num_class=5) 157 | 158 | def forward(self, batch, results): 159 | ind = torch.where(batch['label'] != -1) 160 | loss = self.ce_criterion(results['pred'][ind], batch['label'][ind].long()) 161 | 162 | loss1 = loss.item() 163 | 164 | if results['pred'].shape[0] == 1: 165 | return loss 166 | 167 | P_risk = torch.softmax(results['pred'], dim=1)[:, 0] 168 | T = batch['os'] 169 | E = batch['os_censor'] 170 | # T = torch.round(batch['pfs'] / 30.5).float() 171 | # E = batch['pfs_censor'] 172 | loss2 = self.surv_criterion(P_risk, T, E) 173 | if not np.isnan(loss2.item()): 174 | loss += loss2 175 | 176 | cls_weight = 0.1 * (1.0 - get_current_consistency_weight(batch["epoch"], consistency=1.0, consistency_rampup=batch["epochs"])) 177 | 178 | if results["lesions_pred"] is not None: 179 | ind = torch.where(batch["lesions_label"].view(-1) != -1) 180 | lesions_pred = results["lesions_pred"].view(-1, results["lesions_pred"].shape[-1])[ind] 181 | lesions_label = batch["lesions_label"].view(-1)[ind] 182 | loss4 = self.lesion_criterion(lesions_pred, lesions_label) 183 | loss += loss4 * cls_weight 184 | 185 | return loss -------------------------------------------------------------------------------- /utils/criterions/focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | 6 | class FocalLoss(nn.Module): 7 | def __init__(self, num_class, alpha=None, gamma=2, balance_index=-1, smooth=None, size_average=True): 8 | super(FocalLoss, self).__init__() 9 | self.num_class = num_class 10 | self.alpha = alpha 11 | self.gamma = gamma 12 | self.smooth = smooth 13 | self.size_average = size_average 14 | 15 | if self.alpha is None: 16 | self.alpha = torch.ones(self.num_class, 1) 17 | elif isinstance(self.alpha, (list, np.ndarray)): 18 | assert len(self.alpha) == self.num_class 19 | self.alpha = torch.FloatTensor(alpha).view(self.num_class, 1) 20 | self.alpha = self.alpha / self.alpha.sum() 21 | elif isinstance(self.alpha, float): 22 | alpha = torch.ones(self.num_class, 1) 23 | alpha = alpha * (1 - self.alpha) 24 | alpha[balance_index] = self.alpha 25 | self.alpha = alpha 26 | else: 27 | raise TypeError('Not support alpha type') 28 | 29 | if self.smooth is not None: 30 | if self.smooth < 0 or self.smooth > 1.0: 31 | raise ValueError('smooth value should be in [0,1]') 32 | 33 | def forward(self, input, target): 34 | # ind = (target != self.ignore_index) 35 | # input = input[ind, ...] 36 | # target = target[ind] 37 | 38 | logit = F.softmax(input, dim=1) 39 | 40 | if logit.dim() > 2: 41 | # N,C,d1,d2 -> N,C,m (m=d1*d2*...) 42 | logit = logit.view(logit.size(0), logit.size(1), -1) 43 | logit = logit.permute(0, 2, 1).contiguous() 44 | logit = logit.view(-1, logit.size(-1)) 45 | target = target.view(-1, 1) 46 | 47 | epsilon = 1e-10 48 | alpha = self.alpha 49 | if alpha.device != input.device: 50 | alpha = alpha.to(input.device) 51 | 52 | idx = target.cpu().long() 53 | one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_() 54 | one_hot_key = one_hot_key.scatter_(1, idx, 1) 55 | if one_hot_key.device != logit.device: 56 | one_hot_key = one_hot_key.to(logit.device) 57 | 58 | if self.smooth: 59 | one_hot_key = torch.clamp( 60 | one_hot_key, self.smooth, 1.0 - self.smooth) 61 | pt = (one_hot_key * logit).sum(1) + epsilon 62 | logpt = pt.log() 63 | 64 | gamma = self.gamma 65 | 66 | alpha = alpha[idx] 67 | loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt 68 | 69 | if self.size_average: 70 | loss = loss.mean() 71 | else: 72 | loss = loss.sum() 73 | return loss -------------------------------------------------------------------------------- /utils/criterions/smoothing_loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | # adapted from Fairseq 5 | 6 | class CrossEntropyWithLabelSmoothing(nn.Module): 7 | def __init__(self, ls_eps=0.1, ignore_idx=None, reduce=True, reduction='mean', *args, **kwargs): 8 | super(CrossEntropyWithLabelSmoothing, self).__init__() 9 | self.ls_eps = ls_eps 10 | self.ignore_idx = ignore_idx 11 | self.reduce = reduce 12 | self.reduction = reduction 13 | 14 | def compute_loss(self, log_probs, target): 15 | if target.dim() == log_probs.dim() - 1: 16 | target = target.unsqueeze(-1) 17 | nll_loss = -log_probs.gather(dim=-1, index=target) 18 | smooth_loss = -log_probs.sum(dim=-1, keepdim=True) 19 | if self.ignore_idx is not None: 20 | pad_mask = target.eq(self.ignore_idx) 21 | if pad_mask.any(): 22 | nll_loss.masked_fill_(pad_mask, 0.) 23 | smooth_loss.masked_fill_(pad_mask, 0.) 24 | else: 25 | nll_loss = nll_loss.squeeze(-1) 26 | smooth_loss = smooth_loss.squeeze(-1) 27 | if self.reduce: 28 | nll_loss = nll_loss.sum() 29 | smooth_loss = smooth_loss.sum() 30 | eps_i = self.ls_eps / log_probs.size(-1) 31 | loss = (1. - self.ls_eps) * nll_loss + eps_i * smooth_loss 32 | return loss 33 | 34 | def forward(self, pred, target): 35 | assert pred.dim() == 2, 'Should be B x C' 36 | B, C = pred.size() 37 | log_probs = F.log_softmax(pred, dim=-1) 38 | log_probs = log_probs.view(-1, C) 39 | target = target.view(-1, 1) 40 | loss = self.compute_loss(log_probs, target) 41 | if self.reduction == 'mean': 42 | loss /= B 43 | return loss 44 | -------------------------------------------------------------------------------- /utils/criterions/survival_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DeepSurvLoss(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def _compute_loss(self, P, T, E, M, mode): 9 | P_exp = torch.exp(P) # (B,) 10 | P_exp_B = torch.stack([P_exp for _ in range(P.shape[0])], dim=0) # (B, B) 11 | if mode == 'risk': 12 | E = E.float() * (M.sum(dim=1) > 0).float() 13 | elif mode == 'surv': 14 | E = (M.sum(dim=1) > 0).float() 15 | else: 16 | raise NotImplementedError 17 | P_exp_sum = (P_exp_B * M.float()).sum(dim=1) 18 | P_tmp = P_exp / (P_exp_sum+1e-6) 19 | loss = -torch.sum(torch.log(P_tmp.clip(1e-6, P_tmp.max().item())) * E) / torch.sum(E) 20 | return loss 21 | 22 | def forward(self, P_risk, T, E): 23 | # P: (B,) 24 | # T: (B,) 25 | # E: (B,) \in {0, 1} 26 | M_risk = T.unsqueeze(dim=1) < T.unsqueeze(dim=0) # (B, B) 27 | loss_risk = self._compute_loss(P_risk, T, E, M_risk, mode='risk') 28 | return loss_risk -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import math 3 | from utils.print_utils import * 4 | 5 | class CyclicLR(object): 6 | ''' 7 | CLass that defines cyclic learning rate with warm restarts that decays the learning rate linearly till the end of cycle and then restarts 8 | at the maximum value. 9 | See https://arxiv.org/abs/1811.11431 for more details 10 | ''' 11 | 12 | def __init__(self, min_lr=0.1, cycle_len=5, steps=[51, 101, 131, 161, 191, 221, 251, 281], gamma=0.5, step=True): 13 | super(CyclicLR, self).__init__() 14 | assert len(steps) > 0, 'Please specify step intervals.' 15 | assert 0 < gamma <= 1, 'Learing rate decay factor should be between 0 and 1' 16 | self.min_lr = min_lr # minimum learning rate 17 | self.m = cycle_len 18 | self.steps = steps 19 | self.warm_up_interval = 1 # we do not start from max value for the first epoch, because some time it diverges 20 | self.counter = 0 21 | self.decayFactor = gamma # factor by which we should decay learning rate 22 | self.count_cycles = 0 23 | self.step_counter = 0 24 | self.stepping = step 25 | 26 | def step(self, epoch): 27 | if epoch % self.steps[self.step_counter] == 0 and epoch > 1 and self.stepping: 28 | self.min_lr = self.min_lr * self.decayFactor 29 | self.count_cycles = 0 30 | if self.step_counter < len(self.steps) - 1: 31 | self.step_counter += 1 32 | else: 33 | self.stepping = False 34 | current_lr = self.min_lr 35 | # warm-up or cool-down phase 36 | if self.count_cycles < self.warm_up_interval: 37 | self.count_cycles += 1 38 | # We do not need warm up after first step. 39 | # so, we set warm up interval to 0 after first step 40 | if self.count_cycles == self.warm_up_interval: 41 | self.warm_up_interval = 0 42 | else: 43 | # Cyclic learning rate with warm restarts 44 | # max_lr (= min_lr * step_size) is decreased to min_lr using linear decay before 45 | # it is set to max value at the end of cycle. 46 | if self.counter >= self.m: 47 | self.counter = 0 48 | current_lr = round((self.min_lr * self.m) - (self.counter * self.min_lr), 5) 49 | self.counter += 1 50 | self.count_cycles += 1 51 | return current_lr 52 | 53 | def __repr__(self): 54 | fmt_str = 'Scheduler ' + self.__class__.__name__ + '\n' 55 | fmt_str += ' Min. base LR: {}\n'.format(self.min_lr) 56 | fmt_str += ' Max. base LR: {}\n'.format(self.min_lr * self.m) 57 | fmt_str += ' Step interval: {}\n'.format(self.steps) 58 | fmt_str += ' Decay lr at each step by {}\n'.format(self.decayFactor) 59 | return fmt_str 60 | 61 | 62 | class MultiStepLR(object): 63 | ''' 64 | Fixed LR scheduler with steps 65 | ''' 66 | 67 | def __init__(self, base_lr=0.1, steps=[30, 60, 90], gamma=0.1, step=True): 68 | super(MultiStepLR, self).__init__() 69 | assert len(steps) >= 1, 'Please specify step intervals.' 70 | self.base_lr = base_lr 71 | self.steps = steps 72 | self.decayFactor = gamma # factor by which we should decay learning rate 73 | self.stepping = step 74 | print('Using Fixed LR Scheduler') 75 | 76 | def step(self, epoch): 77 | return round(self.base_lr * (self.decayFactor ** bisect.bisect(self.steps, epoch)), 5) 78 | 79 | def __repr__(self): 80 | fmt_str = 'Scheduler ' + self.__class__.__name__ + '\n' 81 | fmt_str += ' Base LR: {}\n'.format(self.base_lr) 82 | fmt_str += ' Step interval: {}\n'.format(self.steps) 83 | fmt_str += ' Decay lr at each step by {}\n'.format(self.decayFactor) 84 | return fmt_str 85 | 86 | 87 | class PolyLR(object): 88 | ''' 89 | Polynomial LR scheduler with steps 90 | ''' 91 | 92 | def __init__(self, base_lr, max_epochs, power=0.99): 93 | super(PolyLR, self).__init__() 94 | assert 0 < power < 1 95 | self.base_lr = base_lr 96 | self.power = power 97 | self.max_epochs = max_epochs 98 | 99 | def step(self, epoch): 100 | curr_lr = self.base_lr * (1 - (float(epoch) / self.max_epochs)) ** self.power 101 | return round(curr_lr, 6) 102 | 103 | def __repr__(self): 104 | fmt_str = 'Scheduler ' + self.__class__.__name__ + '\n' 105 | fmt_str += ' Total Epochs: {}\n'.format(self.max_epochs) 106 | fmt_str += ' Base LR: {}\n'.format(self.base_lr) 107 | fmt_str += ' Power: {}\n'.format(self.power) 108 | return fmt_str 109 | 110 | 111 | class LinearLR(object): 112 | def __init__(self, base_lr, max_epochs): 113 | super(LinearLR, self).__init__() 114 | self.base_lr = base_lr 115 | self.max_epochs = max_epochs 116 | 117 | def step(self, epoch): 118 | curr_lr = self.base_lr - (self.base_lr * (epoch / (self.max_epochs))) 119 | return round(curr_lr, 6) 120 | 121 | def __repr__(self): 122 | fmt_str = 'Scheduler ' + self.__class__.__name__ + '\n' 123 | fmt_str += ' Total Epochs: {}\n'.format(self.max_epochs) 124 | fmt_str += ' Base LR: {}\n'.format(self.base_lr) 125 | return fmt_str 126 | 127 | 128 | class HybirdLR(object): 129 | def __init__(self, base_lr, clr_max, max_epochs, cycle_len=5): 130 | super(HybirdLR, self).__init__() 131 | self.linear_epochs = max_epochs - clr_max + 1 132 | steps = [clr_max] 133 | self.clr = CyclicLR(min_lr=base_lr, cycle_len=cycle_len, steps=steps, gamma=1) 134 | self.decay_lr = LinearLR(base_lr=base_lr, max_epochs=self.linear_epochs) 135 | self.cyclic_epochs = clr_max 136 | 137 | self.base_lr = base_lr 138 | self.max_epochs = max_epochs 139 | self.clr_max = clr_max 140 | self.cycle_len = cycle_len 141 | 142 | def step(self, epoch): 143 | if epoch < self.cyclic_epochs: 144 | curr_lr = self.clr.step(epoch) 145 | else: 146 | curr_lr = self.decay_lr.step(epoch - self.cyclic_epochs + 1) 147 | return round(curr_lr, 6) 148 | 149 | def __repr__(self): 150 | fmt_str = 'Scheduler ' + self.__class__.__name__ + '\n' 151 | fmt_str += ' Total Epochs: {}\n'.format(self.max_epochs) 152 | fmt_str += ' Cycle with length of {}: {}\n'.format(self.cycle_len, int(self.clr_max / self.cycle_len)) 153 | fmt_str += ' Base LR with {} cycle length: {}\n'.format(self.cycle_len, self.base_lr) 154 | fmt_str += ' Cycle with length of {}: {}\n'.format(self.linear_epochs, 1) 155 | fmt_str += ' Base LR with {} cycle length: {}\n'.format(self.linear_epochs, self.base_lr) 156 | return fmt_str 157 | 158 | class CosineLR(object): 159 | def __init__(self, base_lr, max_epochs): 160 | super(CosineLR, self).__init__() 161 | self.base_lr = base_lr 162 | self.max_epochs = max_epochs 163 | 164 | def step(self, epoch): 165 | return round(self.base_lr * (1 + math.cos(math.pi * epoch / self.max_epochs)) / 2, 6) 166 | 167 | def __repr__(self): 168 | fmt_str = 'Scheduler ' + self.__class__.__name__ + '\n' 169 | fmt_str += ' Total Epochs: {}\n'.format(self.max_epochs) 170 | fmt_str += ' Base LR : {}\n'.format(self.base_lr) 171 | return fmt_str 172 | 173 | class FixedLR(object): 174 | def __init__(self, base_lr): 175 | self.base_lr = base_lr 176 | 177 | def step(self, epoch): 178 | return self.base_lr 179 | 180 | def __repr__(self): 181 | fmt_str = 'Scheduler ' + self.__class__.__name__ + '\n' 182 | fmt_str += ' Base LR : {}\n'.format(self.base_lr) 183 | return fmt_str 184 | 185 | def get_lr_scheduler(opts, printer=print): 186 | if opts.scheduler == 'multistep': 187 | step_size = opts.step_size if isinstance(opts.step_size, list) else [opts.step_size] 188 | if len(step_size) == 1: 189 | step_size = step_size[0] 190 | step_sizes = [step_size * i for i in range(1, int(math.ceil(opts.epochs / step_size)))] 191 | else: 192 | step_sizes = step_size 193 | lr_scheduler = MultiStepLR(base_lr=opts.lr, steps=step_sizes, gamma=opts.lr_decay) 194 | elif opts.scheduler == 'fixed': 195 | lr_scheduler = FixedLR(base_lr=opts.lr) 196 | elif opts.scheduler == 'clr': 197 | step_size = opts.step_size if isinstance(opts.step_size, list) else [opts.step_size] 198 | if len(step_size) == 1: 199 | step_size = step_size[0] 200 | step_sizes = [step_size * i for i in range(1, int(math.ceil(opts.epochs / step_size)))] 201 | else: 202 | step_sizes = step_size 203 | lr_scheduler = CyclicLR(min_lr=opts.lr, cycle_len=opts.cycle_len, steps=step_sizes, gamma=opts.lr_decay) 204 | elif opts.scheduler == 'poly': 205 | lr_scheduler = PolyLR(base_lr=opts.lr, max_epochs=opts.epochs, power=opts.power) 206 | elif opts.scheduler == 'hybrid': 207 | lr_scheduler = HybirdLR(base_lr=opts.lr, max_epochs=opts.epochs, clr_max=opts.clr_max, 208 | cycle_len=opts.cycle_len) 209 | elif opts.scheduler == 'linear': 210 | lr_scheduler = LinearLR(base_lr=opts.lr, max_epochs=opts.epochs) 211 | elif opts.scheduler == 'cos': 212 | lr_scheduler = CosineLR(base_lr=opts.lr, max_epochs=opts.epochs) 213 | else: 214 | print_error_message('{} scheduler Not supported'.format(opts.scheduler), printer) 215 | 216 | print_info_message(lr_scheduler, printer) 217 | return lr_scheduler 218 | 219 | 220 | def get_scheduler_opts(parser): 221 | ''' Scheduler Details''' 222 | 223 | group = parser.add_argument_group('Learning rate scheduler') 224 | group.add_argument('--scheduler', default='hybrid', help='Learning rate scheduler (e.g. fixed, clr, poly)') 225 | group.add_argument('--step-size', default=[51], type=int, nargs="+", help='Step sizes') 226 | group.add_argument('--lr-decay', default=0.5, type=float, help='factor by which lr should be decreased') 227 | 228 | group = parser.add_argument_group('CLR relating settings') 229 | group.add_argument('--cycle-len', default=5, type=int, help='Cycle length') 230 | group.add_argument('--clr-max', default=61, type=int, 231 | help='Max number of epochs for cylic LR before changing last cycle to linear') 232 | 233 | group = parser.add_argument_group('Poly LR related settings') 234 | group.add_argument('--power', default=0.9, type=float, help='power factor for Polynomial LR') 235 | 236 | group = parser.add_argument_group('Warm-up settings') 237 | group.add_argument('--warm-up', action='store_true', default=False, help='Warm-up') 238 | group.add_argument('--warm-up-min-lr', default=1e-7, help='Warm-up minimum lr') 239 | group.add_argument('--warm-up-iterations', default=2000, type=int, help='Number of warm-up iterations') 240 | 241 | return parser 242 | -------------------------------------------------------------------------------- /utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from utils.print_utils import print_log_message 4 | import torch 5 | from torch.nn import functional as F 6 | import numpy as np 7 | from typing import NamedTuple 8 | 9 | CMResults = NamedTuple( 10 | "CMResults", 11 | [ 12 | ('overall_accuracy', float), 13 | # 14 | ("sensitivity_micro", float), 15 | ("sensitivity_macro", float), 16 | ("sensitivity_class", list), 17 | # 18 | ("specificity_micro", float), 19 | ("specificity_macro", float), 20 | ("specificity_class", float), 21 | ("precision_micro", float), 22 | ("precision_macro", float), 23 | ("precision_class", float), 24 | # 25 | ("recall_micro", float), 26 | ("recall_macro", float), 27 | ("recall_class", float), 28 | # 29 | ("f1_micro", float), 30 | ("f1_macro", float), 31 | ("f1_class", float), 32 | # 33 | ("accuracy_micro", float), 34 | ("accuracy_macro", float), 35 | ("accuracy_class", float), 36 | # 37 | ('true_positive_rate_micro', float), 38 | ('true_positive_rate_macro', float), 39 | ('true_positive_rate_class', float), 40 | # 41 | ('true_negative_rate_micro', float), 42 | ('true_negative_rate_macro', float), 43 | ('true_negative_rate_class', float), 44 | # 45 | ('false_positive_rate_micro', float), 46 | ('false_positive_rate_macro', float), 47 | ('false_positive_rate_class', float), 48 | # 49 | ('false_negative_rate_micro', float), 50 | ('false_negative_rate_macro', float), 51 | ('false_negative_rate_class', float), 52 | # 53 | ('positive_pred_value_micro', float), 54 | ('positive_pred_value_macro', float), 55 | ('positive_pred_value_class', float), 56 | # 57 | ('negative_pred_value_micro', float), 58 | ('negative_pred_value_macro', float), 59 | ('negative_pred_value_class', float), 60 | # 61 | ('negative_likelihood_ratio_micro', float), 62 | ('negative_likelihood_ratio_macro', float), 63 | ('negative_likelihood_ratio_class', float), 64 | # 65 | ('positive_likelihood_ratio_micro', float), 66 | ('positive_likelihood_ratio_macro', float), 67 | ('positive_likelihood_ratio_class', float), 68 | # 69 | ('diagnostic_odd_ratio_micro', float), 70 | ('diagnostic_odd_ratio_macro', float), 71 | ('diagnostic_odd_ratio_class', float), 72 | # 73 | ('younden_index_micro', float), 74 | ('younden_index_macro', float), 75 | ('younden_index_class', float) 76 | ], 77 | ) 78 | 79 | 80 | def compute_micro_stats(values_a, values_b, eps=1e-8): 81 | sum_a = np.sum(values_a) 82 | sum_b = np.sum(values_b) 83 | 84 | micro_sc = sum_a / (sum_a + sum_b + eps) 85 | 86 | return micro_sc 87 | 88 | 89 | def compute_macro_stats(values): 90 | return np.mean(values) 91 | 92 | 93 | class CMMetrics(object): 94 | ''' 95 | Metrics defined here: https://www.sciencedirect.com/science/article/pii/S2210832718301546 96 | ''' 97 | 98 | def __init__(self): 99 | super(CMMetrics, self).__init__() 100 | self.eps = 1e-8 101 | 102 | def compute_precision(self, tp, fp): 103 | ''' 104 | Precision = TP/(TP + FP) 105 | ''' 106 | class_wise = tp / (tp + fp + self.eps) 107 | 108 | micro = compute_micro_stats(tp, fp) 109 | macro = compute_macro_stats(class_wise) 110 | 111 | return micro, macro, class_wise 112 | 113 | def compute_senstivity(self, tp, fn): 114 | ''' 115 | Sensitivity = TP/(TP + FN) 116 | ''' 117 | class_wise = tp / (tp + fn + self.eps) 118 | micro = compute_micro_stats(tp, fn) 119 | macro = compute_macro_stats(class_wise) 120 | 121 | return micro, macro, class_wise 122 | 123 | def compute_specificity(self, tn, fp): 124 | class_wise = (tn / (tn + fp + self.eps)) 125 | micro = compute_micro_stats(tn, fp) 126 | macro = compute_macro_stats(class_wise) 127 | 128 | return micro, macro, class_wise 129 | 130 | def compute_recall(self, tp, fn): 131 | # same as sensitivity 132 | class_wise = (tp / (tp + fn + self.eps)) 133 | micro = compute_micro_stats(tp, fn) 134 | macro = compute_macro_stats(class_wise) 135 | 136 | return micro, macro, class_wise 137 | 138 | def compute_f1(self, precision, recall): 139 | return (2.0 * precision * recall) / (precision + recall) 140 | 141 | def compute_acc(self, tp, tn, fp, fn): 142 | class_wise = ((tp + tn) / (tp + tn + fp + fn + self.eps)) 143 | micro = compute_micro_stats(tp + tn, tp + tn + fp + fn) 144 | macro = compute_macro_stats(class_wise) 145 | return micro, macro, class_wise 146 | 147 | def compute_overall_acc(self, tp, N): 148 | return tp.sum() / (N + self.eps) 149 | 150 | def compute_tpr(self, tp, fn): 151 | # True positive rate 152 | # same as senstivity and recall 153 | class_wise = (tp / (tp + fn + self.eps)) 154 | 155 | micro = compute_micro_stats(tp, fn) 156 | macro = compute_macro_stats(class_wise) 157 | 158 | return micro, macro, class_wise 159 | 160 | def compute_tnr(self, tn, fp): 161 | # True negative rate 162 | # same as specificity 163 | class_wise = (tn / (tn + fp + self.eps)) 164 | micro = compute_micro_stats(tn, fp) 165 | macro = compute_macro_stats(class_wise) 166 | return micro, macro, class_wise 167 | 168 | def compute_fpr(self, fp, tn): 169 | # False posistive rate 170 | class_wise = (fp / (fp + tn + self.eps)) 171 | micro = compute_micro_stats(fp, tn) 172 | 173 | macro = compute_macro_stats(class_wise) 174 | return micro, macro, class_wise 175 | 176 | def compute_fnr(self, fn, tp): 177 | # false negative rate 178 | # fnr_micro 179 | class_wise = (fn / (fn + tp + self.eps)) 180 | 181 | micro = compute_micro_stats(fn, tp) 182 | macro = compute_macro_stats(class_wise) 183 | 184 | return micro, macro, class_wise 185 | 186 | def compute_ppv(self, tp, fp): 187 | # Positive prediction value 188 | return self.compute_precision(tp=tp, fp=fp) 189 | 190 | def compute_npv(self, tn, fn): 191 | # Negative predictive value 192 | class_wise = (tn / (tn + fn + self.eps)) 193 | micro = compute_micro_stats(tn, fn) 194 | macro = compute_macro_stats(class_wise) 195 | 196 | return micro, macro, class_wise 197 | 198 | def compute_neg_lr(self, tpr, tnr): 199 | # negative likelihood ratio 200 | return (1.0 - tpr) / (tnr + self.eps) 201 | 202 | def compute_pos_lr(self, tpr, tnr): 203 | # positive likelihood ratio 204 | return tpr / (1.0 - tnr + self.eps) 205 | 206 | def compute_dor(self, tp, tn, fp, fn): 207 | # Diagnostic odds ratio 208 | class_wise = ((tp * tn) / (fp * fn + self.eps)) 209 | 210 | micro = compute_micro_stats(tp * tn, fp * fn) 211 | macro = compute_macro_stats(class_wise) 212 | 213 | return micro, macro, class_wise 214 | 215 | def compute_younden_index(self, tpr, tnr): 216 | # Youden's index 217 | return tpr + tnr - 1.0 218 | 219 | def compute_metrics(self, conf_mat): 220 | num_samples = conf_mat.sum() 221 | if conf_mat.shape[0] > 2: 222 | true_positives = np.diag(conf_mat) 223 | false_positives = conf_mat.sum(axis=0) - true_positives 224 | false_negatives = conf_mat.sum(axis=1) - true_positives 225 | true_negatives = conf_mat.sum() - (false_positives + false_negatives + true_positives) 226 | else: 227 | true_negatives, false_positives, false_negatives, true_positives = conf_mat.ravel() 228 | 229 | false_positives = false_positives.astype(float) 230 | false_negatives = false_negatives.astype(float) 231 | true_positives = true_positives.astype(float) 232 | true_negatives = true_negatives.astype(float) 233 | 234 | #print(true_positives, true_negatives, false_positives, false_negatives) 235 | 236 | sensitivity_micro, sensitivity_macro, sensitivity_class = self.compute_senstivity(tp=true_positives, 237 | fn=false_negatives) 238 | specificity_micro, specificity_macro, specificity_class = self.compute_specificity(tn=true_negatives, 239 | fp=false_positives) 240 | precision_micro, precision_macro, precision_class = self.compute_precision(tp=true_positives, 241 | fp=false_positives) 242 | recall_micro, recall_macro, recall_class = self.compute_recall(tp=true_positives, fn=false_negatives) 243 | f1_micro = self.compute_f1(precision=precision_micro, recall=recall_micro) 244 | f1_macro = compute_macro_stats(self.compute_f1(precision=precision_class, recall=recall_class)) 245 | f1_class = self.compute_f1(precision=precision_class, recall=recall_class) 246 | 247 | acc_micro, acc_macro, acc_class = self.compute_acc(tp=true_positives, tn=true_negatives, fp=false_positives, 248 | fn=false_negatives) 249 | overall_acc = self.compute_overall_acc(tp=true_positives, N=num_samples) 250 | 251 | tpr_micro, tpr_macro, tpr_class = self.compute_tpr(tp=true_positives, fn=false_negatives) 252 | tnr_micro, tnr_macro, tnr_class = self.compute_tnr(tn=true_negatives, fp=false_positives) 253 | fpr_micro, fpr_macro, fpr_class = self.compute_fpr(fp=false_positives, tn=true_negatives) 254 | fnr_micro, fnr_macro, fnr_class = self.compute_fnr(fn=false_negatives, tp=true_positives) 255 | 256 | ppv_micro, ppv_macro, ppv_class = self.compute_ppv(tp=true_positives, fp=false_positives) 257 | npv_micro, npv_macro, npv_class = self.compute_npv(tn=true_negatives, fn=false_negatives) 258 | neg_lr_micro = self.compute_neg_lr(tpr=tpr_micro, tnr=tnr_micro) 259 | neg_lr_class = self.compute_neg_lr(tpr=tpr_class, tnr=tnr_class) 260 | neg_lr_macro = compute_macro_stats(self.compute_neg_lr(tpr=tpr_class, tnr=tnr_class)) 261 | 262 | pos_lr_micro = self.compute_pos_lr(tpr=tpr_micro, tnr=tnr_micro) 263 | pos_lr_class = self.compute_pos_lr(tpr=tpr_class, tnr=tnr_class) 264 | pos_lr_macro = compute_macro_stats(self.compute_pos_lr(tpr=tpr_class, tnr=tnr_class)) 265 | 266 | dor_micro, dor_macro, dor_class = self.compute_dor(tp=true_positives, tn=true_negatives, fp=false_positives, 267 | fn=false_negatives) 268 | 269 | yi_micro = self.compute_younden_index(tpr=tpr_micro, tnr=tnr_micro) 270 | yi_class = self.compute_younden_index(tpr=tpr_class, tnr=tnr_class) 271 | yi_macro = compute_macro_stats(self.compute_younden_index(tpr=tpr_class, tnr=tnr_class)) 272 | 273 | return CMResults( 274 | overall_accuracy=overall_acc, 275 | sensitivity_micro=sensitivity_micro, 276 | sensitivity_macro=sensitivity_macro, 277 | sensitivity_class=sensitivity_class, 278 | # 279 | specificity_micro=specificity_micro, 280 | specificity_macro=specificity_macro, 281 | specificity_class=specificity_class, 282 | # 283 | precision_micro=precision_micro, 284 | precision_macro=precision_macro, 285 | precision_class=precision_class, 286 | # 287 | recall_micro=recall_micro, 288 | recall_macro=recall_macro, 289 | recall_class=recall_class, 290 | # 291 | f1_micro=f1_micro, 292 | f1_macro=f1_macro, 293 | f1_class=f1_class, 294 | # 295 | accuracy_micro=acc_micro, 296 | accuracy_macro=acc_macro, 297 | accuracy_class=acc_class, 298 | # 299 | true_positive_rate_micro=tpr_micro, 300 | true_positive_rate_macro=tpr_macro, 301 | true_positive_rate_class=tpr_class, 302 | # 303 | true_negative_rate_micro=tnr_micro, 304 | true_negative_rate_macro=tnr_macro, 305 | true_negative_rate_class=tnr_class, 306 | # 307 | false_positive_rate_micro=fpr_micro, 308 | false_positive_rate_macro=fpr_macro, 309 | false_positive_rate_class=fpr_class, 310 | # 311 | false_negative_rate_micro=fnr_micro, 312 | false_negative_rate_macro=fnr_macro, 313 | false_negative_rate_class=fnr_class, 314 | # 315 | positive_pred_value_micro=ppv_micro, 316 | positive_pred_value_macro=ppv_macro, 317 | positive_pred_value_class=ppv_class, 318 | # 319 | negative_pred_value_micro=npv_micro, 320 | negative_pred_value_macro=npv_macro, 321 | negative_pred_value_class=npv_class, 322 | # 323 | negative_likelihood_ratio_micro=neg_lr_micro, 324 | negative_likelihood_ratio_macro=neg_lr_macro, 325 | negative_likelihood_ratio_class=neg_lr_class, 326 | # 327 | positive_likelihood_ratio_micro=pos_lr_micro, 328 | positive_likelihood_ratio_macro=pos_lr_macro, 329 | positive_likelihood_ratio_class=pos_lr_class, 330 | # 331 | diagnostic_odd_ratio_micro=dor_micro, 332 | diagnostic_odd_ratio_macro=dor_macro, 333 | diagnostic_odd_ratio_class=dor_class, 334 | # 335 | younden_index_micro=yi_micro, 336 | younden_index_macro=yi_macro, 337 | younden_index_class=yi_class 338 | ) 339 | 340 | def accuracy(output, target, topk=(1,)): 341 | """Computes the precision@k for the specified values of k""" 342 | with torch.no_grad(): 343 | maxk = max(topk) 344 | batch_size = target.size(0) 345 | 346 | _, pred = output.topk(maxk, 1, True, True) 347 | pred = pred.t() 348 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 349 | 350 | res = [] 351 | for k in topk: 352 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 353 | res.append(correct_k.mul_(100.0 / batch_size)) 354 | return res 355 | 356 | 357 | def compute_f1(y_pred: torch.Tensor, y_true: torch.Tensor, n_classes=4, epsilon=1e-7, is_one_hot=False): 358 | if is_one_hot: 359 | # B x C 360 | assert y_pred.dim() == y_true.dim() 361 | else: 362 | assert len(y_pred.size()) == 2 # B x C 363 | assert len(y_true.size()) == 1 # B 364 | 365 | with torch.no_grad(): 366 | y_true = y_true.to(torch.float32) if is_one_hot else F.one_hot(y_true.to(torch.int64), n_classes).to(torch.float32) 367 | y_pred = y_pred.argmax(dim=1) 368 | y_pred = F.one_hot(y_pred.to(torch.int64), n_classes).to(torch.float32) 369 | 370 | tp = (y_true * y_pred).sum().to(torch.float32) 371 | tn = ((1 - y_true) * (1 - y_pred)).sum().to(torch.float32) 372 | fp = ((1 - y_true) * y_pred).sum().to(torch.float32) 373 | fn = (y_true * (1 - y_pred)).sum().to(torch.float32) 374 | 375 | precision = tp / (tp + fp + epsilon) 376 | recall = tp / (tp + fn + epsilon) 377 | 378 | f1 = 2 * (precision * recall) / (precision + recall + epsilon) 379 | return torch.mean(f1) * 100 380 | 381 | class Statistics(object): 382 | ''' 383 | This class is used to store the training and validation statistics 384 | ''' 385 | def __init__(self, printer=print): 386 | super(Statistics, self).__init__() 387 | self.loss = 0 388 | self.auc = 0 389 | self.eps = 1e-9 390 | self.counter = 0 391 | self.printer = printer 392 | 393 | def update(self, loss, auc): 394 | ''' 395 | :param loss: Loss at ith time step 396 | :param auc: Accuracy at ith time step 397 | :return: 398 | ''' 399 | self.loss += loss 400 | self.auc += auc 401 | self.counter += 1 402 | 403 | def __str__(self): 404 | return 'Loss: {}'.format(self.loss) 405 | 406 | def avg_auc(self): 407 | ''' 408 | :return: Average Accuracy 409 | ''' 410 | return self.auc / self.counter 411 | 412 | 413 | def avg_loss(self): 414 | ''' 415 | :return: Average loss 416 | ''' 417 | return self.loss/self.counter 418 | 419 | def output(self, epoch, batch, n_batches, start, lr): 420 | ''' 421 | Displays the output 422 | :param epoch: Epoch number 423 | :param batch: batch number 424 | :param n_batches: Total number of batches in the dataset 425 | :param start: Epoch start time 426 | :param lr: Current LR 427 | :return: 428 | ''' 429 | print_log_message( 430 | "Epoch: {:3d} [{:4d}/{:4d}], " 431 | "Loss: {:5.3f}, " 432 | "Auc: {:3.3f}, " 433 | "LR: {:1.6f}, " 434 | "Elapsed time: {:5.2f} seconds".format( 435 | epoch, batch, n_batches, 436 | self.avg_loss(), 437 | self.avg_auc(), 438 | lr, 439 | time.time() - start 440 | ), 441 | printer=self.printer 442 | ) 443 | sys.stdout.flush() -------------------------------------------------------------------------------- /utils/print_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | text_colors = { 4 | 'logs': '\033[34m', 5 | 'info': '\033[32m', 6 | 'warning': '\033[33m', 7 | 'error': '\033[31m', 8 | 'bold': '\033[1m', 9 | 'end_color': '\033[0m' 10 | } 11 | 12 | 13 | def get_curr_time_stamp(): 14 | return time.strftime("%Y-%m-%d %H:%M:%S") 15 | 16 | 17 | def print_error_message(message, printer=print): 18 | time_stamp = get_curr_time_stamp() 19 | error_str = text_colors['error'] + text_colors['bold'] + 'ERROR ' + text_colors['end_color'] 20 | printer('{} - {} - {}'.format(time_stamp, error_str, message)) 21 | printer('{} - {} - {}'.format(time_stamp, error_str, 'Exiting!!!')) 22 | exit(-1) 23 | 24 | 25 | def print_log_message(message, printer=print): 26 | time_stamp = get_curr_time_stamp() 27 | log_str = text_colors['logs'] + text_colors['bold'] + 'LOGS ' + text_colors['end_color'] 28 | printer('{} - {} - {}'.format(time_stamp, log_str, message)) 29 | 30 | 31 | def print_warning_message(message, printer=print): 32 | time_stamp = get_curr_time_stamp() 33 | warn_str = text_colors['warning'] + text_colors['bold'] + 'WARNING' + text_colors['end_color'] 34 | printer('{} - {} - {}'.format(time_stamp, warn_str, message)) 35 | 36 | 37 | def print_info_message(message, printer=print): 38 | time_stamp = get_curr_time_stamp() 39 | info_str = text_colors['info'] + text_colors['bold'] + 'INFO ' + text_colors['end_color'] 40 | printer('{} - {} - {}'.format(time_stamp, info_str, message)) 41 | 42 | 43 | if __name__ == '__main__': 44 | print_log_message('Testing') 45 | print_warning_message('Testing') 46 | print_info_message('Testing') 47 | print_error_message('Testing') -------------------------------------------------------------------------------- /utils/roc_utils.py: -------------------------------------------------------------------------------- 1 | from utils.utils import ColorEncoder 2 | from sklearn.metrics import roc_curve, auc 3 | from matplotlib import pyplot as plt 4 | from numpy import interp 5 | import numpy as np 6 | import itertools 7 | from matplotlib import rcParams 8 | import pandas as pd 9 | import seaborn as sn 10 | 11 | ''' 12 | This file defines functions for plotting ROC curves and confusion matrices 13 | ''' 14 | 15 | rcParams['font.family'] = 'monospace' 16 | rcParams['font.size'] = 12 17 | 18 | font_main_axis = { 19 | 'weight': 'bold', 20 | 'size': 12 21 | } 22 | 23 | LINE_WIDTH = 1.5 24 | 25 | MICRO_COLOR = 'k' # (255/255.0, 127/255.0, 0/255.0) 26 | MACRO_COLOR = 'k' # (255/255.0,255/255.0,51/255.0) 27 | MICRO_LINE_STYLE = 'dashed' 28 | MACRO_LINE_STYLE = 'solid' 29 | 30 | CLASS_LINE_WIDTH = 2 31 | 32 | GRID_COLOR = (204 / 255.0, 204 / 255.0, 204 / 255.0) 33 | GRID_LINE_WIDTH = 0.25 34 | GRID_LINE_STYLE = ':' 35 | 36 | 37 | def plot_roc(ground_truth, pred_probs, n_classes, save_loc='./', file_name='dummy', 38 | class_names=None, dataset_name='bbwsi'): 39 | class_colors, class_linestyles = ColorEncoder().get_colors(dataset_name=dataset_name) 40 | 41 | fpr = dict() 42 | tpr = dict() 43 | roc_auc = dict() 44 | # compute ROC curve class-wise 45 | for i in range(n_classes): 46 | fpr[i], tpr[i], _ = roc_curve(ground_truth[:, i], pred_probs[:, i]) 47 | roc_auc[i] = auc(fpr[i], tpr[i]) 48 | 49 | # COMPUTE MICRO-AVERAGE ROC CURVE AND ROC AREA 50 | fpr["micro"], tpr["micro"], _ = roc_curve(ground_truth.ravel(), pred_probs.ravel()) 51 | roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 52 | 53 | # COMPUTE MACRO-AVERAGE ROC CURVE AND ROC AREA 54 | 55 | # First aggregate all false positive rates 56 | all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) 57 | 58 | # Then interpolate all ROC curves at this points 59 | mean_tpr = np.zeros_like(all_fpr) 60 | for i in range(n_classes): 61 | mean_tpr += interp(all_fpr, fpr[i], tpr[i]) 62 | 63 | # Finally average it and compute AUC 64 | mean_tpr /= n_classes 65 | 66 | fpr["macro"] = all_fpr 67 | tpr["macro"] = mean_tpr 68 | roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) 69 | 70 | # PLOT the curves 71 | micro_label = 'Micro avg. (AUC={0:0.2f})'.format(roc_auc["micro"]) 72 | plt.plot(fpr["micro"], tpr["micro"], label=micro_label, color=MICRO_COLOR, 73 | linestyle=MICRO_LINE_STYLE, linewidth=LINE_WIDTH) 74 | 75 | macro_label = 'Macro avg. (AUC={0:0.2f})'.format(roc_auc["macro"]) 76 | plt.plot(fpr["macro"], tpr["macro"], label=macro_label, color=MACRO_COLOR, 77 | linestyle=MACRO_LINE_STYLE, linewidth=LINE_WIDTH) 78 | 79 | if class_names is not None: 80 | assert len(class_names) == n_classes 81 | for i, c_name in enumerate(class_names): 82 | label = "{0} (AUC={1:0.2f})".format(c_name, roc_auc[i]) 83 | plt.plot(fpr[i], tpr[i], color=class_colors[i], 84 | lw=CLASS_LINE_WIDTH, label=label, linestyle=class_linestyles[i]) 85 | else: 86 | for i, color in zip(range(n_classes), class_colors): 87 | label = 'Class {0} (AUC={1:0.2f})'.format(i, roc_auc[i]) 88 | plt.plot(fpr[i], tpr[i], color=color, lw=CLASS_LINE_WIDTH, 89 | label=label, linestyle=class_linestyles[i]) 90 | 91 | plt.plot([0, 1], [0, 1], 'tab:gray', linestyle='--', linewidth=1) 92 | plt.grid(color=GRID_COLOR, linestyle=GRID_LINE_STYLE, linewidth=GRID_LINE_WIDTH) 93 | plt.xlim([0.0, 1.0]) 94 | plt.ylim([0.0, 1.05]) 95 | plt.xlabel('False Positive Rate', fontdict=font_main_axis) 96 | plt.ylabel('True Positive Rate', fontdict=font_main_axis) 97 | plt.legend(edgecolor='black', loc="best") 98 | # plt.tight_layout() 99 | plt.savefig('{}/{}.pdf'.format(save_loc, file_name), dpi=300, bbox_inches='tight') 100 | plt.close() 101 | 102 | 103 | def plot_confusion_matrix(cmat_array, class_names=None, save_loc='./', file_name='demo'): 104 | class_names = range(cmat_array.shape[0]) if class_names is None else class_names 105 | cmat_array = cmat_array / cmat_array.astype(np.float).sum(axis=1)[:, np.newaxis] 106 | 107 | df_cm = pd.DataFrame(cmat_array, columns=class_names, index=class_names) 108 | sn.heatmap(df_cm, cmap="Blues", 109 | xticklabels=class_names, 110 | annot=True, 111 | annot_kws=font_main_axis, 112 | square=True) 113 | 114 | plt.yticks(np.arange(len(class_names)) + 0.5, class_names, va="center") 115 | plt.ylim([len(class_names), 0]) 116 | plt.ylabel('True label', fontdict=font_main_axis) 117 | plt.xlabel('Predicted label', fontdict=font_main_axis) 118 | plt.savefig('{}/{}.pdf'.format(save_loc, file_name), dpi=300, bbox_inches='tight') 119 | plt.close() 120 | 121 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from utils.print_utils import * 4 | import argparse 5 | import glob 6 | import torch 7 | import random 8 | import time 9 | import logging 10 | import matplotlib.pyplot as plt 11 | import os 12 | 13 | def setup_seed(seed): 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | 17 | os.environ['PYTHONHASHSEED'] = str(seed) 18 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 19 | 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | torch.manual_seed(seed) 23 | 24 | torch.use_deterministic_algorithms(True) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.enabled = False 27 | torch.backends.cudnn.benchmark = False 28 | 29 | def build_logging(filename): 30 | logging.basicConfig(level=logging.DEBUG, 31 | format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', 32 | datefmt='%m-%d %H:%M', 33 | filename=filename, 34 | filemode='w') 35 | console = logging.StreamHandler() 36 | console.setLevel(logging.INFO) 37 | formatter = logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') 38 | console.setFormatter(formatter) 39 | logging.getLogger('PIL').setLevel(logging.INFO) 40 | logging.getLogger('').addHandler(console) 41 | return logging 42 | 43 | class NumpyEncoder(json.JSONEncoder): 44 | def default(self, obj): 45 | if isinstance(obj, np.integer): 46 | return int(obj) 47 | elif isinstance(obj, np.floating): 48 | return float(obj) 49 | elif isinstance(obj, np.ndarray): 50 | return obj.tolist() 51 | else: 52 | return super(NumpyEncoder, self).default(obj) 53 | 54 | class ColorEncoder(object): 55 | def __init__(self): 56 | super().__init__() 57 | 58 | def get_colors(self, dataset_name): 59 | if dataset_name == 'bingli': 60 | class_colors = [ 61 | (228/ 255.0, 26/ 255.0, 28/ 255.0), 62 | (55/ 255.0, 126/ 255.0, 184/ 255.0), 63 | #(77/ 255.0, 175/ 255.0, 74/ 255.0), 64 | #(152/ 255.0, 78/ 255.0, 163/ 255.0) 65 | ] 66 | 67 | class_linestyle = ['solid', 'solid'] 68 | 69 | return class_colors, class_linestyle 70 | else: 71 | raise NotImplementedError 72 | 73 | class DictWriter(object): 74 | def __init__(self, file_name, format='csv'): 75 | super().__init__() 76 | assert format in ['csv', 'json', 'txt'] 77 | 78 | self.file_name = '{}.{}'.format(file_name, format) 79 | self.format = format 80 | 81 | def write(self, data_dict: dict): 82 | if self.format == 'csv': 83 | import csv 84 | with open(self.file_name, 'w', newline="") as csv_file: 85 | writer = csv.writer(csv_file) 86 | for key, value in data_dict.items(): 87 | writer.writerow([key, value]) 88 | elif self.format == 'json': 89 | import json 90 | with open(self.file_name, 'w') as fp: 91 | json.dump(data_dict, fp, indent=4, sort_keys=True) 92 | else: 93 | with open(self.file_name, 'w') as txt_file: 94 | for key, value in data_dict.items(): 95 | line = '{} : {}\n'.format(key, value) 96 | txt_file.write(line) 97 | 98 | def save_checkpoint(epoch, model_state, optimizer_state, best_perf, save_dir, is_best, metric, keep_best_k_models=-1, printer=print): 99 | best_perf = round(best_perf, 3) 100 | checkpoint = { 101 | 'epoch': epoch, 102 | 'state_dict': model_state, 103 | 'optim_dict': optimizer_state, 104 | 'best_perf': best_perf 105 | } 106 | # overwrite last checkpoint everytime 107 | ckpt_fname = '{}/checkpoint_last.pth'.format(save_dir) 108 | torch.save(checkpoint, ckpt_fname) 109 | 110 | # if epoch % 10 == 0: 111 | if metric >= 0.8: 112 | # # write checkpoint for every epoch 113 | ep_ckpt_fname = '{}/model_{:03d}.pth'.format(save_dir, epoch) 114 | torch.save(checkpoint['state_dict'], ep_ckpt_fname) 115 | 116 | if keep_best_k_models > 0: 117 | checkpoint_files = glob.glob('{}/model_best_*') 118 | n_best_chkpts = len(checkpoint_files) 119 | if n_best_chkpts >= keep_best_k_models: 120 | # Extract accuracy of existing best checkpoints 121 | perf_tie = dict() 122 | for f_name in checkpoint_files: 123 | # first split on directory 124 | # second split on _ 125 | # 3rd split on pth 126 | perf = float(f_name.split('/')[-1].split('_')[-1].split('.pth')[0]) 127 | # in case multiple models have the same perf value 128 | if perf not in perf_tie: 129 | perf_tie[perf] = [f_name] 130 | else: 131 | perf_tie[perf].append(f_name) 132 | 133 | min_perf_k_checks = min(list(perf_tie.keys())) 134 | 135 | if best_perf >= min_perf_k_checks: 136 | best_ckpt_fname = '{}/model_best_{}_{}.pth'.format(save_dir, epoch, best_perf) 137 | torch.save(checkpoint['state_dict'], best_ckpt_fname) 138 | 139 | min_check_loc = perf_tie[min_auc][0] 140 | if os.path.isfile(min_check_loc): 141 | os.remove(min_check_loc) 142 | else: 143 | best_ckpt_fname = '{}/model_best_{}_{}.pth'.format(save_dir, epoch, best_perf) 144 | torch.save(checkpoint['state_dict'], best_ckpt_fname) 145 | 146 | # save the best checkpoint 147 | if is_best: 148 | best_model_fname = '{}/model_best.pth'.format(save_dir) 149 | torch.save(model_state, best_model_fname) 150 | print_info_message('Checkpoint saved at: {}'.format(best_model_fname), printer) 151 | 152 | #print_info_message('Checkpoint saved at: {}'.format(ep_ckpt_fname), printer) 153 | 154 | 155 | def load_checkpoint(ckpt_fname, device='cpu'): 156 | #ckpt_fname = '{}/checkpoint_last.pth'.format(checkpoint_dir) 157 | model_state = torch.load(ckpt_fname, map_location=device) 158 | return model_state 159 | 160 | # epoch = checkpoint['epoch'] 161 | # model_state = checkpoint['state_dict'] 162 | # optim_state = checkpoint['optim_dict'] 163 | # best_perf = checkpoint['best_perf'] 164 | # return epoch, model_state, optim_state, best_perf 165 | 166 | def save_arguments(args, save_loc, json_file_name='arguments.json', printer=print): 167 | argparse_dict = vars(args) 168 | arg_fname = '{}/{}'.format(save_loc, json_file_name) 169 | writer = DictWriter(file_name=arg_fname, format='json') 170 | writer.write(argparse_dict) 171 | print_log_message('Arguments are dumped here: {}'.format(arg_fname), printer) 172 | 173 | 174 | def load_arguments(parser, dumped_arg_loc, json_file_name='arguments.json'): 175 | arg_fname = '{}/{}'.format(dumped_arg_loc, json_file_name) 176 | parser = argparse.ArgumentParser(parents=[parser], add_help=False) 177 | with open(arg_fname, 'r') as fp: 178 | json_dict = json.load(fp) 179 | parser.set_defaults(**json_dict) 180 | 181 | updated_args = parser.parse_args() 182 | 183 | return updated_args 184 | 185 | 186 | def load_arguments_file(parser, arg_fname): 187 | parser = argparse.ArgumentParser(parents=[parser], add_help=False) 188 | with open(arg_fname, 'r') as fp: 189 | json_dict = json.load(fp) 190 | parser.set_defaults(**json_dict) 191 | updated_args = parser.parse_args() 192 | 193 | return updated_args 194 | 195 | def plot_results(res_dict, plot_file): 196 | N = len(list(res_dict.keys())) 197 | _, axarr = plt.subplots(1, N, figsize=(5*N, 5)) 198 | for i, (key, value) in enumerate(res_dict.items()): 199 | axarr[i].plot(range(len(value)), value, label=key) 200 | axarr[i].legend() 201 | plt.savefig(plot_file) 202 | plt.close() --------------------------------------------------------------------------------