├── .DS_Store ├── figure ├── fig1.png ├── fig2.png └── .DS_Store ├── .gitattributes ├── EchoFM ├── .DS_Store ├── util │ ├── __pycache__ │ │ ├── env.cpython-310.pyc │ │ ├── env.cpython-312.pyc │ │ ├── misc.cpython-310.pyc │ │ ├── kinetics.cpython-310.pyc │ │ ├── logging.cpython-310.pyc │ │ ├── lr_sched.cpython-310.pyc │ │ └── video_vit.cpython-310.pyc │ ├── decoder │ │ ├── __pycache__ │ │ │ ├── utils.cpython-310.pyc │ │ │ ├── decoder.cpython-310.pyc │ │ │ ├── transform.cpython-310.pyc │ │ │ ├── rand_augment.cpython-310.pyc │ │ │ └── video_container.cpython-310.pyc │ │ ├── video_container.py │ │ ├── mixup.py │ │ ├── random_erasing.py │ │ ├── decoder.py │ │ ├── utils.py │ │ ├── rand_augment.py │ │ └── transform.py │ ├── env.py │ ├── lr_sched.py │ ├── pos_embed.py │ ├── lr_decay.py │ ├── logging.py │ ├── video_vit.py │ ├── meters.py │ └── misc.py ├── __pycache__ │ ├── models_mae.cpython-310.pyc │ └── engine_pretrain.cpython-310.pyc ├── engine_test.py ├── engine_pretrain.py ├── engine_finetune.py ├── models_vit.py └── models_mae.py ├── data ├── __pycache__ │ └── dataset.cpython-310.pyc └── dataset.py ├── run_pretrain.py ├── environment_setup.sh ├── README.md ├── config └── config.yaml └── main_pretrain.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/.DS_Store -------------------------------------------------------------------------------- /figure/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/figure/fig1.png -------------------------------------------------------------------------------- /figure/fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/figure/fig2.png -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /EchoFM/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/.DS_Store -------------------------------------------------------------------------------- /figure/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/figure/.DS_Store -------------------------------------------------------------------------------- /data/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/data/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/__pycache__/env.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/env.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/__pycache__/env.cpython-312.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/env.cpython-312.pyc -------------------------------------------------------------------------------- /EchoFM/__pycache__/models_mae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/__pycache__/models_mae.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/__pycache__/misc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/misc.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/__pycache__/kinetics.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/kinetics.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/__pycache__/logging.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/logging.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/__pycache__/lr_sched.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/lr_sched.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/__pycache__/engine_pretrain.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/__pycache__/engine_pretrain.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/__pycache__/video_vit.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/video_vit.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/decoder/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/decoder/__pycache__/decoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/decoder.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/decoder/__pycache__/transform.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/transform.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/decoder/__pycache__/rand_augment.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/rand_augment.cpython-310.pyc -------------------------------------------------------------------------------- /EchoFM/util/decoder/__pycache__/video_container.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/video_container.cpython-310.pyc -------------------------------------------------------------------------------- /run_pretrain.py: -------------------------------------------------------------------------------- 1 | from main_pretrain import get_args_parser, main 2 | from pathlib import Path 3 | 4 | def invoke_main() -> None: 5 | args = get_args_parser() 6 | args = args.parse_args() 7 | if args.output_dir: 8 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 9 | main(args) 10 | 11 | if __name__ == "__main__": 12 | invoke_main() 13 | -------------------------------------------------------------------------------- /EchoFM/util/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """Set up Environment.""" 6 | 7 | from iopath.common.file_io import PathManagerFactory 8 | 9 | _ENV_SETUP_DONE = False 10 | pathmgr = PathManagerFactory.get(key="mae_st") 11 | checkpoint_pathmgr = PathManagerFactory.get(key="mae_st_checkpoint") 12 | 13 | 14 | def setup_environment(): 15 | global _ENV_SETUP_DONE 16 | if _ENV_SETUP_DONE: 17 | return 18 | _ENV_SETUP_DONE = True 19 | -------------------------------------------------------------------------------- /EchoFM/util/decoder/video_container.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import av 6 | 7 | 8 | def get_video_container(path_to_vid, multi_thread_decode=False): 9 | """ 10 | Given the path to the video, return the pyav video container. 11 | Args: 12 | path_to_vid (str): path to the video. 13 | multi_thread_decode (bool): if True, perform multi-thread decoding. 14 | backend (str): decoder backend, options include `pyav` and 15 | `torchvision`, default is `pyav`. 16 | Returns: 17 | container (container): video container. 18 | """ 19 | with open(path_to_vid, "rb") as fp: 20 | container = fp.read() 21 | return container 22 | -------------------------------------------------------------------------------- /EchoFM/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | 10 | def adjust_learning_rate(optimizer, epoch, args): 11 | """Decay the learning rate with half-cycle cosine after warmup""" 12 | if epoch < args.warmup_epochs: 13 | lr = args.lr * epoch / args.warmup_epochs 14 | else: 15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( 16 | 1.0 17 | + math.cos( 18 | math.pi 19 | * (epoch - args.warmup_epochs) 20 | / (args.epochs - args.warmup_epochs) 21 | ) 22 | ) 23 | for param_group in optimizer.param_groups: 24 | if "lr_scale" in param_group: 25 | param_group["lr"] = lr * param_group["lr_scale"] 26 | else: 27 | param_group["lr"] = lr 28 | return lr 29 | -------------------------------------------------------------------------------- /environment_setup.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | CONDA_ENV=${1:-""} 5 | if [ -n "$CONDA_ENV" ]; then 6 | # This is required to activate conda environment 7 | eval "$(conda shell.bash hook)" 8 | 9 | conda create -n $CONDA_ENV python=3.10.14 -y 10 | conda activate $CONDA_ENV 11 | 12 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 13 | else 14 | echo "Skipping conda environment creation. Make sure you have the correct environment activated." 15 | fi 16 | 17 | pip install iopath, psutil, scipy, einops, tensorboard 18 | conda install simplejson 19 | # # This is required to enable PEP 660 support 20 | # pip install --upgrade pip setuptools 21 | 22 | # # Install FlashAttention2 23 | # pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl 24 | 25 | # # Install VILA 26 | # pip install -e ".[train,eval]" 27 | 28 | # pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git 29 | 30 | # pip install git+https://github.com/huggingface/transformers@v4.36.2 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## EchoFM - A Video Vision Foundation Model for Echocardiogram 2 | 3 | Official repo for [EchoFM: Foundation Model for Generalizable Echocardiogram Analysis] 4 | 5 | This model and associated code are released under the CC-BY-NC-ND 4.0 license and may only be used for non-commercial, academic research purposes with proper attribution. Any commercial use, sale, or other monetization of the EchoFM model and its derivatives, which include models trained on outputs from the EchoFM model or datasets created from the EchoFM model, is prohibited and requires prior approval. 6 | 7 | 8 | 9 | This work was supported by the National Research Foundation of Korea(NRF) grant funded by the Korea government(MSIT) (RS-2024-00348696) 10 | 11 | ## Key features 12 | 13 | - EchoFM is pre-trained on 290K Echocardiography clips with self-supervised learning 14 | - EchoFM has been validated in multiple downstream tasks including segmentatino, classification, disease detection tasks. 15 | - EchoFM can be efficiently adapted to customised tasks. 16 | 17 | 18 | 19 | ## 1. Environment Setup 20 | 21 | ```bash 22 | git clone https://github.com/SekeunKim/EchoFM.git 23 | cd EchoFM 24 | ./environment_setup.sh EchoFM 25 | ``` 26 | 27 | ## 2. Download model 28 | Download the EchoFM weights from the following link: 29 | [EchoFM Weights](https://drive.google.com/drive/folders/1Gn43_qMwk-wzZIxZdxXLyk2mXDv5Jsxt?usp=share_link) 30 | 31 | ## 3. Citation 32 | If you find this repository useful, please consider citing this paper: [will be released soon] 33 | ``` 34 | @article{kim2024echofm, 35 | title={EchoFM: Foundation Model for Generalizable Echocardiogram Analysis}, 36 | author={Kim, Sekeun and Jin, Pengfei and Song, Sifan and Chen, Cheng and Li, Yiwei and Ren, Hui and Li, Xiang and Liu, Tianming and Li, Quanzheng}, 37 | journal={arXiv preprint arXiv:2410.23413}, 38 | year={2024} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /EchoFM/engine_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import EchoFM.util.misc as misc 13 | import torch 14 | 15 | 16 | @torch.no_grad() 17 | def test(data_loader, model, device, test_meter, fp32=False): 18 | metric_logger = misc.MetricLogger(delimiter=" ") 19 | 20 | # switch to evaluation mode 21 | model.eval() 22 | softmax = torch.nn.Softmax(dim=1).cuda() 23 | 24 | for cur_iter, (images, labels, video_idx) in enumerate(data_loader): 25 | images = images.to(device, non_blocking=True) 26 | labels = labels.to(device, non_blocking=True) 27 | video_idx = video_idx.to(device, non_blocking=True) 28 | 29 | if len(images.shape) == 6: 30 | b, r, c, t, h, w = images.shape 31 | images = images.view(b * r, c, t, h, w) 32 | labels = labels.view(b * r) 33 | 34 | # compute output 35 | with torch.cuda.amp.autocast(enabled=not fp32): 36 | preds = model(images) 37 | preds = softmax(preds) 38 | 39 | if torch.distributed.is_initialized(): 40 | preds, labels, video_idx = misc.all_gather([preds, labels, video_idx]) 41 | preds = preds.cpu() 42 | labels = labels.cpu() 43 | video_idx = video_idx.cpu() 44 | # Update and log stats. 45 | test_meter.update_stats(preds.detach(), labels.detach(), video_idx.detach()) 46 | test_meter.log_iter_stats(cur_iter) 47 | 48 | test_meter.finalize_metrics() 49 | # gather the stats from all processes 50 | metric_logger.synchronize_between_processes() 51 | return test_meter.stats 52 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # architecture 2 | arch: vit_base 3 | enc_arch: MAEViTEncoder 4 | dec_arch: MAEViTDecoder 5 | 6 | # wandb 7 | proj_name: mae3d 8 | run_name: ${proj_name}_${arch}_${dataset} 9 | wandb_id: 10 | disable_wandb: 0 11 | eval : True 12 | # dataset 13 | dataset: echo #mgh echo 14 | data_path: /nvme/zhoulei/MSD 15 | data_seed: 12345 16 | ts_fold: 0 17 | resize_h_w : [224,224] 18 | max_fr : 32 19 | view_type : "A24C" 20 | 21 | # output 22 | # output_dir: /nvme/zhoulei/ssl-framework/${run_name} 23 | # ckpt_dir: ${output_dir}/ckpts 24 | 25 | # data preprocessing 26 | roi_x: 128 27 | roi_y: 128 28 | roi_z: 128 29 | RandFlipd_prob: 0.2 30 | RandRotate90d_prob: 0.2 31 | RandScaleIntensityd_prob: 0.1 32 | RandShiftIntensityd_prob: 0.1 33 | spatial_dim: 3 34 | cache_rate: 1. 35 | 36 | # trainer 37 | trainer_name: MAE3DTrainer 38 | batch_size: 3 39 | vis_batch_size: 1 40 | start_epoch: 0 41 | warmup_epochs: 10 42 | epochs: 1000 43 | workers: 1 #8 44 | pretrain: 45 | resume: 46 | # model 47 | patchembed: 'PatchEmbed3D' 48 | pos_embed_type: 'sincos' 49 | mask_ratio: 0.75 50 | input_size: [224, 224, 32] 51 | patch_size: 16 52 | in_chans: 3 53 | encoder_embed_dim: 768 54 | encoder_depth: 12 55 | encoder_num_heads: 12 56 | decoder_embed_dim: 384 57 | decoder_depth: 8 58 | decoder_num_heads: 12 59 | 60 | # optimizer 61 | type: adamw 62 | lr: 6.4e-3 63 | beta1: 0.9 64 | beta2: 0.95 65 | weight_decay: 0.05 66 | 67 | # logging 68 | vis_freq: 10 69 | save_freq: 100 70 | print_freq: 5 71 | 72 | # distributed processing 73 | gpu: 0 74 | dist_url: # 'tcp://localhost:10001' 75 | world_size: 1 76 | multiprocessing_distributed: false 77 | dist_backend: nccl 78 | distributed: 79 | rank: 0 80 | ngpus_per_node: 81 | 82 | # randomness 83 | seed: 84 | 85 | # debugging 86 | debug: false 87 | 88 | #path 89 | base_pr_path : '/mount/home/local/PARTNERS/sk1064/workspace/nature/' 90 | base_data_path : '/mount/mnt/CAMCA/home/sk/us' 91 | base_json_path : '/mount/home/local/PARTNERS/sk1064/workspace/echo_samv2/dataset/Echo' 92 | json_path : '/camus/train_2c4c.json' 93 | output_dir: ${base_pr_path}/output/ssl-framework/${run_name} 94 | ckpt_dir: ${output_dir}/ckpts 95 | 96 | base_mgh_data_path : /mount/mnt/CAMCA/AorticStenosis -------------------------------------------------------------------------------- /EchoFM/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import mae_st.util.logging as logging 11 | import numpy as np 12 | import torch 13 | 14 | 15 | logger = logging.get_logger(__name__) 16 | 17 | 18 | # -------------------------------------------------------- 19 | # Interpolate position embeddings for high-resolution 20 | # References: 21 | # DeiT: https://github.com/facebookresearch/deit 22 | # -------------------------------------------------------- 23 | def interpolate_pos_embed(model, checkpoint_model): 24 | if "pos_embed" in checkpoint_model: 25 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 26 | embedding_size = pos_embed_checkpoint.shape[-1] 27 | num_patches = model.patch_embed.num_patches 28 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 29 | # height (== width) for the checkpoint position embedding 30 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 31 | # height (== width) for the new position embedding 32 | new_size = int(num_patches**0.5) 33 | # class_token and dist_token are kept unchanged 34 | if orig_size != new_size: 35 | print( 36 | "Position interpolate from %dx%d to %dx%d" 37 | % (orig_size, orig_size, new_size, new_size) 38 | ) 39 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 40 | # only the position tokens are interpolated 41 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 42 | pos_tokens = pos_tokens.reshape( 43 | -1, orig_size, orig_size, embedding_size 44 | ).permute(0, 3, 1, 2) 45 | pos_tokens = torch.nn.functional.interpolate( 46 | pos_tokens, 47 | size=(new_size, new_size), 48 | mode="bicubic", 49 | align_corners=False, 50 | ) 51 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 52 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 53 | checkpoint_model["pos_embed"] = new_pos_embed 54 | -------------------------------------------------------------------------------- /EchoFM/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd( 16 | model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 17 | ): 18 | """ 19 | Parameter groups for layer-wise lr decay 20 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 21 | """ 22 | param_group_names = {} 23 | param_groups = {} 24 | 25 | num_layers = len(model.blocks) + 1 26 | 27 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 28 | 29 | for n, p in model.named_parameters(): 30 | if not p.requires_grad: 31 | continue 32 | 33 | # no decay: all 1D parameters and model specific ones 34 | if p.ndim == 1 or n in no_weight_decay_list: 35 | g_decay = "no_decay" 36 | this_decay = 0.0 37 | else: 38 | g_decay = "decay" 39 | this_decay = weight_decay 40 | 41 | layer_id = get_layer_id_for_vit(n, num_layers) 42 | group_name = "layer_%d_%s" % (layer_id, g_decay) 43 | 44 | if group_name not in param_group_names: 45 | this_scale = layer_scales[layer_id] 46 | 47 | param_group_names[group_name] = { 48 | "lr_scale": this_scale, 49 | "weight_decay": this_decay, 50 | "params": [], 51 | } 52 | param_groups[group_name] = { 53 | "lr_scale": this_scale, 54 | "weight_decay": this_decay, 55 | "params": [], 56 | } 57 | 58 | param_group_names[group_name]["params"].append(n) 59 | param_groups[group_name]["params"].append(p) 60 | 61 | print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 62 | 63 | return list(param_groups.values()) 64 | 65 | 66 | def get_layer_id_for_vit(name, num_layers): 67 | """ 68 | Assign a parameter with its layer id 69 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 70 | """ 71 | if name in [ 72 | "cls_token", 73 | "mask_token", 74 | ]: 75 | return 0 76 | elif name.startswith("patch_embed"): 77 | return 0 78 | elif name.startswith("pos_embed"): 79 | return 0 80 | elif name.startswith("blocks"): 81 | return int(name.split(".")[1]) + 1 82 | else: 83 | return num_layers 84 | -------------------------------------------------------------------------------- /EchoFM/util/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """Logging.""" 6 | 7 | import atexit 8 | import builtins 9 | import decimal 10 | import functools 11 | import logging 12 | import os 13 | import sys 14 | 15 | import simplejson 16 | import torch 17 | import torch.distributed as dist 18 | from iopath.common.file_io import g_pathmgr as pathmgr 19 | 20 | 21 | def is_master_proc(multinode=False): 22 | """ 23 | Determines if the current process is the master process. 24 | """ 25 | if dist.is_initialized(): 26 | if multinode: 27 | return dist.get_rank() % dist.get_world_size() == 0 28 | else: 29 | return dist.get_rank() % torch.cuda.device_count() == 0 30 | else: 31 | return True 32 | 33 | 34 | def _suppress_print(): 35 | """ 36 | Suppresses printing from the current process. 37 | """ 38 | 39 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False): 40 | pass 41 | 42 | builtins.print = print_pass 43 | 44 | 45 | @functools.lru_cache(maxsize=None) 46 | def _cached_log_stream(filename): 47 | # Use 1K buffer if writing to cloud storage. 48 | io = pathmgr.open(filename, "a", buffering=1024 if "://" in filename else -1) 49 | atexit.register(io.close) 50 | return io 51 | 52 | 53 | def setup_logging(output_dir=None): 54 | """ 55 | Sets up the logging for multiple processes. Only enable the logging for the 56 | master process, and suppress logging for the non-master processes. 57 | """ 58 | # Set up logging format. 59 | if is_master_proc(): 60 | # Enable logging for the master process. 61 | logging.root.handlers = [] 62 | else: 63 | # Suppress logging for non-master processes. 64 | _suppress_print() 65 | 66 | logger = logging.getLogger() 67 | logger.setLevel(logging.DEBUG) 68 | logger.propagate = False 69 | plain_formatter = logging.Formatter( 70 | "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s", 71 | datefmt="%m/%d %H:%M:%S", 72 | ) 73 | 74 | if is_master_proc(): 75 | ch = logging.StreamHandler(stream=sys.stdout) 76 | ch.setLevel(logging.DEBUG) 77 | ch.setFormatter(plain_formatter) 78 | logger.addHandler(ch) 79 | 80 | if output_dir is not None and is_master_proc(multinode=True): 81 | filename = os.path.join(output_dir, "stdout.log") 82 | fh = logging.StreamHandler(_cached_log_stream(filename)) 83 | fh.setLevel(logging.DEBUG) 84 | fh.setFormatter(plain_formatter) 85 | logger.addHandler(fh) 86 | 87 | 88 | def get_logger(name): 89 | """ 90 | Retrieve the logger with the specified name or, if name is None, return a 91 | logger which is the root logger of the hierarchy. 92 | Args: 93 | name (string): name of the logger. 94 | """ 95 | return logging.getLogger(name) 96 | 97 | 98 | def log_json_stats(stats): 99 | """ 100 | Logs json stats. 101 | Args: 102 | stats (dict): a dictionary of statistical information to log. 103 | """ 104 | stats = { 105 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v 106 | for k, v in stats.items() 107 | } 108 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True) 109 | logger = get_logger(__name__) 110 | print("json_stats: {:s}".format(json_stats)) 111 | 112 | 113 | def master_print(*args, **kwargs): 114 | if is_master_proc(): 115 | print(*args, **kwargs) 116 | else: 117 | pass 118 | -------------------------------------------------------------------------------- /EchoFM/engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | from typing import Iterable 13 | 14 | import EchoFM.util.lr_sched as lr_sched 15 | import EchoFM.util.misc as misc 16 | import torch 17 | from iopath.common.file_io import g_pathmgr as pathmgr 18 | 19 | 20 | def train_one_epoch( 21 | model: torch.nn.Module, 22 | data_loader: Iterable, 23 | optimizer: torch.optim.Optimizer, 24 | device: torch.device, 25 | epoch: int, 26 | loss_scaler, 27 | log_writer=None, 28 | args=None, 29 | fp32=False, 30 | ): 31 | model.train(True) 32 | metric_logger = misc.MetricLogger(delimiter=" ") 33 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) 34 | metric_logger.add_meter( 35 | "cpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 36 | ) 37 | metric_logger.add_meter( 38 | "cpu_mem_all", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 39 | ) 40 | metric_logger.add_meter( 41 | "gpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 42 | ) 43 | metric_logger.add_meter( 44 | "mask_ratio", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 45 | ) 46 | header = "Epoch: [{}]".format(epoch) 47 | print_freq = 20 48 | 49 | accum_iter = args.accum_iter 50 | 51 | optimizer.zero_grad() 52 | 53 | if log_writer is not None: 54 | print("log_dir: {}".format(log_writer.log_dir)) 55 | 56 | for data_iter_step, samples in enumerate( 57 | metric_logger.log_every(data_loader, print_freq, header) 58 | ): 59 | # we use a per iteration (instead of per epoch) lr scheduler 60 | if data_iter_step % accum_iter == 0: 61 | lr_sched.adjust_learning_rate( 62 | optimizer, data_iter_step / len(data_loader) + epoch, args 63 | ) 64 | 65 | samples = samples.to(device, non_blocking=True) 66 | if len(samples.shape) == 6: 67 | b, r, c, t, h, w = samples.shape 68 | samples = samples.reshape(b * r, c, t, h, w) 69 | 70 | with torch.cuda.amp.autocast(enabled=not fp32): 71 | loss, _, _ = model( 72 | samples, 73 | mask_ratio=args.mask_ratio, 74 | ) 75 | 76 | loss_value = loss.item() 77 | 78 | if not math.isfinite(loss_value): 79 | for _ in range(args.num_checkpoint_del): 80 | try: 81 | path = misc.get_last_checkpoint(args) 82 | pathmgr.rm(path) 83 | print(f"remove checkpoint {path}") 84 | except Exception as _: 85 | pass 86 | raise Exception("Loss is {}, stopping training".format(loss_value)) 87 | 88 | loss /= accum_iter 89 | loss_scaler( 90 | loss, 91 | optimizer, 92 | parameters=model.parameters(), 93 | update_grad=(data_iter_step + 1) % accum_iter == 0, 94 | clip_grad=args.clip_grad, 95 | ) 96 | 97 | if (data_iter_step + 1) % accum_iter == 0: 98 | optimizer.zero_grad() 99 | 100 | torch.cuda.synchronize() 101 | 102 | metric_logger.update(loss=loss_value) 103 | metric_logger.update(cpu_mem=misc.cpu_mem_usage()[0]) 104 | metric_logger.update(cpu_mem_all=misc.cpu_mem_usage()[1]) 105 | metric_logger.update(gpu_mem=misc.gpu_mem_usage()) 106 | metric_logger.update(mask_ratio=args.mask_ratio) 107 | 108 | lr = optimizer.param_groups[0]["lr"] 109 | metric_logger.update(lr=lr) 110 | 111 | loss_value_reduce = misc.all_reduce_mean(loss_value) 112 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 113 | """We use epoch_1000x as the x-axis in tensorboard. 114 | This calibrates different curves when batch size changes. 115 | """ 116 | epoch_1000x = int( 117 | (data_iter_step / len(data_loader) + epoch) * 1000 * args.repeat_aug 118 | ) 119 | log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x) 120 | log_writer.add_scalar("lr", lr, epoch_1000x) 121 | 122 | # gather the stats from all processes 123 | metric_logger.synchronize_between_processes() 124 | print("Averaged stats:", metric_logger) 125 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 126 | -------------------------------------------------------------------------------- /EchoFM/util/video_vit.py: -------------------------------------------------------------------------------- 1 | import EchoFM.util.logging as logging 2 | import torch 3 | import torch.nn as nn 4 | from timm.models.layers import to_2tuple 5 | from timm.models.vision_transformer import DropPath, Mlp 6 | 7 | 8 | logger = logging.get_logger(__name__) 9 | 10 | 11 | class PatchEmbed(nn.Module): 12 | """Image to Patch Embedding""" 13 | 14 | def __init__( 15 | self, 16 | img_size=224, 17 | patch_size=16, 18 | in_chans=3, 19 | embed_dim=768, 20 | # temporal related: 21 | frames=32, 22 | t_patch_size=4, 23 | ): 24 | super().__init__() 25 | img_size = to_2tuple(img_size) 26 | patch_size = to_2tuple(patch_size) 27 | assert img_size[1] % patch_size[1] == 0 28 | assert img_size[0] % patch_size[0] == 0 29 | assert frames % t_patch_size == 0 30 | num_patches = ( 31 | (img_size[1] // patch_size[1]) 32 | * (img_size[0] // patch_size[0]) 33 | * (frames // t_patch_size) 34 | ) 35 | self.input_size = ( 36 | frames // t_patch_size, 37 | img_size[0] // patch_size[0], 38 | img_size[1] // patch_size[1], 39 | ) 40 | print( 41 | f"img_size {img_size} patch_size {patch_size} frames {frames} t_patch_size {t_patch_size}" 42 | ) 43 | self.img_size = img_size 44 | self.patch_size = patch_size 45 | 46 | self.frames = frames 47 | self.t_patch_size = t_patch_size 48 | 49 | self.num_patches = num_patches 50 | 51 | self.grid_size = img_size[0] // patch_size[0] 52 | self.t_grid_size = frames // t_patch_size 53 | 54 | kernel_size = [t_patch_size] + list(patch_size) 55 | self.proj = nn.Conv3d( 56 | in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size 57 | ) 58 | 59 | def forward(self, x): 60 | B, C, T, H, W = x.shape 61 | assert ( 62 | H == self.img_size[0] and W == self.img_size[1] 63 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 64 | assert T == self.frames 65 | x = self.proj(x).flatten(3) 66 | x = torch.einsum("ncts->ntsc", x) # [N, T, H*W, C] 67 | return x 68 | 69 | 70 | class Attention(nn.Module): 71 | def __init__( 72 | self, 73 | dim, 74 | num_heads=8, 75 | qkv_bias=False, 76 | qk_scale=None, 77 | attn_drop=0.0, 78 | proj_drop=0.0, 79 | input_size=(4, 14, 14), 80 | ): 81 | super().__init__() 82 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 83 | self.num_heads = num_heads 84 | head_dim = dim // num_heads 85 | self.scale = qk_scale or head_dim**-0.5 86 | 87 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 88 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 89 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 90 | assert attn_drop == 0.0 # do not use 91 | self.proj = nn.Linear(dim, dim) 92 | self.proj_drop = nn.Dropout(proj_drop) 93 | self.input_size = input_size 94 | assert input_size[1] == input_size[2] 95 | 96 | def forward(self, x): 97 | B, N, C = x.shape 98 | q = ( 99 | self.q(x) 100 | .reshape(B, N, self.num_heads, C // self.num_heads) 101 | .permute(0, 2, 1, 3) 102 | ) 103 | k = ( 104 | self.k(x) 105 | .reshape(B, N, self.num_heads, C // self.num_heads) 106 | .permute(0, 2, 1, 3) 107 | ) 108 | v = ( 109 | self.v(x) 110 | .reshape(B, N, self.num_heads, C // self.num_heads) 111 | .permute(0, 2, 1, 3) 112 | ) 113 | 114 | attn = (q @ k.transpose(-2, -1)) * self.scale 115 | 116 | attn = attn.softmax(dim=-1) 117 | 118 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 119 | x = self.proj(x) 120 | x = self.proj_drop(x) 121 | x = x.view(B, -1, C) 122 | return x 123 | 124 | 125 | class Block(nn.Module): 126 | """ 127 | Transformer Block with specified Attention function 128 | """ 129 | 130 | def __init__( 131 | self, 132 | dim, 133 | num_heads, 134 | mlp_ratio=4.0, 135 | qkv_bias=False, 136 | qk_scale=None, 137 | drop=0.0, 138 | attn_drop=0.0, 139 | drop_path=0.0, 140 | act_layer=nn.GELU, 141 | norm_layer=nn.LayerNorm, 142 | attn_func=Attention, 143 | ): 144 | super().__init__() 145 | self.norm1 = norm_layer(dim) 146 | self.attn = attn_func( 147 | dim, 148 | num_heads=num_heads, 149 | qkv_bias=qkv_bias, 150 | qk_scale=qk_scale, 151 | attn_drop=attn_drop, 152 | proj_drop=drop, 153 | ) 154 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 155 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 156 | self.norm2 = norm_layer(dim) 157 | mlp_hidden_dim = int(dim * mlp_ratio) 158 | self.mlp = Mlp( 159 | in_features=dim, 160 | hidden_features=mlp_hidden_dim, 161 | act_layer=act_layer, 162 | drop=drop, 163 | ) 164 | 165 | def forward(self, x): 166 | x = x + self.drop_path(self.attn(self.norm1(x))) 167 | x = x + self.drop_path(self.mlp(self.norm2(x))) 168 | return x 169 | -------------------------------------------------------------------------------- /EchoFM/engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import EchoFM.util.lr_sched as lr_sched 17 | import EchoFM.util.misc as misc 18 | import torch 19 | from EchoFM.util.logging import master_print as print 20 | from timm.data import Mixup 21 | from timm.utils import accuracy 22 | 23 | 24 | def train_one_epoch( 25 | model: torch.nn.Module, 26 | criterion: torch.nn.Module, 27 | data_loader: Iterable, 28 | optimizer: torch.optim.Optimizer, 29 | device: torch.device, 30 | epoch: int, 31 | loss_scaler, 32 | max_norm: float = 0, 33 | mixup_fn: Optional[Mixup] = None, 34 | log_writer=None, 35 | args=None, 36 | fp32=False, 37 | ): 38 | model.train(True) 39 | metric_logger = misc.MetricLogger(delimiter=" ") 40 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) 41 | metric_logger.add_meter( 42 | "cpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 43 | ) 44 | metric_logger.add_meter( 45 | "cpu_mem_all", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 46 | ) 47 | metric_logger.add_meter( 48 | "gpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}") 49 | ) 50 | header = "Epoch: [{}]".format(epoch) 51 | print_freq = 20 52 | 53 | accum_iter = args.accum_iter 54 | 55 | optimizer.zero_grad() 56 | 57 | if log_writer is not None: 58 | print("log_dir: {}".format(log_writer.log_dir)) 59 | 60 | for data_iter_step, (samples, targets) in enumerate( 61 | metric_logger.log_every(data_loader, print_freq, header) 62 | ): 63 | 64 | # we use a per iteration (instead of per epoch) lr scheduler 65 | if data_iter_step % accum_iter == 0: 66 | lr_sched.adjust_learning_rate( 67 | optimizer, data_iter_step / len(data_loader) + epoch, args 68 | ) 69 | 70 | if len(samples.shape) == 6: 71 | b, r, c, t, h, w = samples.shape 72 | samples = samples.view(b * r, c, t, h, w) 73 | targets = targets.view(b * r) 74 | 75 | if args.cpu_mix: 76 | if mixup_fn is not None: 77 | samples, targets = mixup_fn(samples, targets) 78 | samples = samples.to(device, non_blocking=True) 79 | targets = targets.to(device, non_blocking=True) 80 | else: 81 | samples = samples.to(device, non_blocking=True) 82 | targets = targets.to(device, non_blocking=True) 83 | if mixup_fn is not None: 84 | samples, targets = mixup_fn(samples, targets) 85 | 86 | with torch.cuda.amp.autocast(enabled=not fp32): 87 | outputs = model(samples) 88 | loss = criterion(outputs, targets) 89 | 90 | loss_value = loss.item() 91 | 92 | if not math.isfinite(loss_value): 93 | print("Loss is {}, stopping training".format(loss_value)) 94 | sys.exit(1) 95 | 96 | loss /= accum_iter 97 | loss_scaler( 98 | loss, 99 | optimizer, 100 | clip_grad=max_norm, 101 | parameters=model.parameters(), 102 | create_graph=False, 103 | update_grad=(data_iter_step + 1) % accum_iter == 0, 104 | ) 105 | if (data_iter_step + 1) % accum_iter == 0: 106 | optimizer.zero_grad() 107 | 108 | torch.cuda.synchronize() 109 | 110 | metric_logger.update(loss=loss_value) 111 | metric_logger.update(cpu_mem=misc.cpu_mem_usage()[0]) 112 | metric_logger.update(cpu_mem_all=misc.cpu_mem_usage()[1]) 113 | metric_logger.update(gpu_mem=misc.gpu_mem_usage()) 114 | min_lr = 10.0 115 | max_lr = 0.0 116 | for group in optimizer.param_groups: 117 | min_lr = min(min_lr, group["lr"]) 118 | max_lr = max(max_lr, group["lr"]) 119 | 120 | metric_logger.update(lr=max_lr) 121 | 122 | loss_value_reduce = misc.all_reduce_mean(loss_value) 123 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 124 | """We use epoch_1000x as the x-axis in tensorboard. 125 | This calibrates different curves when batch size changes. 126 | """ 127 | epoch_1000x = int( 128 | (data_iter_step / len(data_loader) + epoch) * 1000 * args.repeat_aug 129 | ) 130 | log_writer.add_scalar("loss", loss_value_reduce, epoch_1000x) 131 | log_writer.add_scalar("lr", max_lr, epoch_1000x) 132 | 133 | # gather the stats from all processes 134 | metric_logger.synchronize_between_processes() 135 | print("Averaged stats:", metric_logger) 136 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 137 | 138 | 139 | @torch.no_grad() 140 | def evaluate(data_loader, model, device): 141 | criterion = torch.nn.CrossEntropyLoss() 142 | 143 | metric_logger = misc.MetricLogger(delimiter=" ") 144 | header = "Test:" 145 | 146 | # switch to evaluation mode 147 | model.eval() 148 | 149 | for batch in metric_logger.log_every(data_loader, 10, header): 150 | images = batch[0] 151 | target = batch[-1] 152 | images = images.to(device, non_blocking=True) 153 | target = target.to(device, non_blocking=True) 154 | 155 | if len(images.shape) == 6: 156 | b, r, c, t, h, w = images.shape 157 | images = images.view(b * r, c, t, h, w) 158 | target = target.view(b * r) 159 | 160 | # compute output 161 | with torch.cuda.amp.autocast(): 162 | output = model(images) 163 | loss = criterion(output, target) 164 | 165 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 166 | 167 | batch_size = images.shape[0] 168 | metric_logger.update(loss=loss.item()) 169 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 170 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 171 | # gather the stats from all processes 172 | metric_logger.synchronize_between_processes() 173 | print( 174 | "* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format( 175 | top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss 176 | ) 177 | ) 178 | 179 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 180 | -------------------------------------------------------------------------------- /EchoFM/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # MAE: https://github.com/facebookresearch/mae 11 | # -------------------------------------------------------- 12 | 13 | from functools import partial 14 | 15 | import torch 16 | import torch.nn as nn 17 | from EchoFM.util.logging import master_print as print 18 | 19 | from EchoFM.util.video_vit import Attention, Block, PatchEmbed 20 | 21 | 22 | class VisionTransformer(nn.Module): 23 | """Vision Transformer with support for global average pooling""" 24 | 25 | def __init__( 26 | self, 27 | num_frames, 28 | t_patch_size, 29 | img_size=224, 30 | patch_size=16, 31 | in_chans=3, 32 | num_classes=400, 33 | embed_dim=768, 34 | depth=12, 35 | num_heads=12, 36 | mlp_ratio=4.0, 37 | no_qkv_bias=False, 38 | qk_scale=None, 39 | drop_rate=0.0, 40 | attn_drop_rate=0.0, 41 | drop_path_rate=0.0, 42 | norm_layer=nn.LayerNorm, 43 | dropout=0.5, 44 | sep_pos_embed=False, 45 | cls_embed=False, 46 | **kwargs, 47 | ): 48 | super().__init__() 49 | print(locals()) 50 | 51 | self.sep_pos_embed = sep_pos_embed 52 | # -------------------------------------------------------------------------- 53 | # MAE encoder specifics 54 | self.patch_embed = PatchEmbed( 55 | img_size, patch_size, in_chans, embed_dim, num_frames, t_patch_size 56 | ) 57 | num_patches = self.patch_embed.num_patches 58 | input_size = self.patch_embed.input_size 59 | self.input_size = input_size 60 | self.cls_embed = cls_embed 61 | 62 | if self.cls_embed: 63 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 64 | 65 | if sep_pos_embed: 66 | self.pos_embed_spatial = nn.Parameter( 67 | torch.zeros(1, input_size[1] * input_size[2], embed_dim) 68 | ) 69 | self.pos_embed_temporal = nn.Parameter( 70 | torch.zeros(1, input_size[0], embed_dim) 71 | ) 72 | if self.cls_embed: 73 | self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim)) 74 | else: 75 | if self.cls_embed: 76 | _num_patches = num_patches + 1 77 | else: 78 | _num_patches = num_patches 79 | 80 | self.pos_embed = nn.Parameter( 81 | torch.zeros(1, _num_patches, embed_dim), requires_grad=True 82 | ) # fixed or not? 83 | 84 | dpr = [ 85 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 86 | ] # stochastic depth decay rule 87 | 88 | self.blocks = nn.ModuleList( 89 | [ 90 | Block( 91 | embed_dim, 92 | num_heads, 93 | mlp_ratio, 94 | qkv_bias=not no_qkv_bias, 95 | qk_scale=None, 96 | norm_layer=norm_layer, 97 | drop_path=dpr[i], 98 | attn_func=partial( 99 | Attention, 100 | input_size=self.patch_embed.input_size, 101 | ), 102 | ) 103 | for i in range(depth) 104 | ] 105 | ) 106 | self.norm = norm_layer(embed_dim) 107 | # -------------------------------------------------------------------------- 108 | 109 | self.dropout = nn.Dropout(dropout) 110 | self.head = nn.Linear(embed_dim, num_classes) 111 | 112 | torch.nn.init.normal_(self.head.weight, std=0.02) 113 | 114 | @torch.jit.ignore 115 | def no_weight_decay(self): 116 | return { 117 | "cls_token", 118 | "pos_embed", 119 | "pos_embed_spatial", 120 | "pos_embed_temporal", 121 | "pos_embed_class", 122 | } 123 | 124 | def forward(self, x): 125 | # embed patches 126 | x = self.patch_embed(x) 127 | N, T, L, C = x.shape # T: temporal; L: spatial 128 | 129 | x = x.view([N, T * L, C]) 130 | 131 | # append cls token 132 | if self.cls_embed: 133 | cls_token = self.cls_token 134 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 135 | x = torch.cat((cls_tokens, x), dim=1) 136 | 137 | if self.sep_pos_embed: 138 | pos_embed = self.pos_embed_spatial.repeat( 139 | 1, self.input_size[0], 1 140 | ) + torch.repeat_interleave( 141 | self.pos_embed_temporal, 142 | self.input_size[1] * self.input_size[2], 143 | dim=1, 144 | ) 145 | if self.cls_embed: 146 | pos_embed = torch.cat( 147 | [ 148 | self.pos_embed_class.expand(pos_embed.shape[0], -1, -1), 149 | pos_embed, 150 | ], 151 | 1, 152 | ) 153 | else: 154 | pos_embed = self.pos_embed[:, :, :] 155 | x = x + pos_embed 156 | 157 | # reshape to [N, T, L, C] or [N, T*L, C] 158 | requires_t_shape = ( 159 | len(self.blocks) > 0 # support empty decoder 160 | and hasattr(self.blocks[0].attn, "requires_t_shape") 161 | and self.blocks[0].attn.requires_t_shape 162 | ) 163 | if requires_t_shape: 164 | x = x.view([N, T, L, C]) 165 | 166 | # apply Transformer blocks 167 | for blk in self.blocks: 168 | x = blk(x) 169 | 170 | if requires_t_shape: 171 | x = x.view([N, T * L, C]) 172 | 173 | # classifier 174 | x = x[:, 1:, :].mean(dim=1) # global pool 175 | x = self.norm(x) 176 | # x = self.fc_norm(x) 177 | x = self.dropout(x) 178 | x = self.head(x) 179 | 180 | return x 181 | 182 | 183 | def vit_base_patch16(**kwargs): 184 | model = VisionTransformer( 185 | patch_size=16, 186 | embed_dim=768, 187 | depth=12, 188 | num_heads=12, 189 | mlp_ratio=4, 190 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 191 | **kwargs, 192 | ) 193 | return model 194 | 195 | 196 | def vit_large_patch16(**kwargs): 197 | model = VisionTransformer( 198 | patch_size=16, 199 | embed_dim=1024, 200 | depth=24, 201 | num_heads=16, 202 | mlp_ratio=4, 203 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 204 | **kwargs, 205 | ) 206 | return model 207 | 208 | 209 | def vit_huge_patch14(**kwargs): 210 | model = VisionTransformer( 211 | patch_size=16, 212 | embed_dim=1280, 213 | depth=32, 214 | num_heads=16, 215 | mlp_ratio=4, 216 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 217 | **kwargs, 218 | ) 219 | return model 220 | -------------------------------------------------------------------------------- /EchoFM/util/decoder/mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """ 6 | This implementation is based on 7 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/mixup.py, 8 | published under an Apache License 2.0. 9 | 10 | COMMENT FROM ORIGINAL: 11 | Mixup and Cutmix 12 | Papers: 13 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 14 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) # NOQA 15 | Code Reference: 16 | CutMix: https://github.com/clovaai/CutMix-PyTorch 17 | Hacked together by / Copyright 2020 Ross Wightman 18 | """ 19 | 20 | import numpy as np 21 | import torch 22 | 23 | 24 | def convert_to_one_hot(targets, num_classes, on_value=1.0, off_value=0.0): 25 | """ 26 | This function converts target class indices to one-hot vectors, given the 27 | number of classes. 28 | Args: 29 | targets (loader): Class labels. 30 | num_classes (int): Total number of classes. 31 | on_value (float): Target Value for ground truth class. 32 | off_value (float): Target Value for other classes.This value is used for 33 | label smoothing. 34 | """ 35 | 36 | targets = targets.long().view(-1, 1) 37 | return torch.full( 38 | (targets.size()[0], num_classes), off_value, device=targets.device 39 | ).scatter_(1, targets, on_value) 40 | 41 | 42 | def mixup_target(target, num_classes, lam=1.0, smoothing=0.0): 43 | """ 44 | This function converts target class indices to one-hot vectors, given the 45 | number of classes. 46 | Args: 47 | targets (loader): Class labels. 48 | num_classes (int): Total number of classes. 49 | lam (float): lamba value for mixup/cutmix. 50 | smoothing (float): Label smoothing value. 51 | """ 52 | off_value = smoothing / num_classes 53 | on_value = 1.0 - smoothing + off_value 54 | target1 = convert_to_one_hot( 55 | target, 56 | num_classes, 57 | on_value=on_value, 58 | off_value=off_value, 59 | ) 60 | target2 = convert_to_one_hot( 61 | target.flip(0), 62 | num_classes, 63 | on_value=on_value, 64 | off_value=off_value, 65 | ) 66 | return target1 * lam + target2 * (1.0 - lam) 67 | 68 | 69 | def rand_bbox(img_shape, lam, margin=0.0, count=None): 70 | """ 71 | Generates a random square bbox based on lambda value. 72 | 73 | Args: 74 | img_shape (tuple): Image shape as tuple 75 | lam (float): Cutmix lambda value 76 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) 77 | count (int): Number of bbox to generate 78 | """ 79 | ratio = np.sqrt(1 - lam) 80 | img_h, img_w = img_shape[-2:] 81 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) 82 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) 83 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) 84 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) 85 | yl = np.clip(cy - cut_h // 2, 0, img_h) 86 | yh = np.clip(cy + cut_h // 2, 0, img_h) 87 | xl = np.clip(cx - cut_w // 2, 0, img_w) 88 | xh = np.clip(cx + cut_w // 2, 0, img_w) 89 | return yl, yh, xl, xh 90 | 91 | 92 | def get_cutmix_bbox(img_shape, lam, correct_lam=True, count=None): 93 | """ 94 | Generates the box coordinates for cutmix. 95 | 96 | Args: 97 | img_shape (tuple): Image shape as tuple 98 | lam (float): Cutmix lambda value 99 | correct_lam (bool): Apply lambda correction when cutmix bbox clipped by 100 | image borders. 101 | count (int): Number of bbox to generate 102 | """ 103 | 104 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) 105 | if correct_lam: 106 | bbox_area = (yu - yl) * (xu - xl) 107 | lam = 1.0 - bbox_area / float(img_shape[-2] * img_shape[-1]) 108 | return (yl, yu, xl, xu), lam 109 | 110 | 111 | class MixUp: 112 | """ 113 | Apply mixup and/or cutmix for videos at batch level. 114 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) 115 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable 116 | Features (https://arxiv.org/abs/1905.04899) 117 | """ 118 | 119 | def __init__( 120 | self, 121 | mixup_alpha=1.0, 122 | cutmix_alpha=0.0, 123 | mix_prob=1.0, 124 | switch_prob=0.5, 125 | correct_lam=True, 126 | label_smoothing=0.1, 127 | num_classes=1000, 128 | ): 129 | """ 130 | Args: 131 | mixup_alpha (float): Mixup alpha value. 132 | cutmix_alpha (float): Cutmix alpha value. 133 | mix_prob (float): Probability of applying mixup or cutmix. 134 | switch_prob (float): Probability of switching to cutmix instead of 135 | mixup when both are active. 136 | correct_lam (bool): Apply lambda correction when cutmix bbox 137 | clipped by image borders. 138 | label_smoothing (float): Apply label smoothing to the mixed target 139 | tensor. If label_smoothing is not used, set it to 0. 140 | num_classes (int): Number of classes for target. 141 | """ 142 | self.mixup_alpha = mixup_alpha 143 | self.cutmix_alpha = cutmix_alpha 144 | self.mix_prob = mix_prob 145 | self.switch_prob = switch_prob 146 | self.label_smoothing = label_smoothing 147 | self.num_classes = num_classes 148 | self.correct_lam = correct_lam 149 | 150 | def _get_mixup_params(self): 151 | lam = 1.0 152 | use_cutmix = False 153 | if np.random.rand() < self.mix_prob: 154 | if self.mixup_alpha > 0.0 and self.cutmix_alpha > 0.0: 155 | use_cutmix = np.random.rand() < self.switch_prob 156 | lam_mix = ( 157 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 158 | if use_cutmix 159 | else np.random.beta(self.mixup_alpha, self.mixup_alpha) 160 | ) 161 | elif self.mixup_alpha > 0.0: 162 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 163 | elif self.cutmix_alpha > 0.0: 164 | use_cutmix = True 165 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 166 | lam = float(lam_mix) 167 | return lam, use_cutmix 168 | 169 | def _mix_batch(self, x): 170 | lam, use_cutmix = self._get_mixup_params() 171 | if lam == 1.0: 172 | return 1.0 173 | if use_cutmix: 174 | (yl, yh, xl, xh), lam = get_cutmix_bbox( 175 | x.shape, 176 | lam, 177 | correct_lam=self.correct_lam, 178 | ) 179 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh] 180 | else: 181 | x_flipped = x.flip(0).mul_(1.0 - lam) 182 | x.mul_(lam).add_(x_flipped) 183 | return lam 184 | 185 | def __call__(self, x, target): 186 | assert len(x) > 1, "Batch size should be greater than 1 for mixup." 187 | lam = self._mix_batch(x) 188 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing) 189 | return x, target 190 | -------------------------------------------------------------------------------- /EchoFM/util/decoder/random_erasing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """ 6 | This implementation is based on 7 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 8 | pulished under an Apache License 2.0. 9 | 10 | COMMENT FROM ORIGINAL: 11 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 12 | Copyright Zhun Zhong & Liang Zheng 13 | Hacked together by / Copyright 2020 Ross Wightman 14 | """ 15 | import math 16 | import random 17 | 18 | import torch 19 | 20 | 21 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"): 22 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 23 | # paths, flip the order so normal is run on CPU if this becomes a problem 24 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 25 | if per_pixel: 26 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 27 | elif rand_color: 28 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() 29 | else: 30 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 31 | 32 | 33 | class RandomErasing: 34 | """Randomly selects a rectangle region in an image and erases its pixels. 35 | 'Random Erasing Data Augmentation' by Zhong et al. 36 | See https://arxiv.org/pdf/1708.04896.pdf 37 | This variant of RandomErasing is intended to be applied to either a batch 38 | or single image tensor after it has been normalized by dataset mean and std. 39 | Args: 40 | probability: Probability that the Random Erasing operation will be performed. 41 | min_area: Minimum percentage of erased area wrt input image area. 42 | max_area: Maximum percentage of erased area wrt input image area. 43 | min_aspect: Minimum aspect ratio of erased area. 44 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 45 | 'const' - erase block is constant color of 0 for all channels 46 | 'rand' - erase block is same per-channel random (normal) color 47 | 'pixel' - erase block is per-pixel random (normal) color 48 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 49 | per-image count is randomly chosen between 1 and this value. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | probability=0.5, 55 | min_area=0.02, 56 | max_area=1 / 3, 57 | min_aspect=0.3, 58 | max_aspect=None, 59 | mode="const", 60 | min_count=1, 61 | max_count=None, 62 | num_splits=0, 63 | device="cuda", 64 | cube=True, 65 | ): 66 | self.probability = probability 67 | self.min_area = min_area 68 | self.max_area = max_area 69 | max_aspect = max_aspect or 1 / min_aspect 70 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 71 | self.min_count = min_count 72 | self.max_count = max_count or min_count 73 | self.num_splits = num_splits 74 | mode = mode.lower() 75 | self.rand_color = False 76 | self.per_pixel = False 77 | self.cube = cube 78 | if mode == "rand": 79 | self.rand_color = True # per block random normal 80 | elif mode == "pixel": 81 | self.per_pixel = True # per pixel random normal 82 | else: 83 | assert not mode or mode == "const" 84 | self.device = device 85 | 86 | def _erase(self, img, chan, img_h, img_w, dtype): 87 | if random.random() > self.probability: 88 | return 89 | area = img_h * img_w 90 | count = ( 91 | self.min_count 92 | if self.min_count == self.max_count 93 | else random.randint(self.min_count, self.max_count) 94 | ) 95 | for _ in range(count): 96 | for _ in range(10): 97 | target_area = ( 98 | random.uniform(self.min_area, self.max_area) * area / count 99 | ) 100 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 101 | h = int(round(math.sqrt(target_area * aspect_ratio))) 102 | w = int(round(math.sqrt(target_area / aspect_ratio))) 103 | if w < img_w and h < img_h: 104 | top = random.randint(0, img_h - h) 105 | left = random.randint(0, img_w - w) 106 | img[:, top : top + h, left : left + w] = _get_pixels( 107 | self.per_pixel, 108 | self.rand_color, 109 | (chan, h, w), 110 | dtype=dtype, 111 | device=self.device, 112 | ) 113 | break 114 | 115 | def _erase_cube( 116 | self, 117 | img, 118 | batch_start, 119 | batch_size, 120 | chan, 121 | img_h, 122 | img_w, 123 | dtype, 124 | ): 125 | if random.random() > self.probability: 126 | return 127 | area = img_h * img_w 128 | count = ( 129 | self.min_count 130 | if self.min_count == self.max_count 131 | else random.randint(self.min_count, self.max_count) 132 | ) 133 | for _ in range(count): 134 | for _ in range(100): 135 | target_area = ( 136 | random.uniform(self.min_area, self.max_area) * area / count 137 | ) 138 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 139 | h = int(round(math.sqrt(target_area * aspect_ratio))) 140 | w = int(round(math.sqrt(target_area / aspect_ratio))) 141 | if w < img_w and h < img_h: 142 | top = random.randint(0, img_h - h) 143 | left = random.randint(0, img_w - w) 144 | for i in range(batch_start, batch_size): 145 | img_instance = img[i] 146 | img_instance[:, top : top + h, left : left + w] = _get_pixels( 147 | self.per_pixel, 148 | self.rand_color, 149 | (chan, h, w), 150 | dtype=dtype, 151 | device=self.device, 152 | ) 153 | break 154 | 155 | def __call__(self, input): 156 | if len(input.size()) == 3: 157 | self._erase(input, *input.size(), input.dtype) 158 | else: 159 | batch_size, chan, img_h, img_w = input.size() 160 | # skip first slice of batch if num_splits is set (for clean portion of samples) 161 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 162 | if self.cube: 163 | self._erase_cube( 164 | input, 165 | batch_start, 166 | batch_size, 167 | chan, 168 | img_h, 169 | img_w, 170 | input.dtype, 171 | ) 172 | else: 173 | for i in range(batch_start, batch_size): 174 | self._erase(input[i], chan, img_h, img_w, input.dtype) 175 | return input 176 | -------------------------------------------------------------------------------- /EchoFM/util/decoder/decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import math 6 | import random 7 | 8 | import numpy as np 9 | import torch 10 | import torchvision.io as io 11 | 12 | 13 | def temporal_sampling(frames, start_idx, end_idx, num_samples): 14 | """ 15 | Given the start and end frame index, sample num_samples frames between 16 | the start and end with equal interval. 17 | Args: 18 | frames (tensor): a tensor of video frames, dimension is 19 | `num video frames` x `channel` x `height` x `width`. 20 | start_idx (int): the index of the start frame. 21 | end_idx (int): the index of the end frame. 22 | num_samples (int): number of frames to sample. 23 | Returns: 24 | frames (tersor): a tensor of temporal sampled video frames, dimension is 25 | `num clip frames` x `channel` x `height` x `width`. 26 | """ 27 | index = torch.linspace(start_idx, end_idx, num_samples) 28 | index = torch.clamp(index, 0, frames.shape[0] - 1).long() 29 | new_frames = torch.index_select(frames, 0, index) 30 | return new_frames 31 | 32 | 33 | def get_start_end_idx(video_size, clip_size, clip_idx, num_clips, use_offset=False): 34 | """ 35 | Sample a clip of size clip_size from a video of size video_size and 36 | return the indices of the first and last frame of the clip. If clip_idx is 37 | -1, the clip is randomly sampled, otherwise uniformly split the video to 38 | num_clips clips, and select the start and end index of clip_idx-th video 39 | clip. 40 | Args: 41 | video_size (int): number of overall frames. 42 | clip_size (int): size of the clip to sample from the frames. 43 | clip_idx (int): if clip_idx is -1, perform random jitter sampling. If 44 | clip_idx is larger than -1, uniformly split the video to num_clips 45 | clips, and select the start and end index of the clip_idx-th video 46 | clip. 47 | num_clips (int): overall number of clips to uniformly sample from the 48 | given video for testing. 49 | Returns: 50 | start_idx (int): the start frame index. 51 | end_idx (int): the end frame index. 52 | """ 53 | delta = max(video_size - clip_size, 0) 54 | if clip_idx == -1: 55 | # Random temporal sampling. 56 | start_idx = random.uniform(0, delta) 57 | else: 58 | if use_offset: 59 | if num_clips == 1: 60 | # Take the center clip if num_clips is 1. 61 | start_idx = math.floor(delta / 2) 62 | else: 63 | # Uniformly sample the clip with the given index. 64 | start_idx = clip_idx * math.floor(delta / (num_clips - 1)) 65 | else: 66 | # Uniformly sample the clip with the given index. 67 | start_idx = delta * clip_idx / num_clips 68 | end_idx = start_idx + clip_size - 1 69 | return start_idx, end_idx 70 | 71 | 72 | def decode( 73 | container, 74 | sampling_rate, 75 | num_frames, 76 | clip_idx=-1, 77 | num_clips=10, 78 | video_meta=None, 79 | target_fps=30, 80 | max_spatial_scale=0, 81 | use_offset=False, 82 | rigid_decode_all_video=True, 83 | modalities=("visual",), 84 | ): 85 | """ 86 | Decode the video and perform temporal sampling. 87 | Args: 88 | container (container): pyav container. 89 | sampling_rate (int): frame sampling rate (interval between two sampled 90 | frames). 91 | num_frames (int): number of frames to sample. 92 | clip_idx (int): if clip_idx is -1, perform random temporal 93 | sampling. If clip_idx is larger than -1, uniformly split the 94 | video to num_clips clips, and select the 95 | clip_idx-th video clip. 96 | num_clips (int): overall number of clips to uniformly 97 | sample from the given video. 98 | video_meta (dict): a dict contains VideoMetaData. Details can be find 99 | at `pytorch/vision/torchvision/io/_video_opt.py`. 100 | target_fps (int): the input video may have different fps, convert it to 101 | the target video fps before frame sampling. 102 | max_spatial_scale (int): keep the aspect ratio and resize the frame so 103 | that shorter edge size is max_spatial_scale. Only used in 104 | `torchvision` backend. 105 | Returns: 106 | frames (tensor): decoded frames from the video. 107 | """ 108 | try: 109 | assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx) 110 | # Convert the bytes to a tensor. 111 | video_tensor = torch.from_numpy(np.frombuffer(container, dtype=np.uint8)) 112 | 113 | decode_all_video = True 114 | video_start_pts, video_end_pts = 0, -1 115 | # The video_meta is empty, fetch the meta data from the raw video. 116 | if len(video_meta) == 0: 117 | # Tracking the meta info for selective decoding in the future. 118 | meta = io._probe_video_from_memory(video_tensor) 119 | # Using the information from video_meta to perform selective decoding. 120 | video_meta["video_timebase"] = meta.video_timebase 121 | video_meta["video_numerator"] = meta.video_timebase.numerator 122 | video_meta["video_denominator"] = meta.video_timebase.denominator 123 | video_meta["has_video"] = meta.has_video 124 | video_meta["video_duration"] = meta.video_duration 125 | video_meta["video_fps"] = meta.video_fps 126 | video_meta["audio_timebas"] = meta.audio_timebase 127 | video_meta["audio_numerator"] = meta.audio_timebase.numerator 128 | video_meta["audio_denominator"] = meta.audio_timebase.denominator 129 | video_meta["has_audio"] = meta.has_audio 130 | video_meta["audio_duration"] = meta.audio_duration 131 | video_meta["audio_sample_rate"] = meta.audio_sample_rate 132 | 133 | fps = video_meta["video_fps"] 134 | if not rigid_decode_all_video: 135 | if ( 136 | video_meta["has_video"] 137 | and video_meta["video_denominator"] > 0 138 | and video_meta["video_duration"] > 0 139 | ): 140 | # try selective decoding. 141 | decode_all_video = False 142 | clip_size = sampling_rate * num_frames / target_fps * fps 143 | start_idx, end_idx = get_start_end_idx( 144 | fps * video_meta["video_duration"], 145 | clip_size, 146 | clip_idx, 147 | num_clips, 148 | use_offset=use_offset, 149 | ) 150 | # Convert frame index to pts. 151 | pts_per_frame = video_meta["video_denominator"] / fps 152 | video_start_pts = int(start_idx * pts_per_frame) 153 | video_end_pts = int(end_idx * pts_per_frame) 154 | 155 | # Decode the raw video with the tv decoder. 156 | v_frames, _ = io._read_video_from_memory( 157 | video_tensor, 158 | seek_frame_margin=1.0, 159 | read_video_stream="visual" in modalities, 160 | video_width=0, 161 | video_height=0, 162 | video_min_dimension=max_spatial_scale, 163 | video_pts_range=(video_start_pts, video_end_pts), 164 | video_timebase_numerator=video_meta["video_numerator"], 165 | video_timebase_denominator=video_meta["video_denominator"], 166 | ) 167 | 168 | if v_frames.shape == torch.Size([0]): 169 | # failed selective decoding 170 | decode_all_video = True 171 | video_start_pts, video_end_pts = 0, -1 172 | v_frames, _ = io._read_video_from_memory( 173 | video_tensor, 174 | seek_frame_margin=1.0, 175 | read_video_stream="visual" in modalities, 176 | video_width=0, 177 | video_height=0, 178 | video_min_dimension=max_spatial_scale, 179 | video_pts_range=(video_start_pts, video_end_pts), 180 | video_timebase_numerator=video_meta["video_numerator"], 181 | video_timebase_denominator=video_meta["video_denominator"], 182 | ) 183 | except Exception as e: 184 | print("Failed to decode by torchvision with exception: {}".format(e)) 185 | return None 186 | 187 | # Return None if the frames was not decoded successfully. 188 | if v_frames is None or v_frames.size(0) == 0: 189 | return None, fps, decode_all_video 190 | return v_frames, fps, decode_all_video 191 | -------------------------------------------------------------------------------- /EchoFM/util/meters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import numpy as np 6 | import torch 7 | from sklearn.metrics import average_precision_score 8 | 9 | 10 | def topks_correct(preds, labels, ks): 11 | """ 12 | Given the predictions, labels, and a list of top-k values, compute the 13 | number of correct predictions for each top-k value. 14 | 15 | Args: 16 | preds (array): array of predictions. Dimension is batchsize 17 | N x ClassNum. 18 | labels (array): array of labels. Dimension is batchsize N. 19 | ks (list): list of top-k values. For example, ks = [1, 5] correspods 20 | to top-1 and top-5. 21 | 22 | Returns: 23 | topks_correct (list): list of numbers, where the `i`-th entry 24 | corresponds to the number of top-`ks[i]` correct predictions. 25 | """ 26 | assert preds.size(0) == labels.size( 27 | 0 28 | ), "Batch dim of predictions and labels must match" 29 | # Find the top max_k predictions for each sample 30 | _top_max_k_vals, top_max_k_inds = torch.topk( 31 | preds, max(ks), dim=1, largest=True, sorted=True 32 | ) 33 | # (batch_size, max_k) -> (max_k, batch_size). 34 | top_max_k_inds = top_max_k_inds.t() 35 | # (batch_size, ) -> (max_k, batch_size). 36 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 37 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct. 38 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 39 | # Compute the number of topk correct predictions for each k. 40 | topks_correct = [top_max_k_correct[:k, :].float().sum() for k in ks] 41 | return topks_correct 42 | 43 | 44 | def topk_errors(preds, labels, ks): 45 | """ 46 | Computes the top-k error for each k. 47 | Args: 48 | preds (array): array of predictions. Dimension is N. 49 | labels (array): array of labels. Dimension is N. 50 | ks (list): list of ks to calculate the top accuracies. 51 | """ 52 | num_topks_correct = topks_correct(preds, labels, ks) 53 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct] 54 | 55 | 56 | def topk_accuracies(preds, labels, ks): 57 | """ 58 | Computes the top-k accuracy for each k. 59 | Args: 60 | preds (array): array of predictions. Dimension is N. 61 | labels (array): array of labels. Dimension is N. 62 | ks (list): list of ks to calculate the top accuracies. 63 | """ 64 | num_topks_correct = topks_correct(preds, labels, ks) 65 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct] 66 | 67 | 68 | def get_map(preds, labels): 69 | """ 70 | Compute mAP for multi-label case. 71 | Args: 72 | preds (numpy tensor): num_examples x num_classes. 73 | labels (numpy tensor): num_examples x num_classes. 74 | Returns: 75 | mean_ap (int): final mAP score. 76 | """ 77 | 78 | print("Getting mAP for {} examples".format(preds.shape[0])) 79 | 80 | preds = preds[:, ~(np.all(labels == 0, axis=0))] 81 | labels = labels[:, ~(np.all(labels == 0, axis=0))] 82 | aps = [0] 83 | try: 84 | aps = average_precision_score(labels, preds, average=None) 85 | except ValueError: 86 | print( 87 | "Average precision requires a sufficient number of samples \ 88 | in a batch which are missing in this sample." 89 | ) 90 | 91 | mean_ap = np.mean(aps) 92 | return mean_ap 93 | 94 | 95 | class TestMeter: 96 | """ 97 | Perform the multi-view ensemble for testing: each video with an unique index 98 | will be sampled with multiple clips, and the predictions of the clips will 99 | be aggregated to produce the final prediction for the video. 100 | The accuracy is calculated with the given ground truth labels. 101 | """ 102 | 103 | def __init__( 104 | self, 105 | num_videos, 106 | num_clips, 107 | num_cls, 108 | overall_iters, 109 | multi_label=False, 110 | ensemble_method="sum", 111 | ): 112 | """ 113 | Construct tensors to store the predictions and labels. Expect to get 114 | num_clips predictions from each video, and calculate the metrics on 115 | num_videos videos. 116 | Args: 117 | num_videos (int): number of videos to test. 118 | num_clips (int): number of clips sampled from each video for 119 | aggregating the final prediction for the video. 120 | num_cls (int): number of classes for each prediction. 121 | overall_iters (int): overall iterations for testing. 122 | multi_label (bool): if True, use map as the metric. 123 | ensemble_method (str): method to perform the ensemble, options 124 | include "sum", and "max". 125 | """ 126 | 127 | self.num_clips = num_clips 128 | self.overall_iters = overall_iters 129 | self.multi_label = multi_label 130 | self.ensemble_method = ensemble_method 131 | # Initialize tensors. 132 | self.video_preds = torch.zeros((num_videos, num_cls)) 133 | if multi_label: 134 | self.video_preds -= 1e10 135 | 136 | self.video_labels = ( 137 | torch.zeros((num_videos, num_cls)) 138 | if multi_label 139 | else torch.zeros((num_videos)).long() 140 | ) 141 | self.clip_count = torch.zeros((num_videos)).long() 142 | self.topk_accs = [] 143 | self.stats = {} 144 | 145 | # Reset metric. 146 | self.reset() 147 | 148 | def reset(self): 149 | """ 150 | Reset the metric. 151 | """ 152 | self.clip_count.zero_() 153 | self.video_preds.zero_() 154 | if self.multi_label: 155 | self.video_preds -= 1e10 156 | self.video_labels.zero_() 157 | 158 | def update_stats(self, preds, labels, clip_ids): 159 | """ 160 | Collect the predictions from the current batch and perform on-the-flight 161 | summation as ensemble. 162 | Args: 163 | preds (tensor): predictions from the current batch. Dimension is 164 | N x C where N is the batch size and C is the channel size 165 | (num_cls). 166 | labels (tensor): the corresponding labels of the current batch. 167 | Dimension is N. 168 | clip_ids (tensor): clip indexes of the current batch, dimension is 169 | N. 170 | """ 171 | for ind in range(preds.shape[0]): 172 | vid_id = int(clip_ids[ind]) // self.num_clips 173 | if self.video_labels[vid_id].sum() > 0: 174 | assert torch.equal( 175 | self.video_labels[vid_id].type(torch.FloatTensor), 176 | labels[ind].type(torch.FloatTensor), 177 | ) 178 | self.video_labels[vid_id] = labels[ind] 179 | if self.ensemble_method == "sum": 180 | self.video_preds[vid_id] += preds[ind] 181 | elif self.ensemble_method == "max": 182 | self.video_preds[vid_id] = torch.max( 183 | self.video_preds[vid_id], preds[ind] 184 | ) 185 | else: 186 | raise NotImplementedError( 187 | "Ensemble Method {} is not supported".format(self.ensemble_method) 188 | ) 189 | self.clip_count[vid_id] += 1 190 | 191 | def log_iter_stats(self, cur_iter): 192 | """ 193 | Log the stats. 194 | Args: 195 | cur_iter (int): the current iteration of testing. 196 | """ 197 | stats = { 198 | "split": "test_iter", 199 | "cur_iter": "{}".format(cur_iter + 1), 200 | } 201 | print(stats) 202 | 203 | def finalize_metrics(self, ks=(1, 5)): 204 | """ 205 | Calculate and log the final ensembled metrics. 206 | ks (tuple): list of top-k values for topk_accuracies. For example, 207 | ks = (1, 5) correspods to top-1 and top-5 accuracy. 208 | """ 209 | if not all(self.clip_count == self.num_clips): 210 | print( 211 | "clip count {} ~= num clips {}".format( 212 | ", ".join( 213 | [ 214 | "{}: {}".format(i, k) 215 | for i, k in enumerate(self.clip_count.tolist()) 216 | ] 217 | ), 218 | self.num_clips, 219 | ) 220 | ) 221 | 222 | self.stats = {"split": "test_final"} 223 | if self.multi_label: 224 | map = get_map( 225 | self.video_preds.cpu().numpy(), self.video_labels.cpu().numpy() 226 | ) 227 | self.stats["map"] = map 228 | else: 229 | num_topks_correct = topks_correct(self.video_preds, self.video_labels, ks) 230 | topks = [(x / self.video_preds.size(0)) * 100.0 for x in num_topks_correct] 231 | assert len({len(ks), len(topks)}) == 1 232 | for k, topk in zip(ks, topks): 233 | self.stats["top{}_acc".format(k)] = "{:.{prec}f}".format(topk, prec=2) 234 | print(self.stats) 235 | -------------------------------------------------------------------------------- /EchoFM/util/decoder/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | import logging 6 | import os 7 | import random 8 | import time 9 | from collections import defaultdict 10 | 11 | import cv2 12 | import numpy as np 13 | import torch 14 | from iopath.common.file_io import g_pathmgr as pathmgr 15 | from torch.utils.data.distributed import DistributedSampler 16 | 17 | from . import transform as transform 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def retry_load_images(image_paths, retry=10, backend="pytorch"): 23 | """ 24 | This function is to load images with support of retrying for failed load. 25 | 26 | Args: 27 | image_paths (list): paths of images needed to be loaded. 28 | retry (int, optional): maximum time of loading retrying. Defaults to 10. 29 | backend (str): `pytorch` or `cv2`. 30 | 31 | Returns: 32 | imgs (list): list of loaded images. 33 | """ 34 | for i in range(retry): 35 | imgs = [] 36 | for image_path in image_paths: 37 | with pathmgr.open(image_path, "rb") as f: 38 | img_str = np.frombuffer(f.read(), np.uint8) 39 | img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR) 40 | imgs.append(img) 41 | 42 | if all(img is not None for img in imgs): 43 | if backend == "pytorch": 44 | imgs = torch.as_tensor(np.stack(imgs)) 45 | return imgs 46 | else: 47 | logger.warn("Reading failed. Will retry.") 48 | time.sleep(1.0) 49 | if i == retry - 1: 50 | raise Exception("Failed to load images {}".format(image_paths)) 51 | 52 | 53 | def get_sequence(center_idx, half_len, sample_rate, num_frames): 54 | """ 55 | Sample frames among the corresponding clip. 56 | 57 | Args: 58 | center_idx (int): center frame idx for current clip 59 | half_len (int): half of the clip length 60 | sample_rate (int): sampling rate for sampling frames inside of the clip 61 | num_frames (int): number of expected sampled frames 62 | 63 | Returns: 64 | seq (list): list of indexes of sampled frames in this clip. 65 | """ 66 | seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate)) 67 | 68 | for seq_idx in range(len(seq)): 69 | if seq[seq_idx] < 0: 70 | seq[seq_idx] = 0 71 | elif seq[seq_idx] >= num_frames: 72 | seq[seq_idx] = num_frames - 1 73 | return seq 74 | 75 | 76 | def spatial_sampling( 77 | frames, 78 | spatial_idx=-1, 79 | min_scale=256, 80 | max_scale=320, 81 | crop_size=224, 82 | random_horizontal_flip=True, 83 | inverse_uniform_sampling=False, 84 | aspect_ratio=None, 85 | scale=None, 86 | motion_shift=False, 87 | ): 88 | """ 89 | Perform spatial sampling on the given video frames. If spatial_idx is 90 | -1, perform random scale, random crop, and random flip on the given 91 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling 92 | with the given spatial_idx. 93 | Args: 94 | frames (tensor): frames of images sampled from the video. The 95 | dimension is `num frames` x `height` x `width` x `channel`. 96 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1, 97 | or 2, perform left, center, right crop if width is larger than 98 | height, and perform top, center, buttom crop if height is larger 99 | than width. 100 | min_scale (int): the minimal size of scaling. 101 | max_scale (int): the maximal size of scaling. 102 | crop_size (int): the size of height and width used to crop the 103 | frames. 104 | inverse_uniform_sampling (bool): if True, sample uniformly in 105 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 106 | scale. If False, take a uniform sample from [min_scale, 107 | max_scale]. 108 | aspect_ratio (list): Aspect ratio range for resizing. 109 | scale (list): Scale range for resizing. 110 | motion_shift (bool): Whether to apply motion shift for resizing. 111 | Returns: 112 | frames (tensor): spatially sampled frames. 113 | """ 114 | assert spatial_idx in [-1, 0, 1, 2] 115 | if spatial_idx == -1: 116 | if aspect_ratio is None and scale is None: 117 | frames = transform.random_short_side_scale_jitter( 118 | images=frames, 119 | min_size=min_scale, 120 | max_size=max_scale, 121 | inverse_uniform_sampling=inverse_uniform_sampling, 122 | ) 123 | frames = transform.random_crop(frames, crop_size) 124 | else: 125 | transform_func = ( 126 | transform.random_resized_crop_with_shift 127 | if motion_shift 128 | else transform.random_resized_crop 129 | ) 130 | frames = transform_func( 131 | images=frames, 132 | target_height=crop_size, 133 | target_width=crop_size, 134 | scale=scale, 135 | ratio=aspect_ratio, 136 | ) 137 | if random_horizontal_flip: 138 | frames = transform.horizontal_flip(0.5, frames) 139 | else: 140 | # The testing is deterministic and no jitter should be performed. 141 | # min_scale, max_scale, and crop_size are expect to be the same. 142 | assert len({min_scale, max_scale}) == 1 143 | frames = transform.random_short_side_scale_jitter(frames, min_scale, max_scale) 144 | frames = transform.uniform_crop(frames, crop_size, spatial_idx) 145 | return frames 146 | 147 | 148 | def as_binary_vector(labels, num_classes): 149 | """ 150 | Construct binary label vector given a list of label indices. 151 | Args: 152 | labels (list): The input label list. 153 | num_classes (int): Number of classes of the label vector. 154 | Returns: 155 | labels (numpy array): the resulting binary vector. 156 | """ 157 | label_arr = np.zeros((num_classes,)) 158 | 159 | for lbl in set(labels): 160 | label_arr[lbl] = 1.0 161 | return label_arr 162 | 163 | 164 | def aggregate_labels(label_list): 165 | """ 166 | Join a list of label list. 167 | Args: 168 | labels (list): The input label list. 169 | Returns: 170 | labels (list): The joint list of all lists in input. 171 | """ 172 | all_labels = [] 173 | for labels in label_list: 174 | for l in labels: 175 | all_labels.append(l) 176 | return list(set(all_labels)) 177 | 178 | 179 | def convert_to_video_level_labels(labels): 180 | """ 181 | Aggregate annotations from all frames of a video to form video-level labels. 182 | Args: 183 | labels (list): The input label list. 184 | Returns: 185 | labels (list): Same as input, but with each label replaced by 186 | a video-level one. 187 | """ 188 | for video_id in range(len(labels)): 189 | video_level_labels = aggregate_labels(labels[video_id]) 190 | for i in range(len(labels[video_id])): 191 | labels[video_id][i] = video_level_labels 192 | return labels 193 | 194 | 195 | def load_image_lists(frame_list_file, prefix="", return_list=False): 196 | """ 197 | Load image paths and labels from a "frame list". 198 | Each line of the frame list contains: 199 | `original_vido_id video_id frame_id path labels` 200 | Args: 201 | frame_list_file (string): path to the frame list. 202 | prefix (str): the prefix for the path. 203 | return_list (bool): if True, return a list. If False, return a dict. 204 | Returns: 205 | image_paths (list or dict): list of list containing path to each frame. 206 | If return_list is False, then return in a dict form. 207 | labels (list or dict): list of list containing label of each frame. 208 | If return_list is False, then return in a dict form. 209 | """ 210 | image_paths = defaultdict(list) 211 | labels = defaultdict(list) 212 | with pathmgr.open(frame_list_file, "r") as f: 213 | assert f.readline().startswith("original_vido_id") 214 | for line in f: 215 | row = line.split() 216 | # original_vido_id video_id frame_id path labels 217 | assert len(row) == 5 218 | video_name = row[0] 219 | if prefix == "": 220 | path = row[3] 221 | else: 222 | path = os.path.join(prefix, row[3]) 223 | image_paths[video_name].append(path) 224 | frame_labels = row[-1].replace('"', "") 225 | if frame_labels != "": 226 | labels[video_name].append([int(x) for x in frame_labels.split(",")]) 227 | else: 228 | labels[video_name].append([]) 229 | 230 | if return_list: 231 | keys = image_paths.keys() 232 | image_paths = [image_paths[key] for key in keys] 233 | labels = [labels[key] for key in keys] 234 | return image_paths, labels 235 | return dict(image_paths), dict(labels) 236 | 237 | 238 | def tensor_normalize(tensor, mean, std): 239 | """ 240 | Normalize a given tensor by subtracting the mean and dividing the std. 241 | Args: 242 | tensor (tensor): tensor to normalize. 243 | mean (tensor or list): mean value to subtract. 244 | std (tensor or list): std to divide. 245 | """ 246 | if tensor.dtype == torch.uint8: 247 | tensor = tensor.float() 248 | tensor = tensor / 255.0 249 | if type(mean) == tuple: 250 | mean = torch.tensor(mean) 251 | if type(std) == tuple: 252 | std = torch.tensor(std) 253 | tensor = tensor - mean 254 | tensor = tensor / std 255 | return tensor 256 | 257 | 258 | def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate): 259 | """ 260 | When multigrid training uses a fewer number of frames, we randomly 261 | increase the sampling rate so that some clips cover the original span. 262 | """ 263 | if long_cycle_sampling_rate > 0: 264 | assert long_cycle_sampling_rate >= sampling_rate 265 | return random.randint(sampling_rate, long_cycle_sampling_rate) 266 | else: 267 | return sampling_rate 268 | 269 | 270 | def revert_tensor_normalize(tensor, mean, std): 271 | """ 272 | Revert normalization for a given tensor by multiplying by the std and adding the mean. 273 | Args: 274 | tensor (tensor): tensor to revert normalization. 275 | mean (tensor or list): mean value to add. 276 | std (tensor or list): std to multiply. 277 | """ 278 | if type(mean) == list: 279 | mean = torch.tensor(mean) 280 | if type(std) == list: 281 | std = torch.tensor(std) 282 | tensor = tensor * std 283 | tensor = tensor + mean 284 | return tensor 285 | 286 | 287 | def create_sampler(dataset, shuffle, cfg): 288 | """ 289 | Create sampler for the given dataset. 290 | Args: 291 | dataset (torch.utils.data.Dataset): the given dataset. 292 | shuffle (bool): set to ``True`` to have the data reshuffled 293 | at every epoch. 294 | cfg (CfgNode): configs. Details can be found in 295 | slowfast/config/defaults.py 296 | Returns: 297 | sampler (Sampler): the created sampler. 298 | """ 299 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 300 | 301 | return sampler 302 | 303 | 304 | def loader_worker_init_fn(dataset): 305 | """ 306 | Create init function passed to pytorch data loader. 307 | Args: 308 | dataset (torch.utils.data.Dataset): the given dataset. 309 | """ 310 | return None 311 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import os 15 | import time 16 | 17 | # import mae_st.util.env 18 | 19 | import mae_st.util.misc as misc 20 | 21 | import numpy as np 22 | # import timm 23 | import torch 24 | import torch.backends.cudnn as cudnn 25 | from iopath.common.file_io import g_pathmgr as pathmgr 26 | from mae_st import models_mae 27 | from mae_st.engine_pretrain import train_one_epoch 28 | from mae_st.util.misc import NativeScalerWithGradNormCount as NativeScaler 29 | 30 | from torch.utils.tensorboard import SummaryWriter 31 | from data.dataset import EchoDataset_from_Video_mp4 32 | import torch.distributed as dist 33 | 34 | 35 | def get_args_parser(): 36 | parser = argparse.ArgumentParser("MAE pre-training", add_help=False) 37 | parser.add_argument( 38 | "--batch_size", 39 | default=4, 40 | type=int, 41 | help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus", 42 | ) 43 | parser.add_argument("--epochs", default=100, type=int) 44 | parser.add_argument( 45 | "--accum_iter", 46 | default=1, 47 | type=int, 48 | help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)", 49 | ) 50 | 51 | # Model parameters 52 | parser.add_argument( 53 | "--model", 54 | default="mae_vit_large_patch16", 55 | type=str, 56 | metavar="MODEL", 57 | help="Name of model to train", 58 | ) 59 | 60 | parser.add_argument("--input_size", default=224, type=int, help="images input size") 61 | 62 | parser.add_argument( 63 | "--mask_ratio", 64 | default=0.75, 65 | type=float, 66 | help="Masking ratio (percentage of removed patches).", 67 | ) 68 | 69 | parser.add_argument( 70 | "--norm_pix_loss", 71 | action="store_true", 72 | help="Use (per-patch) normalized pixels as targets for computing loss", 73 | ) 74 | parser.set_defaults(norm_pix_loss=False) 75 | 76 | # Optimizer parameters 77 | parser.add_argument( 78 | "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)" 79 | ) 80 | 81 | parser.add_argument( 82 | "--lr", 83 | type=float, 84 | default=None, 85 | metavar="LR", 86 | help="learning rate (absolute lr)", 87 | ) 88 | parser.add_argument( 89 | "--blr", 90 | type=float, 91 | default=1e-3, 92 | metavar="LR", 93 | help="base learning rate: absolute_lr = base_lr * total_batch_size / 256", 94 | ) 95 | parser.add_argument( 96 | "--min_lr", 97 | type=float, 98 | default=0.0, 99 | metavar="LR", 100 | help="lower lr bound for cyclic schedulers that hit 0", 101 | ) 102 | 103 | parser.add_argument( 104 | "--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR" 105 | ) 106 | parser.add_argument( 107 | "--path_to_data_dir", 108 | default="", 109 | help="path where to save, empty for no saving", 110 | ) 111 | parser.add_argument( 112 | "--output_dir", 113 | default="./output_dir", 114 | ) 115 | parser.add_argument( 116 | "--data_path", 117 | default="/raid/camca/sk1064/us/fullset/video/", 118 | help="path where to save, empty for no saving", 119 | ) 120 | parser.add_argument( 121 | "--log_dir", 122 | default="", 123 | help="path where to tensorboard log", 124 | ) 125 | parser.add_argument( 126 | "--device", default="cuda", help="device to use for training / testing" 127 | ) 128 | parser.add_argument("--seed", default=0, type=int) 129 | parser.add_argument("--resume", default="", help="resume from checkpoint") 130 | 131 | parser.add_argument( 132 | "--start_epoch", default=0, type=int, metavar="N", help="start epoch" 133 | ) 134 | parser.add_argument("--num_workers", default=10, type=int) 135 | parser.add_argument( 136 | "--pin_mem", 137 | action="store_true", 138 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", 139 | ) 140 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem") 141 | parser.set_defaults(pin_mem=True) 142 | 143 | # distributed training parameters 144 | parser.add_argument( 145 | "--world_size", default=1, type=int, help="number of distributed processes" 146 | ) 147 | parser.add_argument("--local_rank", default=-1, type=int) 148 | parser.add_argument("--dist_on_itp", action="store_true") 149 | parser.add_argument("--no_env", action="store_true") 150 | 151 | # Video related configs 152 | parser.add_argument( 153 | "--dist_url", default="env://", help="url used to set up distributed training" 154 | ) 155 | 156 | parser.add_argument("--decoder_embed_dim", default=512, type=int) 157 | parser.add_argument("--decoder_depth", default=8, type=int) 158 | parser.add_argument("--decoder_num_heads", default=16, type=int) 159 | parser.add_argument("--t_patch_size", default=4, type=int) 160 | parser.add_argument("--num_frames", default=32, type=int) 161 | parser.add_argument("--checkpoint_period", default=1, type=int) 162 | parser.add_argument("--sampling_rate", default=4, type=int) 163 | parser.add_argument("--distributed", default=True, type=bool, help="Enable distributed training") 164 | parser.add_argument("--repeat_aug", default=4, type=int) 165 | parser.add_argument( 166 | "--clip_grad", 167 | type=float, 168 | default=None, 169 | ) 170 | parser.add_argument("--no_qkv_bias", action="store_true") 171 | parser.add_argument("--bias_wd", action="store_true") 172 | parser.add_argument("--num_checkpoint_del", default=20, type=int) 173 | parser.add_argument("--sep_pos_embed", action="store_true") 174 | parser.set_defaults(sep_pos_embed=True) 175 | parser.add_argument( 176 | "--trunc_init", 177 | action="store_true", 178 | ) 179 | parser.add_argument( 180 | "--fp32", 181 | action="store_true", 182 | ) 183 | parser.set_defaults(fp32=True) 184 | parser.add_argument( 185 | "--jitter_scales_relative", 186 | default=[0.5, 1.0], 187 | type=float, 188 | nargs="+", 189 | ) 190 | parser.add_argument( 191 | "--jitter_aspect_relative", 192 | default=[0.75, 1.3333], 193 | type=float, 194 | nargs="+", 195 | ) 196 | parser.add_argument( 197 | "--beta", 198 | default=None, 199 | type=float, 200 | nargs="+", 201 | ) 202 | parser.add_argument( 203 | "--pred_t_dim", 204 | type=int, 205 | default=8, 206 | ) 207 | parser.add_argument("--cls_embed", action="store_true") 208 | parser.set_defaults(cls_embed=True) 209 | return parser 210 | 211 | 212 | def main(args): 213 | misc.init_distributed_mode(args) 214 | 215 | print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) 216 | print("{}".format(args).replace(", ", ",\n")) 217 | 218 | # 멀티 GPU 초기화 219 | # if args.distributed: 220 | # dist.init_process_group(backend="nccl", init_method=args.dist_url, rank=args.local_rank, world_size=args.world_size) 221 | # torch.cuda.set_device(args.local_rank) 222 | if args.distributed: 223 | if not dist.is_initialized(): # 이미 초기화된 경우 중복 호출 방지 224 | dist.init_process_group( 225 | backend="nccl", 226 | init_method=args.dist_url, 227 | rank=args.local_rank, 228 | world_size=args.world_size 229 | ) 230 | torch.cuda.set_device(args.local_rank) 231 | 232 | print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__)))) 233 | print("{}".format(args).replace(", ", ",\n")) 234 | 235 | device = torch.device(args.device) 236 | 237 | # fix the seed for reproducibility 238 | seed = args.seed + misc.get_rank() 239 | torch.manual_seed(seed) 240 | np.random.seed(seed) 241 | 242 | cudnn.benchmark = True 243 | 244 | dataset_train = EchoDataset_from_Video_mp4(args.data_path) 245 | 246 | if args.distributed: 247 | num_tasks = misc.get_world_size() 248 | global_rank = misc.get_rank() 249 | sampler_train = torch.utils.data.DistributedSampler( 250 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 251 | ) 252 | print("Sampler_train = %s" % str(sampler_train)) 253 | else: 254 | num_tasks = 1 255 | global_rank = 0 256 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 257 | 258 | if global_rank == 0 and args.log_dir is not None: 259 | try: 260 | pathmgr.mkdirs(args.log_dir) 261 | except Exception as _: 262 | pass 263 | log_writer = SummaryWriter(log_dir=args.log_dir) 264 | else: 265 | log_writer = None 266 | 267 | data_loader_train = torch.utils.data.DataLoader( 268 | dataset_train, 269 | sampler=sampler_train, 270 | batch_size=args.batch_size, 271 | num_workers=args.num_workers, 272 | pin_memory=args.pin_mem, 273 | drop_last=True, 274 | ) 275 | 276 | # define the model 277 | model = models_mae.__dict__[args.model]( 278 | **vars(args), 279 | ) 280 | 281 | model.to(device) 282 | 283 | model_without_ddp = model 284 | print("Model = %s" % str(model_without_ddp)) 285 | 286 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 287 | 288 | if args.lr is None: # only base_lr is specified 289 | args.lr = args.blr * eff_batch_size / 256 290 | 291 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 292 | print("actual lr: %.2e" % args.lr) 293 | 294 | print("accumulate grad iterations: %d" % args.accum_iter) 295 | print("effective batch size: %d" % eff_batch_size) 296 | 297 | if args.distributed: 298 | model = torch.nn.parallel.DistributedDataParallel( 299 | model, 300 | device_ids=[torch.cuda.current_device()], 301 | # find_unused_parameters=True, 302 | ) 303 | model_without_ddp = model.module 304 | 305 | # following timm: set wd as 0 for bias and norm layers 306 | param_groups = misc.add_weight_decay( 307 | model_without_ddp, 308 | args.weight_decay, 309 | bias_wd=args.bias_wd, 310 | ) 311 | if args.beta is None: 312 | beta = (0.9, 0.95) 313 | else: 314 | beta = args.beta 315 | optimizer = torch.optim.AdamW( 316 | param_groups, 317 | lr=args.lr, 318 | betas=beta, 319 | ) 320 | loss_scaler = NativeScaler(fp32=args.fp32) 321 | 322 | misc.load_model( 323 | args=args, 324 | model_without_ddp=model_without_ddp, 325 | optimizer=optimizer, 326 | loss_scaler=loss_scaler, 327 | ) 328 | 329 | checkpoint_path = "" 330 | print(f"Start training for {args.epochs} epochs") 331 | start_time = time.time() 332 | for epoch in range(args.start_epoch, args.epochs): 333 | if args.distributed: 334 | data_loader_train.sampler.set_epoch(epoch) 335 | train_stats = train_one_epoch( 336 | model, 337 | data_loader_train, 338 | optimizer, 339 | device, 340 | epoch, 341 | loss_scaler, 342 | log_writer=log_writer, 343 | args=args, 344 | fp32=args.fp32, 345 | ) 346 | if args.output_dir and ( 347 | epoch % args.checkpoint_period == 0 or epoch + 1 == args.epochs 348 | ): 349 | checkpoint_path = misc.save_model( 350 | args=args, 351 | model=model, 352 | model_without_ddp=model_without_ddp, 353 | optimizer=optimizer, 354 | loss_scaler=loss_scaler, 355 | epoch=epoch, 356 | ) 357 | 358 | log_stats = { 359 | **{f"train_{k}": v for k, v in train_stats.items()}, 360 | "epoch": epoch, 361 | } 362 | 363 | if args.output_dir and misc.is_main_process(): 364 | if log_writer is not None: 365 | log_writer.flush() 366 | with pathmgr.open( 367 | f"{args.output_dir}/log.txt", 368 | "a", 369 | ) as f: 370 | f.write(json.dumps(log_stats) + "\n") 371 | 372 | total_time = time.time() - start_time 373 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 374 | print("Training time {}".format(total_time_str)) 375 | print(torch.cuda.memory_allocated()) 376 | return [checkpoint_path] 377 | 378 | 379 | def launch_one_thread( 380 | local_rank, 381 | shard_rank, 382 | num_gpus_per_node, 383 | num_shards, 384 | init_method, 385 | output_path, 386 | opts, 387 | stats_queue, 388 | ): 389 | print(opts) 390 | args = get_args_parser() 391 | args = args.parse_args(opts) 392 | args.rank = shard_rank * num_gpus_per_node + local_rank 393 | args.world_size = num_shards * num_gpus_per_node 394 | args.gpu = local_rank 395 | args.dist_url = init_method 396 | args.output_dir = output_path 397 | output = main(args) 398 | stats_queue.put(output) -------------------------------------------------------------------------------- /EchoFM/util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import math 15 | import os 16 | import time 17 | from collections import defaultdict, deque, OrderedDict 18 | 19 | import EchoFM.util.logging as logging 20 | import psutil 21 | import torch 22 | import torch.distributed as dist 23 | # import torch.fb.rendezvous.zeus 24 | from iopath.common.file_io import g_pathmgr as pathmgr 25 | from EchoFM.util.logging import master_print as print 26 | from torch import inf 27 | 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | class SmoothedValue: 33 | """Track a series of values and provide access to smoothed values over a 34 | window or the global series average. 35 | """ 36 | 37 | def __init__(self, window_size=20, fmt=None): 38 | if fmt is None: 39 | fmt = "{median:.4f} ({global_avg:.4f})" 40 | self.deque = deque(maxlen=window_size) 41 | self.total = 0.0 42 | self.count = 0 43 | self.fmt = fmt 44 | 45 | def update(self, value, n=1): 46 | self.deque.append(value) 47 | self.count += n 48 | self.total += value * n 49 | 50 | def synchronize_between_processes(self): 51 | """ 52 | Warning: does not synchronize the deque! 53 | """ 54 | if not is_dist_avail_and_initialized(): 55 | return 56 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 57 | dist.barrier() 58 | dist.all_reduce(t) 59 | t = t.tolist() 60 | self.count = int(t[0]) 61 | self.total = t[1] 62 | 63 | @property 64 | def median(self): 65 | d = torch.tensor(list(self.deque)) 66 | return d.median().item() 67 | 68 | @property 69 | def avg(self): 70 | d = torch.tensor(list(self.deque), dtype=torch.float32) 71 | return d.mean().item() 72 | 73 | @property 74 | def global_avg(self): 75 | return self.total / self.count 76 | 77 | @property 78 | def max(self): 79 | return max(self.deque) 80 | 81 | @property 82 | def value(self): 83 | return self.deque[-1] 84 | 85 | def __str__(self): 86 | return self.fmt.format( 87 | median=self.median, 88 | avg=self.avg, 89 | global_avg=self.global_avg, 90 | max=self.max, 91 | value=self.value, 92 | ) 93 | 94 | 95 | class MetricLogger: 96 | def __init__(self, delimiter="\t"): 97 | self.meters = defaultdict(SmoothedValue) 98 | self.delimiter = delimiter 99 | 100 | def update(self, **kwargs): 101 | for k, v in kwargs.items(): 102 | if v is None: 103 | continue 104 | if isinstance(v, torch.Tensor): 105 | v = v.item() 106 | assert isinstance(v, (float, int)) 107 | self.meters[k].update(v) 108 | 109 | def __getattr__(self, attr): 110 | if attr in self.meters: 111 | return self.meters[attr] 112 | if attr in self.__dict__: 113 | return self.__dict__[attr] 114 | raise AttributeError( 115 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 116 | ) 117 | 118 | def __str__(self): 119 | loss_str = [] 120 | for name, meter in self.meters.items(): 121 | loss_str.append("{}: {}".format(name, str(meter))) 122 | return self.delimiter.join(loss_str) 123 | 124 | def synchronize_between_processes(self): 125 | for meter in self.meters.values(): 126 | meter.synchronize_between_processes() 127 | 128 | def add_meter(self, name, meter): 129 | self.meters[name] = meter 130 | 131 | def log_every(self, iterable, print_freq, header=None): 132 | i = 0 133 | if not header: 134 | header = "" 135 | start_time = time.time() 136 | end = time.time() 137 | iter_time = SmoothedValue(fmt="{avg:.4f}") 138 | data_time = SmoothedValue(fmt="{avg:.4f}") 139 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 140 | log_msg = [ 141 | header, 142 | "[{0" + space_fmt + "}/{1}]", 143 | "eta: {eta}", 144 | "{meters}", 145 | "time: {time}", 146 | "data: {data}", 147 | ] 148 | if torch.cuda.is_available(): 149 | log_msg.append("max mem: {memory:.0f}") 150 | log_msg = self.delimiter.join(log_msg) 151 | MB = 1024.0 * 1024.0 152 | for obj in iterable: 153 | data_time.update(time.time() - end) 154 | yield obj 155 | iter_time.update(time.time() - end) 156 | if i % print_freq == 0 or i == len(iterable) - 1: 157 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 158 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 159 | if torch.cuda.is_available(): 160 | print( 161 | log_msg.format( 162 | i, 163 | len(iterable), 164 | eta=eta_string, 165 | meters=str(self), 166 | time=str(iter_time), 167 | data=str(data_time), 168 | memory=torch.cuda.max_memory_allocated() / MB, 169 | ) 170 | ) 171 | 172 | else: 173 | print( 174 | log_msg.format( 175 | i, 176 | len(iterable), 177 | eta=eta_string, 178 | meters=str(self), 179 | time=str(iter_time), 180 | data=str(data_time), 181 | ) 182 | ) 183 | i += 1 184 | end = time.time() 185 | total_time = time.time() - start_time 186 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 187 | print( 188 | "{} Total time: {} ({:.4f} s / it)".format( 189 | header, total_time_str, total_time / len(iterable) 190 | ) 191 | ) 192 | 193 | 194 | def setup_for_distributed(is_master): 195 | """ 196 | This function disables printing when not in master process 197 | """ 198 | builtin_print = builtins.print 199 | 200 | def print(*args, **kwargs): 201 | force = kwargs.pop("force", False) 202 | force = force or (get_world_size() > 8) 203 | if is_master or force: 204 | now = datetime.datetime.now().time() 205 | builtin_print("[{}] ".format(now), end="") # print with time stamp 206 | builtin_print(*args, **kwargs) 207 | 208 | builtins.print = print 209 | 210 | 211 | def is_dist_avail_and_initialized(): 212 | if not dist.is_available(): 213 | return False 214 | if not dist.is_initialized(): 215 | return False 216 | return True 217 | 218 | 219 | def get_world_size(): 220 | if not is_dist_avail_and_initialized(): 221 | return 1 222 | return dist.get_world_size() 223 | 224 | 225 | def get_rank(): 226 | if not is_dist_avail_and_initialized(): 227 | return 0 228 | return dist.get_rank() 229 | 230 | 231 | def is_main_process(): 232 | return get_rank() == 0 233 | 234 | 235 | def save_on_master(state, path): 236 | if is_main_process(): 237 | print(f"save path {path}") 238 | with pathmgr.open(path, "wb") as f: 239 | torch.save(state, f) 240 | 241 | 242 | def init_distributed_mode(args): 243 | if args.no_env: 244 | pass 245 | elif args.dist_on_itp: 246 | args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 247 | args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 248 | args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 249 | args.dist_url = "tcp://%s:%s" % ( 250 | os.environ["MASTER_ADDR"], 251 | os.environ["MASTER_PORT"], 252 | ) 253 | os.environ["LOCAL_RANK"] = str(args.gpu) 254 | os.environ["RANK"] = str(args.rank) 255 | os.environ["WORLD_SIZE"] = str(args.world_size) 256 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 257 | elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: 258 | args.rank = int(os.environ["RANK"]) 259 | args.world_size = int(os.environ["WORLD_SIZE"]) 260 | args.gpu = int(os.environ["LOCAL_RANK"]) 261 | elif "SLURM_PROCID" in os.environ: 262 | args.rank = int(os.environ["SLURM_PROCID"]) 263 | args.gpu = args.rank % torch.cuda.device_count() 264 | else: 265 | print("Not using distributed mode") 266 | setup_for_distributed(is_master=True) # hack 267 | args.distributed = False 268 | return 269 | 270 | args.distributed = True 271 | 272 | torch.cuda.set_device(args.gpu) 273 | args.dist_backend = "nccl" 274 | print( 275 | "| distributed init (rank {}): {}, gpu {}".format( 276 | args.rank, args.dist_url, args.gpu 277 | ), 278 | # flush=True, 279 | ) 280 | torch.distributed.init_process_group( 281 | backend=args.dist_backend, 282 | init_method=args.dist_url, 283 | world_size=args.world_size, 284 | rank=args.rank, 285 | ) 286 | torch.distributed.barrier() 287 | setup_for_distributed(args.rank == 0) 288 | 289 | 290 | class NativeScalerWithGradNormCount: 291 | state_dict_key = "amp_scaler" 292 | 293 | def __init__(self, fp32=False): 294 | self._scaler = torch.cuda.amp.GradScaler(enabled=not fp32) 295 | 296 | def __call__( 297 | self, 298 | loss, 299 | optimizer, 300 | clip_grad=None, 301 | parameters=None, 302 | create_graph=False, 303 | update_grad=True, 304 | ): 305 | self._scaler.scale(loss).backward(create_graph=create_graph) 306 | if update_grad: 307 | if clip_grad is not None: 308 | assert parameters is not None 309 | self._scaler.unscale_( 310 | optimizer 311 | ) # unscale the gradients of optimizer's assigned params in-place 312 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 313 | else: 314 | self._scaler.unscale_(optimizer) 315 | norm = get_grad_norm_(parameters) 316 | self._scaler.step(optimizer) 317 | self._scaler.update() 318 | else: 319 | norm = None 320 | return norm 321 | 322 | def state_dict(self): 323 | return self._scaler.state_dict() 324 | 325 | def load_state_dict(self, state_dict): 326 | self._scaler.load_state_dict(state_dict) 327 | 328 | 329 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 330 | if isinstance(parameters, torch.Tensor): 331 | parameters = [parameters] 332 | parameters = [p for p in parameters if p.grad is not None] 333 | norm_type = float(norm_type) 334 | if len(parameters) == 0: 335 | return torch.tensor(0.0) 336 | device = parameters[0].grad.device 337 | if norm_type == inf: 338 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 339 | else: 340 | total_norm = torch.norm( 341 | torch.stack( 342 | [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] 343 | ), 344 | norm_type, 345 | ) 346 | return total_norm 347 | 348 | 349 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 350 | checkpoint_path = "{}/checkpoint-{:05d}.pth".format(args.output_dir, epoch) 351 | to_save = { 352 | "model": model_without_ddp.state_dict(), 353 | "optimizer": optimizer.state_dict(), 354 | "epoch": epoch, 355 | "scaler": loss_scaler.state_dict(), 356 | "args": args, 357 | } 358 | 359 | save_on_master(to_save, checkpoint_path) 360 | return checkpoint_path 361 | 362 | 363 | def get_last_checkpoint(args): 364 | """ 365 | Get the last checkpoint from the checkpointing folder. 366 | Args: 367 | path_to_job (string): the path to the folder of the current job. 368 | """ 369 | d = args.output_dir 370 | names = pathmgr.ls(d) if pathmgr.exists(d) else [] 371 | names = [f for f in names if "checkpoint" in f] 372 | if len(names) == 0: 373 | print("No checkpoints found in '{}'.".format(d)) 374 | return None 375 | else: 376 | # Sort the checkpoints by epoch. 377 | name = sorted(names)[-1] 378 | return os.path.join(d, name) 379 | 380 | 381 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 382 | if not args.resume: 383 | args.resume = get_last_checkpoint(args) 384 | if args.resume: 385 | if args.resume.startswith("https"): 386 | checkpoint = torch.hub.load_state_dict_from_url( 387 | args.resume, map_location="cpu", check_hash=True 388 | ) 389 | else: 390 | with pathmgr.open(args.resume, "rb") as f: 391 | checkpoint = torch.load(f, map_location="cpu") 392 | model_without_ddp.load_state_dict(checkpoint["model"]) 393 | print("Resume checkpoint %s" % args.resume) 394 | if ( 395 | "optimizer" in checkpoint 396 | and "epoch" in checkpoint 397 | and not (hasattr(args, "eval") and args.eval) 398 | ): 399 | optimizer.load_state_dict(checkpoint["optimizer"]) 400 | args.start_epoch = checkpoint["epoch"] + 1 401 | if "scaler" in checkpoint: 402 | loss_scaler.load_state_dict(checkpoint["scaler"]) 403 | print("With optim & sched!") 404 | 405 | 406 | def all_reduce_mean(x): 407 | world_size = get_world_size() 408 | if world_size > 1: 409 | x_reduce = torch.tensor(x).cuda() 410 | dist.all_reduce(x_reduce) 411 | x_reduce /= world_size 412 | return x_reduce.item() 413 | else: 414 | return x 415 | 416 | 417 | def gpu_mem_usage(): 418 | """ 419 | Compute the GPU memory usage for the current device (GB). 420 | """ 421 | if torch.cuda.is_available(): 422 | mem_usage_bytes = torch.cuda.max_memory_allocated() 423 | else: 424 | mem_usage_bytes = 0 425 | return mem_usage_bytes / 1024**3 426 | 427 | 428 | def cpu_mem_usage(): 429 | """ 430 | Compute the system memory (RAM) usage for the current device (GB). 431 | Returns: 432 | usage (float): used memory (GB). 433 | total (float): total memory (GB). 434 | """ 435 | vram = psutil.virtual_memory() 436 | usage = (vram.total - vram.available) / 1024**3 437 | total = vram.total / 1024**3 438 | 439 | return usage, total 440 | 441 | 442 | def all_gather(tensors): 443 | """ 444 | All gathers the provided tensors from all processes across machines. 445 | Args: 446 | tensors (list): tensors to perform all gather across all processes in 447 | all machines. 448 | """ 449 | 450 | gather_list = [] 451 | output_tensor = [] 452 | world_size = dist.get_world_size() 453 | for tensor in tensors: 454 | tensor_placeholder = [torch.ones_like(tensor) for _ in range(world_size)] 455 | dist.all_gather(tensor_placeholder, tensor, async_op=False) 456 | gather_list.append(tensor_placeholder) 457 | for gathered_tensor in gather_list: 458 | output_tensor.append(torch.cat(gathered_tensor, dim=0)) 459 | return output_tensor 460 | 461 | 462 | def add_weight_decay(model, weight_decay=1e-5, skip_list=(), bias_wd=False): 463 | decay = [] 464 | no_decay = [] 465 | for name, param in model.named_parameters(): 466 | if not param.requires_grad: 467 | continue # frozen weights 468 | if ( 469 | (not bias_wd) 470 | and len(param.shape) == 1 471 | or name.endswith(".bias") 472 | or name in skip_list 473 | ): 474 | no_decay.append(param) 475 | else: 476 | decay.append(param) 477 | return [ 478 | {"params": no_decay, "weight_decay": 0.0}, 479 | {"params": decay, "weight_decay": weight_decay}, 480 | ] 481 | 482 | 483 | def inflate(model_2d, model_3d): 484 | state_dict_inflated = OrderedDict() 485 | for k, v2d in model_2d.items(): 486 | if "patch_embed.proj.weight" in k: 487 | v3d = model_3d[k] 488 | v3d = v2d.unsqueeze(2).repeat(1, 1, v3d.shape[2], 1, 1) / v3d.shape[2] 489 | state_dict_inflated[k] = v3d.clone() 490 | elif "pos_embed" in k: 491 | pos_embed_cls, pos_embed_spatial = torch.split(v2d, [1, 196], dim=1) 492 | state_dict_inflated["pos_embed_cls"] = pos_embed_cls.clone() 493 | state_dict_inflated["pos_embed"] = pos_embed_spatial.clone() 494 | else: 495 | state_dict_inflated[k] = v2d.clone() 496 | return state_dict_inflated 497 | 498 | 499 | def convert_checkpoint(model_2d): 500 | state_dict_inflated = OrderedDict() 501 | for k, v2d in model_2d.items(): 502 | if "head.projection.weight" in k: 503 | state_dict_inflated["head.weight"] = v2d.clone() 504 | elif "head.projection.bias" in k: 505 | state_dict_inflated["head.bias"] = v2d.clone() 506 | else: 507 | state_dict_inflated[k] = v2d.clone() 508 | return state_dict_inflated 509 | -------------------------------------------------------------------------------- /EchoFM/util/decoder/rand_augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | 5 | """ 6 | This implementation is based on 7 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py 8 | pulished under an Apache License 2.0. 9 | 10 | COMMENT FROM ORIGINAL: 11 | AutoAugment, RandAugment, and AugMix for PyTorch 12 | This code implements the searched ImageNet policies with various tweaks and 13 | improvements and does not include any of the search code. AA and RA 14 | Implementation adapted from: 15 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 16 | AugMix adapted from: 17 | https://github.com/google-research/augmix 18 | Papers: 19 | AutoAugment: Learning Augmentation Policies from Data 20 | https://arxiv.org/abs/1805.09501 21 | Learning Data Augmentation Strategies for Object Detection 22 | https://arxiv.org/abs/1906.11172 23 | RandAugment: Practical automated data augmentation... 24 | https://arxiv.org/abs/1909.13719 25 | AugMix: A Simple Data Processing Method to Improve Robustness and 26 | Uncertainty https://arxiv.org/abs/1912.02781 27 | 28 | Hacked together by / Copyright 2020 Ross Wightman 29 | """ 30 | 31 | import math 32 | import random 33 | import re 34 | 35 | import numpy as np 36 | import PIL 37 | from PIL import Image, ImageEnhance, ImageOps 38 | 39 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) 40 | 41 | _FILL = (128, 128, 128) 42 | 43 | # This signifies the max integer that the controller RNN could predict for the 44 | # augmentation scheme. 45 | _MAX_LEVEL = 10.0 46 | 47 | _HPARAMS_DEFAULT = { 48 | "translate_const": 250, 49 | "img_mean": _FILL, 50 | } 51 | 52 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 53 | 54 | 55 | def _interpolation(kwargs): 56 | interpolation = kwargs.pop("resample", Image.BILINEAR) 57 | if isinstance(interpolation, (list, tuple)): 58 | return random.choice(interpolation) 59 | else: 60 | return interpolation 61 | 62 | 63 | def _check_args_tf(kwargs): 64 | if "fillcolor" in kwargs and _PIL_VER < (5, 0): 65 | kwargs.pop("fillcolor") 66 | kwargs["resample"] = _interpolation(kwargs) 67 | 68 | 69 | def shear_x(img, factor, **kwargs): 70 | _check_args_tf(kwargs) 71 | return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) 72 | 73 | 74 | def shear_y(img, factor, **kwargs): 75 | _check_args_tf(kwargs) 76 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) 77 | 78 | 79 | def translate_x_rel(img, pct, **kwargs): 80 | pixels = pct * img.size[0] 81 | _check_args_tf(kwargs) 82 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) 83 | 84 | 85 | def translate_y_rel(img, pct, **kwargs): 86 | pixels = pct * img.size[1] 87 | _check_args_tf(kwargs) 88 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) 89 | 90 | 91 | def translate_x_abs(img, pixels, **kwargs): 92 | _check_args_tf(kwargs) 93 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) 94 | 95 | 96 | def translate_y_abs(img, pixels, **kwargs): 97 | _check_args_tf(kwargs) 98 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) 99 | 100 | 101 | def rotate(img, degrees, **kwargs): 102 | _check_args_tf(kwargs) 103 | if _PIL_VER >= (5, 2): 104 | return img.rotate(degrees, **kwargs) 105 | elif _PIL_VER >= (5, 0): 106 | w, h = img.size 107 | post_trans = (0, 0) 108 | rotn_center = (w / 2.0, h / 2.0) 109 | angle = -math.radians(degrees) 110 | matrix = [ 111 | round(math.cos(angle), 15), 112 | round(math.sin(angle), 15), 113 | 0.0, 114 | round(-math.sin(angle), 15), 115 | round(math.cos(angle), 15), 116 | 0.0, 117 | ] 118 | 119 | def transform(x, y, matrix): 120 | (a, b, c, d, e, f) = matrix 121 | return a * x + b * y + c, d * x + e * y + f 122 | 123 | matrix[2], matrix[5] = transform( 124 | -rotn_center[0] - post_trans[0], 125 | -rotn_center[1] - post_trans[1], 126 | matrix, 127 | ) 128 | matrix[2] += rotn_center[0] 129 | matrix[5] += rotn_center[1] 130 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs) 131 | else: 132 | return img.rotate(degrees, resample=kwargs["resample"]) 133 | 134 | 135 | def auto_contrast(img, **__): 136 | return ImageOps.autocontrast(img) 137 | 138 | 139 | def invert(img, **__): 140 | return ImageOps.invert(img) 141 | 142 | 143 | def equalize(img, **__): 144 | return ImageOps.equalize(img) 145 | 146 | 147 | def solarize(img, thresh, **__): 148 | return ImageOps.solarize(img, thresh) 149 | 150 | 151 | def solarize_add(img, add, thresh=128, **__): 152 | lut = [] 153 | for i in range(256): 154 | if i < thresh: 155 | lut.append(min(255, i + add)) 156 | else: 157 | lut.append(i) 158 | if img.mode in ("L", "RGB"): 159 | if img.mode == "RGB" and len(lut) == 256: 160 | lut = lut + lut + lut 161 | return img.point(lut) 162 | else: 163 | return img 164 | 165 | 166 | def posterize(img, bits_to_keep, **__): 167 | if bits_to_keep >= 8: 168 | return img 169 | return ImageOps.posterize(img, bits_to_keep) 170 | 171 | 172 | def contrast(img, factor, **__): 173 | return ImageEnhance.Contrast(img).enhance(factor) 174 | 175 | 176 | def color(img, factor, **__): 177 | return ImageEnhance.Color(img).enhance(factor) 178 | 179 | 180 | def brightness(img, factor, **__): 181 | return ImageEnhance.Brightness(img).enhance(factor) 182 | 183 | 184 | def sharpness(img, factor, **__): 185 | return ImageEnhance.Sharpness(img).enhance(factor) 186 | 187 | 188 | def _randomly_negate(v): 189 | """With 50% prob, negate the value""" 190 | return -v if random.random() > 0.5 else v 191 | 192 | 193 | def _rotate_level_to_arg(level, _hparams): 194 | # range [-30, 30] 195 | level = (level / _MAX_LEVEL) * 30.0 196 | level = _randomly_negate(level) 197 | return (level,) 198 | 199 | 200 | def _enhance_level_to_arg(level, _hparams): 201 | # range [0.1, 1.9] 202 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,) 203 | 204 | 205 | def _enhance_increasing_level_to_arg(level, _hparams): 206 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend 207 | # range [0.1, 1.9] 208 | level = (level / _MAX_LEVEL) * 0.9 209 | level = 1.0 + _randomly_negate(level) 210 | return (level,) 211 | 212 | 213 | def _shear_level_to_arg(level, _hparams): 214 | # range [-0.3, 0.3] 215 | level = (level / _MAX_LEVEL) * 0.3 216 | level = _randomly_negate(level) 217 | return (level,) 218 | 219 | 220 | def _translate_abs_level_to_arg(level, hparams): 221 | translate_const = hparams["translate_const"] 222 | level = (level / _MAX_LEVEL) * float(translate_const) 223 | level = _randomly_negate(level) 224 | return (level,) 225 | 226 | 227 | def _translate_rel_level_to_arg(level, hparams): 228 | # default range [-0.45, 0.45] 229 | translate_pct = hparams.get("translate_pct", 0.45) 230 | level = (level / _MAX_LEVEL) * translate_pct 231 | level = _randomly_negate(level) 232 | return (level,) 233 | 234 | 235 | def _posterize_level_to_arg(level, _hparams): 236 | # As per Tensorflow TPU EfficientNet impl 237 | # range [0, 4], 'keep 0 up to 4 MSB of original image' 238 | # intensity/severity of augmentation decreases with level 239 | return (int((level / _MAX_LEVEL) * 4),) 240 | 241 | 242 | def _posterize_increasing_level_to_arg(level, hparams): 243 | # As per Tensorflow models research and UDA impl 244 | # range [4, 0], 'keep 4 down to 0 MSB of original image', 245 | # intensity/severity of augmentation increases with level 246 | return (4 - _posterize_level_to_arg(level, hparams)[0],) 247 | 248 | 249 | def _posterize_original_level_to_arg(level, _hparams): 250 | # As per original AutoAugment paper description 251 | # range [4, 8], 'keep 4 up to 8 MSB of image' 252 | # intensity/severity of augmentation decreases with level 253 | return (int((level / _MAX_LEVEL) * 4) + 4,) 254 | 255 | 256 | def _solarize_level_to_arg(level, _hparams): 257 | # range [0, 256] 258 | # intensity/severity of augmentation decreases with level 259 | return (int((level / _MAX_LEVEL) * 256),) 260 | 261 | 262 | def _solarize_increasing_level_to_arg(level, _hparams): 263 | # range [0, 256] 264 | # intensity/severity of augmentation increases with level 265 | return (256 - _solarize_level_to_arg(level, _hparams)[0],) 266 | 267 | 268 | def _solarize_add_level_to_arg(level, _hparams): 269 | # range [0, 110] 270 | return (int((level / _MAX_LEVEL) * 110),) 271 | 272 | 273 | LEVEL_TO_ARG = { 274 | "AutoContrast": None, 275 | "Equalize": None, 276 | "Invert": None, 277 | "Rotate": _rotate_level_to_arg, 278 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers 279 | "Posterize": _posterize_level_to_arg, 280 | "PosterizeIncreasing": _posterize_increasing_level_to_arg, 281 | "PosterizeOriginal": _posterize_original_level_to_arg, 282 | "Solarize": _solarize_level_to_arg, 283 | "SolarizeIncreasing": _solarize_increasing_level_to_arg, 284 | "SolarizeAdd": _solarize_add_level_to_arg, 285 | "Color": _enhance_level_to_arg, 286 | "ColorIncreasing": _enhance_increasing_level_to_arg, 287 | "Contrast": _enhance_level_to_arg, 288 | "ContrastIncreasing": _enhance_increasing_level_to_arg, 289 | "Brightness": _enhance_level_to_arg, 290 | "BrightnessIncreasing": _enhance_increasing_level_to_arg, 291 | "Sharpness": _enhance_level_to_arg, 292 | "SharpnessIncreasing": _enhance_increasing_level_to_arg, 293 | "ShearX": _shear_level_to_arg, 294 | "ShearY": _shear_level_to_arg, 295 | "TranslateX": _translate_abs_level_to_arg, 296 | "TranslateY": _translate_abs_level_to_arg, 297 | "TranslateXRel": _translate_rel_level_to_arg, 298 | "TranslateYRel": _translate_rel_level_to_arg, 299 | } 300 | 301 | 302 | NAME_TO_OP = { 303 | "AutoContrast": auto_contrast, 304 | "Equalize": equalize, 305 | "Invert": invert, 306 | "Rotate": rotate, 307 | "Posterize": posterize, 308 | "PosterizeIncreasing": posterize, 309 | "PosterizeOriginal": posterize, 310 | "Solarize": solarize, 311 | "SolarizeIncreasing": solarize, 312 | "SolarizeAdd": solarize_add, 313 | "Color": color, 314 | "ColorIncreasing": color, 315 | "Contrast": contrast, 316 | "ContrastIncreasing": contrast, 317 | "Brightness": brightness, 318 | "BrightnessIncreasing": brightness, 319 | "Sharpness": sharpness, 320 | "SharpnessIncreasing": sharpness, 321 | "ShearX": shear_x, 322 | "ShearY": shear_y, 323 | "TranslateX": translate_x_abs, 324 | "TranslateY": translate_y_abs, 325 | "TranslateXRel": translate_x_rel, 326 | "TranslateYRel": translate_y_rel, 327 | } 328 | 329 | 330 | class AugmentOp: 331 | """ 332 | Apply for video. 333 | """ 334 | 335 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None): 336 | hparams = hparams or _HPARAMS_DEFAULT 337 | self.aug_fn = NAME_TO_OP[name] 338 | self.level_fn = LEVEL_TO_ARG[name] 339 | self.prob = prob 340 | self.magnitude = magnitude 341 | self.hparams = hparams.copy() 342 | self.kwargs = { 343 | "fillcolor": hparams["img_mean"] if "img_mean" in hparams else _FILL, 344 | "resample": ( 345 | hparams["interpolation"] 346 | if "interpolation" in hparams 347 | else _RANDOM_INTERPOLATION 348 | ), 349 | } 350 | 351 | # If magnitude_std is > 0, we introduce some randomness 352 | # in the usually fixed policy and sample magnitude from a normal distribution 353 | # with mean `magnitude` and std-dev of `magnitude_std`. 354 | # NOTE This is my own hack, being tested, not in papers or reference impls. 355 | self.magnitude_std = self.hparams.get("magnitude_std", 0) 356 | 357 | def __call__(self, img_list): 358 | if self.prob < 1.0 and random.random() > self.prob: 359 | return img_list 360 | magnitude = self.magnitude 361 | if self.magnitude_std and self.magnitude_std > 0: 362 | magnitude = random.gauss(magnitude, self.magnitude_std) 363 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range 364 | level_args = ( 365 | self.level_fn(magnitude, self.hparams) if self.level_fn is not None else () 366 | ) 367 | 368 | if isinstance(img_list, list): 369 | return [self.aug_fn(img, *level_args, **self.kwargs) for img in img_list] 370 | else: 371 | return self.aug_fn(img_list, *level_args, **self.kwargs) 372 | 373 | 374 | _RAND_TRANSFORMS = [ 375 | "AutoContrast", 376 | "Equalize", 377 | "Invert", 378 | "Rotate", 379 | "Posterize", 380 | "Solarize", 381 | "SolarizeAdd", 382 | "Color", 383 | "Contrast", 384 | "Brightness", 385 | "Sharpness", 386 | "ShearX", 387 | "ShearY", 388 | "TranslateXRel", 389 | "TranslateYRel", 390 | ] 391 | 392 | 393 | _RAND_INCREASING_TRANSFORMS = [ 394 | "AutoContrast", 395 | "Equalize", 396 | "Invert", 397 | "Rotate", 398 | "PosterizeIncreasing", 399 | "SolarizeIncreasing", 400 | "SolarizeAdd", 401 | "ColorIncreasing", 402 | "ContrastIncreasing", 403 | "BrightnessIncreasing", 404 | "SharpnessIncreasing", 405 | "ShearX", 406 | "ShearY", 407 | "TranslateXRel", 408 | "TranslateYRel", 409 | ] 410 | 411 | 412 | # These experimental weights are based loosely on the relative improvements mentioned in paper. 413 | # They may not result in increased performance, but could likely be tuned to so. 414 | _RAND_CHOICE_WEIGHTS_0 = { 415 | "Rotate": 0.3, 416 | "ShearX": 0.2, 417 | "ShearY": 0.2, 418 | "TranslateXRel": 0.1, 419 | "TranslateYRel": 0.1, 420 | "Color": 0.025, 421 | "Sharpness": 0.025, 422 | "AutoContrast": 0.025, 423 | "Solarize": 0.005, 424 | "SolarizeAdd": 0.005, 425 | "Contrast": 0.005, 426 | "Brightness": 0.005, 427 | "Equalize": 0.005, 428 | "Posterize": 0, 429 | "Invert": 0, 430 | } 431 | 432 | 433 | def _select_rand_weights(weight_idx=0, transforms=None): 434 | transforms = transforms or _RAND_TRANSFORMS 435 | assert weight_idx == 0 # only one set of weights currently 436 | rand_weights = _RAND_CHOICE_WEIGHTS_0 437 | probs = [rand_weights[k] for k in transforms] 438 | probs /= np.sum(probs) 439 | return probs 440 | 441 | 442 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None): 443 | hparams = hparams or _HPARAMS_DEFAULT 444 | transforms = transforms or _RAND_TRANSFORMS 445 | return [ 446 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) 447 | for name in transforms 448 | ] 449 | 450 | 451 | class RandAugment: 452 | def __init__(self, ops, num_layers=2, choice_weights=None): 453 | self.ops = ops 454 | self.num_layers = num_layers 455 | self.choice_weights = choice_weights 456 | 457 | def __call__(self, img): 458 | # no replacement when using weighted choice 459 | ops = np.random.choice( 460 | self.ops, 461 | self.num_layers, 462 | replace=self.choice_weights is None, 463 | p=self.choice_weights, 464 | ) 465 | for op in ops: 466 | img = op(img) 467 | return img 468 | 469 | 470 | def rand_augment_transform(config_str, hparams): 471 | """ 472 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 473 | 474 | Create a RandAugment transform 475 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by 476 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining 477 | sections, not order sepecific determine 478 | 'm' - integer magnitude of rand augment 479 | 'n' - integer num layers (number of transform ops selected per image) 480 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) 481 | 'mstd' - float std deviation of magnitude noise applied 482 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) 483 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 484 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 485 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme 486 | :return: A PyTorch compatible Transform 487 | """ 488 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) 489 | num_layers = 2 # default to 2 ops per image 490 | weight_idx = None # default to no probability weights for op choice 491 | transforms = _RAND_TRANSFORMS 492 | config = config_str.split("-") 493 | assert config[0] == "rand" 494 | config = config[1:] 495 | for c in config: 496 | cs = re.split(r"(\d.*)", c) 497 | if len(cs) < 2: 498 | continue 499 | key, val = cs[:2] 500 | if key == "mstd": 501 | # noise param injected via hparams for now 502 | hparams.setdefault("magnitude_std", float(val)) 503 | elif key == "inc": 504 | if bool(val): 505 | transforms = _RAND_INCREASING_TRANSFORMS 506 | elif key == "m": 507 | magnitude = int(val) 508 | elif key == "n": 509 | num_layers = int(val) 510 | elif key == "w": 511 | weight_idx = int(val) 512 | else: 513 | assert NotImplementedError 514 | ra_ops = rand_augment_ops( 515 | magnitude=magnitude, hparams=hparams, transforms=transforms 516 | ) 517 | choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) 518 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) 519 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import re 3 | 4 | import cv2 5 | from PIL import Image 6 | from functools import partial 7 | 8 | from typing import Tuple, List 9 | # from beartype.door import is_bearable 10 | 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.utils.data import Dataset, DataLoader as PytorchDataLoader 16 | from torchvision import transforms as T, utils 17 | 18 | from einops import rearrange 19 | import os 20 | import pickle as pkl 21 | import random 22 | 23 | # helper functions 24 | 25 | from scipy import interpolate 26 | import torchvision.utils as vutils 27 | from PIL import Image 28 | 29 | 30 | def exists(val): 31 | return val is not None 32 | 33 | 34 | def identity(t, *args, **kwargs): 35 | return t 36 | 37 | 38 | def pair(val): 39 | return val if isinstance(val, tuple) else (val, val) 40 | 41 | 42 | def bgr_to_rgb(video_tensor): 43 | video_tensor = video_tensor[[2, 1, 0], :, :, :] 44 | return video_tensor 45 | 46 | 47 | 48 | def cast_num_frames(t, *, frames): 49 | f = t.shape[1] 50 | 51 | if f == frames: 52 | return t 53 | 54 | if f > frames: 55 | return t[:, :frames] 56 | 57 | return F.pad(t, (0, 0, 0, 0, 0, frames - f)) 58 | 59 | def convert_image_to_fn(img_type, image): 60 | if image.mode != img_type: 61 | return image.convert(img_type) 62 | return image 63 | 64 | # image related helpers functions and dataset 65 | def z_normalize(data): 66 | """ 67 | Perform z-score normalization on the input data. 68 | """ 69 | mean = np.mean(data) 70 | std = np.std(data) 71 | return (data - mean) / std 72 | 73 | def save_tensor_images(tensor, output_dir="output_images"): 74 | # 출력 디렉토리 생성 75 | os.makedirs(output_dir, exist_ok=True) 76 | 77 | # 텐서의 값을 0-255 범위로 정규화 78 | tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * 255 79 | tensor = tensor.byte() 80 | 81 | # 각 프레임을 개별적으로 저장 82 | for frame in range(tensor.shape[1]): # 프레임 차원을 순회 83 | # 현재 프레임의 모든 채널 선택 84 | frame_data = tensor[:, frame, :, :] 85 | 86 | # 채널 순서 변경 (C, H, W) -> (H, W, C) 87 | frame_data = frame_data.permute(1, 2, 0) 88 | 89 | # NumPy 배열로 변환 90 | frame_array = frame_data.numpy() 91 | 92 | # RGB 이미지로 변환 (채널이 3개인 경우) 93 | if frame_array.shape[2] == 3: 94 | img = Image.fromarray(frame_array, 'RGB') 95 | else: 96 | img = Image.fromarray(frame_array[:,:,0], 'L') 97 | 98 | # 이미지 저장 99 | img.save(os.path.join(output_dir, f"frame_{frame:03d}.png")) 100 | 101 | class ImageDataset(Dataset): 102 | def __init__( 103 | self, 104 | folder, 105 | image_size, 106 | exts=['jpg', 'jpeg', 'png'] 107 | ): 108 | super().__init__() 109 | self.folder = folder 110 | self.image_size = image_size 111 | self.paths = [p for ext in exts for p in Path( 112 | f'{folder}').glob(f'**/*.{ext}')] 113 | 114 | print(f'{len(self.paths)} training samples found at {folder}') 115 | 116 | self.transform = T.Compose([ 117 | T.Lambda(lambda img: img.convert('RGB') 118 | if img.mode != 'RGB' else img), 119 | T.Resize(image_size), 120 | # T.RandomHorizontalFlip(), 121 | # T.CenterCrop(image_size), 122 | T.ToTensor() 123 | ]) 124 | 125 | def __len__(self): 126 | return len(self.paths) 127 | 128 | def __getitem__(self, index): 129 | path = self.paths[index] 130 | img = Image.open(path) 131 | return self.transform(img) 132 | 133 | # tensor of shape (channels, frames, height, width) -> gif 134 | 135 | # handle reading and writing gif 136 | 137 | 138 | CHANNELS_TO_MODE = { 139 | 1: 'L', 140 | 3: 'RGB', 141 | 4: 'RGBA' 142 | } 143 | 144 | 145 | def seek_all_images(img, channels=3): 146 | assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' 147 | mode = CHANNELS_TO_MODE[channels] 148 | 149 | i = 0 150 | while True: 151 | try: 152 | img.seek(i) 153 | yield img.convert(mode) 154 | except EOFError: 155 | break 156 | i += 1 157 | 158 | # tensor of shape (channels, frames, height, width) -> gif 159 | 160 | 161 | def video_tensor_to_pil_first_image(tensor): 162 | 163 | tensor = bgr_to_rgb(tensor) 164 | images = map(T.ToPILImage(), tensor.unbind(dim=1)) 165 | first_img, *rest_imgs = images 166 | 167 | return first_img 168 | 169 | 170 | def video_tensor_to_gif( 171 | tensor, 172 | path, 173 | duration=120, 174 | loop=0, 175 | optimize=True 176 | ): 177 | 178 | tensor = torch.clamp(tensor, min=0, max=1) # clipping underflow and overflow 179 | #tensor = bgr_to_rgb(tensor) 180 | images = map(T.ToPILImage(), tensor.unbind(dim=1)) 181 | first_img, *rest_imgs = images 182 | first_img.save(path, save_all=True, append_images=rest_imgs, 183 | loop=loop, optimize=optimize) 184 | return images 185 | 186 | # gif -> (channels, frame, height, width) tensor 187 | 188 | 189 | def gif_to_tensor( 190 | path, 191 | channels=3, 192 | transform=T.ToTensor() 193 | ): 194 | img = Image.open(path) 195 | tensors = tuple(map(transform, seek_all_images(img, channels=channels))) 196 | return torch.stack(tensors, dim=1) 197 | 198 | # handle reading and writing mp4 199 | 200 | 201 | 202 | def tensor_to_video( 203 | tensor, # Pytorch video tensor 204 | path: str, # Path of the video to be saved 205 | fps=8, # Frames per second for the saved video 206 | video_format=('m', 'p', '4', 'v') 207 | ): 208 | # Import the video and cut it into frames. 209 | tensor = tensor.cpu()*255. # TODO: have a better function for that? Not using cv2? 210 | 211 | num_frames, height, width = tensor.shape[-3:] 212 | 213 | # Changes in this line can allow for different video formats. 214 | fourcc = cv2.VideoWriter_fourcc(*video_format) 215 | video = cv2.VideoWriter(path, fourcc, fps, (width, height)) 216 | 217 | frames = [] 218 | 219 | for idx in range(num_frames): 220 | numpy_frame = tensor[:, idx, :, :].numpy() 221 | numpy_frame = np.uint8(rearrange(numpy_frame, 'c h w -> h w c')) 222 | video.write(numpy_frame) 223 | 224 | video.release() 225 | 226 | cv2.destroyAllWindows() 227 | 228 | return video 229 | 230 | 231 | def crop_center( 232 | img, # tensor 233 | cropx, # Length of the final image in the x direction. 234 | cropy # Length of the final image in the y direction. 235 | ) -> torch.Tensor: 236 | y, x, c = img.shape 237 | startx = x // 2 - cropx // 2 238 | starty = y // 2 - cropy // 2 239 | return img[starty:(starty + cropy), startx:(startx + cropx), :] 240 | 241 | def sort_key(file_path): 242 | # Extract the numerical parts from the file name using regex 243 | match = re.findall(r'(\d+)', file_path.stem) 244 | if match: 245 | return [int(part) for part in match] 246 | return str(file_path) 247 | # video dataset 248 | 249 | def save_tensor_as_grid(tensor, grid_size, save_path="grid_image.png"): 250 | """ 251 | 4x4 그리드로 텐서를 저장하는 함수. 252 | 253 | Args: 254 | tensor (torch.Tensor): 텐서 크기 (C, N, H, W) 또는 (N, C, H, W). 255 | - C: 채널 수 (3이면 RGB, 1이면 그레이스케일). 256 | - N: 이미지 개수. 257 | - H, W: 이미지 높이와 너비. 258 | grid_size (int): 그리드의 행과 열 수. (예: 4이면 4x4 그리드). 259 | save_path (str): 저장할 이미지 파일 경로. 260 | """ 261 | # 텐서 차원 맞추기: (N, C, H, W) 262 | if tensor.shape[0] == 3 or tensor.shape[0] == 1: 263 | tensor = tensor.permute(1, 0, 2, 3) # (C, N, H, W) -> (N, C, H, W) 264 | 265 | # 그리드 생성 266 | grid = vutils.make_grid(tensor, nrow=grid_size, padding=2) 267 | 268 | # 텐서를 PIL 이미지로 변환 269 | grid_image = (grid * 255).byte().permute(1, 2, 0).numpy() # RGB 순서로 변환 270 | image = Image.fromarray(grid_image) 271 | 272 | # 이미지 저장 273 | image.save(save_path) 274 | print(f"4x4 그리드 이미지를 '{save_path}'로 저장했습니다.") 275 | 276 | def process_ekg(ekg_data, target_length=2250, repetitions=16): 277 | processed_ekg = np.zeros((12, target_length)) 278 | 279 | # 원본 데이터를 16번 반복 280 | repeated_data = np.tile(ekg_data, repetitions) 281 | 282 | # 반복된 데이터의 길이 283 | original_length = len(repeated_data) 284 | 285 | if original_length < target_length: 286 | # Interpolation 287 | x = np.linspace(0, 1, original_length) 288 | f = interpolate.interp1d(x, repeated_data, kind='linear') 289 | x_new = np.linspace(0, 1, target_length) 290 | processed_ekg[1] = f(x_new) 291 | elif original_length > target_length: 292 | # Resampling 293 | x = np.linspace(0, 1, original_length) 294 | x_new = np.linspace(0, 1, target_length) 295 | processed_ekg[1] = np.interp(x_new, x, repeated_data) 296 | else: 297 | processed_ekg[1] = repeated_data[:target_length] 298 | 299 | return processed_ekg 300 | 301 | 302 | def video_to_tensor( 303 | path: str, 304 | transform, # Path of the video to be imported 305 | num_frames=-1, # Number of frames to be stored in the output tensor 306 | crop_size=None 307 | ) -> torch.Tensor: # shape (1, channels, frames, height, width) 308 | 309 | video = cv2.VideoCapture(path) 310 | 311 | total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 312 | 313 | # print ("PATH", path) 314 | # print ("TOTAL frame : ",total_frames ) 315 | frames = [] 316 | check = True 317 | 318 | shear_x = random.uniform(-5, 5) 319 | shear_y = random.uniform(-5, 5) 320 | contrast_factor = random.uniform(0.6, 1.4) 321 | 322 | while check: 323 | check, frame = video.read() 324 | 325 | if not check: 326 | continue 327 | # frame = np.transpose(frame, (2, 0, 1)) 328 | # 고정된 augmentation 값들로 transform 적용 329 | # frame = transform(frame, shear_x, shear_y, contrast_factor) 330 | frame = transform(frame) 331 | frames.append(rearrange(frame, '... -> 1 ...')) 332 | 333 | # convert list of frames to numpy array 334 | frames = np.array(np.concatenate(frames, axis=0)) 335 | # frames = rearrange(frames, 'f c h w -> c f h w') 336 | frames = rearrange(frames, 'f c h w -> c f h w') 337 | 338 | frames_torch = torch.tensor(frames).float() 339 | 340 | return frames_torch 341 | 342 | 343 | 344 | def process_ultrasound_image(video_tensor): 345 | # 입력 텐서의 shape 확인 346 | B, T, H, W = video_tensor.shape 347 | 348 | # 결과를 저장할 텐서 초기화 349 | result = torch.zeros_like(video_tensor) 350 | 351 | for b in range(B): 352 | for t in range(T): 353 | 354 | # 현재 프레임 추출 355 | frame = video_tensor[b, t].cpu().numpy() #shape of 128 128 356 | 357 | # frame_ = (frame*255).astype(np.uint8) 358 | 359 | # Save the grayscale image using OpenCV 360 | # output_path_gray = "/home/local/PARTNERS/sk1064/project/EchoHub/dataset/frame_image_gray.jpg" 361 | # cv2.imwrite(output_path_gray, frame_) 362 | 363 | # If you want to convert the numpy array to PIL Image and save it 364 | 365 | # print ("FRAME SIZE CHECK " *10, np.max(frame_)) 366 | # 임계값 설정 (이 값은 이미지에 따라 조정이 필요할 수 있습니다) 367 | threshold = 0.1 368 | 369 | # 각 열에서 첫 번째로 임계값을 넘는 픽셀 찾기 370 | first_pixels = np.argmax(frame > threshold, axis=0) 371 | 372 | # 가장 위에 있는 픽셀의 y 좌표 찾기 373 | top_y = np.min(first_pixels[first_pixels > 0]) 374 | 375 | # 128x128 크기로 자르기 376 | # cropped = frame[top_y:top_y+128, :128] 377 | cropped = frame[top_y:top_y+224, :224] 378 | 379 | # 크기가 128x128이 아닌 경우 인터폴레이션을 사용하여 리사이징 380 | # if cropped.shape != (128, 128): 381 | # cropped = cv2.resize(cropped, (128, 128), interpolation=cv2.INTER_LINEAR) 382 | if cropped.shape != (224, 224): 383 | cropped = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_LINEAR) 384 | 385 | # 결과 텐서에 저장 386 | result[b, t] = torch.from_numpy(cropped).float() 387 | 388 | return result 389 | 390 | class EchoDataset_from_Video_mp4(Dataset): 391 | def __init__( 392 | self, 393 | folder, 394 | image_size = [224, 224], 395 | channels = 3, 396 | ): 397 | super().__init__() 398 | self.folder = folder 399 | 400 | self.image_size = image_size 401 | self.channels = channels 402 | 403 | def apply_augmentation(img, shear_x, shear_y, contrast_factor): 404 | # Apply contrast augmentation 405 | img = T.functional.adjust_contrast(img, contrast_factor) 406 | 407 | # # Apply shear x, y augmentation 408 | img = T.functional.affine(img, angle=0, translate=[0, 0], scale=1.0, shear=[shear_x, shear_y]) 409 | 410 | return img 411 | 412 | def create_transform(image_size): 413 | # def transform(img, shear_x, shear_y, contrast_factor): 414 | def transform(img): 415 | if not isinstance(img, Image.Image): 416 | img = T.ToPILImage()(img) 417 | img = T.Resize(image_size)(img) 418 | # img = apply_augmentation(img, shear_x, shear_y, contrast_factor) 419 | return T.ToTensor()(img) 420 | return transform 421 | 422 | self.transform_for_videos = create_transform(self.image_size) 423 | 424 | self.transform = T.Compose([ 425 | T.Resize(image_size), 426 | T.ToTensor() 427 | ]) 428 | self.paths = os.listdir(folder) 429 | 430 | self.mp4_to_tensor = partial( 431 | video_to_tensor, transform=self.transform_for_videos, crop_size=self.image_size, num_frames=10) 432 | 433 | force_num_frames = True 434 | 435 | self.cast_num_frames_fn = partial( 436 | cast_num_frames, frames=32) if force_num_frames else identity 437 | 438 | def __len__(self): 439 | return len(self.paths) 440 | 441 | def __getitem__(self, index): 442 | path = self.paths[index] 443 | 444 | path = os.path.join(self.folder, path) 445 | tensor = self.mp4_to_tensor(str(path)) 446 | 447 | tensor = self.cast_num_frames_fn(tensor) 448 | 449 | # print ("Check Final output : ", tensor.size()) 450 | 451 | # save_tensor_as_grid(tensor, grid_size=4, save_path="grid_image.png") 452 | # data = {"image":tensor , "p_id" : self.paths[index]} 453 | return tensor 454 | 455 | 456 | class EchoDataset_from_Video(Dataset): 457 | def __init__( 458 | self, 459 | folder, 460 | image_size, 461 | channels=3, 462 | num_frames=11, 463 | horizontal_flip=False, 464 | force_num_frames=True, 465 | exts=['gif', 'mp4'], 466 | sample_texts=None # 新增参数 467 | ): 468 | super().__init__() 469 | self.folder = os.path.join(folder, "mp4") 470 | self.folder_ekg= os.path.join(folder, "ekg") 471 | 472 | self.image_size = image_size 473 | self.channels = channels 474 | self.paths = [p for ext in exts for p in Path( 475 | f'{folder}').glob(f'**/*.{ext}')] 476 | self.paths.sort(key=sort_key) 477 | self.sample_texts = sample_texts 478 | self.transform = T.Compose([ 479 | T.Resize(image_size), 480 | T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), 481 | T.CenterCrop(image_size), 482 | T.ToTensor() 483 | ]) 484 | 485 | # TODO: rework so it is faster, for now it works but is bad 486 | # self.transform_for_videos = T.Compose([ 487 | # T.ToPILImage(), # added to PIL conversion because video is read with cv2 488 | # T.Resize(image_size), 489 | # # T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), 490 | # T.ToTensor() 491 | # ]) 492 | 493 | self.gif_to_tensor = partial( 494 | gif_to_tensor, channels=self.channels, transform=self.transform) 495 | # self.mp4_to_tensor = partial( 496 | # video_to_tensor, transform=self.transform_for_videos, crop_size=self.image_size, num_frames=num_frames) 497 | 498 | 499 | self.cast_num_frames_fn = partial( 500 | cast_num_frames, frames=num_frames) if force_num_frames else identity 501 | 502 | 503 | def apply_augmentation(img, shear_x, shear_y, contrast_factor): 504 | # Apply contrast augmentation 505 | img = T.functional.adjust_contrast(img, contrast_factor) 506 | 507 | # # Apply shear x, y augmentation 508 | img = T.functional.affine(img, angle=0, translate=[0, 0], scale=1.0, shear=[shear_x, shear_y]) 509 | 510 | return img 511 | 512 | def create_transform(image_size): 513 | def transform(img, shear_x, shear_y, contrast_factor): 514 | if not isinstance(img, Image.Image): 515 | img = T.ToPILImage()(img) 516 | img = T.Resize(image_size)(img) 517 | img = apply_augmentation(img, shear_x, shear_y, contrast_factor) 518 | return T.ToTensor()(img) 519 | return transform 520 | 521 | self.transform_for_videos = create_transform(image_size) 522 | self.mp4_to_tensor = partial( 523 | video_to_tensor, transform=self.transform_for_videos, crop_size=self.image_size, num_frames=num_frames) 524 | 525 | def __len__(self): 526 | return len(self.paths) 527 | 528 | def __getitem__(self, index): 529 | path = self.paths[index] 530 | 531 | ext = path.suffix 532 | 533 | if ext == '.gif': 534 | tensor = self.gif_to_tensor(path) 535 | elif ext == '.mp4': 536 | tensor = self.mp4_to_tensor(str(path)) 537 | else: 538 | raise ValueError(f'unknown extension {ext}') 539 | 540 | tensor = self.cast_num_frames_fn(tensor) 541 | 542 | return tensor 543 | 544 | def collate_tensors_and_strings(batch): 545 | tensors, ekgs = zip(*batch) 546 | 547 | # Process tensors (assuming they are already in the correct format) 548 | tensors = torch.stack(tensors, dim=0) 549 | 550 | # Process EKGs 551 | processed_ekgs = [] 552 | for ekg in ekgs: 553 | processed_ekgs.append(ekg) 554 | 555 | processed_ekgs = torch.stack(processed_ekgs, dim=0) 556 | 557 | # print ("BATCH COLLATE ", tensors.size(), processed_ekgs.size()) 558 | 559 | return tensors, processed_ekgs 560 | 561 | 562 | def DataLoader(*args, **kwargs): 563 | return PytorchDataLoader(*args, collate_fn=collate_tensors_and_strings, **kwargs) 564 | -------------------------------------------------------------------------------- /EchoFM/models_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # MAE: https://github.com/facebookresearch/mae 11 | # -------------------------------------------------------- 12 | 13 | from functools import partial 14 | 15 | import torch 16 | import torch.nn as nn 17 | from EchoFM.util import video_vit 18 | from EchoFM.util.logging import master_print as print 19 | import torch.nn.functional as F 20 | 21 | class MaskedAutoencoderViT(nn.Module): 22 | """Masked Autoencoder with VisionTransformer backbone""" 23 | 24 | def __init__( 25 | self, 26 | img_size=224, 27 | patch_size=16, 28 | in_chans=3, 29 | embed_dim=1024, 30 | depth=24, 31 | num_heads=16, 32 | decoder_embed_dim=512, 33 | decoder_depth=8, 34 | decoder_num_heads=16, 35 | mlp_ratio=4.0, 36 | norm_layer=nn.LayerNorm, 37 | norm_pix_loss=False, 38 | num_frames=16, 39 | t_patch_size=4, 40 | patch_embed=video_vit.PatchEmbed, 41 | no_qkv_bias=False, 42 | sep_pos_embed=False, 43 | trunc_init=False, 44 | cls_embed=False, 45 | pred_t_dim=8, 46 | **kwargs, 47 | ): 48 | super().__init__() 49 | self.trunc_init = trunc_init 50 | self.sep_pos_embed = sep_pos_embed 51 | self.cls_embed = cls_embed 52 | self.pred_t_dim = pred_t_dim 53 | self.t_pred_patch_size = t_patch_size * pred_t_dim // num_frames 54 | 55 | self.patch_embed = patch_embed( 56 | img_size, 57 | patch_size, 58 | in_chans, 59 | embed_dim, 60 | num_frames, 61 | t_patch_size, 62 | ) 63 | num_patches = self.patch_embed.num_patches 64 | input_size = self.patch_embed.input_size 65 | self.input_size = input_size 66 | 67 | if self.cls_embed: 68 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 69 | self.decoder_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 70 | self.decoder_prj_cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 71 | 72 | if sep_pos_embed: 73 | self.pos_embed_spatial = nn.Parameter( 74 | torch.zeros(1, input_size[1] * input_size[2], embed_dim) 75 | ) 76 | self.pos_embed_temporal = nn.Parameter( 77 | torch.zeros(1, input_size[0], embed_dim) 78 | ) 79 | if self.cls_embed: 80 | self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim)) 81 | else: 82 | if self.cls_embed: 83 | _num_patches = num_patches + 1 84 | else: 85 | _num_patches = num_patches 86 | 87 | self.pos_embed = nn.Parameter( 88 | torch.zeros(1, _num_patches, embed_dim), 89 | ) 90 | 91 | self.blocks = nn.ModuleList( 92 | [ 93 | video_vit.Block( 94 | embed_dim, 95 | num_heads, 96 | mlp_ratio, 97 | qkv_bias=not no_qkv_bias, 98 | qk_scale=None, 99 | norm_layer=norm_layer, 100 | ) 101 | for i in range(depth) 102 | ] 103 | ) 104 | 105 | self.decoder_block = video_vit.Block( 106 | embed_dim, 107 | num_heads, 108 | mlp_ratio, 109 | qkv_bias=not no_qkv_bias, 110 | qk_scale=None, 111 | norm_layer=norm_layer, 112 | ) 113 | 114 | self.norm = norm_layer(embed_dim) 115 | 116 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) 117 | 118 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 119 | 120 | if sep_pos_embed: 121 | self.decoder_pos_embed_spatial = nn.Parameter( 122 | torch.zeros(1, input_size[1] * input_size[2], decoder_embed_dim) 123 | ) 124 | self.decoder_pos_embed_temporal = nn.Parameter( 125 | torch.zeros(1, input_size[0], decoder_embed_dim) 126 | ) 127 | if self.cls_embed: 128 | self.decoder_pos_embed_class = nn.Parameter( 129 | torch.zeros(1, 1, decoder_embed_dim) 130 | ) 131 | else: 132 | if self.cls_embed: 133 | _num_patches = num_patches + 1 134 | else: 135 | _num_patches = num_patches 136 | 137 | self.decoder_pos_embed = nn.Parameter( 138 | torch.zeros(1, _num_patches, decoder_embed_dim), 139 | ) 140 | 141 | self.decoder_blocks = nn.ModuleList( 142 | [ 143 | video_vit.Block( 144 | decoder_embed_dim, 145 | decoder_num_heads, 146 | mlp_ratio, 147 | qkv_bias=not no_qkv_bias, 148 | qk_scale=None, 149 | norm_layer=norm_layer, 150 | ) 151 | for i in range(decoder_depth) 152 | ] 153 | ) 154 | 155 | self.decoder_norm = norm_layer(decoder_embed_dim) 156 | self.decoder_pred = nn.Linear( 157 | decoder_embed_dim, 158 | self.t_pred_patch_size * patch_size**2 * in_chans, 159 | bias=True, 160 | ) 161 | 162 | self.norm_pix_loss = norm_pix_loss 163 | 164 | self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7) 165 | self.initialize_weights() 166 | 167 | 168 | print("model initialized") 169 | 170 | def self_similarity(self, cls_tokens): 171 | """ 172 | Compute self-similarity map using cosine similarity. 173 | 174 | Args: 175 | cls_tokens (list of tensors): List of tensors, where each tensor is of shape [N, D]. 176 | 177 | Returns: 178 | similarity_map (tensor): Tensor of shape [N, T, T] containing self-similarity values. 179 | """ 180 | # Concatenate the list into a single tensor of shape [N, T, D] 181 | cls_tokens_tensor = torch.stack(cls_tokens, dim=1) # Shape: [N, T, D] 182 | 183 | # Normalize embeddings to unit vectors 184 | cls_tokens_tensor = F.normalize(cls_tokens_tensor, p=2, dim=-1) # Shape: [N, T, D] 185 | 186 | # Compute cosine similarity 187 | similarity_map = torch.matmul(cls_tokens_tensor, cls_tokens_tensor.transpose(1, 2)) # Shape: [N, T, T] 188 | 189 | return similarity_map 190 | 191 | def triplet_sampling(self, similarity_map, cls_tokens): 192 | """ 193 | Perform triplet sampling with one anchor, one positive, and one negative per batch. 194 | 195 | Args: 196 | similarity_map (tensor): Self-similarity map of shape [N, T, T]. 197 | cls_tokens (tensor): Tensor of CLS tokens, shape [N, T, D]. 198 | 199 | Returns: 200 | anchor (tensor): Tensor of anchor embeddings, shape [N, D]. 201 | positive (tensor): Tensor of positive embeddings, shape [N, D]. 202 | negative (tensor): Tensor of negative embeddings, shape [N, D]. 203 | """ 204 | 205 | cls_tokens = torch.stack(cls_tokens, dim=1) # Shape: [N, T, D] 206 | 207 | N, T, D = cls_tokens.shape 208 | 209 | anchors, positives, negatives = [], [], [] 210 | 211 | for n in range(N): # Iterate over batches 212 | # Extract the first row (anchor is always index 0) 213 | first_row = similarity_map[n, 0, :] # Shape: [T] 214 | 215 | # Compute mean similarity for the first row 216 | mean_similarity = first_row.mean().item() 217 | 218 | # Identify positive and negative indices, excluding anchor index (0) 219 | positive_indices = (first_row > mean_similarity).nonzero(as_tuple=True)[0] 220 | positive_indices = positive_indices[positive_indices != 0] # Exclude anchor (index 0) 221 | 222 | negative_indices = (first_row <= mean_similarity).nonzero(as_tuple=True)[0] 223 | negative_indices = negative_indices[negative_indices != 0] # Exclude anchor (index 0) 224 | 225 | # Ensure we have at least one positive and one negative 226 | if len(positive_indices) > 0 and len(negative_indices) > 0: 227 | # Randomly select one positive and one negative 228 | pos_idx = positive_indices[torch.randint(len(positive_indices), (1,))].item() 229 | neg_idx = negative_indices[torch.randint(len(negative_indices), (1,))].item() 230 | 231 | # Append CLS tokens for the selected indices 232 | anchors.append(cls_tokens[n, 0, :]) # Anchor is always index 0 233 | positives.append(cls_tokens[n, pos_idx, :]) # Positive CLS token 234 | negatives.append(cls_tokens[n, neg_idx, :]) # Negative CLS token 235 | 236 | # Stack tensors to create final batch outputs 237 | anchor = torch.stack(anchors) # Shape: [N, D] 238 | positive = torch.stack(positives) # Shape: [N, D] 239 | negative = torch.stack(negatives) # Shape: [N, D] 240 | 241 | return anchor, positive, negative 242 | 243 | def initialize_weights(self): 244 | if self.cls_embed: 245 | torch.nn.init.trunc_normal_(self.cls_token, std=0.02) 246 | if self.sep_pos_embed: 247 | torch.nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02) 248 | torch.nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02) 249 | 250 | torch.nn.init.trunc_normal_(self.decoder_pos_embed_spatial, std=0.02) 251 | torch.nn.init.trunc_normal_(self.decoder_pos_embed_temporal, std=0.02) 252 | 253 | if self.cls_embed: 254 | torch.nn.init.trunc_normal_(self.pos_embed_class, std=0.02) 255 | torch.nn.init.trunc_normal_(self.decoder_pos_embed_class, std=0.02) 256 | else: 257 | torch.nn.init.trunc_normal_(self.pos_embed, std=0.02) 258 | torch.nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02) 259 | w = self.patch_embed.proj.weight.data 260 | if self.trunc_init: 261 | torch.nn.init.trunc_normal_(w) 262 | torch.nn.init.trunc_normal_(self.mask_token, std=0.02) 263 | else: 264 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 265 | torch.nn.init.normal_(self.mask_token, std=0.02) 266 | 267 | # initialize nn.Linear and nn.LayerNorm 268 | self.apply(self._init_weights) 269 | 270 | def _init_weights(self, m): 271 | if isinstance(m, nn.Linear): 272 | # we use xavier_uniform following official JAX ViT: 273 | if self.trunc_init: 274 | nn.init.trunc_normal_(m.weight, std=0.02) 275 | else: 276 | torch.nn.init.xavier_uniform_(m.weight) 277 | if isinstance(m, nn.Linear) and m.bias is not None: 278 | nn.init.constant_(m.bias, 0) 279 | elif isinstance(m, nn.LayerNorm): 280 | nn.init.constant_(m.bias, 0) 281 | nn.init.constant_(m.weight, 1.0) 282 | 283 | def patchify(self, imgs): 284 | """ 285 | imgs: (N, 3, H, W) 286 | x: (N, L, patch_size**2 *3) 287 | """ 288 | N, _, T, H, W = imgs.shape 289 | p = self.patch_embed.patch_size[0] 290 | u = self.t_pred_patch_size 291 | assert H == W and H % p == 0 and T % u == 0 292 | h = w = H // p 293 | t = T // u 294 | 295 | x = imgs.reshape(shape=(N, 3, t, u, h, p, w, p)) 296 | x = torch.einsum("nctuhpwq->nthwupqc", x) 297 | x = x.reshape(shape=(N, t * h * w, u * p**2 * 3)) 298 | self.patch_info = (N, T, H, W, p, u, t, h, w) 299 | return x 300 | 301 | def unpatchify(self, x): 302 | """ 303 | x: (N, L, patch_size**2 *3) 304 | imgs: (N, 3, H, W) 305 | """ 306 | N, T, H, W, p, u, t, h, w = self.patch_info 307 | 308 | x = x.reshape(shape=(N, t, h, w, u, p, p, 3)) 309 | 310 | x = torch.einsum("nthwupqc->nctuhpwq", x) 311 | imgs = x.reshape(shape=(N, 3, T, H, W)) 312 | return imgs 313 | 314 | def uniform_random_masking(self, x, mask_ratio, L): 315 | """ 316 | Perform temporal consistent random masking by sampling the same spatial tokens across time steps. 317 | Args: 318 | x: Tensor of shape [N, T * L, D], sequence after patch embedding (flattened temporal and spatial dimensions). 319 | mask_ratio: Float, proportion of tokens to mask. 320 | L: Number of spatial tokens per time step. 321 | 322 | Returns: 323 | x_masked: Tensor of shape [N, len_keep * T, D], after masking. 324 | mask: Binary mask of shape [N, T * L], 0 is keep, 1 is remove. 325 | ids_restore: Indices to restore original sequence order. 326 | ids_keep: Indices of kept tokens. 327 | """ 328 | N, TL, D = x.shape # Batch size, total tokens, embedding dimension 329 | T = TL // L # Temporal length 330 | 331 | # Compute the number of tokens to keep per spatial location 332 | len_keep = int(L * (1 - mask_ratio)) 333 | 334 | # Generate random noise for each spatial location 335 | noise = torch.rand(N, L, device=x.device) # [N, L] 336 | 337 | # Sort spatial tokens based on noise 338 | ids_shuffle = torch.argsort(noise, dim=1) # [N, L] 339 | ids_keep = ids_shuffle[:, :len_keep] # Keep top len_keep indices [N, len_keep] 340 | ids_keep = ids_keep.unsqueeze(1).repeat(1, T, 1) # Broadcast to all time steps [N, T, len_keep] 341 | 342 | # Create a binary mask for all time steps 343 | mask = torch.ones(N, T, L, device=x.device) # Initialize mask with all 1s [N, T, L] 344 | 345 | for n in range(N): # Iterate over batch 346 | for t in range(T): 347 | mask[n, t, ids_keep[n]] = 0 # Use batch-specific ids_keep[n] 348 | 349 | mask = mask.view(N, TL) # Flatten to [N, T * L] 350 | ids_restore = torch.argsort(mask, dim=1) # Indices for restoring order 351 | 352 | # Mask input 353 | x_masked = x[mask == 0].view(N, -1, D) # Kept tokens only [N, len_keep * T, D] 354 | 355 | ids_keep = ids_keep.view(N, -1) 356 | return x_masked, mask, ids_restore, ids_keep 357 | 358 | def random_masking(self, x, mask_ratio): 359 | """ 360 | Perform per-sample random masking by per-sample shuffling. 361 | Per-sample shuffling is done by argsort random noise. 362 | x: [N, L, D], sequence 363 | """ 364 | N, L, D = x.shape # batch, length, dim 365 | len_keep = int(L * (1 - mask_ratio)) 366 | 367 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 368 | 369 | # sort noise for each sample 370 | ids_shuffle = torch.argsort( 371 | noise, dim=1 372 | ) # ascend: small is keep, large is remove 373 | ids_restore = torch.argsort(ids_shuffle, dim=1) 374 | 375 | # keep the first subset 376 | ids_keep = ids_shuffle[:, :len_keep] 377 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 378 | 379 | # generate the binary mask: 0 is keep, 1 is remove 380 | mask = torch.ones([N, L], device=x.device) 381 | mask[:, :len_keep] = 0 382 | # unshuffle to get the binary mask 383 | mask = torch.gather(mask, dim=1, index=ids_restore) 384 | 385 | return x_masked, mask, ids_restore, ids_keep 386 | 387 | def forward_encoder(self, x, mask_ratio): 388 | # embed patches 389 | x = self.patch_embed(x) 390 | N, T, L, C = x.shape 391 | 392 | x = x.reshape(N, T * L, C) 393 | 394 | # masking: length -> length * mask_ratio 395 | # x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio) 396 | 397 | x, mask, ids_restore, ids_keep = self.uniform_random_masking(x, mask_ratio, L) 398 | 399 | x = x.view(N, -1, C) 400 | # append cls token 401 | if self.cls_embed: 402 | cls_token = self.cls_token 403 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 404 | x = torch.cat((cls_tokens, x), dim=1) 405 | 406 | # add pos embed w/o cls token 407 | if self.sep_pos_embed: 408 | pos_embed = self.pos_embed_spatial.repeat( 409 | 1, self.input_size[0], 1 410 | ) + torch.repeat_interleave( 411 | self.pos_embed_temporal, 412 | self.input_size[1] * self.input_size[2], 413 | dim=1, 414 | ) 415 | pos_embed = pos_embed.expand(x.shape[0], -1, -1) 416 | pos_embed = torch.gather( 417 | pos_embed, 418 | dim=1, 419 | index=ids_keep.unsqueeze(-1).repeat(1, 1, pos_embed.shape[2]), 420 | ) 421 | if self.cls_embed: 422 | pos_embed = torch.cat( 423 | [ 424 | self.pos_embed_class.expand(pos_embed.shape[0], -1, -1), 425 | pos_embed, 426 | ], 427 | 1, 428 | ) 429 | else: 430 | if self.cls_embed: 431 | cls_ind = 1 432 | else: 433 | cls_ind = 0 434 | pos_embed = self.pos_embed[:, cls_ind:, :].expand(x.shape[0], -1, -1) 435 | pos_embed = torch.gather( 436 | pos_embed, 437 | dim=1, 438 | index=ids_keep.unsqueeze(-1).repeat(1, 1, pos_embed.shape[2]), 439 | ) 440 | if self.cls_embed: 441 | pos_embed = torch.cat( 442 | [ 443 | self.pos_embed[:, :1, :].expand(x.shape[0], -1, -1), 444 | pos_embed, 445 | ], 446 | 1, 447 | ) 448 | x = x.view([N, -1, C]) + pos_embed 449 | 450 | # apply Transformer blocks 451 | for blk in self.blocks: 452 | x = blk(x) 453 | x = self.norm(x) 454 | 455 | if self.cls_embed: 456 | # remove cls token 457 | x = x[:, 1:, :] 458 | else: 459 | x = x[:, :, :] 460 | 461 | return x, mask, ids_restore 462 | 463 | def decoder_prj(self, x): 464 | # apply Transformer blocks 465 | x = self.decoder_block(x) 466 | x = self.norm(x) 467 | 468 | if self.cls_embed: 469 | return x[:, 0, :] 470 | else: 471 | print ('CLS token is needed') 472 | 473 | 474 | def forward_prj(self, x, ids_restore): 475 | N = x.shape[0] 476 | T = self.patch_embed.t_grid_size 477 | H = W = self.patch_embed.grid_size 478 | 479 | # embed tokens (divide to temporal) 480 | 481 | # x = 4 392 1024 482 | 483 | # x = reshape() -> 4 8 49 1024 484 | x = x.view(N, T, 49, 1024) 485 | 486 | cls_ = [] 487 | for i in range(T): 488 | x_t = x[:,i,:,:] 489 | 490 | if self.cls_embed: 491 | decoder_cls_token = self.decoder_prj_cls_token 492 | decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1) 493 | x_t = torch.cat((decoder_cls_tokens, x_t), dim=1) 494 | 495 | x_t_cls = self.decoder_prj(x_t) #vit 496 | cls_.append(x_t_cls) 497 | return cls_ 498 | 499 | def forward_decoder(self, x, ids_restore): 500 | N = x.shape[0] 501 | T = self.patch_embed.t_grid_size 502 | H = W = self.patch_embed.grid_size 503 | 504 | # embed tokens 505 | x = self.decoder_embed(x) 506 | C = x.shape[-1] 507 | 508 | # append mask tokens to sequence 509 | mask_tokens = self.mask_token.repeat(N, T * H * W + 0 - x.shape[1], 1) 510 | x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token 511 | x_ = x_.view([N, T * H * W, C]) 512 | x_ = torch.gather( 513 | x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2]) 514 | ) # unshuffle 515 | x = x_.view([N, T * H * W, C]) 516 | # append cls token 517 | if self.cls_embed: 518 | decoder_cls_token = self.decoder_cls_token 519 | decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1) 520 | x = torch.cat((decoder_cls_tokens, x), dim=1) 521 | 522 | if self.sep_pos_embed: 523 | decoder_pos_embed = self.decoder_pos_embed_spatial.repeat( 524 | 1, self.input_size[0], 1 525 | ) + torch.repeat_interleave( 526 | self.decoder_pos_embed_temporal, 527 | self.input_size[1] * self.input_size[2], 528 | dim=1, 529 | ) 530 | if self.cls_embed: 531 | decoder_pos_embed = torch.cat( 532 | [ 533 | self.decoder_pos_embed_class.expand( 534 | decoder_pos_embed.shape[0], -1, -1 535 | ), 536 | decoder_pos_embed, 537 | ], 538 | 1, 539 | ) 540 | else: 541 | decoder_pos_embed = self.decoder_pos_embed[:, :, :] 542 | 543 | # add pos embed 544 | x = x + decoder_pos_embed 545 | 546 | attn = self.decoder_blocks[0].attn 547 | requires_t_shape = hasattr(attn, "requires_t_shape") and attn.requires_t_shape 548 | if requires_t_shape: 549 | x = x.view([N, T, H * W, C]) 550 | 551 | # apply Transformer blocks 552 | for blk in self.decoder_blocks: 553 | x = blk(x) 554 | x = self.decoder_norm(x) 555 | 556 | # predictor projection 557 | x = self.decoder_pred(x) 558 | 559 | if requires_t_shape: 560 | x = x.view([N, T * H * W, -1]) 561 | 562 | if self.cls_embed: 563 | # remove cls token 564 | x = x[:, 1:, :] 565 | else: 566 | x = x[:, :, :] 567 | 568 | return x 569 | 570 | def forward_loss(self, imgs, pred, mask): 571 | """ 572 | imgs: [N, 3, T, H, W] 573 | pred: [N, t*h*w, u*p*p*3] 574 | mask: [N*t, h*w], 0 is keep, 1 is remove, 575 | """ 576 | _imgs = torch.index_select( 577 | imgs, 578 | 2, 579 | torch.linspace( 580 | 0, 581 | imgs.shape[2] - 1, 582 | self.pred_t_dim, 583 | ) 584 | .long() 585 | .to(imgs.device), 586 | ) 587 | target = self.patchify(_imgs) 588 | if self.norm_pix_loss: 589 | mean = target.mean(dim=-1, keepdim=True) 590 | var = target.var(dim=-1, keepdim=True) 591 | target = (target - mean) / (var + 1.0e-6) ** 0.5 592 | 593 | loss = (pred - target) ** 2 594 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch 595 | mask = mask.view(loss.shape) 596 | 597 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches 598 | 599 | return loss 600 | 601 | def forward(self, imgs, mask_ratio=0.75): 602 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) 603 | 604 | cls_tokens = self.forward_prj(latent, ids_restore) 605 | 606 | similarity_map = self.self_similarity(cls_tokens) 607 | 608 | anchor, positive, negative = self.triplet_sampling(similarity_map, cls_tokens) 609 | 610 | # triplet sampling 611 | triplet_loss = self.triplet_loss(anchor, positive, negative) 612 | 613 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] 614 | loss = self.forward_loss(imgs, pred, mask) 615 | 616 | loss = loss + triplet_loss 617 | return loss, pred, mask 618 | 619 | 620 | def mae_vit_base_patch16(**kwargs): 621 | model = MaskedAutoencoderViT( 622 | patch_size=16, 623 | embed_dim=768, 624 | depth=12, 625 | num_heads=12, 626 | mlp_ratio=4, 627 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 628 | **kwargs, 629 | ) 630 | return model 631 | 632 | 633 | def mae_vit_large_patch16(**kwargs): 634 | model = MaskedAutoencoderViT( 635 | patch_size=16, 636 | embed_dim=1024, 637 | depth=24, 638 | num_heads=16, 639 | mlp_ratio=4, 640 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 641 | **kwargs, 642 | ) 643 | return model 644 | 645 | 646 | def mae_vit_huge_patch14(**kwargs): 647 | model = MaskedAutoencoderViT( 648 | patch_size=14, 649 | embed_dim=1280, 650 | depth=32, 651 | num_heads=16, 652 | mlp_ratio=4, 653 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 654 | **kwargs, 655 | ) 656 | return model 657 | -------------------------------------------------------------------------------- /EchoFM/util/decoder/transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | 6 | import math 7 | 8 | # import cv2 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | import torchvision.transforms.functional as F 14 | from PIL import Image 15 | from torchvision import transforms 16 | 17 | from .rand_augment import rand_augment_transform 18 | 19 | _pil_interpolation_to_str = { 20 | Image.NEAREST: "PIL.Image.NEAREST", 21 | Image.BILINEAR: "PIL.Image.BILINEAR", 22 | Image.BICUBIC: "PIL.Image.BICUBIC", 23 | Image.LANCZOS: "PIL.Image.LANCZOS", 24 | Image.HAMMING: "PIL.Image.HAMMING", 25 | Image.BOX: "PIL.Image.BOX", 26 | } 27 | 28 | 29 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) 30 | 31 | 32 | def _pil_interp(method): 33 | if method == "bicubic": 34 | return Image.BICUBIC 35 | elif method == "lanczos": 36 | return Image.LANCZOS 37 | elif method == "hamming": 38 | return Image.HAMMING 39 | else: 40 | return Image.BILINEAR 41 | 42 | 43 | def random_short_side_scale_jitter( 44 | images, min_size, max_size, inverse_uniform_sampling=False 45 | ): 46 | """ 47 | Perform a spatial short scale jittering on the given images. 48 | Args: 49 | images (tensor): images to perform scale jitter. Dimension is 50 | `num frames` x `channel` x `height` x `width`. 51 | min_size (int): the minimal size to scale the frames. 52 | max_size (int): the maximal size to scale the frames. 53 | inverse_uniform_sampling (bool): if True, sample uniformly in 54 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the 55 | scale. If False, take a uniform sample from [min_scale, max_scale]. 56 | Returns: 57 | (tensor): the scaled images with dimension of 58 | `num frames` x `channel` x `new height` x `new width`. 59 | """ 60 | if inverse_uniform_sampling: 61 | size = int(round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))) 62 | else: 63 | size = int(round(np.random.uniform(min_size, max_size))) 64 | 65 | height = images.shape[2] 66 | width = images.shape[3] 67 | if (width <= height and width == size) or (height <= width and height == size): 68 | return images 69 | new_width = size 70 | new_height = size 71 | if width < height: 72 | new_height = int(math.floor((float(height) / width) * size)) 73 | else: 74 | new_width = int(math.floor((float(width) / height) * size)) 75 | return torch.nn.functional.interpolate( 76 | images, 77 | size=(new_height, new_width), 78 | mode="bilinear", 79 | align_corners=False, 80 | ) 81 | 82 | 83 | def random_crop(images, size): 84 | """ 85 | Perform random spatial crop on the given images. 86 | Args: 87 | images (tensor): images to perform random crop. The dimension is 88 | `num frames` x `channel` x `height` x `width`. 89 | size (int): the size of height and width to crop on the image. 90 | Returns: 91 | cropped (tensor): cropped images with dimension of 92 | `num frames` x `channel` x `size` x `size`. 93 | """ 94 | if images.shape[2] == size and images.shape[3] == size: 95 | return images 96 | height = images.shape[2] 97 | width = images.shape[3] 98 | y_offset = 0 99 | if height > size: 100 | y_offset = int(np.random.randint(0, height - size)) 101 | x_offset = 0 102 | if width > size: 103 | x_offset = int(np.random.randint(0, width - size)) 104 | cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] 105 | return cropped 106 | 107 | 108 | def horizontal_flip(prob, images): 109 | """ 110 | Perform horizontal flip on the given images. 111 | Args: 112 | prob (float): probility to flip the images. 113 | images (tensor): images to perform horizontal flip, the dimension is 114 | `num frames` x `channel` x `height` x `width`. 115 | Returns: 116 | images (tensor): images with dimension of 117 | `num frames` x `channel` x `height` x `width`. 118 | """ 119 | if np.random.uniform() < prob: 120 | images = images.flip((-1)) 121 | return images 122 | 123 | 124 | def uniform_crop(images, size, spatial_idx, scale_size=None): 125 | """ 126 | Perform uniform spatial sampling on the images. 127 | Args: 128 | images (tensor): images to perform uniform crop. The dimension is 129 | `num frames` x `channel` x `height` x `width`. 130 | size (int): size of height and weight to crop the images. 131 | spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width 132 | is larger than height. Or 0, 1, or 2 for top, center, and bottom 133 | crop if height is larger than width. 134 | scale_size (int): optinal. If not None, resize the images to scale_size before 135 | performing any crop. 136 | Returns: 137 | cropped (tensor): images with dimension of 138 | `num frames` x `channel` x `size` x `size`. 139 | """ 140 | assert spatial_idx in [0, 1, 2] 141 | ndim = len(images.shape) 142 | if ndim == 3: 143 | images = images.unsqueeze(0) 144 | height = images.shape[2] 145 | width = images.shape[3] 146 | 147 | if scale_size is not None: 148 | if width <= height: 149 | width, height = scale_size, int(height / width * scale_size) 150 | else: 151 | width, height = int(width / height * scale_size), scale_size 152 | images = torch.nn.functional.interpolate( 153 | images, 154 | size=(height, width), 155 | mode="bilinear", 156 | align_corners=False, 157 | ) 158 | 159 | y_offset = int(math.ceil((height - size) / 2)) 160 | x_offset = int(math.ceil((width - size) / 2)) 161 | 162 | if height > width: 163 | if spatial_idx == 0: 164 | y_offset = 0 165 | elif spatial_idx == 2: 166 | y_offset = height - size 167 | else: 168 | if spatial_idx == 0: 169 | x_offset = 0 170 | elif spatial_idx == 2: 171 | x_offset = width - size 172 | cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] 173 | if ndim == 3: 174 | cropped = cropped.squeeze(0) 175 | return cropped 176 | 177 | 178 | def blend(images1, images2, alpha): 179 | """ 180 | Blend two images with a given weight alpha. 181 | Args: 182 | images1 (tensor): the first images to be blended, the dimension is 183 | `num frames` x `channel` x `height` x `width`. 184 | images2 (tensor): the second images to be blended, the dimension is 185 | `num frames` x `channel` x `height` x `width`. 186 | alpha (float): the blending weight. 187 | Returns: 188 | (tensor): blended images, the dimension is 189 | `num frames` x `channel` x `height` x `width`. 190 | """ 191 | return images1 * alpha + images2 * (1 - alpha) 192 | 193 | 194 | def grayscale(images): 195 | """ 196 | Get the grayscale for the input images. The channels of images should be 197 | in order BGR. 198 | Args: 199 | images (tensor): the input images for getting grayscale. Dimension is 200 | `num frames` x `channel` x `height` x `width`. 201 | Returns: 202 | img_gray (tensor): blended images, the dimension is 203 | `num frames` x `channel` x `height` x `width`. 204 | """ 205 | # R -> 0.299, G -> 0.587, B -> 0.114. 206 | img_gray = torch.tensor(images) 207 | gray_channel = 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] 208 | img_gray[:, 0] = gray_channel 209 | img_gray[:, 1] = gray_channel 210 | img_gray[:, 2] = gray_channel 211 | return img_gray 212 | 213 | 214 | def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): 215 | """ 216 | Perfrom a color jittering on the input images. The channels of images 217 | should be in order BGR. 218 | Args: 219 | images (tensor): images to perform color jitter. Dimension is 220 | `num frames` x `channel` x `height` x `width`. 221 | img_brightness (float): jitter ratio for brightness. 222 | img_contrast (float): jitter ratio for contrast. 223 | img_saturation (float): jitter ratio for saturation. 224 | Returns: 225 | images (tensor): the jittered images, the dimension is 226 | `num frames` x `channel` x `height` x `width`. 227 | """ 228 | 229 | jitter = [] 230 | if img_brightness != 0: 231 | jitter.append("brightness") 232 | if img_contrast != 0: 233 | jitter.append("contrast") 234 | if img_saturation != 0: 235 | jitter.append("saturation") 236 | 237 | if len(jitter) > 0: 238 | order = np.random.permutation(np.arange(len(jitter))) 239 | for idx in range(0, len(jitter)): 240 | if jitter[order[idx]] == "brightness": 241 | images = brightness_jitter(img_brightness, images) 242 | elif jitter[order[idx]] == "contrast": 243 | images = contrast_jitter(img_contrast, images) 244 | elif jitter[order[idx]] == "saturation": 245 | images = saturation_jitter(img_saturation, images) 246 | return images 247 | 248 | 249 | def brightness_jitter(var, images): 250 | """ 251 | Perfrom brightness jittering on the input images. The channels of images 252 | should be in order BGR. 253 | Args: 254 | var (float): jitter ratio for brightness. 255 | images (tensor): images to perform color jitter. Dimension is 256 | `num frames` x `channel` x `height` x `width`. 257 | Returns: 258 | images (tensor): the jittered images, the dimension is 259 | `num frames` x `channel` x `height` x `width`. 260 | """ 261 | alpha = 1.0 + np.random.uniform(-var, var) 262 | 263 | img_bright = torch.zeros(images.shape) 264 | images = blend(images, img_bright, alpha) 265 | return images 266 | 267 | 268 | def contrast_jitter(var, images): 269 | """ 270 | Perfrom contrast jittering on the input images. The channels of images 271 | should be in order BGR. 272 | Args: 273 | var (float): jitter ratio for contrast. 274 | images (tensor): images to perform color jitter. Dimension is 275 | `num frames` x `channel` x `height` x `width`. 276 | Returns: 277 | images (tensor): the jittered images, the dimension is 278 | `num frames` x `channel` x `height` x `width`. 279 | """ 280 | alpha = 1.0 + np.random.uniform(-var, var) 281 | 282 | img_gray = grayscale(images) 283 | img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) 284 | images = blend(images, img_gray, alpha) 285 | return images 286 | 287 | 288 | def saturation_jitter(var, images): 289 | """ 290 | Perfrom saturation jittering on the input images. The channels of images 291 | should be in order BGR. 292 | Args: 293 | var (float): jitter ratio for saturation. 294 | images (tensor): images to perform color jitter. Dimension is 295 | `num frames` x `channel` x `height` x `width`. 296 | Returns: 297 | images (tensor): the jittered images, the dimension is 298 | `num frames` x `channel` x `height` x `width`. 299 | """ 300 | alpha = 1.0 + np.random.uniform(-var, var) 301 | img_gray = grayscale(images) 302 | images = blend(images, img_gray, alpha) 303 | 304 | return images 305 | 306 | 307 | def lighting_jitter(images, alphastd, eigval, eigvec): 308 | """ 309 | Perform AlexNet-style PCA jitter on the given images. 310 | Args: 311 | images (tensor): images to perform lighting jitter. Dimension is 312 | `num frames` x `channel` x `height` x `width`. 313 | alphastd (float): jitter ratio for PCA jitter. 314 | eigval (list): eigenvalues for PCA jitter. 315 | eigvec (list[list]): eigenvectors for PCA jitter. 316 | Returns: 317 | out_images (tensor): the jittered images, the dimension is 318 | `num frames` x `channel` x `height` x `width`. 319 | """ 320 | if alphastd == 0: 321 | return images 322 | # generate alpha1, alpha2, alpha3. 323 | alpha = np.random.normal(0, alphastd, size=(1, 3)) 324 | eig_vec = np.array(eigvec) 325 | eig_val = np.reshape(eigval, (1, 3)) 326 | rgb = np.sum( 327 | eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), 328 | axis=1, 329 | ) 330 | out_images = torch.zeros_like(images) 331 | if len(images.shape) == 3: 332 | # C H W 333 | channel_dim = 0 334 | elif len(images.shape) == 4: 335 | # T C H W 336 | channel_dim = 1 337 | else: 338 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") 339 | 340 | for idx in range(images.shape[channel_dim]): 341 | # C H W 342 | if len(images.shape) == 3: 343 | out_images[idx] = images[idx] + rgb[2 - idx] 344 | # T C H W 345 | elif len(images.shape) == 4: 346 | out_images[:, idx] = images[:, idx] + rgb[2 - idx] 347 | else: 348 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") 349 | 350 | return out_images 351 | 352 | 353 | def color_normalization(images, mean, stddev): 354 | """ 355 | Perform color nomration on the given images. 356 | Args: 357 | images (tensor): images to perform color normalization. Dimension is 358 | `num frames` x `channel` x `height` x `width`. 359 | mean (list): mean values for normalization. 360 | stddev (list): standard deviations for normalization. 361 | 362 | Returns: 363 | out_images (tensor): the noramlized images, the dimension is 364 | `num frames` x `channel` x `height` x `width`. 365 | """ 366 | if len(images.shape) == 3: 367 | assert len(mean) == images.shape[0], "channel mean not computed properly" 368 | assert len(stddev) == images.shape[0], "channel stddev not computed properly" 369 | elif len(images.shape) == 4: 370 | assert len(mean) == images.shape[1], "channel mean not computed properly" 371 | assert len(stddev) == images.shape[1], "channel stddev not computed properly" 372 | else: 373 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") 374 | 375 | out_images = torch.zeros_like(images) 376 | for idx in range(len(mean)): 377 | # C H W 378 | if len(images.shape) == 3: 379 | out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] 380 | elif len(images.shape) == 4: 381 | out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] 382 | else: 383 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}") 384 | return out_images 385 | 386 | 387 | def _get_param_spatial_crop( 388 | scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False 389 | ): 390 | """ 391 | Given scale, ratio, height and width, return sampled coordinates of the videos. 392 | """ 393 | for _ in range(num_repeat): 394 | area = height * width 395 | target_area = random.uniform(*scale) * area 396 | if log_scale: 397 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 398 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 399 | else: 400 | aspect_ratio = random.uniform(*ratio) 401 | 402 | w = int(round(math.sqrt(target_area * aspect_ratio))) 403 | h = int(round(math.sqrt(target_area / aspect_ratio))) 404 | 405 | if np.random.uniform() < 0.5 and switch_hw: 406 | w, h = h, w 407 | 408 | if 0 < w <= width and 0 < h <= height: 409 | i = random.randint(0, height - h) 410 | j = random.randint(0, width - w) 411 | return i, j, h, w 412 | 413 | # Fallback to central crop 414 | in_ratio = float(width) / float(height) 415 | if in_ratio < min(ratio): 416 | w = width 417 | h = int(round(w / min(ratio))) 418 | elif in_ratio > max(ratio): 419 | h = height 420 | w = int(round(h * max(ratio))) 421 | else: # whole image 422 | w = width 423 | h = height 424 | i = (height - h) // 2 425 | j = (width - w) // 2 426 | return i, j, h, w 427 | 428 | 429 | def random_resized_crop( 430 | images, 431 | target_height, 432 | target_width, 433 | scale=(0.8, 1.0), 434 | ratio=(3.0 / 4.0, 4.0 / 3.0), 435 | ): 436 | """ 437 | Crop the given images to random size and aspect ratio. A crop of random 438 | size (default: of 0.08 to 1.0) of the original size and a random aspect 439 | ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This 440 | crop is finally resized to given size. This is popularly used to train the 441 | Inception networks. 442 | 443 | Args: 444 | images: Images to perform resizing and cropping. 445 | target_height: Desired height after cropping. 446 | target_width: Desired width after cropping. 447 | scale: Scale range of Inception-style area based random resizing. 448 | ratio: Aspect ratio range of Inception-style area based random resizing. 449 | """ 450 | 451 | height = images.shape[2] 452 | width = images.shape[3] 453 | 454 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) 455 | cropped = images[:, :, i : i + h, j : j + w] 456 | return torch.nn.functional.interpolate( 457 | cropped, 458 | size=(target_height, target_width), 459 | mode="bilinear", 460 | align_corners=False, 461 | ) 462 | 463 | 464 | def random_resized_crop_with_shift( 465 | images, 466 | target_height, 467 | target_width, 468 | scale=(0.8, 1.0), 469 | ratio=(3.0 / 4.0, 4.0 / 3.0), 470 | ): 471 | """ 472 | This is similar to random_resized_crop. However, it samples two different 473 | boxes (for cropping) for the first and last frame. It then linearly 474 | interpolates the two boxes for other frames. 475 | 476 | Args: 477 | images: Images to perform resizing and cropping. 478 | target_height: Desired height after cropping. 479 | target_width: Desired width after cropping. 480 | scale: Scale range of Inception-style area based random resizing. 481 | ratio: Aspect ratio range of Inception-style area based random resizing. 482 | """ 483 | t = images.shape[1] 484 | height = images.shape[2] 485 | width = images.shape[3] 486 | 487 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) 488 | i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) 489 | i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] 490 | j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] 491 | h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] 492 | w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] 493 | out = torch.zeros((3, t, target_height, target_width)) 494 | for ind in range(t): 495 | out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate( 496 | images[ 497 | :, 498 | ind : ind + 1, 499 | i_s[ind] : i_s[ind] + h_s[ind], 500 | j_s[ind] : j_s[ind] + w_s[ind], 501 | ], 502 | size=(target_height, target_width), 503 | mode="bilinear", 504 | align_corners=False, 505 | ) 506 | return out 507 | 508 | 509 | def create_random_augment( 510 | input_size, 511 | auto_augment=None, 512 | interpolation="bilinear", 513 | ): 514 | """ 515 | Get video randaug transform. 516 | 517 | Args: 518 | input_size: The size of the input video in tuple. 519 | auto_augment: Parameters for randaug. An example: 520 | "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number 521 | of operations to apply). 522 | interpolation: Interpolation method. 523 | """ 524 | if isinstance(input_size, tuple): 525 | img_size = input_size[-2:] 526 | else: 527 | img_size = input_size 528 | 529 | if auto_augment: 530 | assert isinstance(auto_augment, str) 531 | if isinstance(img_size, tuple): 532 | img_size_min = min(img_size) 533 | else: 534 | img_size_min = img_size 535 | aa_params = {"translate_const": int(img_size_min * 0.45)} 536 | if interpolation and interpolation != "random": 537 | aa_params["interpolation"] = _pil_interp(interpolation) 538 | if auto_augment.startswith("rand"): 539 | return transforms.Compose([rand_augment_transform(auto_augment, aa_params)]) 540 | raise NotImplementedError 541 | 542 | 543 | def random_sized_crop_img( 544 | im, 545 | size, 546 | jitter_scale=(0.08, 1.0), 547 | jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), 548 | max_iter=10, 549 | ): 550 | """ 551 | Performs Inception-style cropping (used for training). 552 | """ 553 | assert len(im.shape) == 3, "Currently only support image for random_sized_crop" 554 | h, w = im.shape[1:3] 555 | i, j, h, w = _get_param_spatial_crop( 556 | scale=jitter_scale, 557 | ratio=jitter_aspect, 558 | height=h, 559 | width=w, 560 | num_repeat=max_iter, 561 | log_scale=False, 562 | switch_hw=True, 563 | ) 564 | cropped = im[:, i : i + h, j : j + w] 565 | return torch.nn.functional.interpolate( 566 | cropped.unsqueeze(0), 567 | size=(size, size), 568 | mode="bilinear", 569 | align_corners=False, 570 | ).squeeze(0) 571 | 572 | 573 | # The following code are modified based on timm lib, we will replace the following 574 | # contents with dependency from PyTorchVideo. 575 | # https://github.com/facebookresearch/pytorchvideo 576 | class RandomResizedCropAndInterpolation: 577 | """Crop the given PIL Image to random size and aspect ratio with random interpolation. 578 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 579 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 580 | is finally resized to given size. 581 | This is popularly used to train the Inception networks. 582 | Args: 583 | size: expected output size of each edge 584 | scale: range of size of the origin size cropped 585 | ratio: range of aspect ratio of the origin aspect ratio cropped 586 | interpolation: Default: PIL.Image.BILINEAR 587 | """ 588 | 589 | def __init__( 590 | self, 591 | size, 592 | scale=(0.08, 1.0), 593 | ratio=(3.0 / 4.0, 4.0 / 3.0), 594 | interpolation="bilinear", 595 | ): 596 | if isinstance(size, tuple): 597 | self.size = size 598 | else: 599 | self.size = (size, size) 600 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 601 | print("range should be of kind (min, max)") 602 | 603 | if interpolation == "random": 604 | self.interpolation = _RANDOM_INTERPOLATION 605 | else: 606 | self.interpolation = _pil_interp(interpolation) 607 | self.scale = scale 608 | self.ratio = ratio 609 | 610 | @staticmethod 611 | def get_params(img, scale, ratio): 612 | """Get parameters for ``crop`` for a random sized crop. 613 | Args: 614 | img (PIL Image): Image to be cropped. 615 | scale (tuple): range of size of the origin size cropped 616 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 617 | Returns: 618 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 619 | sized crop. 620 | """ 621 | area = img.size[0] * img.size[1] 622 | 623 | for _ in range(10): 624 | target_area = random.uniform(*scale) * area 625 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 626 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 627 | 628 | w = int(round(math.sqrt(target_area * aspect_ratio))) 629 | h = int(round(math.sqrt(target_area / aspect_ratio))) 630 | 631 | if w <= img.size[0] and h <= img.size[1]: 632 | i = random.randint(0, img.size[1] - h) 633 | j = random.randint(0, img.size[0] - w) 634 | return i, j, h, w 635 | 636 | # Fallback to central crop 637 | in_ratio = img.size[0] / img.size[1] 638 | if in_ratio < min(ratio): 639 | w = img.size[0] 640 | h = int(round(w / min(ratio))) 641 | elif in_ratio > max(ratio): 642 | h = img.size[1] 643 | w = int(round(h * max(ratio))) 644 | else: # whole image 645 | w = img.size[0] 646 | h = img.size[1] 647 | i = (img.size[1] - h) // 2 648 | j = (img.size[0] - w) // 2 649 | return i, j, h, w 650 | 651 | def __call__(self, img): 652 | """ 653 | Args: 654 | img (PIL Image): Image to be cropped and resized. 655 | Returns: 656 | PIL Image: Randomly cropped and resized image. 657 | """ 658 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 659 | if isinstance(self.interpolation, (tuple, list)): 660 | interpolation = random.choice(self.interpolation) 661 | else: 662 | interpolation = self.interpolation 663 | return F.resized_crop(img, i, j, h, w, self.size, interpolation) 664 | 665 | def __repr__(self): 666 | if isinstance(self.interpolation, (tuple, list)): 667 | interpolate_str = " ".join( 668 | [_pil_interpolation_to_str[x] for x in self.interpolation] 669 | ) 670 | else: 671 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 672 | format_string = self.__class__.__name__ + "(size={0}".format(self.size) 673 | format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) 674 | format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) 675 | format_string += ", interpolation={0})".format(interpolate_str) 676 | return format_string 677 | --------------------------------------------------------------------------------