├── .gitignore ├── LICENSE ├── README.md ├── contrast.py ├── environment.yml ├── mit_utils.py ├── net.py └── transform.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 XueJiang16 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ssl-torch 2 | 3 | This is the code implementation of paper *Self-supervised Contrastive Learning for EEG-based Sleep Staging*. 4 | 5 | ### Environment Setup 6 | 7 | We recommend to setup the environment through `conda`. 8 | 9 | ```shell 10 | $ conda env create -f environment.yml 11 | ``` 12 | 13 | ### Data Preparation 14 | 15 | The dataset Sleep-edf can be downloaded [here](https://physionet.org/content/sleep-edfx/1.0.0/). 16 | 17 | ### Training 18 | 19 | We use Pytorch 3.6 to build the network, which is trained on the NVIDIA GTX 1080Ti with the batch size of 128. The network is trained for 70 epochs. We use the SGD optimizer with the momentum $= 0.9$. 20 | 21 | For training the network, run 22 | 23 | ```shell 24 | $ python contrast.py -F1 time_warp -F2 time_warp 25 | ``` 26 | 27 | 28 | -------------------------------------------------------------------------------- /contrast.py: -------------------------------------------------------------------------------- 1 | from net import resnet18, resnet34, resnet50, resnet101, resnet152 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | # import pandas as pd 7 | import tqdm 8 | import mit_utils as utils 9 | # import analytics 10 | import time 11 | import os, shutil 12 | from mail import mail_it 13 | from sklearn.metrics import confusion_matrix 14 | from sklearn.metrics import f1_score 15 | 16 | import random 17 | 18 | from torch.optim.lr_scheduler import CosineAnnealingLR 19 | from warmup_scheduler import GradualWarmupScheduler 20 | 21 | import argparse 22 | 23 | 24 | 25 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 26 | # parser.add_argument('-d', '--dataset', type=int) 27 | # parser.add_argument('-g', '--gpu_id', type=str, default=0) 28 | parser.add_argument('-F1', '--transform_function_1', type=str) 29 | parser.add_argument('-F2', '--transform_function_2', type=str) 30 | # parser.add_argument('-e', '--epoch', type=int, default=60) 31 | 32 | 33 | arg = parser.parse_args() 34 | 35 | torch.set_default_tensor_type(torch.FloatTensor) 36 | 37 | device = "cuda" 38 | 39 | log_dir = "logs" 40 | model_name = 'resnet17' 41 | model_save_dir = '%s/%s_%s' % (log_dir, model_name, time.strftime("%m%d%H%M")) 42 | 43 | os.makedirs(model_save_dir, exist_ok=True) 44 | log_file = "%s_%s_%s.log" % (arg.transform_function_1, arg.transform_function_2, time.strftime("%m%d%H%M")) 45 | 46 | log_templete = {"acc": None, 47 | "cm": None, 48 | "f1": None, 49 | "per F1":None, 50 | "epoch":None, 51 | } 52 | 53 | data = np.load('data.npz') 54 | orig_x = data['x'] 55 | x = np.zeros((orig_x.shape[0],3072)) 56 | x[:,36:3036] = orig_x 57 | x = x[:,None,:] 58 | y = data['y'] 59 | 60 | 61 | from sklearn.model_selection import train_test_split 62 | 63 | x_train, x_test, y_train, y_test = \ 64 | train_test_split(x, y, test_size=0.3) 65 | x_train = torch.tensor(x_train, dtype=torch.float).to(device) 66 | x_test = torch.tensor(x_test, dtype=torch.float).to(device) 67 | y_train = torch.tensor(y_train, dtype=torch.long).to(device) 68 | y_test = torch.tensor(y_test, dtype=torch.long).to(device) 69 | print(x_train.shape) 70 | 71 | import torch.nn.functional as F 72 | from transform import Transform 73 | 74 | def save_ckpt(state, is_best, model_save_dir, message='best_w.pth'): 75 | current_w = os.path.join(model_save_dir, 'latest_w.pth') 76 | best_w = os.path.join(model_save_dir, message) 77 | torch.save(state, current_w) 78 | if is_best: shutil.copyfile(current_w, best_w) 79 | 80 | def transform(x, mode): 81 | x_ = x.cpu().numpy() 82 | 83 | Trans = Transform() 84 | if mode == 'time_warp': 85 | pieces = random.randint(5,20) 86 | stretch = random.uniform(1.5,4) 87 | squeeze = random.uniform(0.25,0.67) 88 | x_ = Trans.time_warp(x_, 100, pieces, stretch, squeeze) 89 | elif mode == 'noise': 90 | factor = random.uniform(10,20) 91 | x_ = Trans.add_noise_with_SNR(x_,factor) 92 | elif mode == 'scale': 93 | x_ = Trans.scaled(x_,[0.3,3]) 94 | elif mode == 'negate': 95 | x_ = Trans.negate(x_) 96 | elif mode == 'hor_flip': 97 | x_ = Trans.hor_filp(x_) 98 | elif mode == 'permute': 99 | pieces = random.randint(5,20) 100 | x_ = Trans.permute(x_,pieces) 101 | elif mode == 'cutout_resize': 102 | pieces = random.randint(5, 20) 103 | x_ = Trans.cutout_resize(x_, pieces) 104 | elif mode == 'cutout_zero': 105 | pieces = random.randint(5, 20) 106 | x_ = Trans.cutout_zero(x_, pieces) 107 | elif mode == 'crop_resize': 108 | size = random.uniform(0.25,0.75) 109 | x_ = Trans.crop_resize(x_, size) 110 | elif mode == 'move_avg': 111 | n = random.randint(3, 10) 112 | x_ = Trans.move_avg(x_,n, mode="same") 113 | # to test 114 | elif mode == 'lowpass': 115 | order = random.randint(3, 10) 116 | cutoff = random.uniform(5,20) 117 | x_ = Trans.lowpass_filter(x_, order, [cutoff]) 118 | elif mode == 'highpass': 119 | order = random.randint(3, 10) 120 | cutoff = random.uniform(5, 10) 121 | x_ = Trans.highpass_filter(x_, order, [cutoff]) 122 | elif mode == 'bandpass': 123 | order = random.randint(3, 10) 124 | cutoff_l = random.uniform(1, 5) 125 | cutoff_h = random.uniform(20, 40) 126 | cutoff = [cutoff_l, cutoff_h] 127 | x_ = Trans.bandpass_filter(x_, order, cutoff) 128 | 129 | else: 130 | print("Error") 131 | 132 | x_ = x_.copy() 133 | x_ = x_[:,None,:] 134 | return x_ 135 | 136 | def comtrast_loss(x, criterion): 137 | LARGE_NUM = 1e9 138 | temperature = 0.1 139 | x = F.normalize(x, dim=-1) 140 | 141 | num = int(x.shape[0] / 2) 142 | hidden1, hidden2 = torch.split(x, num) 143 | 144 | 145 | hidden1_large = hidden1 146 | hidden2_large = hidden2 147 | labels = torch.arange(0,num).to('cuda') 148 | masks = F.one_hot(torch.arange(0,num), num).to('cuda') 149 | 150 | 151 | logits_aa = torch.matmul(hidden1, hidden1_large.T) / temperature 152 | logits_aa = logits_aa - masks * LARGE_NUM 153 | logits_bb = torch.matmul(hidden2, hidden2_large.T) / temperature 154 | logits_bb = logits_bb - masks * LARGE_NUM 155 | logits_ab = torch.matmul(hidden1, hidden2_large.T) / temperature 156 | logits_ba = torch.matmul(hidden2, hidden1_large.T) / temperature 157 | # print(labels) 158 | # 159 | # print(torch.cat([logits_ab, logits_aa], 1).shape) 160 | 161 | loss_a = criterion(torch.cat([logits_ab, logits_aa], 1), 162 | labels) 163 | loss_b = criterion(torch.cat([logits_ba, logits_bb], 1), 164 | labels) 165 | loss = torch.mean(loss_a + loss_b) 166 | return loss, labels, logits_ab 167 | 168 | net = resnet18(classification=False).to(device) 169 | net = nn.DataParallel(net) 170 | criterion = nn.CrossEntropyLoss().to(device) 171 | 172 | batch_size = 512 173 | 174 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1 * (batch_size / 64), momentum=0.9, weight_decay=0.00001) 175 | 176 | epochs = 70 177 | lr_schduler = CosineAnnealingLR(optimizer, T_max=epochs - 10, eta_min=0.05)#default =0.07 178 | scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=10, after_scheduler=lr_schduler) 179 | optimizer.zero_grad() 180 | optimizer.step() 181 | scheduler_warmup.step() 182 | 183 | 184 | 185 | train_dataset = torch.utils.data.TensorDataset(x_train, y_train) 186 | train_iter = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True) 187 | test_dataset = torch.utils.data.TensorDataset(x_test, y_test) 188 | test_iter = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=True) 189 | 190 | target_class = ['W', 'N1', 'N2', 'N3', 'REM'] 191 | 192 | 193 | 194 | val_acc_list = [] 195 | n_train_samples = x_train.shape[0] 196 | iter_per_epoch = n_train_samples // batch_size + 1 197 | best_acc = -1 198 | err = [] 199 | best_err = 1 200 | margin = 1 201 | 202 | 203 | for epoch in range(epochs): 204 | net.train() 205 | loss_sum = 0 206 | evaluation = [] 207 | iter = 0 208 | with tqdm.tqdm(total=iter_per_epoch) as pbar: 209 | error_counter = 0 210 | 211 | for X, y in train_iter: 212 | trans = [] 213 | for i in range(X.shape[0]): 214 | t1 = transform(X[i], arg.transform_function_1) 215 | trans.append(t1) 216 | for i in range(X.shape[0]): 217 | t2 = transform(X[i], arg.transform_function_2) 218 | trans.append(t2) 219 | trans = np.concatenate(trans) 220 | trans = torch.tensor(trans, dtype=torch.float, device="cuda") 221 | 222 | output = net(trans) 223 | 224 | optimizer.zero_grad() 225 | 226 | l, lab_con, log_con = comtrast_loss(output, criterion) 227 | _, log_p = torch.max(log_con.data,1) 228 | evaluation.append((log_p == lab_con).tolist()) 229 | l.backward() 230 | optimizer.step() 231 | loss_sum += l 232 | iter += 1 233 | pbar.set_description("Epoch %d, loss = %.2f" % (epoch, l.data)) 234 | pbar.update(1) 235 | err = l.data 236 | evaluation = [item for sublist in evaluation for item in sublist] 237 | 238 | 239 | train_acc = sum(evaluation) / len(evaluation) 240 | error = 1 - train_acc 241 | current_lr = optimizer.param_groups[0]['lr'] 242 | print("Epoch:", epoch,"lr:", current_lr, "error:", error, " train_loss =", loss_sum.data) 243 | scheduler_warmup.step() 244 | state = {"state_dict": net.state_dict(), "epoch": epoch} 245 | save_ckpt(state, best_err > error, model_save_dir) 246 | best_err = min(best_err, error) 247 | #========================= 248 | 249 | net = resnet18(classification=True).to('cuda') 250 | net = nn.DataParallel(net) 251 | checkpoint = torch.load(os.path.join(model_save_dir,'best_w.pth')) 252 | net.load_state_dict(checkpoint['state_dict'], strict=False) 253 | criterion = nn.CrossEntropyLoss().to(device) 254 | 255 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=0.00001) 256 | 257 | epochs_t = 70 258 | lr_schduler = CosineAnnealingLR(optimizer, T_max=epochs_t - 10, eta_min=0.09)#default =0.07 259 | scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=10, after_scheduler=lr_schduler) 260 | optimizer.zero_grad() 261 | optimizer.step() 262 | scheduler_warmup.step() 263 | 264 | 265 | batch_size = 256 266 | 267 | val_acc_list = [] 268 | n_train_samples = x_train.shape[0] 269 | iter_per_epoch = n_train_samples // batch_size + 1 270 | best_acc = -1 271 | 272 | for epoch in range(epochs_t): 273 | net.train() 274 | loss_sum = 0 275 | evaluation = [] 276 | iter = 0 277 | with tqdm.tqdm(total=iter_per_epoch) as pbar: 278 | for X, y in train_iter: 279 | output = net(X) 280 | _, predicted = torch.max(output.data, 1) 281 | evaluation.append((predicted == y).tolist()) 282 | optimizer.zero_grad() 283 | l = criterion(output, y) 284 | l.backward() 285 | optimizer.step() 286 | loss_sum += l 287 | iter += 1 288 | pbar.set_description("Epoch %d, loss = %.2f" % (epoch, l.data)) 289 | pbar.update(1) 290 | evaluation = [item for sublist in evaluation for item in sublist] 291 | train_acc = sum(evaluation) / len(evaluation) 292 | current_lr = optimizer.param_groups[0]['lr'] 293 | print("Epoch:", epoch,"lr:", current_lr," train_loss =", loss_sum.data, " train_acc =", train_acc) 294 | # scheduler.step() 295 | scheduler_warmup.step() 296 | val_loss = 0 297 | evaluation = [] 298 | pred_v = [] 299 | true_v = [] 300 | with torch.no_grad(): 301 | net.eval() 302 | for X, y in test_iter: 303 | output = net(X) 304 | _, predicted = torch.max(output.data, 1) 305 | evaluation.append((predicted == y).tolist()) 306 | l = criterion(output, y) 307 | val_loss += l 308 | pred_v.append(predicted.tolist()) 309 | true_v.append(y.tolist()) 310 | evaluation = [item for sublist in evaluation for item in sublist] 311 | pred_v = [item for sublist in pred_v for item in sublist] 312 | true_v = [item for sublist in true_v for item in sublist] 313 | 314 | running_acc = sum(evaluation) / len(evaluation) 315 | val_acc_list.append(running_acc) 316 | print("val_loss =", val_loss, "val_acc =", running_acc) 317 | 318 | 319 | state = {"state_dict": net.state_dict(), "epoch": epoch} 320 | save_ckpt(state, best_acc < running_acc, model_save_dir, 'best_cls.pth') 321 | best_acc = max(best_acc, running_acc) 322 | 323 | print("Highest acc:", max(val_acc_list)) 324 | 325 | 326 | 327 | 328 | # =========================test 329 | model = resnet18(classification=True).to('cuda') 330 | checkpoint = torch.load(os.path.join(model_save_dir,'best_cls.pth')) 331 | model.load_state_dict(checkpoint['state_dict'], strict=True) 332 | epoch_b = checkpoint['epoch'] 333 | # model.train() 334 | model.eval() 335 | val_loss = 0 336 | evaluation = [] 337 | pred_v = [] 338 | true_v = [] 339 | with torch.no_grad(): 340 | for X, y in test_iter: 341 | output = model(X) 342 | _, predicted = torch.max(output.data, 1) 343 | evaluation.append((predicted == y).tolist()) 344 | l = criterion(output, y) 345 | val_loss += l 346 | pred_v.append(predicted.tolist()) 347 | true_v.append(y.tolist()) 348 | evaluation = [item for sublist in evaluation for item in sublist] 349 | pred_v = [item for sublist in pred_v for item in sublist] 350 | true_v = [item for sublist in true_v for item in sublist] 351 | 352 | highest_acc = sum(evaluation) / len(evaluation) 353 | print("epoch=" , epoch_b, "val_acc =", highest_acc) 354 | def calculate_all_prediction(confMatrix): 355 | ''' 356 | 计算总精度:对角线上所有值除以总数 357 | ''' 358 | total_sum = confMatrix.sum() 359 | correct_sum = (np.diag(confMatrix)).sum() 360 | prediction = round(100 * float(correct_sum) / float(total_sum), 2) 361 | return prediction 362 | 363 | 364 | def calculate_label_prediction(confMatrix, labelidx): 365 | ''' 366 | 计算某一个类标预测精度:该类被预测正确的数除以该类的总数 367 | ''' 368 | label_total_sum = confMatrix.sum(axis=0)[labelidx] 369 | label_correct_sum = confMatrix[labelidx][labelidx] 370 | prediction = 0 371 | if label_total_sum != 0: 372 | prediction = round(100 * float(label_correct_sum) / float(label_total_sum), 2) 373 | return prediction 374 | 375 | 376 | def calculate_label_recall(confMatrix, labelidx): 377 | ''' 378 | 计算某一个类标的召回率: 379 | ''' 380 | label_total_sum = confMatrix.sum(axis=1)[labelidx] 381 | label_correct_sum = confMatrix[labelidx][labelidx] 382 | recall = 0 383 | if label_total_sum != 0: 384 | recall = round(100 * float(label_correct_sum) / float(label_total_sum), 2) 385 | return recall 386 | 387 | 388 | def calculate_f1(prediction, recall): 389 | if (prediction + recall) == 0: 390 | return 0 391 | return round(2 * prediction * recall / (prediction + recall), 2) 392 | 393 | cm = confusion_matrix(true_v, pred_v) 394 | f1_macro = f1_score(true_v, pred_v, average='macro') 395 | 396 | i=0 397 | f1 = [] 398 | for i in range(5): 399 | r = calculate_label_recall(cm,i) 400 | p = calculate_label_prediction(cm,i) 401 | f = calculate_f1(p,r) 402 | f1.append(f) 403 | 404 | 405 | log_templete["acc"] = '{:.3%}'.format(highest_acc) 406 | log_templete["epoch"] = epoch_b 407 | 408 | log_templete["cm"] = str(cm) 409 | log_templete["f1"] = str(f1_macro) 410 | log_templete["per F1"] = str(f1) 411 | log = log_templete 412 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ssl-torch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - blas=1.0=mkl 8 | - ca-certificates=2020.7.22=0 9 | - certifi=2020.6.20=py37_0 10 | - cudatoolkit=10.0.130=0 11 | - freetype=2.10.2=h5ab3b9f_0 12 | - intel-openmp=2020.2=254 13 | - jpeg=9b=h024ee3a_2 14 | - lcms2=2.11=h396b838_0 15 | - ld_impl_linux-64=2.33.1=h53a641e_7 16 | - libedit=3.1.20191231=h14c3975_1 17 | - libffi=3.3=he6710b0_2 18 | - libgcc-ng=9.1.0=hdf63c60_0 19 | - libpng=1.6.37=hbc83047_0 20 | - libstdcxx-ng=9.1.0=hdf63c60_0 21 | - libtiff=4.1.0=h2733197_1 22 | - lz4-c=1.9.2=he6710b0_1 23 | - mkl=2020.2=256 24 | - mkl-service=2.3.0=py37he904b0f_0 25 | - mkl_fft=1.2.0=py37h23d657b_0 26 | - mkl_random=1.1.1=py37h0573a6f_0 27 | - ncurses=6.2=he6710b0_1 28 | - ninja=1.10.1=py37hfd86e86_0 29 | - olefile=0.46=py37_0 30 | - openssl=1.1.1h=h7b6447c_0 31 | - pillow=7.2.0=py37hb39fc2d_0 32 | - pip=20.2.2=py37_0 33 | - python=3.7.9=h7579374_0 34 | - readline=8.0=h7b6447c_0 35 | - setuptools=49.6.0=py37_0 36 | - six=1.15.0=py_0 37 | - sqlite=3.33.0=h62c20be_0 38 | - tk=8.6.10=hbc83047_0 39 | - torchvision=0.5.0=py37_cu100 40 | - wheel=0.35.1=py_0 41 | - xz=5.2.5=h7b6447c_0 42 | - zlib=1.2.11=h7b6447c_3 43 | - zstd=1.4.5=h9ceee32_0 44 | - pip: 45 | - chardet==3.0.4 46 | - cycler==0.10.0 47 | - h5py==2.10.0 48 | - idna==2.10 49 | - install==1.3.4 50 | - joblib==1.0.0 51 | - kiwisolver==1.2.0 52 | - matplotlib==3.3.2 53 | - mne==0.18.0 54 | - numpy==1.19.4 55 | - opencv-python==4.4.0.46 56 | - pandas==1.1.3 57 | - pyparsing==2.4.7 58 | - python-dateutil==2.8.1 59 | - pytz==2020.1 60 | - requests==2.24.0 61 | - scikit-learn==0.23.2 62 | - scipy==1.5.2 63 | - threadpoolctl==2.1.0 64 | - torch==1.7.1 65 | - tqdm==4.54.1 66 | - typing-extensions==3.7.4.3 67 | - urllib3==1.25.10 68 | - warmup-scheduler==0.3.2 69 | -------------------------------------------------------------------------------- /mit_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Mar 14 23:47:38 2019 4 | 5 | @author: Winham 6 | 7 | 辅助函数 8 | """ 9 | 10 | import warnings 11 | import numpy as np 12 | from scipy.signal import resample 13 | # import pywt 14 | from sklearn.preprocessing import scale 15 | from sklearn.metrics import confusion_matrix 16 | from sklearn.metrics import accuracy_score 17 | from sklearn.utils.multiclass import unique_labels 18 | import matplotlib.pyplot as plt 19 | # =========================================== 20 | warnings.filterwarnings("ignore") 21 | import torch 22 | import numpy as np 23 | import time,os 24 | from sklearn.metrics import f1_score 25 | from torch import nn 26 | 27 | 28 | def mkdirs(path): 29 | if not os.path.exists(path): 30 | os.makedirs(path) 31 | 32 | def calc_f1(y_true, y_pre, threshold=0.5): 33 | y_true = y_true.view(-1).cpu().detach().numpy().astype(np.int) 34 | y_pre = y_pre.cpu().detach().numpy() 35 | y_pre = np.argmax(y_pre, axis=-1) 36 | return f1_score(y_true, y_pre, average='macro') 37 | 38 | def print_time_cost(since): 39 | time_elapsed = time.time() - since 40 | return '{:.0f}m{:.0f}s\n'.format(time_elapsed // 60, time_elapsed % 60) 41 | 42 | 43 | def adjust_learning_rate(optimizer, lr): 44 | for param_group in optimizer.param_groups: 45 | param_group['lr'] = lr 46 | return lr 47 | 48 | class WeightedMultilabel(nn.Module): 49 | def __init__(self, weights: torch.Tensor): 50 | super(WeightedMultilabel, self).__init__() 51 | self.cerition = nn.BCEWithLogitsLoss(reduction='none') 52 | self.weights = weights 53 | 54 | def forward(self, outputs, targets): 55 | loss = self.cerition(outputs, targets) 56 | return (loss * self.weights).mean() 57 | # ======================================= 58 | def sig_wt_filt(sig): 59 | """ 60 | 对信号进行小波变换滤波 61 | :param sig: 输入信号,1-d array 62 | :return: 小波滤波后的信号,1-d array 63 | 64 | """ 65 | coeffs = pywt.wavedec(sig, 'db6', level=9) 66 | coeffs[-1] = np.zeros(len(coeffs[-1])) 67 | coeffs[-2] = np.zeros(len(coeffs[-2])) 68 | coeffs[0] = np.zeros(len(coeffs[0])) 69 | sig_filt = pywt.waverec(coeffs, 'db6') 70 | return sig_filt 71 | 72 | 73 | def multi_prep(sig, target_point_num=1280): 74 | """ 75 | 信号预处理 76 | :param sig: 原始信号,1-d array 77 | :param target_point_num: 信号目标长度,int 78 | :return: 重采样并z-score标准化后的信号,1-d array 79 | """ 80 | assert len(sig.shape) == 2, 'Not for 1-D data.Use 2-D data.' 81 | sig = resample(sig, target_point_num, axis=1) 82 | for i in range(sig.shape[0]): 83 | sig[i] = sig_wt_filt(sig[i]) 84 | sig = scale(sig, axis=1) 85 | return sig 86 | 87 | 88 | def plot_confusion_matrix(y_true, y_pred, classes, 89 | normalize=False, 90 | title=None, 91 | cmap=plt.cm.Blues): 92 | """ 93 | 绘制混淆矩阵图,来源: 94 | https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py 95 | """ 96 | if not title: 97 | if normalize: 98 | title = 'Normalized confusion matrix' 99 | else: 100 | title = 'Confusion matrix, without normalization' 101 | 102 | cm = confusion_matrix(y_true, y_pred) 103 | 104 | classes = classes[unique_labels(y_true, y_pred)] 105 | if normalize: 106 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 107 | print("Normalized confusion matrix") 108 | else: 109 | print('Confusion matrix, without normalization') 110 | 111 | print(cm) 112 | 113 | # fig, ax = plt.subplots() 114 | # # for i in range(5): 115 | # # cm[i,i] = 0 116 | # im = ax.imshow(cm, interpolation='nearest', cmap=cmap) 117 | # ax.figure.colorbar(im, ax=ax) 118 | # ax.set(xticks=np.arange(cm.shape[1]), 119 | # yticks=np.arange(cm.shape[0]), 120 | # xticklabels=classes, yticklabels=classes, 121 | # title=title, 122 | # ylabel='True label', 123 | # xlabel='Predicted label') 124 | # 125 | # plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 126 | # rotation_mode="anchor") 127 | # 128 | # fmt = '.2f' if normalize else 'd' 129 | # thresh = cm.max() / 2. 130 | # for i in range(cm.shape[0]): 131 | # for j in range(cm.shape[1]): 132 | # ax.text(j, i, format(cm[i, j], fmt), 133 | # ha="center", va="center", 134 | # color="white" if cm[i, j] > thresh else "black") 135 | # fig.tight_layout() 136 | return cm 137 | 138 | 139 | def print_results(y_true, y_pred, target_names): 140 | """ 141 | 打印相关结果 142 | :param y_true: 期望输出,1-d array 143 | :param y_pred: 实际输出,1-d array 144 | :param target_names: 各类别名称 145 | :return: 打印结果 146 | """ 147 | overall_accuracy = accuracy_score(y_true, y_pred) 148 | print('\n----- overall_accuracy: {0:f} -----'.format(overall_accuracy)) 149 | cm = confusion_matrix(y_true, y_pred) 150 | for i in range(len(target_names)): 151 | print(target_names[i] + ':') 152 | Se = cm[i][i]/np.sum(cm[i]) 153 | Pp = cm[i][i]/np.sum(cm[:, i]) 154 | print(' Se = ' + str(Se)) 155 | print(' P+ = ' + str(Pp)) 156 | print('--------------------------------------') 157 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed1d.pth', 16 | } 17 | 18 | dp_rate = 0 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv1d(in_planes, out_planes, kernel_size=33, stride=stride, 22 | padding=16, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 2 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.bn0 = nn.BatchNorm1d(inplanes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm1d(planes) 34 | self.conv2 = conv3x3(planes, planes*2) 35 | 36 | self.downsample = downsample 37 | self.stride = stride 38 | self.dropout = nn.Dropout(dp_rate) 39 | 40 | def forward(self, x): 41 | residual = x 42 | out = self.bn0(x) 43 | out = self.relu(out) 44 | # out = self.dropout(out) 45 | out = self.conv1(out) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | out = self.dropout(out) 49 | out = self.conv2(out) 50 | 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | # residual = torch.cat((residual,residual),1) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.bn0 = nn.BatchNorm1d(inplanes) 68 | self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=33, bias=False, padding=16) 69 | self.bn1 = nn.BatchNorm1d(planes) 70 | self.conv2 = nn.Conv1d(planes, planes, kernel_size=65, stride=stride, 71 | padding=32, bias=False) 72 | self.bn2 = nn.BatchNorm1d(planes) 73 | self.conv3 = nn.Conv1d(planes, planes * 4, kernel_size=1, bias=False, padding=0) 74 | self.bn3 = nn.BatchNorm1d(planes * 4) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | self.dropout = nn.Dropout(dp_rate) 79 | 80 | def forward(self, x): 81 | residual = x 82 | out = self.bn0(x) 83 | out = self.relu(out) 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | out = self.dropout(out) 93 | 94 | out = self.conv3(out) 95 | # out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | # residual = torch.cat((residual, residual), 1) 100 | 101 | out += residual 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class ResNet(nn.Module): 108 | 109 | def __init__(self, block, layers, classification, num_classes=5): 110 | self.inplanes = 12 111 | self.classification = classification 112 | super(ResNet, self).__init__() 113 | self.conv1 = nn.Conv1d(1, self.inplanes, kernel_size=33, stride=1, padding=16, 114 | bias=False) 115 | self.bn1 = nn.BatchNorm1d(self.inplanes) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 118 | self.conv2 = nn.Conv1d(self.inplanes, self.inplanes, kernel_size=33, stride=2, padding=16, 119 | bias=False) 120 | self.bn2 = nn.BatchNorm1d(self.inplanes) 121 | self.downsample = nn.MaxPool1d(kernel_size=2, stride=2) 122 | self.conv3 = nn.Conv1d(self.inplanes, self.inplanes, kernel_size=33, stride=1, padding=16, 123 | bias=False) 124 | self.dropout = nn.Dropout(dp_rate) 125 | self.layer1 = self._make_layer(block, 12, layers[0], stride=2) 126 | self.layer2 = self._make_layer(block, 24, layers[1], stride=2) 127 | self.layer3 = self._make_layer(block, 48, layers[2], stride=2) 128 | self.layer4 = self._make_layer(block, 96, layers[3], stride=2) 129 | # self.layer5 = self._make_layer(block, self.inplanes, layers[4], stride=2) 130 | self.bn_final = nn.BatchNorm1d(96*2) 131 | self.avgpool = nn.AdaptiveAvgPool1d(2) 132 | self.fc1 = nn.Linear(96*4, 384) 133 | self.bn3 = nn.BatchNorm1d(384) 134 | self.fc2 = nn.Linear(384, 192) 135 | self.bn4 = nn.BatchNorm1d(192) 136 | self.fc3 = nn.Linear(192, 5) 137 | self.softmax = nn.Softmax(1) 138 | 139 | for m in self.modules(): 140 | if isinstance(m, nn.Conv1d): 141 | nn.init.kaiming_normal_(m.weight.data, mode='fan_in', nonlinearity='relu') 142 | elif isinstance(m, nn.BatchNorm1d): 143 | m.weight.data.fill_(1) 144 | m.bias.data.zero_() 145 | elif isinstance(m, nn.Linear): 146 | m.weight.data.normal_(0, 0.01) 147 | m.bias.data.zero_() 148 | 149 | def _make_layer(self, block, planes, blocks, stride=1): 150 | downsample = None 151 | if stride != 1: 152 | downsample = nn.Sequential( 153 | nn.Conv1d(self.inplanes, planes * block.expansion, 154 | kernel_size=1, stride=stride, bias=False), 155 | nn.BatchNorm1d(planes * block.expansion), 156 | ) 157 | 158 | layers = [] 159 | 160 | layers.append(block(self.inplanes, planes, stride, downsample)) 161 | self.inplanes = planes * block.expansion 162 | for _ in range(1, blocks): 163 | layers.append(block(self.inplanes, planes)) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def forward(self, x): 168 | x = self.conv1(x) 169 | x = self.bn1(x) 170 | x = self.relu(x) 171 | # x = self.maxpool(x) 172 | out = self.conv2(x) 173 | out = self.bn2(out) 174 | out = self.relu(out) 175 | out = self.dropout(out) 176 | out = self.conv3(out) 177 | residual = self.downsample(x) 178 | out += residual 179 | x = self.relu(out) 180 | 181 | x = self.layer1(x) 182 | x = self.layer2(x) 183 | x = self.layer3(x) 184 | x = self.layer4(x) 185 | # x = self.layer5(x) 186 | x = self.bn_final(x) 187 | x = self.avgpool(x) 188 | x = x.view(x.size(0), -1) 189 | if self.classification: 190 | x = self.fc1(x) 191 | x = self.bn3(x) 192 | x = self.relu(x) 193 | x = self.dropout(x) 194 | x = self.fc2(x) 195 | x = self.bn4(x) 196 | x = self.relu(x) 197 | x = self.dropout(x) 198 | x = self.fc3(x) 199 | # x = self.softmax(x) 200 | 201 | return x 202 | 203 | 204 | def resnet18(pretrained=False, **kwargs): 205 | """Constructs a ResNet-18 model. 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(BasicBlock, [ 2, 2, 2, 2], **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 212 | return model 213 | 214 | 215 | def resnet34(pretrained=False, **kwargs): 216 | """Constructs a ResNet-34 model. 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on ImageNet 219 | """ 220 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 221 | if pretrained: 222 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 223 | return model 224 | 225 | 226 | def resnet50(pretrained=False, **kwargs): 227 | """Constructs a ResNet-50 model. 228 | Args: 229 | pretrained (bool): If True, returns a model pre-trained on ImageNet 230 | """ 231 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 232 | if pretrained: 233 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 234 | return model 235 | 236 | 237 | def resnet101(pretrained=False, **kwargs): 238 | """Constructs a ResNet-101 model. 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | """ 242 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 243 | if pretrained: 244 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 245 | return model 246 | 247 | 248 | def resnet152(pretrained=False, **kwargs): 249 | """Constructs a ResNet-152 model. 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | """ 253 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 254 | if pretrained: 255 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 256 | return model 257 | 258 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy import signal 4 | 5 | import math 6 | import cv2 7 | 8 | import random 9 | class Transform: 10 | def __init__(self): 11 | pass 12 | 13 | 14 | 15 | 16 | def add_noise(self, signal, noise_amount): 17 | """ 18 | adding noise 19 | """ 20 | signal = signal.T 21 | noise = (0.4 ** 0.5) * np.random.normal(1, noise_amount, np.shape(signal)[0]) 22 | noise = noise[:,None] 23 | noised_signal = signal + noise 24 | noised_signal = noised_signal.T 25 | # print(noised_signal.shape) 26 | return noised_signal 27 | 28 | def add_noise_with_SNR(self,signal, noise_amount): 29 | """ 30 | adding noise 31 | created using: https://stackoverflow.com/a/53688043/10700812 32 | """ 33 | signal = signal[0] 34 | target_snr_db = noise_amount # 20 35 | x_watts = signal ** 2 # Calculate signal power and convert to dB 36 | sig_avg_watts = np.mean(x_watts) 37 | sig_avg_db = 10 * np.log10(sig_avg_watts) # Calculate noise then convert to watts 38 | noise_avg_db = sig_avg_db - target_snr_db 39 | noise_avg_watts = 10 ** (noise_avg_db / 10) 40 | mean_noise = 0 41 | noise_volts = np.random.normal(mean_noise, np.sqrt(noise_avg_watts), 42 | len(x_watts)) # Generate an sample of white noise 43 | noised_signal = signal + noise_volts # noise added signal 44 | noised_signal = noised_signal[None,:] 45 | # print(noised_signal.shape) 46 | 47 | return noised_signal 48 | 49 | def scaled(self,signal, factor_list): 50 | """" 51 | scale the signal 52 | """ 53 | factor = round(np.random.uniform(factor_list[0],factor_list[1]),2) 54 | signal[0] = 1 / (1 + np.exp(-signal[0])) 55 | # print(signal.max()) 56 | return signal 57 | 58 | def negate(self,signal): 59 | """ 60 | negate the signal 61 | """ 62 | signal[0] = signal[0] * (-1) 63 | return signal 64 | 65 | def hor_filp(self,signal): 66 | """ 67 | flipped horizontally 68 | """ 69 | hor_flipped = np.flip(signal,axis=1) 70 | return hor_flipped 71 | 72 | def permute(self,signal, pieces): 73 | """ 74 | signal: numpy array (batch x window) 75 | pieces: number of segments along time 76 | """ 77 | signal = signal.T 78 | pieces = int(np.ceil(np.shape(signal)[0] / (np.shape(signal)[0] // pieces)).tolist()) #向上取整 79 | piece_length = int(np.shape(signal)[0] // pieces) 80 | 81 | sequence = list(range(0, pieces)) 82 | np.random.shuffle(sequence) 83 | 84 | permuted_signal = np.reshape(signal[:(np.shape(signal)[0] // pieces * pieces)], 85 | (pieces, piece_length)).tolist() 86 | 87 | tail = signal[(np.shape(signal)[0] // pieces * pieces):] 88 | permuted_signal = np.asarray(permuted_signal)[sequence] 89 | permuted_signal = np.concatenate(permuted_signal, axis=0) 90 | permuted_signal = np.concatenate((permuted_signal,tail[:,0]), axis=0) 91 | permuted_signal = permuted_signal[:,None] 92 | permuted_signal = permuted_signal.T 93 | return permuted_signal 94 | 95 | def cutout_resize(self,signal,pieces): 96 | """ 97 | signal: numpy array (batch x window) 98 | pieces: number of segments along time 99 | cutout 1 piece 100 | """ 101 | signal = signal.T 102 | pieces = int(np.ceil(np.shape(signal)[0] / (np.shape(signal)[0] // pieces)).tolist()) # 向上取整 103 | piece_length = int(np.shape(signal)[0] // pieces) 104 | import random 105 | sequence = [] 106 | 107 | cutout = random.randint(0, pieces) 108 | # print(cutout) 109 | # sequence1 = list(range(0, cutout)) 110 | # sequence2 = list(range(int(cutout + 1), pieces)) 111 | # sequence = np.hstack((sequence1, sequence2)) 112 | for i in range(pieces): 113 | if i == cutout: 114 | pass 115 | else: 116 | sequence.append(i) 117 | # print(sequence) 118 | 119 | cutout_signal = np.reshape(signal[:(np.shape(signal)[0] // pieces * pieces)], 120 | (pieces, piece_length)).tolist() 121 | 122 | tail = signal[(np.shape(signal)[0] // pieces * pieces):] 123 | 124 | cutout_signal = np.asarray(cutout_signal)[sequence] 125 | 126 | cutout_signal = np.hstack(cutout_signal) 127 | cutout_signal = np.concatenate((cutout_signal, tail[:, 0]), axis=0) 128 | 129 | cutout_signal = cv2.resize(cutout_signal, (1, 3072), interpolation=cv2.INTER_LINEAR) 130 | cutout_signal = cutout_signal.T 131 | 132 | 133 | return cutout_signal 134 | 135 | def cutout_zero(self,signal,pieces): 136 | """ 137 | signal: numpy array (batch x window) 138 | pieces: number of segments along time 139 | cutout 1 piece 140 | """ 141 | signal = signal.T 142 | ones = np.ones((np.shape(signal)[0],np.shape(signal)[1])) 143 | # print(ones.shape) 144 | # assert False 145 | pieces = int(np.ceil(np.shape(signal)[0] / (np.shape(signal)[0] // pieces)).tolist()) # 向上取整 146 | piece_length = int(np.shape(signal)[0] // pieces) 147 | 148 | 149 | cutout = random.randint(1, pieces) 150 | cutout_signal = np.reshape(signal[:(np.shape(signal)[0] // pieces * pieces)], 151 | (pieces, piece_length)).tolist() 152 | ones_pieces = np.reshape(ones[:(np.shape(signal)[0] // pieces * pieces)], 153 | (pieces, piece_length)).tolist() 154 | tail = signal[(np.shape(signal)[0] // pieces * pieces):] 155 | 156 | cutout_signal = np.asarray(cutout_signal) 157 | ones_pieces = np.asarray(ones_pieces) 158 | for i in range(pieces): 159 | if i == cutout: 160 | ones_pieces[i]*=0 161 | 162 | cutout_signal = cutout_signal * ones_pieces 163 | cutout_signal = np.hstack(cutout_signal) 164 | cutout_signal = np.concatenate((cutout_signal, tail[:, 0]), axis=0) 165 | cutout_signal = cutout_signal[:,None] 166 | cutout_signal = cutout_signal.T 167 | 168 | return cutout_signal 169 | # mic 170 | def crop_resize(self, signal, size): 171 | signal = signal.T 172 | size = signal.shape[0] * size 173 | size = int(size) 174 | start = random.randint(0, signal.shape[0]-size) 175 | crop_signal = signal[start:start + size,:] 176 | # print(crop_signal.shape) 177 | 178 | crop_signal = cv2.resize(crop_signal, (1, 3072), interpolation=cv2.INTER_LINEAR) 179 | # print(crop_signal.shape) 180 | crop_signal = crop_signal.T 181 | return crop_signal 182 | 183 | def move_avg(self,a,n, mode="same"): 184 | # a = a.T 185 | 186 | result = np.convolve(a[0], np.ones((n,)) / n, mode=mode) 187 | return result[None,:] 188 | 189 | def bandpass_filter(self, x, order, cutoff, fs=100): 190 | result = np.zeros((x.shape[0], x.shape[1])) 191 | w1 = 2 * cutoff[0] / int(fs) 192 | w2 = 2 * cutoff[1] / int(fs) 193 | b, a = signal.butter(order, [w1, w2], btype='bandpass') # 配置滤波器 8 表示滤波器的阶数 194 | result = signal.filtfilt(b, a, x, axis=1) 195 | # print(result.shape) 196 | 197 | return result 198 | 199 | def lowpass_filter(self, x, order, cutoff, fs=100): 200 | result = np.zeros((x.shape[0], x.shape[1])) 201 | w1 = 2 * cutoff[0] / int(fs) 202 | # w2 = 2 * cutoff[1] / fs 203 | b, a = signal.butter(order, w1, btype='lowpass') # 配置滤波器 8 表示滤波器的阶数 204 | result = signal.filtfilt(b, a, x, axis=1) 205 | # print(result.shape) 206 | 207 | return result 208 | 209 | def highpass_filter(self, x, order, cutoff, fs=100): 210 | result = np.zeros((x.shape[0], x.shape[1])) 211 | w1 = 2 * cutoff[0] / int(fs) 212 | # w2 = 2 * cutoff[1] / fs 213 | b, a = signal.butter(order, w1, btype='highpass') # 配置滤波器 8 表示滤波器的阶数 214 | result = signal.filtfilt(b, a, x, axis=1) 215 | # print(result.shape) 216 | 217 | return result 218 | 219 | 220 | def time_warp(self,signal, sampling_freq, pieces, stretch_factor, squeeze_factor): 221 | """ 222 | signal: numpy array (batch x window) 223 | sampling freq 224 | pieces: number of segments along time 225 | stretch factor 226 | squeeze factor 227 | """ 228 | signal = signal.T 229 | 230 | total_time = np.shape(signal)[0] // sampling_freq 231 | segment_time = total_time / pieces 232 | sequence = list(range(0, pieces)) 233 | stretch = np.random.choice(sequence, math.ceil(len(sequence) / 2), replace=False) 234 | squeeze = list(set(sequence).difference(set(stretch))) 235 | initialize = True 236 | for i in sequence: 237 | orig_signal = signal[int(i * np.floor(segment_time * sampling_freq)):int( 238 | (i + 1) * np.floor(segment_time * sampling_freq))] 239 | orig_signal = orig_signal.reshape(np.shape(orig_signal)[0], 1) 240 | if i in stretch: 241 | output_shape = int(np.ceil(np.shape(orig_signal)[0] * stretch_factor)) 242 | new_signal = cv2.resize(orig_signal, (1, output_shape), interpolation=cv2.INTER_LINEAR) 243 | if initialize == True: 244 | time_warped = new_signal 245 | initialize = False 246 | else: 247 | time_warped = np.vstack((time_warped, new_signal)) 248 | elif i in squeeze: 249 | output_shape = int(np.ceil(np.shape(orig_signal)[0] * squeeze_factor)) 250 | new_signal = cv2.resize(orig_signal, (1, output_shape), interpolation=cv2.INTER_LINEAR) 251 | if initialize == True: 252 | time_warped = new_signal 253 | initialize = False 254 | else: 255 | time_warped = np.vstack((time_warped, new_signal)) 256 | time_warped = cv2.resize(time_warped, (1,3072), interpolation=cv2.INTER_LINEAR) 257 | time_warped = time_warped.T 258 | return time_warped 259 | 260 | if __name__ == '__main__': 261 | from transform import Transform 262 | import matplotlib.pyplot as plt 263 | Trans = Transform() 264 | input = np.zeros((1,3072)) 265 | input = Trans.add_noise(input,10) 266 | plt.subplot(211) 267 | plt.plot(input[0]) 268 | 269 | # print(input.shape) 270 | # output = Trans.cutout_resize(input,10) 271 | order = random.randint(3, 10) 272 | cutoff = random.uniform(5, 20) 273 | output = Trans.filter(input, order, [2,15], mode='lowpass') 274 | plt.subplot(212) 275 | plt.plot(output[0]) 276 | plt.savefig('filter.png') 277 | # print(output.shape) 278 | --------------------------------------------------------------------------------