├── .gitignore ├── viz.py ├── utils ├── trace_utils.py ├── lr_sched.py ├── helpers.py ├── format.py ├── crop.py ├── lars.py ├── datasets.py ├── ecg_dataloader.py ├── lr_decay.py ├── pos_embed.py ├── ecg_multilabel.py ├── patch_embed.py ├── misc.py └── utils.py ├── README.md ├── adap_weight.py ├── models_vit.py ├── test.py ├── vit_model.py ├── engine_pretrain.py ├── engine_finetune.py ├── main_pretrain.py ├── models_mae.py ├── main_linprobe.py └── main_finetune.py /.gitignore: -------------------------------------------------------------------------------- 1 | dl/ 2 | output_dir/ 3 | __pycache__/ 4 | output_dir_fin/ 5 | output_dir_fin_f90/ 6 | fin_0_40k/ 7 | fin_360_40k/ 8 | fin_0_40k/ 9 | checkpoint-86.pth/ 10 | Results/ 11 | Checkpoints/ 12 | ptb_xl/ 13 | ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3.zip/ 14 | ptb_benchmark/ 15 | -------------------------------------------------------------------------------- /viz.py: -------------------------------------------------------------------------------- 1 | from graphviz import Digraph 2 | import torch 3 | from torch.autograd import Variable 4 | 5 | # make_dot was moved to https://github.com/szagoruyko/pytorchviz 6 | from torchviz import make_dot 7 | from models_mae import mae_vit_1dcnn 8 | 9 | x = torch.rand(10, 1, 12, 1000) 10 | model = mae_vit_1dcnn() 11 | w = model(x) 12 | print(model.named_parameters()) -------------------------------------------------------------------------------- /utils/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /utils/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Masked Autoencoder meets GAN for ECG 2 | 3 | Pytorch Implementation of [Masked Auto-Encoders Meet Generative Adversarial Networks and Beyond](https://feizc.github.io/resume/ganmae.pdf) for ECG Signals. 4 | 5 | To Pretrain run : 6 | 7 | ``` 8 | python main_pretrain.py \ 9 | --batch_size 64 \ 10 | --norm_pix_loss \ 11 | --mask_ratio 0.75 \ 12 | --epochs 500 \ 13 | --warmup_epochs 10 \ 14 | --data_path ${IMAGENET_DIR} \ 15 | --lr 1e-3 \ 16 | --cuda "CUDA" 17 | 18 | ``` 19 | 20 | data_path to the physionet - 21 | 22 | Eg. if path to the physionet dataset is 23 | 24 | /Users/parthagrawal02/Desktop/ECG_CNN/physionet/WFDBRecords 25 | 26 | then --datapath '/Users/parthagrawal02/Desktop/ECG_CNN/physionet' 27 | 28 | To Finetune : 29 | 30 | ``` 31 | python /kaggle/working/ECG_MAE/main_finetune.py\ 32 | --model vit_1dcnn \ 33 | --finetune '/checkpoint-360.pth' \ 34 | --epochs 70 \ 35 | --lr 5e-3 \ 36 | --data_path /Users/parthagrawal02/Desktop/ECG_CNN/physionet \ 37 | --cuda 'CUDA'\ 38 | --train_start 0 --train_end 46 --data_split 0.85 39 | ``` 40 | 41 | Modify ecg_dataloader according to the dataset 42 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 13 | return tuple(x) 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def extend_tuple(x, n): 35 | # pdas a tuple to specified n by padding with last value 36 | if not isinstance(x, (tuple, list)): 37 | x = (x,) 38 | else: 39 | x = tuple(x) 40 | pad_n = n - len(x) 41 | if pad_n <= 0: 42 | return x[:n] 43 | return x + (x[-1],) * pad_n 44 | -------------------------------------------------------------------------------- /utils/format.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Union 3 | 4 | import torch 5 | 6 | 7 | class Format(str, Enum): 8 | NCHW = 'NCHW' 9 | NHWC = 'NHWC' 10 | NCL = 'NCL' 11 | NLC = 'NLC' 12 | 13 | 14 | FormatT = Union[str, Format] 15 | 16 | 17 | def get_spatial_dim(fmt: FormatT): 18 | fmt = Format(fmt) 19 | if fmt is Format.NLC: 20 | dim = (1,) 21 | elif fmt is Format.NCL: 22 | dim = (2,) 23 | elif fmt is Format.NHWC: 24 | dim = (1, 2) 25 | else: 26 | dim = (2, 3) 27 | return dim 28 | 29 | 30 | def get_channel_dim(fmt: FormatT): 31 | fmt = Format(fmt) 32 | if fmt is Format.NHWC: 33 | dim = 3 34 | elif fmt is Format.NLC: 35 | dim = 2 36 | else: 37 | dim = 1 38 | return dim 39 | 40 | 41 | def nchw_to(x: torch.Tensor, fmt: Format): 42 | if fmt == Format.NHWC: 43 | x = x.permute(0, 2, 3, 1) 44 | elif fmt == Format.NLC: 45 | x = x.flatten(2).transpose(1, 2) 46 | elif fmt == Format.NCL: 47 | x = x.flatten(2) 48 | return x 49 | 50 | 51 | def nhwc_to(x: torch.Tensor, fmt: Format): 52 | if fmt == Format.NCHW: 53 | x = x.permute(0, 3, 1, 2) 54 | elif fmt == Format.NLC: 55 | x = x.flatten(1, 2) 56 | elif fmt == Format.NCL: 57 | x = x.flatten(1, 2).transpose(1, 2) 58 | return x -------------------------------------------------------------------------------- /utils/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /adap_weight.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def aw_loss(L_mae, L_adv, Gen_opt, Gen_net): 6 | # resetting gradient back to zero 7 | Gen_opt.zero_grad() 8 | 9 | # computing real batch gradient 10 | L_mae.backward(retain_graph=True) 11 | # tensor with real gradients 12 | grad_real_tensor = [param.grad.clone() for _, param in Gen_net.named_parameters() if param.grad is not None] 13 | grad_real_list = torch.cat([grad.reshape(-1) for grad in grad_real_tensor], dim=0) 14 | # calculating the norm of the real gradient 15 | rdotr = torch.dot(grad_real_list, grad_real_list).item() 16 | mae_norm = np.sqrt(rdotr) 17 | # resetting gradient back to zero 18 | Gen_opt.zero_grad() 19 | 20 | # computing fake batch gradient 21 | L_adv.backward(retain_graph = True)#(retain_graph=True) 22 | # tensor with real gradients 23 | grad_fake_tensor = [param.grad.clone() for _, param in Gen_net.named_parameters() if param.grad is not None] 24 | grad_fake_list = torch.cat([grad.reshape(-1) for grad in grad_fake_tensor], dim=0) 25 | # calculating the norm of the fake gradient 26 | fdotf = torch.dot(grad_fake_list, grad_fake_list).item() + 1e-6 # 1e-4 added to avoid division by zero 27 | adv_norm = np.sqrt(fdotf) 28 | 29 | # resetting gradient back to zero 30 | Gen_opt.zero_grad() 31 | 32 | # dot product between real and fake gradients 33 | adaptive_weight = mae_norm/adv_norm 34 | # print(adaptive_weight) 35 | # print(L_mae) 36 | # print(L_adv) 37 | # calculating aw_loss 38 | aw_loss = L_mae + adaptive_weight * L_adv 39 | 40 | # updating gradient, i.e. getting aw_loss gradient 41 | for index, (_, param) in enumerate(Gen_net.named_parameters()): 42 | # print(grad_real_tensor[index]) 43 | # print(grad_fake_tensor[index]) 44 | if param.grad is not None: 45 | param.grad = grad_real_tensor[index] + adaptive_weight * grad_fake_tensor[index] 46 | 47 | return aw_loss -------------------------------------------------------------------------------- /utils/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /utils/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | 22 | transform = build_transform(is_train, args) 23 | 24 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 25 | dataset = datasets.ImageFolder(root, transform=transform) 26 | 27 | print(dataset) 28 | 29 | return dataset 30 | 31 | 32 | def build_transform(is_train, args): 33 | mean = IMAGENET_DEFAULT_MEAN 34 | std = IMAGENET_DEFAULT_STD 35 | # train transform 36 | if is_train: 37 | # this should always dispatch to transforms_imagenet_train 38 | transform = create_transform( 39 | input_size=args.input_size, 40 | is_training=True, 41 | color_jitter=args.color_jitter, 42 | auto_augment=args.aa, 43 | interpolation='bicubic', 44 | re_prob=args.reprob, 45 | re_mode=args.remode, 46 | re_count=args.recount, 47 | mean=mean, 48 | std=std, 49 | ) 50 | return transform 51 | 52 | # eval transform 53 | t = [] 54 | if args.input_size <= 224: 55 | crop_pct = 224 / 256 56 | else: 57 | crop_pct = 1.0 58 | size = int(args.input_size / crop_pct) 59 | t.append( 60 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 61 | ) 62 | t.append(transforms.CenterCrop(args.input_size)) 63 | 64 | t.append(transforms.ToTensor()) 65 | t.append(transforms.Normalize(mean, std)) 66 | return transforms.Compose(t) 67 | -------------------------------------------------------------------------------- /models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | import timm.models.vision_transformer 18 | 19 | 20 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 21 | """ Vision Transformer with support for global average pooling 22 | """ 23 | def __init__(self, global_pool=False, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | 26 | self.global_pool = global_pool 27 | if self.global_pool: 28 | norm_layer = kwargs['norm_layer'] 29 | embed_dim = kwargs['embed_dim'] 30 | self.fc_norm = norm_layer(embed_dim) 31 | 32 | del self.norm # remove the original norm 33 | 34 | def forward_features(self, x): 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | x = torch.cat((cls_tokens, x), dim=1) 40 | x = x + self.pos_embed 41 | x = self.pos_drop(x) 42 | 43 | for blk in self.blocks: 44 | x = blk(x) 45 | 46 | if self.global_pool: 47 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 48 | outcome = self.fc_norm(x) 49 | else: 50 | x = self.norm(x) 51 | outcome = x[:, 0] 52 | 53 | return outcome 54 | # VIT 55 | def vit_ecg_patch_50(**kwargs): 56 | model = VisionTransformer( 57 | patch_size=(1, 50), embed_dim=128, depth=6, num_heads=8, 58 | mlp_ratio=3, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 59 | return model 60 | 61 | def vit_1dcnn(**kwargs): 62 | model = VisionTransformer( 63 | patch_size=(1, 50), embed_dim=128, depth=6, num_heads=8, 64 | mlp_ratio=3, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 65 | return model -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from tsai.all import * 2 | import torch 3 | import numpy as np 4 | """ 5 | c_in - No. of Channels 6 | c_out - No. of target classes - 2 for binary classification 7 | seq_len - 10*frequency 8 | d_model - 128 9 | depth - 6 10 | n_heads - 8 11 | mlp_ratio - 3 12 | token_size - 50 13 | tokenizer - 14 | 15 | TSiTPlus(c_in:12, c_out:1, seq_len:int, d_model:int=128, depth:int=6, 16 | n_heads:int=16, act:str='gelu', lsa:bool=False, 17 | attn_dropout:float=0.0, dropout:float=0.0, 18 | drop_path_rate:float=0.0, mlp_ratio:int=1, qkv_bias:bool=True, 19 | pre_norm:bool=False, use_token:bool=False, use_pe:bool=True, 20 | cat_pos:Optional[list]=None, n_cat_embeds:Optional[list]=None, 21 | cat_embed_dims:Optional[list]=None, 22 | cat_padding_idxs:Optional[list]=None, token_size:int=None, 23 | tokenizer:Optional[Callable]=None, 24 | feature_extractor:Optional[Callable]=None, flatten:bool=False, 25 | concat_pool:bool=True, fc_dropout:float=0.0, use_bn:bool=False, 26 | bias_init:Union[float,list,NoneType]=None, 27 | y_range:Optional[tuple]=None, 28 | custom_head:Optional[Callable]=None, verbose:bool=True, 29 | **kwargs) 30 | 31 | """ 32 | c_in = 12 33 | c_out = 2 34 | seq_len = 1000 35 | bs = 16 36 | """ 37 | Model parameters : 38 | 39 | MAE - Pretrained ViT size 40 | d_model = 128 41 | depth = 6 42 | n_heads = 8 43 | mlp_ratio = 3 44 | token_size = 50 45 | No. of Model Parameters - 1.08M 46 | 47 | VIT Base parameters 48 | d_model = 768 49 | depth = 12 50 | n_heads = 12 51 | mlp_ratio = 4 52 | token_size = 50 53 | No. of Model Parameters - 86M 54 | 55 | ViT Large Parameters 56 | d_model = 1024 57 | depth = 24 58 | n_heads = 16 59 | mlp_ratio = 4 60 | token_size = 50 61 | No. of Model Parameters - 307M 62 | 63 | ViT Huge Parameters 64 | d_model = 1280 65 | depth = 32 66 | n_heads = 16 67 | mlp_ratio = 4 68 | token_size = 50 69 | No. of Model Parameters - 630M 70 | """ 71 | 72 | d_model = 1280 73 | depth = 32 74 | n_heads = 16 75 | mlp_ratio = 4 76 | token_size = 50 77 | 78 | xb = torch.rand(bs, c_in, seq_len) 79 | bias_init = np.array([0.8, .2]) 80 | model = TSiTPlus(c_in = c_in, c_out=c_out, seq_len=seq_len,d_model=d_model, depth = depth, n_heads = n_heads, mlp_ratio = mlp_ratio, token_size=token_size) 81 | 82 | # test_eq(model.head[1].bias.data, tensor(bias_init)) 83 | def count_parameters(model): 84 | return sum(p.numel() for p in model.parameters()) 85 | 86 | print(count_parameters(model)) -------------------------------------------------------------------------------- /utils/ecg_dataloader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import re 6 | import wfdb 7 | from wfdb import processing 8 | import pdb 9 | 10 | class CustomDataset(Dataset): 11 | def __init__(self, data_path: str = "", start: int = 0, end: int = 46): 12 | self.class_map = { 13 | 426177001: 1, 14 | 426783006: 2, 15 | 164889003: 3, 16 | 427084000: 4, 17 | 164890007: 5, 18 | 427393009: 6, 19 | 426761007: 7, 20 | 713422000: 8, 21 | 233896004: 9, 22 | 233897008: 0 23 | } 24 | self.data_path = data_path 25 | self.data = [] 26 | y = [] 27 | for n in range(start, end): 28 | for j in range(0, 10): 29 | for filepath in glob.iglob(self.data_path + '/WFDBRecords/' + f"{n:02}" + '/' + f"{n:02}" + str(j) + '/*.hea'): 30 | try: 31 | ecg_record = wfdb.rdsamp(filepath[:-4]) 32 | except Exception: 33 | continue 34 | # pdb.set_trace() 35 | if(np.isnan(ecg_record[0]).any()): 36 | print(str(filepath)) 37 | continue 38 | numbers = re.findall(r'\d+', ecg_record[1]['comments'][2]) 39 | output_array = list(map(int, numbers)) 40 | for j in output_array: # Only classify into one of the predecided classes. 41 | if int(j) in self.class_map: 42 | output_array = j 43 | if isinstance(output_array, list): 44 | continue 45 | y.append(output_array) 46 | self.data.append([filepath, output_array]) 47 | def __len__(self): 48 | return len(self.data) 49 | 50 | def __getitem__(self, idx): 51 | ecg_path, class_name = self.data[idx] 52 | ecg_record = wfdb.rdsamp(ecg_path[:-4]) 53 | lx = [] 54 | for chan in range(ecg_record[0].shape[1]): 55 | resampled_x, _ = wfdb.processing.resample_sig(ecg_record[0][:, chan], 500, 100) 56 | lx.append(resampled_x) 57 | 58 | class_id = self.class_map[class_name] 59 | ecg_tensor = torch.from_numpy(np.array(lx)) 60 | img_tensor = ecg_tensor[None, :, :] 61 | mean = img_tensor.mean(dim=-1, keepdim=True) 62 | var = img_tensor.var(dim=-1, keepdim=True) 63 | img_tensor = (img_tensor - mean) / (var + 1.e-6)**.5 64 | class_id = torch.tensor([class_id]) 65 | return img_tensor, class_id -------------------------------------------------------------------------------- /utils/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /utils/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size[0], dtype=np.float32) 27 | grid_w = np.arange(grid_size[1], dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | # -------------------------------------------------------- 38 | # Interpolate position embeddings for high-resolution 39 | # References: 40 | # DeiT: https://github.com/facebookresearch/deit 41 | # -------------------------------------------------------- 42 | def interpolate_pos_embed(model, checkpoint_model): 43 | if 'pos_embed' in checkpoint_model: 44 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 45 | embedding_size = pos_embed_checkpoint.shape[-1] 46 | num_patches = model.patch_embed.num_patches 47 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 48 | # height (== width) for the checkpoint position embedding 49 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 50 | # height (== width) for the new position embedding 51 | new_size = int(num_patches ** 0.5) 52 | # class_token and dist_token are kept unchanged 53 | if orig_size != new_size: 54 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 55 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 56 | # only the position tokens are interpolated 57 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 58 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 59 | pos_tokens = torch.nn.functional.interpolate( 60 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 61 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 62 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 63 | checkpoint_model['pos_embed'] = new_pos_embed 64 | 65 | 66 | 67 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 68 | assert embed_dim % 2 == 0 69 | 70 | # use half of dimensions to encode grid_h 71 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 72 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 73 | 74 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 75 | return emb 76 | 77 | 78 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 79 | """ 80 | embed_dim: output dimension for each position 81 | pos: a list of positions to be encoded: size (M,) 82 | out: (M, D) 83 | """ 84 | assert embed_dim % 2 == 0 85 | omega = np.arange(embed_dim // 2, dtype=np.float32) 86 | omega /= embed_dim / 2. 87 | omega = 1. / 10000**omega # (D/2,) 88 | 89 | pos = pos.reshape(-1) # (M,) 90 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 91 | 92 | emb_sin = np.sin(out) # (M, D/2) 93 | emb_cos = np.cos(out) # (M, D/2) 94 | 95 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 96 | return emb 97 | 98 | 99 | -------------------------------------------------------------------------------- /utils/ecg_multilabel.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | import re 6 | import wfdb 7 | from wfdb import processing 8 | import pdb 9 | import pandas as pd 10 | 11 | class CustomDataset(Dataset): 12 | 13 | def __init__(self, data_path: str = "", start: int = 0, end: int = 46): 14 | self.data_path = data_path 15 | self.data = [] 16 | y = [] 17 | if(sampling_rate == 100): 18 | filepath = [path + f for f in df.filename_lr] 19 | else: 20 | filepath = [path + f for f in df.filename_hr] 21 | 22 | self.data.append([filepath, output_array]) 23 | 24 | def multihot_encoder(labels, n_categories = 1, dtype=torch.float32): 25 | label_set = set() 26 | for label_list in labels: 27 | label_set = label_set.union(set(label_list)) 28 | label_set = sorted(label_set) 29 | 30 | multihot_vectors = [] 31 | for label_list in labels: 32 | multihot_vectors.append([1 if x in label_list else 0 for x in label_set]) 33 | if dtype is None: 34 | return pd.DataFrame(multihot_vectors, columns=label_set) 35 | return torch.Tensor(multihot_vectors).to(dtype) 36 | 37 | def __len__(self): 38 | return len(self.data) 39 | 40 | def __getitem__(self, idx): 41 | ecg_path, class_name = self.data[idx] 42 | ecg_record = wfdb.rdsamp(ecg_path[:-4]) 43 | lx = [] 44 | for chan in range(ecg_record[0].shape[1]): 45 | resampled_x, _ = wfdb.processing.resample_sig(ecg_record[0][:, chan], 500, 100) 46 | lx.append(resampled_x) 47 | 48 | class_id = self.class_map[class_name] 49 | ecg_tensor = torch.from_numpy(np.array(lx)) 50 | img_tensor = ecg_tensor[None, :, :] 51 | mean = img_tensor.mean(dim=-1, keepdim=True) 52 | var = img_tensor.var(dim=-1, keepdim=True) 53 | img_tensor = (img_tensor - mean) / (var + 1.e-6)**.5 54 | class_id = torch.tensor([class_id]) 55 | return img_tensor, class_id 56 | 57 | 58 | 59 | def load_raw_data(df, sampling_rate, path): 60 | if(sampling_rate == 100): 61 | data = [path + f for f in df.filename_lr] 62 | else: 63 | data = [path + f for f in df.filename_hr] 64 | return data 65 | 66 | data = np.array([signal for signal, meta in data]) 67 | 68 | path = self.data_path 69 | sampling_rate = 100 70 | 71 | # load and convert annotation data 72 | Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id') 73 | Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) 74 | 75 | # Load raw signal data 76 | X = load_raw_data(Y, sampling_rate, path) 77 | 78 | # Load scp_statements.csv for diagnostic aggregation 79 | agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0) 80 | agg_df = agg_df[agg_df.diagnostic == 1] 81 | 82 | def aggregate_diagnostic(y_dic): 83 | tmp = [] 84 | for key in y_dic.keys(): 85 | if key in agg_df.index: 86 | tmp.append(agg_df.loc[key].diagnostic_class) 87 | return list(set(tmp)) 88 | 89 | # Apply diagnostic superclass 90 | Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic) 91 | 92 | # Split data into train and test 93 | test_fold = 10 94 | # Train 95 | X_train = X[np.where(Y.strat_fold != test_fold)] 96 | y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass 97 | # Test 98 | X_test = X[np.where(Y.strat_fold == test_fold)] 99 | y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass 100 | 101 | def multihot_encoder(labels, n_categories = 1, dtype=torch.float32): 102 | label_set = set() 103 | for label_list in labels: 104 | label_set = label_set.union(set(label_list)) 105 | label_set = sorted(label_set) 106 | 107 | multihot_vectors = [] 108 | for label_list in labels: 109 | multihot_vectors.append([1 if x in label_list else 0 for x in label_set]) 110 | if dtype is None: 111 | return pd.DataFrame(multihot_vectors, columns=label_set) 112 | return torch.Tensor(multihot_vectors).to(dtype) 113 | X_train = torch.tensor(X_train.transpose(0, 2, 1)) 114 | mean = X_train.mean(dim=-1, keepdim=True) 115 | var = X_train.var(dim=-1, keepdim=True) 116 | X_train = (X_train - mean) / (var + 1.e-6)**.5 117 | X_test = torch.tensor(X_test.transpose(0, 2, 1)) 118 | mean = X_test.mean(dim=-1, keepdim=True) 119 | var = X_test.var(dim=-1, keepdim=True) 120 | X_test = (X_test - mean) / (var + 1.e-6)**.5 121 | 122 | y_train = multihot_encoder(y_train, n_categories = 5) 123 | y_test = multihot_encoder(y_test, n_categories= 5) 124 | -------------------------------------------------------------------------------- /vit_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from timm.models.vision_transformer import Block 18 | from utils.patch_embed import PatchEmbed 19 | 20 | from utils.pos_embed import get_2d_sincos_pos_embed 21 | 22 | 23 | # Main changes img_size adjusted to 12 channel ECG signal - 12*1000 24 | # Functions - Patchify and unpatchify 25 | # Other functions remains the same 26 | 27 | 28 | class VisionTransformer(nn.Module): 29 | """ Masked Autoencoder with VisionTransformer backbone 30 | """ 31 | def __init__(self, img_size=(12, 1000), patch_size=(1, 50), in_chans=1, 32 | embed_dim=128, depth=6, num_heads=8, 33 | mlp_ratio=3., norm_layer=nn.LayerNorm, norm_pix_loss=False, num_classes = 10, global_pool=False, drop_rate = 0): 34 | super().__init__() 35 | use_fc_norm = global_pool 36 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 37 | num_patches = self.patch_embed.num_patches 38 | 39 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 40 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 41 | 42 | self.blocks = nn.ModuleList([ 43 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 44 | for i in range(depth)]) 45 | 46 | self.global_pool = global_pool 47 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 48 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 49 | self.head_drop = nn.Dropout(drop_rate) 50 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 51 | 52 | 53 | @torch.jit.ignore 54 | def no_weight_decay(self): 55 | return {'pos_embed', 'cls_token', 'dist_token'} 56 | 57 | def forward(self, x): 58 | # embed patches 59 | x = self.patch_embed(x) 60 | 61 | # add pos embed w/o cls token 62 | # masking: length -> length * mask_ratio 63 | 64 | x = x + self.pos_embed[:, 1:, :] 65 | 66 | # append cls token 67 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 68 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 69 | x = torch.cat((cls_tokens, x), dim=1) 70 | 71 | # apply Transformer blocks 72 | for blk in self.blocks: 73 | x = blk(x) 74 | 75 | if self.global_pool: 76 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 77 | # outcome = self.fc_norm(x) 78 | else: 79 | x = self.norm(x) 80 | x = x[:, 0] 81 | x = self.fc_norm(x) 82 | x = self.head_drop(x) 83 | return self.head(x) 84 | 85 | 86 | class Discriminator(nn.Module): 87 | """ Masked Autoencoder with VisionTransformer backbone 88 | """ 89 | def __init__(self, img_size=(12, 1000), patch_size=(1, 50), in_chans=1, 90 | embed_dim=128, depth=6, num_heads=8, 91 | mlp_ratio=3., norm_layer=nn.LayerNorm, norm_pix_loss=False, num_classes = 10, global_pool=False, drop_rate = 0): 92 | super().__init__() 93 | self.encoder = vit_1dcnn() 94 | self.output_shape = (img_size[0]//patch_size[0], img_size[1]//patch_size[1]) 95 | self.linear = nn.Linear(self.output_shape[0]*self.output_shape[1]*128, ) 96 | 97 | def forward(self, x): 98 | x = self.encoder(x) 99 | x = x.view(x.size(0), -1) 100 | return torch.sigmoid(self.linear(x)) 101 | 102 | 103 | 104 | # Model architecture as described in the paper. 105 | def vit_1dcnn(**kwargs): 106 | model = VisionTransformer( 107 | patch_size=(1, 50), embed_dim=128, depth=6, num_heads=8, 108 | mlp_ratio=3, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 109 | return model 110 | 111 | 112 | # def generator(**kwargs): 113 | # model = VisionTransformer( 114 | # patch_size=(1, 50), embed_dim=128, depth=6, num_heads=8, 115 | # mlp_ratio=3, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 116 | # return model 117 | 118 | # def discriminator(**kwargs): 119 | # model = VisionTransformer( 120 | # patch_size=(1, 50), embed_dim=128, depth=6, num_heads=8, 121 | # mlp_ratio=3, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 122 | # return model -------------------------------------------------------------------------------- /engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable 15 | 16 | import torch 17 | 18 | import utils.misc as misc 19 | import utils.lr_sched as lr_sched 20 | from adap_weight import aw_loss 21 | from utils.misc import plot_reconstruction 22 | 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, loss_scaler, 28 | log_writer=None, 29 | args=None): 30 | model.train(True) 31 | metric_logger = misc.MetricLogger(delimiter=" ") 32 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 20 35 | 36 | accum_iter = args.accum_iter 37 | 38 | optimizer.zero_grad() 39 | 40 | if log_writer is not None: 41 | print('log_dir: {}'.format(log_writer.log_dir)) 42 | 43 | for data_iter_step, (samples,_) in enumerate(data_loader): 44 | 45 | # we use a per iteration (instead of per epoch) lr scheduler 46 | if data_iter_step % accum_iter == 0: 47 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 48 | 49 | if args.cuda is not None: 50 | # with torch.cuda.amp.autocast(): 51 | mae_loss, pred, mask, disc_loss, adv_loss, currupt_img = model(samples.to(device), mask_ratio=args.mask_ratio) 52 | else: 53 | mae_loss, pred, mask, disc_loss, adv_loss, currupt_img = model(samples, mask_ratio=args.mask_ratio) 54 | 55 | # print(model.parameters()) 56 | gen_loss = aw_loss(mae_loss, adv_loss, optimizer, model) 57 | # print(gen_loss) 58 | gen_loss_value = gen_loss.item() 59 | if not math.isfinite(gen_loss_value): 60 | print("Loss is {}, stopping training".format(gen_loss_value)) 61 | sys.exit(1) 62 | 63 | gen_loss = gen_loss/accum_iter 64 | 65 | loss_scaler(gen_loss, optimizer, parameters=model.parameters(), 66 | update_grad=(data_iter_step + 1) % accum_iter == 0, retain_graph = True) 67 | 68 | if (data_iter_step + 1) % accum_iter == 0: 69 | optimizer.zero_grad() 70 | 71 | mae_loss, pred, mask, disc_loss, adv_loss, currupt_img = model(samples.to(device), mask_ratio=args.mask_ratio) 72 | 73 | disc_loss_value = disc_loss.item() 74 | mae_loss_value = mae_loss.item() 75 | 76 | if not math.isfinite(disc_loss_value): 77 | print("Loss is {}, stopping training".format(disc_loss_value)) 78 | sys.exit(1) 79 | 80 | disc_loss = disc_loss/accum_iter 81 | mae_loss = mae_loss/accum_iter 82 | 83 | loss_scaler(disc_loss, optimizer, parameters=model.parameters(), 84 | update_grad=(data_iter_step + 1) % accum_iter == 0, retain_graph = True) 85 | if (data_iter_step + 1) % accum_iter == 0: 86 | optimizer.zero_grad() 87 | 88 | if args.cuda is not None: 89 | torch.cuda.synchronize() 90 | 91 | metric_logger.update(disc_loss=disc_loss_value) 92 | metric_logger.update(gen_loss=gen_loss_value) 93 | metric_logger.update(mae_loss=mae_loss_value) 94 | 95 | lr = optimizer.param_groups[0]["lr"] 96 | metric_logger.update(lr=lr) 97 | 98 | disc_loss_value_reduce = misc.all_reduce_mean(disc_loss_value) 99 | gen_loss_value_reduce = misc.all_reduce_mean(gen_loss_value) 100 | mae_loss_value_reduce = misc.all_reduce_mean(mae_loss_value) 101 | 102 | 103 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 104 | """ We use epoch_1000x as the x-axis in tensorboard. 105 | This calibrates different curves when batch size changes. 106 | """ 107 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 108 | log_writer.add_scalar('disc_train_loss', disc_loss_value_reduce, epoch_1000x) 109 | log_writer.add_scalar('gen_train_loss', gen_loss_value_reduce, epoch_1000x) 110 | log_writer.add_scalar('mae_loss', mae_loss_value_reduce, epoch_1000x) 111 | log_writer.add_scalar('lr', lr, epoch_1000x) 112 | log_writer.add_figure('Reconstructed vs. actuals', 113 | plot_reconstruction(currupt_img, samples), 114 | global_step=epoch * len(data_loader) + data_iter_step) 115 | 116 | # gather the stats from all processes 117 | metric_logger.synchronize_between_processes() 118 | print("Averaged stats:", metric_logger) 119 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 120 | -------------------------------------------------------------------------------- /utils/patch_embed.py: -------------------------------------------------------------------------------- 1 | """ Image to Patch Embedding using Conv2d 2 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection. 4 | 5 | Based on code in: 6 | * https://github.com/google-research/vision_transformer 7 | * https://github.com/google-research/big_vision/tree/main/big_vision 8 | 9 | Hacked together by / Copyright 2020 Ross Wightman 10 | """ 11 | import logging 12 | from typing import Callable, List, Optional, Tuple, Union 13 | import numpy as np 14 | import torch 15 | from torch import nn as nn 16 | import torch.nn.functional as F 17 | 18 | from utils.format import Format, nchw_to 19 | from utils.helpers import to_2tuple 20 | from utils.trace_utils import _assert 21 | 22 | _logger = logging.getLogger(__name__) 23 | 24 | 25 | class PatchEmbed(nn.Module): 26 | """ 12 Channel ECG to Patch Embedding 27 | change img_size and patch_size as per data. 28 | """ 29 | output_fmt: Format 30 | dynamic_img_pad: torch.jit.Final[bool] 31 | 32 | def __init__( 33 | self, 34 | img_size: Optional[int] = (12, 1000), 35 | patch_size: int = (1, 50), 36 | in_chans: int = 1, 37 | embed_dim: int = 128, 38 | norm_layer: Optional[Callable] = None, 39 | flatten: bool = True, 40 | output_fmt: Optional[str] = None, 41 | bias: bool = True, 42 | strict_img_size: bool = True, 43 | dynamic_img_pad: bool = False, 44 | ): 45 | super().__init__() 46 | self.patch_size = patch_size 47 | if img_size is not None: 48 | self.img_size = img_size 49 | self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) 50 | self.num_patches = self.grid_size[0] * self.grid_size[1] 51 | else: 52 | self.img_size = None 53 | self.grid_size = None 54 | self.num_patches = None 55 | 56 | if output_fmt is not None: 57 | self.flatten = False 58 | self.output_fmt = Format(output_fmt) 59 | else: 60 | # flatten spatial dim and transpose to channels last, kept for bwd compat 61 | self.flatten = flatten 62 | self.output_fmt = Format.NCHW 63 | 64 | self.strict_img_size = strict_img_size 65 | self.dynamic_img_pad = dynamic_img_pad 66 | self.patch = nn.Sequential( 67 | nn.Conv1d(1, 32, kernel_size=15, stride=1, bias=bias), 68 | nn.BatchNorm1d(32), 69 | nn.ReLU(), 70 | nn.Conv1d(32, 64, kernel_size=7, stride=1, bias=bias), 71 | nn.BatchNorm1d(64), 72 | nn.ReLU(), 73 | nn.Conv1d(64, embed_dim, kernel_size=50, stride=50, padding = 50,dilation= 2, bias=bias), 74 | ) 75 | self.layer_norm = nn.LayerNorm(embed_dim) 76 | 77 | def forward(self, x): 78 | B, C, H, W = x.shape 79 | if self.img_size is not None: 80 | if self.strict_img_size: 81 | _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") 82 | _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") 83 | elif not self.dynamic_img_pad: 84 | _assert( 85 | H % self.patch_size[0] == 0, 86 | f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." 87 | ) 88 | _assert( 89 | W % self.patch_size[1] == 0, 90 | f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." 91 | ) 92 | if self.dynamic_img_pad: 93 | pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0] 94 | pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1] 95 | x = F.pad(x, (0, pad_w, 0, pad_h)) 96 | # print(x.size()) 97 | if self.flatten: 98 | x = x.flatten(2) # NCHW -> NLC 99 | 100 | elif self.output_fmt != Format.NCHW: 101 | x = nchw_to(x, self.output_fmt) 102 | # print(x.size()) 103 | # 3 Convolutional Layers as described in the MAE ECG paper, along with batch normalisations. 104 | x = self.patch(x).transpose(2, 1) 105 | # print(x.size()) 106 | x = self.layer_norm(x) 107 | return x 108 | 109 | 110 | # Need to read about this. 111 | class PatchEmbedWithSize(PatchEmbed): 112 | """ 2D Image to Patch Embedding 113 | """ 114 | output_fmt: Format 115 | 116 | def __init__( 117 | self, 118 | img_size: Optional[int] = 224, 119 | patch_size: int = 16, 120 | in_chans: int = 3, 121 | embed_dim: int = 768, 122 | norm_layer: Optional[Callable] = None, 123 | flatten: bool = True, 124 | output_fmt: Optional[str] = None, 125 | bias: bool = True, 126 | ): 127 | super().__init__( 128 | img_size=img_size, 129 | patch_size=patch_size, 130 | in_chans=in_chans, 131 | embed_dim=embed_dim, 132 | norm_layer=norm_layer, 133 | flatten=flatten, 134 | output_fmt=output_fmt, 135 | bias=bias, 136 | ) 137 | 138 | def forward(self, x) -> Tuple[torch.Tensor, List[int]]: 139 | B, C, H, W = x.shape 140 | if self.img_size is not None: 141 | _assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).") 142 | _assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).") 143 | 144 | x = self.proj(x) 145 | grid_size = x.shape[-2:] 146 | if self.flatten: 147 | x = x.flatten(2).transpose(1, 2) # NCHW -> NLC 148 | elif self.output_fmt != Format.NCHW: 149 | x = nchw_to(x, self.output_fmt) 150 | x = self.norm(x) 151 | return x, grid_size 152 | 153 | 154 | 155 | 156 | """Resample the weights of the patch embedding kernel to target resolution. 157 | We resample the patch embedding kernel by approximately inverting the effect 158 | of patch resizing. 159 | 160 | Code based on: 161 | https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py 162 | 163 | With this resizing, we can for example load a B/8 filter into a B/16 model 164 | and, on 2x larger input image, the result will match. 165 | 166 | Args: 167 | patch_embed: original parameter to be resized. 168 | new_size (tuple(int, int): target shape (height, width)-only. 169 | interpolation (str): interpolation for resize 170 | antialias (bool): use anti-aliasing filter in resize 171 | verbose (bool): log operation 172 | Returns: 173 | Resized patch embedding kernel. 174 | """ 175 | -------------------------------------------------------------------------------- /engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | from torchmetrics.classification import MultilabelAUROC 16 | 17 | import torch 18 | 19 | from timm.data import Mixup 20 | from timm.utils import accuracy 21 | 22 | import utils.misc as misc 23 | import utils.lr_sched as lr_sched 24 | from sklearn.metrics import accuracy_score, roc_auc_score 25 | 26 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 27 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 28 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 29 | mixup_fn: Optional[Mixup] = None, log_writer=None, 30 | args=None): 31 | model.train(True) 32 | metric_logger = misc.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 34 | header = 'Epoch: [{}]'.format(epoch) 35 | print_freq = 20 36 | 37 | accum_iter = args.accum_iter 38 | 39 | optimizer.zero_grad() 40 | 41 | if log_writer is not None: 42 | print('log_dir: {}'.format(log_writer.log_dir)) 43 | 44 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 45 | 46 | # we use a per iteration (instead of per epoch) lr scheduler 47 | if(args.classf_type != "multi_label"): 48 | targets = targets[:, 0] 49 | if data_iter_step % accum_iter == 0: 50 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 51 | 52 | if args.cuda is not None: 53 | samples = samples.to(device, non_blocking=True) 54 | targets = targets.to(device, non_blocking=True) 55 | 56 | if mixup_fn is not None: 57 | samples, targets = mixup_fn(samples, targets) 58 | 59 | if args.cuda is not None: 60 | with torch.cuda.amp.autocast(): 61 | outputs = model(samples) 62 | loss = criterion(outputs, targets) 63 | else: 64 | outputs = model(samples) 65 | # print(outputs.size()) 66 | # print(targets.size()) 67 | loss = criterion(outputs, targets) 68 | 69 | loss_value = loss.item() 70 | 71 | if not math.isfinite(loss_value): 72 | print("Loss is {}, stopping training".format(loss_value)) 73 | sys.exit(1) 74 | 75 | loss /= accum_iter 76 | loss_scaler(loss, optimizer, clip_grad=max_norm, 77 | parameters=model.parameters(), create_graph=False, 78 | update_grad=(data_iter_step + 1) % accum_iter == 0) 79 | if (data_iter_step + 1) % accum_iter == 0: 80 | optimizer.zero_grad() 81 | 82 | if args.cuda is not None: 83 | torch.cuda.synchronize() 84 | 85 | metric_logger.update(loss=loss_value) 86 | min_lr = 10. 87 | max_lr = 0. 88 | for group in optimizer.param_groups: 89 | min_lr = min(min_lr, group["lr"]) 90 | max_lr = max(max_lr, group["lr"]) 91 | 92 | metric_logger.update(lr=max_lr) 93 | 94 | loss_value_reduce = misc.all_reduce_mean(loss_value) 95 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 96 | """ We use epoch_1000x as the x-axis in tensorboard. 97 | This calibrates different curves when batch size changes. 98 | """ 99 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 100 | log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) 101 | log_writer.add_scalar('lr', max_lr, epoch_1000x) 102 | 103 | # gather the stats from all processes 104 | # metric_logger.synchronize_between_processes() 105 | print("Averaged stats:", metric_logger) 106 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 107 | 108 | @torch.no_grad() 109 | def evaluate(data_loader, model, device, args): 110 | criterion = torch.nn.BCEWithLogitsLoss() 111 | 112 | metric_logger = misc.MetricLogger(delimiter=" ") 113 | metric_logger.add_meter('auc') # Add this line 114 | 115 | header = 'Test:' 116 | 117 | # switch to evaluation mode 118 | model.eval() 119 | trues = [] 120 | preds = [] 121 | 122 | for batch in metric_logger.log_every(data_loader, 10, header): 123 | images = batch[0] 124 | target = batch[-1] 125 | if(args.classf_type != "multi_label"): 126 | target = target[:, 0] 127 | if args.cuda is not None: 128 | images = images.to(device, non_blocking=True) 129 | target = target.to(device, non_blocking=True) 130 | 131 | # compute output 132 | if args.cuda is not None: 133 | with torch.cuda.amp.autocast(): 134 | output = model(images) 135 | loss = criterion(output, target) 136 | else: 137 | output = model(images) 138 | loss = criterion(output, target) 139 | if(args.classf_type != "multi_label"): 140 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 141 | batch_size = images.shape[0] 142 | metric_logger.update(loss=loss.item()) 143 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 144 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 145 | else: 146 | acc1 = accuracy_score(target.cpu(), torch.sigmoid(output.cpu()) > 0.5)*100 147 | # ml_auroc = MultilabelAUROC(num_labels=args.nb_classes, average="macro", thresholds=None) 148 | # auc = ml_auroc(torch.sigmoid(output.cpu()), target.cpu().int()) 149 | batch_size = images.shape[0] 150 | metric_logger.update(loss=loss.item()) 151 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 152 | # metric_logger.meters['auc'].update(auc, n=batch_size) 153 | trues.append(target.cpu().int()) 154 | preds.append(torch.sigmoid(output.detach().cpu())) 155 | # gather the stats from all processes 156 | 157 | metric_logger.synchronize_between_processes() 158 | ml_auroc = MultilabelAUROC(num_labels=args.nb_classes, average="macro", thresholds=None) 159 | auc = ml_auroc(torch.cat(preds), torch.cat(trues)) 160 | metric_logger.meters['auc'].update(auc) # Update the AUC meter 161 | 162 | print('* Acc@1 {top1.global_avg:.3f} auc {aucs:.3f} loss {losses.global_avg:.3f}' 163 | .format(top1=metric_logger.acc1, aucs = auc, losses=metric_logger.loss)) 164 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import numpy as np 15 | import os 16 | import time 17 | 18 | from pathlib import Path 19 | 20 | import torch.nn.functional as F 21 | import wfdb 22 | from wfdb import processing 23 | import re 24 | from os.path import exists 25 | import glob 26 | 27 | import torch 28 | import torch.backends.cudnn as cudnn 29 | from torch.utils.tensorboard import SummaryWriter 30 | import torchvision.transforms as transforms 31 | import torchvision.datasets as datasets 32 | import timm 33 | # assert timm.__version__ == "0.3.2" # version check 34 | import timm.optim.optim_factory as optim_factory 35 | 36 | import utils.misc as misc 37 | from utils.misc import NativeScalerWithGradNormCount as NativeScaler 38 | from utils.ecg_dataloader import CustomDataset 39 | import models_mae 40 | from torchsummary import summary 41 | 42 | from engine_pretrain import train_one_epoch 43 | 44 | torch.autograd.set_detect_anomaly(True) 45 | 46 | def get_args_parser(): 47 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 48 | parser.add_argument('--batch_size', default=64, type=int, 49 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 50 | parser.add_argument('--epochs', default=800, type=int) 51 | parser.add_argument('--accum_iter', default=1, type=int, 52 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 53 | 54 | # Model parameters 55 | parser.add_argument('--model', default='mae_vit_1dcnn', type=str, metavar='MODEL', 56 | help='Name of model to train') 57 | 58 | parser.add_argument('--input_size', default=(12, 1000), type=int, 59 | help='images input size') 60 | 61 | parser.add_argument('--mask_ratio', default=0.75, type=float, 62 | help='Masking ratio (percentage of removed patches).') 63 | 64 | parser.add_argument('--norm_pix_loss', action='store_true', 65 | help='Use (per-patch) normalized pixels as targets for computing loss') 66 | parser.set_defaults(norm_pix_loss=True) 67 | 68 | # Optimizer parameters 69 | parser.add_argument('--weight_decay', type=float, default=0.05, 70 | help='weight decay (default: 0.05)') 71 | 72 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 73 | help='learning rate (absolute lr)') 74 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 75 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 76 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 77 | help='lower lr bound for cyclic schedulers that hit 0') 78 | 79 | parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', 80 | help='epochs to warmup LR') 81 | 82 | # Dataset parameters 83 | parser.add_argument('--data_path', default='/Users/parthagrawal02/Desktop/Carelog/ECG_CNN', type=str, 84 | help='dataset path') 85 | 86 | parser.add_argument('--output_dir', default='./output_dir', 87 | help='path where to save, empty for no saving') 88 | parser.add_argument('--log_dir', default='./output_dir', 89 | help='path where to tensorboard log') 90 | parser.add_argument('--device', default='cuda', 91 | help='device to use for training / testing') 92 | parser.add_argument('--seed', default=0, type=int) 93 | parser.add_argument('--resume', default='', 94 | help='resume from checkpoint') 95 | 96 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 97 | help='start epoch') 98 | parser.add_argument('--num_workers', default=1, type=int) 99 | parser.add_argument('--pin_mem', action='store_true', 100 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 101 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 102 | parser.set_defaults(pin_mem=True) 103 | parser.add_argument('--cuda', default=None, type=str, help = 'to enable cuda give, change it from None to any str') 104 | 105 | 106 | # distributed training parameters 107 | parser.add_argument('--world_size', default=1, type=int, 108 | help='number of distributed processes') 109 | parser.add_argument('--start', default=0, type=int) 110 | parser.add_argument('--end', default=4, type=int) 111 | 112 | 113 | parser.add_argument('--local_rank', default=-1, type=int) 114 | parser.add_argument('--dist_on_itp', action='store_true') 115 | parser.add_argument('--dist_url', default='env://', 116 | help='url used to set up distributed training') 117 | 118 | return parser 119 | 120 | 121 | def main(args): 122 | misc.init_distributed_mode(args) 123 | 124 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 125 | print("{}".format(args).replace(', ', ',\n')) 126 | 127 | device = torch.device(args.device) 128 | 129 | # fix the seed for reproducibility 130 | seed = args.seed + misc.get_rank() 131 | torch.manual_seed(seed) 132 | np.random.seed(seed) 133 | if args.cuda is not None: 134 | cudnn.benchmark = True 135 | 136 | # Physionet Dataset - change range n from (1, 46) to the number of folders you need 137 | # Custom Dataloader, arguments - data_path, start file and end file (from the 46 folders) 138 | dataset = CustomDataset(args.data_path, args.start, args.end) 139 | # print(dataset.size()) 140 | sampler_train = torch.utils.data.RandomSampler(dataset) 141 | 142 | if args.log_dir is not None: 143 | os.makedirs(args.log_dir, exist_ok=True) 144 | log_writer = SummaryWriter(log_dir=args.log_dir) 145 | else: 146 | log_writer = None 147 | 148 | data_loader_train = torch.utils.data.DataLoader( 149 | dataset, sampler=sampler_train, 150 | batch_size=args.batch_size, 151 | num_workers=args.num_workers, 152 | pin_memory=args.pin_mem, 153 | drop_last=True, 154 | ) 155 | 156 | # define the model 157 | model = models_mae.__dict__[args.model](norm_pix_loss=args.norm_pix_loss) 158 | if args.cuda is not None: 159 | model.to(device) 160 | model = model.double() 161 | model.train() 162 | model_without_ddp = model 163 | print("Model = %s" % str(model_without_ddp)) 164 | 165 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 166 | 167 | if args.lr is None: # only base_lr is specified 168 | args.lr = args.blr * eff_batch_size / 256 169 | 170 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 171 | print("actual lr: %.2e" % args.lr) 172 | 173 | print("accumulate grad iterations: %d" % args.accum_iter) 174 | print("effective batch size: %d" % eff_batch_size) 175 | 176 | 177 | # following timm: set wd as 0 for bias and norm layers 178 | param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) 179 | 180 | dis_optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 181 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 182 | # print(param_groups) 183 | # print(optimizer) 184 | loss_scaler = NativeScaler() 185 | 186 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 187 | 188 | print(f"Start training for {args.epochs} epochs") 189 | start_time = time.time() 190 | 191 | # if args.resume != '': 192 | # checkpoint = torch.load("output_dir/checkpoint-" + str(args.start_epoch) + ".pth") 193 | # model.load_state_dict(checkpoint['model']) 194 | # epoch = checkpoint['epoch'] 195 | 196 | for epoch in range(args.start_epoch, args.epochs): 197 | train_stats = train_one_epoch( 198 | model, data_loader_train, 199 | optimizer, device, epoch, loss_scaler, 200 | log_writer=log_writer, 201 | args=args 202 | ) 203 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 204 | 'epoch': epoch,} 205 | 206 | if args.output_dir and (epoch % 2 == 0 or epoch + 1 == args.epochs): 207 | misc.save_model( 208 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 209 | loss_scaler=loss_scaler, epoch=epoch) 210 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 211 | 'epoch': epoch,} 212 | 213 | if args.output_dir: 214 | if log_writer is not None: 215 | log_writer.flush() 216 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 217 | f.write(json.dumps(log_stats) + "\n") 218 | 219 | total_time = time.time() - start_time 220 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 221 | print('Training time {}'.format(total_time_str)) 222 | 223 | if __name__ == '__main__': 224 | args = get_args_parser() 225 | args = args.parse_args() 226 | if args.output_dir: 227 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 228 | main(args) 229 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | import numpy as np 17 | from collections import defaultdict, deque 18 | from pathlib import Path 19 | import matplotlib.pyplot as plt 20 | import torch 21 | import torch.distributed as dist 22 | from torch import inf 23 | torch.autograd.set_detect_anomaly(True) 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 50 | dist.barrier() 51 | dist.all_reduce(t) 52 | t = t.tolist() 53 | self.count = int(t[0]) 54 | self.total = t[1] 55 | 56 | @property 57 | def median(self): 58 | d = torch.tensor(list(self.deque)) 59 | return d.median().item() 60 | 61 | @property 62 | def avg(self): 63 | d = torch.tensor(list(self.deque), dtype=torch.float32) 64 | return d.mean().item() 65 | 66 | @property 67 | def global_avg(self): 68 | return self.total / self.count 69 | 70 | @property 71 | def max(self): 72 | return max(self.deque) 73 | 74 | @property 75 | def value(self): 76 | return self.deque[-1] 77 | 78 | def __str__(self): 79 | return self.fmt.format( 80 | median=self.median, 81 | avg=self.avg, 82 | global_avg=self.global_avg, 83 | max=self.max, 84 | value=self.value) 85 | 86 | 87 | class MetricLogger(object): 88 | def __init__(self, delimiter="\t"): 89 | self.meters = defaultdict(SmoothedValue) 90 | self.delimiter = delimiter 91 | 92 | def update(self, **kwargs): 93 | for k, v in kwargs.items(): 94 | if v is None: 95 | continue 96 | if isinstance(v, torch.Tensor): 97 | v = v.item() 98 | assert isinstance(v, (float, int)) 99 | self.meters[k].update(v) 100 | 101 | def __getattr__(self, attr): 102 | if attr in self.meters: 103 | return self.meters[attr] 104 | if attr in self.__dict__: 105 | return self.__dict__[attr] 106 | raise AttributeError("'{}' object has no attribute '{}'".format( 107 | type(self).__name__, attr)) 108 | 109 | def __str__(self): 110 | loss_str = [] 111 | for name, meter in self.meters.items(): 112 | loss_str.append( 113 | "{}: {}".format(name, str(meter)) 114 | ) 115 | return self.delimiter.join(loss_str) 116 | 117 | def synchronize_between_processes(self): 118 | for meter in self.meters.values(): 119 | meter.synchronize_between_processes() 120 | 121 | def add_meter(self, name, meter): 122 | self.meters[name] = meter 123 | 124 | def log_every(self, iterable, print_freq, header=None): 125 | i = 0 126 | if not header: 127 | header = '' 128 | start_time = time.time() 129 | end = time.time() 130 | iter_time = SmoothedValue(fmt='{avg:.4f}') 131 | data_time = SmoothedValue(fmt='{avg:.4f}') 132 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 133 | log_msg = [ 134 | header, 135 | '[{0' + space_fmt + '}/{1}]', 136 | 'eta: {eta}', 137 | '{meters}', 138 | 'time: {time}', 139 | 'data: {data}' 140 | ] 141 | if torch.cuda.is_available(): 142 | log_msg.append('max mem: {memory:.0f}') 143 | log_msg = self.delimiter.join(log_msg) 144 | MB = 1024.0 * 1024.0 145 | for obj in iterable: 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if i % print_freq == 0 or i == len(iterable) - 1: 150 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 152 | if torch.cuda.is_available(): 153 | print(log_msg.format( 154 | i, len(iterable), eta=eta_string, 155 | meters=str(self), 156 | time=str(iter_time), data=str(data_time), 157 | memory=torch.cuda.max_memory_allocated() / MB)) 158 | else: 159 | print(log_msg.format( 160 | i, len(iterable), eta=eta_string, 161 | meters=str(self), 162 | time=str(iter_time), data=str(data_time))) 163 | i += 1 164 | end = time.time() 165 | total_time = time.time() - start_time 166 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 167 | print('{} Total time: {} ({:.4f} s / it)'.format( 168 | header, total_time_str, total_time / len(iterable))) 169 | 170 | 171 | def setup_for_distributed(is_master): 172 | """ 173 | This function disables printing when not in master process 174 | """ 175 | builtin_print = builtins.print 176 | 177 | def print(*args, **kwargs): 178 | force = kwargs.pop('force', False) 179 | force = force or (get_world_size() > 8) 180 | if is_master or force: 181 | now = datetime.datetime.now().time() 182 | builtin_print('[{}] '.format(now), end='') # print with time stamp 183 | builtin_print(*args, **kwargs) 184 | 185 | builtins.print = print 186 | 187 | 188 | def is_dist_avail_and_initialized(): 189 | if not dist.is_available(): 190 | return False 191 | if not dist.is_initialized(): 192 | return False 193 | return True 194 | 195 | 196 | def get_world_size(): 197 | if not is_dist_avail_and_initialized(): 198 | return 1 199 | return dist.get_world_size() 200 | 201 | 202 | def get_rank(): 203 | if not is_dist_avail_and_initialized(): 204 | return 0 205 | return dist.get_rank() 206 | 207 | 208 | def is_main_process(): 209 | return get_rank() == 0 210 | 211 | 212 | def save_on_master(*args, **kwargs): 213 | if is_main_process(): 214 | torch.save(*args, **kwargs) 215 | 216 | 217 | def init_distributed_mode(args): 218 | if args.dist_on_itp: 219 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 220 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 221 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 222 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 223 | os.environ['LOCAL_RANK'] = str(args.gpu) 224 | os.environ['RANK'] = str(args.rank) 225 | os.environ['WORLD_SIZE'] = str(args.world_size) 226 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 227 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 228 | args.rank = int(os.environ["RANK"]) 229 | args.world_size = int(os.environ['WORLD_SIZE']) 230 | args.gpu = int(os.environ['LOCAL_RANK']) 231 | elif 'SLURM_PROCID' in os.environ: 232 | args.rank = int(os.environ['SLURM_PROCID']) 233 | args.gpu = args.rank % torch.cuda.device_count() 234 | else: 235 | print('Not using distributed mode') 236 | setup_for_distributed(is_master=True) # hack 237 | args.distributed = False 238 | return 239 | 240 | args.distributed = True 241 | 242 | torch.cuda.set_device(args.gpu) 243 | args.dist_backend = 'nccl' 244 | print('| distributed init (rank {}): {}, gpu {}'.format( 245 | args.rank, args.dist_url, args.gpu), flush=True) 246 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 247 | world_size=args.world_size, rank=args.rank) 248 | torch.distributed.barrier() 249 | setup_for_distributed(args.rank == 0) 250 | 251 | 252 | class NativeScalerWithGradNormCount: 253 | state_dict_key = "amp_scaler" 254 | 255 | def __init__(self): 256 | self._scaler = torch.cuda.amp.GradScaler() 257 | 258 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True, retain_graph = False): 259 | self._scaler.scale(loss).backward(create_graph=create_graph, retain_graph= retain_graph) 260 | if update_grad: 261 | if clip_grad is not None: 262 | assert parameters is not None 263 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 264 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 265 | else: 266 | self._scaler.unscale_(optimizer) 267 | norm = get_grad_norm_(parameters) 268 | self._scaler.step(optimizer) 269 | self._scaler.update() 270 | else: 271 | norm = None 272 | return norm 273 | 274 | def state_dict(self): 275 | return self._scaler.state_dict() 276 | 277 | def load_state_dict(self, state_dict): 278 | self._scaler.load_state_dict(state_dict) 279 | 280 | 281 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 282 | if isinstance(parameters, torch.Tensor): 283 | parameters = [parameters] 284 | parameters = [p for p in parameters if p.grad is not None] 285 | norm_type = float(norm_type) 286 | if len(parameters) == 0: 287 | return torch.tensor(0.) 288 | device = parameters[0].grad.device 289 | if norm_type == inf: 290 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 291 | else: 292 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 293 | return total_norm 294 | 295 | 296 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 297 | output_dir = Path(args.output_dir) 298 | epoch_name = str(epoch) 299 | if loss_scaler is not None: 300 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 301 | for checkpoint_path in checkpoint_paths: 302 | to_save = { 303 | 'model': model_without_ddp.state_dict(), 304 | 'optimizer': optimizer.state_dict(), 305 | 'epoch': epoch, 306 | 'scaler': loss_scaler.state_dict(), 307 | 'args': args, 308 | } 309 | 310 | save_on_master(to_save, checkpoint_path) 311 | else: 312 | client_state = {'epoch': epoch} 313 | model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) 314 | 315 | 316 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 317 | if args.resume: 318 | if args.resume.startswith('https'): 319 | checkpoint = torch.hub.load_state_dict_from_url( 320 | args.resume, map_location='cpu', check_hash=True) 321 | else: 322 | checkpoint = torch.load(args.resume, map_location='cpu') 323 | model_without_ddp.load_state_dict(checkpoint['model'], strict = False) 324 | print("Resume checkpoint %s" % args.resume) 325 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 326 | optimizer.load_state_dict(checkpoint['optimizer']) 327 | args.start_epoch = checkpoint['epoch'] + 1 328 | if 'scaler' in checkpoint: 329 | loss_scaler.load_state_dict(checkpoint['scaler']) 330 | print("With optim & sched!") 331 | 332 | 333 | def all_reduce_mean(x): 334 | world_size = get_world_size() 335 | if world_size > 1: 336 | x_reduce = torch.tensor(x).cuda() 337 | dist.all_reduce(x_reduce) 338 | x_reduce /= world_size 339 | return x_reduce.item() 340 | else: 341 | return x 342 | 343 | def plot_reconstruction(currupt_img, samples, size = 1): 344 | 345 | fig = plt.figure(figsize=(12, 48)) 346 | for idx in np.arange(size): 347 | fig, ax = plt.subplots(2, figsize=(10, 6)) 348 | ax[0].plot(samples[0, 0, idx].numpy()) 349 | ax[0].set_title('Original ECG') 350 | 351 | # Plot processed ECG 352 | ax[1].plot(currupt_img[0, 0, idx].cpu().detach().numpy()) 353 | ax[1].set_title('Processed ECG') 354 | plt.tight_layout() 355 | return fig 356 | 357 | pass 358 | -------------------------------------------------------------------------------- /models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pyto/Users/parthagrawal02/Library/CloudStorage/GoogleDrive-acads.parth@gmail.com/My Drive/project_folder/ECG_MAE_code/main_pretrain.pyrch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import torch 15 | import torch.nn as nn 16 | import numpy as np 17 | 18 | from timm.models.vision_transformer import Block 19 | from utils.patch_embed import PatchEmbed 20 | 21 | from utils.pos_embed import get_2d_sincos_pos_embed 22 | 23 | 24 | # Main changes img_size adjusted to 12 channel ECG signal - 12*1000 25 | # Functions - Patchify and unpatchify 26 | # Other functions remains the same 27 | 28 | 29 | class MaskedAutoencoderViT(nn.Module): 30 | """ Masked Autoencoder with VisionTransformer backbone 31 | """ 32 | def __init__(self, img_size=(12, 1000), patch_size=(1, 50), in_chans=1, 33 | embed_dim=128, depth=6, num_heads=8, 34 | decoder_embed_dim=64, decoder_depth=3, decoder_num_heads=8, 35 | mlp_ratio=3., norm_layer=nn.LayerNorm, norm_pix_loss=False): 36 | super().__init__() 37 | 38 | # -------------------------------------------------------------------------- 39 | # MAE encoder specifics 40 | # self.linear1 = nn.Linear(1, 32, bias=True) 41 | # self.linear2 = nn.Linear(1, 32, bias=True) 42 | self.output_shape = (img_size[0]//patch_size[0], img_size[1]//patch_size[1]) 43 | 44 | self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) 45 | num_patches = self.patch_embed.num_patches 46 | 47 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 48 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 49 | 50 | self.blocks = nn.ModuleList([ 51 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 52 | for i in range(depth)]) 53 | self.norm = norm_layer(embed_dim) 54 | 55 | # Discriminator Specifics 56 | self.discriminate = nn.Linear(embed_dim, 1, bias = True) 57 | 58 | # -------------------------------------------------------------------------- 59 | # -------------------------------------------------------------------------- 60 | # MAE decoder specifics 61 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 62 | 63 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 64 | 65 | self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding 66 | 67 | self.decoder_blocks = nn.ModuleList([ 68 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 69 | for i in range(decoder_depth)]) 70 | 71 | self.decoder_norm = norm_layer(decoder_embed_dim) 72 | self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size[0]* patch_size[1] * in_chans, bias=True) # decoder to patch 73 | # -------------------------------------------------------------------------- 74 | 75 | self.norm_pix_loss = norm_pix_loss 76 | 77 | self.initialize_weights() 78 | 79 | 80 | def initialize_weights(self): 81 | # initialization 82 | # initialize (and freeze) pos_embed by sin-cos embedding 83 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], (12, self.patch_embed.num_patches//12), cls_token=True) 84 | # grid = (height, width). height = 12 here for 12 lead ecg signals 85 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 86 | 87 | decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], (12, self.patch_embed.num_patches//12), cls_token=True) 88 | self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 89 | 90 | for layer in self.patch_embed.patch: 91 | if isinstance(layer, nn.Conv1d): 92 | torch.nn.init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu') 93 | # if layer.bias is not None: 94 | # torch.nn.init.constant_(layer.bias, 0.0) 95 | elif isinstance(layer, nn.LayerNorm): 96 | nn.init.constant_(layer.bias, 0) 97 | nn.init.constant_(layer.weight, 1.0) 98 | 99 | 100 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 101 | torch.nn.init.normal_(self.cls_token, std=.02) 102 | torch.nn.init.normal_(self.mask_token, std=.02) 103 | 104 | # initialize nn.Linear and nn.LayerNorm 105 | self.apply(self._init_weights) 106 | 107 | def _init_weights(self, m): 108 | if isinstance(m, nn.Linear): 109 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 110 | if isinstance(m, nn.Linear) and m.bias is not None: 111 | nn.init.constant_(m.bias, 0) 112 | elif isinstance(m, nn.LayerNorm): 113 | nn.init.constant_(m.bias, 0) 114 | nn.init.constant_(m.weight, 1.0) 115 | 116 | def patchify(self, imgs): 117 | """ 118 | imgs: (N, 1, H, W) - 12 channel ECG - H = No. of channels, W = Length of ECG signal (1000 in this case) 119 | x: (N, L, patch_size_height*patch_size_width*1) 120 | """ 121 | ph = self.patch_embed.patch_size[0] 122 | pw = self.patch_embed.patch_size[1] 123 | 124 | h = imgs.shape[2] // ph 125 | w = imgs.shape[3] // pw 126 | x = imgs.reshape(shape=(imgs.shape[0], 1, h, ph, w, pw)) 127 | x = torch.einsum('nchpwq->nhwpqc', x) 128 | x = x.reshape(shape=(imgs.shape[0], h * w, ph*pw * 1)) 129 | return x 130 | 131 | def unpatchify(self, x): 132 | """ 133 | x: (N, L, patch_size_height*patch_size_width*1) 134 | imgs: (N, 1, H, W) - 12 channel ECG - H = No. of channels, W = Length of ECG signal (1000 in this case) 135 | """ 136 | ph = self.patch_embed.patch_size[0] 137 | pw = self.patch_embed.patch_size[1] 138 | 139 | # h = w = int(x.shape[1]**.5) 140 | # assert h * w == x.shape[1] 141 | h = 12 142 | w = x.shape[1]//12 143 | 144 | x = x.reshape(shape=(x.shape[0], h, w, ph, pw, 1)) 145 | x = torch.einsum('nhwpqc->nchpwq', x) 146 | imgs = x.reshape(shape=(x.shape[0], 1, h * ph, w * pw)) 147 | return imgs 148 | 149 | def random_masking(self, x, mask_ratio): 150 | """ 151 | Perform per-sample random masking by per-sample shuffling. 152 | Per-sample shuffling is done by argsort random noise. 153 | x: [N, L, D], sequence 154 | """ 155 | N, L, D = x.shape # batch, length, dim 156 | len_keep = int(L * (1 - mask_ratio)) 157 | 158 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 159 | 160 | # sort noise for each sample 161 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 162 | ids_restore = torch.argsort(ids_shuffle, dim=1) 163 | 164 | # keep the first subset 165 | ids_keep = ids_shuffle[:, :len_keep] 166 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 167 | 168 | # generate the binary mask: 0 is keep, 1 is remove 169 | mask = torch.ones([N, L], device=x.device) 170 | mask[:, :len_keep] = 0 171 | # unshuffle to get the binary mask 172 | mask = torch.gather(mask, dim=1, index=ids_restore) 173 | 174 | return x_masked, mask, ids_restore 175 | 176 | def forward_encoder(self, x, mask_ratio): 177 | # embed patches 178 | # print("before patch embed = "+str(x.size())) 179 | x = self.patch_embed(x) 180 | # print("after patch embed = "+str(x.size())) 181 | # add pos embed w/o cls token 182 | x = x + self.pos_embed[:, 1:, :] 183 | 184 | # masking: length -> length * mask_ratio 185 | x, mask, ids_restore = self.random_masking(x, mask_ratio) 186 | 187 | # append cls token 188 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 189 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 190 | x = torch.cat((cls_tokens, x), dim=1) 191 | 192 | # apply Transformer blocks 193 | for blk in self.blocks: 194 | x = blk(x) 195 | x = self.norm(x) 196 | # print(x.size()) 197 | 198 | return x, mask, ids_restore, 199 | 200 | 201 | def discriminator(self, currupt_img): 202 | 203 | x = self.patch_embed(currupt_img) 204 | # print("after patch embed = "+str(x.size())) 205 | # add pos embed w/o cls token 206 | x = x + self.pos_embed[:, 1:, :] 207 | 208 | # masking: length -> length * mask_ratio 209 | # x, mask, ids_restore = self.random_masking(x, mask_ratio) 210 | # append cls token 211 | 212 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 213 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 214 | x = torch.cat((cls_tokens, x), dim=1) 215 | 216 | # apply Transformer blocks 217 | for blk in self.blocks: 218 | x = blk(x) 219 | x = self.norm(x) 220 | 221 | # x = x.view(x.size(0), -1) 222 | 223 | x = self.discriminate(x) 224 | return torch.sigmoid(x) 225 | 226 | 227 | def discriminator_loss(self, x, mask): 228 | # Real and fake discriminator outputs 229 | output = self.discriminator(x) 230 | output = output[:, 1:, 0] 231 | target = 1 - mask 232 | target = target.double() 233 | 234 | disc_loss = torch.nn.BCELoss() 235 | return disc_loss(output, target) 236 | 237 | 238 | def adv_loss(self, currupt_img, mask): 239 | target = 1 - mask # This flips the mask values 240 | output = self.discriminator(currupt_img) 241 | disc_preds = output[:, 1:, 0] 242 | 243 | # Reshape target to match the discriminator output shape 244 | target = target.view(disc_preds.shape) 245 | target = target.float() 246 | 247 | # Calculate the number of correct predictions for original and reconstructed patches 248 | corr_orig = (torch.log(disc_preds + 1e-8) * target).sum()/(target.sum()) 249 | corr_recons = (torch.log((1-disc_preds + 1e-8))*(1 - target)).sum()/((1-target).sum()) 250 | # print(corr_orig) 251 | # print(corr_orig + corr_recons) 252 | return (corr_orig) + (corr_recons) 253 | 254 | 255 | def forward_decoder(self, x, ids_restore): 256 | # embed tokens 257 | x = self.decoder_embed(x) 258 | 259 | # append mask tokens to sequence 260 | mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) 261 | x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token 262 | x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 263 | x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token 264 | 265 | # add pos embed 266 | x = x + self.decoder_pos_embed 267 | 268 | # apply Transformer blocks 269 | for blk in self.decoder_blocks: 270 | x = blk(x) 271 | x = self.decoder_norm(x) 272 | 273 | # predictor projection 274 | x = self.decoder_pred(x) 275 | 276 | # remove cls token 277 | x = x[:, 1:, :] 278 | 279 | return x 280 | 281 | def forward_loss(self, imgs, pred, mask): 282 | """ 283 | imgs: (N, 1, H, W) - 12 channel ECG - H = No. of channels, W = Length of ECG signal (1000 in this case) 284 | x: (N, L, patch_size_height*patch_size_width*1) 285 | mask: [N, L], 0 is keep, 1 is remove, 286 | """ 287 | # pred = self.unpatchify(pred) 288 | target = self.patchify(imgs) 289 | if self.norm_pix_loss: 290 | mean = target.mean(dim=-1, keepdim=True) 291 | var = target.var(dim=-1, keepdim=True) 292 | target = (target - mean) / (var + 1.e-6)**.5 293 | 294 | if torch.isnan(target).any(): 295 | print("NaN values found in target") 296 | if torch.isnan(pred).any(): 297 | print("NaN values found in pred tensors") 298 | 299 | loss = (pred - target)**2 300 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 301 | # print(loss) 302 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 303 | # loss = (loss).sum() / len(loss)*240 304 | # print(loss) 305 | return loss 306 | 307 | def forward(self, imgs, mask_ratio=0.75): 308 | # print(imgs.size()) 309 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 310 | # print(latent.size()) 311 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 312 | # print(pred.size()) 313 | mae_loss = self.forward_loss(imgs, pred, mask) 314 | 315 | # loss = loss + self.adaptive_weight()*self.adv_loss() 316 | ### 317 | # currupt_img = reconstructed masked patches + unmasked patches 318 | img_patched = self.patchify(imgs) 319 | currupt_img = torch.zeros(img_patched.size()) 320 | mask1 = mask.unsqueeze(-1).expand_as(pred) 321 | currupt_img = torch.where(mask1 == 1, pred, img_patched) 322 | currupt_img = self.unpatchify(currupt_img) 323 | ### 324 | 325 | disc_loss = self.discriminator_loss(currupt_img, mask) 326 | adv_loss = self.adv_loss(currupt_img, mask) 327 | 328 | return mae_loss, pred, mask, disc_loss, adv_loss, currupt_img 329 | 330 | # Model architecture as described in the paper. 331 | def mae_vit_1dcnn(**kwargs): 332 | model = MaskedAutoencoderViT( 333 | patch_size=(1, 50), embed_dim=128, depth=6, num_heads=8, 334 | decoder_embed_dim=64, decoder_depth=3, decoder_num_heads=8, 335 | mlp_ratio=3, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 336 | return model 337 | 338 | def mae_vit_base_patch16_dec512d8b(**kwargs): 339 | model = MaskedAutoencoderViT( 340 | patch_size=16, embed_dim=768, depth=12, num_heads=12, 341 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 342 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 343 | return model 344 | 345 | def mae_vit_large_patch16_dec512d8b(**kwargs): 346 | model = MaskedAutoencoderViT( 347 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, 348 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 349 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 350 | return model 351 | 352 | def mae_vit_huge_patch14_dec512d8b(**kwargs): 353 | model = MaskedAutoencoderViT( 354 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, 355 | decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, 356 | mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 357 | return model 358 | 359 | # set recommended archs 360 | mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks 361 | mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks 362 | mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks 363 | -------------------------------------------------------------------------------- /main_linprobe.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # MoCo v3: https://github.com/facebookresearch/moco-v3 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import numpy as np 16 | import os 17 | import time 18 | from pathlib import Path 19 | 20 | import torch 21 | import torch.backends.cudnn as cudnn 22 | from torch.utils.tensorboard import SummaryWriter 23 | import torchvision.transforms as transforms 24 | import torchvision.datasets as datasets 25 | import pandas as pd 26 | import numpy as np 27 | import wfdb 28 | import ast 29 | import torch 30 | 31 | import timm 32 | 33 | # assert timm.__version__ == "0.3.2" # version check 34 | from timm.models.layers import trunc_normal_ 35 | 36 | import utils.misc as misc 37 | from utils.pos_embed import interpolate_pos_embed 38 | from utils.misc import NativeScalerWithGradNormCount as NativeScaler 39 | from utils.lars import LARS 40 | from utils.crop import RandomResizedCrop 41 | 42 | import vit_model 43 | from utils.ecg_dataloader import CustomDataset 44 | 45 | from engine_finetune import train_one_epoch, evaluate 46 | 47 | 48 | def get_args_parser(): 49 | parser = argparse.ArgumentParser('MAE linear probing for ECG classification', add_help=False) 50 | parser.add_argument('--batch_size', default=512, type=int, 51 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 52 | parser.add_argument('--epochs', default=90, type=int) 53 | parser.add_argument('--accum_iter', default=1, type=int, 54 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 55 | 56 | # Model parameters 57 | parser.add_argument('--model', default='vit_1dcnn', type=str, metavar='MODEL', 58 | help='Name of model to train') 59 | 60 | # Optimizer parameters 61 | parser.add_argument('--weight_decay', type=float, default=0, 62 | help='weight decay (default: 0 for linear probe following MoCo v1)') 63 | 64 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 65 | help='learning rate (absolute lr)') 66 | parser.add_argument('--blr', type=float, default=0.1, metavar='LR', 67 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 68 | 69 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 70 | help='lower lr bound for cyclic schedulers that hit 0') 71 | 72 | parser.add_argument('--warmup_epochs', type=int, default=10, metavar='N', 73 | help='epochs to warmup LR') 74 | 75 | # * Finetuning params 76 | parser.add_argument('--finetune', default='', 77 | help='finetune from checkpoint') 78 | parser.add_argument('--global_pool', action='store_true') 79 | parser.set_defaults(global_pool=False) 80 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 81 | help='Use class token instead of global pool for classification') 82 | 83 | # Dataset parameters 84 | parser.add_argument('--data_path', default='/Users/parthagrawal02/Desktop/Carelog/ECG_CNN/physionet', type=str, 85 | help='dataset path') 86 | parser.add_argument('--nb_classes', default=10, type=int, 87 | help='number of the classification types') 88 | 89 | parser.add_argument('--output_dir', default='./output_dir', 90 | help='path where to save, empty for no saving') 91 | parser.add_argument('--log_dir', default='./output_dir', 92 | help='path where to tensorboard log') 93 | parser.add_argument('--device', default='cuda', 94 | help='device to use for training / testing') 95 | parser.add_argument('--seed', default=0, type=int) 96 | parser.add_argument('--resume', default='', 97 | help='resume from checkpoint') 98 | 99 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 100 | help='start epoch') 101 | parser.add_argument('--eval', action='store_true', 102 | help='Perform evaluation only') 103 | parser.add_argument('--dist_eval', action='store_true', default=False, 104 | help='Enabling distributed evaluation (recommended during training for faster monitor') 105 | parser.add_argument('--num_workers', default=10, type=int) 106 | parser.add_argument('--pin_mem', action='store_true', 107 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 108 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 109 | parser.set_defaults(pin_mem=True) 110 | 111 | # distributed training parameters 112 | parser.add_argument('--world_size', default=1, type=int, 113 | help='number of distributed processes') 114 | parser.add_argument('--local_rank', default=-1, type=int) 115 | parser.add_argument('--dist_on_itp', action='store_true') 116 | parser.add_argument('--dist_url', default='env://', 117 | help='url used to set up distributed training') 118 | parser.add_argument('--cuda', default=None, 119 | help='url used to set up distributed training') 120 | parser.add_argument('--data_split', default=0.8, type= float, 121 | help='url used to set up distributed training') 122 | parser.add_argument('--mode',type=str, default="linprobe", 123 | help='Finetuning or Linear Eval') 124 | parser.add_argument('--train_start',type=int, default=31, 125 | help='train start') 126 | parser.add_argument('--train_end',type=int, default=40, 127 | help='train end') 128 | parser.add_argument('--classf_type',type=str, default="multi_label", 129 | help='classification type') 130 | return parser 131 | 132 | def main(args): 133 | misc.init_distributed_mode(args) 134 | 135 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 136 | print("{}".format(args).replace(', ', ',\n')) 137 | 138 | device = torch.device(args.device) 139 | 140 | # fix the seed for reproducibility 141 | seed = args.seed + misc.get_rank() 142 | torch.manual_seed(seed) 143 | np.random.seed(seed) 144 | 145 | if args.cuda is not None: 146 | cudnn.benchmark = True 147 | 148 | # linear probe: weak augmentation 149 | # transform_train = transforms.Compose([ 150 | # RandomResizedCrop(224, interpolation=3), 151 | # transforms.RandomHorizontalFlip(), 152 | # transforms.ToTensor(), 153 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 154 | # transform_val = transforms.Compose([ 155 | # transforms.Resize(256, interpolation=3), 156 | # transforms.CenterCrop(224), 157 | # transforms.ToTensor(), 158 | # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 159 | 160 | def load_raw_data(df, sampling_rate, path): 161 | if(sampling_rate == 100): 162 | data = [wfdb.rdsamp(path + f) for f in df.filename_lr] 163 | else: 164 | data = [wfdb.rdsamp(path + f) for f in df.filename_hr] 165 | data = np.array([signal for signal, meta in data]) 166 | return data 167 | 168 | 169 | path = args.data_path 170 | sampling_rate = 100 171 | 172 | # load and convert annotation data 173 | Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id') 174 | Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) 175 | 176 | # Load raw signal data 177 | X = load_raw_data(Y, sampling_rate, path) 178 | 179 | # Load scp_statements.csv for diagnostic aggregation 180 | agg_df = pd.read_csv(path+'scp_statements.csv', index_col=0) 181 | agg_df = agg_df[agg_df.diagnostic == 1] 182 | 183 | def aggregate_diagnostic(y_dic): 184 | tmp = [] 185 | for key in y_dic.keys(): 186 | if key in agg_df.index: 187 | tmp.append(agg_df.loc[key].diagnostic_class) 188 | return list(set(tmp)) 189 | 190 | # Apply diagnostic superclass 191 | Y['diagnostic_superclass'] = Y.scp_codes.apply(aggregate_diagnostic) 192 | 193 | # Split data into train and test 194 | test_fold = 10 195 | # Train 196 | X_train = X[np.where(Y.strat_fold != test_fold)] 197 | y_train = Y[(Y.strat_fold != test_fold)].diagnostic_superclass 198 | # Test 199 | X_test = X[np.where(Y.strat_fold == test_fold)] 200 | y_test = Y[Y.strat_fold == test_fold].diagnostic_superclass 201 | 202 | def multihot_encoder(labels, n_categories = 1, dtype=torch.float32): 203 | label_set = set() 204 | for label_list in labels: 205 | label_set = label_set.union(set(label_list)) 206 | label_set = sorted(label_set) 207 | 208 | multihot_vectors = [] 209 | for label_list in labels: 210 | multihot_vectors.append([1 if x in label_list else 0 for x in label_set]) 211 | if dtype is None: 212 | return pd.DataFrame(multihot_vectors, columns=label_set) 213 | return torch.Tensor(multihot_vectors).to(dtype) 214 | X_train = torch.tensor(X_train.transpose(0, 2, 1)) 215 | # mean = X_train.mean(dim=-1, keepdim=True) 216 | # var = X_train.var(dim=-1, keepdim=True) 217 | # X_train = (X_train - mean) / (var + 1.e-6)**.5 218 | X_test = torch.tensor(X_test.transpose(0, 2, 1)) 219 | # mean = X_test.mean(dim=-1, keepdim=True) 220 | # var = X_test.var(dim=-1, keepdim=True) 221 | # X_test = (X_test - mean) / (var + 1.e-6)**.5 222 | 223 | y_train = multihot_encoder(y_train, n_categories = 5) 224 | y_test = multihot_encoder(y_test, n_categories= 5) 225 | dataset_train = torch.utils.data.TensorDataset(torch.tensor(X_train[:, None, :, :]), torch.tensor(y_train)) 226 | dataset_val = torch.utils.data.TensorDataset(torch.tensor(X_test[:, None, :, :]), torch.tensor(y_test)) 227 | 228 | # full_dataset = CustomDataset(args.data_path, args.train_start, args.train_end) # Training Data - 229 | # train_size = int(args.data_split * len(full_dataset)) 230 | # val_size = len(full_dataset) - train_size 231 | # dataset_train, dataset_val = torch.utils.data.random_split(full_dataset, [train_size, val_size]) 232 | 233 | if args.distributed is not None: # args.distributed: 234 | num_tasks = misc.get_world_size() 235 | global_rank = misc.get_rank() 236 | sampler_train = torch.utils.data.DistributedSampler( 237 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 238 | ) 239 | print("Sampler_train = %s" % str(sampler_train)) 240 | if args.dist_eval: 241 | if len(dataset_val) % num_tasks != 0: 242 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 243 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 244 | 'equal num of samples per-process.') 245 | sampler_val = torch.utils.data.DistributedSampler( 246 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 247 | else: 248 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 249 | else: 250 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 251 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 252 | 253 | if global_rank == 0 and args.log_dir is not None and not args.eval: 254 | os.makedirs(args.log_dir, exist_ok=True) 255 | log_writer = SummaryWriter(log_dir=args.log_dir) 256 | else: 257 | log_writer = None 258 | 259 | data_loader_train = torch.utils.data.DataLoader( 260 | dataset_train, sampler=sampler_train, 261 | batch_size=args.batch_size, 262 | num_workers=args.num_workers, 263 | pin_memory=args.pin_mem, 264 | drop_last=True, 265 | ) 266 | 267 | data_loader_val = torch.utils.data.DataLoader( 268 | dataset_val, sampler=sampler_val, 269 | batch_size=args.batch_size, 270 | num_workers=args.num_workers, 271 | pin_memory=args.pin_mem, 272 | drop_last=False 273 | ) 274 | 275 | model = vit_model.__dict__[args.model]( 276 | num_classes=args.nb_classes, 277 | global_pool=args.global_pool, 278 | ) 279 | 280 | if args.finetune and not args.eval: 281 | checkpoint = torch.load(args.finetune, map_location='cpu') 282 | 283 | print("Load pre-trained checkpoint from: %s" % args.finetune) 284 | checkpoint_model = checkpoint['model'] 285 | state_dict = model.state_dict() 286 | for k in ['head.weight', 'head.bias']: 287 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 288 | print(f"Removing key {k} from pretrained checkpoint") 289 | del checkpoint_model[k] 290 | # interpolate position embedding 291 | interpolate_pos_embed(model, checkpoint_model) 292 | 293 | # load pre-trained model 294 | msg = model.load_state_dict(checkpoint_model, strict=False) 295 | print(msg) 296 | 297 | if args.global_pool: 298 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 299 | else: 300 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 301 | 302 | # manually initialize fc layer: following MoCo v3 303 | trunc_normal_(model.head.weight, std=0.01) 304 | 305 | # for linear prob only 306 | # hack: revise model's head with BN 307 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) 308 | 309 | # freeze all but the head 310 | if(args.mode == "linprobe"): 311 | for _, p in model.named_parameters(): 312 | p.requires_grad = False 313 | for _, p in model.head.named_parameters(): 314 | p.requires_grad = True 315 | model = model.double() 316 | if args.cuda is not None: 317 | model.to(device) 318 | 319 | model_without_ddp = model 320 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 321 | 322 | print("Model = %s" % str(model_without_ddp)) 323 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 324 | 325 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 326 | 327 | if args.lr is None: # only base_lr is specified 328 | args.lr = args.blr * eff_batch_size / 256 329 | 330 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 331 | print("actual lr: %.2e" % args.lr) 332 | 333 | print("accumulate grad iterations: %d" % args.accum_iter) 334 | print("effective batch size: %d" % eff_batch_size) 335 | 336 | if args.distributed: 337 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 338 | model_without_ddp = model.module 339 | 340 | optimizer = torch.optim.AdamW(model_without_ddp.head.parameters(), lr=args.lr, weight_decay=args.weight_decay) 341 | print(optimizer) 342 | loss_scaler = NativeScaler() 343 | 344 | criterion = torch.nn.BCEWithLogitsLoss() 345 | 346 | print("criterion = %s" % str(criterion)) 347 | 348 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 349 | 350 | if args.eval: 351 | test_stats = evaluate(data_loader_val, model, device, args) 352 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 353 | exit(0) 354 | 355 | print(f"Start training for {args.epochs} epochs") 356 | start_time = time.time() 357 | max_accuracy = 0.0 358 | for epoch in range(args.start_epoch, args.epochs): 359 | if args.distributed: 360 | data_loader_train.sampler.set_epoch(epoch) 361 | train_stats = train_one_epoch( 362 | model, criterion, data_loader_train, 363 | optimizer, device, epoch, loss_scaler, 364 | max_norm=None, 365 | log_writer=log_writer, 366 | args=args 367 | ) 368 | if args.output_dir: 369 | misc.save_model( 370 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 371 | loss_scaler=loss_scaler, epoch=epoch) 372 | 373 | test_stats = evaluate(data_loader_val, model, device, args) 374 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 375 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 376 | print(f'Max accuracy: {max_accuracy:.2f}%') 377 | 378 | if log_writer is not None: 379 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 380 | # log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 381 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 382 | 383 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 384 | **{f'test_{k}': v for k, v in test_stats.items()}, 385 | 'epoch': epoch, 386 | 'n_parameters': n_parameters} 387 | 388 | if args.output_dir and misc.is_main_process(): 389 | if log_writer is not None: 390 | log_writer.flush() 391 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 392 | f.write(json.dumps(log_stats) + "\n") 393 | 394 | total_time = time.time() - start_time 395 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 396 | print('Training time {}'.format(total_time_str)) 397 | 398 | 399 | if __name__ == '__main__': 400 | args = get_args_parser() 401 | args = args.parse_args() 402 | if args.output_dir: 403 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 404 | main(args) 405 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import glob 5 | import pickle 6 | import copy 7 | 8 | import pandas as pd 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from tqdm import tqdm 12 | import wfdb 13 | import ast 14 | from sklearn.metrics import fbeta_score, roc_auc_score, roc_curve, roc_curve, auc 15 | from sklearn.preprocessing import StandardScaler, MultiLabelBinarizer 16 | from matplotlib.axes._axes import _log as matplotlib_axes_logger 17 | import warnings 18 | 19 | # EVALUATION STUFF 20 | def generate_results(idxs, y_true, y_pred, thresholds): 21 | return evaluate_experiment(y_true[idxs], y_pred[idxs], thresholds) 22 | 23 | def evaluate_experiment(y_true, y_pred, thresholds=None): 24 | results = {} 25 | 26 | if not thresholds is None: 27 | # binary predictions 28 | y_pred_binary = apply_thresholds(y_pred, thresholds) 29 | # PhysioNet/CinC Challenges metrics 30 | challenge_scores = challenge_metrics(y_true, y_pred_binary, beta1=2, beta2=2) 31 | results['F_beta_macro'] = challenge_scores['F_beta_macro'] 32 | results['G_beta_macro'] = challenge_scores['G_beta_macro'] 33 | 34 | # label based metric 35 | results['macro_auc'] = roc_auc_score(y_true, y_pred, average='macro') 36 | 37 | df_result = pd.DataFrame(results, index=[0]) 38 | return df_result 39 | 40 | def challenge_metrics(y_true, y_pred, beta1=2, beta2=2, class_weights=None, single=False): 41 | f_beta = 0 42 | g_beta = 0 43 | if single: # if evaluating single class in case of threshold-optimization 44 | sample_weights = np.ones(y_true.sum(axis=1).shape) 45 | else: 46 | sample_weights = y_true.sum(axis=1) 47 | for classi in range(y_true.shape[1]): 48 | y_truei, y_predi = y_true[:,classi], y_pred[:,classi] 49 | TP, FP, TN, FN = 0.,0.,0.,0. 50 | for i in range(len(y_predi)): 51 | sample_weight = sample_weights[i] 52 | if y_truei[i]==y_predi[i]==1: 53 | TP += 1./sample_weight 54 | if ((y_predi[i]==1) and (y_truei[i]!=y_predi[i])): 55 | FP += 1./sample_weight 56 | if y_truei[i]==y_predi[i]==0: 57 | TN += 1./sample_weight 58 | if ((y_predi[i]==0) and (y_truei[i]!=y_predi[i])): 59 | FN += 1./sample_weight 60 | f_beta_i = ((1+beta1**2)*TP)/((1+beta1**2)*TP + FP + (beta1**2)*FN) 61 | g_beta_i = (TP)/(TP+FP+beta2*FN) 62 | 63 | f_beta += f_beta_i 64 | g_beta += g_beta_i 65 | 66 | return {'F_beta_macro':f_beta/y_true.shape[1], 'G_beta_macro':g_beta/y_true.shape[1]} 67 | 68 | def get_appropriate_bootstrap_samples(y_true, n_bootstraping_samples): 69 | samples=[] 70 | while True: 71 | ridxs = np.random.randint(0, len(y_true), len(y_true)) 72 | if y_true[ridxs].sum(axis=0).min() != 0: 73 | samples.append(ridxs) 74 | if len(samples) == n_bootstraping_samples: 75 | break 76 | return samples 77 | 78 | def find_optimal_cutoff_threshold(target, predicted): 79 | """ 80 | Find the optimal probability cutoff point for a classification model related to event rate 81 | """ 82 | fpr, tpr, threshold = roc_curve(target, predicted) 83 | optimal_idx = np.argmax(tpr - fpr) 84 | optimal_threshold = threshold[optimal_idx] 85 | return optimal_threshold 86 | 87 | def find_optimal_cutoff_thresholds(y_true, y_pred): 88 | return [find_optimal_cutoff_threshold(y_true[:,i], y_pred[:,i]) for i in range(y_true.shape[1])] 89 | 90 | def find_optimal_cutoff_threshold_for_Gbeta(target, predicted, n_thresholds=100): 91 | thresholds = np.linspace(0.00,1,n_thresholds) 92 | scores = [challenge_metrics(target, predicted>t, single=True)['G_beta_macro'] for t in thresholds] 93 | optimal_idx = np.argmax(scores) 94 | return thresholds[optimal_idx] 95 | 96 | def find_optimal_cutoff_thresholds_for_Gbeta(y_true, y_pred): 97 | print("optimize thresholds with respect to G_beta") 98 | return [find_optimal_cutoff_threshold_for_Gbeta(y_true[:,k][:,np.newaxis], y_pred[:,k][:,np.newaxis]) for k in tqdm(range(y_true.shape[1]))] 99 | 100 | def apply_thresholds(preds, thresholds): 101 | """ 102 | apply class-wise thresholds to prediction score in order to get binary format. 103 | BUT: if no score is above threshold, pick maximum. This is needed due to metric issues. 104 | """ 105 | tmp = [] 106 | for p in preds: 107 | tmp_p = (p > thresholds).astype(int) 108 | if np.sum(tmp_p) == 0: 109 | tmp_p[np.argmax(p)] = 1 110 | tmp.append(tmp_p) 111 | tmp = np.array(tmp) 112 | return tmp 113 | 114 | # DATA PROCESSING STUFF 115 | 116 | def load_dataset(path, sampling_rate, release=False): 117 | # if path.split('/')[-1] == 'ptb_xl': 118 | # load and convert annotation data 119 | Y = pd.read_csv(path+'ptbxl_database.csv', index_col='ecg_id') 120 | Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) 121 | 122 | # Load raw signal data 123 | X = load_raw_data_ptbxl(Y, sampling_rate, path) 124 | 125 | # elif path.split('/')[-2] == 'ICBEB': 126 | # # load and convert annotation data 127 | # Y = pd.read_csv(path+'icbeb_database.csv', index_col='ecg_id') 128 | # Y.scp_codes = Y.scp_codes.apply(lambda x: ast.literal_eval(x)) 129 | 130 | # # Load raw signal data 131 | # X = load_raw_data_icbeb(Y, sampling_rate, path) 132 | 133 | return X, Y 134 | 135 | 136 | def load_raw_data_icbeb(df, sampling_rate, path): 137 | 138 | if sampling_rate == 100: 139 | if os.path.exists(path + 'raw100.npy'): 140 | data = np.load(path+'raw100.npy', allow_pickle=True) 141 | else: 142 | data = [wfdb.rdsamp(path + 'records100/'+str(f)) for f in tqdm(df.index)] 143 | data = np.array([signal for signal, meta in data]) 144 | pickle.dump(data, open(path+'raw100.npy', 'wb'), protocol=4) 145 | elif sampling_rate == 500: 146 | if os.path.exists(path + 'raw500.npy'): 147 | data = np.load(path+'raw500.npy', allow_pickle=True) 148 | else: 149 | data = [wfdb.rdsamp(path + 'records500/'+str(f)) for f in tqdm(df.index)] 150 | data = np.array([signal for signal, meta in data]) 151 | pickle.dump(data, open(path+'raw500.npy', 'wb'), protocol=4) 152 | return data 153 | 154 | def load_raw_data_ptbxl(df, sampling_rate, path): 155 | if sampling_rate == 100: 156 | if os.path.exists(path + 'raw100.npy'): 157 | data = np.load(path+'raw100.npy', allow_pickle=True) 158 | else: 159 | data = [wfdb.rdsamp(path+f) for f in tqdm(df.filename_lr)] 160 | data = np.array([signal for signal, meta in data]) 161 | pickle.dump(data, open(path+'raw100.npy', 'wb'), protocol=4) 162 | elif sampling_rate == 500: 163 | if os.path.exists(path + 'raw500.npy'): 164 | data = np.load(path+'raw500.npy', allow_pickle=True) 165 | else: 166 | data = [wfdb.rdsamp(path+f) for f in tqdm(df.filename_hr)] 167 | data = np.array([signal for signal, meta in data]) 168 | pickle.dump(data, open(path+'raw500.npy', 'wb'), protocol=4) 169 | return data 170 | 171 | def compute_label_aggregations(df, folder, ctype): 172 | 173 | df['scp_codes_len'] = df.scp_codes.apply(lambda x: len(x)) 174 | 175 | aggregation_df = pd.read_csv(folder+'scp_statements.csv', index_col=0) 176 | 177 | if ctype in ['diagnostic', 'subdiagnostic', 'superdiagnostic']: 178 | 179 | def aggregate_all_diagnostic(y_dic): 180 | tmp = [] 181 | for key in y_dic.keys(): 182 | if key in diag_agg_df.index: 183 | tmp.append(key) 184 | return list(set(tmp)) 185 | 186 | def aggregate_subdiagnostic(y_dic): 187 | tmp = [] 188 | for key in y_dic.keys(): 189 | if key in diag_agg_df.index: 190 | c = diag_agg_df.loc[key].diagnostic_subclass 191 | if str(c) != 'nan': 192 | tmp.append(c) 193 | return list(set(tmp)) 194 | 195 | def aggregate_diagnostic(y_dic): 196 | tmp = [] 197 | for key in y_dic.keys(): 198 | if key in diag_agg_df.index: 199 | c = diag_agg_df.loc[key].diagnostic_class 200 | if str(c) != 'nan': 201 | tmp.append(c) 202 | return list(set(tmp)) 203 | 204 | diag_agg_df = aggregation_df[aggregation_df.diagnostic == 1.0] 205 | if ctype == 'diagnostic': 206 | df['diagnostic'] = df.scp_codes.apply(aggregate_all_diagnostic) 207 | df['diagnostic_len'] = df.diagnostic.apply(lambda x: len(x)) 208 | elif ctype == 'subdiagnostic': 209 | df['subdiagnostic'] = df.scp_codes.apply(aggregate_subdiagnostic) 210 | df['subdiagnostic_len'] = df.subdiagnostic.apply(lambda x: len(x)) 211 | elif ctype == 'superdiagnostic': 212 | df['superdiagnostic'] = df.scp_codes.apply(aggregate_diagnostic) 213 | df['superdiagnostic_len'] = df.superdiagnostic.apply(lambda x: len(x)) 214 | elif ctype == 'form': 215 | form_agg_df = aggregation_df[aggregation_df.form == 1.0] 216 | 217 | def aggregate_form(y_dic): 218 | tmp = [] 219 | for key in y_dic.keys(): 220 | if key in form_agg_df.index: 221 | c = key 222 | if str(c) != 'nan': 223 | tmp.append(c) 224 | return list(set(tmp)) 225 | 226 | df['form'] = df.scp_codes.apply(aggregate_form) 227 | df['form_len'] = df.form.apply(lambda x: len(x)) 228 | elif ctype == 'rhythm': 229 | rhythm_agg_df = aggregation_df[aggregation_df.rhythm == 1.0] 230 | 231 | def aggregate_rhythm(y_dic): 232 | tmp = [] 233 | for key in y_dic.keys(): 234 | if key in rhythm_agg_df.index: 235 | c = key 236 | if str(c) != 'nan': 237 | tmp.append(c) 238 | return list(set(tmp)) 239 | 240 | df['rhythm'] = df.scp_codes.apply(aggregate_rhythm) 241 | df['rhythm_len'] = df.rhythm.apply(lambda x: len(x)) 242 | elif ctype == 'all': 243 | df['all_scp'] = df.scp_codes.apply(lambda x: list(set(x.keys()))) 244 | 245 | return df 246 | 247 | def select_data(XX,YY, ctype, min_samples, outputfolder): 248 | # convert multilabel to multi-hot 249 | mlb = MultiLabelBinarizer() 250 | 251 | if ctype == 'diagnostic': 252 | X = XX[YY.diagnostic_len > 0] 253 | Y = YY[YY.diagnostic_len > 0] 254 | mlb.fit(Y.diagnostic.values) 255 | y = mlb.transform(Y.diagnostic.values) 256 | elif ctype == 'subdiagnostic': 257 | counts = pd.Series(np.concatenate(YY.subdiagnostic.values)).value_counts() 258 | counts = counts[counts > min_samples] 259 | YY.subdiagnostic = YY.subdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 260 | YY['subdiagnostic_len'] = YY.subdiagnostic.apply(lambda x: len(x)) 261 | X = XX[YY.subdiagnostic_len > 0] 262 | Y = YY[YY.subdiagnostic_len > 0] 263 | mlb.fit(Y.subdiagnostic.values) 264 | y = mlb.transform(Y.subdiagnostic.values) 265 | elif ctype == 'superdiagnostic': 266 | counts = pd.Series(np.concatenate(YY.superdiagnostic.values)).value_counts() 267 | counts = counts[counts > min_samples] 268 | YY.superdiagnostic = YY.superdiagnostic.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 269 | YY['superdiagnostic_len'] = YY.superdiagnostic.apply(lambda x: len(x)) 270 | X = XX[YY.superdiagnostic_len > 0] 271 | Y = YY[YY.superdiagnostic_len > 0] 272 | mlb.fit(Y.superdiagnostic.values) 273 | y = mlb.transform(Y.superdiagnostic.values) 274 | elif ctype == 'form': 275 | # filter 276 | counts = pd.Series(np.concatenate(YY.form.values)).value_counts() 277 | counts = counts[counts > min_samples] 278 | YY.form = YY.form.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 279 | YY['form_len'] = YY.form.apply(lambda x: len(x)) 280 | # select 281 | X = XX[YY.form_len > 0] 282 | Y = YY[YY.form_len > 0] 283 | mlb.fit(Y.form.values) 284 | y = mlb.transform(Y.form.values) 285 | elif ctype == 'rhythm': 286 | # filter 287 | counts = pd.Series(np.concatenate(YY.rhythm.values)).value_counts() 288 | counts = counts[counts > min_samples] 289 | YY.rhythm = YY.rhythm.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 290 | YY['rhythm_len'] = YY.rhythm.apply(lambda x: len(x)) 291 | # select 292 | X = XX[YY.rhythm_len > 0] 293 | Y = YY[YY.rhythm_len > 0] 294 | mlb.fit(Y.rhythm.values) 295 | y = mlb.transform(Y.rhythm.values) 296 | elif ctype == 'all': 297 | # filter 298 | counts = pd.Series(np.concatenate(YY.all_scp.values)).value_counts() 299 | counts = counts[counts > min_samples] 300 | YY.all_scp = YY.all_scp.apply(lambda x: list(set(x).intersection(set(counts.index.values)))) 301 | YY['all_scp_len'] = YY.all_scp.apply(lambda x: len(x)) 302 | # select 303 | X = XX[YY.all_scp_len > 0] 304 | Y = YY[YY.all_scp_len > 0] 305 | mlb.fit(Y.all_scp.values) 306 | y = mlb.transform(Y.all_scp.values) 307 | else: 308 | pass 309 | 310 | # save LabelBinarizer 311 | # with open(outputfolder+'mlb.pkl', 'wb') as tokenizer: 312 | # pickle.dump(mlb, tokenizer) 313 | 314 | return X, Y, y, mlb 315 | 316 | def preprocess_signals(X_train, X_validation, X_test, outputfolder): 317 | # Standardize data such that mean 0 and variance 1 318 | ss = StandardScaler() 319 | ss.fit(np.vstack(X_train).flatten()[:,np.newaxis].astype(float)) 320 | 321 | # Save Standardizer data 322 | with open(outputfolder+'standard_scaler.pkl', 'wb') as ss_file: 323 | pickle.dump(ss, ss_file) 324 | 325 | return apply_standardizer(X_train, ss), apply_standardizer(X_validation, ss), apply_standardizer(X_test, ss) 326 | 327 | def apply_standardizer(X, ss): 328 | X_tmp = [] 329 | for x in X: 330 | x_shape = x.shape 331 | X_tmp.append(ss.transform(x.flatten()[:,np.newaxis]).reshape(x_shape)) 332 | X_tmp = np.array(X_tmp) 333 | return X_tmp 334 | 335 | 336 | # DOCUMENTATION STUFF 337 | 338 | def generate_ptbxl_summary_table(selection=None, folder='../output/'): 339 | 340 | exps = ['exp0', 'exp1', 'exp1.1', 'exp1.1.1', 'exp2', 'exp3'] 341 | metric1 = 'macro_auc' 342 | 343 | # get models 344 | models = {} 345 | for i, exp in enumerate(exps): 346 | if selection is None: 347 | exp_models = [m.split('/')[-1] for m in glob.glob(folder+str(exp)+'/models/*')] 348 | else: 349 | exp_models = selection 350 | if i == 0: 351 | models = set(exp_models) 352 | else: 353 | models = models.union(set(exp_models)) 354 | 355 | results_dic = {'Method':[], 356 | 'exp0_AUC':[], 357 | 'exp1_AUC':[], 358 | 'exp1.1_AUC':[], 359 | 'exp1.1.1_AUC':[], 360 | 'exp2_AUC':[], 361 | 'exp3_AUC':[] 362 | } 363 | 364 | for m in models: 365 | results_dic['Method'].append(m) 366 | 367 | for e in exps: 368 | 369 | try: 370 | me_res = pd.read_csv(folder+str(e)+'/models/'+str(m)+'/results/te_results.csv', index_col=0) 371 | 372 | mean1 = me_res.loc['point'][metric1] 373 | unc1 = max(me_res.loc['upper'][metric1]-me_res.loc['point'][metric1], me_res.loc['point'][metric1]-me_res.loc['lower'][metric1]) 374 | 375 | results_dic[e+'_AUC'].append("%.3f(%.2d)" %(np.round(mean1,3), int(unc1*1000))) 376 | 377 | except FileNotFoundError: 378 | results_dic[e+'_AUC'].append("--") 379 | 380 | 381 | df = pd.DataFrame(results_dic) 382 | df_index = df[df.Method.isin(['naive', 'ensemble'])] 383 | df_rest = df[~df.Method.isin(['naive', 'ensemble'])] 384 | df = pd.concat([df_rest, df_index]) 385 | df.to_csv(folder+'results_ptbxl.csv') 386 | 387 | titles = [ 388 | '### 1. PTB-XL: all statements', 389 | '### 2. PTB-XL: diagnostic statements', 390 | '### 3. PTB-XL: Diagnostic subclasses', 391 | '### 4. PTB-XL: Diagnostic superclasses', 392 | '### 5. PTB-XL: Form statements', 393 | '### 6. PTB-XL: Rhythm statements' 394 | ] 395 | 396 | # helper output function for markdown tables 397 | our_work = 'https://arxiv.org/abs/2004.13701' 398 | our_repo = 'https://github.com/helme/ecg_ptbxl_benchmarking/' 399 | md_source = '' 400 | for i, e in enumerate(exps): 401 | md_source += '\n '+titles[i]+' \n \n' 402 | md_source += '| Model | AUC ↓ | paper/source | code | \n' 403 | md_source += '|---:|:---|:---|:---| \n' 404 | for row in df_rest[['Method', e+'_AUC']].sort_values(e+'_AUC', ascending=False).values: 405 | md_source += '| ' + row[0].replace('fastai_', '') + ' | ' + row[1] + ' | [our work]('+our_work+') | [this repo]('+our_repo+')| \n' 406 | print(md_source) 407 | 408 | def ICBEBE_table(selection=None, folder='../output/'): 409 | cols = ['macro_auc', 'F_beta_macro', 'G_beta_macro'] 410 | 411 | if selection is None: 412 | models = [m.split('/')[-1].split('_pretrained')[0] for m in glob.glob(folder+'exp_ICBEB/models/*')] 413 | else: 414 | models = [] 415 | for s in selection: 416 | #if s != 'Wavelet+NN': 417 | models.append(s) 418 | 419 | data = [] 420 | for model in models: 421 | me_res = pd.read_csv(folder+'exp_ICBEB/models/'+model+'/results/te_results.csv', index_col=0) 422 | mcol=[] 423 | for col in cols: 424 | mean = me_res.ix['point'][col] 425 | unc = max(me_res.ix['upper'][col]-me_res.ix['point'][col], me_res.ix['point'][col]-me_res.ix['lower'][col]) 426 | mcol.append("%.3f(%.2d)" %(np.round(mean,3), int(unc*1000))) 427 | data.append(mcol) 428 | data = np.array(data) 429 | 430 | df = pd.DataFrame(data, columns=cols, index=models) 431 | df.to_csv(folder+'results_icbeb.csv') 432 | 433 | df_rest = df[~df.index.isin(['naive', 'ensemble'])] 434 | df_rest = df_rest.sort_values('macro_auc', ascending=False) 435 | our_work = 'https://arxiv.org/abs/2004.13701' 436 | our_repo = 'https://github.com/helme/ecg_ptbxl_benchmarking/' 437 | 438 | md_source = '| Model | AUC ↓ | F_beta=2 | G_beta=2 | paper/source | code | \n' 439 | md_source += '|---:|:---|:---|:---|:---|:---| \n' 440 | for i, row in enumerate(df_rest[cols].values): 441 | md_source += '| ' + df_rest.index[i].replace('fastai_', '') + ' | ' + row[0] + ' | ' + row[1] + ' | ' + row[2] + ' | [our work]('+our_work+') | [this repo]('+our_repo+')| \n' 442 | print(md_source) 443 | -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import pandas as pd 14 | import datetime 15 | import json 16 | import numpy as np 17 | from utils import utils 18 | import pandas as pd 19 | import ast 20 | import wfdb 21 | import numpy as np 22 | 23 | import os 24 | import time 25 | from pathlib import Path 26 | import glob 27 | import torch 28 | from wfdb import processing 29 | import re 30 | import torch.backends.cudnn as cudnn 31 | from torch.utils.tensorboard import SummaryWriter 32 | import timm 33 | import ast 34 | # assert timm.__version__ == "0.3.2" # version check 35 | from timm.models.layers import trunc_normal_ 36 | from timm.data.mixup import Mixup 37 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 38 | 39 | import utils.lr_decay as lrd 40 | import utils.misc as misc 41 | from utils.datasets import build_dataset 42 | from utils.pos_embed import interpolate_pos_embed 43 | from utils.misc import NativeScalerWithGradNormCount as NativeScaler 44 | import wfdb 45 | import vit_model 46 | from utils.ecg_dataloader import CustomDataset 47 | from engine_finetune import train_one_epoch, evaluate 48 | 49 | 50 | def get_args_parser(): 51 | parser = argparse.ArgumentParser('MAE fine-tuning for ECG classification', add_help=False) 52 | parser.add_argument('--batch_size', default=64, type=int, 53 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 54 | parser.add_argument('--epochs', default=50, type=int) 55 | parser.add_argument('--accum_iter', default=1, type=int, 56 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 57 | 58 | # Model parameters 59 | parser.add_argument('--model', default='vit_1dcnn', type=str, metavar='MODEL', 60 | help='Name of model to train') 61 | 62 | parser.add_argument('--input_size', default=224, type=int, 63 | help='images input size') 64 | 65 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 66 | help='Drop path rate (default: 0.1)') 67 | 68 | # Optimizer parameters 69 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 70 | help='Clip gradient norm (default: None, no clipping)') 71 | 72 | parser.add_argument('--weight_decay', type=float, default=0.05, 73 | help='weight decay (default: 0.05)') 74 | 75 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 76 | help='learning rate (absolute lr)') 77 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 78 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 79 | parser.add_argument('--layer_decay', type=float, default=0.75, 80 | help='layer-wise lr decay from ELECTRA/BEiT') 81 | 82 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 83 | help='lower lr bound for cyclic schedulers that hit 0') 84 | 85 | parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', 86 | help='epochs to warmup LR') 87 | 88 | # Augmentation parameters 89 | parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', 90 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 91 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 92 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 93 | parser.add_argument('--smoothing', type=float, default=0.1, 94 | help='Label smoothing (default: 0.1)') 95 | 96 | # * Random Erase params 97 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 98 | help='Random erase prob (default: 0.25)') 99 | parser.add_argument('--remode', type=str, default='pixel', 100 | help='Random erase mode (default: "pixel")') 101 | parser.add_argument('--recount', type=int, default=1, 102 | help='Random erase count (default: 1)') 103 | parser.add_argument('--resplit', action='store_true', default=False, 104 | help='Do not random erase first (clean) augmentation split') 105 | 106 | # * Mixup params 107 | parser.add_argument('--mixup', type=float, default=0, 108 | help='mixup alpha, mixup enabled if > 0.') 109 | parser.add_argument('--cutmix', type=float, default=0, 110 | help='cutmix alpha, cutmix enabled if > 0.') 111 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 112 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 113 | parser.add_argument('--mixup_prob', type=float, default=1.0, 114 | help='Probability of performing mixup or cutmix when either/both is enabled') 115 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 116 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 117 | parser.add_argument('--mixup_mode', type=str, default='batch', 118 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 119 | # * Data Set 120 | 121 | parser.add_argument('--val_start',type=int, default= 1, 122 | help='validation start') 123 | parser.add_argument('--val_end',type=int, default=30, 124 | help='validation end') 125 | parser.add_argument('--train_start',type=int, default=31, 126 | help='train start') 127 | parser.add_argument('--train_end',type=int, default=40, 128 | help='train end') 129 | parser.add_argument('--data',type=str, default=" ", 130 | help='Which dataset') 131 | parser.add_argument('--classf_type',type=str, default="multi_label", 132 | help='Which Classification') 133 | parser.add_argument('--mode',type=str, default="finetune", 134 | help='Which Classification') 135 | 136 | 137 | 138 | # * Finetuning params 139 | parser.add_argument('--finetune', default='', 140 | help='finetune from checkpoint') 141 | parser.add_argument('--global_pool', action='store_true') 142 | parser.set_defaults(global_pool=True) 143 | parser.add_argument('--cls_token', action='store_false', dest='global_pool', 144 | help='Use class token instead of global pool for classification') 145 | 146 | # Dataset parameters 147 | parser.add_argument('--data_path', default='/Users/parthagrawal02/Desktop/Carelog/ECG_CNN/physionet', type=str, 148 | help='dataset path') 149 | parser.add_argument('--nb_classes', default=10, type=int, 150 | help='number of the classification types') 151 | 152 | parser.add_argument('--output_dir', default='./output_dir_fin', 153 | help='path where to save, empty for no saving') 154 | parser.add_argument('--log_dir', default='./output_dir_fin', 155 | help='path where to tensorboard log') 156 | parser.add_argument('--device', default='cuda', 157 | help='device to use for training / testing') 158 | parser.add_argument('--seed', default=0, type=int) 159 | parser.add_argument('--resume', default='', 160 | help='resume from checkpoint') 161 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 162 | help='start epoch') 163 | parser.add_argument('--eval', action='store_true', 164 | help='Perform evaluation only') 165 | parser.add_argument('--dist_eval', action='store_true', default=False, 166 | help='Enabling distributed evaluation (recommended during training for faster monitor') 167 | parser.add_argument('--num_workers', default=2, type=int) 168 | parser.add_argument('--pin_mem', action='store_true', 169 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 170 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 171 | parser.set_defaults(pin_mem=True) 172 | 173 | # distributed training parameters 174 | 175 | parser.add_argument('--world_size', default=1, type=int, 176 | help='number of distributed processes') 177 | parser.add_argument('--local_rank', default=-1, type=int) 178 | parser.add_argument('--dist_on_itp', action='store_true') 179 | parser.add_argument('--dist_url', default='env://', 180 | help='url used to set up distributed training') 181 | parser.add_argument('--distributed', default=None, 182 | help='url used to set up distributed training') 183 | parser.add_argument('--cuda', default=None, 184 | help='url used to set up distributed training') 185 | parser.add_argument('--data_split', default=0.8, type= float, 186 | help='url used to set up distributed training') 187 | parser.add_argument('--task', default='superdiagnostics', type= str, 188 | help='url used to set up distributed training') 189 | 190 | 191 | 192 | return parser 193 | 194 | 195 | def main(args): 196 | misc.init_distributed_mode(args) 197 | 198 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 199 | print("{}".format(args).replace(', ', ',\n')) 200 | 201 | device = torch.device(args.device) 202 | 203 | # fix the seed for reproducibility 204 | seed = args.seed + misc.get_rank() 205 | torch.manual_seed(seed) 206 | np.random.seed(seed) 207 | 208 | if args.cuda is not None: 209 | cudnn.benchmark = True 210 | 211 | if args.data == "PTB": 212 | def load_raw_data(df, sampling_rate, path): 213 | if(sampling_rate == 100): 214 | data = [wfdb.rdsamp(path + f) for f in df.filename_lr] 215 | else: 216 | data = [wfdb.rdsamp(path + f) for f in df.filename_hr] 217 | data = np.array([signal for signal, meta in data]) 218 | return data 219 | 220 | 221 | sampling_frequency=100 222 | datafolder=args.data_path 223 | task=args.task 224 | outputfolder='/output/' 225 | 226 | # Load PTB-XL data 227 | raw_labels = pd.read_csv(datafolder+'ptbxl_database.csv', index_col='ecg_id') 228 | raw_labels.scp_codes = raw_labels.scp_codes.apply(lambda x: ast.literal_eval(x)) 229 | 230 | # Load raw signal data 231 | data = load_raw_data(raw_labels, sampling_frequency, datafolder) 232 | 233 | # data, raw_labels = utils.load_dataset(datafolder, sampling_frequency) 234 | # Preprocess label data 235 | labels = utils.compute_label_aggregations(raw_labels, datafolder, task) 236 | # Select relevant data and convert to one-hot 237 | data, labels, Y, _ = utils.select_data(data, labels, task, min_samples=0, outputfolder=outputfolder) 238 | 239 | # 1-9 for training 240 | X_train = data[labels.strat_fold < 10] 241 | y_train = Y[labels.strat_fold < 10] 242 | # 10 for validation 243 | X_test = data[labels.strat_fold == 10] 244 | y_test = Y[labels.strat_fold == 10] 245 | 246 | 247 | X_train = torch.tensor(X_train.transpose(0, 2, 1)) 248 | mean = X_train.mean(dim=-1, keepdim=True) 249 | var = X_train.var(dim=-1, keepdim=True) 250 | X_train = (X_train - mean) / (var + 1.e-6)**.5 251 | X_test = torch.tensor(X_test.transpose(0, 2, 1)) 252 | mean = X_test.mean(dim=-1, keepdim=True) 253 | var = X_test.var(dim=-1, keepdim=True) 254 | X_test = (X_test - mean) / (var + 1.e-6)**.5 255 | dataset_train = torch.utils.data.TensorDataset(torch.tensor(X_train[:, None, :, :]).double(), torch.tensor(y_train).double()) 256 | dataset_val = torch.utils.data.TensorDataset(torch.tensor(X_test[:, None, :, :]).double(), torch.tensor(y_test).double()) 257 | 258 | else: 259 | 260 | full_dataset = CustomDataset(args.data_path, args.train_start, args.train_end) # Training Data - 261 | train_size = int(args.data_split * len(full_dataset)) 262 | val_size = len(full_dataset) - train_size 263 | dataset_train, dataset_val = torch.utils.data.random_split(full_dataset, [train_size, val_size]) 264 | 265 | 266 | 267 | 268 | if args.distributed is not None: # args.distributed: 269 | num_tasks = misc.get_world_size() 270 | global_rank = misc.get_rank() 271 | sampler_train = torch.utils.data.DistributedSampler( 272 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 273 | ) 274 | print("Sampler_train = %s" % str(sampler_train)) 275 | if args.dist_eval: 276 | if len(dataset_val) % num_tasks != 0: 277 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 278 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 279 | 'equal num of samples per-process.') 280 | sampler_val = torch.utils.data.DistributedSampler( 281 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True) # shuffle=True to reduce monitor bias 282 | else: 283 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 284 | else: 285 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 286 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 287 | 288 | if args.log_dir is not None and not args.eval: 289 | os.makedirs(args.log_dir, exist_ok=True) 290 | log_writer = SummaryWriter(log_dir=args.log_dir) 291 | else: 292 | log_writer = None 293 | 294 | data_loader_train = torch.utils.data.DataLoader( 295 | dataset_train, sampler=sampler_train, 296 | batch_size=args.batch_size, 297 | num_workers=args.num_workers, 298 | pin_memory=args.pin_mem, 299 | drop_last=True, 300 | ) 301 | data_loader_val = torch.utils.data.DataLoader( 302 | dataset_val, sampler=sampler_val, 303 | batch_size=args.batch_size, 304 | num_workers=args.num_workers, 305 | pin_memory=args.pin_mem, 306 | drop_last=False 307 | ) 308 | mixup_fn = None 309 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 310 | if mixup_active: 311 | print("Mixup is activated!") 312 | mixup_fn = Mixup( 313 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 314 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 315 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 316 | 317 | model = vit_model.__dict__[args.model]( 318 | num_classes=args.nb_classes, 319 | # drop_path_rate=args.drop_path, 320 | global_pool=args.global_pool, 321 | ) 322 | 323 | if args.finetune and not args.eval: 324 | checkpoint = torch.load(args.finetune, map_location='cpu') 325 | 326 | print("Load pre-trained checkpoint from: %s" % args.finetune) 327 | checkpoint_model = checkpoint['model'] 328 | state_dict = model.state_dict() 329 | for k in ['head.weight', 'head.bias']: 330 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 331 | print(f"Removing key {k} from pretrained checkpoint") 332 | del checkpoint_model[k] 333 | # interpolate position embedding 334 | interpolate_pos_embed(model, checkpoint_model) 335 | 336 | # load pre-trained model 337 | msg = model.load_state_dict(checkpoint_model, strict=False) 338 | print(msg) 339 | 340 | if args.global_pool: 341 | assert set(msg.missing_keys) == {'head.weight', 'head.bias', 'fc_norm.weight', 'fc_norm.bias'} 342 | else: 343 | assert set(msg.missing_keys) == {'head.weight', 'head.bias'} 344 | 345 | # manually initialize fc layer 346 | trunc_normal_(model.head.weight, std=2e-5) 347 | 348 | if(args.mode == "linprobe"): 349 | for _, p in model.named_parameters(): 350 | p.requires_grad = False 351 | for _, p in model.head.named_parameters(): 352 | p.requires_grad = True 353 | 354 | model = model.double() 355 | if args.cuda is not None: 356 | model.to(device) 357 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 358 | model_without_ddp = model 359 | 360 | print("Model = %s" % str(model_without_ddp)) 361 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 362 | 363 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 364 | 365 | if args.lr is None: # only base_lr is specified 366 | args.lr = args.blr * eff_batch_size / 256 367 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 368 | print("actual lr: %.2e" % args.lr) 369 | 370 | print("accumulate grad iterations: %d" % args.accum_iter) 371 | print("effective batch size: %d" % eff_batch_size) 372 | 373 | if args.distributed: 374 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 375 | model_without_ddp = model.module 376 | 377 | # build optimizer with layer-wise lr decay (lrd) 378 | param_groups = lrd.param_groups_lrd(model_without_ddp, args.weight_decay, 379 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 380 | layer_decay=args.layer_decay 381 | ) 382 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 383 | loss_scaler = NativeScaler() 384 | 385 | # if mixup_fn is not None: 386 | # # smoothing is handled with mixup label transform 387 | # criterion = SoftTargetCrossEntropy() 388 | # elif args.smoothing > 0.: 389 | # criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 390 | # else: 391 | if args.classf_type == "multi_label": 392 | criterion = torch.nn.BCEWithLogitsLoss() 393 | else: 394 | criterion = torch.nn.CrossEntropyLoss() 395 | 396 | print("criterion = %s" % str(criterion)) 397 | 398 | # ckpt_file = args.ckpt 399 | # state_dict = torch.load(ckpt_file, map_location="cpu") 400 | 401 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 402 | # model.load_state_dict(state_dict, strict=True) 403 | if args.eval: 404 | test_stats = evaluate(data_loader_val, model, device, args = args) 405 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 406 | exit(0) 407 | 408 | print(f"Start training for {args.epochs} epochs") 409 | start_time = time.time() 410 | max_accuracy = 0.0 411 | for epoch in range(args.start_epoch, args.epochs): 412 | if args.distributed: 413 | data_loader_train.sampler.set_epoch(epoch) 414 | train_stats = train_one_epoch( 415 | model, criterion, data_loader_train, 416 | optimizer, device, epoch, loss_scaler, 417 | args.clip_grad, mixup_fn, 418 | log_writer=log_writer, 419 | args=args 420 | ) 421 | if args.output_dir: 422 | misc.save_model( 423 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 424 | loss_scaler=loss_scaler, epoch=epoch) 425 | 426 | test_stats = evaluate(data_loader_val, model, device, args = args) 427 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 428 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 429 | print(f'Max accuracy: {max_accuracy:.2f}%') 430 | 431 | if log_writer is not None: 432 | log_writer.add_scalar('perf/test_acc1', test_stats['acc1'], epoch) 433 | # log_writer.add_scalar('perf/auc') 434 | # log_writer.add_scalar('perf/test_acc5', test_stats['acc5'], epoch) 435 | log_writer.add_scalar('perf/test_loss', test_stats['loss'], epoch) 436 | 437 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 438 | **{f'test_{k}': v for k, v in test_stats.items()}, 439 | 'epoch': epoch, 440 | 'n_parameters': n_parameters} 441 | 442 | if args.output_dir and misc.is_main_process(): 443 | if log_writer is not None: 444 | log_writer.flush() 445 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 446 | f.write(json.dumps(log_stats) + "\n") 447 | 448 | total_time = time.time() - start_time 449 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 450 | print('Training time {}'.format(total_time_str)) 451 | 452 | 453 | if __name__ == '__main__': 454 | args = get_args_parser() 455 | args = args.parse_args() 456 | if args.output_dir: 457 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 458 | main(args) 459 | --------------------------------------------------------------------------------