├── README.md ├── config.json ├── ecg_preprocessing.py ├── model_training.py └── modelbuild.py /README.md: -------------------------------------------------------------------------------- 1 | # Python Scripts 2 | This is the python scripts for data preprocessing and modelling of the following paper in *The lancet Digital Health*: 3 | 4 | Zhu, Hongling, et al. "Automatic multilabel electrocardiogram diagnosis of heart rhythm or conduction abnormalities with deep learning: a cohort study." The Lancet Digital Health (2020). [Open access.](https://www.thelancet.com/journals/landig/article/PIIS2589-7500(20)30107-2/fulltext) 5 | 6 | Bibtex citation: 7 | 8 | @article{zhu2020automatic, 9 | title={Automatic multilabel electrocardiogram diagnosis of heart rhythm or conduction abnormalities with deep learning: a cohort study}, 10 | author={Zhu, Hongling and Cheng, Cheng and Yin, Hang and Li, Xingyi and Zuo, Ping and Ding, Jia and Lin, Fan and Wang, Jingyi and Zhou, Beitong and Li, Yonge and others}, 11 | journal={The Lancet Digital Health}, 12 | year={2020}, 13 | publisher={Elsevier} 14 | } 15 | 16 | 17 | ## Files 18 | 19 | * `model_training.py ` 20 | * Script for training the diagnosis network 21 | * `ecg_preprocessing.py` 22 | * Script for the pre-processing procedure of the ECG recordings 23 | * `modelbuild.py` 24 | * Network structrue for the multi-label diagnosis model 25 | * `config.json` 26 | * root directory and hyper parameters 27 | 28 | 29 | ## Test dataset 30 | The test dataset from Tongji 31 | Hospital of this study is publicly available at [Mendeley Data](https://data.mendeley.com/datasets/6jd4rn2z9x/1). 32 | 33 | The Independent China Physiological Signal 34 | Challenge dataset is a public dataset available at: [http://2018.icbeb.org/Challenge.html.](http://2018.icbeb.org/Challenge.html.) 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "disease_num": 21, 3 | "conv_filter_size": 15, 4 | "conv_num_filters": [64, 128, 256, 512], 5 | "identity_block_num": [1, 1, 1, 1], 6 | "conv_init": "he_normal", 7 | "dropout": 0.6, 8 | "leaky_relu": 0.1, 9 | "dense_neurons": 512, 10 | 11 | "learning_rate": 0.001, 12 | "batch_size": 64, 13 | "l2": 0.01, 14 | 15 | "ecg_root_path": "/mnt/data/ECG_data/ECG/", 16 | 17 | "multilabel_save_folder": "/mnt/data/ECG_data/multilabel_new/", 18 | "multilabel_data_folder": "/mnt/data/ECG_data/multilabel_new/data/", 19 | "multilabel_label_folder": "/mnt/data/ECG_data/multilabel_new/label/", 20 | "multilabel_cardiologist_test_data": "/mnt/data/ECG_data/multilabel_test/data/", 21 | "multilabel_cardiologist_test_label": "/mnt/data/ECG_data/multilabel_test/label/", 22 | 23 | "multiclass_save_folder": "/mnt/data/ECG_data/multiclass/", 24 | "multiclass_data_folder": "/mnt/data/ECG_data/multiclass/data/", 25 | "multiclass_label_folder": "/mnt/data/ECG_data/multiclass/label/", 26 | "multiclass_cardiologist_test_data": "/mnt/data/ECG_data/multiclass_test/data/", 27 | "multiclass_cardiologist_test_label": "/mnt/data/ECG_data/multiclass_test/label/", 28 | 29 | "abbr_list": ["Normal", "ST", "SB", "PAC", "AR", "AT", "AFlutter", "AFib", "PJC", "JR", "PSVT", "PVC", "IVR", "VT", "AAPR", "AVPR", "LBBB", "1st Deg AVB", "Mobitz I AVB", "WPW-B", "WPW-A"], 30 | 31 | "gpu":2, 32 | 33 | "n_folds": 5 34 | } -------------------------------------------------------------------------------- /ecg_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import xlrd 4 | from tqdm import tqdm_notebook 5 | from multiprocessing import Process, Manager, Pool 6 | import os 7 | import datetime 8 | from sklearn.utils import shuffle 9 | from keras.utils import to_categorical 10 | import random 11 | 12 | class load_ecg: 13 | def __init__(self, ecg_dir): 14 | manager = Manager() 15 | self.data = manager.list() 16 | self.i_list = manager.list() 17 | self.file_path = ecg_dir 18 | self.file_list = os.listdir(self.file_path) 19 | 20 | def read_data_with_skiprows(self, work): 21 | wb = xlrd.open_workbook(work[1] + self.file_list[work[0]], logfile=open(os.devnull, 'w')) 22 | temp = np.array(pd.read_excel(wb, skiprows=[0, 1, 2], engine='xlrd')) 23 | self.data.append(temp) 24 | self.i_list.append(work[0]) 25 | 26 | def read_data_without_skiprows(self, work): 27 | wb = xlrd.open_workbook(work[1] + self.file_list[work[0]], logfile=open(os.devnull, 'w')) 28 | temp = np.array(pd.read_excel(wb, engine='xlrd')) 29 | self.data.append(temp) 30 | self.i_list.append(work[0]) 31 | 32 | def load_data(self, saved_name, normalized, skiprows, save): 33 | def normalize(data): 34 | mx = np.max(data, axis=0) 35 | mn = np.min(data, axis=0) 36 | return (data - mn) / (mx - mn) 37 | 38 | print('Start loading data from ' + self.file_path + '...') 39 | items = zip([x for x in range(len(self.file_list))], [self.file_path] * len(self.file_list)) 40 | p = Pool(30) 41 | if skiprows == True: 42 | list(tqdm_notebook(p.imap(self.read_data_with_skiprows, items), total=len(self.file_list))) 43 | else: 44 | list(tqdm_notebook(p.imap(self.read_data_without_skiprows, items), total=len(self.file_list))) 45 | p.close() 46 | p.join() 47 | print('Loading data done...') 48 | 49 | data_set = [x for x in range(len(self.file_list))] 50 | for i in range(len(self.file_list)): 51 | data_set[self.i_list[i]] = self.data[i] 52 | 53 | if normalized == True: 54 | print("Start normalizing...") 55 | for i in tqdm_notebook(range(np.array(data_set).shape[0])): 56 | data_set[i] = normalize(data_set[i]) 57 | print("Normalizing done...") 58 | 59 | if save == True: 60 | print('Start saving files...') 61 | np.save(saved_name + '_' + datetime.date.today().strftime("%m%d"), data_set) 62 | np.save(saved_name + '_norm_' + datetime.date.today().strftime("%m%d"), norm_data_set) 63 | np.save(saved_name + '_filelist', self.file_list) 64 | print('Saving files done...') 65 | 66 | return np.array(data_set), norm_data_set, self.file_list 67 | 68 | def get_dataset(dir_list, saved_name_list, skiprows, saved_bool, normalized=True): 69 | for i in tqdm_notebook(range(len(dir_list))): 70 | if skiprows[i] == True: 71 | data_loader = load_ecg(dir_list[i]) 72 | data_loader.load_data(saved_name_list[i], 73 | normalized=normalized, skiprows=True, save=saved_bool[i]) 74 | else: 75 | data_loader = load_ecg(dir_list[i]) 76 | data_loader.load_data(saved_name_list[i], 77 | normalized=normalized, skiprows=False, save=saved_bool[i]) 78 | 79 | def multilabel_ecg_loader(ecg_dir, label_dir, n_classes): 80 | def random_split(data_set, file_list): 81 | index = np.arange(len(data_set)) 82 | np.random.shuffle(index) 83 | n_test = int(len(data_set) * 0.05) 84 | test_index = index[:n_test] 85 | train_index = index[n_test:] 86 | test = data_set[test_index] 87 | train = data_set[train_index] 88 | test_file_list = file_list[test_index] 89 | train_file_list = file_list[train_index] 90 | return train, train_file_list, test, test_file_list 91 | 92 | def over_sampling(data_set, file_list, n_target): 93 | delta = n_target - data_set.shape[0] 94 | fold = delta // len(data_set) 95 | index = np.repeat(np.arange(len(data_set)), fold + 2) 96 | np.random.shuffle(index) 97 | np.random.shuffle(index) 98 | np.random.shuffle(index) 99 | return data_set[index[:n_target]], file_list[index[:n_target]] 100 | 101 | def multilabel_marker(target, file_list_dirs, n_classes): 102 | note_taker = [] 103 | file_list = np.array([]) 104 | for i in range(n_classes): 105 | note_taker.append(np.load(file_list_dirs[i])) 106 | file_list = np.append(file_list, note_taker[i]) 107 | 108 | label = np.zeros((target.shape[0], n_classes)) 109 | for i in range(target.shape[0]): 110 | for j in range(n_classes): 111 | if target[i] in note_taker[j]: 112 | label[i][j] = 1 113 | 114 | return label 115 | 116 | def nan_sweeper(data_set, label): 117 | nan_index = [] 118 | for i in range(data_set.shape[0]): 119 | if np.isnan(data_set[i]).any(): 120 | nan_index.append(i) 121 | data_set = np.delete(data_set, nan_index, axis=0) 122 | label = np.delete(label, nan_index, axis=0) 123 | 124 | return data_set, label 125 | 126 | 127 | x_train = np.array([]) 128 | x_test = np.array([]) 129 | y_train = np.array([]) 130 | y_test = np.array([]) 131 | 132 | print('Starting loading ECGs...') 133 | for i in tqdm_notebook(range(len(ecg_dir))): 134 | print('Loading ECGs from ' + ecg_dir[i] + '...') 135 | data_loader = np.load(ecg_dir[i], mmap_mode='r') 136 | label_dir_loader = np.load(label_dir[i], mmap_mode='r') 137 | train, train_label_dir, test, test_label_dir = random_split(data_loader, label_dir_loader) 138 | if train.shape[0] < 10000: 139 | train, train_label_dir = over_sampling(train, train_label_dir, n_target=10000) 140 | x_train = np.append(x_train, train) 141 | x_test = np.append(x_test, test) 142 | y_train = np.append(y_train, multilabel_marker(train_label_dir, label_dir, n_classes)) 143 | y_test = np.append(y_test, multilabel_marker(test_label_dir, label_dir, n_classes)) 144 | 145 | x_train, y_train = nan_sweeper(x_train.reshape(-1, 5000, 12), y_train.reshape(-1, n_classes)) 146 | x_test, y_test = nan_sweeper(x_test.reshape(-1, 5000, 12), y_test.reshape(-1, n_classes)) 147 | 148 | x_train, y_train = shuffle(x_train, y_train) 149 | print('Loading ECGs done...') 150 | 151 | return x_train, y_train, x_test, y_test 152 | 153 | def multiclasses_ecg_loader(ecg_dir): 154 | def random_split(data_set): 155 | index = np.arange(len(data_set)) 156 | np.random.shuffle(index) 157 | n_test = int(len(data_set) * 0.05) 158 | test_index = index[:n_test] 159 | train_index = index[n_test:] 160 | test = data_set[test_index] 161 | train = data_set[train_index] 162 | return train,test 163 | 164 | def over_sampling(data_set, n_target): 165 | delta = n_target - data_set.shape[0] 166 | fold = delta // len(data_set) 167 | index = np.repeat(np.arange(len(data_set)), fold + 2) 168 | np.random.shuffle(index) 169 | np.random.shuffle(index) 170 | np.random.shuffle(index) 171 | return data_set[index[:n_target]] 172 | 173 | 174 | 175 | def nan_sweeper(data_set, label): 176 | nan_index = [] 177 | for i in range(data_set.shape[0]): 178 | if np.isnan(data_set[i]).any(): 179 | nan_index.append(i) 180 | data_set = np.delete(data_set, nan_index, axis=0) 181 | label = np.delete(label, nan_index, axis=0) 182 | 183 | return data_set, label 184 | 185 | 186 | x_train = np.array([]) 187 | x_test = np.array([]) 188 | y_train = np.array([]) 189 | y_test = np.array([]) 190 | 191 | print('Starting loading ECGs...') 192 | for i in tqdm_notebook(range(len(ecg_dir))): 193 | print('Loading ECGs from ' + ecg_dir[i] + '...') 194 | data_loader = np.load(ecg_dir[i], mmap_mode='r') 195 | 196 | train, test = random_split(data_loader) 197 | if train.shape[0] < 10000: 198 | train = over_sampling(train, n_target=10000) 199 | x_train = np.append(x_train, train) 200 | x_test = np.append(x_test, test) 201 | y_train = np.append(y_train, np.ones(train.shape[0], dtype=np.int64)*i) 202 | y_test = np.append(y_test, np.ones(test.shape[0], dtype=np.int64) *i) 203 | y_train = to_categorical(y_train, num_classes=len(ecg_dir)) 204 | y_test = to_categorical(y_test, num_classes=len(ecg_dir)) 205 | 206 | x_train, y_train = nan_sweeper(x_train.reshape(-1, 5000, 12), y_train.reshape(-1, len(ecg_dir))) 207 | x_test, y_test = nan_sweeper(x_test.reshape(-1, 5000, 12), y_test.reshape(-1, len(ecg_dir))) 208 | 209 | x_train, y_train = shuffle(x_train, y_train) 210 | print('Loading ECGs done...') 211 | 212 | return x_train, y_train, x_test, y_test 213 | 214 | def val_choose(num, data): 215 | xmlfolder = os.listdir('/mnt/data/ECG_data/ECG_xml') 216 | datafolder = os.listdir('/mnt/data/ECG_data/ECG') 217 | xmlfolder.sort() 218 | datafolder.sort() 219 | xml_list = os.listdir('/mnt/data/ECG_data/ECG_xml/'+xmlfolder[num]) 220 | data_list = os.listdir('/mnt/data/ECG_data/ECG/'+datafolder[num]) 221 | dic = {} 222 | data_batch= [] 223 | #ecg from GE machine 224 | if num ==0: 225 | random.shuffle(data_list) 226 | for i in data_list: 227 | if len(data_batch)<=len(data_list)*0.1: 228 | data_batch.append(i[:-5]+'.npy') 229 | else: 230 | break 231 | return data_batch 232 | for l in xml_list: 233 | s = l[:-4].split('_') 234 | dic[s[1]] = dic.get(s[1],[]) 235 | dic[s[1]].append(s[0]+'ECG.npy' ) 236 | 237 | values = list(dic.values()) 238 | random.shuffle(values) 239 | #xml file corresponding to ECG file from GE machine 240 | for value in values: 241 | if len(data_batch)<=len(data_list)*0.1: 242 | for i in value: 243 | if i in data: 244 | data_batch.append(i) 245 | else: 246 | break 247 | else: 248 | # ecg from Holder machine 249 | dic = {} 250 | for i in data_list: 251 | if "_" in i: 252 | s = i[:-4].split('_') 253 | dic[s[0]] = dic.get(s[0],[]) 254 | dic[s[0]].append(i[:-4]+'.npy') 255 | elif "-" in i: 256 | s = i[:-5].split('-') 257 | dic[s[0]] = dic.get(s[0],[]) 258 | dic[s[0]].append(i[:-5]+'.npy') 259 | values = list(dic.values()) 260 | random.shuffle(values) 261 | for value in values: 262 | if len(data_batch)<=len(data_list)*0.1: 263 | for i in value: 264 | if i in data: 265 | data_batch.append(i) 266 | else: 267 | break 268 | else: 269 | print(len(data_batch)) 270 | print(datafolder[num]) 271 | print(xmlfolder[num]) 272 | raise 273 | return data_batch 274 | 275 | def val_split(class_num, params ): 276 | data = os.listdir(params["multilabel_data_folder"]) 277 | val_data = [] 278 | for i in tqdm_notebook(range(class_num)): 279 | val_data = val_data + val_choose(i,data) 280 | for i in tqdm_notebook(val_data): 281 | if i in data: 282 | data.remove(i) 283 | else: 284 | val_data.remove(i) 285 | 286 | train_label = np.array([np.load(params["multilabel_label_folder"] + i) for i in tqdm_notebook(data)]) 287 | val_label = np.array([np.load(params["multilabel_label_folder"] + i) for i in tqdm_notebook(val_data)]) 288 | 289 | train_id = np.array([params["multilabel_data_folder"] + i for i in data]).reshape(-1, 1) 290 | val_id = np.array([params["multilabel_data_folder"] + i for i in val_data]).reshape(-1, 1) 291 | 292 | train_id = train_id.reshape(-1).tolist() 293 | val_id = val_id.reshape(-1).tolist() 294 | 295 | train_label = dict(zip(train_id, train_label)) 296 | val_label = dict(zip(val_id, val_label)) 297 | 298 | return train_id, train_label, val_id, val_label 299 | 300 | def multiclass_val_split(class_num, params ): 301 | data = os.listdir(params["multilabel_data_folder"]) 302 | val_data = [] 303 | for i in tqdm_notebook(range(class_num)): 304 | val_data = val_data + val_choose(i,data) 305 | for i in tqdm_notebook(val_data): 306 | if i in data: 307 | data.remove(i) 308 | else: 309 | val_data.remove(i) 310 | 311 | train_label = np.load(params["multilabel_label_folder"] + data[0]) 312 | train_data_onehot = [] 313 | print('loading train_label') 314 | for i in tqdm_notebook(data): 315 | temp_label = np.load(params["multilabel_label_folder"] + i) 316 | 317 | if sum(temp_label)==1: 318 | train_label = np.vstack((train_label,temp_label)) 319 | train_data_onehot.append(i) 320 | val_data_onehot = [] 321 | val_label = np.load(params["multilabel_label_folder"] + val_data[0]) 322 | print('loading val_label') 323 | for i in tqdm_notebook(val_data): 324 | temp_label = np.load(params["multilabel_label_folder"] + i) 325 | if sum(temp_label) ==1: 326 | val_label = np.vstack((train_label,temp_label)) 327 | val_data_onehot.append(i) 328 | 329 | 330 | # train_label = np.delete(train_label,0,axis=0) 331 | # val_label = np.delete(val_label,0,axis=0) 332 | 333 | train_id = np.array([params["multilabel_data_folder"] + i for i in train_data_onehot]).reshape(-1, 1) 334 | val_id = np.array([params["multilabel_data_folder"] + i for i in val_data_onehot]).reshape(-1, 1) 335 | 336 | train_id = train_id.reshape(-1).tolist() 337 | val_id = val_id.reshape(-1).tolist() 338 | 339 | train_label = dict(zip(train_id, train_label)) 340 | val_label = dict(zip(val_id, val_label)) 341 | 342 | return train_id, train_label, val_id, val_label 343 | # 9:1 344 | def all_choose(num, data): 345 | xmlfolder = os.listdir('/mnt/data/ECG_data/ECG_xml') 346 | datafolder = os.listdir('/mnt/data/ECG_data/ECG') 347 | xmlfolder.sort() 348 | datafolder.sort() 349 | xml_list = os.listdir('/mnt/data/ECG_data/ECG_xml/'+xmlfolder[num]) 350 | data_list = os.listdir('/mnt/data/ECG_data/ECG/'+datafolder[num]) 351 | dic = {} 352 | conter = 0 353 | data_batch = [] 354 | for i in range(10): 355 | data_batch.append([]) 356 | normdatasum = 0 357 | tensecondsdatasum = 0 358 | print(datafolder[num]) 359 | #ecg from GE machine 360 | if num ==0: 361 | random.shuffle(data_list) 362 | for i in data_list: 363 | while len(data_batch[conter%10])>len(data_list)*0.1: 364 | conter =conter +1 365 | else: 366 | data_batch[conter%10].append(i[:-5]+'.npy') 367 | normdatasum =normdatasum+1 368 | conter =conter +1 369 | 370 | return data_batch 371 | 372 | for l in xml_list: 373 | s = l[:-4].split('_') 374 | dic[s[1]] = dic.get(s[1],[]) 375 | dic[s[1]].append(s[0]+'ECG.npy') 376 | 377 | values = list(dic.values()) 378 | random.shuffle(values) 379 | #xml file corresponding to ECG file from GE machine 380 | for value in values: 381 | while len(data_batch[conter%10])>len(data_list)*0.1: 382 | conter =conter +1 383 | else: 384 | for i in value: 385 | if i in data: 386 | data_batch[conter%10].append(i) 387 | normdatasum =normdatasum+1 388 | conter =conter +1 389 | 390 | else: 391 | # ecg from Holter machine 392 | dic = {} 393 | for i in data_list: 394 | if "_" in i: 395 | s = i[:-4].split('_') 396 | dic[s[0]] = dic.get(s[0],[]) 397 | dic[s[0]].append(i[:-4]+'.npy') 398 | elif "-" in i: 399 | s = i[:-5].split('-') 400 | dic[s[0]] = dic.get(s[0],[]) 401 | dic[s[0]].append(i[:-5]+'.npy') 402 | values = list(dic.values()) 403 | random.shuffle(values) 404 | for value in values: 405 | while len(data_batch[conter%10])>len(data_list)*0.1: 406 | conter =conter +1 407 | else: 408 | for i in value: 409 | if i in data: 410 | data_batch[conter%10].append(i) 411 | tensecondsdatasum = tensecondsdatasum+1 412 | conter =conter +1 413 | # else: 414 | # print(len(data_batch)) 415 | # print(datafolder[num]) 416 | # print(dic) 417 | # raise 418 | 419 | return data_batch 420 | 421 | def all_split(class_num, params ): 422 | data = os.listdir('/mnt/data/ECG_data/multilabel/data') 423 | all_data = [] 424 | for i in range(10): 425 | all_data.append([]) 426 | all_label = [] 427 | for i in range(10): 428 | all_label.append([]) 429 | all_id = [] 430 | for i in range(10): 431 | all_id.append([]) 432 | temp = [] 433 | for i in tqdm_notebook(range(class_num)): 434 | 435 | temp = all_choose(i,data) 436 | for j in tqdm_notebook(range(10)): 437 | 438 | for onedata in temp[j]: 439 | if onedata not in data: 440 | temp[j].remove(onedata) 441 | print('warning:',onedata) 442 | 443 | all_data[j] = all_data[j]+ temp[j] 444 | 445 | for i in tqdm_notebook(range(10)): 446 | 447 | all_id[i] = np.array([params["multilabel_data_folder"] + i for i in all_data[i]]).reshape(-1, 1) 448 | all_id[i] = all_id[i].reshape(-1).tolist() 449 | all_label[i] = np.array([np.load(params["multilabel_label_folder"] + i) for i in all_data[i]]) 450 | all_label[i] = dict(zip(all_id[i], all_label[i])) 451 | 452 | return all_id, all_label 453 | 454 | def load_cardiologist_test_set(data_dir, label_dir): 455 | ctd = [data_dir + i for i in sorted(os.listdir(data_dir))] 456 | ctl = [np.load(label_dir + i) for i in sorted(os.listdir(label_dir))] 457 | ctl = dict(zip(ctd, ctl)) 458 | 459 | X_cdolg_test = [] 460 | y_cdolg_test = [] 461 | for i in ctd: 462 | X_cdolg_test.append(np.load(i)) 463 | for i in range(len(ctd)): 464 | y_cdolg_test.append(ctl[ctd[i]]) 465 | 466 | 467 | return np.array(X_cdolg_test), np.array( y_cdolg_test) -------------------------------------------------------------------------------- /model_training.py: -------------------------------------------------------------------------------- 1 | %matplotlib inline 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | from IPython.display import display 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | from tqdm import tqdm_notebook 9 | from keras_tqdm import TQDMNotebookCallback 10 | from keras_radam import RAdam 11 | from sklearn.metrics import roc_curve, auc, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score,classification_report 12 | from multiclass_modelbuild import model_build, model_train, model_save, model_load, plot_roc, plot_confusion_matrix 13 | from ecg_preprocessing import val_split, load_cardiologist_test_set,multiclass_val_split,all_split 14 | from mpl_toolkits.axes_grid1 import make_axes_locatable 15 | from scipy.stats import sem, t 16 | 17 | import os 18 | import json 19 | from collections import Counter 20 | from datetime import date 21 | import tensorflow as tf 22 | from keras.backend.tensorflow_backend import set_session 23 | import time 24 | config = tf.ConfigProto() 25 | config.gpu_options.allocator_type = 'BFC' #A "Best-fit with coalescing" algorithm, simplified from a version of dlmalloc. 26 | config.gpu_options.allow_growth = True 27 | set_session(tf.Session(config=config)) 28 | 29 | tf.test.is_gpu_available() 30 | 31 | # Load parameters 32 | params = json.load(open('config.json', 'r')) 33 | 34 | # Load data and label 35 | all_id, all_label = all_split(21,params) 36 | 37 | # build model 38 | model, parallel_model = model_build(params) 39 | 40 | # model training (we here train the model for 10 times to calculate the mean and CIs) 41 | for j in tqdm_notebook(range(10)): 42 | val_id = all_id[j] 43 | val_label = all_label[j] 44 | train_id =[] 45 | train_label = {} 46 | for i in tqdm_notebook(range(10)): 47 | if i !=j: 48 | train_id = train_id + all_id[i] 49 | train_label.update(all_label[i]) 50 | model, parallel_model = model_build(params) 51 | model, parallel_model = model_train(model, parallel_model, train_id, train_label, val_id, val_label, params) 52 | model.save('multilabel_model_' + str(j+1) + '.h5') #save model 53 | time.sleep(1800) 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /modelbuild.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | from keras.models import Model, load_model 3 | from keras.layers import Dense, Dropout, Flatten, Activation, MaxPooling1D, Conv1D, Input, add, BatchNormalization, AveragePooling1D, LeakyReLU 4 | from keras.optimizers import Adam, SGD 5 | from keras.callbacks import ReduceLROnPlateau, EarlyStopping, TensorBoard, Callback 6 | from keras.utils import multi_gpu_model, Sequence, to_categorical 7 | from keras.losses import binary_crossentropy 8 | from keras.regularizers import l2 9 | from sklearn.metrics import roc_curve, auc, confusion_matrix, accuracy_score, precision_score, recall_score, f1_score 10 | from sklearn.model_selection import train_test_split, StratifiedKFold 11 | import os 12 | import json 13 | from collections import Counter 14 | from datetime import date 15 | import tensorflow as tf 16 | from keras.backend.tensorflow_backend import set_session 17 | 18 | from IPython.display import display 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import pandas as pd 22 | from tqdm import tqdm_notebook 23 | from keras_tqdm import TQDMNotebookCallback 24 | from keras_radam import RAdam 25 | 26 | def model_build(params): 27 | def conv_block(input_data, n_filters, filter_size, index): 28 | x = Conv1D(n_filters, filter_size, strides=2, padding='same', 29 | name='conv_block' + str(index) + '_' + 'conv_1', 30 | kernel_initializer=params["conv_init"])(input_data) 31 | x = BatchNormalization(name='conv_block' + str(index) + '_' + 'BN_1')(x) 32 | x = Activation('relu', name='conv_block' + str(index) + '_' + 'relu_1')(x) 33 | x = Conv1D(n_filters, filter_size, strides=1, padding='same', 34 | name='conv_block' + str(index) + '_' + 'conv_2', 35 | kernel_initializer=params["conv_init"])(x) 36 | x = BatchNormalization(name='conv_block' + str(index) + '_' + 'BN_2')(x) 37 | x = Activation('relu', name='conv_block' + str(index) + '_' + 'relu_2')(x) 38 | 39 | 40 | shortcut = Conv1D(n_filters, filter_size, strides=2, padding='same', 41 | name='conv_block' + str(index) + '_' + 'shortcut_conv', 42 | kernel_initializer=params["conv_init"])(input_data) 43 | shortcut = BatchNormalization(name='conv_block' + str(index) + '_' + 'shortcut_BN')(shortcut) 44 | x = add([x, shortcut], name='conv_block' + str(index) + '_' + 'add') 45 | x = Activation('relu', name='conv_block' + str(index) + '_' + 'relu_3')(x) 46 | 47 | return x 48 | 49 | def identity_block(input_data, n_filters, filter_size, index): 50 | x = Conv1D(n_filters, filter_size, strides=1, padding='same', 51 | name='identity_block' + str(index) + '_' + 'conv_1', 52 | kernel_initializer=params["conv_init"])(input_data) 53 | x = BatchNormalization(name='identity_block' + str(index) + '_' + 'BN_1')(x) 54 | x = Activation('relu', name='identity_block' + str(index) + '_' + 'relu_1')(x) 55 | x = add([x, input_data], name='identity_block' + str(index) + '_' + 'add') 56 | x = Activation('relu', name='identity_block' + str(index) + '_' + 'relu_2')(x) 57 | 58 | return x 59 | 60 | input_ecg = Input(shape=(5000, 12), name='input') 61 | x = Conv1D(filters=params["conv_num_filters"][0], kernel_size=15, 62 | strides=2, padding='same', kernel_initializer=params["conv_init"], name='conv_2')(input_ecg) 63 | x = BatchNormalization(name='BN_2')(x) 64 | x = Activation('relu', name='relu_2')(x) 65 | x = MaxPooling1D(name='max_pooling_1')(x) 66 | 67 | 68 | for i in range(4): 69 | x = conv_block(x, n_filters=params["conv_num_filters"][i], filter_size=params["conv_filter_size"], index=i + 1) 70 | x = MaxPooling1D(name='max_pooling_' + str(i + 2))(x) 71 | x = identity_block(x, n_filters=params["conv_num_filters"][i], 72 | filter_size=params["conv_filter_size"], index=i + 1) 73 | 74 | x = AveragePooling1D(name='average_pooling')(x) 75 | x = Flatten(name='flatten')(x) 76 | x = Dense(params["dense_neurons"], kernel_regularizer=l2(params["l2"]), name='FC1')(x) 77 | x = Activation('relu', name='relu_3')(x) 78 | x = Dropout(rate=params["dropout"])(x) 79 | x = Dense(params["dense_neurons"], kernel_regularizer=l2(params["l2"]), name='FC2')(x) 80 | x = Activation('relu', name='relu_4')(x) 81 | x = Dropout(rate=params["dropout"])(x) 82 | x = Dense(params["disease_num"], activation='sigmoid', name='output')(x) 83 | 84 | model = Model(inputs=input_ecg, outputs=x) 85 | parallel_model = multi_gpu_model(model, params["gpu"]) 86 | 87 | return model, parallel_model 88 | 89 | def multilabel_loss(y_true, y_pred): 90 | return K.sum(binary_crossentropy(y_true, y_pred)) 91 | def model_train(model, parallel_model, train_id, train_label, val_id, val_label, params): 92 | 93 | 94 | class DataGenerator(Sequence): 95 | """ 96 | Generate data for fit_generator. 97 | """ 98 | def __init__(self, data_ids, labels, batch_size, n_classes, shuffle=True): 99 | self.data_ids = data_ids 100 | self.labels = labels 101 | self.batch_size = batch_size 102 | self.n_classes = n_classes 103 | self.shuffle = shuffle 104 | self.on_epoch_end() 105 | 106 | def __len__(self): 107 | """ 108 | Denote the number of batches per epoch. 109 | """ 110 | return int(len(self.data_ids) / self.batch_size) 111 | 112 | def __getitem__(self, index): 113 | """ 114 | Generate one batch of data. 115 | """ 116 | # Generate indexes of the batch 117 | indexes = self.indexes[index * self.batch_size: (index + 1) * self.batch_size] 118 | 119 | # Find list of IDs 120 | data_ids_temp = [self.data_ids[k] for k in indexes] 121 | 122 | # Generate data 123 | X, y = self.__data_generation(data_ids_temp) 124 | 125 | return X, y 126 | 127 | def on_epoch_end(self): 128 | """ 129 | Update indexes after each epoch. 130 | """ 131 | self.indexes = np.arange(len(self.data_ids)) 132 | if self.shuffle == True: 133 | np.random.shuffle(self.indexes) 134 | 135 | def __data_generation(self, data_ids_temp): 136 | """ 137 | Generate data containing batch_size samples. 138 | """ 139 | # Generate data 140 | X = np.empty((self.batch_size, 5000, 12)) 141 | y = np.empty((self.batch_size, self.n_classes), dtype=int) 142 | for i, ID in enumerate(data_ids_temp): 143 | X[i] = np.load(ID) 144 | y[i] = self.labels[ID] 145 | 146 | return X, y 147 | 148 | # Get class_weight to solve the data imbalanced problem 149 | tmp = {i: len(os.listdir(os.path.join(params["ecg_root_path"], j))) for i, j in enumerate(sorted( 150 | os.listdir(params["ecg_root_path"]), key=lambda x: int(x[0]) if x[1] == '-' else int(x[:2])))} 151 | counter = Counter(tmp) 152 | max_val = float(max(counter.values())) 153 | class_weight = {class_id : max_val/num_ecg for class_id, num_ecg in counter.items()} 154 | 155 | parallel_model.compile(loss=multilabel_loss, optimizer=RAdam(lr=params["learning_rate"]), 156 | metrics=['accuracy']) 157 | my_callbacks = [EarlyStopping(monitor='val_loss', patience=8, verbose=2), 158 | ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=4, 159 | min_lr=0.00000001, verbose=1), 160 | TQDMNotebookCallback(leave_inner=True, leave_outer=True)] 161 | # TensorBoard(log_dir='./logs', histogram_freq=0, write_graph=True, write_images=True)] 162 | 163 | parallel_model.fit_generator(generator=DataGenerator(train_id, train_label, 164 | batch_size=params["batch_size"], 165 | n_classes=params["disease_num"]), 166 | use_multiprocessing=True, 167 | workers=45, 168 | epochs=30, 169 | validation_data=DataGenerator(val_id, val_label, 170 | batch_size=params["batch_size"], 171 | n_classes=params["disease_num"]), 172 | steps_per_epoch=int(len(train_id)/params["batch_size"]), 173 | callbacks=my_callbacks, 174 | verbose=0, 175 | class_weight=class_weight) 176 | 177 | return model, parallel_model 178 | 179 | def model_save(model): 180 | today = date.today() 181 | # Save the model 182 | model.save('multilabel_model_' + today.strftime("%m%d") + '.h5') 183 | # # Save the weights as well 184 | model.save_weights('multilabel_model_weights_' + today.strftime("%m%d") + '.h5') 185 | 186 | def model_load(h5_name): 187 | # This code can load the whole model 188 | model = load_model(h5_name) 189 | # If necesssary, you can create a new model using the weights you have got. 190 | # Fisrt create a new model... 191 | # Then load the weights 192 | # model.load_weights('model_weights_0805.h5') 193 | return model 194 | 195 | def model_eval(model, params): 196 | def plot_roc(name, labels, predict_prob, cur_clr): 197 | fp_rate, tp_rate, thresholds = roc_curve(labels, predict_prob) 198 | roc_auc = auc(fp_rate, tp_rate) 199 | plt.title('ROC') 200 | plt.plot(fp_rate, tp_rate, cur_clr, label= name + "'s AUC = %0.4f" % roc_auc) 201 | plt.legend(loc='lower right') 202 | plt.plot([0, 1], [0, 1], 'r--') 203 | plt.xlim([0, 1]) 204 | plt.ylim([0, 1]) 205 | plt.ylabel('TPR') 206 | plt.xlabel('FPR') 207 | 208 | def plot_confusion_matrix(name, cm, title='Confusion Matrix', cmap='Blues'): 209 | labels = ['Non-' + name, name] 210 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 211 | plt.title(title) 212 | plt.colorbar() 213 | xlocations = np.array(range(len(labels))) 214 | plt.xticks(xlocations, labels, rotation=30) 215 | plt.yticks(xlocations, labels) 216 | plt.ylabel('True label') 217 | plt.xlabel('Predicted label') 218 | 219 | cm_normalized = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis] 220 | ind_array = np.arange(len(labels)) 221 | x, y = np.meshgrid(ind_array, ind_array) 222 | 223 | for x_val, y_val in zip(x.flatten(), y.flatten()): 224 | c = cm_normalized[y_val][x_val] 225 | plt.text(x_val, y_val, "%0.4f" %(c,), color='black', fontsize=15, va='center', ha='center') 226 | 227 | # Visualize the classification result 228 | # First load the test set into memory 229 | X_test = [] 230 | y_test = [] 231 | for i in test_id: 232 | X_test.append(np.load(i)) 233 | for i in range(len(test_id)): 234 | y_test.append(test_label[test_id[i]]) 235 | 236 | X_test = np.array(X_test) 237 | y_test = np.array(y_test) 238 | 239 | test_pos_predict = model.predict(X_test) 240 | test_predict_onehot = (test_pos_predict >= 0.5).astype(int) 241 | 242 | abbr_list = params["abbr_list"] 243 | 244 | today = date.today() 245 | # ROC & AUC 246 | plt.figure(figsize=(24, 20)) 247 | for i in range(len(abbr_list)): 248 | plt.subplot(5, 6, i+1) 249 | plot_roc(abbr_list[i], y_test[:, i], test_pos_predict[:, i], 'blue') 250 | 251 | plt.tight_layout() 252 | plt.savefig('multilabel_roc_' + today.strftime("%m%d") + '.png') 253 | 254 | # Confusion matrix 255 | conf_matrix = [] 256 | for i in range(len(abbr_list)): 257 | conf_matrix.append(confusion_matrix(y_test[:, i], test_predict_onehot[:, i])) 258 | plt.figure(figsize=(42, 35)) 259 | for i in range(len(abbr_list)): 260 | plt.subplot(5, 6, i+1) 261 | plot_confusion_matrix(abbr_list[i], conf_matrix[i]) 262 | 263 | plt.tight_layout() 264 | plt.savefig('multilabel_conf_' + today.strftime("%m%d") + '.png') 265 | 266 | def plot_roc(name, labels, predict_prob, cur_clr): 267 | fp_rate, tp_rate, thresholds = roc_curve(labels, predict_prob) 268 | roc_auc = auc(fp_rate, tp_rate) 269 | plt.title('ROC') 270 | plt.plot(fp_rate, tp_rate, cur_clr, label= name + "'s AUC = %0.4f" % roc_auc) 271 | plt.legend(loc='lower right') 272 | plt.plot([0, 1], [0, 1], 'r--') 273 | plt.xlim([0, 1]) 274 | plt.ylim([0, 1]) 275 | plt.ylabel('TPR') 276 | plt.xlabel('FPR') 277 | 278 | def plot_confusion_matrix(name, cm, title='', cmap='Blues'): 279 | labels = ['Non-' + name, name] 280 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 281 | plt.title(title, fontsize=25) 282 | plt.colorbar() 283 | xlocations = np.array(range(len(labels))) 284 | plt.xticks(xlocations, labels, rotation=30, fontsize=25) 285 | plt.yticks(xlocations, labels, rotation=30, fontsize=25) 286 | plt.ylabel('Committee consensus label', fontsize=25) 287 | plt.xlabel('Model predicted label', fontsize=25) 288 | 289 | cm_normalized = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis] 290 | ind_array = np.arange(len(labels)) 291 | x, y = np.meshgrid(ind_array, ind_array) 292 | 293 | for x_val, y_val in zip(x.flatten(), y.flatten()): 294 | c = cm_normalized[y_val][x_val] 295 | plt.text(x_val, y_val, "%0.3f" %(c,), color='black', fontsize=25, va='center', ha='center') --------------------------------------------------------------------------------