├── .DS_Store
├── figure
├── fig1.png
├── fig2.png
└── .DS_Store
├── .gitattributes
├── EchoFM
├── .DS_Store
├── util
│ ├── __pycache__
│ │ ├── env.cpython-310.pyc
│ │ ├── env.cpython-312.pyc
│ │ ├── misc.cpython-310.pyc
│ │ ├── kinetics.cpython-310.pyc
│ │ ├── logging.cpython-310.pyc
│ │ ├── lr_sched.cpython-310.pyc
│ │ └── video_vit.cpython-310.pyc
│ ├── decoder
│ │ ├── __pycache__
│ │ │ ├── utils.cpython-310.pyc
│ │ │ ├── decoder.cpython-310.pyc
│ │ │ ├── transform.cpython-310.pyc
│ │ │ ├── rand_augment.cpython-310.pyc
│ │ │ └── video_container.cpython-310.pyc
│ │ ├── video_container.py
│ │ ├── mixup.py
│ │ ├── random_erasing.py
│ │ ├── decoder.py
│ │ ├── utils.py
│ │ ├── rand_augment.py
│ │ └── transform.py
│ ├── env.py
│ ├── lr_sched.py
│ ├── pos_embed.py
│ ├── lr_decay.py
│ ├── logging.py
│ ├── video_vit.py
│ ├── meters.py
│ └── misc.py
├── __pycache__
│ ├── models_mae.cpython-310.pyc
│ └── engine_pretrain.cpython-310.pyc
├── engine_test.py
├── engine_pretrain.py
├── engine_finetune.py
├── models_vit.py
└── models_mae.py
├── data
├── __pycache__
│ └── dataset.cpython-310.pyc
└── dataset.py
├── run_pretrain.py
├── environment_setup.sh
├── README.md
├── config
└── config.yaml
└── main_pretrain.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/.DS_Store
--------------------------------------------------------------------------------
/figure/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/figure/fig1.png
--------------------------------------------------------------------------------
/figure/fig2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/figure/fig2.png
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/EchoFM/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/.DS_Store
--------------------------------------------------------------------------------
/figure/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/figure/.DS_Store
--------------------------------------------------------------------------------
/data/__pycache__/dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/data/__pycache__/dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/__pycache__/env.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/env.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/__pycache__/env.cpython-312.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/env.cpython-312.pyc
--------------------------------------------------------------------------------
/EchoFM/__pycache__/models_mae.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/__pycache__/models_mae.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/__pycache__/misc.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/misc.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/__pycache__/kinetics.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/kinetics.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/__pycache__/logging.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/logging.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/__pycache__/lr_sched.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/lr_sched.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/__pycache__/engine_pretrain.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/__pycache__/engine_pretrain.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/__pycache__/video_vit.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/__pycache__/video_vit.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/decoder/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/decoder/__pycache__/decoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/decoder.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/decoder/__pycache__/transform.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/transform.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/decoder/__pycache__/rand_augment.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/rand_augment.cpython-310.pyc
--------------------------------------------------------------------------------
/EchoFM/util/decoder/__pycache__/video_container.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SekeunKim/EchoFM/HEAD/EchoFM/util/decoder/__pycache__/video_container.cpython-310.pyc
--------------------------------------------------------------------------------
/run_pretrain.py:
--------------------------------------------------------------------------------
1 | from main_pretrain import get_args_parser, main
2 | from pathlib import Path
3 |
4 | def invoke_main() -> None:
5 | args = get_args_parser()
6 | args = args.parse_args()
7 | if args.output_dir:
8 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
9 | main(args)
10 |
11 | if __name__ == "__main__":
12 | invoke_main()
13 |
--------------------------------------------------------------------------------
/EchoFM/util/env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 |
5 | """Set up Environment."""
6 |
7 | from iopath.common.file_io import PathManagerFactory
8 |
9 | _ENV_SETUP_DONE = False
10 | pathmgr = PathManagerFactory.get(key="mae_st")
11 | checkpoint_pathmgr = PathManagerFactory.get(key="mae_st_checkpoint")
12 |
13 |
14 | def setup_environment():
15 | global _ENV_SETUP_DONE
16 | if _ENV_SETUP_DONE:
17 | return
18 | _ENV_SETUP_DONE = True
19 |
--------------------------------------------------------------------------------
/EchoFM/util/decoder/video_container.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 |
5 | import av
6 |
7 |
8 | def get_video_container(path_to_vid, multi_thread_decode=False):
9 | """
10 | Given the path to the video, return the pyav video container.
11 | Args:
12 | path_to_vid (str): path to the video.
13 | multi_thread_decode (bool): if True, perform multi-thread decoding.
14 | backend (str): decoder backend, options include `pyav` and
15 | `torchvision`, default is `pyav`.
16 | Returns:
17 | container (container): video container.
18 | """
19 | with open(path_to_vid, "rb") as fp:
20 | container = fp.read()
21 | return container
22 |
--------------------------------------------------------------------------------
/EchoFM/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 |
10 | def adjust_learning_rate(optimizer, epoch, args):
11 | """Decay the learning rate with half-cycle cosine after warmup"""
12 | if epoch < args.warmup_epochs:
13 | lr = args.lr * epoch / args.warmup_epochs
14 | else:
15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
16 | 1.0
17 | + math.cos(
18 | math.pi
19 | * (epoch - args.warmup_epochs)
20 | / (args.epochs - args.warmup_epochs)
21 | )
22 | )
23 | for param_group in optimizer.param_groups:
24 | if "lr_scale" in param_group:
25 | param_group["lr"] = lr * param_group["lr_scale"]
26 | else:
27 | param_group["lr"] = lr
28 | return lr
29 |
--------------------------------------------------------------------------------
/environment_setup.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | set -e
3 |
4 | CONDA_ENV=${1:-""}
5 | if [ -n "$CONDA_ENV" ]; then
6 | # This is required to activate conda environment
7 | eval "$(conda shell.bash hook)"
8 |
9 | conda create -n $CONDA_ENV python=3.10.14 -y
10 | conda activate $CONDA_ENV
11 |
12 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
13 | else
14 | echo "Skipping conda environment creation. Make sure you have the correct environment activated."
15 | fi
16 |
17 | pip install iopath, psutil, scipy, einops, tensorboard
18 | conda install simplejson
19 | # # This is required to enable PEP 660 support
20 | # pip install --upgrade pip setuptools
21 |
22 | # # Install FlashAttention2
23 | # pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
24 |
25 | # # Install VILA
26 | # pip install -e ".[train,eval]"
27 |
28 | # pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
29 |
30 | # pip install git+https://github.com/huggingface/transformers@v4.36.2
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## EchoFM - A Video Vision Foundation Model for Echocardiogram
2 |
3 | Official repo for [EchoFM: Foundation Model for Generalizable Echocardiogram Analysis]
4 |
5 | This model and associated code are released under the CC-BY-NC-ND 4.0 license and may only be used for non-commercial, academic research purposes with proper attribution. Any commercial use, sale, or other monetization of the EchoFM model and its derivatives, which include models trained on outputs from the EchoFM model or datasets created from the EchoFM model, is prohibited and requires prior approval.
6 |
7 |
8 |
9 | This work was supported by the National Research Foundation of Korea(NRF) grant funded by the Korea government(MSIT) (RS-2024-00348696)
10 |
11 | ## Key features
12 |
13 | - EchoFM is pre-trained on 290K Echocardiography clips with self-supervised learning
14 | - EchoFM has been validated in multiple downstream tasks including segmentatino, classification, disease detection tasks.
15 | - EchoFM can be efficiently adapted to customised tasks.
16 |
17 |
18 |
19 | ## 1. Environment Setup
20 |
21 | ```bash
22 | git clone https://github.com/SekeunKim/EchoFM.git
23 | cd EchoFM
24 | ./environment_setup.sh EchoFM
25 | ```
26 |
27 | ## 2. Download model
28 | Download the EchoFM weights from the following link:
29 | [EchoFM Weights](https://drive.google.com/drive/folders/1Gn43_qMwk-wzZIxZdxXLyk2mXDv5Jsxt?usp=share_link)
30 |
31 | ## 3. Citation
32 | If you find this repository useful, please consider citing this paper: [will be released soon]
33 | ```
34 | @article{kim2024echofm,
35 | title={EchoFM: Foundation Model for Generalizable Echocardiogram Analysis},
36 | author={Kim, Sekeun and Jin, Pengfei and Song, Sifan and Chen, Cheng and Li, Yiwei and Ren, Hui and Li, Xiang and Liu, Tianming and Li, Quanzheng},
37 | journal={arXiv preprint arXiv:2410.23413},
38 | year={2024}
39 | }
40 | ```
41 |
--------------------------------------------------------------------------------
/EchoFM/engine_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import EchoFM.util.misc as misc
13 | import torch
14 |
15 |
16 | @torch.no_grad()
17 | def test(data_loader, model, device, test_meter, fp32=False):
18 | metric_logger = misc.MetricLogger(delimiter=" ")
19 |
20 | # switch to evaluation mode
21 | model.eval()
22 | softmax = torch.nn.Softmax(dim=1).cuda()
23 |
24 | for cur_iter, (images, labels, video_idx) in enumerate(data_loader):
25 | images = images.to(device, non_blocking=True)
26 | labels = labels.to(device, non_blocking=True)
27 | video_idx = video_idx.to(device, non_blocking=True)
28 |
29 | if len(images.shape) == 6:
30 | b, r, c, t, h, w = images.shape
31 | images = images.view(b * r, c, t, h, w)
32 | labels = labels.view(b * r)
33 |
34 | # compute output
35 | with torch.cuda.amp.autocast(enabled=not fp32):
36 | preds = model(images)
37 | preds = softmax(preds)
38 |
39 | if torch.distributed.is_initialized():
40 | preds, labels, video_idx = misc.all_gather([preds, labels, video_idx])
41 | preds = preds.cpu()
42 | labels = labels.cpu()
43 | video_idx = video_idx.cpu()
44 | # Update and log stats.
45 | test_meter.update_stats(preds.detach(), labels.detach(), video_idx.detach())
46 | test_meter.log_iter_stats(cur_iter)
47 |
48 | test_meter.finalize_metrics()
49 | # gather the stats from all processes
50 | metric_logger.synchronize_between_processes()
51 | return test_meter.stats
52 |
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | # architecture
2 | arch: vit_base
3 | enc_arch: MAEViTEncoder
4 | dec_arch: MAEViTDecoder
5 |
6 | # wandb
7 | proj_name: mae3d
8 | run_name: ${proj_name}_${arch}_${dataset}
9 | wandb_id:
10 | disable_wandb: 0
11 | eval : True
12 | # dataset
13 | dataset: echo #mgh echo
14 | data_path: /nvme/zhoulei/MSD
15 | data_seed: 12345
16 | ts_fold: 0
17 | resize_h_w : [224,224]
18 | max_fr : 32
19 | view_type : "A24C"
20 |
21 | # output
22 | # output_dir: /nvme/zhoulei/ssl-framework/${run_name}
23 | # ckpt_dir: ${output_dir}/ckpts
24 |
25 | # data preprocessing
26 | roi_x: 128
27 | roi_y: 128
28 | roi_z: 128
29 | RandFlipd_prob: 0.2
30 | RandRotate90d_prob: 0.2
31 | RandScaleIntensityd_prob: 0.1
32 | RandShiftIntensityd_prob: 0.1
33 | spatial_dim: 3
34 | cache_rate: 1.
35 |
36 | # trainer
37 | trainer_name: MAE3DTrainer
38 | batch_size: 3
39 | vis_batch_size: 1
40 | start_epoch: 0
41 | warmup_epochs: 10
42 | epochs: 1000
43 | workers: 1 #8
44 | pretrain:
45 | resume:
46 | # model
47 | patchembed: 'PatchEmbed3D'
48 | pos_embed_type: 'sincos'
49 | mask_ratio: 0.75
50 | input_size: [224, 224, 32]
51 | patch_size: 16
52 | in_chans: 3
53 | encoder_embed_dim: 768
54 | encoder_depth: 12
55 | encoder_num_heads: 12
56 | decoder_embed_dim: 384
57 | decoder_depth: 8
58 | decoder_num_heads: 12
59 |
60 | # optimizer
61 | type: adamw
62 | lr: 6.4e-3
63 | beta1: 0.9
64 | beta2: 0.95
65 | weight_decay: 0.05
66 |
67 | # logging
68 | vis_freq: 10
69 | save_freq: 100
70 | print_freq: 5
71 |
72 | # distributed processing
73 | gpu: 0
74 | dist_url: # 'tcp://localhost:10001'
75 | world_size: 1
76 | multiprocessing_distributed: false
77 | dist_backend: nccl
78 | distributed:
79 | rank: 0
80 | ngpus_per_node:
81 |
82 | # randomness
83 | seed:
84 |
85 | # debugging
86 | debug: false
87 |
88 | #path
89 | base_pr_path : '/mount/home/local/PARTNERS/sk1064/workspace/nature/'
90 | base_data_path : '/mount/mnt/CAMCA/home/sk/us'
91 | base_json_path : '/mount/home/local/PARTNERS/sk1064/workspace/echo_samv2/dataset/Echo'
92 | json_path : '/camus/train_2c4c.json'
93 | output_dir: ${base_pr_path}/output/ssl-framework/${run_name}
94 | ckpt_dir: ${output_dir}/ckpts
95 |
96 | base_mgh_data_path : /mount/mnt/CAMCA/AorticStenosis
--------------------------------------------------------------------------------
/EchoFM/util/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import mae_st.util.logging as logging
11 | import numpy as np
12 | import torch
13 |
14 |
15 | logger = logging.get_logger(__name__)
16 |
17 |
18 | # --------------------------------------------------------
19 | # Interpolate position embeddings for high-resolution
20 | # References:
21 | # DeiT: https://github.com/facebookresearch/deit
22 | # --------------------------------------------------------
23 | def interpolate_pos_embed(model, checkpoint_model):
24 | if "pos_embed" in checkpoint_model:
25 | pos_embed_checkpoint = checkpoint_model["pos_embed"]
26 | embedding_size = pos_embed_checkpoint.shape[-1]
27 | num_patches = model.patch_embed.num_patches
28 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
29 | # height (== width) for the checkpoint position embedding
30 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
31 | # height (== width) for the new position embedding
32 | new_size = int(num_patches**0.5)
33 | # class_token and dist_token are kept unchanged
34 | if orig_size != new_size:
35 | print(
36 | "Position interpolate from %dx%d to %dx%d"
37 | % (orig_size, orig_size, new_size, new_size)
38 | )
39 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
40 | # only the position tokens are interpolated
41 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
42 | pos_tokens = pos_tokens.reshape(
43 | -1, orig_size, orig_size, embedding_size
44 | ).permute(0, 3, 1, 2)
45 | pos_tokens = torch.nn.functional.interpolate(
46 | pos_tokens,
47 | size=(new_size, new_size),
48 | mode="bicubic",
49 | align_corners=False,
50 | )
51 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
52 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
53 | checkpoint_model["pos_embed"] = new_pos_embed
54 |
--------------------------------------------------------------------------------
/EchoFM/util/lr_decay.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ELECTRA https://github.com/google-research/electra
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import json
13 |
14 |
15 | def param_groups_lrd(
16 | model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75
17 | ):
18 | """
19 | Parameter groups for layer-wise lr decay
20 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
21 | """
22 | param_group_names = {}
23 | param_groups = {}
24 |
25 | num_layers = len(model.blocks) + 1
26 |
27 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
28 |
29 | for n, p in model.named_parameters():
30 | if not p.requires_grad:
31 | continue
32 |
33 | # no decay: all 1D parameters and model specific ones
34 | if p.ndim == 1 or n in no_weight_decay_list:
35 | g_decay = "no_decay"
36 | this_decay = 0.0
37 | else:
38 | g_decay = "decay"
39 | this_decay = weight_decay
40 |
41 | layer_id = get_layer_id_for_vit(n, num_layers)
42 | group_name = "layer_%d_%s" % (layer_id, g_decay)
43 |
44 | if group_name not in param_group_names:
45 | this_scale = layer_scales[layer_id]
46 |
47 | param_group_names[group_name] = {
48 | "lr_scale": this_scale,
49 | "weight_decay": this_decay,
50 | "params": [],
51 | }
52 | param_groups[group_name] = {
53 | "lr_scale": this_scale,
54 | "weight_decay": this_decay,
55 | "params": [],
56 | }
57 |
58 | param_group_names[group_name]["params"].append(n)
59 | param_groups[group_name]["params"].append(p)
60 |
61 | print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
62 |
63 | return list(param_groups.values())
64 |
65 |
66 | def get_layer_id_for_vit(name, num_layers):
67 | """
68 | Assign a parameter with its layer id
69 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
70 | """
71 | if name in [
72 | "cls_token",
73 | "mask_token",
74 | ]:
75 | return 0
76 | elif name.startswith("patch_embed"):
77 | return 0
78 | elif name.startswith("pos_embed"):
79 | return 0
80 | elif name.startswith("blocks"):
81 | return int(name.split(".")[1]) + 1
82 | else:
83 | return num_layers
84 |
--------------------------------------------------------------------------------
/EchoFM/util/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 |
5 | """Logging."""
6 |
7 | import atexit
8 | import builtins
9 | import decimal
10 | import functools
11 | import logging
12 | import os
13 | import sys
14 |
15 | import simplejson
16 | import torch
17 | import torch.distributed as dist
18 | from iopath.common.file_io import g_pathmgr as pathmgr
19 |
20 |
21 | def is_master_proc(multinode=False):
22 | """
23 | Determines if the current process is the master process.
24 | """
25 | if dist.is_initialized():
26 | if multinode:
27 | return dist.get_rank() % dist.get_world_size() == 0
28 | else:
29 | return dist.get_rank() % torch.cuda.device_count() == 0
30 | else:
31 | return True
32 |
33 |
34 | def _suppress_print():
35 | """
36 | Suppresses printing from the current process.
37 | """
38 |
39 | def print_pass(*objects, sep=" ", end="\n", file=sys.stdout, flush=False):
40 | pass
41 |
42 | builtins.print = print_pass
43 |
44 |
45 | @functools.lru_cache(maxsize=None)
46 | def _cached_log_stream(filename):
47 | # Use 1K buffer if writing to cloud storage.
48 | io = pathmgr.open(filename, "a", buffering=1024 if "://" in filename else -1)
49 | atexit.register(io.close)
50 | return io
51 |
52 |
53 | def setup_logging(output_dir=None):
54 | """
55 | Sets up the logging for multiple processes. Only enable the logging for the
56 | master process, and suppress logging for the non-master processes.
57 | """
58 | # Set up logging format.
59 | if is_master_proc():
60 | # Enable logging for the master process.
61 | logging.root.handlers = []
62 | else:
63 | # Suppress logging for non-master processes.
64 | _suppress_print()
65 |
66 | logger = logging.getLogger()
67 | logger.setLevel(logging.DEBUG)
68 | logger.propagate = False
69 | plain_formatter = logging.Formatter(
70 | "[%(asctime)s][%(levelname)s] %(filename)s: %(lineno)3d: %(message)s",
71 | datefmt="%m/%d %H:%M:%S",
72 | )
73 |
74 | if is_master_proc():
75 | ch = logging.StreamHandler(stream=sys.stdout)
76 | ch.setLevel(logging.DEBUG)
77 | ch.setFormatter(plain_formatter)
78 | logger.addHandler(ch)
79 |
80 | if output_dir is not None and is_master_proc(multinode=True):
81 | filename = os.path.join(output_dir, "stdout.log")
82 | fh = logging.StreamHandler(_cached_log_stream(filename))
83 | fh.setLevel(logging.DEBUG)
84 | fh.setFormatter(plain_formatter)
85 | logger.addHandler(fh)
86 |
87 |
88 | def get_logger(name):
89 | """
90 | Retrieve the logger with the specified name or, if name is None, return a
91 | logger which is the root logger of the hierarchy.
92 | Args:
93 | name (string): name of the logger.
94 | """
95 | return logging.getLogger(name)
96 |
97 |
98 | def log_json_stats(stats):
99 | """
100 | Logs json stats.
101 | Args:
102 | stats (dict): a dictionary of statistical information to log.
103 | """
104 | stats = {
105 | k: decimal.Decimal("{:.5f}".format(v)) if isinstance(v, float) else v
106 | for k, v in stats.items()
107 | }
108 | json_stats = simplejson.dumps(stats, sort_keys=True, use_decimal=True)
109 | logger = get_logger(__name__)
110 | print("json_stats: {:s}".format(json_stats))
111 |
112 |
113 | def master_print(*args, **kwargs):
114 | if is_master_proc():
115 | print(*args, **kwargs)
116 | else:
117 | pass
118 |
--------------------------------------------------------------------------------
/EchoFM/engine_pretrain.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 | import math
12 | from typing import Iterable
13 |
14 | import EchoFM.util.lr_sched as lr_sched
15 | import EchoFM.util.misc as misc
16 | import torch
17 | from iopath.common.file_io import g_pathmgr as pathmgr
18 |
19 |
20 | def train_one_epoch(
21 | model: torch.nn.Module,
22 | data_loader: Iterable,
23 | optimizer: torch.optim.Optimizer,
24 | device: torch.device,
25 | epoch: int,
26 | loss_scaler,
27 | log_writer=None,
28 | args=None,
29 | fp32=False,
30 | ):
31 | model.train(True)
32 | metric_logger = misc.MetricLogger(delimiter=" ")
33 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
34 | metric_logger.add_meter(
35 | "cpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")
36 | )
37 | metric_logger.add_meter(
38 | "cpu_mem_all", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")
39 | )
40 | metric_logger.add_meter(
41 | "gpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")
42 | )
43 | metric_logger.add_meter(
44 | "mask_ratio", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")
45 | )
46 | header = "Epoch: [{}]".format(epoch)
47 | print_freq = 20
48 |
49 | accum_iter = args.accum_iter
50 |
51 | optimizer.zero_grad()
52 |
53 | if log_writer is not None:
54 | print("log_dir: {}".format(log_writer.log_dir))
55 |
56 | for data_iter_step, samples in enumerate(
57 | metric_logger.log_every(data_loader, print_freq, header)
58 | ):
59 | # we use a per iteration (instead of per epoch) lr scheduler
60 | if data_iter_step % accum_iter == 0:
61 | lr_sched.adjust_learning_rate(
62 | optimizer, data_iter_step / len(data_loader) + epoch, args
63 | )
64 |
65 | samples = samples.to(device, non_blocking=True)
66 | if len(samples.shape) == 6:
67 | b, r, c, t, h, w = samples.shape
68 | samples = samples.reshape(b * r, c, t, h, w)
69 |
70 | with torch.cuda.amp.autocast(enabled=not fp32):
71 | loss, _, _ = model(
72 | samples,
73 | mask_ratio=args.mask_ratio,
74 | )
75 |
76 | loss_value = loss.item()
77 |
78 | if not math.isfinite(loss_value):
79 | for _ in range(args.num_checkpoint_del):
80 | try:
81 | path = misc.get_last_checkpoint(args)
82 | pathmgr.rm(path)
83 | print(f"remove checkpoint {path}")
84 | except Exception as _:
85 | pass
86 | raise Exception("Loss is {}, stopping training".format(loss_value))
87 |
88 | loss /= accum_iter
89 | loss_scaler(
90 | loss,
91 | optimizer,
92 | parameters=model.parameters(),
93 | update_grad=(data_iter_step + 1) % accum_iter == 0,
94 | clip_grad=args.clip_grad,
95 | )
96 |
97 | if (data_iter_step + 1) % accum_iter == 0:
98 | optimizer.zero_grad()
99 |
100 | torch.cuda.synchronize()
101 |
102 | metric_logger.update(loss=loss_value)
103 | metric_logger.update(cpu_mem=misc.cpu_mem_usage()[0])
104 | metric_logger.update(cpu_mem_all=misc.cpu_mem_usage()[1])
105 | metric_logger.update(gpu_mem=misc.gpu_mem_usage())
106 | metric_logger.update(mask_ratio=args.mask_ratio)
107 |
108 | lr = optimizer.param_groups[0]["lr"]
109 | metric_logger.update(lr=lr)
110 |
111 | loss_value_reduce = misc.all_reduce_mean(loss_value)
112 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
113 | """We use epoch_1000x as the x-axis in tensorboard.
114 | This calibrates different curves when batch size changes.
115 | """
116 | epoch_1000x = int(
117 | (data_iter_step / len(data_loader) + epoch) * 1000 * args.repeat_aug
118 | )
119 | log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
120 | log_writer.add_scalar("lr", lr, epoch_1000x)
121 |
122 | # gather the stats from all processes
123 | metric_logger.synchronize_between_processes()
124 | print("Averaged stats:", metric_logger)
125 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
126 |
--------------------------------------------------------------------------------
/EchoFM/util/video_vit.py:
--------------------------------------------------------------------------------
1 | import EchoFM.util.logging as logging
2 | import torch
3 | import torch.nn as nn
4 | from timm.models.layers import to_2tuple
5 | from timm.models.vision_transformer import DropPath, Mlp
6 |
7 |
8 | logger = logging.get_logger(__name__)
9 |
10 |
11 | class PatchEmbed(nn.Module):
12 | """Image to Patch Embedding"""
13 |
14 | def __init__(
15 | self,
16 | img_size=224,
17 | patch_size=16,
18 | in_chans=3,
19 | embed_dim=768,
20 | # temporal related:
21 | frames=32,
22 | t_patch_size=4,
23 | ):
24 | super().__init__()
25 | img_size = to_2tuple(img_size)
26 | patch_size = to_2tuple(patch_size)
27 | assert img_size[1] % patch_size[1] == 0
28 | assert img_size[0] % patch_size[0] == 0
29 | assert frames % t_patch_size == 0
30 | num_patches = (
31 | (img_size[1] // patch_size[1])
32 | * (img_size[0] // patch_size[0])
33 | * (frames // t_patch_size)
34 | )
35 | self.input_size = (
36 | frames // t_patch_size,
37 | img_size[0] // patch_size[0],
38 | img_size[1] // patch_size[1],
39 | )
40 | print(
41 | f"img_size {img_size} patch_size {patch_size} frames {frames} t_patch_size {t_patch_size}"
42 | )
43 | self.img_size = img_size
44 | self.patch_size = patch_size
45 |
46 | self.frames = frames
47 | self.t_patch_size = t_patch_size
48 |
49 | self.num_patches = num_patches
50 |
51 | self.grid_size = img_size[0] // patch_size[0]
52 | self.t_grid_size = frames // t_patch_size
53 |
54 | kernel_size = [t_patch_size] + list(patch_size)
55 | self.proj = nn.Conv3d(
56 | in_chans, embed_dim, kernel_size=kernel_size, stride=kernel_size
57 | )
58 |
59 | def forward(self, x):
60 | B, C, T, H, W = x.shape
61 | assert (
62 | H == self.img_size[0] and W == self.img_size[1]
63 | ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
64 | assert T == self.frames
65 | x = self.proj(x).flatten(3)
66 | x = torch.einsum("ncts->ntsc", x) # [N, T, H*W, C]
67 | return x
68 |
69 |
70 | class Attention(nn.Module):
71 | def __init__(
72 | self,
73 | dim,
74 | num_heads=8,
75 | qkv_bias=False,
76 | qk_scale=None,
77 | attn_drop=0.0,
78 | proj_drop=0.0,
79 | input_size=(4, 14, 14),
80 | ):
81 | super().__init__()
82 | assert dim % num_heads == 0, "dim should be divisible by num_heads"
83 | self.num_heads = num_heads
84 | head_dim = dim // num_heads
85 | self.scale = qk_scale or head_dim**-0.5
86 |
87 | self.q = nn.Linear(dim, dim, bias=qkv_bias)
88 | self.k = nn.Linear(dim, dim, bias=qkv_bias)
89 | self.v = nn.Linear(dim, dim, bias=qkv_bias)
90 | assert attn_drop == 0.0 # do not use
91 | self.proj = nn.Linear(dim, dim)
92 | self.proj_drop = nn.Dropout(proj_drop)
93 | self.input_size = input_size
94 | assert input_size[1] == input_size[2]
95 |
96 | def forward(self, x):
97 | B, N, C = x.shape
98 | q = (
99 | self.q(x)
100 | .reshape(B, N, self.num_heads, C // self.num_heads)
101 | .permute(0, 2, 1, 3)
102 | )
103 | k = (
104 | self.k(x)
105 | .reshape(B, N, self.num_heads, C // self.num_heads)
106 | .permute(0, 2, 1, 3)
107 | )
108 | v = (
109 | self.v(x)
110 | .reshape(B, N, self.num_heads, C // self.num_heads)
111 | .permute(0, 2, 1, 3)
112 | )
113 |
114 | attn = (q @ k.transpose(-2, -1)) * self.scale
115 |
116 | attn = attn.softmax(dim=-1)
117 |
118 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
119 | x = self.proj(x)
120 | x = self.proj_drop(x)
121 | x = x.view(B, -1, C)
122 | return x
123 |
124 |
125 | class Block(nn.Module):
126 | """
127 | Transformer Block with specified Attention function
128 | """
129 |
130 | def __init__(
131 | self,
132 | dim,
133 | num_heads,
134 | mlp_ratio=4.0,
135 | qkv_bias=False,
136 | qk_scale=None,
137 | drop=0.0,
138 | attn_drop=0.0,
139 | drop_path=0.0,
140 | act_layer=nn.GELU,
141 | norm_layer=nn.LayerNorm,
142 | attn_func=Attention,
143 | ):
144 | super().__init__()
145 | self.norm1 = norm_layer(dim)
146 | self.attn = attn_func(
147 | dim,
148 | num_heads=num_heads,
149 | qkv_bias=qkv_bias,
150 | qk_scale=qk_scale,
151 | attn_drop=attn_drop,
152 | proj_drop=drop,
153 | )
154 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
155 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
156 | self.norm2 = norm_layer(dim)
157 | mlp_hidden_dim = int(dim * mlp_ratio)
158 | self.mlp = Mlp(
159 | in_features=dim,
160 | hidden_features=mlp_hidden_dim,
161 | act_layer=act_layer,
162 | drop=drop,
163 | )
164 |
165 | def forward(self, x):
166 | x = x + self.drop_path(self.attn(self.norm1(x)))
167 | x = x + self.drop_path(self.mlp(self.norm2(x)))
168 | return x
169 |
--------------------------------------------------------------------------------
/EchoFM/engine_finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import math
13 | import sys
14 | from typing import Iterable, Optional
15 |
16 | import EchoFM.util.lr_sched as lr_sched
17 | import EchoFM.util.misc as misc
18 | import torch
19 | from EchoFM.util.logging import master_print as print
20 | from timm.data import Mixup
21 | from timm.utils import accuracy
22 |
23 |
24 | def train_one_epoch(
25 | model: torch.nn.Module,
26 | criterion: torch.nn.Module,
27 | data_loader: Iterable,
28 | optimizer: torch.optim.Optimizer,
29 | device: torch.device,
30 | epoch: int,
31 | loss_scaler,
32 | max_norm: float = 0,
33 | mixup_fn: Optional[Mixup] = None,
34 | log_writer=None,
35 | args=None,
36 | fp32=False,
37 | ):
38 | model.train(True)
39 | metric_logger = misc.MetricLogger(delimiter=" ")
40 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}"))
41 | metric_logger.add_meter(
42 | "cpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")
43 | )
44 | metric_logger.add_meter(
45 | "cpu_mem_all", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")
46 | )
47 | metric_logger.add_meter(
48 | "gpu_mem", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")
49 | )
50 | header = "Epoch: [{}]".format(epoch)
51 | print_freq = 20
52 |
53 | accum_iter = args.accum_iter
54 |
55 | optimizer.zero_grad()
56 |
57 | if log_writer is not None:
58 | print("log_dir: {}".format(log_writer.log_dir))
59 |
60 | for data_iter_step, (samples, targets) in enumerate(
61 | metric_logger.log_every(data_loader, print_freq, header)
62 | ):
63 |
64 | # we use a per iteration (instead of per epoch) lr scheduler
65 | if data_iter_step % accum_iter == 0:
66 | lr_sched.adjust_learning_rate(
67 | optimizer, data_iter_step / len(data_loader) + epoch, args
68 | )
69 |
70 | if len(samples.shape) == 6:
71 | b, r, c, t, h, w = samples.shape
72 | samples = samples.view(b * r, c, t, h, w)
73 | targets = targets.view(b * r)
74 |
75 | if args.cpu_mix:
76 | if mixup_fn is not None:
77 | samples, targets = mixup_fn(samples, targets)
78 | samples = samples.to(device, non_blocking=True)
79 | targets = targets.to(device, non_blocking=True)
80 | else:
81 | samples = samples.to(device, non_blocking=True)
82 | targets = targets.to(device, non_blocking=True)
83 | if mixup_fn is not None:
84 | samples, targets = mixup_fn(samples, targets)
85 |
86 | with torch.cuda.amp.autocast(enabled=not fp32):
87 | outputs = model(samples)
88 | loss = criterion(outputs, targets)
89 |
90 | loss_value = loss.item()
91 |
92 | if not math.isfinite(loss_value):
93 | print("Loss is {}, stopping training".format(loss_value))
94 | sys.exit(1)
95 |
96 | loss /= accum_iter
97 | loss_scaler(
98 | loss,
99 | optimizer,
100 | clip_grad=max_norm,
101 | parameters=model.parameters(),
102 | create_graph=False,
103 | update_grad=(data_iter_step + 1) % accum_iter == 0,
104 | )
105 | if (data_iter_step + 1) % accum_iter == 0:
106 | optimizer.zero_grad()
107 |
108 | torch.cuda.synchronize()
109 |
110 | metric_logger.update(loss=loss_value)
111 | metric_logger.update(cpu_mem=misc.cpu_mem_usage()[0])
112 | metric_logger.update(cpu_mem_all=misc.cpu_mem_usage()[1])
113 | metric_logger.update(gpu_mem=misc.gpu_mem_usage())
114 | min_lr = 10.0
115 | max_lr = 0.0
116 | for group in optimizer.param_groups:
117 | min_lr = min(min_lr, group["lr"])
118 | max_lr = max(max_lr, group["lr"])
119 |
120 | metric_logger.update(lr=max_lr)
121 |
122 | loss_value_reduce = misc.all_reduce_mean(loss_value)
123 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
124 | """We use epoch_1000x as the x-axis in tensorboard.
125 | This calibrates different curves when batch size changes.
126 | """
127 | epoch_1000x = int(
128 | (data_iter_step / len(data_loader) + epoch) * 1000 * args.repeat_aug
129 | )
130 | log_writer.add_scalar("loss", loss_value_reduce, epoch_1000x)
131 | log_writer.add_scalar("lr", max_lr, epoch_1000x)
132 |
133 | # gather the stats from all processes
134 | metric_logger.synchronize_between_processes()
135 | print("Averaged stats:", metric_logger)
136 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
137 |
138 |
139 | @torch.no_grad()
140 | def evaluate(data_loader, model, device):
141 | criterion = torch.nn.CrossEntropyLoss()
142 |
143 | metric_logger = misc.MetricLogger(delimiter=" ")
144 | header = "Test:"
145 |
146 | # switch to evaluation mode
147 | model.eval()
148 |
149 | for batch in metric_logger.log_every(data_loader, 10, header):
150 | images = batch[0]
151 | target = batch[-1]
152 | images = images.to(device, non_blocking=True)
153 | target = target.to(device, non_blocking=True)
154 |
155 | if len(images.shape) == 6:
156 | b, r, c, t, h, w = images.shape
157 | images = images.view(b * r, c, t, h, w)
158 | target = target.view(b * r)
159 |
160 | # compute output
161 | with torch.cuda.amp.autocast():
162 | output = model(images)
163 | loss = criterion(output, target)
164 |
165 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
166 |
167 | batch_size = images.shape[0]
168 | metric_logger.update(loss=loss.item())
169 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
170 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
171 | # gather the stats from all processes
172 | metric_logger.synchronize_between_processes()
173 | print(
174 | "* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format(
175 | top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss
176 | )
177 | )
178 |
179 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
180 |
--------------------------------------------------------------------------------
/EchoFM/models_vit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # DeiT: https://github.com/facebookresearch/deit
10 | # MAE: https://github.com/facebookresearch/mae
11 | # --------------------------------------------------------
12 |
13 | from functools import partial
14 |
15 | import torch
16 | import torch.nn as nn
17 | from EchoFM.util.logging import master_print as print
18 |
19 | from EchoFM.util.video_vit import Attention, Block, PatchEmbed
20 |
21 |
22 | class VisionTransformer(nn.Module):
23 | """Vision Transformer with support for global average pooling"""
24 |
25 | def __init__(
26 | self,
27 | num_frames,
28 | t_patch_size,
29 | img_size=224,
30 | patch_size=16,
31 | in_chans=3,
32 | num_classes=400,
33 | embed_dim=768,
34 | depth=12,
35 | num_heads=12,
36 | mlp_ratio=4.0,
37 | no_qkv_bias=False,
38 | qk_scale=None,
39 | drop_rate=0.0,
40 | attn_drop_rate=0.0,
41 | drop_path_rate=0.0,
42 | norm_layer=nn.LayerNorm,
43 | dropout=0.5,
44 | sep_pos_embed=False,
45 | cls_embed=False,
46 | **kwargs,
47 | ):
48 | super().__init__()
49 | print(locals())
50 |
51 | self.sep_pos_embed = sep_pos_embed
52 | # --------------------------------------------------------------------------
53 | # MAE encoder specifics
54 | self.patch_embed = PatchEmbed(
55 | img_size, patch_size, in_chans, embed_dim, num_frames, t_patch_size
56 | )
57 | num_patches = self.patch_embed.num_patches
58 | input_size = self.patch_embed.input_size
59 | self.input_size = input_size
60 | self.cls_embed = cls_embed
61 |
62 | if self.cls_embed:
63 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
64 |
65 | if sep_pos_embed:
66 | self.pos_embed_spatial = nn.Parameter(
67 | torch.zeros(1, input_size[1] * input_size[2], embed_dim)
68 | )
69 | self.pos_embed_temporal = nn.Parameter(
70 | torch.zeros(1, input_size[0], embed_dim)
71 | )
72 | if self.cls_embed:
73 | self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim))
74 | else:
75 | if self.cls_embed:
76 | _num_patches = num_patches + 1
77 | else:
78 | _num_patches = num_patches
79 |
80 | self.pos_embed = nn.Parameter(
81 | torch.zeros(1, _num_patches, embed_dim), requires_grad=True
82 | ) # fixed or not?
83 |
84 | dpr = [
85 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
86 | ] # stochastic depth decay rule
87 |
88 | self.blocks = nn.ModuleList(
89 | [
90 | Block(
91 | embed_dim,
92 | num_heads,
93 | mlp_ratio,
94 | qkv_bias=not no_qkv_bias,
95 | qk_scale=None,
96 | norm_layer=norm_layer,
97 | drop_path=dpr[i],
98 | attn_func=partial(
99 | Attention,
100 | input_size=self.patch_embed.input_size,
101 | ),
102 | )
103 | for i in range(depth)
104 | ]
105 | )
106 | self.norm = norm_layer(embed_dim)
107 | # --------------------------------------------------------------------------
108 |
109 | self.dropout = nn.Dropout(dropout)
110 | self.head = nn.Linear(embed_dim, num_classes)
111 |
112 | torch.nn.init.normal_(self.head.weight, std=0.02)
113 |
114 | @torch.jit.ignore
115 | def no_weight_decay(self):
116 | return {
117 | "cls_token",
118 | "pos_embed",
119 | "pos_embed_spatial",
120 | "pos_embed_temporal",
121 | "pos_embed_class",
122 | }
123 |
124 | def forward(self, x):
125 | # embed patches
126 | x = self.patch_embed(x)
127 | N, T, L, C = x.shape # T: temporal; L: spatial
128 |
129 | x = x.view([N, T * L, C])
130 |
131 | # append cls token
132 | if self.cls_embed:
133 | cls_token = self.cls_token
134 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
135 | x = torch.cat((cls_tokens, x), dim=1)
136 |
137 | if self.sep_pos_embed:
138 | pos_embed = self.pos_embed_spatial.repeat(
139 | 1, self.input_size[0], 1
140 | ) + torch.repeat_interleave(
141 | self.pos_embed_temporal,
142 | self.input_size[1] * self.input_size[2],
143 | dim=1,
144 | )
145 | if self.cls_embed:
146 | pos_embed = torch.cat(
147 | [
148 | self.pos_embed_class.expand(pos_embed.shape[0], -1, -1),
149 | pos_embed,
150 | ],
151 | 1,
152 | )
153 | else:
154 | pos_embed = self.pos_embed[:, :, :]
155 | x = x + pos_embed
156 |
157 | # reshape to [N, T, L, C] or [N, T*L, C]
158 | requires_t_shape = (
159 | len(self.blocks) > 0 # support empty decoder
160 | and hasattr(self.blocks[0].attn, "requires_t_shape")
161 | and self.blocks[0].attn.requires_t_shape
162 | )
163 | if requires_t_shape:
164 | x = x.view([N, T, L, C])
165 |
166 | # apply Transformer blocks
167 | for blk in self.blocks:
168 | x = blk(x)
169 |
170 | if requires_t_shape:
171 | x = x.view([N, T * L, C])
172 |
173 | # classifier
174 | x = x[:, 1:, :].mean(dim=1) # global pool
175 | x = self.norm(x)
176 | # x = self.fc_norm(x)
177 | x = self.dropout(x)
178 | x = self.head(x)
179 |
180 | return x
181 |
182 |
183 | def vit_base_patch16(**kwargs):
184 | model = VisionTransformer(
185 | patch_size=16,
186 | embed_dim=768,
187 | depth=12,
188 | num_heads=12,
189 | mlp_ratio=4,
190 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
191 | **kwargs,
192 | )
193 | return model
194 |
195 |
196 | def vit_large_patch16(**kwargs):
197 | model = VisionTransformer(
198 | patch_size=16,
199 | embed_dim=1024,
200 | depth=24,
201 | num_heads=16,
202 | mlp_ratio=4,
203 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
204 | **kwargs,
205 | )
206 | return model
207 |
208 |
209 | def vit_huge_patch14(**kwargs):
210 | model = VisionTransformer(
211 | patch_size=16,
212 | embed_dim=1280,
213 | depth=32,
214 | num_heads=16,
215 | mlp_ratio=4,
216 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
217 | **kwargs,
218 | )
219 | return model
220 |
--------------------------------------------------------------------------------
/EchoFM/util/decoder/mixup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 |
5 | """
6 | This implementation is based on
7 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/mixup.py,
8 | published under an Apache License 2.0.
9 |
10 | COMMENT FROM ORIGINAL:
11 | Mixup and Cutmix
12 | Papers:
13 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
14 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) # NOQA
15 | Code Reference:
16 | CutMix: https://github.com/clovaai/CutMix-PyTorch
17 | Hacked together by / Copyright 2020 Ross Wightman
18 | """
19 |
20 | import numpy as np
21 | import torch
22 |
23 |
24 | def convert_to_one_hot(targets, num_classes, on_value=1.0, off_value=0.0):
25 | """
26 | This function converts target class indices to one-hot vectors, given the
27 | number of classes.
28 | Args:
29 | targets (loader): Class labels.
30 | num_classes (int): Total number of classes.
31 | on_value (float): Target Value for ground truth class.
32 | off_value (float): Target Value for other classes.This value is used for
33 | label smoothing.
34 | """
35 |
36 | targets = targets.long().view(-1, 1)
37 | return torch.full(
38 | (targets.size()[0], num_classes), off_value, device=targets.device
39 | ).scatter_(1, targets, on_value)
40 |
41 |
42 | def mixup_target(target, num_classes, lam=1.0, smoothing=0.0):
43 | """
44 | This function converts target class indices to one-hot vectors, given the
45 | number of classes.
46 | Args:
47 | targets (loader): Class labels.
48 | num_classes (int): Total number of classes.
49 | lam (float): lamba value for mixup/cutmix.
50 | smoothing (float): Label smoothing value.
51 | """
52 | off_value = smoothing / num_classes
53 | on_value = 1.0 - smoothing + off_value
54 | target1 = convert_to_one_hot(
55 | target,
56 | num_classes,
57 | on_value=on_value,
58 | off_value=off_value,
59 | )
60 | target2 = convert_to_one_hot(
61 | target.flip(0),
62 | num_classes,
63 | on_value=on_value,
64 | off_value=off_value,
65 | )
66 | return target1 * lam + target2 * (1.0 - lam)
67 |
68 |
69 | def rand_bbox(img_shape, lam, margin=0.0, count=None):
70 | """
71 | Generates a random square bbox based on lambda value.
72 |
73 | Args:
74 | img_shape (tuple): Image shape as tuple
75 | lam (float): Cutmix lambda value
76 | margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
77 | count (int): Number of bbox to generate
78 | """
79 | ratio = np.sqrt(1 - lam)
80 | img_h, img_w = img_shape[-2:]
81 | cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
82 | margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
83 | cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
84 | cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
85 | yl = np.clip(cy - cut_h // 2, 0, img_h)
86 | yh = np.clip(cy + cut_h // 2, 0, img_h)
87 | xl = np.clip(cx - cut_w // 2, 0, img_w)
88 | xh = np.clip(cx + cut_w // 2, 0, img_w)
89 | return yl, yh, xl, xh
90 |
91 |
92 | def get_cutmix_bbox(img_shape, lam, correct_lam=True, count=None):
93 | """
94 | Generates the box coordinates for cutmix.
95 |
96 | Args:
97 | img_shape (tuple): Image shape as tuple
98 | lam (float): Cutmix lambda value
99 | correct_lam (bool): Apply lambda correction when cutmix bbox clipped by
100 | image borders.
101 | count (int): Number of bbox to generate
102 | """
103 |
104 | yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
105 | if correct_lam:
106 | bbox_area = (yu - yl) * (xu - xl)
107 | lam = 1.0 - bbox_area / float(img_shape[-2] * img_shape[-1])
108 | return (yl, yu, xl, xu), lam
109 |
110 |
111 | class MixUp:
112 | """
113 | Apply mixup and/or cutmix for videos at batch level.
114 | mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
115 | CutMix: Regularization Strategy to Train Strong Classifiers with Localizable
116 | Features (https://arxiv.org/abs/1905.04899)
117 | """
118 |
119 | def __init__(
120 | self,
121 | mixup_alpha=1.0,
122 | cutmix_alpha=0.0,
123 | mix_prob=1.0,
124 | switch_prob=0.5,
125 | correct_lam=True,
126 | label_smoothing=0.1,
127 | num_classes=1000,
128 | ):
129 | """
130 | Args:
131 | mixup_alpha (float): Mixup alpha value.
132 | cutmix_alpha (float): Cutmix alpha value.
133 | mix_prob (float): Probability of applying mixup or cutmix.
134 | switch_prob (float): Probability of switching to cutmix instead of
135 | mixup when both are active.
136 | correct_lam (bool): Apply lambda correction when cutmix bbox
137 | clipped by image borders.
138 | label_smoothing (float): Apply label smoothing to the mixed target
139 | tensor. If label_smoothing is not used, set it to 0.
140 | num_classes (int): Number of classes for target.
141 | """
142 | self.mixup_alpha = mixup_alpha
143 | self.cutmix_alpha = cutmix_alpha
144 | self.mix_prob = mix_prob
145 | self.switch_prob = switch_prob
146 | self.label_smoothing = label_smoothing
147 | self.num_classes = num_classes
148 | self.correct_lam = correct_lam
149 |
150 | def _get_mixup_params(self):
151 | lam = 1.0
152 | use_cutmix = False
153 | if np.random.rand() < self.mix_prob:
154 | if self.mixup_alpha > 0.0 and self.cutmix_alpha > 0.0:
155 | use_cutmix = np.random.rand() < self.switch_prob
156 | lam_mix = (
157 | np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
158 | if use_cutmix
159 | else np.random.beta(self.mixup_alpha, self.mixup_alpha)
160 | )
161 | elif self.mixup_alpha > 0.0:
162 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
163 | elif self.cutmix_alpha > 0.0:
164 | use_cutmix = True
165 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
166 | lam = float(lam_mix)
167 | return lam, use_cutmix
168 |
169 | def _mix_batch(self, x):
170 | lam, use_cutmix = self._get_mixup_params()
171 | if lam == 1.0:
172 | return 1.0
173 | if use_cutmix:
174 | (yl, yh, xl, xh), lam = get_cutmix_bbox(
175 | x.shape,
176 | lam,
177 | correct_lam=self.correct_lam,
178 | )
179 | x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh]
180 | else:
181 | x_flipped = x.flip(0).mul_(1.0 - lam)
182 | x.mul_(lam).add_(x_flipped)
183 | return lam
184 |
185 | def __call__(self, x, target):
186 | assert len(x) > 1, "Batch size should be greater than 1 for mixup."
187 | lam = self._mix_batch(x)
188 | target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
189 | return x, target
190 |
--------------------------------------------------------------------------------
/EchoFM/util/decoder/random_erasing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 |
5 | """
6 | This implementation is based on
7 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
8 | pulished under an Apache License 2.0.
9 |
10 | COMMENT FROM ORIGINAL:
11 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
12 | Copyright Zhun Zhong & Liang Zheng
13 | Hacked together by / Copyright 2020 Ross Wightman
14 | """
15 | import math
16 | import random
17 |
18 | import torch
19 |
20 |
21 | def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"):
22 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
23 | # paths, flip the order so normal is run on CPU if this becomes a problem
24 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
25 | if per_pixel:
26 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
27 | elif rand_color:
28 | return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
29 | else:
30 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
31 |
32 |
33 | class RandomErasing:
34 | """Randomly selects a rectangle region in an image and erases its pixels.
35 | 'Random Erasing Data Augmentation' by Zhong et al.
36 | See https://arxiv.org/pdf/1708.04896.pdf
37 | This variant of RandomErasing is intended to be applied to either a batch
38 | or single image tensor after it has been normalized by dataset mean and std.
39 | Args:
40 | probability: Probability that the Random Erasing operation will be performed.
41 | min_area: Minimum percentage of erased area wrt input image area.
42 | max_area: Maximum percentage of erased area wrt input image area.
43 | min_aspect: Minimum aspect ratio of erased area.
44 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
45 | 'const' - erase block is constant color of 0 for all channels
46 | 'rand' - erase block is same per-channel random (normal) color
47 | 'pixel' - erase block is per-pixel random (normal) color
48 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
49 | per-image count is randomly chosen between 1 and this value.
50 | """
51 |
52 | def __init__(
53 | self,
54 | probability=0.5,
55 | min_area=0.02,
56 | max_area=1 / 3,
57 | min_aspect=0.3,
58 | max_aspect=None,
59 | mode="const",
60 | min_count=1,
61 | max_count=None,
62 | num_splits=0,
63 | device="cuda",
64 | cube=True,
65 | ):
66 | self.probability = probability
67 | self.min_area = min_area
68 | self.max_area = max_area
69 | max_aspect = max_aspect or 1 / min_aspect
70 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
71 | self.min_count = min_count
72 | self.max_count = max_count or min_count
73 | self.num_splits = num_splits
74 | mode = mode.lower()
75 | self.rand_color = False
76 | self.per_pixel = False
77 | self.cube = cube
78 | if mode == "rand":
79 | self.rand_color = True # per block random normal
80 | elif mode == "pixel":
81 | self.per_pixel = True # per pixel random normal
82 | else:
83 | assert not mode or mode == "const"
84 | self.device = device
85 |
86 | def _erase(self, img, chan, img_h, img_w, dtype):
87 | if random.random() > self.probability:
88 | return
89 | area = img_h * img_w
90 | count = (
91 | self.min_count
92 | if self.min_count == self.max_count
93 | else random.randint(self.min_count, self.max_count)
94 | )
95 | for _ in range(count):
96 | for _ in range(10):
97 | target_area = (
98 | random.uniform(self.min_area, self.max_area) * area / count
99 | )
100 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
101 | h = int(round(math.sqrt(target_area * aspect_ratio)))
102 | w = int(round(math.sqrt(target_area / aspect_ratio)))
103 | if w < img_w and h < img_h:
104 | top = random.randint(0, img_h - h)
105 | left = random.randint(0, img_w - w)
106 | img[:, top : top + h, left : left + w] = _get_pixels(
107 | self.per_pixel,
108 | self.rand_color,
109 | (chan, h, w),
110 | dtype=dtype,
111 | device=self.device,
112 | )
113 | break
114 |
115 | def _erase_cube(
116 | self,
117 | img,
118 | batch_start,
119 | batch_size,
120 | chan,
121 | img_h,
122 | img_w,
123 | dtype,
124 | ):
125 | if random.random() > self.probability:
126 | return
127 | area = img_h * img_w
128 | count = (
129 | self.min_count
130 | if self.min_count == self.max_count
131 | else random.randint(self.min_count, self.max_count)
132 | )
133 | for _ in range(count):
134 | for _ in range(100):
135 | target_area = (
136 | random.uniform(self.min_area, self.max_area) * area / count
137 | )
138 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
139 | h = int(round(math.sqrt(target_area * aspect_ratio)))
140 | w = int(round(math.sqrt(target_area / aspect_ratio)))
141 | if w < img_w and h < img_h:
142 | top = random.randint(0, img_h - h)
143 | left = random.randint(0, img_w - w)
144 | for i in range(batch_start, batch_size):
145 | img_instance = img[i]
146 | img_instance[:, top : top + h, left : left + w] = _get_pixels(
147 | self.per_pixel,
148 | self.rand_color,
149 | (chan, h, w),
150 | dtype=dtype,
151 | device=self.device,
152 | )
153 | break
154 |
155 | def __call__(self, input):
156 | if len(input.size()) == 3:
157 | self._erase(input, *input.size(), input.dtype)
158 | else:
159 | batch_size, chan, img_h, img_w = input.size()
160 | # skip first slice of batch if num_splits is set (for clean portion of samples)
161 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
162 | if self.cube:
163 | self._erase_cube(
164 | input,
165 | batch_start,
166 | batch_size,
167 | chan,
168 | img_h,
169 | img_w,
170 | input.dtype,
171 | )
172 | else:
173 | for i in range(batch_start, batch_size):
174 | self._erase(input[i], chan, img_h, img_w, input.dtype)
175 | return input
176 |
--------------------------------------------------------------------------------
/EchoFM/util/decoder/decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 |
5 | import math
6 | import random
7 |
8 | import numpy as np
9 | import torch
10 | import torchvision.io as io
11 |
12 |
13 | def temporal_sampling(frames, start_idx, end_idx, num_samples):
14 | """
15 | Given the start and end frame index, sample num_samples frames between
16 | the start and end with equal interval.
17 | Args:
18 | frames (tensor): a tensor of video frames, dimension is
19 | `num video frames` x `channel` x `height` x `width`.
20 | start_idx (int): the index of the start frame.
21 | end_idx (int): the index of the end frame.
22 | num_samples (int): number of frames to sample.
23 | Returns:
24 | frames (tersor): a tensor of temporal sampled video frames, dimension is
25 | `num clip frames` x `channel` x `height` x `width`.
26 | """
27 | index = torch.linspace(start_idx, end_idx, num_samples)
28 | index = torch.clamp(index, 0, frames.shape[0] - 1).long()
29 | new_frames = torch.index_select(frames, 0, index)
30 | return new_frames
31 |
32 |
33 | def get_start_end_idx(video_size, clip_size, clip_idx, num_clips, use_offset=False):
34 | """
35 | Sample a clip of size clip_size from a video of size video_size and
36 | return the indices of the first and last frame of the clip. If clip_idx is
37 | -1, the clip is randomly sampled, otherwise uniformly split the video to
38 | num_clips clips, and select the start and end index of clip_idx-th video
39 | clip.
40 | Args:
41 | video_size (int): number of overall frames.
42 | clip_size (int): size of the clip to sample from the frames.
43 | clip_idx (int): if clip_idx is -1, perform random jitter sampling. If
44 | clip_idx is larger than -1, uniformly split the video to num_clips
45 | clips, and select the start and end index of the clip_idx-th video
46 | clip.
47 | num_clips (int): overall number of clips to uniformly sample from the
48 | given video for testing.
49 | Returns:
50 | start_idx (int): the start frame index.
51 | end_idx (int): the end frame index.
52 | """
53 | delta = max(video_size - clip_size, 0)
54 | if clip_idx == -1:
55 | # Random temporal sampling.
56 | start_idx = random.uniform(0, delta)
57 | else:
58 | if use_offset:
59 | if num_clips == 1:
60 | # Take the center clip if num_clips is 1.
61 | start_idx = math.floor(delta / 2)
62 | else:
63 | # Uniformly sample the clip with the given index.
64 | start_idx = clip_idx * math.floor(delta / (num_clips - 1))
65 | else:
66 | # Uniformly sample the clip with the given index.
67 | start_idx = delta * clip_idx / num_clips
68 | end_idx = start_idx + clip_size - 1
69 | return start_idx, end_idx
70 |
71 |
72 | def decode(
73 | container,
74 | sampling_rate,
75 | num_frames,
76 | clip_idx=-1,
77 | num_clips=10,
78 | video_meta=None,
79 | target_fps=30,
80 | max_spatial_scale=0,
81 | use_offset=False,
82 | rigid_decode_all_video=True,
83 | modalities=("visual",),
84 | ):
85 | """
86 | Decode the video and perform temporal sampling.
87 | Args:
88 | container (container): pyav container.
89 | sampling_rate (int): frame sampling rate (interval between two sampled
90 | frames).
91 | num_frames (int): number of frames to sample.
92 | clip_idx (int): if clip_idx is -1, perform random temporal
93 | sampling. If clip_idx is larger than -1, uniformly split the
94 | video to num_clips clips, and select the
95 | clip_idx-th video clip.
96 | num_clips (int): overall number of clips to uniformly
97 | sample from the given video.
98 | video_meta (dict): a dict contains VideoMetaData. Details can be find
99 | at `pytorch/vision/torchvision/io/_video_opt.py`.
100 | target_fps (int): the input video may have different fps, convert it to
101 | the target video fps before frame sampling.
102 | max_spatial_scale (int): keep the aspect ratio and resize the frame so
103 | that shorter edge size is max_spatial_scale. Only used in
104 | `torchvision` backend.
105 | Returns:
106 | frames (tensor): decoded frames from the video.
107 | """
108 | try:
109 | assert clip_idx >= -1, "Not valied clip_idx {}".format(clip_idx)
110 | # Convert the bytes to a tensor.
111 | video_tensor = torch.from_numpy(np.frombuffer(container, dtype=np.uint8))
112 |
113 | decode_all_video = True
114 | video_start_pts, video_end_pts = 0, -1
115 | # The video_meta is empty, fetch the meta data from the raw video.
116 | if len(video_meta) == 0:
117 | # Tracking the meta info for selective decoding in the future.
118 | meta = io._probe_video_from_memory(video_tensor)
119 | # Using the information from video_meta to perform selective decoding.
120 | video_meta["video_timebase"] = meta.video_timebase
121 | video_meta["video_numerator"] = meta.video_timebase.numerator
122 | video_meta["video_denominator"] = meta.video_timebase.denominator
123 | video_meta["has_video"] = meta.has_video
124 | video_meta["video_duration"] = meta.video_duration
125 | video_meta["video_fps"] = meta.video_fps
126 | video_meta["audio_timebas"] = meta.audio_timebase
127 | video_meta["audio_numerator"] = meta.audio_timebase.numerator
128 | video_meta["audio_denominator"] = meta.audio_timebase.denominator
129 | video_meta["has_audio"] = meta.has_audio
130 | video_meta["audio_duration"] = meta.audio_duration
131 | video_meta["audio_sample_rate"] = meta.audio_sample_rate
132 |
133 | fps = video_meta["video_fps"]
134 | if not rigid_decode_all_video:
135 | if (
136 | video_meta["has_video"]
137 | and video_meta["video_denominator"] > 0
138 | and video_meta["video_duration"] > 0
139 | ):
140 | # try selective decoding.
141 | decode_all_video = False
142 | clip_size = sampling_rate * num_frames / target_fps * fps
143 | start_idx, end_idx = get_start_end_idx(
144 | fps * video_meta["video_duration"],
145 | clip_size,
146 | clip_idx,
147 | num_clips,
148 | use_offset=use_offset,
149 | )
150 | # Convert frame index to pts.
151 | pts_per_frame = video_meta["video_denominator"] / fps
152 | video_start_pts = int(start_idx * pts_per_frame)
153 | video_end_pts = int(end_idx * pts_per_frame)
154 |
155 | # Decode the raw video with the tv decoder.
156 | v_frames, _ = io._read_video_from_memory(
157 | video_tensor,
158 | seek_frame_margin=1.0,
159 | read_video_stream="visual" in modalities,
160 | video_width=0,
161 | video_height=0,
162 | video_min_dimension=max_spatial_scale,
163 | video_pts_range=(video_start_pts, video_end_pts),
164 | video_timebase_numerator=video_meta["video_numerator"],
165 | video_timebase_denominator=video_meta["video_denominator"],
166 | )
167 |
168 | if v_frames.shape == torch.Size([0]):
169 | # failed selective decoding
170 | decode_all_video = True
171 | video_start_pts, video_end_pts = 0, -1
172 | v_frames, _ = io._read_video_from_memory(
173 | video_tensor,
174 | seek_frame_margin=1.0,
175 | read_video_stream="visual" in modalities,
176 | video_width=0,
177 | video_height=0,
178 | video_min_dimension=max_spatial_scale,
179 | video_pts_range=(video_start_pts, video_end_pts),
180 | video_timebase_numerator=video_meta["video_numerator"],
181 | video_timebase_denominator=video_meta["video_denominator"],
182 | )
183 | except Exception as e:
184 | print("Failed to decode by torchvision with exception: {}".format(e))
185 | return None
186 |
187 | # Return None if the frames was not decoded successfully.
188 | if v_frames is None or v_frames.size(0) == 0:
189 | return None, fps, decode_all_video
190 | return v_frames, fps, decode_all_video
191 |
--------------------------------------------------------------------------------
/EchoFM/util/meters.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 |
5 | import numpy as np
6 | import torch
7 | from sklearn.metrics import average_precision_score
8 |
9 |
10 | def topks_correct(preds, labels, ks):
11 | """
12 | Given the predictions, labels, and a list of top-k values, compute the
13 | number of correct predictions for each top-k value.
14 |
15 | Args:
16 | preds (array): array of predictions. Dimension is batchsize
17 | N x ClassNum.
18 | labels (array): array of labels. Dimension is batchsize N.
19 | ks (list): list of top-k values. For example, ks = [1, 5] correspods
20 | to top-1 and top-5.
21 |
22 | Returns:
23 | topks_correct (list): list of numbers, where the `i`-th entry
24 | corresponds to the number of top-`ks[i]` correct predictions.
25 | """
26 | assert preds.size(0) == labels.size(
27 | 0
28 | ), "Batch dim of predictions and labels must match"
29 | # Find the top max_k predictions for each sample
30 | _top_max_k_vals, top_max_k_inds = torch.topk(
31 | preds, max(ks), dim=1, largest=True, sorted=True
32 | )
33 | # (batch_size, max_k) -> (max_k, batch_size).
34 | top_max_k_inds = top_max_k_inds.t()
35 | # (batch_size, ) -> (max_k, batch_size).
36 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds)
37 | # (i, j) = 1 if top i-th prediction for the j-th sample is correct.
38 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels)
39 | # Compute the number of topk correct predictions for each k.
40 | topks_correct = [top_max_k_correct[:k, :].float().sum() for k in ks]
41 | return topks_correct
42 |
43 |
44 | def topk_errors(preds, labels, ks):
45 | """
46 | Computes the top-k error for each k.
47 | Args:
48 | preds (array): array of predictions. Dimension is N.
49 | labels (array): array of labels. Dimension is N.
50 | ks (list): list of ks to calculate the top accuracies.
51 | """
52 | num_topks_correct = topks_correct(preds, labels, ks)
53 | return [(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct]
54 |
55 |
56 | def topk_accuracies(preds, labels, ks):
57 | """
58 | Computes the top-k accuracy for each k.
59 | Args:
60 | preds (array): array of predictions. Dimension is N.
61 | labels (array): array of labels. Dimension is N.
62 | ks (list): list of ks to calculate the top accuracies.
63 | """
64 | num_topks_correct = topks_correct(preds, labels, ks)
65 | return [(x / preds.size(0)) * 100.0 for x in num_topks_correct]
66 |
67 |
68 | def get_map(preds, labels):
69 | """
70 | Compute mAP for multi-label case.
71 | Args:
72 | preds (numpy tensor): num_examples x num_classes.
73 | labels (numpy tensor): num_examples x num_classes.
74 | Returns:
75 | mean_ap (int): final mAP score.
76 | """
77 |
78 | print("Getting mAP for {} examples".format(preds.shape[0]))
79 |
80 | preds = preds[:, ~(np.all(labels == 0, axis=0))]
81 | labels = labels[:, ~(np.all(labels == 0, axis=0))]
82 | aps = [0]
83 | try:
84 | aps = average_precision_score(labels, preds, average=None)
85 | except ValueError:
86 | print(
87 | "Average precision requires a sufficient number of samples \
88 | in a batch which are missing in this sample."
89 | )
90 |
91 | mean_ap = np.mean(aps)
92 | return mean_ap
93 |
94 |
95 | class TestMeter:
96 | """
97 | Perform the multi-view ensemble for testing: each video with an unique index
98 | will be sampled with multiple clips, and the predictions of the clips will
99 | be aggregated to produce the final prediction for the video.
100 | The accuracy is calculated with the given ground truth labels.
101 | """
102 |
103 | def __init__(
104 | self,
105 | num_videos,
106 | num_clips,
107 | num_cls,
108 | overall_iters,
109 | multi_label=False,
110 | ensemble_method="sum",
111 | ):
112 | """
113 | Construct tensors to store the predictions and labels. Expect to get
114 | num_clips predictions from each video, and calculate the metrics on
115 | num_videos videos.
116 | Args:
117 | num_videos (int): number of videos to test.
118 | num_clips (int): number of clips sampled from each video for
119 | aggregating the final prediction for the video.
120 | num_cls (int): number of classes for each prediction.
121 | overall_iters (int): overall iterations for testing.
122 | multi_label (bool): if True, use map as the metric.
123 | ensemble_method (str): method to perform the ensemble, options
124 | include "sum", and "max".
125 | """
126 |
127 | self.num_clips = num_clips
128 | self.overall_iters = overall_iters
129 | self.multi_label = multi_label
130 | self.ensemble_method = ensemble_method
131 | # Initialize tensors.
132 | self.video_preds = torch.zeros((num_videos, num_cls))
133 | if multi_label:
134 | self.video_preds -= 1e10
135 |
136 | self.video_labels = (
137 | torch.zeros((num_videos, num_cls))
138 | if multi_label
139 | else torch.zeros((num_videos)).long()
140 | )
141 | self.clip_count = torch.zeros((num_videos)).long()
142 | self.topk_accs = []
143 | self.stats = {}
144 |
145 | # Reset metric.
146 | self.reset()
147 |
148 | def reset(self):
149 | """
150 | Reset the metric.
151 | """
152 | self.clip_count.zero_()
153 | self.video_preds.zero_()
154 | if self.multi_label:
155 | self.video_preds -= 1e10
156 | self.video_labels.zero_()
157 |
158 | def update_stats(self, preds, labels, clip_ids):
159 | """
160 | Collect the predictions from the current batch and perform on-the-flight
161 | summation as ensemble.
162 | Args:
163 | preds (tensor): predictions from the current batch. Dimension is
164 | N x C where N is the batch size and C is the channel size
165 | (num_cls).
166 | labels (tensor): the corresponding labels of the current batch.
167 | Dimension is N.
168 | clip_ids (tensor): clip indexes of the current batch, dimension is
169 | N.
170 | """
171 | for ind in range(preds.shape[0]):
172 | vid_id = int(clip_ids[ind]) // self.num_clips
173 | if self.video_labels[vid_id].sum() > 0:
174 | assert torch.equal(
175 | self.video_labels[vid_id].type(torch.FloatTensor),
176 | labels[ind].type(torch.FloatTensor),
177 | )
178 | self.video_labels[vid_id] = labels[ind]
179 | if self.ensemble_method == "sum":
180 | self.video_preds[vid_id] += preds[ind]
181 | elif self.ensemble_method == "max":
182 | self.video_preds[vid_id] = torch.max(
183 | self.video_preds[vid_id], preds[ind]
184 | )
185 | else:
186 | raise NotImplementedError(
187 | "Ensemble Method {} is not supported".format(self.ensemble_method)
188 | )
189 | self.clip_count[vid_id] += 1
190 |
191 | def log_iter_stats(self, cur_iter):
192 | """
193 | Log the stats.
194 | Args:
195 | cur_iter (int): the current iteration of testing.
196 | """
197 | stats = {
198 | "split": "test_iter",
199 | "cur_iter": "{}".format(cur_iter + 1),
200 | }
201 | print(stats)
202 |
203 | def finalize_metrics(self, ks=(1, 5)):
204 | """
205 | Calculate and log the final ensembled metrics.
206 | ks (tuple): list of top-k values for topk_accuracies. For example,
207 | ks = (1, 5) correspods to top-1 and top-5 accuracy.
208 | """
209 | if not all(self.clip_count == self.num_clips):
210 | print(
211 | "clip count {} ~= num clips {}".format(
212 | ", ".join(
213 | [
214 | "{}: {}".format(i, k)
215 | for i, k in enumerate(self.clip_count.tolist())
216 | ]
217 | ),
218 | self.num_clips,
219 | )
220 | )
221 |
222 | self.stats = {"split": "test_final"}
223 | if self.multi_label:
224 | map = get_map(
225 | self.video_preds.cpu().numpy(), self.video_labels.cpu().numpy()
226 | )
227 | self.stats["map"] = map
228 | else:
229 | num_topks_correct = topks_correct(self.video_preds, self.video_labels, ks)
230 | topks = [(x / self.video_preds.size(0)) * 100.0 for x in num_topks_correct]
231 | assert len({len(ks), len(topks)}) == 1
232 | for k, topk in zip(ks, topks):
233 | self.stats["top{}_acc".format(k)] = "{:.{prec}f}".format(topk, prec=2)
234 | print(self.stats)
235 |
--------------------------------------------------------------------------------
/EchoFM/util/decoder/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 |
5 | import logging
6 | import os
7 | import random
8 | import time
9 | from collections import defaultdict
10 |
11 | import cv2
12 | import numpy as np
13 | import torch
14 | from iopath.common.file_io import g_pathmgr as pathmgr
15 | from torch.utils.data.distributed import DistributedSampler
16 |
17 | from . import transform as transform
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def retry_load_images(image_paths, retry=10, backend="pytorch"):
23 | """
24 | This function is to load images with support of retrying for failed load.
25 |
26 | Args:
27 | image_paths (list): paths of images needed to be loaded.
28 | retry (int, optional): maximum time of loading retrying. Defaults to 10.
29 | backend (str): `pytorch` or `cv2`.
30 |
31 | Returns:
32 | imgs (list): list of loaded images.
33 | """
34 | for i in range(retry):
35 | imgs = []
36 | for image_path in image_paths:
37 | with pathmgr.open(image_path, "rb") as f:
38 | img_str = np.frombuffer(f.read(), np.uint8)
39 | img = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR)
40 | imgs.append(img)
41 |
42 | if all(img is not None for img in imgs):
43 | if backend == "pytorch":
44 | imgs = torch.as_tensor(np.stack(imgs))
45 | return imgs
46 | else:
47 | logger.warn("Reading failed. Will retry.")
48 | time.sleep(1.0)
49 | if i == retry - 1:
50 | raise Exception("Failed to load images {}".format(image_paths))
51 |
52 |
53 | def get_sequence(center_idx, half_len, sample_rate, num_frames):
54 | """
55 | Sample frames among the corresponding clip.
56 |
57 | Args:
58 | center_idx (int): center frame idx for current clip
59 | half_len (int): half of the clip length
60 | sample_rate (int): sampling rate for sampling frames inside of the clip
61 | num_frames (int): number of expected sampled frames
62 |
63 | Returns:
64 | seq (list): list of indexes of sampled frames in this clip.
65 | """
66 | seq = list(range(center_idx - half_len, center_idx + half_len, sample_rate))
67 |
68 | for seq_idx in range(len(seq)):
69 | if seq[seq_idx] < 0:
70 | seq[seq_idx] = 0
71 | elif seq[seq_idx] >= num_frames:
72 | seq[seq_idx] = num_frames - 1
73 | return seq
74 |
75 |
76 | def spatial_sampling(
77 | frames,
78 | spatial_idx=-1,
79 | min_scale=256,
80 | max_scale=320,
81 | crop_size=224,
82 | random_horizontal_flip=True,
83 | inverse_uniform_sampling=False,
84 | aspect_ratio=None,
85 | scale=None,
86 | motion_shift=False,
87 | ):
88 | """
89 | Perform spatial sampling on the given video frames. If spatial_idx is
90 | -1, perform random scale, random crop, and random flip on the given
91 | frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
92 | with the given spatial_idx.
93 | Args:
94 | frames (tensor): frames of images sampled from the video. The
95 | dimension is `num frames` x `height` x `width` x `channel`.
96 | spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
97 | or 2, perform left, center, right crop if width is larger than
98 | height, and perform top, center, buttom crop if height is larger
99 | than width.
100 | min_scale (int): the minimal size of scaling.
101 | max_scale (int): the maximal size of scaling.
102 | crop_size (int): the size of height and width used to crop the
103 | frames.
104 | inverse_uniform_sampling (bool): if True, sample uniformly in
105 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
106 | scale. If False, take a uniform sample from [min_scale,
107 | max_scale].
108 | aspect_ratio (list): Aspect ratio range for resizing.
109 | scale (list): Scale range for resizing.
110 | motion_shift (bool): Whether to apply motion shift for resizing.
111 | Returns:
112 | frames (tensor): spatially sampled frames.
113 | """
114 | assert spatial_idx in [-1, 0, 1, 2]
115 | if spatial_idx == -1:
116 | if aspect_ratio is None and scale is None:
117 | frames = transform.random_short_side_scale_jitter(
118 | images=frames,
119 | min_size=min_scale,
120 | max_size=max_scale,
121 | inverse_uniform_sampling=inverse_uniform_sampling,
122 | )
123 | frames = transform.random_crop(frames, crop_size)
124 | else:
125 | transform_func = (
126 | transform.random_resized_crop_with_shift
127 | if motion_shift
128 | else transform.random_resized_crop
129 | )
130 | frames = transform_func(
131 | images=frames,
132 | target_height=crop_size,
133 | target_width=crop_size,
134 | scale=scale,
135 | ratio=aspect_ratio,
136 | )
137 | if random_horizontal_flip:
138 | frames = transform.horizontal_flip(0.5, frames)
139 | else:
140 | # The testing is deterministic and no jitter should be performed.
141 | # min_scale, max_scale, and crop_size are expect to be the same.
142 | assert len({min_scale, max_scale}) == 1
143 | frames = transform.random_short_side_scale_jitter(frames, min_scale, max_scale)
144 | frames = transform.uniform_crop(frames, crop_size, spatial_idx)
145 | return frames
146 |
147 |
148 | def as_binary_vector(labels, num_classes):
149 | """
150 | Construct binary label vector given a list of label indices.
151 | Args:
152 | labels (list): The input label list.
153 | num_classes (int): Number of classes of the label vector.
154 | Returns:
155 | labels (numpy array): the resulting binary vector.
156 | """
157 | label_arr = np.zeros((num_classes,))
158 |
159 | for lbl in set(labels):
160 | label_arr[lbl] = 1.0
161 | return label_arr
162 |
163 |
164 | def aggregate_labels(label_list):
165 | """
166 | Join a list of label list.
167 | Args:
168 | labels (list): The input label list.
169 | Returns:
170 | labels (list): The joint list of all lists in input.
171 | """
172 | all_labels = []
173 | for labels in label_list:
174 | for l in labels:
175 | all_labels.append(l)
176 | return list(set(all_labels))
177 |
178 |
179 | def convert_to_video_level_labels(labels):
180 | """
181 | Aggregate annotations from all frames of a video to form video-level labels.
182 | Args:
183 | labels (list): The input label list.
184 | Returns:
185 | labels (list): Same as input, but with each label replaced by
186 | a video-level one.
187 | """
188 | for video_id in range(len(labels)):
189 | video_level_labels = aggregate_labels(labels[video_id])
190 | for i in range(len(labels[video_id])):
191 | labels[video_id][i] = video_level_labels
192 | return labels
193 |
194 |
195 | def load_image_lists(frame_list_file, prefix="", return_list=False):
196 | """
197 | Load image paths and labels from a "frame list".
198 | Each line of the frame list contains:
199 | `original_vido_id video_id frame_id path labels`
200 | Args:
201 | frame_list_file (string): path to the frame list.
202 | prefix (str): the prefix for the path.
203 | return_list (bool): if True, return a list. If False, return a dict.
204 | Returns:
205 | image_paths (list or dict): list of list containing path to each frame.
206 | If return_list is False, then return in a dict form.
207 | labels (list or dict): list of list containing label of each frame.
208 | If return_list is False, then return in a dict form.
209 | """
210 | image_paths = defaultdict(list)
211 | labels = defaultdict(list)
212 | with pathmgr.open(frame_list_file, "r") as f:
213 | assert f.readline().startswith("original_vido_id")
214 | for line in f:
215 | row = line.split()
216 | # original_vido_id video_id frame_id path labels
217 | assert len(row) == 5
218 | video_name = row[0]
219 | if prefix == "":
220 | path = row[3]
221 | else:
222 | path = os.path.join(prefix, row[3])
223 | image_paths[video_name].append(path)
224 | frame_labels = row[-1].replace('"', "")
225 | if frame_labels != "":
226 | labels[video_name].append([int(x) for x in frame_labels.split(",")])
227 | else:
228 | labels[video_name].append([])
229 |
230 | if return_list:
231 | keys = image_paths.keys()
232 | image_paths = [image_paths[key] for key in keys]
233 | labels = [labels[key] for key in keys]
234 | return image_paths, labels
235 | return dict(image_paths), dict(labels)
236 |
237 |
238 | def tensor_normalize(tensor, mean, std):
239 | """
240 | Normalize a given tensor by subtracting the mean and dividing the std.
241 | Args:
242 | tensor (tensor): tensor to normalize.
243 | mean (tensor or list): mean value to subtract.
244 | std (tensor or list): std to divide.
245 | """
246 | if tensor.dtype == torch.uint8:
247 | tensor = tensor.float()
248 | tensor = tensor / 255.0
249 | if type(mean) == tuple:
250 | mean = torch.tensor(mean)
251 | if type(std) == tuple:
252 | std = torch.tensor(std)
253 | tensor = tensor - mean
254 | tensor = tensor / std
255 | return tensor
256 |
257 |
258 | def get_random_sampling_rate(long_cycle_sampling_rate, sampling_rate):
259 | """
260 | When multigrid training uses a fewer number of frames, we randomly
261 | increase the sampling rate so that some clips cover the original span.
262 | """
263 | if long_cycle_sampling_rate > 0:
264 | assert long_cycle_sampling_rate >= sampling_rate
265 | return random.randint(sampling_rate, long_cycle_sampling_rate)
266 | else:
267 | return sampling_rate
268 |
269 |
270 | def revert_tensor_normalize(tensor, mean, std):
271 | """
272 | Revert normalization for a given tensor by multiplying by the std and adding the mean.
273 | Args:
274 | tensor (tensor): tensor to revert normalization.
275 | mean (tensor or list): mean value to add.
276 | std (tensor or list): std to multiply.
277 | """
278 | if type(mean) == list:
279 | mean = torch.tensor(mean)
280 | if type(std) == list:
281 | std = torch.tensor(std)
282 | tensor = tensor * std
283 | tensor = tensor + mean
284 | return tensor
285 |
286 |
287 | def create_sampler(dataset, shuffle, cfg):
288 | """
289 | Create sampler for the given dataset.
290 | Args:
291 | dataset (torch.utils.data.Dataset): the given dataset.
292 | shuffle (bool): set to ``True`` to have the data reshuffled
293 | at every epoch.
294 | cfg (CfgNode): configs. Details can be found in
295 | slowfast/config/defaults.py
296 | Returns:
297 | sampler (Sampler): the created sampler.
298 | """
299 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None
300 |
301 | return sampler
302 |
303 |
304 | def loader_worker_init_fn(dataset):
305 | """
306 | Create init function passed to pytorch data loader.
307 | Args:
308 | dataset (torch.utils.data.Dataset): the given dataset.
309 | """
310 | return None
311 |
--------------------------------------------------------------------------------
/main_pretrain.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 | import argparse
12 | import datetime
13 | import json
14 | import os
15 | import time
16 |
17 | # import mae_st.util.env
18 |
19 | import mae_st.util.misc as misc
20 |
21 | import numpy as np
22 | # import timm
23 | import torch
24 | import torch.backends.cudnn as cudnn
25 | from iopath.common.file_io import g_pathmgr as pathmgr
26 | from mae_st import models_mae
27 | from mae_st.engine_pretrain import train_one_epoch
28 | from mae_st.util.misc import NativeScalerWithGradNormCount as NativeScaler
29 |
30 | from torch.utils.tensorboard import SummaryWriter
31 | from data.dataset import EchoDataset_from_Video_mp4
32 | import torch.distributed as dist
33 |
34 |
35 | def get_args_parser():
36 | parser = argparse.ArgumentParser("MAE pre-training", add_help=False)
37 | parser.add_argument(
38 | "--batch_size",
39 | default=4,
40 | type=int,
41 | help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
42 | )
43 | parser.add_argument("--epochs", default=100, type=int)
44 | parser.add_argument(
45 | "--accum_iter",
46 | default=1,
47 | type=int,
48 | help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)",
49 | )
50 |
51 | # Model parameters
52 | parser.add_argument(
53 | "--model",
54 | default="mae_vit_large_patch16",
55 | type=str,
56 | metavar="MODEL",
57 | help="Name of model to train",
58 | )
59 |
60 | parser.add_argument("--input_size", default=224, type=int, help="images input size")
61 |
62 | parser.add_argument(
63 | "--mask_ratio",
64 | default=0.75,
65 | type=float,
66 | help="Masking ratio (percentage of removed patches).",
67 | )
68 |
69 | parser.add_argument(
70 | "--norm_pix_loss",
71 | action="store_true",
72 | help="Use (per-patch) normalized pixels as targets for computing loss",
73 | )
74 | parser.set_defaults(norm_pix_loss=False)
75 |
76 | # Optimizer parameters
77 | parser.add_argument(
78 | "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)"
79 | )
80 |
81 | parser.add_argument(
82 | "--lr",
83 | type=float,
84 | default=None,
85 | metavar="LR",
86 | help="learning rate (absolute lr)",
87 | )
88 | parser.add_argument(
89 | "--blr",
90 | type=float,
91 | default=1e-3,
92 | metavar="LR",
93 | help="base learning rate: absolute_lr = base_lr * total_batch_size / 256",
94 | )
95 | parser.add_argument(
96 | "--min_lr",
97 | type=float,
98 | default=0.0,
99 | metavar="LR",
100 | help="lower lr bound for cyclic schedulers that hit 0",
101 | )
102 |
103 | parser.add_argument(
104 | "--warmup_epochs", type=int, default=40, metavar="N", help="epochs to warmup LR"
105 | )
106 | parser.add_argument(
107 | "--path_to_data_dir",
108 | default="",
109 | help="path where to save, empty for no saving",
110 | )
111 | parser.add_argument(
112 | "--output_dir",
113 | default="./output_dir",
114 | )
115 | parser.add_argument(
116 | "--data_path",
117 | default="/raid/camca/sk1064/us/fullset/video/",
118 | help="path where to save, empty for no saving",
119 | )
120 | parser.add_argument(
121 | "--log_dir",
122 | default="",
123 | help="path where to tensorboard log",
124 | )
125 | parser.add_argument(
126 | "--device", default="cuda", help="device to use for training / testing"
127 | )
128 | parser.add_argument("--seed", default=0, type=int)
129 | parser.add_argument("--resume", default="", help="resume from checkpoint")
130 |
131 | parser.add_argument(
132 | "--start_epoch", default=0, type=int, metavar="N", help="start epoch"
133 | )
134 | parser.add_argument("--num_workers", default=10, type=int)
135 | parser.add_argument(
136 | "--pin_mem",
137 | action="store_true",
138 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
139 | )
140 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
141 | parser.set_defaults(pin_mem=True)
142 |
143 | # distributed training parameters
144 | parser.add_argument(
145 | "--world_size", default=1, type=int, help="number of distributed processes"
146 | )
147 | parser.add_argument("--local_rank", default=-1, type=int)
148 | parser.add_argument("--dist_on_itp", action="store_true")
149 | parser.add_argument("--no_env", action="store_true")
150 |
151 | # Video related configs
152 | parser.add_argument(
153 | "--dist_url", default="env://", help="url used to set up distributed training"
154 | )
155 |
156 | parser.add_argument("--decoder_embed_dim", default=512, type=int)
157 | parser.add_argument("--decoder_depth", default=8, type=int)
158 | parser.add_argument("--decoder_num_heads", default=16, type=int)
159 | parser.add_argument("--t_patch_size", default=4, type=int)
160 | parser.add_argument("--num_frames", default=32, type=int)
161 | parser.add_argument("--checkpoint_period", default=1, type=int)
162 | parser.add_argument("--sampling_rate", default=4, type=int)
163 | parser.add_argument("--distributed", default=True, type=bool, help="Enable distributed training")
164 | parser.add_argument("--repeat_aug", default=4, type=int)
165 | parser.add_argument(
166 | "--clip_grad",
167 | type=float,
168 | default=None,
169 | )
170 | parser.add_argument("--no_qkv_bias", action="store_true")
171 | parser.add_argument("--bias_wd", action="store_true")
172 | parser.add_argument("--num_checkpoint_del", default=20, type=int)
173 | parser.add_argument("--sep_pos_embed", action="store_true")
174 | parser.set_defaults(sep_pos_embed=True)
175 | parser.add_argument(
176 | "--trunc_init",
177 | action="store_true",
178 | )
179 | parser.add_argument(
180 | "--fp32",
181 | action="store_true",
182 | )
183 | parser.set_defaults(fp32=True)
184 | parser.add_argument(
185 | "--jitter_scales_relative",
186 | default=[0.5, 1.0],
187 | type=float,
188 | nargs="+",
189 | )
190 | parser.add_argument(
191 | "--jitter_aspect_relative",
192 | default=[0.75, 1.3333],
193 | type=float,
194 | nargs="+",
195 | )
196 | parser.add_argument(
197 | "--beta",
198 | default=None,
199 | type=float,
200 | nargs="+",
201 | )
202 | parser.add_argument(
203 | "--pred_t_dim",
204 | type=int,
205 | default=8,
206 | )
207 | parser.add_argument("--cls_embed", action="store_true")
208 | parser.set_defaults(cls_embed=True)
209 | return parser
210 |
211 |
212 | def main(args):
213 | misc.init_distributed_mode(args)
214 |
215 | print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
216 | print("{}".format(args).replace(", ", ",\n"))
217 |
218 | # 멀티 GPU 초기화
219 | # if args.distributed:
220 | # dist.init_process_group(backend="nccl", init_method=args.dist_url, rank=args.local_rank, world_size=args.world_size)
221 | # torch.cuda.set_device(args.local_rank)
222 | if args.distributed:
223 | if not dist.is_initialized(): # 이미 초기화된 경우 중복 호출 방지
224 | dist.init_process_group(
225 | backend="nccl",
226 | init_method=args.dist_url,
227 | rank=args.local_rank,
228 | world_size=args.world_size
229 | )
230 | torch.cuda.set_device(args.local_rank)
231 |
232 | print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
233 | print("{}".format(args).replace(", ", ",\n"))
234 |
235 | device = torch.device(args.device)
236 |
237 | # fix the seed for reproducibility
238 | seed = args.seed + misc.get_rank()
239 | torch.manual_seed(seed)
240 | np.random.seed(seed)
241 |
242 | cudnn.benchmark = True
243 |
244 | dataset_train = EchoDataset_from_Video_mp4(args.data_path)
245 |
246 | if args.distributed:
247 | num_tasks = misc.get_world_size()
248 | global_rank = misc.get_rank()
249 | sampler_train = torch.utils.data.DistributedSampler(
250 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
251 | )
252 | print("Sampler_train = %s" % str(sampler_train))
253 | else:
254 | num_tasks = 1
255 | global_rank = 0
256 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
257 |
258 | if global_rank == 0 and args.log_dir is not None:
259 | try:
260 | pathmgr.mkdirs(args.log_dir)
261 | except Exception as _:
262 | pass
263 | log_writer = SummaryWriter(log_dir=args.log_dir)
264 | else:
265 | log_writer = None
266 |
267 | data_loader_train = torch.utils.data.DataLoader(
268 | dataset_train,
269 | sampler=sampler_train,
270 | batch_size=args.batch_size,
271 | num_workers=args.num_workers,
272 | pin_memory=args.pin_mem,
273 | drop_last=True,
274 | )
275 |
276 | # define the model
277 | model = models_mae.__dict__[args.model](
278 | **vars(args),
279 | )
280 |
281 | model.to(device)
282 |
283 | model_without_ddp = model
284 | print("Model = %s" % str(model_without_ddp))
285 |
286 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
287 |
288 | if args.lr is None: # only base_lr is specified
289 | args.lr = args.blr * eff_batch_size / 256
290 |
291 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
292 | print("actual lr: %.2e" % args.lr)
293 |
294 | print("accumulate grad iterations: %d" % args.accum_iter)
295 | print("effective batch size: %d" % eff_batch_size)
296 |
297 | if args.distributed:
298 | model = torch.nn.parallel.DistributedDataParallel(
299 | model,
300 | device_ids=[torch.cuda.current_device()],
301 | # find_unused_parameters=True,
302 | )
303 | model_without_ddp = model.module
304 |
305 | # following timm: set wd as 0 for bias and norm layers
306 | param_groups = misc.add_weight_decay(
307 | model_without_ddp,
308 | args.weight_decay,
309 | bias_wd=args.bias_wd,
310 | )
311 | if args.beta is None:
312 | beta = (0.9, 0.95)
313 | else:
314 | beta = args.beta
315 | optimizer = torch.optim.AdamW(
316 | param_groups,
317 | lr=args.lr,
318 | betas=beta,
319 | )
320 | loss_scaler = NativeScaler(fp32=args.fp32)
321 |
322 | misc.load_model(
323 | args=args,
324 | model_without_ddp=model_without_ddp,
325 | optimizer=optimizer,
326 | loss_scaler=loss_scaler,
327 | )
328 |
329 | checkpoint_path = ""
330 | print(f"Start training for {args.epochs} epochs")
331 | start_time = time.time()
332 | for epoch in range(args.start_epoch, args.epochs):
333 | if args.distributed:
334 | data_loader_train.sampler.set_epoch(epoch)
335 | train_stats = train_one_epoch(
336 | model,
337 | data_loader_train,
338 | optimizer,
339 | device,
340 | epoch,
341 | loss_scaler,
342 | log_writer=log_writer,
343 | args=args,
344 | fp32=args.fp32,
345 | )
346 | if args.output_dir and (
347 | epoch % args.checkpoint_period == 0 or epoch + 1 == args.epochs
348 | ):
349 | checkpoint_path = misc.save_model(
350 | args=args,
351 | model=model,
352 | model_without_ddp=model_without_ddp,
353 | optimizer=optimizer,
354 | loss_scaler=loss_scaler,
355 | epoch=epoch,
356 | )
357 |
358 | log_stats = {
359 | **{f"train_{k}": v for k, v in train_stats.items()},
360 | "epoch": epoch,
361 | }
362 |
363 | if args.output_dir and misc.is_main_process():
364 | if log_writer is not None:
365 | log_writer.flush()
366 | with pathmgr.open(
367 | f"{args.output_dir}/log.txt",
368 | "a",
369 | ) as f:
370 | f.write(json.dumps(log_stats) + "\n")
371 |
372 | total_time = time.time() - start_time
373 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
374 | print("Training time {}".format(total_time_str))
375 | print(torch.cuda.memory_allocated())
376 | return [checkpoint_path]
377 |
378 |
379 | def launch_one_thread(
380 | local_rank,
381 | shard_rank,
382 | num_gpus_per_node,
383 | num_shards,
384 | init_method,
385 | output_path,
386 | opts,
387 | stats_queue,
388 | ):
389 | print(opts)
390 | args = get_args_parser()
391 | args = args.parse_args(opts)
392 | args.rank = shard_rank * num_gpus_per_node + local_rank
393 | args.world_size = num_shards * num_gpus_per_node
394 | args.gpu = local_rank
395 | args.dist_url = init_method
396 | args.output_dir = output_path
397 | output = main(args)
398 | stats_queue.put(output)
--------------------------------------------------------------------------------
/EchoFM/util/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import builtins
13 | import datetime
14 | import math
15 | import os
16 | import time
17 | from collections import defaultdict, deque, OrderedDict
18 |
19 | import EchoFM.util.logging as logging
20 | import psutil
21 | import torch
22 | import torch.distributed as dist
23 | # import torch.fb.rendezvous.zeus
24 | from iopath.common.file_io import g_pathmgr as pathmgr
25 | from EchoFM.util.logging import master_print as print
26 | from torch import inf
27 |
28 |
29 | logger = logging.get_logger(__name__)
30 |
31 |
32 | class SmoothedValue:
33 | """Track a series of values and provide access to smoothed values over a
34 | window or the global series average.
35 | """
36 |
37 | def __init__(self, window_size=20, fmt=None):
38 | if fmt is None:
39 | fmt = "{median:.4f} ({global_avg:.4f})"
40 | self.deque = deque(maxlen=window_size)
41 | self.total = 0.0
42 | self.count = 0
43 | self.fmt = fmt
44 |
45 | def update(self, value, n=1):
46 | self.deque.append(value)
47 | self.count += n
48 | self.total += value * n
49 |
50 | def synchronize_between_processes(self):
51 | """
52 | Warning: does not synchronize the deque!
53 | """
54 | if not is_dist_avail_and_initialized():
55 | return
56 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
57 | dist.barrier()
58 | dist.all_reduce(t)
59 | t = t.tolist()
60 | self.count = int(t[0])
61 | self.total = t[1]
62 |
63 | @property
64 | def median(self):
65 | d = torch.tensor(list(self.deque))
66 | return d.median().item()
67 |
68 | @property
69 | def avg(self):
70 | d = torch.tensor(list(self.deque), dtype=torch.float32)
71 | return d.mean().item()
72 |
73 | @property
74 | def global_avg(self):
75 | return self.total / self.count
76 |
77 | @property
78 | def max(self):
79 | return max(self.deque)
80 |
81 | @property
82 | def value(self):
83 | return self.deque[-1]
84 |
85 | def __str__(self):
86 | return self.fmt.format(
87 | median=self.median,
88 | avg=self.avg,
89 | global_avg=self.global_avg,
90 | max=self.max,
91 | value=self.value,
92 | )
93 |
94 |
95 | class MetricLogger:
96 | def __init__(self, delimiter="\t"):
97 | self.meters = defaultdict(SmoothedValue)
98 | self.delimiter = delimiter
99 |
100 | def update(self, **kwargs):
101 | for k, v in kwargs.items():
102 | if v is None:
103 | continue
104 | if isinstance(v, torch.Tensor):
105 | v = v.item()
106 | assert isinstance(v, (float, int))
107 | self.meters[k].update(v)
108 |
109 | def __getattr__(self, attr):
110 | if attr in self.meters:
111 | return self.meters[attr]
112 | if attr in self.__dict__:
113 | return self.__dict__[attr]
114 | raise AttributeError(
115 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
116 | )
117 |
118 | def __str__(self):
119 | loss_str = []
120 | for name, meter in self.meters.items():
121 | loss_str.append("{}: {}".format(name, str(meter)))
122 | return self.delimiter.join(loss_str)
123 |
124 | def synchronize_between_processes(self):
125 | for meter in self.meters.values():
126 | meter.synchronize_between_processes()
127 |
128 | def add_meter(self, name, meter):
129 | self.meters[name] = meter
130 |
131 | def log_every(self, iterable, print_freq, header=None):
132 | i = 0
133 | if not header:
134 | header = ""
135 | start_time = time.time()
136 | end = time.time()
137 | iter_time = SmoothedValue(fmt="{avg:.4f}")
138 | data_time = SmoothedValue(fmt="{avg:.4f}")
139 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
140 | log_msg = [
141 | header,
142 | "[{0" + space_fmt + "}/{1}]",
143 | "eta: {eta}",
144 | "{meters}",
145 | "time: {time}",
146 | "data: {data}",
147 | ]
148 | if torch.cuda.is_available():
149 | log_msg.append("max mem: {memory:.0f}")
150 | log_msg = self.delimiter.join(log_msg)
151 | MB = 1024.0 * 1024.0
152 | for obj in iterable:
153 | data_time.update(time.time() - end)
154 | yield obj
155 | iter_time.update(time.time() - end)
156 | if i % print_freq == 0 or i == len(iterable) - 1:
157 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
158 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
159 | if torch.cuda.is_available():
160 | print(
161 | log_msg.format(
162 | i,
163 | len(iterable),
164 | eta=eta_string,
165 | meters=str(self),
166 | time=str(iter_time),
167 | data=str(data_time),
168 | memory=torch.cuda.max_memory_allocated() / MB,
169 | )
170 | )
171 |
172 | else:
173 | print(
174 | log_msg.format(
175 | i,
176 | len(iterable),
177 | eta=eta_string,
178 | meters=str(self),
179 | time=str(iter_time),
180 | data=str(data_time),
181 | )
182 | )
183 | i += 1
184 | end = time.time()
185 | total_time = time.time() - start_time
186 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
187 | print(
188 | "{} Total time: {} ({:.4f} s / it)".format(
189 | header, total_time_str, total_time / len(iterable)
190 | )
191 | )
192 |
193 |
194 | def setup_for_distributed(is_master):
195 | """
196 | This function disables printing when not in master process
197 | """
198 | builtin_print = builtins.print
199 |
200 | def print(*args, **kwargs):
201 | force = kwargs.pop("force", False)
202 | force = force or (get_world_size() > 8)
203 | if is_master or force:
204 | now = datetime.datetime.now().time()
205 | builtin_print("[{}] ".format(now), end="") # print with time stamp
206 | builtin_print(*args, **kwargs)
207 |
208 | builtins.print = print
209 |
210 |
211 | def is_dist_avail_and_initialized():
212 | if not dist.is_available():
213 | return False
214 | if not dist.is_initialized():
215 | return False
216 | return True
217 |
218 |
219 | def get_world_size():
220 | if not is_dist_avail_and_initialized():
221 | return 1
222 | return dist.get_world_size()
223 |
224 |
225 | def get_rank():
226 | if not is_dist_avail_and_initialized():
227 | return 0
228 | return dist.get_rank()
229 |
230 |
231 | def is_main_process():
232 | return get_rank() == 0
233 |
234 |
235 | def save_on_master(state, path):
236 | if is_main_process():
237 | print(f"save path {path}")
238 | with pathmgr.open(path, "wb") as f:
239 | torch.save(state, f)
240 |
241 |
242 | def init_distributed_mode(args):
243 | if args.no_env:
244 | pass
245 | elif args.dist_on_itp:
246 | args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
247 | args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
248 | args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
249 | args.dist_url = "tcp://%s:%s" % (
250 | os.environ["MASTER_ADDR"],
251 | os.environ["MASTER_PORT"],
252 | )
253 | os.environ["LOCAL_RANK"] = str(args.gpu)
254 | os.environ["RANK"] = str(args.rank)
255 | os.environ["WORLD_SIZE"] = str(args.world_size)
256 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
257 | elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
258 | args.rank = int(os.environ["RANK"])
259 | args.world_size = int(os.environ["WORLD_SIZE"])
260 | args.gpu = int(os.environ["LOCAL_RANK"])
261 | elif "SLURM_PROCID" in os.environ:
262 | args.rank = int(os.environ["SLURM_PROCID"])
263 | args.gpu = args.rank % torch.cuda.device_count()
264 | else:
265 | print("Not using distributed mode")
266 | setup_for_distributed(is_master=True) # hack
267 | args.distributed = False
268 | return
269 |
270 | args.distributed = True
271 |
272 | torch.cuda.set_device(args.gpu)
273 | args.dist_backend = "nccl"
274 | print(
275 | "| distributed init (rank {}): {}, gpu {}".format(
276 | args.rank, args.dist_url, args.gpu
277 | ),
278 | # flush=True,
279 | )
280 | torch.distributed.init_process_group(
281 | backend=args.dist_backend,
282 | init_method=args.dist_url,
283 | world_size=args.world_size,
284 | rank=args.rank,
285 | )
286 | torch.distributed.barrier()
287 | setup_for_distributed(args.rank == 0)
288 |
289 |
290 | class NativeScalerWithGradNormCount:
291 | state_dict_key = "amp_scaler"
292 |
293 | def __init__(self, fp32=False):
294 | self._scaler = torch.cuda.amp.GradScaler(enabled=not fp32)
295 |
296 | def __call__(
297 | self,
298 | loss,
299 | optimizer,
300 | clip_grad=None,
301 | parameters=None,
302 | create_graph=False,
303 | update_grad=True,
304 | ):
305 | self._scaler.scale(loss).backward(create_graph=create_graph)
306 | if update_grad:
307 | if clip_grad is not None:
308 | assert parameters is not None
309 | self._scaler.unscale_(
310 | optimizer
311 | ) # unscale the gradients of optimizer's assigned params in-place
312 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
313 | else:
314 | self._scaler.unscale_(optimizer)
315 | norm = get_grad_norm_(parameters)
316 | self._scaler.step(optimizer)
317 | self._scaler.update()
318 | else:
319 | norm = None
320 | return norm
321 |
322 | def state_dict(self):
323 | return self._scaler.state_dict()
324 |
325 | def load_state_dict(self, state_dict):
326 | self._scaler.load_state_dict(state_dict)
327 |
328 |
329 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
330 | if isinstance(parameters, torch.Tensor):
331 | parameters = [parameters]
332 | parameters = [p for p in parameters if p.grad is not None]
333 | norm_type = float(norm_type)
334 | if len(parameters) == 0:
335 | return torch.tensor(0.0)
336 | device = parameters[0].grad.device
337 | if norm_type == inf:
338 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
339 | else:
340 | total_norm = torch.norm(
341 | torch.stack(
342 | [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
343 | ),
344 | norm_type,
345 | )
346 | return total_norm
347 |
348 |
349 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
350 | checkpoint_path = "{}/checkpoint-{:05d}.pth".format(args.output_dir, epoch)
351 | to_save = {
352 | "model": model_without_ddp.state_dict(),
353 | "optimizer": optimizer.state_dict(),
354 | "epoch": epoch,
355 | "scaler": loss_scaler.state_dict(),
356 | "args": args,
357 | }
358 |
359 | save_on_master(to_save, checkpoint_path)
360 | return checkpoint_path
361 |
362 |
363 | def get_last_checkpoint(args):
364 | """
365 | Get the last checkpoint from the checkpointing folder.
366 | Args:
367 | path_to_job (string): the path to the folder of the current job.
368 | """
369 | d = args.output_dir
370 | names = pathmgr.ls(d) if pathmgr.exists(d) else []
371 | names = [f for f in names if "checkpoint" in f]
372 | if len(names) == 0:
373 | print("No checkpoints found in '{}'.".format(d))
374 | return None
375 | else:
376 | # Sort the checkpoints by epoch.
377 | name = sorted(names)[-1]
378 | return os.path.join(d, name)
379 |
380 |
381 | def load_model(args, model_without_ddp, optimizer, loss_scaler):
382 | if not args.resume:
383 | args.resume = get_last_checkpoint(args)
384 | if args.resume:
385 | if args.resume.startswith("https"):
386 | checkpoint = torch.hub.load_state_dict_from_url(
387 | args.resume, map_location="cpu", check_hash=True
388 | )
389 | else:
390 | with pathmgr.open(args.resume, "rb") as f:
391 | checkpoint = torch.load(f, map_location="cpu")
392 | model_without_ddp.load_state_dict(checkpoint["model"])
393 | print("Resume checkpoint %s" % args.resume)
394 | if (
395 | "optimizer" in checkpoint
396 | and "epoch" in checkpoint
397 | and not (hasattr(args, "eval") and args.eval)
398 | ):
399 | optimizer.load_state_dict(checkpoint["optimizer"])
400 | args.start_epoch = checkpoint["epoch"] + 1
401 | if "scaler" in checkpoint:
402 | loss_scaler.load_state_dict(checkpoint["scaler"])
403 | print("With optim & sched!")
404 |
405 |
406 | def all_reduce_mean(x):
407 | world_size = get_world_size()
408 | if world_size > 1:
409 | x_reduce = torch.tensor(x).cuda()
410 | dist.all_reduce(x_reduce)
411 | x_reduce /= world_size
412 | return x_reduce.item()
413 | else:
414 | return x
415 |
416 |
417 | def gpu_mem_usage():
418 | """
419 | Compute the GPU memory usage for the current device (GB).
420 | """
421 | if torch.cuda.is_available():
422 | mem_usage_bytes = torch.cuda.max_memory_allocated()
423 | else:
424 | mem_usage_bytes = 0
425 | return mem_usage_bytes / 1024**3
426 |
427 |
428 | def cpu_mem_usage():
429 | """
430 | Compute the system memory (RAM) usage for the current device (GB).
431 | Returns:
432 | usage (float): used memory (GB).
433 | total (float): total memory (GB).
434 | """
435 | vram = psutil.virtual_memory()
436 | usage = (vram.total - vram.available) / 1024**3
437 | total = vram.total / 1024**3
438 |
439 | return usage, total
440 |
441 |
442 | def all_gather(tensors):
443 | """
444 | All gathers the provided tensors from all processes across machines.
445 | Args:
446 | tensors (list): tensors to perform all gather across all processes in
447 | all machines.
448 | """
449 |
450 | gather_list = []
451 | output_tensor = []
452 | world_size = dist.get_world_size()
453 | for tensor in tensors:
454 | tensor_placeholder = [torch.ones_like(tensor) for _ in range(world_size)]
455 | dist.all_gather(tensor_placeholder, tensor, async_op=False)
456 | gather_list.append(tensor_placeholder)
457 | for gathered_tensor in gather_list:
458 | output_tensor.append(torch.cat(gathered_tensor, dim=0))
459 | return output_tensor
460 |
461 |
462 | def add_weight_decay(model, weight_decay=1e-5, skip_list=(), bias_wd=False):
463 | decay = []
464 | no_decay = []
465 | for name, param in model.named_parameters():
466 | if not param.requires_grad:
467 | continue # frozen weights
468 | if (
469 | (not bias_wd)
470 | and len(param.shape) == 1
471 | or name.endswith(".bias")
472 | or name in skip_list
473 | ):
474 | no_decay.append(param)
475 | else:
476 | decay.append(param)
477 | return [
478 | {"params": no_decay, "weight_decay": 0.0},
479 | {"params": decay, "weight_decay": weight_decay},
480 | ]
481 |
482 |
483 | def inflate(model_2d, model_3d):
484 | state_dict_inflated = OrderedDict()
485 | for k, v2d in model_2d.items():
486 | if "patch_embed.proj.weight" in k:
487 | v3d = model_3d[k]
488 | v3d = v2d.unsqueeze(2).repeat(1, 1, v3d.shape[2], 1, 1) / v3d.shape[2]
489 | state_dict_inflated[k] = v3d.clone()
490 | elif "pos_embed" in k:
491 | pos_embed_cls, pos_embed_spatial = torch.split(v2d, [1, 196], dim=1)
492 | state_dict_inflated["pos_embed_cls"] = pos_embed_cls.clone()
493 | state_dict_inflated["pos_embed"] = pos_embed_spatial.clone()
494 | else:
495 | state_dict_inflated[k] = v2d.clone()
496 | return state_dict_inflated
497 |
498 |
499 | def convert_checkpoint(model_2d):
500 | state_dict_inflated = OrderedDict()
501 | for k, v2d in model_2d.items():
502 | if "head.projection.weight" in k:
503 | state_dict_inflated["head.weight"] = v2d.clone()
504 | elif "head.projection.bias" in k:
505 | state_dict_inflated["head.bias"] = v2d.clone()
506 | else:
507 | state_dict_inflated[k] = v2d.clone()
508 | return state_dict_inflated
509 |
--------------------------------------------------------------------------------
/EchoFM/util/decoder/rand_augment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 |
5 | """
6 | This implementation is based on
7 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py
8 | pulished under an Apache License 2.0.
9 |
10 | COMMENT FROM ORIGINAL:
11 | AutoAugment, RandAugment, and AugMix for PyTorch
12 | This code implements the searched ImageNet policies with various tweaks and
13 | improvements and does not include any of the search code. AA and RA
14 | Implementation adapted from:
15 | https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
16 | AugMix adapted from:
17 | https://github.com/google-research/augmix
18 | Papers:
19 | AutoAugment: Learning Augmentation Policies from Data
20 | https://arxiv.org/abs/1805.09501
21 | Learning Data Augmentation Strategies for Object Detection
22 | https://arxiv.org/abs/1906.11172
23 | RandAugment: Practical automated data augmentation...
24 | https://arxiv.org/abs/1909.13719
25 | AugMix: A Simple Data Processing Method to Improve Robustness and
26 | Uncertainty https://arxiv.org/abs/1912.02781
27 |
28 | Hacked together by / Copyright 2020 Ross Wightman
29 | """
30 |
31 | import math
32 | import random
33 | import re
34 |
35 | import numpy as np
36 | import PIL
37 | from PIL import Image, ImageEnhance, ImageOps
38 |
39 | _PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]])
40 |
41 | _FILL = (128, 128, 128)
42 |
43 | # This signifies the max integer that the controller RNN could predict for the
44 | # augmentation scheme.
45 | _MAX_LEVEL = 10.0
46 |
47 | _HPARAMS_DEFAULT = {
48 | "translate_const": 250,
49 | "img_mean": _FILL,
50 | }
51 |
52 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
53 |
54 |
55 | def _interpolation(kwargs):
56 | interpolation = kwargs.pop("resample", Image.BILINEAR)
57 | if isinstance(interpolation, (list, tuple)):
58 | return random.choice(interpolation)
59 | else:
60 | return interpolation
61 |
62 |
63 | def _check_args_tf(kwargs):
64 | if "fillcolor" in kwargs and _PIL_VER < (5, 0):
65 | kwargs.pop("fillcolor")
66 | kwargs["resample"] = _interpolation(kwargs)
67 |
68 |
69 | def shear_x(img, factor, **kwargs):
70 | _check_args_tf(kwargs)
71 | return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
72 |
73 |
74 | def shear_y(img, factor, **kwargs):
75 | _check_args_tf(kwargs)
76 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
77 |
78 |
79 | def translate_x_rel(img, pct, **kwargs):
80 | pixels = pct * img.size[0]
81 | _check_args_tf(kwargs)
82 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
83 |
84 |
85 | def translate_y_rel(img, pct, **kwargs):
86 | pixels = pct * img.size[1]
87 | _check_args_tf(kwargs)
88 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
89 |
90 |
91 | def translate_x_abs(img, pixels, **kwargs):
92 | _check_args_tf(kwargs)
93 | return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
94 |
95 |
96 | def translate_y_abs(img, pixels, **kwargs):
97 | _check_args_tf(kwargs)
98 | return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
99 |
100 |
101 | def rotate(img, degrees, **kwargs):
102 | _check_args_tf(kwargs)
103 | if _PIL_VER >= (5, 2):
104 | return img.rotate(degrees, **kwargs)
105 | elif _PIL_VER >= (5, 0):
106 | w, h = img.size
107 | post_trans = (0, 0)
108 | rotn_center = (w / 2.0, h / 2.0)
109 | angle = -math.radians(degrees)
110 | matrix = [
111 | round(math.cos(angle), 15),
112 | round(math.sin(angle), 15),
113 | 0.0,
114 | round(-math.sin(angle), 15),
115 | round(math.cos(angle), 15),
116 | 0.0,
117 | ]
118 |
119 | def transform(x, y, matrix):
120 | (a, b, c, d, e, f) = matrix
121 | return a * x + b * y + c, d * x + e * y + f
122 |
123 | matrix[2], matrix[5] = transform(
124 | -rotn_center[0] - post_trans[0],
125 | -rotn_center[1] - post_trans[1],
126 | matrix,
127 | )
128 | matrix[2] += rotn_center[0]
129 | matrix[5] += rotn_center[1]
130 | return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
131 | else:
132 | return img.rotate(degrees, resample=kwargs["resample"])
133 |
134 |
135 | def auto_contrast(img, **__):
136 | return ImageOps.autocontrast(img)
137 |
138 |
139 | def invert(img, **__):
140 | return ImageOps.invert(img)
141 |
142 |
143 | def equalize(img, **__):
144 | return ImageOps.equalize(img)
145 |
146 |
147 | def solarize(img, thresh, **__):
148 | return ImageOps.solarize(img, thresh)
149 |
150 |
151 | def solarize_add(img, add, thresh=128, **__):
152 | lut = []
153 | for i in range(256):
154 | if i < thresh:
155 | lut.append(min(255, i + add))
156 | else:
157 | lut.append(i)
158 | if img.mode in ("L", "RGB"):
159 | if img.mode == "RGB" and len(lut) == 256:
160 | lut = lut + lut + lut
161 | return img.point(lut)
162 | else:
163 | return img
164 |
165 |
166 | def posterize(img, bits_to_keep, **__):
167 | if bits_to_keep >= 8:
168 | return img
169 | return ImageOps.posterize(img, bits_to_keep)
170 |
171 |
172 | def contrast(img, factor, **__):
173 | return ImageEnhance.Contrast(img).enhance(factor)
174 |
175 |
176 | def color(img, factor, **__):
177 | return ImageEnhance.Color(img).enhance(factor)
178 |
179 |
180 | def brightness(img, factor, **__):
181 | return ImageEnhance.Brightness(img).enhance(factor)
182 |
183 |
184 | def sharpness(img, factor, **__):
185 | return ImageEnhance.Sharpness(img).enhance(factor)
186 |
187 |
188 | def _randomly_negate(v):
189 | """With 50% prob, negate the value"""
190 | return -v if random.random() > 0.5 else v
191 |
192 |
193 | def _rotate_level_to_arg(level, _hparams):
194 | # range [-30, 30]
195 | level = (level / _MAX_LEVEL) * 30.0
196 | level = _randomly_negate(level)
197 | return (level,)
198 |
199 |
200 | def _enhance_level_to_arg(level, _hparams):
201 | # range [0.1, 1.9]
202 | return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
203 |
204 |
205 | def _enhance_increasing_level_to_arg(level, _hparams):
206 | # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
207 | # range [0.1, 1.9]
208 | level = (level / _MAX_LEVEL) * 0.9
209 | level = 1.0 + _randomly_negate(level)
210 | return (level,)
211 |
212 |
213 | def _shear_level_to_arg(level, _hparams):
214 | # range [-0.3, 0.3]
215 | level = (level / _MAX_LEVEL) * 0.3
216 | level = _randomly_negate(level)
217 | return (level,)
218 |
219 |
220 | def _translate_abs_level_to_arg(level, hparams):
221 | translate_const = hparams["translate_const"]
222 | level = (level / _MAX_LEVEL) * float(translate_const)
223 | level = _randomly_negate(level)
224 | return (level,)
225 |
226 |
227 | def _translate_rel_level_to_arg(level, hparams):
228 | # default range [-0.45, 0.45]
229 | translate_pct = hparams.get("translate_pct", 0.45)
230 | level = (level / _MAX_LEVEL) * translate_pct
231 | level = _randomly_negate(level)
232 | return (level,)
233 |
234 |
235 | def _posterize_level_to_arg(level, _hparams):
236 | # As per Tensorflow TPU EfficientNet impl
237 | # range [0, 4], 'keep 0 up to 4 MSB of original image'
238 | # intensity/severity of augmentation decreases with level
239 | return (int((level / _MAX_LEVEL) * 4),)
240 |
241 |
242 | def _posterize_increasing_level_to_arg(level, hparams):
243 | # As per Tensorflow models research and UDA impl
244 | # range [4, 0], 'keep 4 down to 0 MSB of original image',
245 | # intensity/severity of augmentation increases with level
246 | return (4 - _posterize_level_to_arg(level, hparams)[0],)
247 |
248 |
249 | def _posterize_original_level_to_arg(level, _hparams):
250 | # As per original AutoAugment paper description
251 | # range [4, 8], 'keep 4 up to 8 MSB of image'
252 | # intensity/severity of augmentation decreases with level
253 | return (int((level / _MAX_LEVEL) * 4) + 4,)
254 |
255 |
256 | def _solarize_level_to_arg(level, _hparams):
257 | # range [0, 256]
258 | # intensity/severity of augmentation decreases with level
259 | return (int((level / _MAX_LEVEL) * 256),)
260 |
261 |
262 | def _solarize_increasing_level_to_arg(level, _hparams):
263 | # range [0, 256]
264 | # intensity/severity of augmentation increases with level
265 | return (256 - _solarize_level_to_arg(level, _hparams)[0],)
266 |
267 |
268 | def _solarize_add_level_to_arg(level, _hparams):
269 | # range [0, 110]
270 | return (int((level / _MAX_LEVEL) * 110),)
271 |
272 |
273 | LEVEL_TO_ARG = {
274 | "AutoContrast": None,
275 | "Equalize": None,
276 | "Invert": None,
277 | "Rotate": _rotate_level_to_arg,
278 | # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
279 | "Posterize": _posterize_level_to_arg,
280 | "PosterizeIncreasing": _posterize_increasing_level_to_arg,
281 | "PosterizeOriginal": _posterize_original_level_to_arg,
282 | "Solarize": _solarize_level_to_arg,
283 | "SolarizeIncreasing": _solarize_increasing_level_to_arg,
284 | "SolarizeAdd": _solarize_add_level_to_arg,
285 | "Color": _enhance_level_to_arg,
286 | "ColorIncreasing": _enhance_increasing_level_to_arg,
287 | "Contrast": _enhance_level_to_arg,
288 | "ContrastIncreasing": _enhance_increasing_level_to_arg,
289 | "Brightness": _enhance_level_to_arg,
290 | "BrightnessIncreasing": _enhance_increasing_level_to_arg,
291 | "Sharpness": _enhance_level_to_arg,
292 | "SharpnessIncreasing": _enhance_increasing_level_to_arg,
293 | "ShearX": _shear_level_to_arg,
294 | "ShearY": _shear_level_to_arg,
295 | "TranslateX": _translate_abs_level_to_arg,
296 | "TranslateY": _translate_abs_level_to_arg,
297 | "TranslateXRel": _translate_rel_level_to_arg,
298 | "TranslateYRel": _translate_rel_level_to_arg,
299 | }
300 |
301 |
302 | NAME_TO_OP = {
303 | "AutoContrast": auto_contrast,
304 | "Equalize": equalize,
305 | "Invert": invert,
306 | "Rotate": rotate,
307 | "Posterize": posterize,
308 | "PosterizeIncreasing": posterize,
309 | "PosterizeOriginal": posterize,
310 | "Solarize": solarize,
311 | "SolarizeIncreasing": solarize,
312 | "SolarizeAdd": solarize_add,
313 | "Color": color,
314 | "ColorIncreasing": color,
315 | "Contrast": contrast,
316 | "ContrastIncreasing": contrast,
317 | "Brightness": brightness,
318 | "BrightnessIncreasing": brightness,
319 | "Sharpness": sharpness,
320 | "SharpnessIncreasing": sharpness,
321 | "ShearX": shear_x,
322 | "ShearY": shear_y,
323 | "TranslateX": translate_x_abs,
324 | "TranslateY": translate_y_abs,
325 | "TranslateXRel": translate_x_rel,
326 | "TranslateYRel": translate_y_rel,
327 | }
328 |
329 |
330 | class AugmentOp:
331 | """
332 | Apply for video.
333 | """
334 |
335 | def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
336 | hparams = hparams or _HPARAMS_DEFAULT
337 | self.aug_fn = NAME_TO_OP[name]
338 | self.level_fn = LEVEL_TO_ARG[name]
339 | self.prob = prob
340 | self.magnitude = magnitude
341 | self.hparams = hparams.copy()
342 | self.kwargs = {
343 | "fillcolor": hparams["img_mean"] if "img_mean" in hparams else _FILL,
344 | "resample": (
345 | hparams["interpolation"]
346 | if "interpolation" in hparams
347 | else _RANDOM_INTERPOLATION
348 | ),
349 | }
350 |
351 | # If magnitude_std is > 0, we introduce some randomness
352 | # in the usually fixed policy and sample magnitude from a normal distribution
353 | # with mean `magnitude` and std-dev of `magnitude_std`.
354 | # NOTE This is my own hack, being tested, not in papers or reference impls.
355 | self.magnitude_std = self.hparams.get("magnitude_std", 0)
356 |
357 | def __call__(self, img_list):
358 | if self.prob < 1.0 and random.random() > self.prob:
359 | return img_list
360 | magnitude = self.magnitude
361 | if self.magnitude_std and self.magnitude_std > 0:
362 | magnitude = random.gauss(magnitude, self.magnitude_std)
363 | magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
364 | level_args = (
365 | self.level_fn(magnitude, self.hparams) if self.level_fn is not None else ()
366 | )
367 |
368 | if isinstance(img_list, list):
369 | return [self.aug_fn(img, *level_args, **self.kwargs) for img in img_list]
370 | else:
371 | return self.aug_fn(img_list, *level_args, **self.kwargs)
372 |
373 |
374 | _RAND_TRANSFORMS = [
375 | "AutoContrast",
376 | "Equalize",
377 | "Invert",
378 | "Rotate",
379 | "Posterize",
380 | "Solarize",
381 | "SolarizeAdd",
382 | "Color",
383 | "Contrast",
384 | "Brightness",
385 | "Sharpness",
386 | "ShearX",
387 | "ShearY",
388 | "TranslateXRel",
389 | "TranslateYRel",
390 | ]
391 |
392 |
393 | _RAND_INCREASING_TRANSFORMS = [
394 | "AutoContrast",
395 | "Equalize",
396 | "Invert",
397 | "Rotate",
398 | "PosterizeIncreasing",
399 | "SolarizeIncreasing",
400 | "SolarizeAdd",
401 | "ColorIncreasing",
402 | "ContrastIncreasing",
403 | "BrightnessIncreasing",
404 | "SharpnessIncreasing",
405 | "ShearX",
406 | "ShearY",
407 | "TranslateXRel",
408 | "TranslateYRel",
409 | ]
410 |
411 |
412 | # These experimental weights are based loosely on the relative improvements mentioned in paper.
413 | # They may not result in increased performance, but could likely be tuned to so.
414 | _RAND_CHOICE_WEIGHTS_0 = {
415 | "Rotate": 0.3,
416 | "ShearX": 0.2,
417 | "ShearY": 0.2,
418 | "TranslateXRel": 0.1,
419 | "TranslateYRel": 0.1,
420 | "Color": 0.025,
421 | "Sharpness": 0.025,
422 | "AutoContrast": 0.025,
423 | "Solarize": 0.005,
424 | "SolarizeAdd": 0.005,
425 | "Contrast": 0.005,
426 | "Brightness": 0.005,
427 | "Equalize": 0.005,
428 | "Posterize": 0,
429 | "Invert": 0,
430 | }
431 |
432 |
433 | def _select_rand_weights(weight_idx=0, transforms=None):
434 | transforms = transforms or _RAND_TRANSFORMS
435 | assert weight_idx == 0 # only one set of weights currently
436 | rand_weights = _RAND_CHOICE_WEIGHTS_0
437 | probs = [rand_weights[k] for k in transforms]
438 | probs /= np.sum(probs)
439 | return probs
440 |
441 |
442 | def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
443 | hparams = hparams or _HPARAMS_DEFAULT
444 | transforms = transforms or _RAND_TRANSFORMS
445 | return [
446 | AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams)
447 | for name in transforms
448 | ]
449 |
450 |
451 | class RandAugment:
452 | def __init__(self, ops, num_layers=2, choice_weights=None):
453 | self.ops = ops
454 | self.num_layers = num_layers
455 | self.choice_weights = choice_weights
456 |
457 | def __call__(self, img):
458 | # no replacement when using weighted choice
459 | ops = np.random.choice(
460 | self.ops,
461 | self.num_layers,
462 | replace=self.choice_weights is None,
463 | p=self.choice_weights,
464 | )
465 | for op in ops:
466 | img = op(img)
467 | return img
468 |
469 |
470 | def rand_augment_transform(config_str, hparams):
471 | """
472 | RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
473 |
474 | Create a RandAugment transform
475 | :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
476 | dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
477 | sections, not order sepecific determine
478 | 'm' - integer magnitude of rand augment
479 | 'n' - integer num layers (number of transform ops selected per image)
480 | 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
481 | 'mstd' - float std deviation of magnitude noise applied
482 | 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
483 | Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
484 | 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
485 | :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
486 | :return: A PyTorch compatible Transform
487 | """
488 | magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10)
489 | num_layers = 2 # default to 2 ops per image
490 | weight_idx = None # default to no probability weights for op choice
491 | transforms = _RAND_TRANSFORMS
492 | config = config_str.split("-")
493 | assert config[0] == "rand"
494 | config = config[1:]
495 | for c in config:
496 | cs = re.split(r"(\d.*)", c)
497 | if len(cs) < 2:
498 | continue
499 | key, val = cs[:2]
500 | if key == "mstd":
501 | # noise param injected via hparams for now
502 | hparams.setdefault("magnitude_std", float(val))
503 | elif key == "inc":
504 | if bool(val):
505 | transforms = _RAND_INCREASING_TRANSFORMS
506 | elif key == "m":
507 | magnitude = int(val)
508 | elif key == "n":
509 | num_layers = int(val)
510 | elif key == "w":
511 | weight_idx = int(val)
512 | else:
513 | assert NotImplementedError
514 | ra_ops = rand_augment_ops(
515 | magnitude=magnitude, hparams=hparams, transforms=transforms
516 | )
517 | choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
518 | return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
519 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import re
3 |
4 | import cv2
5 | from PIL import Image
6 | from functools import partial
7 |
8 | from typing import Tuple, List
9 | # from beartype.door import is_bearable
10 |
11 | import numpy as np
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.utils.data import Dataset, DataLoader as PytorchDataLoader
16 | from torchvision import transforms as T, utils
17 |
18 | from einops import rearrange
19 | import os
20 | import pickle as pkl
21 | import random
22 |
23 | # helper functions
24 |
25 | from scipy import interpolate
26 | import torchvision.utils as vutils
27 | from PIL import Image
28 |
29 |
30 | def exists(val):
31 | return val is not None
32 |
33 |
34 | def identity(t, *args, **kwargs):
35 | return t
36 |
37 |
38 | def pair(val):
39 | return val if isinstance(val, tuple) else (val, val)
40 |
41 |
42 | def bgr_to_rgb(video_tensor):
43 | video_tensor = video_tensor[[2, 1, 0], :, :, :]
44 | return video_tensor
45 |
46 |
47 |
48 | def cast_num_frames(t, *, frames):
49 | f = t.shape[1]
50 |
51 | if f == frames:
52 | return t
53 |
54 | if f > frames:
55 | return t[:, :frames]
56 |
57 | return F.pad(t, (0, 0, 0, 0, 0, frames - f))
58 |
59 | def convert_image_to_fn(img_type, image):
60 | if image.mode != img_type:
61 | return image.convert(img_type)
62 | return image
63 |
64 | # image related helpers functions and dataset
65 | def z_normalize(data):
66 | """
67 | Perform z-score normalization on the input data.
68 | """
69 | mean = np.mean(data)
70 | std = np.std(data)
71 | return (data - mean) / std
72 |
73 | def save_tensor_images(tensor, output_dir="output_images"):
74 | # 출력 디렉토리 생성
75 | os.makedirs(output_dir, exist_ok=True)
76 |
77 | # 텐서의 값을 0-255 범위로 정규화
78 | tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min()) * 255
79 | tensor = tensor.byte()
80 |
81 | # 각 프레임을 개별적으로 저장
82 | for frame in range(tensor.shape[1]): # 프레임 차원을 순회
83 | # 현재 프레임의 모든 채널 선택
84 | frame_data = tensor[:, frame, :, :]
85 |
86 | # 채널 순서 변경 (C, H, W) -> (H, W, C)
87 | frame_data = frame_data.permute(1, 2, 0)
88 |
89 | # NumPy 배열로 변환
90 | frame_array = frame_data.numpy()
91 |
92 | # RGB 이미지로 변환 (채널이 3개인 경우)
93 | if frame_array.shape[2] == 3:
94 | img = Image.fromarray(frame_array, 'RGB')
95 | else:
96 | img = Image.fromarray(frame_array[:,:,0], 'L')
97 |
98 | # 이미지 저장
99 | img.save(os.path.join(output_dir, f"frame_{frame:03d}.png"))
100 |
101 | class ImageDataset(Dataset):
102 | def __init__(
103 | self,
104 | folder,
105 | image_size,
106 | exts=['jpg', 'jpeg', 'png']
107 | ):
108 | super().__init__()
109 | self.folder = folder
110 | self.image_size = image_size
111 | self.paths = [p for ext in exts for p in Path(
112 | f'{folder}').glob(f'**/*.{ext}')]
113 |
114 | print(f'{len(self.paths)} training samples found at {folder}')
115 |
116 | self.transform = T.Compose([
117 | T.Lambda(lambda img: img.convert('RGB')
118 | if img.mode != 'RGB' else img),
119 | T.Resize(image_size),
120 | # T.RandomHorizontalFlip(),
121 | # T.CenterCrop(image_size),
122 | T.ToTensor()
123 | ])
124 |
125 | def __len__(self):
126 | return len(self.paths)
127 |
128 | def __getitem__(self, index):
129 | path = self.paths[index]
130 | img = Image.open(path)
131 | return self.transform(img)
132 |
133 | # tensor of shape (channels, frames, height, width) -> gif
134 |
135 | # handle reading and writing gif
136 |
137 |
138 | CHANNELS_TO_MODE = {
139 | 1: 'L',
140 | 3: 'RGB',
141 | 4: 'RGBA'
142 | }
143 |
144 |
145 | def seek_all_images(img, channels=3):
146 | assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid'
147 | mode = CHANNELS_TO_MODE[channels]
148 |
149 | i = 0
150 | while True:
151 | try:
152 | img.seek(i)
153 | yield img.convert(mode)
154 | except EOFError:
155 | break
156 | i += 1
157 |
158 | # tensor of shape (channels, frames, height, width) -> gif
159 |
160 |
161 | def video_tensor_to_pil_first_image(tensor):
162 |
163 | tensor = bgr_to_rgb(tensor)
164 | images = map(T.ToPILImage(), tensor.unbind(dim=1))
165 | first_img, *rest_imgs = images
166 |
167 | return first_img
168 |
169 |
170 | def video_tensor_to_gif(
171 | tensor,
172 | path,
173 | duration=120,
174 | loop=0,
175 | optimize=True
176 | ):
177 |
178 | tensor = torch.clamp(tensor, min=0, max=1) # clipping underflow and overflow
179 | #tensor = bgr_to_rgb(tensor)
180 | images = map(T.ToPILImage(), tensor.unbind(dim=1))
181 | first_img, *rest_imgs = images
182 | first_img.save(path, save_all=True, append_images=rest_imgs,
183 | loop=loop, optimize=optimize)
184 | return images
185 |
186 | # gif -> (channels, frame, height, width) tensor
187 |
188 |
189 | def gif_to_tensor(
190 | path,
191 | channels=3,
192 | transform=T.ToTensor()
193 | ):
194 | img = Image.open(path)
195 | tensors = tuple(map(transform, seek_all_images(img, channels=channels)))
196 | return torch.stack(tensors, dim=1)
197 |
198 | # handle reading and writing mp4
199 |
200 |
201 |
202 | def tensor_to_video(
203 | tensor, # Pytorch video tensor
204 | path: str, # Path of the video to be saved
205 | fps=8, # Frames per second for the saved video
206 | video_format=('m', 'p', '4', 'v')
207 | ):
208 | # Import the video and cut it into frames.
209 | tensor = tensor.cpu()*255. # TODO: have a better function for that? Not using cv2?
210 |
211 | num_frames, height, width = tensor.shape[-3:]
212 |
213 | # Changes in this line can allow for different video formats.
214 | fourcc = cv2.VideoWriter_fourcc(*video_format)
215 | video = cv2.VideoWriter(path, fourcc, fps, (width, height))
216 |
217 | frames = []
218 |
219 | for idx in range(num_frames):
220 | numpy_frame = tensor[:, idx, :, :].numpy()
221 | numpy_frame = np.uint8(rearrange(numpy_frame, 'c h w -> h w c'))
222 | video.write(numpy_frame)
223 |
224 | video.release()
225 |
226 | cv2.destroyAllWindows()
227 |
228 | return video
229 |
230 |
231 | def crop_center(
232 | img, # tensor
233 | cropx, # Length of the final image in the x direction.
234 | cropy # Length of the final image in the y direction.
235 | ) -> torch.Tensor:
236 | y, x, c = img.shape
237 | startx = x // 2 - cropx // 2
238 | starty = y // 2 - cropy // 2
239 | return img[starty:(starty + cropy), startx:(startx + cropx), :]
240 |
241 | def sort_key(file_path):
242 | # Extract the numerical parts from the file name using regex
243 | match = re.findall(r'(\d+)', file_path.stem)
244 | if match:
245 | return [int(part) for part in match]
246 | return str(file_path)
247 | # video dataset
248 |
249 | def save_tensor_as_grid(tensor, grid_size, save_path="grid_image.png"):
250 | """
251 | 4x4 그리드로 텐서를 저장하는 함수.
252 |
253 | Args:
254 | tensor (torch.Tensor): 텐서 크기 (C, N, H, W) 또는 (N, C, H, W).
255 | - C: 채널 수 (3이면 RGB, 1이면 그레이스케일).
256 | - N: 이미지 개수.
257 | - H, W: 이미지 높이와 너비.
258 | grid_size (int): 그리드의 행과 열 수. (예: 4이면 4x4 그리드).
259 | save_path (str): 저장할 이미지 파일 경로.
260 | """
261 | # 텐서 차원 맞추기: (N, C, H, W)
262 | if tensor.shape[0] == 3 or tensor.shape[0] == 1:
263 | tensor = tensor.permute(1, 0, 2, 3) # (C, N, H, W) -> (N, C, H, W)
264 |
265 | # 그리드 생성
266 | grid = vutils.make_grid(tensor, nrow=grid_size, padding=2)
267 |
268 | # 텐서를 PIL 이미지로 변환
269 | grid_image = (grid * 255).byte().permute(1, 2, 0).numpy() # RGB 순서로 변환
270 | image = Image.fromarray(grid_image)
271 |
272 | # 이미지 저장
273 | image.save(save_path)
274 | print(f"4x4 그리드 이미지를 '{save_path}'로 저장했습니다.")
275 |
276 | def process_ekg(ekg_data, target_length=2250, repetitions=16):
277 | processed_ekg = np.zeros((12, target_length))
278 |
279 | # 원본 데이터를 16번 반복
280 | repeated_data = np.tile(ekg_data, repetitions)
281 |
282 | # 반복된 데이터의 길이
283 | original_length = len(repeated_data)
284 |
285 | if original_length < target_length:
286 | # Interpolation
287 | x = np.linspace(0, 1, original_length)
288 | f = interpolate.interp1d(x, repeated_data, kind='linear')
289 | x_new = np.linspace(0, 1, target_length)
290 | processed_ekg[1] = f(x_new)
291 | elif original_length > target_length:
292 | # Resampling
293 | x = np.linspace(0, 1, original_length)
294 | x_new = np.linspace(0, 1, target_length)
295 | processed_ekg[1] = np.interp(x_new, x, repeated_data)
296 | else:
297 | processed_ekg[1] = repeated_data[:target_length]
298 |
299 | return processed_ekg
300 |
301 |
302 | def video_to_tensor(
303 | path: str,
304 | transform, # Path of the video to be imported
305 | num_frames=-1, # Number of frames to be stored in the output tensor
306 | crop_size=None
307 | ) -> torch.Tensor: # shape (1, channels, frames, height, width)
308 |
309 | video = cv2.VideoCapture(path)
310 |
311 | total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
312 |
313 | # print ("PATH", path)
314 | # print ("TOTAL frame : ",total_frames )
315 | frames = []
316 | check = True
317 |
318 | shear_x = random.uniform(-5, 5)
319 | shear_y = random.uniform(-5, 5)
320 | contrast_factor = random.uniform(0.6, 1.4)
321 |
322 | while check:
323 | check, frame = video.read()
324 |
325 | if not check:
326 | continue
327 | # frame = np.transpose(frame, (2, 0, 1))
328 | # 고정된 augmentation 값들로 transform 적용
329 | # frame = transform(frame, shear_x, shear_y, contrast_factor)
330 | frame = transform(frame)
331 | frames.append(rearrange(frame, '... -> 1 ...'))
332 |
333 | # convert list of frames to numpy array
334 | frames = np.array(np.concatenate(frames, axis=0))
335 | # frames = rearrange(frames, 'f c h w -> c f h w')
336 | frames = rearrange(frames, 'f c h w -> c f h w')
337 |
338 | frames_torch = torch.tensor(frames).float()
339 |
340 | return frames_torch
341 |
342 |
343 |
344 | def process_ultrasound_image(video_tensor):
345 | # 입력 텐서의 shape 확인
346 | B, T, H, W = video_tensor.shape
347 |
348 | # 결과를 저장할 텐서 초기화
349 | result = torch.zeros_like(video_tensor)
350 |
351 | for b in range(B):
352 | for t in range(T):
353 |
354 | # 현재 프레임 추출
355 | frame = video_tensor[b, t].cpu().numpy() #shape of 128 128
356 |
357 | # frame_ = (frame*255).astype(np.uint8)
358 |
359 | # Save the grayscale image using OpenCV
360 | # output_path_gray = "/home/local/PARTNERS/sk1064/project/EchoHub/dataset/frame_image_gray.jpg"
361 | # cv2.imwrite(output_path_gray, frame_)
362 |
363 | # If you want to convert the numpy array to PIL Image and save it
364 |
365 | # print ("FRAME SIZE CHECK " *10, np.max(frame_))
366 | # 임계값 설정 (이 값은 이미지에 따라 조정이 필요할 수 있습니다)
367 | threshold = 0.1
368 |
369 | # 각 열에서 첫 번째로 임계값을 넘는 픽셀 찾기
370 | first_pixels = np.argmax(frame > threshold, axis=0)
371 |
372 | # 가장 위에 있는 픽셀의 y 좌표 찾기
373 | top_y = np.min(first_pixels[first_pixels > 0])
374 |
375 | # 128x128 크기로 자르기
376 | # cropped = frame[top_y:top_y+128, :128]
377 | cropped = frame[top_y:top_y+224, :224]
378 |
379 | # 크기가 128x128이 아닌 경우 인터폴레이션을 사용하여 리사이징
380 | # if cropped.shape != (128, 128):
381 | # cropped = cv2.resize(cropped, (128, 128), interpolation=cv2.INTER_LINEAR)
382 | if cropped.shape != (224, 224):
383 | cropped = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_LINEAR)
384 |
385 | # 결과 텐서에 저장
386 | result[b, t] = torch.from_numpy(cropped).float()
387 |
388 | return result
389 |
390 | class EchoDataset_from_Video_mp4(Dataset):
391 | def __init__(
392 | self,
393 | folder,
394 | image_size = [224, 224],
395 | channels = 3,
396 | ):
397 | super().__init__()
398 | self.folder = folder
399 |
400 | self.image_size = image_size
401 | self.channels = channels
402 |
403 | def apply_augmentation(img, shear_x, shear_y, contrast_factor):
404 | # Apply contrast augmentation
405 | img = T.functional.adjust_contrast(img, contrast_factor)
406 |
407 | # # Apply shear x, y augmentation
408 | img = T.functional.affine(img, angle=0, translate=[0, 0], scale=1.0, shear=[shear_x, shear_y])
409 |
410 | return img
411 |
412 | def create_transform(image_size):
413 | # def transform(img, shear_x, shear_y, contrast_factor):
414 | def transform(img):
415 | if not isinstance(img, Image.Image):
416 | img = T.ToPILImage()(img)
417 | img = T.Resize(image_size)(img)
418 | # img = apply_augmentation(img, shear_x, shear_y, contrast_factor)
419 | return T.ToTensor()(img)
420 | return transform
421 |
422 | self.transform_for_videos = create_transform(self.image_size)
423 |
424 | self.transform = T.Compose([
425 | T.Resize(image_size),
426 | T.ToTensor()
427 | ])
428 | self.paths = os.listdir(folder)
429 |
430 | self.mp4_to_tensor = partial(
431 | video_to_tensor, transform=self.transform_for_videos, crop_size=self.image_size, num_frames=10)
432 |
433 | force_num_frames = True
434 |
435 | self.cast_num_frames_fn = partial(
436 | cast_num_frames, frames=32) if force_num_frames else identity
437 |
438 | def __len__(self):
439 | return len(self.paths)
440 |
441 | def __getitem__(self, index):
442 | path = self.paths[index]
443 |
444 | path = os.path.join(self.folder, path)
445 | tensor = self.mp4_to_tensor(str(path))
446 |
447 | tensor = self.cast_num_frames_fn(tensor)
448 |
449 | # print ("Check Final output : ", tensor.size())
450 |
451 | # save_tensor_as_grid(tensor, grid_size=4, save_path="grid_image.png")
452 | # data = {"image":tensor , "p_id" : self.paths[index]}
453 | return tensor
454 |
455 |
456 | class EchoDataset_from_Video(Dataset):
457 | def __init__(
458 | self,
459 | folder,
460 | image_size,
461 | channels=3,
462 | num_frames=11,
463 | horizontal_flip=False,
464 | force_num_frames=True,
465 | exts=['gif', 'mp4'],
466 | sample_texts=None # 新增参数
467 | ):
468 | super().__init__()
469 | self.folder = os.path.join(folder, "mp4")
470 | self.folder_ekg= os.path.join(folder, "ekg")
471 |
472 | self.image_size = image_size
473 | self.channels = channels
474 | self.paths = [p for ext in exts for p in Path(
475 | f'{folder}').glob(f'**/*.{ext}')]
476 | self.paths.sort(key=sort_key)
477 | self.sample_texts = sample_texts
478 | self.transform = T.Compose([
479 | T.Resize(image_size),
480 | T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
481 | T.CenterCrop(image_size),
482 | T.ToTensor()
483 | ])
484 |
485 | # TODO: rework so it is faster, for now it works but is bad
486 | # self.transform_for_videos = T.Compose([
487 | # T.ToPILImage(), # added to PIL conversion because video is read with cv2
488 | # T.Resize(image_size),
489 | # # T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity),
490 | # T.ToTensor()
491 | # ])
492 |
493 | self.gif_to_tensor = partial(
494 | gif_to_tensor, channels=self.channels, transform=self.transform)
495 | # self.mp4_to_tensor = partial(
496 | # video_to_tensor, transform=self.transform_for_videos, crop_size=self.image_size, num_frames=num_frames)
497 |
498 |
499 | self.cast_num_frames_fn = partial(
500 | cast_num_frames, frames=num_frames) if force_num_frames else identity
501 |
502 |
503 | def apply_augmentation(img, shear_x, shear_y, contrast_factor):
504 | # Apply contrast augmentation
505 | img = T.functional.adjust_contrast(img, contrast_factor)
506 |
507 | # # Apply shear x, y augmentation
508 | img = T.functional.affine(img, angle=0, translate=[0, 0], scale=1.0, shear=[shear_x, shear_y])
509 |
510 | return img
511 |
512 | def create_transform(image_size):
513 | def transform(img, shear_x, shear_y, contrast_factor):
514 | if not isinstance(img, Image.Image):
515 | img = T.ToPILImage()(img)
516 | img = T.Resize(image_size)(img)
517 | img = apply_augmentation(img, shear_x, shear_y, contrast_factor)
518 | return T.ToTensor()(img)
519 | return transform
520 |
521 | self.transform_for_videos = create_transform(image_size)
522 | self.mp4_to_tensor = partial(
523 | video_to_tensor, transform=self.transform_for_videos, crop_size=self.image_size, num_frames=num_frames)
524 |
525 | def __len__(self):
526 | return len(self.paths)
527 |
528 | def __getitem__(self, index):
529 | path = self.paths[index]
530 |
531 | ext = path.suffix
532 |
533 | if ext == '.gif':
534 | tensor = self.gif_to_tensor(path)
535 | elif ext == '.mp4':
536 | tensor = self.mp4_to_tensor(str(path))
537 | else:
538 | raise ValueError(f'unknown extension {ext}')
539 |
540 | tensor = self.cast_num_frames_fn(tensor)
541 |
542 | return tensor
543 |
544 | def collate_tensors_and_strings(batch):
545 | tensors, ekgs = zip(*batch)
546 |
547 | # Process tensors (assuming they are already in the correct format)
548 | tensors = torch.stack(tensors, dim=0)
549 |
550 | # Process EKGs
551 | processed_ekgs = []
552 | for ekg in ekgs:
553 | processed_ekgs.append(ekg)
554 |
555 | processed_ekgs = torch.stack(processed_ekgs, dim=0)
556 |
557 | # print ("BATCH COLLATE ", tensors.size(), processed_ekgs.size())
558 |
559 | return tensors, processed_ekgs
560 |
561 |
562 | def DataLoader(*args, **kwargs):
563 | return PytorchDataLoader(*args, collate_fn=collate_tensors_and_strings, **kwargs)
564 |
--------------------------------------------------------------------------------
/EchoFM/models_mae.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # DeiT: https://github.com/facebookresearch/deit
10 | # MAE: https://github.com/facebookresearch/mae
11 | # --------------------------------------------------------
12 |
13 | from functools import partial
14 |
15 | import torch
16 | import torch.nn as nn
17 | from EchoFM.util import video_vit
18 | from EchoFM.util.logging import master_print as print
19 | import torch.nn.functional as F
20 |
21 | class MaskedAutoencoderViT(nn.Module):
22 | """Masked Autoencoder with VisionTransformer backbone"""
23 |
24 | def __init__(
25 | self,
26 | img_size=224,
27 | patch_size=16,
28 | in_chans=3,
29 | embed_dim=1024,
30 | depth=24,
31 | num_heads=16,
32 | decoder_embed_dim=512,
33 | decoder_depth=8,
34 | decoder_num_heads=16,
35 | mlp_ratio=4.0,
36 | norm_layer=nn.LayerNorm,
37 | norm_pix_loss=False,
38 | num_frames=16,
39 | t_patch_size=4,
40 | patch_embed=video_vit.PatchEmbed,
41 | no_qkv_bias=False,
42 | sep_pos_embed=False,
43 | trunc_init=False,
44 | cls_embed=False,
45 | pred_t_dim=8,
46 | **kwargs,
47 | ):
48 | super().__init__()
49 | self.trunc_init = trunc_init
50 | self.sep_pos_embed = sep_pos_embed
51 | self.cls_embed = cls_embed
52 | self.pred_t_dim = pred_t_dim
53 | self.t_pred_patch_size = t_patch_size * pred_t_dim // num_frames
54 |
55 | self.patch_embed = patch_embed(
56 | img_size,
57 | patch_size,
58 | in_chans,
59 | embed_dim,
60 | num_frames,
61 | t_patch_size,
62 | )
63 | num_patches = self.patch_embed.num_patches
64 | input_size = self.patch_embed.input_size
65 | self.input_size = input_size
66 |
67 | if self.cls_embed:
68 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
69 | self.decoder_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
70 | self.decoder_prj_cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
71 |
72 | if sep_pos_embed:
73 | self.pos_embed_spatial = nn.Parameter(
74 | torch.zeros(1, input_size[1] * input_size[2], embed_dim)
75 | )
76 | self.pos_embed_temporal = nn.Parameter(
77 | torch.zeros(1, input_size[0], embed_dim)
78 | )
79 | if self.cls_embed:
80 | self.pos_embed_class = nn.Parameter(torch.zeros(1, 1, embed_dim))
81 | else:
82 | if self.cls_embed:
83 | _num_patches = num_patches + 1
84 | else:
85 | _num_patches = num_patches
86 |
87 | self.pos_embed = nn.Parameter(
88 | torch.zeros(1, _num_patches, embed_dim),
89 | )
90 |
91 | self.blocks = nn.ModuleList(
92 | [
93 | video_vit.Block(
94 | embed_dim,
95 | num_heads,
96 | mlp_ratio,
97 | qkv_bias=not no_qkv_bias,
98 | qk_scale=None,
99 | norm_layer=norm_layer,
100 | )
101 | for i in range(depth)
102 | ]
103 | )
104 |
105 | self.decoder_block = video_vit.Block(
106 | embed_dim,
107 | num_heads,
108 | mlp_ratio,
109 | qkv_bias=not no_qkv_bias,
110 | qk_scale=None,
111 | norm_layer=norm_layer,
112 | )
113 |
114 | self.norm = norm_layer(embed_dim)
115 |
116 | self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
117 |
118 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
119 |
120 | if sep_pos_embed:
121 | self.decoder_pos_embed_spatial = nn.Parameter(
122 | torch.zeros(1, input_size[1] * input_size[2], decoder_embed_dim)
123 | )
124 | self.decoder_pos_embed_temporal = nn.Parameter(
125 | torch.zeros(1, input_size[0], decoder_embed_dim)
126 | )
127 | if self.cls_embed:
128 | self.decoder_pos_embed_class = nn.Parameter(
129 | torch.zeros(1, 1, decoder_embed_dim)
130 | )
131 | else:
132 | if self.cls_embed:
133 | _num_patches = num_patches + 1
134 | else:
135 | _num_patches = num_patches
136 |
137 | self.decoder_pos_embed = nn.Parameter(
138 | torch.zeros(1, _num_patches, decoder_embed_dim),
139 | )
140 |
141 | self.decoder_blocks = nn.ModuleList(
142 | [
143 | video_vit.Block(
144 | decoder_embed_dim,
145 | decoder_num_heads,
146 | mlp_ratio,
147 | qkv_bias=not no_qkv_bias,
148 | qk_scale=None,
149 | norm_layer=norm_layer,
150 | )
151 | for i in range(decoder_depth)
152 | ]
153 | )
154 |
155 | self.decoder_norm = norm_layer(decoder_embed_dim)
156 | self.decoder_pred = nn.Linear(
157 | decoder_embed_dim,
158 | self.t_pred_patch_size * patch_size**2 * in_chans,
159 | bias=True,
160 | )
161 |
162 | self.norm_pix_loss = norm_pix_loss
163 |
164 | self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
165 | self.initialize_weights()
166 |
167 |
168 | print("model initialized")
169 |
170 | def self_similarity(self, cls_tokens):
171 | """
172 | Compute self-similarity map using cosine similarity.
173 |
174 | Args:
175 | cls_tokens (list of tensors): List of tensors, where each tensor is of shape [N, D].
176 |
177 | Returns:
178 | similarity_map (tensor): Tensor of shape [N, T, T] containing self-similarity values.
179 | """
180 | # Concatenate the list into a single tensor of shape [N, T, D]
181 | cls_tokens_tensor = torch.stack(cls_tokens, dim=1) # Shape: [N, T, D]
182 |
183 | # Normalize embeddings to unit vectors
184 | cls_tokens_tensor = F.normalize(cls_tokens_tensor, p=2, dim=-1) # Shape: [N, T, D]
185 |
186 | # Compute cosine similarity
187 | similarity_map = torch.matmul(cls_tokens_tensor, cls_tokens_tensor.transpose(1, 2)) # Shape: [N, T, T]
188 |
189 | return similarity_map
190 |
191 | def triplet_sampling(self, similarity_map, cls_tokens):
192 | """
193 | Perform triplet sampling with one anchor, one positive, and one negative per batch.
194 |
195 | Args:
196 | similarity_map (tensor): Self-similarity map of shape [N, T, T].
197 | cls_tokens (tensor): Tensor of CLS tokens, shape [N, T, D].
198 |
199 | Returns:
200 | anchor (tensor): Tensor of anchor embeddings, shape [N, D].
201 | positive (tensor): Tensor of positive embeddings, shape [N, D].
202 | negative (tensor): Tensor of negative embeddings, shape [N, D].
203 | """
204 |
205 | cls_tokens = torch.stack(cls_tokens, dim=1) # Shape: [N, T, D]
206 |
207 | N, T, D = cls_tokens.shape
208 |
209 | anchors, positives, negatives = [], [], []
210 |
211 | for n in range(N): # Iterate over batches
212 | # Extract the first row (anchor is always index 0)
213 | first_row = similarity_map[n, 0, :] # Shape: [T]
214 |
215 | # Compute mean similarity for the first row
216 | mean_similarity = first_row.mean().item()
217 |
218 | # Identify positive and negative indices, excluding anchor index (0)
219 | positive_indices = (first_row > mean_similarity).nonzero(as_tuple=True)[0]
220 | positive_indices = positive_indices[positive_indices != 0] # Exclude anchor (index 0)
221 |
222 | negative_indices = (first_row <= mean_similarity).nonzero(as_tuple=True)[0]
223 | negative_indices = negative_indices[negative_indices != 0] # Exclude anchor (index 0)
224 |
225 | # Ensure we have at least one positive and one negative
226 | if len(positive_indices) > 0 and len(negative_indices) > 0:
227 | # Randomly select one positive and one negative
228 | pos_idx = positive_indices[torch.randint(len(positive_indices), (1,))].item()
229 | neg_idx = negative_indices[torch.randint(len(negative_indices), (1,))].item()
230 |
231 | # Append CLS tokens for the selected indices
232 | anchors.append(cls_tokens[n, 0, :]) # Anchor is always index 0
233 | positives.append(cls_tokens[n, pos_idx, :]) # Positive CLS token
234 | negatives.append(cls_tokens[n, neg_idx, :]) # Negative CLS token
235 |
236 | # Stack tensors to create final batch outputs
237 | anchor = torch.stack(anchors) # Shape: [N, D]
238 | positive = torch.stack(positives) # Shape: [N, D]
239 | negative = torch.stack(negatives) # Shape: [N, D]
240 |
241 | return anchor, positive, negative
242 |
243 | def initialize_weights(self):
244 | if self.cls_embed:
245 | torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
246 | if self.sep_pos_embed:
247 | torch.nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
248 | torch.nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
249 |
250 | torch.nn.init.trunc_normal_(self.decoder_pos_embed_spatial, std=0.02)
251 | torch.nn.init.trunc_normal_(self.decoder_pos_embed_temporal, std=0.02)
252 |
253 | if self.cls_embed:
254 | torch.nn.init.trunc_normal_(self.pos_embed_class, std=0.02)
255 | torch.nn.init.trunc_normal_(self.decoder_pos_embed_class, std=0.02)
256 | else:
257 | torch.nn.init.trunc_normal_(self.pos_embed, std=0.02)
258 | torch.nn.init.trunc_normal_(self.decoder_pos_embed, std=0.02)
259 | w = self.patch_embed.proj.weight.data
260 | if self.trunc_init:
261 | torch.nn.init.trunc_normal_(w)
262 | torch.nn.init.trunc_normal_(self.mask_token, std=0.02)
263 | else:
264 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
265 | torch.nn.init.normal_(self.mask_token, std=0.02)
266 |
267 | # initialize nn.Linear and nn.LayerNorm
268 | self.apply(self._init_weights)
269 |
270 | def _init_weights(self, m):
271 | if isinstance(m, nn.Linear):
272 | # we use xavier_uniform following official JAX ViT:
273 | if self.trunc_init:
274 | nn.init.trunc_normal_(m.weight, std=0.02)
275 | else:
276 | torch.nn.init.xavier_uniform_(m.weight)
277 | if isinstance(m, nn.Linear) and m.bias is not None:
278 | nn.init.constant_(m.bias, 0)
279 | elif isinstance(m, nn.LayerNorm):
280 | nn.init.constant_(m.bias, 0)
281 | nn.init.constant_(m.weight, 1.0)
282 |
283 | def patchify(self, imgs):
284 | """
285 | imgs: (N, 3, H, W)
286 | x: (N, L, patch_size**2 *3)
287 | """
288 | N, _, T, H, W = imgs.shape
289 | p = self.patch_embed.patch_size[0]
290 | u = self.t_pred_patch_size
291 | assert H == W and H % p == 0 and T % u == 0
292 | h = w = H // p
293 | t = T // u
294 |
295 | x = imgs.reshape(shape=(N, 3, t, u, h, p, w, p))
296 | x = torch.einsum("nctuhpwq->nthwupqc", x)
297 | x = x.reshape(shape=(N, t * h * w, u * p**2 * 3))
298 | self.patch_info = (N, T, H, W, p, u, t, h, w)
299 | return x
300 |
301 | def unpatchify(self, x):
302 | """
303 | x: (N, L, patch_size**2 *3)
304 | imgs: (N, 3, H, W)
305 | """
306 | N, T, H, W, p, u, t, h, w = self.patch_info
307 |
308 | x = x.reshape(shape=(N, t, h, w, u, p, p, 3))
309 |
310 | x = torch.einsum("nthwupqc->nctuhpwq", x)
311 | imgs = x.reshape(shape=(N, 3, T, H, W))
312 | return imgs
313 |
314 | def uniform_random_masking(self, x, mask_ratio, L):
315 | """
316 | Perform temporal consistent random masking by sampling the same spatial tokens across time steps.
317 | Args:
318 | x: Tensor of shape [N, T * L, D], sequence after patch embedding (flattened temporal and spatial dimensions).
319 | mask_ratio: Float, proportion of tokens to mask.
320 | L: Number of spatial tokens per time step.
321 |
322 | Returns:
323 | x_masked: Tensor of shape [N, len_keep * T, D], after masking.
324 | mask: Binary mask of shape [N, T * L], 0 is keep, 1 is remove.
325 | ids_restore: Indices to restore original sequence order.
326 | ids_keep: Indices of kept tokens.
327 | """
328 | N, TL, D = x.shape # Batch size, total tokens, embedding dimension
329 | T = TL // L # Temporal length
330 |
331 | # Compute the number of tokens to keep per spatial location
332 | len_keep = int(L * (1 - mask_ratio))
333 |
334 | # Generate random noise for each spatial location
335 | noise = torch.rand(N, L, device=x.device) # [N, L]
336 |
337 | # Sort spatial tokens based on noise
338 | ids_shuffle = torch.argsort(noise, dim=1) # [N, L]
339 | ids_keep = ids_shuffle[:, :len_keep] # Keep top len_keep indices [N, len_keep]
340 | ids_keep = ids_keep.unsqueeze(1).repeat(1, T, 1) # Broadcast to all time steps [N, T, len_keep]
341 |
342 | # Create a binary mask for all time steps
343 | mask = torch.ones(N, T, L, device=x.device) # Initialize mask with all 1s [N, T, L]
344 |
345 | for n in range(N): # Iterate over batch
346 | for t in range(T):
347 | mask[n, t, ids_keep[n]] = 0 # Use batch-specific ids_keep[n]
348 |
349 | mask = mask.view(N, TL) # Flatten to [N, T * L]
350 | ids_restore = torch.argsort(mask, dim=1) # Indices for restoring order
351 |
352 | # Mask input
353 | x_masked = x[mask == 0].view(N, -1, D) # Kept tokens only [N, len_keep * T, D]
354 |
355 | ids_keep = ids_keep.view(N, -1)
356 | return x_masked, mask, ids_restore, ids_keep
357 |
358 | def random_masking(self, x, mask_ratio):
359 | """
360 | Perform per-sample random masking by per-sample shuffling.
361 | Per-sample shuffling is done by argsort random noise.
362 | x: [N, L, D], sequence
363 | """
364 | N, L, D = x.shape # batch, length, dim
365 | len_keep = int(L * (1 - mask_ratio))
366 |
367 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
368 |
369 | # sort noise for each sample
370 | ids_shuffle = torch.argsort(
371 | noise, dim=1
372 | ) # ascend: small is keep, large is remove
373 | ids_restore = torch.argsort(ids_shuffle, dim=1)
374 |
375 | # keep the first subset
376 | ids_keep = ids_shuffle[:, :len_keep]
377 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
378 |
379 | # generate the binary mask: 0 is keep, 1 is remove
380 | mask = torch.ones([N, L], device=x.device)
381 | mask[:, :len_keep] = 0
382 | # unshuffle to get the binary mask
383 | mask = torch.gather(mask, dim=1, index=ids_restore)
384 |
385 | return x_masked, mask, ids_restore, ids_keep
386 |
387 | def forward_encoder(self, x, mask_ratio):
388 | # embed patches
389 | x = self.patch_embed(x)
390 | N, T, L, C = x.shape
391 |
392 | x = x.reshape(N, T * L, C)
393 |
394 | # masking: length -> length * mask_ratio
395 | # x, mask, ids_restore, ids_keep = self.random_masking(x, mask_ratio)
396 |
397 | x, mask, ids_restore, ids_keep = self.uniform_random_masking(x, mask_ratio, L)
398 |
399 | x = x.view(N, -1, C)
400 | # append cls token
401 | if self.cls_embed:
402 | cls_token = self.cls_token
403 | cls_tokens = cls_token.expand(x.shape[0], -1, -1)
404 | x = torch.cat((cls_tokens, x), dim=1)
405 |
406 | # add pos embed w/o cls token
407 | if self.sep_pos_embed:
408 | pos_embed = self.pos_embed_spatial.repeat(
409 | 1, self.input_size[0], 1
410 | ) + torch.repeat_interleave(
411 | self.pos_embed_temporal,
412 | self.input_size[1] * self.input_size[2],
413 | dim=1,
414 | )
415 | pos_embed = pos_embed.expand(x.shape[0], -1, -1)
416 | pos_embed = torch.gather(
417 | pos_embed,
418 | dim=1,
419 | index=ids_keep.unsqueeze(-1).repeat(1, 1, pos_embed.shape[2]),
420 | )
421 | if self.cls_embed:
422 | pos_embed = torch.cat(
423 | [
424 | self.pos_embed_class.expand(pos_embed.shape[0], -1, -1),
425 | pos_embed,
426 | ],
427 | 1,
428 | )
429 | else:
430 | if self.cls_embed:
431 | cls_ind = 1
432 | else:
433 | cls_ind = 0
434 | pos_embed = self.pos_embed[:, cls_ind:, :].expand(x.shape[0], -1, -1)
435 | pos_embed = torch.gather(
436 | pos_embed,
437 | dim=1,
438 | index=ids_keep.unsqueeze(-1).repeat(1, 1, pos_embed.shape[2]),
439 | )
440 | if self.cls_embed:
441 | pos_embed = torch.cat(
442 | [
443 | self.pos_embed[:, :1, :].expand(x.shape[0], -1, -1),
444 | pos_embed,
445 | ],
446 | 1,
447 | )
448 | x = x.view([N, -1, C]) + pos_embed
449 |
450 | # apply Transformer blocks
451 | for blk in self.blocks:
452 | x = blk(x)
453 | x = self.norm(x)
454 |
455 | if self.cls_embed:
456 | # remove cls token
457 | x = x[:, 1:, :]
458 | else:
459 | x = x[:, :, :]
460 |
461 | return x, mask, ids_restore
462 |
463 | def decoder_prj(self, x):
464 | # apply Transformer blocks
465 | x = self.decoder_block(x)
466 | x = self.norm(x)
467 |
468 | if self.cls_embed:
469 | return x[:, 0, :]
470 | else:
471 | print ('CLS token is needed')
472 |
473 |
474 | def forward_prj(self, x, ids_restore):
475 | N = x.shape[0]
476 | T = self.patch_embed.t_grid_size
477 | H = W = self.patch_embed.grid_size
478 |
479 | # embed tokens (divide to temporal)
480 |
481 | # x = 4 392 1024
482 |
483 | # x = reshape() -> 4 8 49 1024
484 | x = x.view(N, T, 49, 1024)
485 |
486 | cls_ = []
487 | for i in range(T):
488 | x_t = x[:,i,:,:]
489 |
490 | if self.cls_embed:
491 | decoder_cls_token = self.decoder_prj_cls_token
492 | decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1)
493 | x_t = torch.cat((decoder_cls_tokens, x_t), dim=1)
494 |
495 | x_t_cls = self.decoder_prj(x_t) #vit
496 | cls_.append(x_t_cls)
497 | return cls_
498 |
499 | def forward_decoder(self, x, ids_restore):
500 | N = x.shape[0]
501 | T = self.patch_embed.t_grid_size
502 | H = W = self.patch_embed.grid_size
503 |
504 | # embed tokens
505 | x = self.decoder_embed(x)
506 | C = x.shape[-1]
507 |
508 | # append mask tokens to sequence
509 | mask_tokens = self.mask_token.repeat(N, T * H * W + 0 - x.shape[1], 1)
510 | x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
511 | x_ = x_.view([N, T * H * W, C])
512 | x_ = torch.gather(
513 | x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x_.shape[2])
514 | ) # unshuffle
515 | x = x_.view([N, T * H * W, C])
516 | # append cls token
517 | if self.cls_embed:
518 | decoder_cls_token = self.decoder_cls_token
519 | decoder_cls_tokens = decoder_cls_token.expand(x.shape[0], -1, -1)
520 | x = torch.cat((decoder_cls_tokens, x), dim=1)
521 |
522 | if self.sep_pos_embed:
523 | decoder_pos_embed = self.decoder_pos_embed_spatial.repeat(
524 | 1, self.input_size[0], 1
525 | ) + torch.repeat_interleave(
526 | self.decoder_pos_embed_temporal,
527 | self.input_size[1] * self.input_size[2],
528 | dim=1,
529 | )
530 | if self.cls_embed:
531 | decoder_pos_embed = torch.cat(
532 | [
533 | self.decoder_pos_embed_class.expand(
534 | decoder_pos_embed.shape[0], -1, -1
535 | ),
536 | decoder_pos_embed,
537 | ],
538 | 1,
539 | )
540 | else:
541 | decoder_pos_embed = self.decoder_pos_embed[:, :, :]
542 |
543 | # add pos embed
544 | x = x + decoder_pos_embed
545 |
546 | attn = self.decoder_blocks[0].attn
547 | requires_t_shape = hasattr(attn, "requires_t_shape") and attn.requires_t_shape
548 | if requires_t_shape:
549 | x = x.view([N, T, H * W, C])
550 |
551 | # apply Transformer blocks
552 | for blk in self.decoder_blocks:
553 | x = blk(x)
554 | x = self.decoder_norm(x)
555 |
556 | # predictor projection
557 | x = self.decoder_pred(x)
558 |
559 | if requires_t_shape:
560 | x = x.view([N, T * H * W, -1])
561 |
562 | if self.cls_embed:
563 | # remove cls token
564 | x = x[:, 1:, :]
565 | else:
566 | x = x[:, :, :]
567 |
568 | return x
569 |
570 | def forward_loss(self, imgs, pred, mask):
571 | """
572 | imgs: [N, 3, T, H, W]
573 | pred: [N, t*h*w, u*p*p*3]
574 | mask: [N*t, h*w], 0 is keep, 1 is remove,
575 | """
576 | _imgs = torch.index_select(
577 | imgs,
578 | 2,
579 | torch.linspace(
580 | 0,
581 | imgs.shape[2] - 1,
582 | self.pred_t_dim,
583 | )
584 | .long()
585 | .to(imgs.device),
586 | )
587 | target = self.patchify(_imgs)
588 | if self.norm_pix_loss:
589 | mean = target.mean(dim=-1, keepdim=True)
590 | var = target.var(dim=-1, keepdim=True)
591 | target = (target - mean) / (var + 1.0e-6) ** 0.5
592 |
593 | loss = (pred - target) ** 2
594 | loss = loss.mean(dim=-1) # [N, L], mean loss per patch
595 | mask = mask.view(loss.shape)
596 |
597 | loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
598 |
599 | return loss
600 |
601 | def forward(self, imgs, mask_ratio=0.75):
602 | latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
603 |
604 | cls_tokens = self.forward_prj(latent, ids_restore)
605 |
606 | similarity_map = self.self_similarity(cls_tokens)
607 |
608 | anchor, positive, negative = self.triplet_sampling(similarity_map, cls_tokens)
609 |
610 | # triplet sampling
611 | triplet_loss = self.triplet_loss(anchor, positive, negative)
612 |
613 | pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
614 | loss = self.forward_loss(imgs, pred, mask)
615 |
616 | loss = loss + triplet_loss
617 | return loss, pred, mask
618 |
619 |
620 | def mae_vit_base_patch16(**kwargs):
621 | model = MaskedAutoencoderViT(
622 | patch_size=16,
623 | embed_dim=768,
624 | depth=12,
625 | num_heads=12,
626 | mlp_ratio=4,
627 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
628 | **kwargs,
629 | )
630 | return model
631 |
632 |
633 | def mae_vit_large_patch16(**kwargs):
634 | model = MaskedAutoencoderViT(
635 | patch_size=16,
636 | embed_dim=1024,
637 | depth=24,
638 | num_heads=16,
639 | mlp_ratio=4,
640 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
641 | **kwargs,
642 | )
643 | return model
644 |
645 |
646 | def mae_vit_huge_patch14(**kwargs):
647 | model = MaskedAutoencoderViT(
648 | patch_size=14,
649 | embed_dim=1280,
650 | depth=32,
651 | num_heads=16,
652 | mlp_ratio=4,
653 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
654 | **kwargs,
655 | )
656 | return model
657 |
--------------------------------------------------------------------------------
/EchoFM/util/decoder/transform.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 |
6 | import math
7 |
8 | # import cv2
9 | import random
10 |
11 | import numpy as np
12 | import torch
13 | import torchvision.transforms.functional as F
14 | from PIL import Image
15 | from torchvision import transforms
16 |
17 | from .rand_augment import rand_augment_transform
18 |
19 | _pil_interpolation_to_str = {
20 | Image.NEAREST: "PIL.Image.NEAREST",
21 | Image.BILINEAR: "PIL.Image.BILINEAR",
22 | Image.BICUBIC: "PIL.Image.BICUBIC",
23 | Image.LANCZOS: "PIL.Image.LANCZOS",
24 | Image.HAMMING: "PIL.Image.HAMMING",
25 | Image.BOX: "PIL.Image.BOX",
26 | }
27 |
28 |
29 | _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
30 |
31 |
32 | def _pil_interp(method):
33 | if method == "bicubic":
34 | return Image.BICUBIC
35 | elif method == "lanczos":
36 | return Image.LANCZOS
37 | elif method == "hamming":
38 | return Image.HAMMING
39 | else:
40 | return Image.BILINEAR
41 |
42 |
43 | def random_short_side_scale_jitter(
44 | images, min_size, max_size, inverse_uniform_sampling=False
45 | ):
46 | """
47 | Perform a spatial short scale jittering on the given images.
48 | Args:
49 | images (tensor): images to perform scale jitter. Dimension is
50 | `num frames` x `channel` x `height` x `width`.
51 | min_size (int): the minimal size to scale the frames.
52 | max_size (int): the maximal size to scale the frames.
53 | inverse_uniform_sampling (bool): if True, sample uniformly in
54 | [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
55 | scale. If False, take a uniform sample from [min_scale, max_scale].
56 | Returns:
57 | (tensor): the scaled images with dimension of
58 | `num frames` x `channel` x `new height` x `new width`.
59 | """
60 | if inverse_uniform_sampling:
61 | size = int(round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)))
62 | else:
63 | size = int(round(np.random.uniform(min_size, max_size)))
64 |
65 | height = images.shape[2]
66 | width = images.shape[3]
67 | if (width <= height and width == size) or (height <= width and height == size):
68 | return images
69 | new_width = size
70 | new_height = size
71 | if width < height:
72 | new_height = int(math.floor((float(height) / width) * size))
73 | else:
74 | new_width = int(math.floor((float(width) / height) * size))
75 | return torch.nn.functional.interpolate(
76 | images,
77 | size=(new_height, new_width),
78 | mode="bilinear",
79 | align_corners=False,
80 | )
81 |
82 |
83 | def random_crop(images, size):
84 | """
85 | Perform random spatial crop on the given images.
86 | Args:
87 | images (tensor): images to perform random crop. The dimension is
88 | `num frames` x `channel` x `height` x `width`.
89 | size (int): the size of height and width to crop on the image.
90 | Returns:
91 | cropped (tensor): cropped images with dimension of
92 | `num frames` x `channel` x `size` x `size`.
93 | """
94 | if images.shape[2] == size and images.shape[3] == size:
95 | return images
96 | height = images.shape[2]
97 | width = images.shape[3]
98 | y_offset = 0
99 | if height > size:
100 | y_offset = int(np.random.randint(0, height - size))
101 | x_offset = 0
102 | if width > size:
103 | x_offset = int(np.random.randint(0, width - size))
104 | cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
105 | return cropped
106 |
107 |
108 | def horizontal_flip(prob, images):
109 | """
110 | Perform horizontal flip on the given images.
111 | Args:
112 | prob (float): probility to flip the images.
113 | images (tensor): images to perform horizontal flip, the dimension is
114 | `num frames` x `channel` x `height` x `width`.
115 | Returns:
116 | images (tensor): images with dimension of
117 | `num frames` x `channel` x `height` x `width`.
118 | """
119 | if np.random.uniform() < prob:
120 | images = images.flip((-1))
121 | return images
122 |
123 |
124 | def uniform_crop(images, size, spatial_idx, scale_size=None):
125 | """
126 | Perform uniform spatial sampling on the images.
127 | Args:
128 | images (tensor): images to perform uniform crop. The dimension is
129 | `num frames` x `channel` x `height` x `width`.
130 | size (int): size of height and weight to crop the images.
131 | spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
132 | is larger than height. Or 0, 1, or 2 for top, center, and bottom
133 | crop if height is larger than width.
134 | scale_size (int): optinal. If not None, resize the images to scale_size before
135 | performing any crop.
136 | Returns:
137 | cropped (tensor): images with dimension of
138 | `num frames` x `channel` x `size` x `size`.
139 | """
140 | assert spatial_idx in [0, 1, 2]
141 | ndim = len(images.shape)
142 | if ndim == 3:
143 | images = images.unsqueeze(0)
144 | height = images.shape[2]
145 | width = images.shape[3]
146 |
147 | if scale_size is not None:
148 | if width <= height:
149 | width, height = scale_size, int(height / width * scale_size)
150 | else:
151 | width, height = int(width / height * scale_size), scale_size
152 | images = torch.nn.functional.interpolate(
153 | images,
154 | size=(height, width),
155 | mode="bilinear",
156 | align_corners=False,
157 | )
158 |
159 | y_offset = int(math.ceil((height - size) / 2))
160 | x_offset = int(math.ceil((width - size) / 2))
161 |
162 | if height > width:
163 | if spatial_idx == 0:
164 | y_offset = 0
165 | elif spatial_idx == 2:
166 | y_offset = height - size
167 | else:
168 | if spatial_idx == 0:
169 | x_offset = 0
170 | elif spatial_idx == 2:
171 | x_offset = width - size
172 | cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
173 | if ndim == 3:
174 | cropped = cropped.squeeze(0)
175 | return cropped
176 |
177 |
178 | def blend(images1, images2, alpha):
179 | """
180 | Blend two images with a given weight alpha.
181 | Args:
182 | images1 (tensor): the first images to be blended, the dimension is
183 | `num frames` x `channel` x `height` x `width`.
184 | images2 (tensor): the second images to be blended, the dimension is
185 | `num frames` x `channel` x `height` x `width`.
186 | alpha (float): the blending weight.
187 | Returns:
188 | (tensor): blended images, the dimension is
189 | `num frames` x `channel` x `height` x `width`.
190 | """
191 | return images1 * alpha + images2 * (1 - alpha)
192 |
193 |
194 | def grayscale(images):
195 | """
196 | Get the grayscale for the input images. The channels of images should be
197 | in order BGR.
198 | Args:
199 | images (tensor): the input images for getting grayscale. Dimension is
200 | `num frames` x `channel` x `height` x `width`.
201 | Returns:
202 | img_gray (tensor): blended images, the dimension is
203 | `num frames` x `channel` x `height` x `width`.
204 | """
205 | # R -> 0.299, G -> 0.587, B -> 0.114.
206 | img_gray = torch.tensor(images)
207 | gray_channel = 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
208 | img_gray[:, 0] = gray_channel
209 | img_gray[:, 1] = gray_channel
210 | img_gray[:, 2] = gray_channel
211 | return img_gray
212 |
213 |
214 | def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
215 | """
216 | Perfrom a color jittering on the input images. The channels of images
217 | should be in order BGR.
218 | Args:
219 | images (tensor): images to perform color jitter. Dimension is
220 | `num frames` x `channel` x `height` x `width`.
221 | img_brightness (float): jitter ratio for brightness.
222 | img_contrast (float): jitter ratio for contrast.
223 | img_saturation (float): jitter ratio for saturation.
224 | Returns:
225 | images (tensor): the jittered images, the dimension is
226 | `num frames` x `channel` x `height` x `width`.
227 | """
228 |
229 | jitter = []
230 | if img_brightness != 0:
231 | jitter.append("brightness")
232 | if img_contrast != 0:
233 | jitter.append("contrast")
234 | if img_saturation != 0:
235 | jitter.append("saturation")
236 |
237 | if len(jitter) > 0:
238 | order = np.random.permutation(np.arange(len(jitter)))
239 | for idx in range(0, len(jitter)):
240 | if jitter[order[idx]] == "brightness":
241 | images = brightness_jitter(img_brightness, images)
242 | elif jitter[order[idx]] == "contrast":
243 | images = contrast_jitter(img_contrast, images)
244 | elif jitter[order[idx]] == "saturation":
245 | images = saturation_jitter(img_saturation, images)
246 | return images
247 |
248 |
249 | def brightness_jitter(var, images):
250 | """
251 | Perfrom brightness jittering on the input images. The channels of images
252 | should be in order BGR.
253 | Args:
254 | var (float): jitter ratio for brightness.
255 | images (tensor): images to perform color jitter. Dimension is
256 | `num frames` x `channel` x `height` x `width`.
257 | Returns:
258 | images (tensor): the jittered images, the dimension is
259 | `num frames` x `channel` x `height` x `width`.
260 | """
261 | alpha = 1.0 + np.random.uniform(-var, var)
262 |
263 | img_bright = torch.zeros(images.shape)
264 | images = blend(images, img_bright, alpha)
265 | return images
266 |
267 |
268 | def contrast_jitter(var, images):
269 | """
270 | Perfrom contrast jittering on the input images. The channels of images
271 | should be in order BGR.
272 | Args:
273 | var (float): jitter ratio for contrast.
274 | images (tensor): images to perform color jitter. Dimension is
275 | `num frames` x `channel` x `height` x `width`.
276 | Returns:
277 | images (tensor): the jittered images, the dimension is
278 | `num frames` x `channel` x `height` x `width`.
279 | """
280 | alpha = 1.0 + np.random.uniform(-var, var)
281 |
282 | img_gray = grayscale(images)
283 | img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
284 | images = blend(images, img_gray, alpha)
285 | return images
286 |
287 |
288 | def saturation_jitter(var, images):
289 | """
290 | Perfrom saturation jittering on the input images. The channels of images
291 | should be in order BGR.
292 | Args:
293 | var (float): jitter ratio for saturation.
294 | images (tensor): images to perform color jitter. Dimension is
295 | `num frames` x `channel` x `height` x `width`.
296 | Returns:
297 | images (tensor): the jittered images, the dimension is
298 | `num frames` x `channel` x `height` x `width`.
299 | """
300 | alpha = 1.0 + np.random.uniform(-var, var)
301 | img_gray = grayscale(images)
302 | images = blend(images, img_gray, alpha)
303 |
304 | return images
305 |
306 |
307 | def lighting_jitter(images, alphastd, eigval, eigvec):
308 | """
309 | Perform AlexNet-style PCA jitter on the given images.
310 | Args:
311 | images (tensor): images to perform lighting jitter. Dimension is
312 | `num frames` x `channel` x `height` x `width`.
313 | alphastd (float): jitter ratio for PCA jitter.
314 | eigval (list): eigenvalues for PCA jitter.
315 | eigvec (list[list]): eigenvectors for PCA jitter.
316 | Returns:
317 | out_images (tensor): the jittered images, the dimension is
318 | `num frames` x `channel` x `height` x `width`.
319 | """
320 | if alphastd == 0:
321 | return images
322 | # generate alpha1, alpha2, alpha3.
323 | alpha = np.random.normal(0, alphastd, size=(1, 3))
324 | eig_vec = np.array(eigvec)
325 | eig_val = np.reshape(eigval, (1, 3))
326 | rgb = np.sum(
327 | eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
328 | axis=1,
329 | )
330 | out_images = torch.zeros_like(images)
331 | if len(images.shape) == 3:
332 | # C H W
333 | channel_dim = 0
334 | elif len(images.shape) == 4:
335 | # T C H W
336 | channel_dim = 1
337 | else:
338 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
339 |
340 | for idx in range(images.shape[channel_dim]):
341 | # C H W
342 | if len(images.shape) == 3:
343 | out_images[idx] = images[idx] + rgb[2 - idx]
344 | # T C H W
345 | elif len(images.shape) == 4:
346 | out_images[:, idx] = images[:, idx] + rgb[2 - idx]
347 | else:
348 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
349 |
350 | return out_images
351 |
352 |
353 | def color_normalization(images, mean, stddev):
354 | """
355 | Perform color nomration on the given images.
356 | Args:
357 | images (tensor): images to perform color normalization. Dimension is
358 | `num frames` x `channel` x `height` x `width`.
359 | mean (list): mean values for normalization.
360 | stddev (list): standard deviations for normalization.
361 |
362 | Returns:
363 | out_images (tensor): the noramlized images, the dimension is
364 | `num frames` x `channel` x `height` x `width`.
365 | """
366 | if len(images.shape) == 3:
367 | assert len(mean) == images.shape[0], "channel mean not computed properly"
368 | assert len(stddev) == images.shape[0], "channel stddev not computed properly"
369 | elif len(images.shape) == 4:
370 | assert len(mean) == images.shape[1], "channel mean not computed properly"
371 | assert len(stddev) == images.shape[1], "channel stddev not computed properly"
372 | else:
373 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
374 |
375 | out_images = torch.zeros_like(images)
376 | for idx in range(len(mean)):
377 | # C H W
378 | if len(images.shape) == 3:
379 | out_images[idx] = (images[idx] - mean[idx]) / stddev[idx]
380 | elif len(images.shape) == 4:
381 | out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
382 | else:
383 | raise NotImplementedError(f"Unsupported dimension {len(images.shape)}")
384 | return out_images
385 |
386 |
387 | def _get_param_spatial_crop(
388 | scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False
389 | ):
390 | """
391 | Given scale, ratio, height and width, return sampled coordinates of the videos.
392 | """
393 | for _ in range(num_repeat):
394 | area = height * width
395 | target_area = random.uniform(*scale) * area
396 | if log_scale:
397 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
398 | aspect_ratio = math.exp(random.uniform(*log_ratio))
399 | else:
400 | aspect_ratio = random.uniform(*ratio)
401 |
402 | w = int(round(math.sqrt(target_area * aspect_ratio)))
403 | h = int(round(math.sqrt(target_area / aspect_ratio)))
404 |
405 | if np.random.uniform() < 0.5 and switch_hw:
406 | w, h = h, w
407 |
408 | if 0 < w <= width and 0 < h <= height:
409 | i = random.randint(0, height - h)
410 | j = random.randint(0, width - w)
411 | return i, j, h, w
412 |
413 | # Fallback to central crop
414 | in_ratio = float(width) / float(height)
415 | if in_ratio < min(ratio):
416 | w = width
417 | h = int(round(w / min(ratio)))
418 | elif in_ratio > max(ratio):
419 | h = height
420 | w = int(round(h * max(ratio)))
421 | else: # whole image
422 | w = width
423 | h = height
424 | i = (height - h) // 2
425 | j = (width - w) // 2
426 | return i, j, h, w
427 |
428 |
429 | def random_resized_crop(
430 | images,
431 | target_height,
432 | target_width,
433 | scale=(0.8, 1.0),
434 | ratio=(3.0 / 4.0, 4.0 / 3.0),
435 | ):
436 | """
437 | Crop the given images to random size and aspect ratio. A crop of random
438 | size (default: of 0.08 to 1.0) of the original size and a random aspect
439 | ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This
440 | crop is finally resized to given size. This is popularly used to train the
441 | Inception networks.
442 |
443 | Args:
444 | images: Images to perform resizing and cropping.
445 | target_height: Desired height after cropping.
446 | target_width: Desired width after cropping.
447 | scale: Scale range of Inception-style area based random resizing.
448 | ratio: Aspect ratio range of Inception-style area based random resizing.
449 | """
450 |
451 | height = images.shape[2]
452 | width = images.shape[3]
453 |
454 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
455 | cropped = images[:, :, i : i + h, j : j + w]
456 | return torch.nn.functional.interpolate(
457 | cropped,
458 | size=(target_height, target_width),
459 | mode="bilinear",
460 | align_corners=False,
461 | )
462 |
463 |
464 | def random_resized_crop_with_shift(
465 | images,
466 | target_height,
467 | target_width,
468 | scale=(0.8, 1.0),
469 | ratio=(3.0 / 4.0, 4.0 / 3.0),
470 | ):
471 | """
472 | This is similar to random_resized_crop. However, it samples two different
473 | boxes (for cropping) for the first and last frame. It then linearly
474 | interpolates the two boxes for other frames.
475 |
476 | Args:
477 | images: Images to perform resizing and cropping.
478 | target_height: Desired height after cropping.
479 | target_width: Desired width after cropping.
480 | scale: Scale range of Inception-style area based random resizing.
481 | ratio: Aspect ratio range of Inception-style area based random resizing.
482 | """
483 | t = images.shape[1]
484 | height = images.shape[2]
485 | width = images.shape[3]
486 |
487 | i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width)
488 | i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width)
489 | i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()]
490 | j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()]
491 | h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()]
492 | w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()]
493 | out = torch.zeros((3, t, target_height, target_width))
494 | for ind in range(t):
495 | out[:, ind : ind + 1, :, :] = torch.nn.functional.interpolate(
496 | images[
497 | :,
498 | ind : ind + 1,
499 | i_s[ind] : i_s[ind] + h_s[ind],
500 | j_s[ind] : j_s[ind] + w_s[ind],
501 | ],
502 | size=(target_height, target_width),
503 | mode="bilinear",
504 | align_corners=False,
505 | )
506 | return out
507 |
508 |
509 | def create_random_augment(
510 | input_size,
511 | auto_augment=None,
512 | interpolation="bilinear",
513 | ):
514 | """
515 | Get video randaug transform.
516 |
517 | Args:
518 | input_size: The size of the input video in tuple.
519 | auto_augment: Parameters for randaug. An example:
520 | "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number
521 | of operations to apply).
522 | interpolation: Interpolation method.
523 | """
524 | if isinstance(input_size, tuple):
525 | img_size = input_size[-2:]
526 | else:
527 | img_size = input_size
528 |
529 | if auto_augment:
530 | assert isinstance(auto_augment, str)
531 | if isinstance(img_size, tuple):
532 | img_size_min = min(img_size)
533 | else:
534 | img_size_min = img_size
535 | aa_params = {"translate_const": int(img_size_min * 0.45)}
536 | if interpolation and interpolation != "random":
537 | aa_params["interpolation"] = _pil_interp(interpolation)
538 | if auto_augment.startswith("rand"):
539 | return transforms.Compose([rand_augment_transform(auto_augment, aa_params)])
540 | raise NotImplementedError
541 |
542 |
543 | def random_sized_crop_img(
544 | im,
545 | size,
546 | jitter_scale=(0.08, 1.0),
547 | jitter_aspect=(3.0 / 4.0, 4.0 / 3.0),
548 | max_iter=10,
549 | ):
550 | """
551 | Performs Inception-style cropping (used for training).
552 | """
553 | assert len(im.shape) == 3, "Currently only support image for random_sized_crop"
554 | h, w = im.shape[1:3]
555 | i, j, h, w = _get_param_spatial_crop(
556 | scale=jitter_scale,
557 | ratio=jitter_aspect,
558 | height=h,
559 | width=w,
560 | num_repeat=max_iter,
561 | log_scale=False,
562 | switch_hw=True,
563 | )
564 | cropped = im[:, i : i + h, j : j + w]
565 | return torch.nn.functional.interpolate(
566 | cropped.unsqueeze(0),
567 | size=(size, size),
568 | mode="bilinear",
569 | align_corners=False,
570 | ).squeeze(0)
571 |
572 |
573 | # The following code are modified based on timm lib, we will replace the following
574 | # contents with dependency from PyTorchVideo.
575 | # https://github.com/facebookresearch/pytorchvideo
576 | class RandomResizedCropAndInterpolation:
577 | """Crop the given PIL Image to random size and aspect ratio with random interpolation.
578 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random
579 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
580 | is finally resized to given size.
581 | This is popularly used to train the Inception networks.
582 | Args:
583 | size: expected output size of each edge
584 | scale: range of size of the origin size cropped
585 | ratio: range of aspect ratio of the origin aspect ratio cropped
586 | interpolation: Default: PIL.Image.BILINEAR
587 | """
588 |
589 | def __init__(
590 | self,
591 | size,
592 | scale=(0.08, 1.0),
593 | ratio=(3.0 / 4.0, 4.0 / 3.0),
594 | interpolation="bilinear",
595 | ):
596 | if isinstance(size, tuple):
597 | self.size = size
598 | else:
599 | self.size = (size, size)
600 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
601 | print("range should be of kind (min, max)")
602 |
603 | if interpolation == "random":
604 | self.interpolation = _RANDOM_INTERPOLATION
605 | else:
606 | self.interpolation = _pil_interp(interpolation)
607 | self.scale = scale
608 | self.ratio = ratio
609 |
610 | @staticmethod
611 | def get_params(img, scale, ratio):
612 | """Get parameters for ``crop`` for a random sized crop.
613 | Args:
614 | img (PIL Image): Image to be cropped.
615 | scale (tuple): range of size of the origin size cropped
616 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
617 | Returns:
618 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random
619 | sized crop.
620 | """
621 | area = img.size[0] * img.size[1]
622 |
623 | for _ in range(10):
624 | target_area = random.uniform(*scale) * area
625 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
626 | aspect_ratio = math.exp(random.uniform(*log_ratio))
627 |
628 | w = int(round(math.sqrt(target_area * aspect_ratio)))
629 | h = int(round(math.sqrt(target_area / aspect_ratio)))
630 |
631 | if w <= img.size[0] and h <= img.size[1]:
632 | i = random.randint(0, img.size[1] - h)
633 | j = random.randint(0, img.size[0] - w)
634 | return i, j, h, w
635 |
636 | # Fallback to central crop
637 | in_ratio = img.size[0] / img.size[1]
638 | if in_ratio < min(ratio):
639 | w = img.size[0]
640 | h = int(round(w / min(ratio)))
641 | elif in_ratio > max(ratio):
642 | h = img.size[1]
643 | w = int(round(h * max(ratio)))
644 | else: # whole image
645 | w = img.size[0]
646 | h = img.size[1]
647 | i = (img.size[1] - h) // 2
648 | j = (img.size[0] - w) // 2
649 | return i, j, h, w
650 |
651 | def __call__(self, img):
652 | """
653 | Args:
654 | img (PIL Image): Image to be cropped and resized.
655 | Returns:
656 | PIL Image: Randomly cropped and resized image.
657 | """
658 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
659 | if isinstance(self.interpolation, (tuple, list)):
660 | interpolation = random.choice(self.interpolation)
661 | else:
662 | interpolation = self.interpolation
663 | return F.resized_crop(img, i, j, h, w, self.size, interpolation)
664 |
665 | def __repr__(self):
666 | if isinstance(self.interpolation, (tuple, list)):
667 | interpolate_str = " ".join(
668 | [_pil_interpolation_to_str[x] for x in self.interpolation]
669 | )
670 | else:
671 | interpolate_str = _pil_interpolation_to_str[self.interpolation]
672 | format_string = self.__class__.__name__ + "(size={0}".format(self.size)
673 | format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale))
674 | format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio))
675 | format_string += ", interpolation={0})".format(interpolate_str)
676 | return format_string
677 |
--------------------------------------------------------------------------------