├── README.md ├── config.py ├── convert_cpsc.py ├── data_process.py ├── dataset.py ├── main_distillation.py ├── main_train.py ├── minirocket_train.py ├── models ├── __init__.py ├── acnet.py ├── ati_cnn.py ├── attention.py ├── bi_lstm.py ├── fcn_wang.py ├── inceptiontime.py ├── minrocket.py ├── mobilenet_v3.py ├── model.py ├── resnet1d_wang.py ├── vit.py └── xresnet1d101.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # A Multi-View Multi-Scale Neural Network for Multi-Label ECG Classification 2 | 3 | This is the code for the paper "A Multi-View Multi-Scale Neural Network for Multi-Label ECG Classification" 4 | 5 | # Dependency 6 | 7 | - python>=3.7 8 | - pytorch>=1.7.0 9 | - torchvision>=0.8.1 10 | - numpy>=1.19.5 11 | - tqdm>=4.62.0 12 | - scipy>=1.5.4 13 | - wfdb>=3.2.0 14 | - scikit-learn>=0.24.2 15 | 16 | # Usage 17 | 18 | ## Configuration 19 | 20 | There is a configuration file "config.py", where one can edit both the training and test options. 21 | 22 | ## Stage 1: Training 23 | 24 | After setting the configuration, to start training, simply run 25 | 26 | > python main_train.py 27 | 28 | Since MiniRocket's training strategy is slightly different from the others, to start training in MiniRocket, run 29 | 30 | > python minirocket_train.py 31 | 32 | ## Stage 2: Knowledge Distillation 33 | 34 | The multi-view network trained in the first stage is used to train the single-view network, run 35 | 36 | > python main_distillation.py 37 | 38 | # Dataset 39 | 40 | PTB-XL dataset can be downloaded from [PTB-XL, a large publicly available electrocardiography dataset v1.0.1 (physionet.org)](https://www.physionet.org/content/ptb-xl/1.0.1/). 41 | 42 | CPSC2018 dataset can be downloaded from [The China Physiological Signal Challenge 2018 (icbeb.org)](http://2018.icbeb.org/Challenge.html) 43 | 44 | HFHC dataset can be downloaded from https://tianchi.aliyun.com/competition/entrance/231754/information 45 | 46 | # Citation 47 | 48 | If you find this idea useful in your research, please consider citing: 49 | 50 | ``` 51 | @article{ 52 | title={A Multi-View Multi-Scale Neural Network for Multi-Label ECG Classification}, 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @time: 2021/4/16 18:45 4 | 5 | @ author: 6 | ''' 7 | 8 | class Config: 9 | 10 | seed = 10 11 | 12 | # path 13 | datafolder = '../data/ptbxl/' 14 | # datafolder = '../data/CPSC/' 15 | # datafolder = '../data/hf/' 16 | 17 | # 18 | ''' 19 | experiment = exp0, exp1, exp1.1, exp1.1.1, exp2, exp3 20 | ''' 21 | experiment = 'exp0' 22 | 23 | # for train 24 | ''' 25 | MyNet6View, resnet1d_wang, xresnet1d101, inceptiontime, fcn_wang, lstm, lstm_bidir, vit, mobilenetv3_small 26 | ''' 27 | model_name = 'MyNet6View' 28 | 29 | model_name2 = 'MyNet' 30 | 31 | batch_size = 64 32 | 33 | max_epoch = 100 34 | 35 | lr = 0.001 36 | 37 | device_num = 1 38 | 39 | # eg: MyNet6View_all_checkpoint_best_tpr.pth 40 | checkpoints = 'MyNet6View_exp0_checkpoint_best_auc.pth' 41 | 42 | # knowledge distillation param 43 | alpha = 0.5 44 | temperature = 2 45 | 46 | 47 | config = Config() 48 | -------------------------------------------------------------------------------- /convert_cpsc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import wfdb 4 | from tqdm import tqdm 5 | import numpy as np 6 | from scipy.ndimage import zoom 7 | from scipy.io import loadmat 8 | from stratisfy import stratisfy_df 9 | 10 | output_folder = 'data/CPSC/' 11 | output_datafolder_100 = output_folder+ '/records100/' 12 | output_datafolder_500 = output_folder+ '/records500/' 13 | if not os.path.exists(output_folder): 14 | os.mkdir(output_folder) 15 | if not os.path.exists(output_datafolder_100): 16 | os.makedirs(output_datafolder_100) 17 | if not os.path.exists(output_datafolder_500): 18 | os.makedirs(output_datafolder_500) 19 | 20 | def store_as_wfdb(signame, data, sigfolder, fs): 21 | channel_itos=['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] 22 | wfdb.wrsamp(signame, 23 | fs=fs, 24 | sig_name=channel_itos, 25 | p_signal=data, 26 | units=['mV']*len(channel_itos), 27 | fmt = ['16']*len(channel_itos), 28 | write_dir=sigfolder) 29 | 30 | df_reference = pd.read_csv('tmp_data/REFERENCE.csv') 31 | 32 | label_dict = {1:'NORM', 2:'AFIB', 3:'1AVB', 4:'CLBBB', 5:'CRBBB', 6:'PAC', 7:'VPC', 8:'STD_', 9:'STE_'} 33 | 34 | data = {'ecg_id':[], 'filename':[], 'validation':[], 'age':[], 'sex':[], 'scp_codes':[]} 35 | 36 | ecg_counter = 0 37 | for folder in ['TrainingSet1', 'TrainingSet2', 'TrainingSet3']: 38 | filenames = os.listdir('tmp_data/'+folder) 39 | for filename in tqdm(filenames): 40 | if filename.split('.')[1] == 'mat': 41 | ecg_counter += 1 42 | name = filename.split('.')[0] 43 | sex, age, sig = loadmat('tmp_data/'+folder+'/'+filename)['ECG'][0][0] 44 | data['ecg_id'].append(ecg_counter) 45 | data['filename'].append(name) 46 | data['validation'].append(False) 47 | data['age'].append(age[0][0]) 48 | data['sex'].append(1 if sex[0] == 'Male' else 0) 49 | labels = df_reference[df_reference.Recording == name][['First_label' ,'Second_label' ,'Third_label']].values.flatten() 50 | labels = labels[~np.isnan(labels)].astype(int) 51 | data['scp_codes'].append({label_dict[key]:100 for key in labels}) 52 | store_as_wfdb(str(ecg_counter), sig.T, output_datafolder_500, 500) 53 | down_sig = np.array([zoom(channel, .2) for channel in sig]) 54 | store_as_wfdb(str(ecg_counter), down_sig.T, output_datafolder_100, 100) 55 | 56 | df = pd.DataFrame(data) 57 | df['patient_id'] = df.ecg_id 58 | df = stratisfy_df(df, 'strat_fold') 59 | df.to_csv(output_folder+'cpsc_database.csv') 60 | 61 | 62 | -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pandas as pd 4 | import numpy as np 5 | from tqdm import tqdm 6 | import wfdb 7 | import ast 8 | from scipy.signal import resample 9 | from sklearn.preprocessing import StandardScaler, MultiLabelBinarizer 10 | from sklearn import preprocessing 11 | 12 | # DATA PROCESSING STUFF 13 | def load_dataset(path, sampling_rate, release=False): 14 | if path.split('/')[-2] == 'ptbxl': 15 | # load and convert annotation data 16 | Y = pd.read_csv(path + 'ptbxl_database.csv', index_col='ecg_id') 17 | Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) 18 | 19 | # Load raw signal data 20 | X = load_raw_data_ptbxl(Y, sampling_rate, path) 21 | 22 | elif path.split('/')[-2] == 'CPSC': 23 | # load and convert annotation data 24 | Y = pd.read_csv(path + 'cpsc_database.csv', index_col='ecg_id') 25 | Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) 26 | 27 | # Load raw signal data 28 | X = load_raw_data_cpsc(Y, sampling_rate, path) 29 | 30 | return X, Y 31 | 32 | 33 | def load_raw_data_cpsc(df, sampling_rate, path): 34 | if sampling_rate == 100: 35 | if os.path.exists(path + 'raw100.npy'): 36 | data = np.load(path + 'raw100.npy', allow_pickle=True) 37 | else: 38 | data = [wfdb.rdsamp(path + 'records100/' + str(f)) for f in tqdm(df.index)] 39 | data = np.array([signal for signal, meta in data]) 40 | pickle.dump(data, open(path + 'raw100.npy', 'wb'), protocol=4) 41 | elif sampling_rate == 500: 42 | if os.path.exists(path + 'raw500.npy'): 43 | data = np.load(path + 'raw500.npy', allow_pickle=True) 44 | else: 45 | data = [wfdb.rdsamp(path + 'records500/' + str(f)) for f in tqdm(df.index)] 46 | data = np.array([signal for signal, meta in data]) 47 | pickle.dump(data, open(path + 'raw500.npy', 'wb'), protocol=4) 48 | return data 49 | 50 | 51 | def load_raw_data_ptbxl(df, sampling_rate, path): 52 | if sampling_rate == 100: 53 | if os.path.exists(path + 'raw100.npy'): 54 | data = np.load(path + 'raw100.npy', allow_pickle=True) 55 | else: 56 | data = [wfdb.rdsamp(path + f) for f in tqdm(df.filename_lr)] 57 | data = np.array([signal for signal, meta in data]) 58 | pickle.dump(data, open(path + 'raw100.npy', 'wb'), protocol=4) 59 | elif sampling_rate == 500: 60 | if os.path.exists(path + 'raw500.npy'): 61 | data = np.load(path + 'raw500.npy', allow_pickle=True) 62 | else: 63 | data = [wfdb.rdsamp(path + f) for f in tqdm(df.filename_hr)] 64 | data = np.array([signal for signal, meta in data]) 65 | pickle.dump(data, open(path + 'raw500.npy', 'wb'), protocol=4) 66 | return data 67 | 68 | 69 | def compute_label_aggregations(df, folder, ctype): 70 | df['scp_codes_len'] = df.scp_codes.apply(lambda x: len(x)) 71 | 72 | aggregation_df = pd.read_csv(folder + 'scp_statements.csv', index_col=0) 73 | 74 | if ctype in ['diagnostic', 'subdiagnostic', 'superdiagnostic']: 75 | 76 | def aggregate_all_diagnostic(y_dic): 77 | tmp = [] 78 | for key in y_dic.keys(): 79 | if key in diag_agg_df.index: 80 | tmp.append(key) 81 | return list(set(tmp)) 82 | 83 | def aggregate_subdiagnostic(y_dic): 84 | tmp = [] 85 | for key in y_dic.keys(): 86 | if key in diag_agg_df.index: 87 | c = diag_agg_df.loc[key].diagnostic_subclass 88 | if str(c) != 'nan': 89 | tmp.append(c) 90 | return list(set(tmp)) 91 | 92 | def aggregate_diagnostic(y_dic): 93 | tmp = [] 94 | for key in y_dic.keys(): 95 | if key in diag_agg_df.index: 96 | c = diag_agg_df.loc[key].diagnostic_class 97 | if str(c) != 'nan': 98 | tmp.append(c) 99 | return list(set(tmp)) 100 | 101 | diag_agg_df = aggregation_df[aggregation_df.diagnostic == 1.0] 102 | if ctype == 'diagnostic': 103 | df['diagnostic'] = df.scp_codes.apply(aggregate_all_diagnostic) 104 | df['diagnostic_len'] = df.diagnostic.apply(lambda x: len(x)) 105 | elif ctype == 'subdiagnostic': 106 | df['subdiagnostic'] = df.scp_codes.apply(aggregate_subdiagnostic) 107 | df['subdiagnostic_len'] = df.subdiagnostic.apply(lambda x: len(x)) 108 | elif ctype == 'superdiagnostic': 109 | df['superdiagnostic'] = df.scp_codes.apply(aggregate_diagnostic) 110 | df['superdiagnostic_len'] = df.superdiagnostic.apply(lambda x: len(x)) 111 | elif ctype == 'form': 112 | form_agg_df = aggregation_df[aggregation_df.form == 1.0] 113 | 114 | def aggregate_form(y_dic): 115 | tmp = [] 116 | for key in y_dic.keys(): 117 | if key in form_agg_df.index: 118 | c = key 119 | if str(c) != 'nan': 120 | tmp.append(c) 121 | return list(set(tmp)) 122 | 123 | df['form'] = df.scp_codes.apply(aggregate_form) 124 | df['form_len'] = df.form.apply(lambda x: len(x)) 125 | elif ctype == 'rhythm': 126 | rhythm_agg_df = aggregation_df[aggregation_df.rhythm == 1.0] 127 | 128 | def aggregate_rhythm(y_dic): 129 | tmp = [] 130 | for key in y_dic.keys(): 131 | if key in rhythm_agg_df.index: 132 | c = key 133 | if str(c) != 'nan': 134 | tmp.append(c) 135 | return list(set(tmp)) 136 | 137 | df['rhythm'] = df.scp_codes.apply(aggregate_rhythm) 138 | df['rhythm_len'] = df.rhythm.apply(lambda x: len(x)) 139 | elif ctype == 'all': 140 | df['all_scp'] = df.scp_codes.apply(lambda x: list(set(x.keys()))) 141 | 142 | return df 143 | 144 | 145 | def select_data(XX, YY, ctype, min_samples): 146 | # convert multilabel to multi-hot 147 | mlb = MultiLabelBinarizer() 148 | 149 | if ctype == 'diagnostic': 150 | X = XX[YY.diagnostic_len > 0] 151 | Y = YY[YY.diagnostic_len > 0] 152 | mlb.fit(Y.diagnostic.values) 153 | y = mlb.transform(Y.diagnostic.values) 154 | elif ctype == 'subdiagnostic': 155 | counts = pd.Series(np.concatenate(YY.subdiagnostic.values)).value_counts() 156 | counts = counts[counts > min_samples] 157 | YY.subdiagnostic = YY.subdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 158 | YY['subdiagnostic_len'] = YY.subdiagnostic.apply(lambda x: len(x)) 159 | X = XX[YY.subdiagnostic_len > 0] 160 | Y = YY[YY.subdiagnostic_len > 0] 161 | mlb.fit(Y.subdiagnostic.values) 162 | y = mlb.transform(Y.subdiagnostic.values) 163 | elif ctype == 'superdiagnostic': 164 | counts = pd.Series(np.concatenate(YY.superdiagnostic.values)).value_counts() 165 | counts = counts[counts > min_samples] 166 | YY.superdiagnostic = YY.superdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 167 | YY['superdiagnostic_len'] = YY.superdiagnostic.apply(lambda x: len(x)) 168 | X = XX[YY.superdiagnostic_len > 0] 169 | Y = YY[YY.superdiagnostic_len > 0] 170 | mlb.fit(Y.superdiagnostic.values) 171 | y = mlb.transform(Y.superdiagnostic.values) 172 | elif ctype == 'form': 173 | # filter 174 | counts = pd.Series(np.concatenate(YY.form.values)).value_counts() 175 | counts = counts[counts > min_samples] 176 | YY.form = YY.form.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 177 | YY['form_len'] = YY.form.apply(lambda x: len(x)) 178 | # select 179 | X = XX[YY.form_len > 0] 180 | Y = YY[YY.form_len > 0] 181 | mlb.fit(Y.form.values) 182 | y = mlb.transform(Y.form.values) 183 | elif ctype == 'rhythm': 184 | # filter 185 | counts = pd.Series(np.concatenate(YY.rhythm.values)).value_counts() 186 | counts = counts[counts > min_samples] 187 | YY.rhythm = YY.rhythm.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 188 | YY['rhythm_len'] = YY.rhythm.apply(lambda x: len(x)) 189 | # select 190 | X = XX[YY.rhythm_len > 0] 191 | Y = YY[YY.rhythm_len > 0] 192 | mlb.fit(Y.rhythm.values) 193 | y = mlb.transform(Y.rhythm.values) 194 | elif ctype == 'all': 195 | # filter 196 | counts = pd.Series(np.concatenate(YY.all_scp.values)).value_counts() 197 | counts = counts[counts > min_samples] 198 | YY.all_scp = YY.all_scp.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 199 | YY['all_scp_len'] = YY.all_scp.apply(lambda x: len(x)) 200 | # select 201 | X = XX[YY.all_scp_len > 0] 202 | Y = YY[YY.all_scp_len > 0] 203 | mlb.fit(Y.all_scp.values) 204 | y = mlb.transform(Y.all_scp.values) 205 | else: 206 | pass 207 | 208 | return X, Y, y, mlb 209 | 210 | 211 | def preprocess_signals(X_train, X_validation, X_test): 212 | # Standardize data such that mean 0 and variance 1 213 | ss = StandardScaler() 214 | ss.fit(np.vstack(X_train).flatten()[:, np.newaxis].astype(float)) 215 | 216 | return apply_standardizer(X_train, ss), apply_standardizer(X_validation, ss), apply_standardizer(X_test, ss) 217 | 218 | 219 | def apply_standardizer(X, ss): 220 | X_tmp = [] 221 | for x in X: 222 | x_shape = x.shape 223 | X_tmp.append(ss.transform(x.flatten()[:, np.newaxis]).reshape(x_shape)) 224 | X_tmp = np.array(X_tmp) 225 | return X_tmp 226 | 227 | 228 | def data_slice(data): 229 | data_process = [] 230 | for dat in data: 231 | if dat.shape[0] < 1000: 232 | # dat = np.pad(dat, (0, 1000 - dat.shape[0]), 'constant', constant_values=0) 233 | dat = resample(dat, 1000, axis=0) 234 | elif dat.shape[0] > 1000: 235 | dat = dat[:1000, :] 236 | # dat = resample(dat, 1000, axis=0) 237 | if dat.shape[1] != 12: 238 | dat = dat[:, 0:12] 239 | 240 | data_process.append(dat) 241 | return np.array(data_process) 242 | 243 | 244 | # hf 245 | def name2index(path): 246 | list_name = [] 247 | for line in open(path, encoding='utf-8'): 248 | list_name.append(line.strip()) 249 | name2indx = {name: i for i, name in enumerate(list_name)} 250 | return name2indx 251 | 252 | 253 | 254 | def file2index(path, name2idx): 255 | file2index = dict() 256 | for line in open(path, encoding='utf-8'): 257 | arr = line.strip().split('\t') 258 | id = arr[0] 259 | labels = [name2idx[name] for name in arr[3:]] 260 | file2index[id] = labels 261 | return file2index 262 | 263 | 264 | def load_raw_data_hf(root='../data/hf/', resample_num=1000, num_classes=34): 265 | if os.path.exists(root + 'raw100_data.npy'): 266 | data = np.load(root + 'raw100_data.npy', allow_pickle=True) 267 | y = np.load(root + 'raw100_label.npy', allow_pickle=True) 268 | else: 269 | name2idx = name2index(root + 'hf_round2_arrythmia.txt') 270 | file2idx = file2index(root + 'hf_round2_label.txt', name2idx) 271 | data, label = [], [] 272 | for file, list_idx in file2idx.items(): 273 | temp = np.zeros([5000, 12]) 274 | df = pd.read_csv(root + 'hf_round2_train' + '/' + file, sep=' ').values 275 | temp[:, 2] = df[:, 1] - df[:, 0] 276 | temp[:, 3] = -(df[:, 0] + df[:, 1]) / 2 277 | temp[:, 4] = df[:, 0] - df[:, 1] / 2 278 | temp[:, 5] = df[:, 1] - df[:, 0] / 2 279 | temp[:, 0:2] = df[:, 0:2] 280 | temp[:, 6:12] = df[:, 2:8] 281 | sig = resample(temp, resample_num) 282 | min_max_scaler = preprocessing.MinMaxScaler() 283 | ecg = min_max_scaler.fit_transform(sig) 284 | data.append(ecg) 285 | label.append(tuple(list_idx)) 286 | data = np.array(data) 287 | pickle.dump(data, open(root + 'raw100_data.npy', 'wb'), protocol=4) 288 | mlb = MultiLabelBinarizer(classes=[i for i in range(num_classes)]) 289 | y = mlb.fit_transform(label) 290 | y = np.array(y) 291 | pickle.dump(y, open(root + 'raw100_label.npy', 'wb'), protocol=4) 292 | return data, y 293 | 294 | 295 | def hf_dataset(root='../data/hf/', resample_num=1000, num_classes=34): 296 | data, label = load_raw_data_hf(root, resample_num, num_classes) 297 | data_num = len(label) 298 | shuffle_ix = np.random.permutation(np.arange(data_num)) 299 | data = data[shuffle_ix] 300 | labels = label[shuffle_ix] 301 | 302 | X_train = data[int(data_num * 0.2):int(data_num * 0.8)] 303 | y_train = labels[int(data_num * 0.2):int(data_num * 0.8)] 304 | 305 | X_val = data[int(data_num * 0.8):] 306 | y_val = labels[int(data_num * 0.8):] 307 | 308 | X_test = data[:int(data_num * 0.2)] 309 | y_test = labels[:int(data_num * 0.2)] 310 | 311 | return X_train, y_train, X_val, y_val, X_test, y_test 312 | 313 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset 4 | from data_process import load_dataset, compute_label_aggregations, select_data, preprocess_signals, data_slice, hf_dataset 5 | from config import config 6 | 7 | 8 | class ECGDataset(Dataset): 9 | """ 10 | A generic data loader where the samples are arranged in this way: 11 | """ 12 | 13 | def __init__(self, signals: np.ndarray, labels: np.ndarray): 14 | super(ECGDataset, self).__init__() 15 | self.data = signals 16 | self.label = labels 17 | self.num_classes = self.label.shape[1] 18 | 19 | self.cls_num_list = np.sum(self.label, axis=0) 20 | 21 | def __getitem__(self, index): 22 | x = self.data[index] 23 | y = self.label[index] 24 | 25 | x = x.transpose() 26 | 27 | x = torch.tensor(x.copy(), dtype=torch.float) 28 | 29 | y = torch.tensor(y, dtype=torch.float) 30 | y = y.squeeze() 31 | return x, y 32 | 33 | def __len__(self): 34 | return len(self.data) 35 | 36 | 37 | class DownLoadECGData: 38 | ''' 39 | All experiments data 40 | ''' 41 | 42 | def __init__(self, experiment_name, task, datafolder, sampling_frequency=100, min_samples=0, 43 | train_fold=8, val_fold=9, test_fold=10): 44 | self.min_samples = min_samples 45 | self.task = task 46 | self.train_fold = train_fold 47 | self.val_fold = val_fold 48 | self.test_fold = test_fold 49 | self.experiment_name = experiment_name 50 | self.datafolder = datafolder 51 | self.sampling_frequency = sampling_frequency 52 | 53 | def preprocess_data(self): 54 | # Load PTB-XL data 55 | data, raw_labels = load_dataset(self.datafolder, self.sampling_frequency) 56 | # Preprocess label data 57 | labels = compute_label_aggregations(raw_labels, self.datafolder, self.task) 58 | 59 | # Select relevant data and convert to one-hot 60 | data, labels, Y, _ = select_data(data, labels, self.task, self.min_samples) 61 | 62 | if self.datafolder == '../data/CPSC/': 63 | data = data_slice(data) 64 | 65 | # 10th fold for testing (9th for now) 66 | X_test = data[labels.strat_fold == self.test_fold] 67 | y_test = Y[labels.strat_fold == self.test_fold] 68 | # 9th fold for validation (8th for now) 69 | X_val = data[labels.strat_fold == self.val_fold] 70 | y_val = Y[labels.strat_fold == self.val_fold] 71 | # rest for training 72 | X_train = data[labels.strat_fold <= self.train_fold] 73 | y_train = Y[labels.strat_fold <= self.train_fold] 74 | 75 | # Preprocess signal data 76 | X_train, X_val, X_test = preprocess_signals(X_train, X_val, X_test) 77 | 78 | return X_train, y_train, X_val, y_val, X_test, y_test 79 | 80 | 81 | def load_datasets(datafolder=None, experiment=None): 82 | ''' 83 | Load the final dataset 84 | ''' 85 | experiment = experiment 86 | 87 | if datafolder == '../data/ptbxl/': 88 | experiments = { 89 | 'exp0': ('exp0', 'all'), 90 | 'exp1': ('exp1', 'diagnostic'), 91 | 'exp1.1': ('exp1.1', 'subdiagnostic'), 92 | 'exp1.1.1': ('exp1.1.1', 'superdiagnostic'), 93 | 'exp2': ('exp2', 'form'), 94 | 'exp3': ('exp3', 'rhythm') 95 | } 96 | name, task = experiments[experiment] 97 | ded = DownLoadECGData(name, task, datafolder) 98 | X_train, y_train, X_val, y_val, X_test, y_test = ded.preprocess_data() 99 | elif datafolder == '../data/CPSC/': 100 | ded = DownLoadECGData('exp_CPSC', 'all', datafolder) 101 | X_train, y_train, X_val, y_val, X_test, y_test = ded.preprocess_data() 102 | else: 103 | X_train, y_train, X_val, y_val, X_test, y_test = hf_dataset(datafolder) 104 | 105 | ds_train = ECGDataset(X_train, y_train) 106 | ds_val = ECGDataset(X_val, y_val) 107 | ds_test = ECGDataset(X_test, y_test) 108 | 109 | num_classes = ds_train.num_classes 110 | train_dataloader = DataLoader(ds_train, batch_size=config.batch_size, shuffle=True) 111 | val_dataloader = DataLoader(ds_val, batch_size=config.batch_size, shuffle=False) 112 | test_dataloader = DataLoader(ds_test, batch_size=config.batch_size, shuffle=False) 113 | 114 | return train_dataloader, val_dataloader, test_dataloader, num_classes 115 | 116 | 117 | -------------------------------------------------------------------------------- /main_distillation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2021/4/15 15:40 4 | 5 | @ author: 6 | """ 7 | import torch, time, os 8 | import models, utils 9 | from torch import optim 10 | from dataset import load_datasets 11 | from config import config 12 | 13 | from sklearn.metrics import roc_auc_score 14 | 15 | import numpy as np 16 | import random 17 | import pandas as pd 18 | 19 | 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | 22 | 23 | def setup_seed(seed): 24 | print('seed: ', seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | np.random.seed(seed) 28 | random.seed(seed) 29 | torch.backends.cudnn.deterministic = True 30 | 31 | 32 | def train_epoch(model_large, model_small, optimizer, criterion, train_dataloader): 33 | model_large.train(), model_small.train() 34 | loss_meter, it_count = 0, 0 35 | outputs = [] 36 | targets = [] 37 | for inputs, target in train_dataloader: 38 | 39 | inputs = inputs + torch.randn_like(inputs) * 0.1 40 | 41 | inputs = inputs.to(device) 42 | target = target.to(device) 43 | # zero the parameter gradients 44 | optimizer.zero_grad() 45 | # forward 46 | with torch.no_grad(): 47 | target1 = model_large(inputs) 48 | 49 | output = model_small(inputs) 50 | 51 | loss = criterion(output, target, target1) 52 | 53 | loss.backward() 54 | optimizer.step() 55 | loss_meter += loss.item() 56 | it_count += 1 57 | 58 | output = torch.sigmoid(output) 59 | for i in range(len(output)): 60 | outputs.append(output[i].cpu().detach().numpy()) 61 | targets.append(target[i].cpu().detach().numpy()) 62 | auc = roc_auc_score(targets, outputs) 63 | TPR = utils.compute_TPR(targets, outputs) 64 | print('train_loss: %.4f, macro_auc: %.4f, TPR: %.4f' % (loss_meter / it_count, auc, TPR)) 65 | return loss_meter / it_count, auc, TPR 66 | 67 | 68 | def test_epoch(model_large, model_small, criterion, val_dataloader): 69 | model_large.eval(), model_small.eval() 70 | loss_meter, it_count = 0, 0 71 | outputs = [] 72 | targets = [] 73 | with torch.no_grad(): 74 | for inputs, target in val_dataloader: 75 | 76 | inputs = inputs + torch.randn_like(inputs) * 0.1 77 | 78 | inputs = inputs.to(device) 79 | target = target.to(device) 80 | 81 | target1 = model_large(inputs) 82 | 83 | output = model_small(inputs) 84 | 85 | loss = criterion(output, target, target1) 86 | 87 | loss_meter += loss.item() 88 | it_count += 1 89 | 90 | output = torch.sigmoid(output) 91 | for i in range(len(output)): 92 | outputs.append(output[i].cpu().detach().numpy()) 93 | targets.append(target[i].cpu().detach().numpy()) 94 | 95 | auc = roc_auc_score(targets, outputs) 96 | TPR = utils.compute_TPR(targets, outputs) 97 | 98 | print('test_loss: %.4f, macro_auc: %.4f, TPR: %.4f' % (loss_meter / it_count, auc, TPR)) 99 | return loss_meter / it_count, auc, TPR 100 | 101 | 102 | def train(config=config): 103 | # seed 104 | setup_seed(config.seed) 105 | print('torch.cuda.is_available:', torch.cuda.is_available()) 106 | 107 | # datasets 108 | train_dataloader, val_dataloader, test_dataloader, num_classes = load_datasets( 109 | datafolder=config.datafolder, 110 | experiment=config.experiment, 111 | ) 112 | 113 | # mode 114 | print('model_name:{}, num_classes={}'.format(config.model_name, num_classes)) 115 | model_large = getattr(models, config.model_name)(num_classes=num_classes) 116 | model_small = getattr(models, config.model_name2)(num_classes=num_classes) 117 | 118 | model_large = model_large.to(device) 119 | model_small = model_small.to(device) 120 | 121 | # optimizer and loss 122 | optimizer = optim.Adam(model_small.parameters(), lr=config.lr) 123 | criterion = utils.KdLoss(config.alpha, config.temperature) 124 | 125 | if config.checkpoints is not None: 126 | checkpoints = torch.load(os.path.join('checkpoints', config.checkpoints)) 127 | model_dict = model_large.state_dict() 128 | state_dict = {k: v for k, v in checkpoints['model_state_dict'].items() if k in model_dict.keys()} 129 | model_dict.update(state_dict) 130 | model_large.load_state_dict(model_dict) 131 | print('best_acc: ',checkpoints['best_acc']) 132 | 133 | # =========>train<========= 134 | for epoch in range(1, config.max_epoch + 1): 135 | print('#epoch: {} batch_size: {} Current Learning Rate: {}'.format(epoch, config.batch_size, 136 | config.lr)) 137 | 138 | since = time.time() 139 | train_loss, train_auc, train_TPR = train_epoch(model_large, model_small, optimizer, criterion, train_dataloader) 140 | 141 | val_loss, val_auc, val_TPR = test_epoch(model_large, model_small, criterion, val_dataloader) 142 | 143 | test_loss, test_auc, test_TPR = test_epoch(model_large, model_small, criterion, test_dataloader) 144 | 145 | 146 | result_list = [ 147 | [epoch, train_loss, train_auc, train_TPR, 148 | val_loss, val_auc, val_TPR, 149 | test_loss, test_auc, test_TPR]] 150 | if epoch == 1: 151 | columns = ['epoch', 'train_loss', 'train_auc', 'train_TPR', 152 | 'val_loss', 'val_auc', 'val_TPR', 153 | 'test_loss', 'test_auc', 'test_TPR'] 154 | else: 155 | columns = ['', '', '', '', '', '', '', '', '', ''] 156 | dt = pd.DataFrame(result_list, columns=columns) 157 | dt.to_csv(config.model_name + config.experiment + 'result.csv', mode='a') 158 | print('time:%s\n' % (utils.print_time_cost(since))) 159 | 160 | 161 | if __name__ == '__main__': 162 | train(config) 163 | 164 | -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | @time: 2021/4/15 15:40 4 | 5 | @ author: 6 | """ 7 | import torch, time, os 8 | import models, utils 9 | from torch import nn, optim 10 | from dataset import load_datasets 11 | from config import config 12 | from sklearn.metrics import roc_auc_score 13 | import numpy as np 14 | import random 15 | import pandas as pd 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | def setup_seed(seed): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | def save_checkpoint(best_auc, model, optimizer, epoch): 29 | print('Model Saving...') 30 | if config.device_num > 1: 31 | model_state_dict = model.module.state_dict() 32 | else: 33 | model_state_dict = model.state_dict() 34 | 35 | torch.save({ 36 | 'model_state_dict': model_state_dict, 37 | 'global_epoch': epoch, 38 | 'optimizer_state_dict': optimizer.state_dict(), 39 | 'best_auc': best_auc, 40 | }, os.path.join('checkpoints', config.model_name + '_' + config.experiment + '_checkpoint_best.pth')) 41 | 42 | 43 | def train_epoch(model, optimizer, criterion, train_dataloader): 44 | model.train() 45 | loss_meter, it_count = 0, 0 46 | outputs = [] 47 | targets = [] 48 | for inputs, target in train_dataloader: 49 | 50 | inputs = inputs + torch.randn_like(inputs) * 0.1 51 | 52 | inputs = inputs.to(device) 53 | target = target.to(device) 54 | # zero the parameter gradients 55 | optimizer.zero_grad() 56 | # forward 57 | output = model(inputs) 58 | loss = criterion(output, target) 59 | loss.backward() 60 | optimizer.step() 61 | loss_meter += loss.item() 62 | it_count += 1 63 | 64 | output = torch.sigmoid(output) 65 | for i in range(len(output)): 66 | outputs.append(output[i].cpu().detach().numpy()) 67 | targets.append(target[i].cpu().detach().numpy()) 68 | 69 | auc = roc_auc_score(targets, outputs) 70 | TPR = utils.compute_TPR(targets, outputs) 71 | print('train_loss: %.4f, macro_auc: %.4f, TPR: %.4f' % (loss_meter / it_count, auc, TPR)) 72 | return loss_meter / it_count, auc, TPR 73 | 74 | 75 | # val and test 76 | def test_epoch(model, criterion, val_dataloader): 77 | model.eval() 78 | loss_meter, it_count = 0, 0 79 | outputs = [] 80 | targets = [] 81 | with torch.no_grad(): 82 | for inputs, target in val_dataloader: 83 | 84 | inputs = inputs + torch.randn_like(inputs) * 0.1 85 | 86 | inputs = inputs.to(device) 87 | target = target.to(device) 88 | output = model(inputs) 89 | loss = criterion(output, target) 90 | loss_meter += loss.item() 91 | it_count += 1 92 | 93 | output = torch.sigmoid(output) 94 | for i in range(len(output)): 95 | outputs.append(output[i].cpu().detach().numpy()) 96 | targets.append(target[i].cpu().detach().numpy()) 97 | 98 | auc = roc_auc_score(targets, outputs) 99 | TPR = utils.compute_TPR(targets, outputs) 100 | 101 | print('test_loss: %.4f, macro_auc: %.4f, TPR: %.4f' % (loss_meter / it_count, auc, TPR)) 102 | return loss_meter / it_count, auc, TPR 103 | 104 | 105 | def train(config=config): 106 | # seed 107 | setup_seed(config.seed) 108 | print('torch.cuda.is_available:', torch.cuda.is_available()) 109 | 110 | # datasets 111 | train_dataloader, val_dataloader, test_dataloader, num_classes = load_datasets( 112 | datafolder=config.datafolder, 113 | experiment=config.experiment, 114 | ) 115 | 116 | # mode 117 | model = getattr(models, config.model_name)(num_classes=num_classes) 118 | print('model_name:{}, num_classes={}'.format(config.model_name, num_classes)) 119 | model = model.to(device) 120 | 121 | # optimizer and loss 122 | optimizer = optim.Adam(model.parameters(), lr=config.lr) 123 | criterion = nn.BCEWithLogitsLoss() 124 | 125 | if not os.path.isdir('checkpoints'): 126 | os.mkdir('checkpoints') 127 | 128 | # =========>train<========= 129 | for epoch in range(1, config.max_epoch + 1): 130 | print('#epoch: {} batch_size: {} Current Learning Rate: {}'.format(epoch, config.batch_size, 131 | config.lr)) 132 | 133 | since = time.time() 134 | train_loss, train_auc, train_TPR = train_epoch(model, optimizer, criterion, 135 | train_dataloader) 136 | 137 | val_loss, val_auc, val_TPR = test_epoch(model, criterion, val_dataloader) 138 | 139 | test_loss, test_auc, test_TPR = test_epoch(model, criterion, test_dataloader) 140 | 141 | save_checkpoint(test_auc, model, optimizer, epoch) 142 | 143 | result_list = [[epoch, train_loss, train_auc, train_TPR, 144 | val_loss, val_auc, val_TPR, 145 | test_loss, test_auc, test_TPR]] 146 | 147 | if epoch == 1: 148 | columns = ['epoch', 'train_loss', 'train_auc', 'train_TPR', 149 | 'val_loss', 'val_auc', 'val_TPR', 150 | 'test_loss', 'test_auc', 'test_TPR'] 151 | 152 | else: 153 | columns = ['', '', '', '', '', '', '', '', '', ''] 154 | 155 | dt = pd.DataFrame(result_list, columns=columns) 156 | dt.to_csv(config.model_name + config.experiment + 'result.csv', mode='a') 157 | 158 | print('time:%s\n' % (utils.print_time_cost(since))) 159 | 160 | 161 | if __name__ == '__main__': 162 | # train() 163 | for exp in ['exp0', 'exp1', 'exp1.1', 'exp1.1.1', 'exp2', 'exp3']: 164 | if exp == 'exp0': 165 | config.seed = 10 166 | elif exp == 'exp1': 167 | config.seed = 20 168 | elif exp == 'exp1.1': 169 | config.seed = 20 170 | elif exp == 'exp1.1.1': 171 | config.seed = 20 172 | elif exp == 'exp2': 173 | config.seed = 7 174 | elif exp == 'exp3': 175 | config.seed = 10 176 | config.experiment = exp 177 | train(config) 178 | 179 | config.datafolder = '../data/CPSC/' 180 | config.experiment = 'cpsc' 181 | config.seed = 7 182 | train(config) 183 | 184 | config.datafolder = '../data/hf/' 185 | config.experiment = 'hf' 186 | config.seed = 9 187 | train(config) 188 | -------------------------------------------------------------------------------- /minirocket_train.py: -------------------------------------------------------------------------------- 1 | # Angus Dempster, Daniel F Schmidt, Geoffrey I Webb 2 | 3 | # MiniRocket: A Very Fast (Almost) Deterministic Transform for Time Series 4 | # Classification 5 | 6 | # https://arxiv.org/abs/2012.08791 7 | import utils 8 | from sklearn.metrics import roc_auc_score 9 | import copy 10 | import numpy as np 11 | import torch, torch.nn as nn, torch.optim as optim 12 | from models.minrocket import fit, transform 13 | from dataset import DownLoadECGData 14 | import random 15 | from dataset import hf_dataset 16 | 17 | 18 | def setup_seed(seed): 19 | print('seed: ', seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | random.seed(seed) 24 | torch.backends.cudnn.deterministic = True 25 | 26 | 27 | def train(num_classes, training_size, X_training, Y_training, X_validation, Y_validation, **kwargs): 28 | # -- init ------------------------------------------------------------------ 29 | 30 | # default hyperparameters are reusable for any dataset 31 | args = \ 32 | { 33 | "num_features": 10_000, 34 | "minibatch_size": 256, 35 | "lr": 1e-4, 36 | "max_epochs": 50, 37 | "patience_lr": 5, # 50 minibatches 38 | "patience": 10, # 100 minibatches 39 | "cache_size": training_size # set to 0 to prevent caching 40 | } 41 | args = {**args, **kwargs} 42 | 43 | _num_features = 84 * (args["num_features"] // 84) 44 | 45 | def init(layer): 46 | if isinstance(layer, nn.Linear): 47 | nn.init.constant_(layer.weight.data, 0) 48 | nn.init.constant_(layer.bias.data, 0) 49 | 50 | # -- model ----------------------------------------------------------------- 51 | 52 | model = nn.Sequential(nn.Linear(_num_features, num_classes)) 53 | loss_function = nn.BCEWithLogitsLoss() 54 | optimizer = optim.Adam(model.parameters(), lr=args["lr"]) 55 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, min_lr=1e-8, patience=args["patience_lr"]) 56 | model.apply(init) 57 | 58 | # -- data ------------------------------------------------------- 59 | X_training, Y_training = X_training.astype(np.float32), torch.FloatTensor(Y_training) 60 | X_validation, Y_validation = X_validation.astype(np.float32), torch.FloatTensor(Y_validation) 61 | # -- run ------------------------------------------------------------------- 62 | 63 | minibatch_count = 0 64 | best_validation_loss = np.inf 65 | stall_count = 0 66 | stop = False 67 | 68 | print("Training... (faster once caching is finished)") 69 | 70 | for epoch in range(args["max_epochs"]): 71 | 72 | print(f"Epoch {epoch + 1}...".ljust(80, " "), end="\r", flush=True) 73 | 74 | if epoch == 0: 75 | parameters = fit(X_training, args["num_features"]) 76 | 77 | # transform validation data 78 | X_validation_transform = transform(X_validation, parameters) 79 | 80 | # transform training data 81 | X_training_transform = transform(X_training, parameters) 82 | 83 | if epoch == 0: 84 | # per-feature mean and standard deviation 85 | f_mean = X_training_transform.mean(0) 86 | f_std = X_training_transform.std(0) + 1e-8 87 | 88 | # normalise validation features 89 | X_validation_transform = (X_validation_transform - f_mean) / f_std 90 | X_validation_transform = torch.FloatTensor(X_validation_transform) 91 | 92 | # normalise training features 93 | X_training_transform = (X_training_transform - f_mean) / f_std 94 | X_training_transform = torch.FloatTensor(X_training_transform) 95 | 96 | minibatches = torch.randperm(len(X_training_transform)).split(args["minibatch_size"]) 97 | 98 | # train on transformed features 99 | for minibatch_index, minibatch in enumerate(minibatches): 100 | 101 | if epoch > 0 and stop: 102 | break 103 | 104 | if minibatch_index > 0 and len(minibatch) < args["minibatch_size"]: 105 | break 106 | 107 | # -- training -------------------------------------------------- 108 | 109 | optimizer.zero_grad() 110 | _Y_training = model(X_training_transform[minibatch]) 111 | training_loss = loss_function(_Y_training, Y_training[minibatch]) 112 | training_loss.backward() 113 | optimizer.step() 114 | 115 | minibatch_count += 1 116 | 117 | if minibatch_count % 10 == 0: 118 | 119 | _Y_validation = model(X_validation_transform) 120 | validation_loss = loss_function(_Y_validation, Y_validation) 121 | 122 | scheduler.step(validation_loss) 123 | 124 | if validation_loss.item() >= best_validation_loss: 125 | stall_count += 1 126 | if stall_count >= args["patience"]: 127 | stop = True 128 | print(f"\n") 129 | else: 130 | best_validation_loss = validation_loss.item() 131 | best_model = copy.deepcopy(model) 132 | if not stop: 133 | stall_count = 0 134 | 135 | return parameters, best_model, f_mean, f_std 136 | 137 | 138 | def predict(parameters, model, f_mean, f_std, X_test, Y_test, **kwargs): 139 | predictions = [] 140 | 141 | X_test = X_test.astype(np.float32) 142 | 143 | X_test_transform = transform(X_test, parameters) 144 | X_test_transform = (X_test_transform - f_mean) / f_std 145 | X_test_transform = torch.FloatTensor(X_test_transform) 146 | 147 | _predictions = torch.sigmoid(model(X_test_transform)).cpu().detach().numpy() 148 | predictions.append(_predictions) 149 | predictions = np.array(predictions).squeeze(axis=0) 150 | auc = roc_auc_score(Y_test, predictions) 151 | TPR = utils.compute_TPR(Y_test, predictions) 152 | print("AUC = ", auc, "TPR = ", TPR) 153 | 154 | 155 | def main(data_name='ptbxl'): 156 | setup_seed(7) 157 | if data_name == 'ptbxl': 158 | # eg. ['exp0', 'exp1', 'exp1.1', 'exp1.1.1', 'exp2', 'exp3'] 159 | ded = DownLoadECGData('exp0', 'rhythm', '../data/ptbxl/') 160 | X_training, Y_training, X_validation, Y_validation, X_test, Y_test = ded.preprocess_data() 161 | elif data_name == 'cpsc': 162 | ded = DownLoadECGData('exp_cpsc', 'all', '../data/CPSC/') 163 | X_training, Y_training, X_validation, Y_validation, X_test, Y_test = ded.preprocess_data() 164 | else: 165 | X_training, Y_training, X_validation, Y_validation, X_test, Y_test = hf_dataset() 166 | 167 | parameters, best_model, f_mean, f_std = train(len(Y_training[0]), len(X_training), 168 | X_training, Y_training, 169 | X_validation, Y_validation) 170 | 171 | predict(parameters, best_model, f_mean, f_std, X_test, Y_test) 172 | 173 | if __name__ == '__main__': 174 | main() 175 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @time: 2021/4/17 15:32 4 | 5 | @ author: 6 | ''' 7 | from .model import MyNet, MyNet6View 8 | from .resnet1d_wang import resnet1d_wang 9 | from .xresnet1d101 import xresnet1d101, xresnet1d50 10 | from .inceptiontime import inceptiontime 11 | from .fcn_wang import fcn_wang 12 | from .bi_lstm import lstm, lstm_bidir 13 | from .vit import vit 14 | from .mobilenet_v3 import mobilenetv3_small, mobilenetv3_large 15 | from .dccacb import dccacb 16 | from .ati_cnn import ATI_CNN 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /models/acnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Dccacblock(nn.Module): 7 | def __init__(self, channels, kernel_size): 8 | super(Dccacblock, self).__init__() 9 | self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=1) 10 | self.relu1 = nn.LeakyReLU(0.3) 11 | self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=1) 12 | self.relu2 = nn.LeakyReLU(0.3) 13 | self.conv3 = nn.Conv1d(channels, channels, kernel_size=kernel_size, stride=2, padding=kernel_size//2-1) 14 | self.relu3 = nn.LeakyReLU(0.3) 15 | self.dp = nn.Dropout(0.2) 16 | 17 | def forward(self, x): 18 | return self.dp(self.relu3(self.conv3(self.relu2(self.conv2(self.relu1(self.conv1(x))))))) 19 | 20 | 21 | class Dccacb(nn.Module): 22 | def __init__(self, input_channels, num_classes): 23 | super(Dccacb, self).__init__() 24 | self.block1 = Dccacblock(12, 24) 25 | self.block2 = Dccacblock(12, 24) 26 | self.block3 = Dccacblock(12, 24) 27 | self.block4 = Dccacblock(12, 24) 28 | self.block5 = Dccacblock(12, 48) 29 | self.rnn = nn.GRU(12, 12, bidirectional=True) 30 | self.relu = nn.LeakyReLU(0.3) 31 | self.dp = nn.Dropout(0.2) 32 | 33 | self.attention_layer = nn.Sequential( 34 | nn.Linear(12, 12), 35 | nn.ReLU(inplace=True) 36 | ) 37 | self.bn = nn.BatchNorm1d(12) 38 | self.relu1 = nn.LeakyReLU(0.3) 39 | 40 | self.fc = nn.Linear(12, num_classes) 41 | 42 | def forward(self, x): 43 | x = self.block5(self.block4(self.block3(self.block2(self.block1(x))))) 44 | x = x.transpose(0, 2).transpose(1, 2) 45 | 46 | x, _ = self.rnn(x) 47 | x = self.dp(self.relu(x)) 48 | 49 | x = x.transpose(0, 1) 50 | x = self.dp(self.relu1(self.bn(self.attention_net_with_w(x)))) 51 | 52 | x = self.fc(x) 53 | 54 | return x 55 | 56 | def attention_net_with_w(self, lstm_out): 57 | lstm_tmp_out = torch.chunk(lstm_out, 2, -1) 58 | h = lstm_tmp_out[0] + lstm_tmp_out[1] 59 | atten_w = self.attention_layer(h) 60 | m = nn.Tanh()(h) 61 | atten_context = torch.bmm(m, atten_w.transpose(1, 2)) 62 | 63 | softmax_w = F.softmax(atten_context, dim=-1) 64 | 65 | context = torch.bmm(h.transpose(1, 2), softmax_w) 66 | result = torch.sum(context, dim=-1) 67 | result = nn.Dropout(0.)(result) 68 | return result 69 | 70 | 71 | 72 | def dccacb(**kwargs): 73 | return Dccacb(input_channels=12, **kwargs) 74 | 75 | 76 | # if __name__ == '__main__': 77 | # model = Dccacb(10, 12) 78 | # x = torch.randn(64, 12, 1000) 79 | 80 | -------------------------------------------------------------------------------- /models/ati_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ATICNN(nn.Module): 6 | def __init__(self, features, num_classes=1000, init_weights=False): 7 | super(ATICNN, self).__init__() 8 | self.L = 32 9 | self.D = 128 10 | self.K = num_classes 11 | 12 | self.features = features 13 | self.lstm = nn.LSTM(num_layers=2, input_size=512, hidden_size=32, dropout=0.2) 14 | 15 | self.attention = nn.Sequential( 16 | nn.Linear(self.L, self.D), # 32 * 32 17 | nn.Tanh(), 18 | nn.Linear(self.D, self.K) # 32 * 9 19 | ) 20 | if init_weights: 21 | self._initialize_weights() 22 | 23 | def forward(self, x): 24 | x = self.features(x) # N*3*224*224 25 | x = x.transpose(1, 2) 26 | x = x.transpose(0, 1) 27 | x = self.lstm(x) 28 | x = x[0].transpose(0, 1) # [batch, num_directions * num_layers, hidden_size] 29 | # x = x.transpose(1, 2) 30 | A = self.attention(x) # NxK 31 | # A = torch.transpose(A, 1, 0) # KxN 32 | A = F.softmax(A, dim=-1) # softmax over N 33 | M = torch.bmm(x.transpose(1, 2), A) # KxL 34 | # print(x.shape, A.shape) 35 | M = torch.sum(M, dim=1) 36 | return M 37 | 38 | def _initialize_weights(self): 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | nn.init.xavier_uniform_(m.weight) 42 | if m.bias is not None: 43 | nn.init.constant_(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | nn.init.xavier_uniform_(m.weight) 46 | nn.init.constant_(m.bias, 0) 47 | 48 | 49 | 50 | def make_features(input_channels: int, cfg: list): 51 | layers = [] 52 | in_channels = input_channels 53 | for v in cfg: 54 | if v == "M": 55 | layers += [nn.MaxPool1d(kernel_size=1, stride=2)] 56 | else: 57 | conv1d = nn.Conv1d(in_channels, v, kernel_size=3, padding=1) 58 | layers += [conv1d, nn.BatchNorm1d(v), nn.ReLU(True)] 59 | in_channels = v 60 | return nn.Sequential(*layers) 61 | 62 | 63 | def ATI_CNN(input_channels=12, num_classes=9, **kwargs): 64 | config = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 65 | model = ATICNN(make_features(input_channels, config), num_classes = num_classes, **kwargs) 66 | return model 67 | 68 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # SEnet 6 | class SELayer(nn.Module): 7 | def __init__(self, planes): 8 | super(SELayer, self).__init__() 9 | self.relu = nn.ReLU(inplace=True) 10 | self.GAP = nn.AdaptiveAvgPool1d(1) 11 | self.fc1 = nn.Linear(in_features=planes, out_features=round(planes / 16)) 12 | self.fc2 = nn.Linear(in_features=round(planes / 16), out_features=planes) 13 | self.sigmoid = nn.Sigmoid() 14 | 15 | def forward(self, x): 16 | original_out = x 17 | out = self.GAP(x) 18 | out = out.view(out.size(0), -1) 19 | out = self.fc1(out) 20 | out = self.relu(out) 21 | out = self.fc2(out) 22 | out = self.sigmoid(out) 23 | out = out.view(out.size(0), out.size(1), 1) 24 | out = out * original_out 25 | return out 26 | 27 | 28 | # CoordAtt 29 | class h_sigmoid(nn.Module): 30 | def __init__(self, inplace=True): 31 | super(h_sigmoid, self).__init__() 32 | self.relu = nn.ReLU6(inplace=inplace) 33 | 34 | def forward(self, x): 35 | return self.relu(x + 3) / 6 36 | 37 | 38 | class h_swish(nn.Module): 39 | def __init__(self, inplace=True): 40 | super(h_swish, self).__init__() 41 | self.sigmoid = h_sigmoid(inplace=inplace) 42 | 43 | def forward(self, x): 44 | return x * self.sigmoid(x) 45 | 46 | 47 | class CoordAtt(nn.Module): 48 | def __init__(self, inp, oup, reduction=16): 49 | super(CoordAtt, self).__init__() 50 | self.pool_h = nn.AdaptiveAvgPool2d((None, 1)) 51 | # self.pool_w = nn.AdaptiveAvgPool2d((1, None)) 52 | 53 | mip = max(8, inp // reduction) 54 | 55 | self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0) 56 | self.bn1 = nn.BatchNorm2d(mip) 57 | # self.act = MetaAconC(mip) 58 | self.act = h_swish() 59 | 60 | self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 61 | self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0) 62 | 63 | def forward(self, x): 64 | x = x.unsqueeze(2) 65 | 66 | identity = x 67 | 68 | n, c, h, w = x.size() 69 | x_h = self.pool_h(x) 70 | 71 | # x_w = self.pool_w(x).permute(0, 1, 3, 2) 72 | x_w = x.permute(0, 1, 3, 2) 73 | 74 | y = torch.cat([x_h, x_w], dim=2) 75 | y = self.conv1(y) 76 | y = self.bn1(y) 77 | y = self.act(y) 78 | 79 | x_h, x_w = torch.split(y, [h, w], dim=2) 80 | x_w = x_w.permute(0, 1, 3, 2) 81 | 82 | a_h = self.conv_h(x_h).sigmoid() 83 | a_w = self.conv_w(x_w).sigmoid() 84 | 85 | out = identity * a_w * a_h 86 | 87 | out = out.squeeze(2) 88 | return out 89 | 90 | 91 | # CBAM 92 | class ChannelAttention(nn.Module): 93 | def __init__(self, in_planes, ratio=16): 94 | super(ChannelAttention, self).__init__() 95 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 96 | self.max_pool = nn.AdaptiveMaxPool1d(1) 97 | 98 | self.fc1 = nn.Conv1d(in_planes, in_planes // ratio, 1, bias=False) 99 | self.relu1 = nn.ReLU() 100 | self.fc2 = nn.Conv1d(in_planes // ratio, in_planes, 1, bias=False) 101 | self.sigmoid = nn.Sigmoid() 102 | 103 | def forward(self, x): 104 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 105 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 106 | out = avg_out + max_out 107 | return self.sigmoid(out) 108 | 109 | 110 | class SpatialAttention(nn.Module): 111 | def __init__(self, kernel_size=1, padding=0): 112 | super(SpatialAttention, self).__init__() 113 | self.conv1 = nn.Conv1d(2, 1, kernel_size, padding=padding, bias=False) # concat完channel维度为2 114 | self.sigmoid = nn.Sigmoid() 115 | 116 | def forward(self, x): 117 | avg_out = torch.mean(x, dim=1, keepdim=True) # 沿着channel 维度计算均值和最大值 118 | max_out, _ = torch.max(x, dim=1, keepdim=True) 119 | x = torch.cat([avg_out, max_out], dim=1) # 沿着channel维度concat一块 120 | x = self.conv1(x) 121 | return self.sigmoid(x) 122 | 123 | 124 | class CBAM(nn.Module): 125 | def __init__(self, channel): 126 | super(CBAM, self).__init__() 127 | self.channel_attention = ChannelAttention(channel) 128 | self.spatial_attention = SpatialAttention() 129 | 130 | def forward(self, x): 131 | out = self.channel_attention(x) * x 132 | out = self.spatial_attention(out) * out 133 | return out 134 | 135 | 136 | -------------------------------------------------------------------------------- /models/bi_lstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class RNN1d(nn.Sequential): 5 | def __init__(self, num_classes, input_channels=12, hidden_dim=256, num_layers=2, bidirectional=False): 6 | super(RNN1d, self).__init__() 7 | self.lstm = nn.LSTM(input_size=input_channels, hidden_size=hidden_dim, num_layers=num_layers, 8 | bidirectional=bidirectional) 9 | 10 | self.avgpool = nn.AdaptiveAvgPool1d(1) 11 | if not bidirectional: 12 | self.fc = nn.Linear(256, num_classes) 13 | else: 14 | self.fc = nn.Linear(512, num_classes) 15 | 16 | def forward(self, x): 17 | x = x.transpose(1, 2) 18 | x = x.transpose(0, 1) 19 | output = self.lstm(x) 20 | output = output[0].transpose(0, 1) 21 | output = output.transpose(1, 2) 22 | output = self.avgpool(output).squeeze(-1) 23 | output = self.fc(output) 24 | return output 25 | 26 | 27 | def lstm(**kwargs): 28 | return RNN1d(bidirectional=False, **kwargs) 29 | 30 | 31 | def lstm_bidir(**kwargs): 32 | return RNN1d(bidirectional=True, **kwargs) 33 | 34 | 35 | -------------------------------------------------------------------------------- /models/fcn_wang.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def conv1d(in_planes, out_planes, kernel_size=3, stride=1, dilation=1): 5 | lst = [] 6 | lst.append(nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size - 1) // 2, 7 | dilation=dilation, bias=False)) 8 | lst.append(nn.BatchNorm1d(out_planes)) 9 | lst.append(nn.ReLU(True)) 10 | return nn.Sequential(*lst) 11 | 12 | 13 | class FCN(nn.Sequential): 14 | def __init__(self, input_channels=12, num_classes=5): 15 | super(FCN, self).__init__() 16 | self.conv1 = conv1d(input_channels, 128, kernel_size=8, stride=1) 17 | self.conv2 = conv1d(128, 256, kernel_size=5, stride=1) 18 | self.conv3 = conv1d(256, 128, kernel_size=3, stride=1) 19 | self.Avgpool = nn.AdaptiveAvgPool1d(1) 20 | self.fc = nn.Linear(128, num_classes) 21 | 22 | def forward(self, x): 23 | output = self.conv1(x) 24 | output = self.conv2(output) 25 | output = self.conv3(output) 26 | output = self.Avgpool(output) 27 | output = output.view(output.size(0), -1) 28 | output = self.fc(output) 29 | return output 30 | 31 | 32 | def fcn_wang(**kwargs): 33 | return FCN(input_channels=12, **kwargs) 34 | 35 | -------------------------------------------------------------------------------- /models/inceptiontime.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | 6 | class BaseBlock(nn.Module): 7 | def __init__(self, in_planes): 8 | super(BaseBlock, self).__init__() 9 | 10 | self.bottleneck = nn.Conv1d(in_planes, 32, kernel_size=1, stride=1, bias=False) 11 | self.conv4 = nn.Conv1d(32, 32, kernel_size=39, stride=1, padding=19, bias=False) 12 | self.conv3 = nn.Conv1d(32, 32, kernel_size=19, stride=1, padding=9, bias=False) 13 | self.conv2 = nn.Conv1d(32, 32, kernel_size=9, stride=1, padding=4, bias=False) 14 | 15 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False) 16 | self.conv1 = nn.Conv1d(in_planes, 32, kernel_size=1, stride=1, bias=False) 17 | 18 | self.bn = nn.BatchNorm1d(32 * 4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) 19 | self.relu = nn.ReLU(inplace=True) 20 | 21 | def forward(self, x): 22 | output = self.bottleneck(x) 23 | output4 = self.conv4(output) 24 | output3 = self.conv3(output) 25 | output2 = self.conv2(output) 26 | 27 | output1 = self.maxpool(x) 28 | output1 = self.conv1(output1) 29 | 30 | x_out = self.relu(self.bn(torch.cat((output1, output2, output3, output4), dim=1))) 31 | return x_out 32 | 33 | 34 | class InceptionTime(nn.Module): 35 | def __init__(self, in_channel=12, num_classes=10): 36 | super(InceptionTime, self).__init__() 37 | 38 | self.BaseBlock1 = BaseBlock(in_channel) 39 | self.BaseBlock2 = BaseBlock(128) 40 | self.BaseBlock3 = BaseBlock(128) 41 | 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv1 = nn.Conv1d(in_channel, 128, kernel_size=1, stride=1, bias=False) 44 | self.bn1 = nn.BatchNorm1d(128) 45 | 46 | self.BaseBlock4 = BaseBlock(128) 47 | self.BaseBlock5 = BaseBlock(128) 48 | self.BaseBlock6 = BaseBlock(128) 49 | 50 | self.conv2 = nn.Conv1d(128, 128, kernel_size=1, stride=1, bias=False) 51 | self.bn2 = nn.BatchNorm1d(128) 52 | 53 | self.Avgpool = nn.AdaptiveAvgPool1d(1) 54 | self.fc = nn.Linear(128, num_classes) 55 | 56 | def forward(self, x): 57 | shortcut1 = self.bn1(self.conv1(x)) 58 | 59 | output1 = self.BaseBlock1(x) 60 | output1 = self.BaseBlock2(output1) 61 | output1 = self.BaseBlock3(output1) 62 | output1 = self.relu(output1 + shortcut1) 63 | 64 | shortcut2 = self.bn2(self.conv2(output1)) 65 | 66 | output2 = self.BaseBlock4(output1) 67 | output2 = self.BaseBlock5(output2) 68 | output2 = self.BaseBlock6(output2) 69 | output2 = self.relu(output2 + shortcut2) 70 | 71 | output = self.Avgpool(output2) 72 | output = output.view(output.size(0), -1) 73 | output = self.fc(output) 74 | return output 75 | 76 | 77 | def inceptiontime(**kwargs): 78 | return InceptionTime(in_channel=12, **kwargs) 79 | 80 | 81 | -------------------------------------------------------------------------------- /models/minrocket.py: -------------------------------------------------------------------------------- 1 | # Angus Dempster, Daniel F Schmidt, Geoffrey I Webb 2 | 3 | # MiniRocket: A Very Fast (Almost) Deterministic Transform for Time Series 4 | # Classification 5 | 6 | # https://arxiv.org/abs/2012.08791 7 | 8 | # ** This is a naive extension of MiniRocket to multivariate time series. ** 9 | 10 | from numba import njit, prange, vectorize 11 | import numpy as np 12 | 13 | @njit("float32[:](float32[:,:,:],int32[:],int32[:],int32[:],int32[:],float32[:])", fastmath = True, parallel = False, cache = True) 14 | def _fit_biases(X, num_channels_per_combination, channel_indices, dilations, num_features_per_dilation, quantiles): 15 | 16 | num_examples, num_channels, input_length = X.shape 17 | 18 | # equivalent to: 19 | # >>> from itertools import combinations 20 | # >>> indices = np.array([_ for _ in combinations(np.arange(9), 3)], dtype = np.int32) 21 | indices = np.array(( 22 | 0,1,2,0,1,3,0,1,4,0,1,5,0,1,6,0,1,7,0,1,8, 23 | 0,2,3,0,2,4,0,2,5,0,2,6,0,2,7,0,2,8,0,3,4, 24 | 0,3,5,0,3,6,0,3,7,0,3,8,0,4,5,0,4,6,0,4,7, 25 | 0,4,8,0,5,6,0,5,7,0,5,8,0,6,7,0,6,8,0,7,8, 26 | 1,2,3,1,2,4,1,2,5,1,2,6,1,2,7,1,2,8,1,3,4, 27 | 1,3,5,1,3,6,1,3,7,1,3,8,1,4,5,1,4,6,1,4,7, 28 | 1,4,8,1,5,6,1,5,7,1,5,8,1,6,7,1,6,8,1,7,8, 29 | 2,3,4,2,3,5,2,3,6,2,3,7,2,3,8,2,4,5,2,4,6, 30 | 2,4,7,2,4,8,2,5,6,2,5,7,2,5,8,2,6,7,2,6,8, 31 | 2,7,8,3,4,5,3,4,6,3,4,7,3,4,8,3,5,6,3,5,7, 32 | 3,5,8,3,6,7,3,6,8,3,7,8,4,5,6,4,5,7,4,5,8, 33 | 4,6,7,4,6,8,4,7,8,5,6,7,5,6,8,5,7,8,6,7,8 34 | ), dtype = np.int32).reshape(84, 3) 35 | 36 | num_kernels = len(indices) 37 | num_dilations = len(dilations) 38 | 39 | num_features = num_kernels * np.sum(num_features_per_dilation) 40 | 41 | biases = np.zeros(num_features, dtype = np.float32) 42 | 43 | feature_index_start = 0 44 | 45 | combination_index = 0 46 | num_channels_start = 0 47 | 48 | for dilation_index in range(num_dilations): 49 | 50 | dilation = dilations[dilation_index] 51 | padding = ((9 - 1) * dilation) // 2 52 | 53 | num_features_this_dilation = num_features_per_dilation[dilation_index] 54 | 55 | for kernel_index in range(num_kernels): 56 | 57 | feature_index_end = feature_index_start + num_features_this_dilation 58 | 59 | num_channels_this_combination = num_channels_per_combination[combination_index] 60 | 61 | num_channels_end = num_channels_start + num_channels_this_combination 62 | 63 | channels_this_combination = channel_indices[num_channels_start:num_channels_end] 64 | 65 | _X = X[np.random.randint(num_examples)][channels_this_combination] 66 | 67 | A = -_X # A = alpha * X = -X 68 | G = _X + _X + _X # G = gamma * X = 3X 69 | 70 | C_alpha = np.zeros((num_channels_this_combination, input_length), dtype = np.float32) 71 | C_alpha[:] = A 72 | 73 | C_gamma = np.zeros((9, num_channels_this_combination, input_length), dtype = np.float32) 74 | C_gamma[9 // 2] = G 75 | 76 | start = dilation 77 | end = input_length - padding 78 | 79 | for gamma_index in range(9 // 2): 80 | 81 | C_alpha[:, -end:] = C_alpha[:, -end:] + A[:, :end] 82 | C_gamma[gamma_index, :, -end:] = G[:, :end] 83 | 84 | end += dilation 85 | 86 | for gamma_index in range(9 // 2 + 1, 9): 87 | 88 | C_alpha[:, :-start] = C_alpha[:, :-start] + A[:, start:] 89 | C_gamma[gamma_index, :, :-start] = G[:, start:] 90 | 91 | start += dilation 92 | 93 | index_0, index_1, index_2 = indices[kernel_index] 94 | 95 | C = C_alpha + C_gamma[index_0] + C_gamma[index_1] + C_gamma[index_2] 96 | C = np.sum(C, axis = 0) 97 | 98 | biases[feature_index_start:feature_index_end] = np.quantile(C, quantiles[feature_index_start:feature_index_end]) 99 | 100 | feature_index_start = feature_index_end 101 | 102 | combination_index += 1 103 | num_channels_start = num_channels_end 104 | 105 | return biases 106 | 107 | def _fit_dilations(input_length, num_features, max_dilations_per_kernel): 108 | 109 | num_kernels = 84 110 | 111 | num_features_per_kernel = num_features // num_kernels 112 | true_max_dilations_per_kernel = min(num_features_per_kernel, max_dilations_per_kernel) 113 | multiplier = num_features_per_kernel / true_max_dilations_per_kernel 114 | 115 | max_exponent = np.log2((input_length - 1) / (9 - 1)) 116 | dilations, num_features_per_dilation = \ 117 | np.unique(np.logspace(0, max_exponent, true_max_dilations_per_kernel, base = 2).astype(np.int32), return_counts = True) 118 | num_features_per_dilation = (num_features_per_dilation * multiplier).astype(np.int32) # this is a vector 119 | 120 | remainder = num_features_per_kernel - np.sum(num_features_per_dilation) 121 | i = 0 122 | while remainder > 0: 123 | num_features_per_dilation[i] += 1 124 | remainder -= 1 125 | i = (i + 1) % len(num_features_per_dilation) 126 | 127 | return dilations, num_features_per_dilation 128 | 129 | # low-discrepancy sequence to assign quantiles to kernel/dilation combinations 130 | def _quantiles(n): 131 | return np.array([(_ * ((np.sqrt(5) + 1) / 2)) % 1 for _ in range(1, n + 1)], dtype = np.float32) 132 | 133 | def fit(X, num_features = 10_000, max_dilations_per_kernel = 32): 134 | 135 | _, num_channels, input_length = X.shape 136 | 137 | num_kernels = 84 138 | 139 | dilations, num_features_per_dilation = _fit_dilations(input_length, num_features, max_dilations_per_kernel) 140 | 141 | num_features_per_kernel = np.sum(num_features_per_dilation) 142 | 143 | quantiles = _quantiles(num_kernels * num_features_per_kernel) 144 | 145 | num_dilations = len(dilations) 146 | num_combinations = num_kernels * num_dilations 147 | 148 | max_num_channels = min(num_channels, 9) 149 | max_exponent = np.log2(max_num_channels + 1) 150 | 151 | num_channels_per_combination = (2 ** np.random.uniform(0, max_exponent, num_combinations)).astype(np.int32) 152 | 153 | channel_indices = np.zeros(num_channels_per_combination.sum(), dtype = np.int32) 154 | 155 | num_channels_start = 0 156 | for combination_index in range(num_combinations): 157 | num_channels_this_combination = num_channels_per_combination[combination_index] 158 | num_channels_end = num_channels_start + num_channels_this_combination 159 | channel_indices[num_channels_start:num_channels_end] = np.random.choice(num_channels, num_channels_this_combination, replace = False) 160 | 161 | num_channels_start = num_channels_end 162 | 163 | biases = _fit_biases(X, num_channels_per_combination, channel_indices, dilations, num_features_per_dilation, quantiles) 164 | 165 | return num_channels_per_combination, channel_indices, dilations, num_features_per_dilation, biases 166 | 167 | # _PPV(C, b).mean() returns PPV for vector C (convolution output) and scalar b (bias) 168 | @vectorize("float32(float32,float32)", nopython = True, cache = True) 169 | def _PPV(a, b): 170 | if a > b: 171 | return 1 172 | else: 173 | return 0 174 | 175 | @njit("float32[:,:](float32[:,:,:],Tuple((int32[:],int32[:],int32[:],int32[:],float32[:])))", fastmath = True, parallel = True, cache = True) 176 | def transform(X, parameters): 177 | 178 | num_examples, num_channels, input_length = X.shape 179 | 180 | num_channels_per_combination, channel_indices, dilations, num_features_per_dilation, biases = parameters 181 | 182 | # equivalent to: 183 | # >>> from itertools import combinations 184 | # >>> indices = np.array([_ for _ in combinations(np.arange(9), 3)], dtype = np.int32) 185 | indices = np.array(( 186 | 0,1,2,0,1,3,0,1,4,0,1,5,0,1,6,0,1,7,0,1,8, 187 | 0,2,3,0,2,4,0,2,5,0,2,6,0,2,7,0,2,8,0,3,4, 188 | 0,3,5,0,3,6,0,3,7,0,3,8,0,4,5,0,4,6,0,4,7, 189 | 0,4,8,0,5,6,0,5,7,0,5,8,0,6,7,0,6,8,0,7,8, 190 | 1,2,3,1,2,4,1,2,5,1,2,6,1,2,7,1,2,8,1,3,4, 191 | 1,3,5,1,3,6,1,3,7,1,3,8,1,4,5,1,4,6,1,4,7, 192 | 1,4,8,1,5,6,1,5,7,1,5,8,1,6,7,1,6,8,1,7,8, 193 | 2,3,4,2,3,5,2,3,6,2,3,7,2,3,8,2,4,5,2,4,6, 194 | 2,4,7,2,4,8,2,5,6,2,5,7,2,5,8,2,6,7,2,6,8, 195 | 2,7,8,3,4,5,3,4,6,3,4,7,3,4,8,3,5,6,3,5,7, 196 | 3,5,8,3,6,7,3,6,8,3,7,8,4,5,6,4,5,7,4,5,8, 197 | 4,6,7,4,6,8,4,7,8,5,6,7,5,6,8,5,7,8,6,7,8 198 | ), dtype = np.int32).reshape(84, 3) 199 | 200 | num_kernels = len(indices) 201 | num_dilations = len(dilations) 202 | 203 | num_features = num_kernels * np.sum(num_features_per_dilation) 204 | 205 | features = np.zeros((num_examples, num_features), dtype = np.float32) 206 | 207 | for example_index in prange(num_examples): 208 | 209 | _X = X[example_index] 210 | 211 | A = -_X # A = alpha * X = -X 212 | G = _X + _X + _X # G = gamma * X = 3X 213 | 214 | feature_index_start = 0 215 | 216 | combination_index = 0 217 | num_channels_start = 0 218 | 219 | for dilation_index in range(num_dilations): 220 | 221 | _padding0 = dilation_index % 2 222 | 223 | dilation = dilations[dilation_index] 224 | padding = ((9 - 1) * dilation) // 2 225 | 226 | num_features_this_dilation = num_features_per_dilation[dilation_index] 227 | 228 | C_alpha = np.zeros((num_channels, input_length), dtype = np.float32) 229 | C_alpha[:] = A 230 | 231 | C_gamma = np.zeros((9, num_channels, input_length), dtype = np.float32) 232 | C_gamma[9 // 2] = G 233 | 234 | start = dilation 235 | end = input_length - padding 236 | 237 | for gamma_index in range(9 // 2): 238 | 239 | C_alpha[:, -end:] = C_alpha[:, -end:] + A[:, :end] 240 | C_gamma[gamma_index, :, -end:] = G[:, :end] 241 | 242 | end += dilation 243 | 244 | for gamma_index in range(9 // 2 + 1, 9): 245 | 246 | C_alpha[:, :-start] = C_alpha[:, :-start] + A[:, start:] 247 | C_gamma[gamma_index, :, :-start] = G[:, start:] 248 | 249 | start += dilation 250 | 251 | for kernel_index in range(num_kernels): 252 | 253 | feature_index_end = feature_index_start + num_features_this_dilation 254 | 255 | num_channels_this_combination = num_channels_per_combination[combination_index] 256 | 257 | num_channels_end = num_channels_start + num_channels_this_combination 258 | 259 | channels_this_combination = channel_indices[num_channels_start:num_channels_end] 260 | 261 | _padding1 = (_padding0 + kernel_index) % 2 262 | 263 | index_0, index_1, index_2 = indices[kernel_index] 264 | 265 | C = C_alpha[channels_this_combination] + \ 266 | C_gamma[index_0][channels_this_combination] + \ 267 | C_gamma[index_1][channels_this_combination] + \ 268 | C_gamma[index_2][channels_this_combination] 269 | C = np.sum(C, axis = 0) 270 | 271 | if _padding1 == 0: 272 | for feature_count in range(num_features_this_dilation): 273 | features[example_index, feature_index_start + feature_count] = _PPV(C, biases[feature_index_start + feature_count]).mean() 274 | else: 275 | for feature_count in range(num_features_this_dilation): 276 | features[example_index, feature_index_start + feature_count] = _PPV(C[padding:-padding], biases[feature_index_start + feature_count]).mean() 277 | 278 | feature_index_start = feature_index_end 279 | 280 | combination_index += 1 281 | num_channels_start = num_channels_end 282 | 283 | return features 284 | 285 | 286 | -------------------------------------------------------------------------------- /models/mobilenet_v3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates a MobileNetV3 Model as defined in: 3 | Andrew Howard, Mark Sandler, Grace Chu, Liang-Chieh Chen, Bo Chen, Mingxing Tan, Weijun Wang, Yukun Zhu, Ruoming Pang, Vijay Vasudevan, Quoc V. Le, Hartwig Adam. (2019). 4 | Searching for MobileNetV3 5 | arXiv preprint arXiv:1905.02244. 6 | """ 7 | import torch.nn as nn 8 | import math 9 | 10 | 11 | __all__ = ['mobilenetv3_large', 'mobilenetv3_small'] 12 | 13 | 14 | def _make_divisible(v, divisor, min_value=None): 15 | """ 16 | This function is taken from the original tf repo. 17 | It ensures that all layers have a channel number that is divisible by 8 18 | It can be seen here: 19 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 20 | :param v: 21 | :param divisor: 22 | :param min_value: 23 | :return: 24 | """ 25 | if min_value is None: 26 | min_value = divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < 0.9 * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | class h_sigmoid(nn.Module): 35 | def __init__(self, inplace=True): 36 | super(h_sigmoid, self).__init__() 37 | self.relu = nn.ReLU6(inplace=inplace) 38 | 39 | def forward(self, x): 40 | return self.relu(x + 3) / 6 41 | 42 | 43 | class h_swish(nn.Module): 44 | def __init__(self, inplace=True): 45 | super(h_swish, self).__init__() 46 | self.sigmoid = h_sigmoid(inplace=inplace) 47 | 48 | def forward(self, x): 49 | return x * self.sigmoid(x) 50 | 51 | 52 | class SELayer(nn.Module): 53 | def __init__(self, channel, reduction=4): 54 | super(SELayer, self).__init__() 55 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 56 | self.fc = nn.Sequential( 57 | nn.Linear(channel, _make_divisible(channel // reduction, 8)), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(_make_divisible(channel // reduction, 8), channel), 60 | h_sigmoid() 61 | ) 62 | 63 | def forward(self, x): 64 | b, c, _ = x.size() 65 | y = self.avg_pool(x).view(b, c) 66 | y = self.fc(y).view(b, c, 1) 67 | return x * y 68 | 69 | 70 | def conv_3x3_bn(inp, oup, stride): 71 | return nn.Sequential( 72 | nn.Conv1d(inp, oup, 15, stride, 7, bias=False), 73 | nn.BatchNorm1d(oup), 74 | h_swish() 75 | ) 76 | 77 | 78 | def conv_1x1_bn(inp, oup): 79 | return nn.Sequential( 80 | nn.Conv1d(inp, oup, 1, 1, 0, bias=False), 81 | nn.BatchNorm1d(oup), 82 | h_swish() 83 | ) 84 | 85 | 86 | class InvertedResidual(nn.Module): 87 | def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs): 88 | super(InvertedResidual, self).__init__() 89 | assert stride in [1, 2] 90 | 91 | self.identity = stride == 1 and inp == oup 92 | 93 | if inp == hidden_dim: 94 | self.conv = nn.Sequential( 95 | # dw 96 | nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 97 | nn.BatchNorm1d(hidden_dim), 98 | h_swish() if use_hs else nn.ReLU(inplace=True), 99 | # Squeeze-and-Excite 100 | SELayer(hidden_dim) if use_se else nn.Identity(), 101 | # pw-linear 102 | nn.Conv1d(hidden_dim, oup, 1, 1, 0, bias=False), 103 | nn.BatchNorm1d(oup), 104 | ) 105 | else: 106 | self.conv = nn.Sequential( 107 | # pw 108 | nn.Conv1d(inp, hidden_dim, 1, 1, 0, bias=False), 109 | nn.BatchNorm1d(hidden_dim), 110 | h_swish() if use_hs else nn.ReLU(inplace=True), 111 | # dw 112 | nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride, (kernel_size - 1) // 2, groups=hidden_dim, bias=False), 113 | nn.BatchNorm1d(hidden_dim), 114 | # Squeeze-and-Excite 115 | SELayer(hidden_dim) if use_se else nn.Identity(), 116 | h_swish() if use_hs else nn.ReLU(inplace=True), 117 | # pw-linear 118 | nn.Conv1d(hidden_dim, oup, 1, 1, 0, bias=False), 119 | nn.BatchNorm1d(oup), 120 | ) 121 | 122 | def forward(self, x): 123 | if self.identity: 124 | return x + self.conv(x) 125 | else: 126 | return self.conv(x) 127 | 128 | 129 | class MobileNetV3(nn.Module): 130 | def __init__(self, cfgs, mode, in_channel=12, num_classes=10, width_mult=1.): 131 | super(MobileNetV3, self).__init__() 132 | # setting of inverted residual blocks 133 | self.cfgs = cfgs 134 | assert mode in ['large', 'small'] 135 | 136 | # building first layer 137 | input_channel = _make_divisible(16 * width_mult, 8) 138 | layers = [conv_3x3_bn(in_channel, input_channel, 2)] 139 | # building inverted residual blocks 140 | block = InvertedResidual 141 | for k, t, c, use_se, use_hs, s in self.cfgs: 142 | output_channel = _make_divisible(c * width_mult, 8) 143 | exp_size = _make_divisible(input_channel * t, 8) 144 | layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs)) 145 | input_channel = output_channel 146 | self.features = nn.Sequential(*layers) 147 | # building last several layers 148 | self.conv = conv_1x1_bn(input_channel, exp_size) 149 | self.avgpool = nn.AdaptiveAvgPool1d(1) 150 | output_channel = {'large': 1280, 'small': 1024} 151 | output_channel = _make_divisible(output_channel[mode] * width_mult, 8) if width_mult > 1.0 else output_channel[mode] 152 | self.classifier = nn.Sequential( 153 | nn.Linear(exp_size, output_channel), 154 | h_swish(), 155 | nn.Dropout(0.2), 156 | nn.Linear(output_channel, num_classes), 157 | ) 158 | 159 | self._initialize_weights() 160 | 161 | def forward(self, x): 162 | x = self.features(x) 163 | x = self.conv(x) 164 | x = self.avgpool(x) 165 | x = x.view(x.size(0), -1) 166 | x = self.classifier(x) 167 | return x 168 | 169 | def _initialize_weights(self): 170 | for m in self.modules(): 171 | if isinstance(m, nn.Conv2d): 172 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 173 | m.weight.data.normal_(0, math.sqrt(2. / n)) 174 | if m.bias is not None: 175 | m.bias.data.zero_() 176 | elif isinstance(m, nn.BatchNorm2d): 177 | m.weight.data.fill_(1) 178 | m.bias.data.zero_() 179 | elif isinstance(m, nn.Linear): 180 | m.weight.data.normal_(0, 0.01) 181 | m.bias.data.zero_() 182 | 183 | 184 | def mobilenetv3_large(**kwargs): 185 | """ 186 | Constructs a MobileNetV3-Large model 187 | """ 188 | cfgs = [ 189 | # k, t, c, SE, HS, s 190 | [3, 1, 16, 0, 0, 1], 191 | [3, 4, 24, 0, 0, 2], 192 | [3, 3, 24, 0, 0, 1], 193 | [5, 3, 40, 1, 0, 2], 194 | [5, 3, 40, 1, 0, 1], 195 | [5, 3, 40, 1, 0, 1], 196 | [3, 6, 80, 0, 1, 2], 197 | [3, 2.5, 80, 0, 1, 1], 198 | [3, 2.3, 80, 0, 1, 1], 199 | [3, 2.3, 80, 0, 1, 1], 200 | [3, 6, 112, 1, 1, 1], 201 | [3, 6, 112, 1, 1, 1], 202 | [5, 6, 160, 1, 1, 2], 203 | [5, 6, 160, 1, 1, 1], 204 | [5, 6, 160, 1, 1, 1] 205 | ] 206 | return MobileNetV3(cfgs, mode='large', **kwargs) 207 | 208 | 209 | def mobilenetv3_small(**kwargs): 210 | """ 211 | Constructs a MobileNetV3-Small model 212 | """ 213 | cfgs = [ 214 | # k, t, c, SE, HS, s 215 | [3, 1, 16, 1, 0, 2], 216 | [3, 4.5, 24, 0, 0, 2], 217 | [3, 3.67, 24, 0, 0, 1], 218 | [5, 4, 40, 1, 1, 2], 219 | [5, 6, 40, 1, 1, 1], 220 | [5, 6, 40, 1, 1, 1], 221 | [5, 3, 48, 1, 1, 1], 222 | [5, 3, 48, 1, 1, 1], 223 | [5, 6, 96, 1, 1, 2], 224 | [5, 6, 96, 1, 1, 1], 225 | [5, 6, 96, 1, 1, 1], 226 | ] 227 | 228 | return MobileNetV3(cfgs, mode='small', **kwargs) 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @time: 2021/4/17 20:14 4 | 5 | @ author: ysx 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.attention import CoordAtt 11 | 12 | 13 | class Mish(nn.Module): 14 | def __init__(self): 15 | super().__init__() 16 | 17 | def forward(self, x): 18 | return x * (torch.tanh(F.softplus(x))) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 23 | 24 | 25 | class Res2Block(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, kernel_size=5, stride=1, downsample=None, groups=1, base_width=26, 29 | dilation=1, scale=4, first_block=True, norm_layer=nn.BatchNorm1d, 30 | atten=True): 31 | 32 | super(Res2Block, self).__init__() 33 | if norm_layer is None: 34 | norm_layer = nn.BatchNorm1d 35 | 36 | width = int(planes * (base_width / 64.)) * groups 37 | # print(width) 38 | 39 | self.atten = atten 40 | 41 | self.conv1 = conv1x1(inplanes, width * scale) 42 | self.bn1 = norm_layer(width * scale) 43 | 44 | # If scale == 1, single conv else identity & (scale - 1) convs 45 | nb_branches = max(scale, 2) - 1 46 | if first_block: 47 | self.pool = nn.AvgPool1d(kernel_size=3, stride=stride, padding=1) 48 | self.convs = nn.ModuleList([nn.Conv1d(width, width, kernel_size=kernel_size, stride=stride, 49 | padding=kernel_size // 2, groups=1, bias=False, dilation=1) 50 | for _ in range(nb_branches)]) 51 | self.bns = nn.ModuleList([norm_layer(width) for _ in range(nb_branches)]) 52 | self.first_block = first_block 53 | self.scale = scale 54 | 55 | self.conv3 = conv1x1(width * scale, planes * self.expansion) 56 | 57 | self.relu = Mish() 58 | self.bn3 = norm_layer(planes * self.expansion) # bn reverse 59 | 60 | # self.dropout = nn.Dropout(.1) 61 | 62 | if self.atten is True: 63 | # self.attention = SELayer(planes * self.expansion) 64 | # self.attention = CBAM(planes * self.expansion) 65 | self.attention = CoordAtt(planes * self.expansion, planes * self.expansion) 66 | else: 67 | self.attention = None 68 | 69 | self.shortcut = nn.Sequential() 70 | if stride != 1 or inplanes != self.expansion * planes: 71 | self.shortcut = nn.Sequential( 72 | nn.Conv1d(inplanes, self.expansion * planes, kernel_size=1, stride=stride), 73 | nn.BatchNorm1d(self.expansion * planes) 74 | ) 75 | 76 | def forward(self, x): 77 | 78 | out = self.conv1(x) 79 | 80 | out = self.relu(out) 81 | out = self.bn1(out) # bn reverse 82 | # Chunk the feature map 83 | xs = torch.chunk(out, self.scale, dim=1) 84 | # Initialize output as empty tensor for proper concatenation 85 | y = 0 86 | for idx, conv in enumerate(self.convs): 87 | # Add previous y-value 88 | if self.first_block: 89 | y = xs[idx] 90 | else: 91 | y += xs[idx] 92 | y = conv(y) 93 | y = self.relu(self.bns[idx](y)) 94 | # Concatenate with previously computed values 95 | out = torch.cat((out, y), 1) if idx > 0 else y 96 | # Use last chunk as x1 97 | if self.scale > 1: 98 | if self.first_block: 99 | out = torch.cat((out, self.pool(xs[len(self.convs)])), 1) 100 | else: 101 | out = torch.cat((out, xs[len(self.convs)]), 1) 102 | 103 | # out = self.dropout(out) 104 | 105 | out = self.conv3(out) 106 | out = self.bn3(out) 107 | 108 | if self.atten: 109 | out = self.attention(out) 110 | 111 | out += self.shortcut(x) 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class MyNet(nn.Module): 118 | 119 | def __init__(self, num_classes=5, input_channels=12, single_view=False): 120 | super(MyNet, self).__init__() 121 | 122 | self.single_view = single_view 123 | 124 | self.conv1 = nn.Conv1d(input_channels, 64, kernel_size=25, stride=1, padding=0, bias=False) 125 | self.bn1 = nn.BatchNorm1d(64) 126 | self.relu = Mish() 127 | 128 | self.layer1 = Res2Block(inplanes=64, planes=128, kernel_size=15, stride=2, atten=True) 129 | 130 | self.layer2 = Res2Block(inplanes=128, planes=128, kernel_size=15, stride=2, atten=True) 131 | 132 | self.avgpool = nn.AdaptiveAvgPool1d(1) 133 | 134 | if not self.single_view: 135 | self.fc = nn.Linear(128, num_classes) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 140 | elif isinstance(m, nn.BatchNorm2d): 141 | nn.init.constant_(m.weight, 1) 142 | nn.init.constant_(m.bias, 0) 143 | 144 | def forward(self, x): 145 | 146 | output = self.conv1(x) 147 | output = self.bn1(output) 148 | output = self.relu(output) 149 | 150 | output = self.layer1(output) 151 | 152 | output = self.layer2(output) 153 | 154 | output = self.avgpool(output) 155 | 156 | output = output.view(output.size(0), -1) 157 | 158 | if not self.single_view: 159 | output = self.fc(output) 160 | 161 | return output 162 | 163 | 164 | class AdaptiveWeight(nn.Module): 165 | def __init__(self, plances=32): 166 | super(AdaptiveWeight, self).__init__() 167 | 168 | self.fc = nn.Linear(plances, 1) 169 | # self.bn = nn.BatchNorm1d(1) 170 | self.sig = nn.Sigmoid() 171 | 172 | def forward(self, x): 173 | out = self.fc(x) 174 | # out = self.bn(out) 175 | out = self.sig(out) 176 | 177 | return out 178 | 179 | 180 | class MyNet6View(nn.Module): 181 | 182 | def __init__(self, num_classes=5): 183 | super(MyNet6View, self).__init__() 184 | 185 | self.MyNet1 = MyNet(input_channels=1, single_view=True) 186 | self.MyNet2 = MyNet(input_channels=2, single_view=True) 187 | self.MyNet3 = MyNet(input_channels=2, single_view=True) 188 | self.MyNet4 = MyNet(input_channels=2, single_view=True) 189 | self.MyNet5 = MyNet(input_channels=2, single_view=True) 190 | self.MyNet6 = MyNet(input_channels=3, single_view=True) 191 | 192 | self.fuse_weight_1 = AdaptiveWeight(128) 193 | self.fuse_weight_2 = AdaptiveWeight(128) 194 | self.fuse_weight_3 = AdaptiveWeight(128) 195 | self.fuse_weight_4 = AdaptiveWeight(128) 196 | self.fuse_weight_5 = AdaptiveWeight(128) 197 | self.fuse_weight_6 = AdaptiveWeight(128) 198 | 199 | self.fc = nn.Linear(128, num_classes) 200 | 201 | for m in self.modules(): 202 | if isinstance(m, nn.Conv2d): 203 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 204 | elif isinstance(m, nn.BatchNorm2d): 205 | nn.init.constant_(m.weight, 1) 206 | nn.init.constant_(m.bias, 0) 207 | 208 | def forward(self, x): 209 | 210 | outputs_view = [self.MyNet1(x[:, 3, :].unsqueeze(1)), 211 | self.MyNet2(torch.cat((x[:, 0, :].unsqueeze(1), x[:, 4, :].unsqueeze(1)), dim=1)), 212 | self.MyNet3(x[:, 6:8, :]), 213 | self.MyNet4(x[:, 8:10, :]), 214 | self.MyNet5(x[:, 10:12, :]), 215 | self.MyNet6(torch.cat((x[:, 1:3, :], x[:, 5, :].unsqueeze(1)), dim=1))] 216 | 217 | fuse_weight_1 = self.fuse_weight_1(outputs_view[0]) 218 | fuse_weight_2 = self.fuse_weight_2(outputs_view[1]) 219 | fuse_weight_3 = self.fuse_weight_3(outputs_view[2]) 220 | fuse_weight_4 = self.fuse_weight_4(outputs_view[3]) 221 | fuse_weight_5 = self.fuse_weight_5(outputs_view[4]) 222 | fuse_weight_6 = self.fuse_weight_6(outputs_view[5]) 223 | 224 | output = fuse_weight_1 * outputs_view[0] + fuse_weight_2 * outputs_view[1] + fuse_weight_3 * \ 225 | outputs_view[2] + fuse_weight_4 * outputs_view[3] + fuse_weight_5 * outputs_view[ 226 | 4] + fuse_weight_6 * outputs_view[5] 227 | 228 | x_out = self.fc(output) 229 | 230 | return x_out 231 | 232 | -------------------------------------------------------------------------------- /models/resnet1d_wang.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def conv(in_planes, out_planes, stride=1, kernel_size=3): 4 | return nn.Conv1d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 5 | padding=(kernel_size - 1) // 2, bias=False) 6 | 7 | 8 | class BasicBlock1d(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, inplanes, planes, stride=1, kernel_size=[3, 3], downsample=None): 12 | super().__init__() 13 | if (isinstance(kernel_size, int)): kernel_size = [kernel_size, kernel_size // 2 + 1] 14 | 15 | self.conv1 = conv(inplanes, planes, stride=stride, kernel_size=kernel_size[0]) 16 | self.bn1 = nn.BatchNorm1d(planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.conv2 = conv(planes, planes, kernel_size=kernel_size[1]) 19 | self.bn2 = nn.BatchNorm1d(planes) 20 | self.downsample = downsample 21 | self.stride = stride 22 | 23 | def forward(self, x): 24 | residual = x 25 | 26 | out = self.conv1(x) 27 | out = self.bn1(out) 28 | out = self.relu(out) 29 | 30 | out = self.conv2(out) 31 | out = self.bn2(out) 32 | 33 | if self.downsample is not None: 34 | residual = self.downsample(x) 35 | 36 | out += residual 37 | out = self.relu(out) 38 | 39 | return out 40 | 41 | 42 | class ResNet1d(nn.Sequential): 43 | def __init__(self, block, layers, kernel_size=5, num_classes=2, input_channels=12): 44 | super(ResNet1d, self).__init__() 45 | self.inplanes = 64 46 | self.conv = nn.Conv1d(input_channels, 64, kernel_size=8, stride=2, padding=7 // 2, bias=False) 47 | self.bn = nn.BatchNorm1d(64) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 50 | self.block1 = self._make_layer(block, 64, layers[0], kernel_size=kernel_size) 51 | self.block2 = self._make_layer(block, 128, layers[1], stride=2, kernel_size=kernel_size) 52 | self.block3 = self._make_layer(block, 128, layers[1], stride=2, kernel_size=kernel_size) 53 | 54 | self.Avgpool = nn.AdaptiveAvgPool1d(1) 55 | self.fc = nn.Linear(128, num_classes) 56 | 57 | def _make_layer(self, block, planes, blocks, stride=1, kernel_size=3): 58 | downsample = None 59 | 60 | if stride != 1 or self.inplanes != planes * block.expansion: 61 | downsample = nn.Sequential( 62 | nn.Conv1d(self.inplanes, planes * block.expansion, 63 | kernel_size=1, stride=stride, bias=False), 64 | nn.BatchNorm1d(planes * block.expansion), 65 | ) 66 | 67 | layers = [] 68 | layers.append(block(self.inplanes, planes, stride, kernel_size, downsample)) 69 | self.inplanes = planes * block.expansion 70 | for i in range(1, blocks): 71 | layers.append(block(self.inplanes, planes)) 72 | 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | output = self.relu(self.bn(self.conv(x))) 77 | output = self.maxpool(output) 78 | output = self.block1(output) 79 | output = self.block2(output) 80 | output = self.block3(output) 81 | output = self.Avgpool(output) 82 | output = output.view(output.size(0), -1) 83 | output = self.fc(output) 84 | return output 85 | 86 | 87 | def resnet1d_wang(**kwargs): 88 | return ResNet1d(BasicBlock1d, [1, 1, 1], **kwargs) 89 | 90 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | original code from rwightman: 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 4 | """ 5 | from functools import partial 6 | from collections import OrderedDict 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def drop_path(x, drop_prob: float = 0., training: bool = False): 12 | """ 13 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 14 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 15 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 16 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 17 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 18 | 'survival rate' as the argument. 19 | """ 20 | if drop_prob == 0. or not training: 21 | return x 22 | keep_prob = 1 - drop_prob 23 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 24 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 25 | random_tensor.floor_() # binarize 26 | output = x.div(keep_prob) * random_tensor 27 | return output 28 | 29 | 30 | class DropPath(nn.Module): 31 | """ 32 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 33 | """ 34 | 35 | def __init__(self, drop_prob=None): 36 | super(DropPath, self).__init__() 37 | self.drop_prob = drop_prob 38 | 39 | def forward(self, x): 40 | return drop_path(x, self.drop_prob, self.training) 41 | 42 | 43 | class PatchEmbed(nn.Module): 44 | """ 45 | 2D Image to Patch Embedding 46 | """ 47 | 48 | def __init__(self, img_size=1000, patch_size=50, in_c=12, embed_dim=600, norm_layer=None): 49 | super().__init__() 50 | img_size = (1, img_size) 51 | patch_size = (1, patch_size) 52 | self.img_size = img_size 53 | self.patch_size = patch_size 54 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 55 | self.num_patches = self.grid_size[0] * self.grid_size[1] 56 | 57 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) 58 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 59 | 60 | def forward(self, x): 61 | B, C, H, W = x.shape 62 | assert H == self.img_size[0] and W == self.img_size[1], \ 63 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 64 | 65 | x = self.proj(x).flatten(2).transpose(1, 2) 66 | x = self.norm(x) 67 | return x 68 | 69 | 70 | class Attention(nn.Module): 71 | def __init__(self, 72 | dim, 73 | num_heads=8, 74 | qkv_bias=False, 75 | qk_scale=None, 76 | attn_drop_ratio=0., 77 | proj_drop_ratio=0.): 78 | super(Attention, self).__init__() 79 | self.num_heads = num_heads 80 | head_dim = dim // num_heads 81 | self.scale = qk_scale or head_dim ** -0.5 82 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 83 | self.attn_drop = nn.Dropout(attn_drop_ratio) 84 | self.proj = nn.Linear(dim, dim) 85 | self.proj_drop = nn.Dropout(proj_drop_ratio) 86 | 87 | def forward(self, x): 88 | B, N, C = x.shape 89 | 90 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 91 | q, k, v = qkv[0], qkv[1], qkv[2] 92 | 93 | attn = (q @ k.transpose(-2, -1)) * self.scale 94 | attn = attn.softmax(dim=-1) 95 | attn = self.attn_drop(attn) 96 | 97 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 98 | x = self.proj(x) 99 | x = self.proj_drop(x) 100 | return x 101 | 102 | 103 | class Mlp(nn.Module): 104 | """ 105 | MLP as used in Vision Transformer, MLP-Mixer and related networks 106 | """ 107 | 108 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 109 | super().__init__() 110 | out_features = out_features or in_features 111 | hidden_features = hidden_features or in_features 112 | self.fc1 = nn.Linear(in_features, hidden_features) 113 | self.act = act_layer() 114 | self.fc2 = nn.Linear(hidden_features, out_features) 115 | self.drop = nn.Dropout(drop) 116 | 117 | def forward(self, x): 118 | x = self.fc1(x) 119 | x = self.act(x) 120 | x = self.drop(x) 121 | x = self.fc2(x) 122 | x = self.drop(x) 123 | return x 124 | 125 | 126 | class Block(nn.Module): 127 | def __init__(self, 128 | dim, 129 | num_heads, 130 | mlp_ratio=4., 131 | qkv_bias=False, 132 | qk_scale=None, 133 | drop_ratio=0., 134 | attn_drop_ratio=0., 135 | drop_path_ratio=0., 136 | act_layer=nn.GELU, 137 | norm_layer=nn.LayerNorm): 138 | super(Block, self).__init__() 139 | self.norm1 = norm_layer(dim) 140 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 141 | attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio) 142 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 143 | self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity() 144 | self.norm2 = norm_layer(dim) 145 | mlp_hidden_dim = int(dim * mlp_ratio) 146 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio) 147 | 148 | def forward(self, x): 149 | x = x + self.drop_path(self.attn(self.norm1(x))) 150 | x = x + self.drop_path(self.mlp(self.norm2(x))) 151 | return x 152 | 153 | 154 | class VisionTransformer(nn.Module): 155 | def __init__(self, img_size=1000, patch_size=50, in_c=12, num_classes=10, 156 | embed_dim=600, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, 157 | qk_scale=None, representation_size=None, distilled=False, drop_ratio=0., 158 | attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None, 159 | act_layer=None): 160 | """ 161 | Args: 162 | img_size (int, tuple): input image size 163 | patch_size (int, tuple): patch size 164 | in_c (int): number of input channels 165 | num_classes (int): number of classes for classification head 166 | embed_dim (int): embedding dimension 167 | depth (int): depth of transformer 168 | num_heads (int): number of attention heads 169 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 170 | qkv_bias (bool): enable bias for qkv if True 171 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 172 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 173 | distilled (bool): model includes a distillation token and head as in DeiT models 174 | drop_ratio (float): dropout rate 175 | attn_drop_ratio (float): attention dropout rate 176 | drop_path_ratio (float): stochastic depth rate 177 | embed_layer (nn.Module): patch embedding layer 178 | norm_layer: (nn.Module): normalization layer 179 | """ 180 | super(VisionTransformer, self).__init__() 181 | self.num_classes = num_classes 182 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 183 | self.num_tokens = 2 if distilled else 1 184 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 185 | act_layer = act_layer or nn.GELU 186 | 187 | self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim) 188 | num_patches = self.patch_embed.num_patches 189 | 190 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 191 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 192 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 193 | self.pos_drop = nn.Dropout(p=drop_ratio) 194 | 195 | dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule 196 | self.blocks = nn.Sequential(*[ 197 | Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 198 | drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i], 199 | norm_layer=norm_layer, act_layer=act_layer) 200 | for i in range(depth) 201 | ]) 202 | self.norm = norm_layer(embed_dim) 203 | 204 | # Representation layer 205 | if representation_size and not distilled: 206 | self.has_logits = True 207 | self.num_features = representation_size 208 | self.pre_logits = nn.Sequential(OrderedDict([ 209 | ("fc", nn.Linear(embed_dim, representation_size)), 210 | ("act", nn.Tanh()) 211 | ])) 212 | else: 213 | self.has_logits = False 214 | self.pre_logits = nn.Identity() 215 | 216 | # Classifier head(s) 217 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 218 | self.head_dist = None 219 | if distilled: 220 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 221 | 222 | # Weight init 223 | nn.init.trunc_normal_(self.pos_embed, std=0.02) 224 | if self.dist_token is not None: 225 | nn.init.trunc_normal_(self.dist_token, std=0.02) 226 | 227 | nn.init.trunc_normal_(self.cls_token, std=0.02) 228 | self.apply(_init_vit_weights) 229 | 230 | def forward_features(self, x): 231 | x = x.unsqueeze(2) 232 | x = self.patch_embed(x) 233 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 234 | if self.dist_token is None: 235 | x = torch.cat((cls_token, x), dim=1) 236 | else: 237 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 238 | 239 | x = self.pos_drop(x + self.pos_embed) 240 | x = self.blocks(x) 241 | x = self.norm(x) 242 | if self.dist_token is None: 243 | return self.pre_logits(x[:, 0]) 244 | else: 245 | return x[:, 0], x[:, 1] 246 | 247 | def forward(self, x): 248 | x = self.forward_features(x) 249 | if self.head_dist is not None: 250 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) 251 | if self.training and not torch.jit.is_scripting(): 252 | # during inference, return the average of both classifier predictions 253 | return x, x_dist 254 | else: 255 | return (x + x_dist) / 2 256 | else: 257 | x = self.head(x) 258 | return x 259 | 260 | 261 | def _init_vit_weights(m): 262 | """ 263 | ViT weight initialization 264 | :param m: module 265 | """ 266 | if isinstance(m, nn.Linear): 267 | nn.init.trunc_normal_(m.weight, std=.01) 268 | if m.bias is not None: 269 | nn.init.zeros_(m.bias) 270 | elif isinstance(m, nn.Conv2d): 271 | nn.init.kaiming_normal_(m.weight, mode="fan_out") 272 | if m.bias is not None: 273 | nn.init.zeros_(m.bias) 274 | elif isinstance(m, nn.LayerNorm): 275 | nn.init.zeros_(m.bias) 276 | nn.init.ones_(m.weight) 277 | 278 | 279 | def vit(num_classes: int = 10, has_logits: bool = True): 280 | """ 281 | ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 282 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 283 | weights ported from official Google JAX impl: 284 | https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth 285 | """ 286 | model = VisionTransformer(img_size=1000, 287 | patch_size=50, 288 | embed_dim=600, 289 | depth=2, 290 | num_heads=12, 291 | representation_size=512 if has_logits else None, 292 | num_classes=num_classes) 293 | return model 294 | 295 | -------------------------------------------------------------------------------- /models/xresnet1d101.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import inspect 3 | 4 | 5 | def delegates(to=None, keep=False): 6 | "Decorator: replace `**kwargs` in signature with params from `to`" 7 | 8 | def _f(f): 9 | if to is None: 10 | to_f, from_f = f.__base__.__init__, f.__init__ 11 | else: 12 | to_f, from_f = to, f 13 | sig = inspect.signature(from_f) 14 | sigd = dict(sig.parameters) 15 | k = sigd.pop('kwargs') 16 | s2 = {k: v for k, v in inspect.signature(to_f).parameters.items() 17 | if v.default != inspect.Parameter.empty and k not in sigd} 18 | sigd.update(s2) 19 | if keep: sigd['kwargs'] = k 20 | from_f.__signature__ = sig.replace(parameters=sigd.values()) 21 | return f 22 | 23 | return _f 24 | 25 | 26 | def AvgPool(ks=2, stride=None, padding=0, ceil_mode=False): 27 | return nn.AvgPool1d(ks, stride=stride, padding=padding, ceil_mode=ceil_mode) 28 | 29 | 30 | def MaxPool(ks=2, stride=None, padding=0, ceil_mode=False): 31 | return nn.MaxPool1d(ks, stride=stride, padding=padding) 32 | 33 | 34 | def AdaptiveAvgPool(sz=1): 35 | return nn.AdaptiveAvgPool1d(sz) 36 | 37 | 38 | class Flatten(nn.Module): 39 | def forward(self, x): return x.view(x.size(0), -1) 40 | 41 | 42 | def init_cnn(m): 43 | if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0) 44 | if isinstance(m, (nn.Conv2d, nn.Conv1d, nn.Linear)): nn.init.kaiming_normal_(m.weight) 45 | for l in m.children(): init_cnn(l) 46 | 47 | 48 | class ConvLayer(nn.Sequential): 49 | """ 50 | Creates a sequence of Conv, Act, Norm 51 | """ 52 | 53 | @delegates(nn.Conv1d) 54 | def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, norm='bn', bn_1st=True, 55 | act_cls=nn.ReLU, xtra=None, **kwargs): 56 | if padding is None: padding = ((ks - 1) // 2) 57 | norm = nn.BatchNorm1d(nf) 58 | bias = None if not (not norm) else bias 59 | conv = nn.Conv1d(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs) 60 | layers = [conv] 61 | act_bn = [] 62 | if act_cls is not None: act_bn.append(act_cls()) 63 | if norm: act_bn.append(norm) 64 | if bn_1st: act_bn.reverse() 65 | layers += act_bn 66 | if xtra: layers.append(xtra) 67 | super().__init__(*layers) 68 | 69 | 70 | class ResBlock(nn.Module): 71 | """ 72 | Resnet block from ni to nh with stride 73 | """ 74 | 75 | @delegates(ConvLayer.__init__) 76 | def __init__(self, expansion, ni, nf, stride=1, nh1=None, nh2=None, 77 | norm='bn', act_cls=nn.ReLU, ks=3, pool_first=True, **kwargs): 78 | super(ResBlock, self).__init__() 79 | norm1 = norm2 = norm 80 | pool = AvgPool 81 | if nh2 is None: nh2 = nf 82 | if nh1 is None: nh1 = nh2 83 | nf, ni = nf * expansion, ni * expansion 84 | k0 = dict(norm=norm1, act_cls=act_cls, **kwargs) 85 | k1 = dict(norm=norm2, act_cls=None, **kwargs) 86 | conv_path = [ 87 | ConvLayer(ni, nh2, ks, stride=stride, **k0), 88 | ConvLayer(nh2, nf, ks, **k1) 89 | ] if expansion == 1 else [ 90 | ConvLayer(ni, nh1, 1, **k0), 91 | ConvLayer(nh1, nh2, ks, stride=stride, **k0), 92 | ConvLayer(nh2, nf, 1, **k1)] 93 | self.conv_path = nn.Sequential(*conv_path) 94 | id_path = [] 95 | if ni != nf: id_path.append(ConvLayer(ni, nf, 1, norm=norm, act_cls=None, **kwargs)) 96 | if stride != 1: id_path.insert((1, 0)[pool_first], pool(stride, ceil_mode=True)) 97 | self.id_path = nn.Sequential(*id_path) 98 | self.act = nn.ReLU(inplace=True) if act_cls is nn.ReLU else act_cls() 99 | 100 | def forward(self, x): 101 | return self.act(self.conv_path(x) + self.id_path(x)) 102 | 103 | 104 | class XResNet(nn.Sequential): 105 | @delegates(ResBlock) 106 | def __init__(self, block, expansion, layers, p=0.0, input_channels=12, num_classes=5, stem_szs=(32, 32, 64), 107 | widen=1.0, norm='bn', act_cls=nn.ReLU, ks=3, stride=2, **kwargs): 108 | self.block, self.expansion, self.act_cls, self.ks = block, expansion, act_cls, ks 109 | if ks % 2 == 0: raise Exception('Kernel size has to be odd') 110 | self.norm = norm 111 | stem_szs = [input_channels, *stem_szs] 112 | stem = [ 113 | ConvLayer(stem_szs[i], stem_szs[i + 1], ks=ks, stride=stride if i == 0 else 1, norm=norm, act_cls=act_cls) 114 | for i in range(3)] 115 | # block_szs = [int(o * widen) for o in [64, 128, 256, 512] + [256] * (len(layers) - 4)] 116 | block_szs = [int(o * widen) for o in [64, 64, 64, 64] + [32] * (len(layers) - 4)] 117 | block_szs = [64 // expansion] + block_szs 118 | blocks = self._make_blocks(layers, block_szs, stride, **kwargs) 119 | 120 | # head = head_layer(inplanes=block_szs[-1] * expansion, ps_head=0.5, num_classes=num_classes) 121 | 122 | super().__init__( 123 | *stem, MaxPool(ks=ks, stride=stride, padding=ks // 2), 124 | *blocks, 125 | # head, 126 | AdaptiveAvgPool(sz=1), Flatten(), nn.Dropout(p), 127 | nn.Linear(block_szs[-1] * expansion, num_classes), 128 | ) 129 | init_cnn(self) 130 | 131 | def _make_blocks(self, layers, block_szs, stride, **kwargs): 132 | return [self._make_layer(ni=block_szs[i], nf=block_szs[i + 1], blocks=l, 133 | stride=1 if i == 0 else stride, **kwargs) 134 | for i, l in enumerate(layers)] 135 | 136 | def _make_layer(self, ni, nf, blocks, stride, **kwargs): 137 | return nn.Sequential( 138 | *[self.block(self.expansion, ni if i == 0 else nf, nf, stride=stride if i == 0 else 1, 139 | norm=self.norm, act_cls=self.act_cls, ks=self.ks, **kwargs) 140 | for i in range(blocks)]) 141 | 142 | 143 | def xresnet1d101(**kwargs): 144 | return XResNet(ResBlock, 4, [3, 4, 23, 3], input_channels=12, **kwargs) 145 | 146 | 147 | def xresnet1d50(**kwargs): 148 | return XResNet(ResBlock, 4, [3, 4, 6, 3], input_channels=12, **kwargs) 149 | 150 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | @time: 2021/4/7 15:21 4 | 5 | @ author: 6 | ''' 7 | 8 | import time 9 | from sklearn.metrics import average_precision_score, confusion_matrix, roc_auc_score 10 | import numpy as np 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def get_n_params(model): 16 | pp = 0 17 | for p in list(model.parameters()): 18 | nn = 1 19 | for s in list(p.size()): 20 | nn = nn * s 21 | pp += nn 22 | return pp 23 | 24 | 25 | def compute_mAP(y_true, y_pred): 26 | AP = [] 27 | for i in range(len(y_true)): 28 | AP.append(average_precision_score(y_true[i], y_pred[i])) 29 | return np.mean(AP) 30 | 31 | 32 | def compute_TPR(y_true, y_pred): 33 | y_true = np.array(y_true) 34 | y_pred = np.array(y_pred) 35 | sum, count = 0.0, 0 36 | for i, _ in enumerate(y_pred): 37 | y_pred[i] = np.where(y_pred[i] >= 0.5, 1, 0) 38 | (x, y) = confusion_matrix(y_true=y_true[i], y_pred=y_pred[i])[1] 39 | sum += y / (x + y) 40 | count += 1 41 | 42 | return sum / count 43 | 44 | 45 | def compute_AUC(y_true, y_pred): 46 | y_true = np.array(y_true) 47 | y_pred = np.array(y_pred) 48 | class_auc = [] 49 | for i in range(len(y_true[1])): 50 | class_auc.append(roc_auc_score(y_true[:, i], y_pred[:, i])) 51 | auc = roc_auc_score(y_true, y_pred) 52 | return auc, class_auc 53 | 54 | 55 | # PRINT TIME 56 | def print_time_cost(since): 57 | time_elapsed = time.time() - since 58 | return '{:.0f}m{:.0f}s'.format(time_elapsed // 60, time_elapsed % 60) 59 | 60 | 61 | # KD loss 62 | class KdLoss(nn.Module): 63 | def __init__(self, alpha, temperature): 64 | super(KdLoss, self).__init__() 65 | self.alpha = alpha 66 | self.T = temperature 67 | 68 | def forward(self, outputs, labels, teacher_outputs): 69 | kd_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(outputs / self.T, dim=1), 70 | F.softmax(teacher_outputs / self.T, dim=1)) * ( 71 | self.alpha * self.T * self.T) + F.binary_cross_entropy_with_logits(outputs, labels) * ( 72 | 1. - self.alpha) 73 | return kd_loss 74 | 75 | --------------------------------------------------------------------------------