├── Step1_EDA ├── .DS_Store ├── PRG007_train_test_split.py ├── PRG003_mimic_demog.py ├── PRG002_ICDcode.py ├── PRG006_tfdataset.py ├── PRG005_data_selection.py ├── PRG001_eda.py └── PRG004_mimic_race.py ├── README.md ├── LICENSE ├── Step3_DownstreamTask └── PRG021_downstream_task.py └── Step2_VEbuilding ├── PRG012_vector_embedding.py └── PRG011_VAE.py /Step1_EDA/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MIT-LCP/mimic-iv-ecg-ve/main/Step1_EDA/.DS_Store -------------------------------------------------------------------------------- /Step1_EDA/PRG007_train_test_split.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import shutil 4 | 5 | 6 | source_dir = "dataset/tfdata/tfrevised" 7 | test_dir = "dataset/tfdata/tftest" 8 | train_dir = "dataset/tfdata/tftrain" 9 | 10 | os.makedirs(test_dir, exist_ok=True) 11 | os.makedirs(train_dir, exist_ok=True) 12 | 13 | 14 | def copy_files(): 15 | files = tf.io.gfile.listdir(source_dir) 16 | 17 | for file in files: 18 | source_path = os.path.join(source_dir, file) 19 | 20 | if any(f"p{str(i)}" in file for i in range(1800, 2000)): 21 | target_path = os.path.join(test_dir, file) 22 | else: 23 | target_path = os.path.join(train_dir, file) 24 | 25 | tf.io.gfile.copy(source_path, target_path, overwrite=True) 26 | 27 | copy_files() 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mimic-iv-ecg-ve 2 | Vector Embedding Pipeline for MIMIC-IV-ECG 3 | 4 | This repository contains the Python files to create vector embeddings from 12-lead ECG data and to evaluate the effectiveness of vector embedding with downstream tasks. 5 | 6 | ### Step 1 7 | Data handling and data curation for VE model training and evaluation with downstream tasks. 8 | 9 | 001: EDA for MIMIC-IV-ECG 10 | 002: Diagnosis extraction from MIMIC-IV using ICD 11 | 003: Extraction of demographic information from MIMI-IV-hosp 12 | 004: Data clustering for race 13 | 005: Data selection for VE model training 14 | 006: Creating TF dataset for faster training and evaluation 15 | 007: Splitting data for training and evaluation 16 | 17 | ### Step 2 18 | 011: VE model creation using VAE approach 19 | 012: VE creation using trained VE model 20 | 21 | ### Step 3 22 | 021: VE Evaluation with Downstream Tasks 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 MIT Laboratory for Computational Physiology 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Step1_EDA/PRG003_mimic_demog.py: -------------------------------------------------------------------------------- 1 | # load library 2 | import numpy as np 3 | import pandas as pd 4 | import wfdb 5 | import joblib 6 | import glob 7 | import tqdm 8 | from IPython.display import display 9 | from sklearn.model_selection import train_test_split 10 | 11 | # load demog data 12 | patients = pd.read_csv('data/mimic-iv-3.1/hosp/patients.csv.gz') 13 | patients['sex'] = patients['gender'].replace({"M":0,"F":1}) 14 | 15 | admission = pd.read_csv('data/mimic-iv-3.1/hosp/admissions.csv.gz') 16 | patients = patients.merge(admission,how='inner',on='subject_id') 17 | patients['age'] = patients['anchor_age'] + pd.to_datetime(patients['admittime']).dt.year - patients['anchor_year'] 18 | 19 | 20 | # load icd data 21 | dx_cnt = joblib.load("dataset/df/dx_cnt") 22 | 23 | # load ecg meta data 24 | ecg_data_df = joblib.load("dataset/df/mimic_iv_ecg_data_df") 25 | ecg_data_df["subject_id"] = ecg_data_df["comments"].str[0].str[14:].astype(int) 26 | 27 | # df merge 28 | merge_df = pd.merge(patients,dx_cnt,how='inner',on=['subject_id','hadm_id']) 29 | merge_df = pd.merge(merge_df,ecg_data_df,how='inner',on=['subject_id']) 30 | 31 | merge_df["admittime"] = pd.to_datetime(merge_df["admittime"]) 32 | merge_df["dischtime"] = pd.to_datetime(merge_df["dischtime"]) 33 | merge_df["ecgtime"] = pd.to_datetime(merge_df["base_date"].astype(str) + ' ' + merge_df["base_time"].astype(str)) 34 | 35 | # The ECG recorded in ER is also taken into account, the window sets 1 day before admission. 36 | merge_df_sel = merge_df.copy()[(merge_df["admittime"] - pd.Timedelta(days=1) < merge_df["ecgtime"]) & 37 | (merge_df["ecgtime"] < merge_df["dischtime"])] 38 | merge_df_sel2 = merge_df_sel.sort_values("dischtime").drop_duplicates(subset=['ecgtime'],keep="last") 39 | 40 | joblib.dump(merge_df_sel2,"dataset/df/merge_df_sel2") 41 | # 481568 -------------------------------------------------------------------------------- /Step1_EDA/PRG002_ICDcode.py: -------------------------------------------------------------------------------- 1 | # load library 2 | import numpy as np 3 | import pandas as pd 4 | import wfdb 5 | import joblib 6 | import glob 7 | import tqdm 8 | from IPython.display import display 9 | from sklearn.model_selection import train_test_split 10 | 11 | icd_df = pd.read_csv('data/mimic-iv-3.1/hosp/d_icd_diagnoses.csv.gz') 12 | 13 | 14 | # https://www.cms.gov/icd10m/version36-fullcode-cms/fullcode_cms/P0467.html 15 | # https://en.wikipedia.org/wiki/List_of_ICD-9_codes_390%E2%80%93459:_diseases_of_the_circulatory_system#:~:text=428%20Heart%20failure&text=428.4%20Heart%20failure%2C%20combined%2C%20unspec. 16 | # https://www.bcbsnm.com/provider/education-reference/education/news/2021-archive/02-15-2021-atrial-fibrillation 17 | # https://en.wikipedia.org/wiki/List_of_ICD-9_codes_390%E2%80%93459:_diseases_of_the_circulatory_system#:~:text=428%20Heart%20failure&text=428.4%20Heart%20failure%2C%20combined%2C%20unspec. 18 | 19 | icd_df.loc[icd_df.icd_code.str.startswith('I50'),'icd_dx']= 'hf' 20 | icd_df.loc[icd_df.icd_code.str.startswith('428'),'icd_dx']= 'hf' 21 | icd_df.loc[icd_df.icd_code.str.startswith('I48'),'icd_dx']= 'af' 22 | icd_df.loc[icd_df.icd_code.str.startswith('4273'),'icd_dx']= 'af' 23 | 24 | icd_df_sel = icd_df.copy()[~icd_df.icd_dx.isna()] 25 | 26 | 27 | mimic_icd_df = pd.read_csv('data/mimic-iv-3.1/hosp/diagnoses_icd.csv.gz') 28 | 29 | dx_list = icd_df_sel.drop_duplicates(subset='icd_dx').icd_dx.to_list() 30 | 31 | def Category_cnt(df): 32 | firstLoop = True 33 | for i in dx_list: 34 | df2=df.copy() 35 | recept_list = icd_df_sel[icd_df_sel['icd_dx']== i ]['icd_code'].to_list() 36 | df2[i]=df2['icd_code'].isin(recept_list).astype(int) 37 | df2=df2[['subject_id','hadm_id',i]].groupby(['subject_id','hadm_id']).agg({i:'max'}) 38 | if firstLoop: 39 | df3=df2.copy() 40 | firstLoop = False 41 | else: 42 | df3 = pd.merge(df3,df2,how='left',on=['subject_id','hadm_id']) 43 | return df3.reset_index(drop=False) 44 | 45 | dx_cnt = Category_cnt(mimic_icd_df) 46 | 47 | joblib.dump(dx_cnt,"dataset/df/dx_cnt") 48 | -------------------------------------------------------------------------------- /Step1_EDA/PRG006_tfdataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import joblib, tqdm 3 | import pandas as pd 4 | import tensorflow as tf 5 | import wfdb 6 | from multiprocessing import Pool 7 | import tqdm 8 | 9 | 10 | 11 | 12 | def CreateTensorflowReadFile(df, out_file): 13 | with tf.io.TFRecordWriter(out_file) as writer: 14 | csv_paths = df['path_wo_ext'].to_list() 15 | x_ECG = np.empty((len(csv_paths), 5000, 12, 1)) 16 | 17 | for i, (item_path) in enumerate(csv_paths): 18 | signals, _ = wfdb.rdsamp(item_path) 19 | x_ECG[i, :] = signals[:, :, np.newaxis] 20 | 21 | y_sex = df['sex'].astype(float).values # It's better to unify the type as float. 22 | y_age = df['age'].astype(float).values # It's better to unify the type as float. 23 | y_hf = df['hf'].astype(float).values # It's better to unify the type as float. 24 | y_af = df['af'].astype(float).values # It's better to unify the type as float. 25 | 26 | # Convert data to binary file 27 | example = tf.train.Example(features=tf.train.Features(feature={ 28 | "x_ECG": tf.train.Feature(bytes_list=tf.train.BytesList(value=[x_ECG.tobytes()])), 29 | "y_sex": tf.train.Feature(bytes_list=tf.train.BytesList(value=[y_sex.tobytes()])), 30 | "y_age": tf.train.Feature(bytes_list=tf.train.BytesList(value=[y_age.tobytes()])), 31 | "y_hf": tf.train.Feature(bytes_list=tf.train.BytesList(value=[y_hf.tobytes()])), 32 | "y_af": tf.train.Feature(bytes_list=tf.train.BytesList(value=[y_af.tobytes()])), 33 | })) 34 | 35 | # Write 36 | writer.write(example.SerializeToString()) 37 | 38 | 39 | # Load pandas DataFrame 40 | merge_df_sel4 = joblib.load("dataset/df/merge_df_sel4") 41 | 42 | 43 | 44 | # # To avoid running out of memory, create TensorFlow data for each record. 45 | 46 | def process_row(i): 47 | # Avoid unnecessary dataframe copy 48 | select_df = merge_df_sel4.iloc[[i]] 49 | CreateTensorflowReadFile(select_df, "dataset/tfdata/tfrevised/ecg_tf_" + merge_df_sel4['ecg_pID'][i] +"_" + merge_df_sel4['ecg_sID'][i] ) 50 | 51 | if __name__ == "__main__": 52 | # Increasing chunk size for better parallelization 53 | chunk_size = 10 54 | with Pool(processes=8) as pool: 55 | list(tqdm.tqdm(pool.imap(process_row, range(len(merge_df_sel4)), chunksize=chunk_size), total=len(merge_df_sel4))) 56 | -------------------------------------------------------------------------------- /Step1_EDA/PRG005_data_selection.py: -------------------------------------------------------------------------------- 1 | # load library 2 | import numpy as np 3 | import pandas as pd 4 | import wfdb 5 | import joblib 6 | import glob 7 | import tqdm 8 | from IPython.display import display 9 | from sklearn.model_selection import train_test_split 10 | 11 | merge_df_sel3 = joblib.load("dataset/df/merge_df_sel3") 12 | merge_df_sel4 = merge_df_sel3.copy()[~merge_df_sel3.del_ecg] 13 | 14 | 15 | merge_df_sel4.loc[merge_df_sel4.subject_id<18000000,'total_set'] = 'train' 16 | merge_df_sel4.loc[merge_df_sel4.subject_id>=18000000,'total_set'] = 'test' 17 | 18 | merge_df_sel4.loc[(merge_df_sel4['total_set']=='test'),'ds50K'] = 'test' 19 | merge_df_sel4.loc[(merge_df_sel4['total_set']=='test'),'ds_m_50K'] = 'test' 20 | merge_df_sel4.loc[(merge_df_sel4['total_set']=='test'),'ds_w_50K'] = 'test' 21 | 22 | merge_df_sel4.loc[(merge_df_sel4[ merge_df_sel4['total_set'] != 'test'].sample(50000,random_state=0).index), 'ds50K']= 'train' 23 | merge_df_sel4.loc[(merge_df_sel4[ (merge_df_sel4['total_set'] != 'test') & (merge_df_sel4['gender']=="M")].sample(50000,random_state=0).index), 'ds_m_50K']= 'train' 24 | merge_df_sel4.loc[(merge_df_sel4[ (merge_df_sel4['total_set'] != 'test') & (merge_df_sel4['race2']=="White")].sample(50000,random_state=0).index), 'ds_w_50K']= 'train' 25 | 26 | merge_df_sel4['ecg_id'] = "ecg_" + merge_df_sel4['total_set'].replace( 27 | {'train': '1', 'test': '2'}).apply(lambda x: x if x in ["1", "2"] else '0') + \ 28 | merge_df_sel4['ds50K'].replace( 29 | {'train': '1', 'test': '2'}).apply(lambda x: x if x in ["1", "2"] else '0') + \ 30 | merge_df_sel4['ds_m_50K'].replace( 31 | {'train': '1', 'test': '2'}).apply(lambda x: x if x in ["1", "2"] else '0') + \ 32 | merge_df_sel4['ds_w_50K'].replace( 33 | {'train': '1', 'test': '2'}).apply(lambda x: x if x in ["1", "2"] else '0') + \ 34 | "_" + merge_df_sel4.index.astype(str).str.zfill(8) 35 | 36 | merge_df_sel4['ecg_sID']= merge_df_sel4['path'].apply(lambda x: os.path.basename(os.path.dirname(x))) 37 | merge_df_sel4['ecg_pID']= merge_df_sel4['path'].apply(lambda x: os.path.basename(os.path.dirname(os.path.dirname(x)))) 38 | 39 | merge_df_sel4 = merge_df_sel4.reset_index(drop=True) 40 | 41 | joblib.dump(merge_df_sel4,"dataset/df/merge_df_sel4") 42 | 43 | display(merge_df_sel4['total_set'].value_counts()) 44 | display(merge_df_sel4['ds50K'].value_counts()) 45 | display(merge_df_sel4['ds_m_50K'].value_counts()) 46 | display(merge_df_sel4['ds_w_50K'].value_counts()) 47 | 48 | 49 | # train 328895 50 | # test 143686 51 | # Name: total_set, dtype: int64 52 | 53 | 54 | # test 143686 55 | # train 50000 56 | # Name: ds50K, dtype: int64 -------------------------------------------------------------------------------- /Step3_DownstreamTask/PRG021_downstream_task.py: -------------------------------------------------------------------------------- 1 | # File and directory handling 2 | import os, random, glob, joblib 3 | # Data manipulation 4 | import pandas as pd 5 | import numpy as np 6 | # Scikit-learn for machine learning and metrics 7 | from sklearn.model_selection import GroupKFold, cross_validate 8 | from sklearn.linear_model import LogisticRegression 9 | from sklearn.preprocessing import StandardScaler 10 | from sklearn.pipeline import Pipeline 11 | from sklearn.metrics import make_scorer, recall_score, precision_score, fbeta_score, confusion_matrix 12 | 13 | 14 | merge_df = joblib.load("dataset/df/annotation_df.jb") 15 | merge_df = merge_df.drop('filename',axis=1) 16 | merge_df = merge_df.rename(columns={'filename':'ecg_path', "af":"af_ai"}) 17 | 18 | merge_df_sel4 = joblib.load("dataset/df/merge_df_sel4") 19 | ve = pd.read_csv("embeddings/embedding_128dim.csv") 20 | 21 | merge_df_sel4['filename'] = "dataset/tfdata/tftest/ecg_tf_" + merge_df_sel4['ecg_pID'] +"_" + merge_df_sel4['ecg_sID'] 22 | df_with_ve = pd.merge(merge_df_sel4,ve,how="inner", on=["filename"]) 23 | 24 | 25 | df_with_ve_with_af = pd.merge(merge_df, df_with_ve,how="inner", on=["ecg_sID","ecg_pID"]) 26 | 27 | df_with_ve_with_af_sel = df_with_ve_with_af.copy()[df_with_ve_with_af['annot_Dx'].isin(["AF","nonAF"])] 28 | 29 | df_with_ve_with_af_sel["af_dr"] = df_with_ve_with_af_sel['annot_Dx'].replace({"AF":1,"nonAF":0}) 30 | 31 | 32 | latent_dim = 128 33 | VE_vars = ['VE_' + str(item).zfill(2) for item in range(latent_dim)] 34 | 35 | # Preparation of five-part cross-validation per patient 36 | gkf = GroupKFold(n_splits=5) 37 | 38 | # Create pipeline (preprocessing + model) 39 | pipeline = Pipeline([ 40 | ('scaler', StandardScaler()), 41 | ('lasso_logreg', LogisticRegression(penalty='l2', solver='saga', max_iter=10000, C=10)) 42 | ]) 43 | # Create custom metrics 44 | def specificity_score(y_true, y_pred): 45 | tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() 46 | return tn / (tn + fp) 47 | 48 | scoring = { 49 | 'AUC': 'roc_auc', 50 | 'Sensitivity': make_scorer(recall_score), 51 | 'Specificity': make_scorer(specificity_score), 52 | 'Precision': make_scorer(precision_score), 53 | 'F2': make_scorer(fbeta_score, beta=2) 54 | } 55 | 56 | def mean_sd(series,digit): 57 | return str(round(np.mean(series),digit)).zfill(digit)+' ± '+ str(round( 58 | np.std(series),digit)).zfill(digit) 59 | 60 | def five_cv_eval(df, outcome): 61 | cv_results = cross_validate(pipeline, df[VE_vars], 62 | df[outcome], cv=gkf.split(df[VE_vars], 63 | df[outcome], groups=df.subject_id), scoring=scoring, 64 | n_jobs=-1 ) 65 | print("outcome: ", outcome) 66 | print("AUC: ", mean_sd(cv_results['test_AUC'],2)) 67 | print("Sensitivity: ", mean_sd(cv_results['test_Sensitivity'],2)) 68 | print("Specificity: ", mean_sd(cv_results['test_Specificity'],2)) 69 | print("Precision: ", mean_sd(cv_results['test_Precision'],2)) 70 | print("F2 score: ", mean_sd(cv_results['test_F2'],2)) 71 | 72 | 73 | five_cv_eval(df_with_ve_with_af_sel,"af_ai") 74 | five_cv_eval(df_with_ve_with_af_sel,'af_dr') 75 | 76 | five_cv_eval(df_with_ve_with_af,"hf") 77 | five_cv_eval(df_with_ve_with_af,"sex") 78 | five_cv_eval(df_with_ve_with_af,"af_ai") 79 | 80 | 81 | 82 | # outcome: af_ai 83 | # AUC: 0.84 ± 0.01 84 | # Sensitivity: 0.52 ± 0.02 85 | # Specificity: 0.94 ± 0.0 86 | # Precision: 0.82 ± 0.02 87 | # F2 score: 0.56 ± 0.02 88 | # outcome: af_dr 89 | # AUC: 0.83 ± 0.01 90 | # Sensitivity: 0.46 ± 0.02 91 | # Specificity: 0.94 ± 0.01 92 | # Precision: 0.79 ± 0.02 93 | # F2 score: 0.5 ± 0.02 94 | 95 | 96 | # outcome: hf 97 | # AUC: 0.72 ± 0.0 98 | # Sensitivity: 0.22 ± 0.01 99 | # Specificity: 0.95 ± 0.0 100 | # Precision: 0.63 ± 0.03 101 | # F2 score: 0.25 ± 0.01 102 | # outcome: sex 103 | # AUC: 0.63 ± 0.01 104 | # Sensitivity: 0.57 ± 0.01 105 | # Specificity: 0.63 ± 0.01 106 | # Precision: 0.58 ± 0.01 107 | # F2 score: 0.57 ± 0.01 108 | # outcome: af_ai 109 | # AUC: 0.76 ± 0.01 110 | # Sensitivity: 0.13 ± 0.01 111 | # Specificity: 0.99 ± 0.0 112 | # Precision: 0.58 ± 0.04 113 | # F2 score: 0.15 ± 0.02 -------------------------------------------------------------------------------- /Step1_EDA/PRG001_eda.py: -------------------------------------------------------------------------------- 1 | # load library 2 | import numpy as np 3 | import pandas as pd 4 | import tensorflow as tf 5 | import wfdb 6 | import joblib 7 | import glob 8 | from joblib import Parallel, delayed 9 | import wfdb 10 | from tqdm import tqdm 11 | from IPython.display import display 12 | gpus = tf.config.experimental.list_physical_devices('GPU') 13 | if gpus: 14 | try: 15 | for gpu in gpus: 16 | tf.config.experimental.set_memory_growth(gpu, True) 17 | print("Memory growth enabled") 18 | except RuntimeError as e: 19 | print(e) 20 | # listing the ECG paths 21 | ecg_path_list = glob.glob('data/mimic-iv-ecg-1.0/files/*/*/*/*.dat') 22 | 23 | print("ECG data count: ", len(ecg_path_list)) 24 | 25 | 26 | #### Check ECG data ##### 27 | # remove extention 28 | ecg_data_df = pd.DataFrame(ecg_path_list,columns=['path']) 29 | ecg_data_df['path_wo_ext'] = ecg_data_df['path'].str[:-4] 30 | 31 | 32 | signals, fields = wfdb.rdsamp(ecg_path_list[0][:-4]) 33 | np.std(signals,axis=0) 34 | np.mean(signals,axis=0) 35 | np.any(np.isnan(signals)) 36 | np.any(np.isinf(signals)) 37 | 38 | sig_name_list = [] 39 | units_list = [] 40 | fields_list = [] 41 | min_v_list = [] 42 | max_v_list = [] 43 | del_ecg_list = [] 44 | del_ecg_zero_list = [] 45 | del_ecg_nan_list = [] 46 | del_ecg_inf_list = [] 47 | for path in tqdm(ecg_data_df['path_wo_ext']): 48 | signals, fields = wfdb.rdsamp(path) 49 | sig_name_list.append(fields["sig_name"]) 50 | units_list.append(fields["units"]) 51 | # print(np.min(signals)) 52 | # print(np.max(signals)) 53 | min_v_list.append(np.min(signals)) 54 | max_v_list.append(np.max(signals)) 55 | tensor_signals = tf.convert_to_tensor(signals) 56 | if np.any(np.std(signals,axis=0)==0) or np.any(np.isnan(signals)) or np.any(np.isinf(signals)) or tf.reduce_any(tf.math.is_nan(tensor_signals)): 57 | del_ecg_list.append(path) 58 | 59 | if np.any(np.std(signals,axis=0)==0): 60 | del_ecg_zero_list.append(path) 61 | if np.any(np.isnan(signals)): 62 | del_ecg_nan_list.append(path) 63 | if np.any(np.isinf(signals)): 64 | del_ecg_inf_list.append(path) 65 | if tf.reduce_any(tf.math.is_nan(tensor_signals)): 66 | del_ecg_nan_list.append(path) 67 | 68 | del fields["sig_name"] 69 | del fields["units"] 70 | fields_list.append(fields) 71 | 72 | 73 | print(len(del_ecg_list)) # 12358 74 | print(len(del_ecg_zero_list)) # 1951 75 | print(len(del_ecg_nan_list)) # 10554 -> 21108 76 | print(len(del_ecg_inf_list)) # 0 77 | 78 | pd.DataFrame(min_v_list).hist() 79 | 80 | fields_df = pd.DataFrame(fields_list) 81 | sig_name_df = pd.DataFrame(sig_name_list) 82 | units_df = pd.DataFrame(units_list) 83 | 84 | display('variation of sampling frequency : ',fields_df.fs.value_counts()) 85 | display('variation of signal_length : ',fields_df.sig_len.value_counts()) 86 | display('variation of n_signal : ',fields_df.n_sig.value_counts()) 87 | display('variation of sig_name : ',sig_name_df.value_counts()) 88 | display('variation of units : ',units_df.value_counts()) 89 | 90 | ecg_data_df = pd.concat([ecg_data_df,fields_df],axis=1) 91 | 92 | 93 | ecg_data_sel_df = ecg_data_df.copy()[~ecg_data_df["path_wo_ext"].isin(del_ecg_list)] 94 | 95 | ecg_data_df['del_ecg'] = ecg_data_df["path_wo_ext"].isin(del_ecg_list) 96 | ecg_data_df['del_ecg_zero'] = ecg_data_df["path_wo_ext"].isin(del_ecg_zero_list) 97 | ecg_data_df['del_ecg_nan_ecg'] = ecg_data_df["path_wo_ext"].isin(del_ecg_nan_list) 98 | 99 | joblib.dump(ecg_data_df,"dataset/df/mimic_iv_ecg_data_df") 100 | 101 | 102 | """ 103 | 'variation of sampling frequency : 'fs 104 | 500 800035 105 | Name: count, dtype: int64'variation of signal_length : 'sig_len 106 | 5000 800035 107 | Name: count, dtype: int64'variation of n_signal : 'n_sig 108 | 12 800035 109 | Name: count, dtype: int64'variation of sig_name : '0 1 2 3 4 5 6 7 8 9 10 11 110 | I II III aVR aVF aVL V1 V2 V3 V4 V5 V6 800035 111 | Name: count, dtype: int64'variation of units : '0 1 2 3 4 5 6 7 8 9 10 11 112 | mV mV mV mV mV mV mV mV mV mV mV mV 800035 113 | Name: count, dtype: int64 114 | """ -------------------------------------------------------------------------------- /Step1_EDA/PRG004_mimic_race.py: -------------------------------------------------------------------------------- 1 | # load library 2 | import numpy as np 3 | import pandas as pd 4 | import wfdb 5 | import joblib 6 | import glob 7 | import tqdm 8 | from IPython.display import display 9 | from sklearn.model_selection import train_test_split 10 | 11 | merge_df_sel2 = joblib.load("dataset/df/merge_df_sel2") 12 | 13 | 14 | merge_df_sel2.race.value_counts() 15 | 16 | # https://github.com/rmovva/granular-race-disparities_MLHC23/blob/main/analysis/race_categories.py 17 | 18 | # We exclude 'OTHER' from the analysis in the paper. 19 | coarse_races = [ 20 | 'WHITE', 21 | 'BLACK/AFRICAN AMERICAN', 22 | 'HISPANIC OR LATINO', 23 | 'ASIAN', 24 | # 'OTHER', 25 | ] 26 | 27 | coarse_abbrev = { 28 | 'WHITE': 'White', 29 | 'BLACK/AFRICAN AMERICAN': 'Black', 30 | 'HISPANIC OR LATINO': 'Hispanic/Latino', 31 | 'ASIAN': 'Asian', 32 | 'AVERAGE': 'Average', # this is for cases where we want to average over coarse groups 33 | } 34 | 35 | granular_to_coarse = { 36 | 'HISPANIC OR LATINO': 'HISPANIC OR LATINO', 37 | 'HISPANIC/LATINO - PUERTO RICAN': 'HISPANIC OR LATINO', 38 | 'HISPANIC/LATINO - DOMINICAN': 'HISPANIC OR LATINO', 39 | 'HISPANIC/LATINO - GUATEMALAN': 'HISPANIC OR LATINO', 40 | 'HISPANIC/LATINO - SALVADORAN': 'HISPANIC OR LATINO', 41 | 'HISPANIC/LATINO - MEXICAN': 'HISPANIC OR LATINO', 42 | 'HISPANIC/LATINO - COLUMBIAN': 'HISPANIC OR LATINO', 43 | 'HISPANIC/LATINO - HONDURAN': 'HISPANIC OR LATINO', 44 | 'HISPANIC/LATINO - CUBAN': 'HISPANIC OR LATINO', 45 | 'HISPANIC/LATINO - CENTRAL AMERICAN': 'HISPANIC OR LATINO', 46 | 'SOUTH AMERICAN': 'HISPANIC OR LATINO', 47 | 48 | 'ASIAN': 'ASIAN', 49 | 'ASIAN - CHINESE': 'ASIAN', 50 | 'ASIAN - SOUTH EAST ASIAN': 'ASIAN', 51 | 'ASIAN - ASIAN INDIAN': 'ASIAN', 52 | 'ASIAN - KOREAN': 'ASIAN', 53 | 54 | 'WHITE': 'WHITE', 55 | 'WHITE - OTHER EUROPEAN': 'WHITE', 56 | 'WHITE - RUSSIAN': 'WHITE', 57 | 'WHITE - EASTERN EUROPEAN': 'WHITE', 58 | 'WHITE - BRAZILIAN': 'WHITE', 59 | 'PORTUGUESE': 'WHITE', 60 | 61 | 'BLACK/AFRICAN AMERICAN': 'BLACK/AFRICAN AMERICAN', 62 | 'BLACK/CAPE VERDEAN': 'BLACK/AFRICAN AMERICAN', 63 | 'BLACK/CARIBBEAN ISLAND': 'BLACK/AFRICAN AMERICAN', 64 | 'BLACK/AFRICAN': 'BLACK/AFRICAN AMERICAN', 65 | 66 | # 'AMERICAN INDIAN/ALASKA NATIVE': 'OTHER', 67 | # 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER': 'OTHER', 68 | # 'MULTIPLE RACE/ETHNICITY': 'OTHER', 69 | # 'UNKNOWN': 'OTHER', 70 | # 'PATIENT DECLINED TO ANSWER': 'OTHER', 71 | # 'UNABLE TO OBTAIN': 'OTHER', 72 | } 73 | 74 | granular_abbrev = { 75 | 'HISPANIC OR LATINO': 'HISPANIC OR LATINO*', 76 | 'HISPANIC/LATINO - PUERTO RICAN': 'PUERTO RICAN', 77 | 'HISPANIC/LATINO - DOMINICAN': 'DOMINICAN', 78 | 'HISPANIC/LATINO - GUATEMALAN': 'GUATEMALAN', 79 | 'HISPANIC/LATINO - SALVADORAN': 'SALVADORAN', 80 | 'HISPANIC/LATINO - MEXICAN': 'MEXICAN', 81 | 'HISPANIC/LATINO - COLUMBIAN': 'COLOMBIAN', 82 | 'HISPANIC/LATINO - HONDURAN': 'HONDURAN', 83 | 'HISPANIC/LATINO - CUBAN': 'CUBAN', 84 | 'HISPANIC/LATINO - CENTRAL AMERICAN': 'CENTRAL AMERICAN', 85 | 'SOUTH AMERICAN': 'SOUTH AMERICAN', 86 | 87 | 'ASIAN': 'ASIAN*', 88 | 'ASIAN - CHINESE': 'CHINESE', 89 | 'ASIAN - SOUTH EAST ASIAN': 'SE ASIAN', 90 | 'ASIAN - ASIAN INDIAN': 'INDIAN', 91 | 'ASIAN - KOREAN': 'KOREAN', 92 | 93 | 'WHITE': 'WHITE*', 94 | 'WHITE - OTHER EUROPEAN': 'OTHER EUR', 95 | 'WHITE - RUSSIAN': 'RUSSIAN', 96 | 'WHITE - EASTERN EUROPEAN': 'EASTERN EUR', 97 | 'WHITE - BRAZILIAN': 'BRAZILIAN', 98 | 'PORTUGUESE': 'PORTUGUESE', 99 | 100 | 'BLACK/AFRICAN AMERICAN': 'BLACK*', 101 | 'BLACK/CAPE VERDEAN': 'CAPE VERDEAN', 102 | 'BLACK/CARIBBEAN ISLAND': 'CARIBBEAN', 103 | 'BLACK/AFRICAN': 'AFRICAN', 104 | 105 | # 'AMERICAN INDIAN/ALASKA NATIVE': 'AMERICAN INDIAN', 106 | # 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER': 'PACIFIC ISLANDER', 107 | # 'MULTIPLE RACE/ETHNICITY': 'MULTIRACIAL', 108 | # 'UNKNOWN': 'UNKNOWN', 109 | # 'PATIENT DECLINED TO ANSWER': 'DECLINE TO ANSWER', 110 | # 'UNABLE TO OBTAIN': 'UNABLE TO OBTAIN', 111 | } 112 | 113 | coarse_to_granular = { 114 | 'WHITE': [ 115 | 'WHITE', 116 | 'WHITE - OTHER EUROPEAN', 117 | 'WHITE - RUSSIAN', 118 | 'WHITE - EASTERN EUROPEAN', 119 | 'WHITE - BRAZILIAN', 120 | 'PORTUGUESE', 121 | ], 122 | 'BLACK/AFRICAN AMERICAN': [ 123 | 'BLACK/AFRICAN AMERICAN', 124 | 'BLACK/CAPE VERDEAN', 125 | 'BLACK/CARIBBEAN ISLAND', 126 | 'BLACK/AFRICAN', 127 | ], 128 | 'HISPANIC OR LATINO': [ 129 | 'HISPANIC OR LATINO', 130 | 'HISPANIC/LATINO - PUERTO RICAN', 131 | 'HISPANIC/LATINO - DOMINICAN', 132 | 'HISPANIC/LATINO - GUATEMALAN', 133 | 'HISPANIC/LATINO - SALVADORAN', 134 | 'HISPANIC/LATINO - MEXICAN', 135 | 'HISPANIC/LATINO - COLUMBIAN', 136 | 'HISPANIC/LATINO - HONDURAN', 137 | 'HISPANIC/LATINO - CUBAN', 138 | 'HISPANIC/LATINO - CENTRAL AMERICAN', 139 | 'SOUTH AMERICAN' 140 | ], 141 | 'ASIAN': [ 142 | 'ASIAN', 143 | 'ASIAN - CHINESE', 144 | 'ASIAN - SOUTH EAST ASIAN', 145 | 'ASIAN - ASIAN INDIAN', 146 | 'ASIAN - KOREAN', 147 | ], 148 | # 'OTHER': [ 149 | # 'AMERICAN INDIAN/ALASKA NATIVE', 150 | # 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER', 151 | # 'MULTIPLE RACE/ETHNICITY', 152 | # ] 153 | } 154 | 155 | 156 | merge_df_sel2['race2'] = merge_df_sel2.race.replace(granular_to_coarse).replace(coarse_abbrev) 157 | 158 | 159 | joblib.dump(merge_df_sel2,"dataset/df/merge_df_sel3") 160 | 161 | # merge_df_sel2['race2'].value_counts() 162 | # race2 163 | # White 330563 164 | # Black 74700 165 | # Hispanic/Latino 25520 166 | # UNKNOWN 16707 167 | # OTHER 14217 168 | # Asian 13627 169 | # UNABLE TO OBTAIN 2712 170 | # PATIENT DECLINED TO ANSWER 1685 171 | # AMERICAN INDIAN/ALASKA NATIVE 964 172 | # MULTIPLE RACE/ETHNICITY 503 173 | # NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER 370 -------------------------------------------------------------------------------- /Step2_VEbuilding/PRG012_vector_embedding.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.saving import register_keras_serializable 3 | from tensorflow.keras import layers, Model 4 | from tensorflow.keras import regularizers 5 | from tqdm import tqdm 6 | import pandas as pd 7 | import joblib 8 | 9 | def load_tfrecord_with_filename(filename): 10 | dataset = tf.data.TFRecordDataset(filename) 11 | dataset = dataset.map(lambda record: (record, filename)) 12 | return dataset 13 | 14 | 15 | def parse_with_filename(record, filename): 16 | features = tf.io.parse_single_example( 17 | record, 18 | features={"x_ECG": tf.io.FixedLenFeature([], tf.string), 19 | "y_sex": tf.io.FixedLenFeature([], tf.string), 20 | "y_age": tf.io.FixedLenFeature([], tf.string), 21 | "y_hf": tf.io.FixedLenFeature([], tf.string), 22 | "y_af": tf.io.FixedLenFeature([], tf.string), 23 | }) 24 | 25 | x_ECG_ = tf.io.decode_raw(features["x_ECG"], tf.float64) 26 | x_ECG = tf.reshape(x_ECG_, tf.stack([5000, 12, 1])) 27 | return x_ECG, filename 28 | def preprocess_with_filename(x_ECG, filename): 29 | start = (5000 - 4096) // 2 30 | x_ECG = x_ECG[start:start+4096, :, :] 31 | return x_ECG, filename 32 | 33 | 34 | @register_keras_serializable() 35 | class Sampling(layers.Layer): 36 | """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" 37 | def __init__(self, **kwargs): 38 | super(Sampling, self).__init__(**kwargs) # 必要な初期化処理を追加 39 | def call(self, inputs): 40 | z_mean, z_log_var = inputs 41 | batch = tf.shape(z_mean)[0] 42 | dim = tf.shape(z_mean)[1] 43 | epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) 44 | return z_mean + tf.exp(0.5 * z_log_var) * epsilon 45 | 46 | @register_keras_serializable() 47 | class ReconstructionLossLayer(tf.keras.layers.Layer): 48 | def __init__(self, **kwargs): 49 | super(ReconstructionLossLayer, self).__init__(**kwargs) 50 | def call(self, inputs): 51 | x_true, x_pred = inputs 52 | loss = tf.reduce_mean(tf.square(x_true - x_pred)) 53 | self.add_loss(loss) 54 | self.reconstruction_loss = loss 55 | return inputs 56 | 57 | @register_keras_serializable() 58 | class KLDivergenceLayer(tf.keras.layers.Layer): 59 | def __init__(self, **kwargs): 60 | super(KLDivergenceLayer, self).__init__(**kwargs) 61 | def call(self, inputs): 62 | z_mean, z_log_var = inputs 63 | loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) 64 | self.add_loss(loss) 65 | self.kl_loss = loss 66 | return inputs 67 | 68 | 69 | def build_vae(filter_num=32, latent_dim=128): 70 | inputs = layers.Input(shape=(4096, 12, 1)) 71 | 72 | # Encoder 73 | encoding_1 = layers.Conv2D(filter_num, (5, 12), padding='same', 74 | kernel_regularizer=regularizers.l2(0.001))(inputs) 75 | encoding_1 = layers.BatchNormalization()(encoding_1) 76 | encoding_1 = layers.Activation('swish')(encoding_1) 77 | encoding_1 = layers.MaxPooling2D((4, 12), padding='same')(encoding_1) 78 | 79 | encoding_2 = layers.Conv2D(filter_num, (5, 1), padding='same', 80 | kernel_regularizer=regularizers.l2(0.001))(encoding_1) 81 | encoding_2 = layers.BatchNormalization()(encoding_2) 82 | encoding_2 = layers.Activation('swish')(encoding_2) 83 | encoding_2 = layers.MaxPooling2D((4, 1), padding='same')(encoding_2) 84 | 85 | encoding_3 = layers.Conv2D(filter_num, (5, 1), padding='same', 86 | kernel_regularizer=regularizers.l2(0.001))(encoding_2) 87 | encoding_3 = layers.BatchNormalization()(encoding_3) 88 | encoding_3 = layers.Activation('swish')(encoding_3) 89 | encoding_3 = layers.MaxPooling2D((4, 1), padding='same')(encoding_3) 90 | 91 | encoding_4 = layers.Conv2D(filter_num, (5, 1), padding='same', 92 | kernel_regularizer=regularizers.l2(0.001))(encoding_3) 93 | encoding_4 = layers.BatchNormalization()(encoding_4) 94 | encoding_4 = layers.Activation('swish')(encoding_4) 95 | encoding_4 = layers.MaxPooling2D((4, 1), padding='same')(encoding_4) 96 | 97 | # Latent space (mean and variance) 98 | encoding_f = layers.Flatten()(encoding_4) 99 | z_mean = layers.Dense(latent_dim, kernel_regularizer=tf.keras.regularizers.l2(0.001))(encoding_f) 100 | z_log_var = layers.Dense(latent_dim, kernel_regularizer=tf.keras.regularizers.l2(0.001))(encoding_f) 101 | 102 | # sampling layer 103 | z_mean, z_log_var = KLDivergenceLayer()([z_mean, z_log_var]) 104 | z = Sampling()([z_mean, z_log_var]) 105 | 106 | # Decoder 107 | decoding_1 = layers.Dense(encoding_4.shape[1] * encoding_4.shape[2] * encoding_4.shape[3])(z) 108 | decoding_1 = layers.Reshape(encoding_4.shape[1:])(decoding_1) 109 | 110 | decoding_2 = layers.Conv2DTranspose(filter_num, (5, 1), strides=(4, 1), padding='same', 111 | kernel_regularizer=regularizers.l2(0.001))(decoding_1) 112 | decoding_2 = layers.Activation('swish')(decoding_2) 113 | 114 | decoding_3 = layers.Conv2DTranspose(filter_num, (5, 1), strides=(4, 1), padding='same', 115 | kernel_regularizer=regularizers.l2(0.001))(decoding_2) 116 | decoding_3 = layers.Activation('swish')(decoding_3) 117 | 118 | # UpSampling to adjust encoding_3 size to decoding_3 119 | encoding_3_up = layers.UpSampling2D(size=(4, 1))(encoding_3) 120 | decoding_3 = layers.Add()([decoding_3, encoding_3_up]) 121 | 122 | decoding_4 = layers.Conv2DTranspose(filter_num, (5, 1), strides=(4, 1), padding='same', 123 | kernel_regularizer=regularizers.l2(0.001))(decoding_3) 124 | decoding_4 = layers.Activation('swish')(decoding_4) 125 | 126 | decoding_5 = layers.Conv2DTranspose(1, (4, 12), strides=(4, 12), padding='same', 127 | kernel_regularizer=regularizers.l2(0.001))(decoding_4) 128 | decoding_5 = layers.Activation('swish')(decoding_5) 129 | 130 | decoding = ReconstructionLossLayer()([inputs, decoding_5]) 131 | 132 | 133 | vae = Model(inputs, decoding) 134 | encoder = Model(inputs, [z_mean, z_log_var, z]) 135 | vae.compile(optimizer='adam', loss=None, jit_compile=False) 136 | 137 | # print(vae.summary()) 138 | 139 | return vae, encoder 140 | 141 | filter_num=32 142 | latent_dim=128 143 | batch_size = 256 144 | vae, encoder = build_vae(filter_num=filter_num,latent_dim=latent_dim) 145 | encoder.load_weights('model/encoder_full_weight.weights.h5') 146 | 147 | 148 | file_pattern = "dataset/tfdata/tftest/ecg_tf_*" 149 | file_list = tf.data.Dataset.list_files(file_pattern) 150 | 151 | 152 | 153 | dataset = file_list.interleave( 154 | load_tfrecord_with_filename, 155 | cycle_length=16, 156 | block_length=4, 157 | num_parallel_calls=tf.data.experimental.AUTOTUNE 158 | ) 159 | 160 | 161 | 162 | 163 | 164 | dataset = ( 165 | dataset.map(parse_with_filename, num_parallel_calls=tf.data.experimental.AUTOTUNE) 166 | .map(preprocess_with_filename, num_parallel_calls=tf.data.experimental.AUTOTUNE) 167 | .batch(batch_size) 168 | .prefetch(tf.data.experimental.AUTOTUNE) 169 | ) 170 | 171 | VE_vars = ['VE_' + str(item).zfill(2) for item in range(latent_dim)] 172 | 173 | 174 | output_csv = "embeddings/embedding_128dim.csv" 175 | 176 | if not os.path.exists(output_csv): 177 | pd.DataFrame(columns=VE_vars + ['filename']).to_csv(output_csv, index=False) 178 | 179 | for batch, filenames in tqdm(dataset): 180 | predicted_value = encoder.predict(batch) 181 | df = pd.DataFrame(predicted_value[0], columns=VE_vars) 182 | df['filename'] = [f.decode('utf-8') for f in filenames.numpy()] 183 | 184 | df.to_csv(output_csv, mode='a', index=False, header=False) 185 | 186 | ve = pd.read_csv("embeddings/embedding_128dim.csv") 187 | 188 | ve.head() 189 | 190 | 191 | merge_df_sel4 = joblib.load("dataset/df/merge_df_sel4") 192 | merge_df_sel4['filename'] = "dataset/tfdata/tftest/ecg_tf_" + merge_df_sel4['ecg_pID'] +"_" + merge_df_sel4['ecg_sID'] 193 | 194 | df_with_ve = pd.merge(merge_df_sel4,ve,how="inner", on=["filename"]) 195 | 196 | joblib.dump(df_with_ve, "dataset/df/test_with_ve_df") -------------------------------------------------------------------------------- /Step2_VEbuilding/PRG011_VAE.py: -------------------------------------------------------------------------------- 1 | # File and directory handling 2 | import os, random, glob, joblib 3 | # Data manipulation 4 | import pandas as pd 5 | import numpy as np 6 | # TensorFlow and Keras for model building 7 | import tensorflow as tf 8 | from tensorflow.keras import layers, Model 9 | from tensorflow.keras import regularizers 10 | # Scikit-learn for machine learning and metrics 11 | from sklearn.model_selection import GroupKFold, cross_validate 12 | from sklearn.linear_model import LogisticRegression 13 | from sklearn.preprocessing import StandardScaler 14 | from sklearn.pipeline import Pipeline 15 | from sklearn.metrics import make_scorer, recall_score, precision_score, fbeta_score, confusion_matrix 16 | from tensorflow.keras.saving import register_keras_serializable 17 | from tensorflow.keras.callbacks import LearningRateScheduler 18 | 19 | ######## For reproducibility, include the below lines. ######### 20 | random.seed(42) 21 | tf.keras.utils.set_random_seed(1) 22 | tf.config.experimental.enable_op_determinism() 23 | ################################################################ 24 | 25 | # 0. Prevent tf from hogging the GPU. 26 | gpus = tf.config.experimental.list_physical_devices('GPU') 27 | if gpus: 28 | try: 29 | for gpu in gpus: 30 | tf.config.experimental.set_memory_growth(gpu, True) 31 | print("Memory growth enabled") 32 | except RuntimeError as e: 33 | print(e) 34 | 35 | # 1. Parse function to decode binary data 36 | def parse(example): 37 | features = tf.io.parse_single_example( 38 | example, 39 | features={"x_ECG": tf.io.FixedLenFeature([], tf.string), 40 | "y_sex": tf.io.FixedLenFeature([], tf.string), 41 | "y_age": tf.io.FixedLenFeature([], tf.string), 42 | "y_hf": tf.io.FixedLenFeature([], tf.string), 43 | "y_af": tf.io.FixedLenFeature([], tf.string), 44 | }) 45 | 46 | x_ECG_ = tf.io.decode_raw(features["x_ECG"], tf.float64) 47 | x_ECG = tf.reshape(x_ECG_, tf.stack([5000, 12, 1])) 48 | 49 | return x_ECG, x_ECG 50 | 51 | def preprocess_data(x, y): 52 | x = tf.image.random_crop(x, size=(4096, 12, 1)) 53 | return x, x 54 | 55 | 56 | # Sampling Layer for VAE 57 | @register_keras_serializable() 58 | class Sampling(layers.Layer): 59 | """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.""" 60 | def __init__(self, **kwargs): 61 | super(Sampling, self).__init__(**kwargs) # 必要な初期化処理を追加 62 | def call(self, inputs): 63 | z_mean, z_log_var = inputs 64 | batch = tf.shape(z_mean)[0] 65 | dim = tf.shape(z_mean)[1] 66 | epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) 67 | return z_mean + tf.exp(0.5 * z_log_var) * epsilon 68 | 69 | @register_keras_serializable() 70 | class ReconstructionLossLayer(tf.keras.layers.Layer): 71 | def __init__(self, **kwargs): 72 | super(ReconstructionLossLayer, self).__init__(**kwargs) 73 | def call(self, inputs): 74 | x_true, x_pred = inputs 75 | loss = tf.reduce_mean(tf.square(x_true - x_pred)) 76 | self.add_loss(loss) 77 | self.reconstruction_loss = loss 78 | return inputs 79 | 80 | @register_keras_serializable() 81 | class KLDivergenceLayer(tf.keras.layers.Layer): 82 | def __init__(self, **kwargs): 83 | super(KLDivergenceLayer, self).__init__(**kwargs) 84 | def call(self, inputs): 85 | z_mean, z_log_var = inputs 86 | loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) 87 | self.add_loss(loss) 88 | self.kl_loss = loss 89 | return inputs 90 | 91 | 92 | def build_vae(filter_num=32, latent_dim=128): 93 | inputs = layers.Input(shape=(4096, 12, 1)) 94 | 95 | # Encoder 96 | encoding_1 = layers.Conv2D(filter_num, (5, 12), padding='same', 97 | kernel_regularizer=regularizers.l2(0.001))(inputs) 98 | encoding_1 = layers.BatchNormalization()(encoding_1) 99 | encoding_1 = layers.Activation('swish')(encoding_1) 100 | encoding_1 = layers.MaxPooling2D((4, 12), padding='same')(encoding_1) 101 | 102 | encoding_2 = layers.Conv2D(filter_num, (5, 1), padding='same', 103 | kernel_regularizer=regularizers.l2(0.001))(encoding_1) 104 | encoding_2 = layers.BatchNormalization()(encoding_2) 105 | encoding_2 = layers.Activation('swish')(encoding_2) 106 | encoding_2 = layers.MaxPooling2D((4, 1), padding='same')(encoding_2) 107 | 108 | encoding_3 = layers.Conv2D(filter_num, (5, 1), padding='same', 109 | kernel_regularizer=regularizers.l2(0.001))(encoding_2) 110 | encoding_3 = layers.BatchNormalization()(encoding_3) 111 | encoding_3 = layers.Activation('swish')(encoding_3) 112 | encoding_3 = layers.MaxPooling2D((4, 1), padding='same')(encoding_3) 113 | 114 | encoding_4 = layers.Conv2D(filter_num, (5, 1), padding='same', 115 | kernel_regularizer=regularizers.l2(0.001))(encoding_3) 116 | encoding_4 = layers.BatchNormalization()(encoding_4) 117 | encoding_4 = layers.Activation('swish')(encoding_4) 118 | encoding_4 = layers.MaxPooling2D((4, 1), padding='same')(encoding_4) 119 | 120 | # Latent space (mean and variance) 121 | encoding_f = layers.Flatten()(encoding_4) 122 | z_mean = layers.Dense(latent_dim, kernel_regularizer=tf.keras.regularizers.l2(0.001))(encoding_f) 123 | z_log_var = layers.Dense(latent_dim, kernel_regularizer=tf.keras.regularizers.l2(0.001))(encoding_f) 124 | 125 | # sampling layer 126 | z_mean, z_log_var = KLDivergenceLayer()([z_mean, z_log_var]) 127 | z = Sampling()([z_mean, z_log_var]) 128 | 129 | # Decoder 130 | decoding_1 = layers.Dense(encoding_4.shape[1] * encoding_4.shape[2] * encoding_4.shape[3])(z) 131 | decoding_1 = layers.Reshape(encoding_4.shape[1:])(decoding_1) 132 | 133 | decoding_2 = layers.Conv2DTranspose(filter_num, (5, 1), strides=(4, 1), padding='same', 134 | kernel_regularizer=regularizers.l2(0.001))(decoding_1) 135 | decoding_2 = layers.Activation('swish')(decoding_2) 136 | 137 | decoding_3 = layers.Conv2DTranspose(filter_num, (5, 1), strides=(4, 1), padding='same', 138 | kernel_regularizer=regularizers.l2(0.001))(decoding_2) 139 | decoding_3 = layers.Activation('swish')(decoding_3) 140 | 141 | # UpSampling to adjust encoding_3 size to decoding_3 142 | encoding_3_up = layers.UpSampling2D(size=(4, 1))(encoding_3) 143 | decoding_3 = layers.Add()([decoding_3, encoding_3_up]) 144 | 145 | decoding_4 = layers.Conv2DTranspose(filter_num, (5, 1), strides=(4, 1), padding='same', 146 | kernel_regularizer=regularizers.l2(0.001))(decoding_3) 147 | decoding_4 = layers.Activation('swish')(decoding_4) 148 | 149 | decoding_5 = layers.Conv2DTranspose(1, (4, 12), strides=(4, 12), padding='same', 150 | kernel_regularizer=regularizers.l2(0.001))(decoding_4) 151 | decoding_5 = layers.Activation('swish')(decoding_5) 152 | 153 | decoding = ReconstructionLossLayer()([inputs, decoding_5]) 154 | 155 | 156 | vae = Model(inputs, decoding) 157 | encoder = Model(inputs, [z_mean, z_log_var, z]) 158 | vae.compile(optimizer='adam', loss=None, jit_compile=False) 159 | 160 | print(vae.summary()) 161 | 162 | return vae, encoder 163 | 164 | filter_num=32 165 | latent_dim=128 166 | batch_size = 256 167 | vae, encoder = build_vae(filter_num=filter_num,latent_dim=latent_dim) 168 | 169 | # vae.layers 170 | 171 | # # 11. Dataset 172 | file_pattern = "dataset/tfdata/tftrain/ecg_tf_*" 173 | file_list = tf.data.Dataset.list_files(file_pattern) 174 | 175 | dataset = file_list.interleave( 176 | lambda filename: tf.data.TFRecordDataset(filename), 177 | cycle_length=16, 178 | block_length=4, 179 | num_parallel_calls=tf.data.experimental.AUTOTUNE 180 | ) 181 | 182 | dataset_size = len(glob.glob("dataset/tfdata/tftrain/ecg_tf_*")) 183 | test_size = dataset_size // 4 184 | train_size = dataset_size - test_size 185 | 186 | test_dataset = dataset.take(test_size) 187 | train_dataset = dataset.skip(test_size) 188 | 189 | train_dataset = train_dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE).map( 190 | preprocess_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).shuffle(8192).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) 191 | test_dataset = test_dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE).map( 192 | preprocess_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) 193 | 194 | # 12. Model training 195 | history = vae.fit(train_dataset, validation_data=test_dataset, epochs=50) 196 | pd.DataFrame(history.history).plot() 197 | 198 | vae.save('model/vae_full.keras') 199 | vae.save_weights('model/vae_full_weight.weights.h5') 200 | encoder.save('model/encoder_full.keras') 201 | encoder.save_weights('model/encoder_full_weight.weights.h5') --------------------------------------------------------------------------------