├── static ├── readme.txt ├── framework.png └── figure-4-updated.png ├── requirements.txt ├── README.md ├── example.py ├── vit.py ├── train.py ├── evaluate.py └── utils.py /static/readme.txt: -------------------------------------------------------------------------------- 1 | Static asset hosting for the repository. 2 | -------------------------------------------------------------------------------- /static/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghost-signal/myna/HEAD/static/framework.png -------------------------------------------------------------------------------- /static/figure-4-updated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ghost-signal/myna/HEAD/static/figure-4-updated.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | libauc>=1.3.0 2 | mir-eval>=0.7 3 | nnAudio>=0.3.3 4 | numpy>=1.25.2 5 | scikit_learn>=1.5.1 6 | torch>=2.0.0 7 | torchvision>=0.15.1 8 | tqdm>=4.66.5 9 | vit_pytorch>=1.6.5 10 | wandb>=0.14.0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Myna: Masking-Based Contrastive Learning of Musical Representations 2 | 3 | Welcome to the official repository for **Myna: Masking-Based Contrastive Learning of Musical Representations**. This repository contains the codebase and public models used in the paper. 4 | 5 | ![The Myna Framework](https://raw.githubusercontent.com/ghost-signal/myna/refs/heads/main/static/framework.png) 6 | 7 | ## Model Checkpoints 8 | 9 | We provide pretrained model checkpoints for different configurations of the Myna model: 10 | 11 | - [Myna-Base](https://drive.google.com/file/d/1JZgR9zqTHz7a0To4X6PWHOuOlM8DBwvL/view?usp=sharing) 12 | - [Myna-Vertical](https://drive.google.com/file/d/1C8ZjL29Y_GII1v808x0k0-tv5rR7GWGA/view?usp=sharing) 13 | - [Myna-Hybrid](https://drive.google.com/file/d/1-U4BmDVOf2kllsXY0H3R1GNat9kZaWrp/view?usp=sharing) 14 | 15 | Feel free to download and use these models for your tasks. 16 | 17 | ## Inference Example 18 | 19 | To get started with inference, you can refer to the minimal example provided in the script [example.py](https://github.com/ghost-signal/myna/blob/main/example.py). 20 | 21 | 22 | ## Citation 23 | 24 | If you use this code or the models in your research, please cite our work: 25 | 26 | ``` 27 | @article{myna2025, 28 | title={Myna: Masking-Based Contrastive Learning of Musical Representations}, 29 | author={Anonymous}, 30 | journal={}, 31 | year={2025} 32 | } 33 | ``` 34 | 35 | --- 36 | 37 | ## License 38 | 39 | This repository is licensed under [MIT License](https://mit-license.org/). 40 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Minimal script example for model inference 3 | ''' 4 | 5 | from argparse import Namespace 6 | from nnAudio.features.mel import MelSpectrogram 7 | import torch 8 | import torchaudio 9 | import torchaudio.transforms as T 10 | 11 | from utils import get_n_frames, load_model 12 | from vit import SimpleViT 13 | 14 | FILENAME = 'your_file_here.wav' # file to get embeddings from 15 | MODEL_PATH = 'myna-hybrid.pth' # path to model checkpoint 16 | MODEL_TYPE = 'hybrid' # 'square', 'vertical', or 'hybrid' 17 | HYBRID_MODE = True # concatenate embeddings for hybrid models; disable to only use square patches 18 | N_SAMPLES = 50000 # number of samples per embedding 19 | MYNA_SR = 16000 # myna constant 20 | 21 | 22 | def load_and_preprocess_audio(filename: str): 23 | # load audio 24 | signal, sr = torchaudio.load(filename) 25 | 26 | # make mono if necessary 27 | if signal.shape[0] > 1: 28 | signal = signal.mean(dim=0, keepdim=True) 29 | 30 | # resample to target sample rate 31 | if sr != MYNA_SR: 32 | resampler = T.Resample(orig_freq=sr, new_freq=MYNA_SR) 33 | signal = resampler(signal) 34 | 35 | # sanity check 36 | assert signal.dim() == 2 37 | 38 | # compute spectrogram 39 | mel_spec = MelSpectrogram(sr=16000, n_mels=128, verbose=False) 40 | ms = mel_spec(signal) 41 | 42 | return ms 43 | 44 | def batch_spectrogram(ms: torch.Tensor, n_frames: int): 45 | # sanity check 46 | assert ms.dim() == 3 and ms.shape[0] == 1 47 | 48 | # discard excess frames 49 | num_chunks = ms.shape[-1] // n_frames 50 | ms = ms[:, :, :num_chunks * n_frames] 51 | 52 | # split the tensor into chunks and stack them 53 | chunks = torch.chunk(ms, num_chunks, dim=2) 54 | batch = torch.stack(chunks) 55 | 56 | return batch 57 | 58 | 59 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 60 | patch_size = (128, 2) if MODEL_TYPE == 'vertical' else 16 61 | 62 | # sanity check 63 | if HYBRID_MODE: 64 | assert MODEL_TYPE == 'hybrid', 'hybrid mode can only be enabled for hybrid model types' 65 | 66 | # number of spectrogram frames to feed into the model 67 | n_frames = get_n_frames( 68 | n_samples=N_SAMPLES, 69 | args=Namespace( 70 | sr=16000, 71 | patch_size=patch_size 72 | ) 73 | ) 74 | 75 | # initialize model first 76 | model = SimpleViT( 77 | image_size=(128, n_frames), 78 | channels=1, 79 | patch_size=patch_size, 80 | num_classes=50, # doesn't matter 81 | dim=384, 82 | depth=12, 83 | heads=6, 84 | mlp_dim=1536, 85 | additional_patch_size=(128, 2) if MODEL_TYPE == 'hybrid' else None 86 | ) 87 | 88 | # now load weights 89 | load_model(model, MODEL_PATH, device, ignore_layers=['linear_head'], verbose=True) 90 | model.linear_head = torch.nn.Identity() 91 | model.hybrid_mode = HYBRID_MODE 92 | 93 | # load and preprocess audio 94 | ms = load_and_preprocess_audio(FILENAME) 95 | ms = batch_spectrogram(ms, n_frames) 96 | 97 | # forward pass 98 | model.eval() 99 | with torch.no_grad(): 100 | embeds = model(ms) 101 | 102 | print(f'Successfully computed embeddings of shape: {embeds.shape}') -------------------------------------------------------------------------------- /vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from the vit_pytorch library: https://github.com/lucidrains/vit-pytorch 3 | ''' 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | # helpers 12 | 13 | def pair(t): 14 | return t if isinstance(t, tuple) else (t, t) 15 | 16 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 17 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 18 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 19 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 20 | omega = 1.0 / (temperature ** omega) 21 | 22 | y = y.flatten()[:, None] * omega[None, :] 23 | x = x.flatten()[:, None] * omega[None, :] 24 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 25 | return pe.type(dtype) 26 | 27 | # classes 28 | 29 | class FeedForward(nn.Module): 30 | def __init__(self, dim, hidden_dim): 31 | super().__init__() 32 | self.net = nn.Sequential( 33 | nn.LayerNorm(dim), 34 | nn.Linear(dim, hidden_dim), 35 | nn.GELU(), 36 | nn.Linear(hidden_dim, dim), 37 | ) 38 | def forward(self, x): 39 | return self.net(x) 40 | 41 | class Attention(nn.Module): 42 | def __init__(self, dim, heads = 8, dim_head = 64): 43 | super().__init__() 44 | inner_dim = dim_head * heads 45 | self.heads = heads 46 | self.scale = dim_head ** -0.5 47 | self.norm = nn.LayerNorm(dim) 48 | 49 | self.attend = nn.Softmax(dim = -1) 50 | 51 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 52 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 53 | 54 | def forward(self, x): 55 | x = self.norm(x) 56 | 57 | qkv = self.to_qkv(x).chunk(3, dim = -1) 58 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 59 | 60 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 61 | 62 | attn = self.attend(dots) 63 | 64 | out = torch.matmul(attn, v) 65 | out = rearrange(out, 'b h n d -> b n (h d)') 66 | return self.to_out(out) 67 | 68 | class Transformer(nn.Module): 69 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 70 | super().__init__() 71 | self.norm = nn.LayerNorm(dim) 72 | self.layers = nn.ModuleList([]) 73 | for _ in range(depth): 74 | self.layers.append(nn.ModuleList([ 75 | Attention(dim, heads = heads, dim_head = dim_head), 76 | FeedForward(dim, mlp_dim) 77 | ])) 78 | def forward(self, x): 79 | for attn, ff in self.layers: 80 | x = attn(x) + x 81 | x = ff(x) + x 82 | return self.norm(x) 83 | 84 | class SimpleViT(nn.Module): 85 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, additional_patch_size = None, hybrid_mode: bool = False): 86 | super().__init__() 87 | self.hybrid_mode = hybrid_mode 88 | image_height, image_width = pair(image_size) 89 | patch_height, patch_width = pair(patch_size) 90 | 91 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 92 | 93 | self.additional_patch_size = additional_patch_size 94 | if additional_patch_size: 95 | patch_height_b, patch_width_b = pair(additional_patch_size) 96 | patch_dim_b = channels * patch_height_b * patch_width_b 97 | 98 | self.to_patch_embedding_b, self.pos_embedding_b = self._make_embeddings( 99 | patch_height_b, patch_width_b, patch_dim_b, dim, image_height, image_width 100 | ) 101 | 102 | patch_dim = channels * patch_height * patch_width 103 | 104 | self.to_patch_embedding, self.pos_embedding = self._make_embeddings( 105 | patch_height, patch_width, patch_dim, dim, image_height, image_width 106 | ) 107 | 108 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 109 | 110 | self.pool = "mean" 111 | self.to_latent = nn.Identity() 112 | 113 | self.linear_head = nn.Linear(dim, num_classes) 114 | 115 | def forward(self, img, recurse=True): 116 | if self.hybrid_mode and recurse: 117 | a = self(img, recurse=False) 118 | self.toggle_embeddings() 119 | b = self(img, recurse=False) 120 | self.toggle_embeddings() 121 | return torch.cat((a, b), dim=-1) 122 | 123 | device = img.device 124 | 125 | x = self.to_patch_embedding(img) 126 | x += self.pos_embedding.to(device, dtype=x.dtype) 127 | 128 | x = self.transformer(x) 129 | x = x.mean(dim = 1) 130 | 131 | x = self.to_latent(x) 132 | return self.linear_head(x) 133 | 134 | def toggle_embeddings(self): 135 | if not self.additional_patch_size: 136 | print('toggle_embeddings() called but no additional patch size provided! Ignoring call.') 137 | return 138 | self.to_patch_embedding, self.to_patch_embedding_b = self.to_patch_embedding_b, self.to_patch_embedding 139 | self.pos_embedding, self.pos_embedding_b = self.pos_embedding_b, self.pos_embedding 140 | 141 | def _make_embeddings(self, patch_height, patch_width, patch_dim, dim, image_height, image_width): 142 | to_patch_embedding = nn.Sequential( 143 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 144 | nn.LayerNorm(patch_dim), 145 | nn.Linear(patch_dim, dim), 146 | nn.LayerNorm(dim), 147 | ) 148 | 149 | pos_embedding = posemb_sincos_2d( 150 | h = image_height // patch_height, 151 | w = image_width // patch_width, 152 | dim = dim, 153 | ) 154 | 155 | return to_patch_embedding, pos_embedding -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # suppress warnings 2 | import warnings 3 | warnings.filterwarnings('ignore') 4 | 5 | import argparse 6 | import os 7 | import torch.distributed as dist 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from vit import SimpleViT # from vit_pytorch import SimpleViT 10 | import wandb 11 | 12 | from utils import * 13 | 14 | 15 | def main_worker(rank: int, world_size: int, args: argparse.Namespace): 16 | args.rank = rank 17 | args.world_size = world_size 18 | if world_size > 1: 19 | # Initialize process group for distributed training 20 | dist.init_process_group(backend=args.dist_backend, rank=rank, world_size=world_size) 21 | torch.cuda.set_device(rank) 22 | args.device = f'cuda:{rank}' 23 | else: 24 | # Use CPU or a single GPU 25 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | 27 | model, criterion, optimizer, train_loader, test_dataset, use_wandb, _ = setup_for_training(rank, world_size, args) 28 | 29 | if world_size > 1: 30 | # Wrap model in DDP 31 | model = DDP(model, device_ids=[rank], output_device=rank) 32 | 33 | best_test_metrics = {} 34 | for epoch in range(args.resume_epochs, args.epochs): 35 | if epoch == args.train_only_head_epochs: 36 | freeze_unfreeze_backbone(model, freeze=False) 37 | 38 | train_loss, train_metrics = train_epoch(model, train_loader, criterion, optimizer, epoch, args=args) 39 | 40 | if rank == 0: 41 | test_loss, test_metrics = test(model, test_dataset, criterion, epoch, args=args) 42 | log_metrics(train_loss, test_loss, train_metrics, test_metrics, best_test_metrics, use_wandb) 43 | 44 | # checkpoint model 45 | if rank == 0 and args.checkpoint_dir and (epoch+1) % args.checkpoint_epochs == 0: 46 | save_model(model, args.checkpoint_dir, f'model_epoch_{epoch+1}.pth') 47 | save_optimizer(optimizer, criterion, args.checkpoint_dir, 'optimizer.pth') 48 | 49 | # log best values achieved 50 | if use_wandb: 51 | log_dict = {} 52 | for metric_name, value in best_test_metrics.items(): 53 | log_dict[f'run_best/test_{metric_name}'] = value 54 | 55 | wandb.log(log_dict) 56 | 57 | # print best values 58 | for metric_name, value in best_test_metrics.items(): 59 | print(f'Best {metric_name}: {value}') 60 | 61 | if world_size > 1: 62 | dist.destroy_process_group() 63 | 64 | print('Training complete.') 65 | return best_test_metrics 66 | 67 | 68 | def setup_for_training(rank: int, world_size: int, args: argparse.Namespace): 69 | seed_everything(args.seed + rank) 70 | if rank == 0: 71 | print(f'==> Seed: {args.seed}') 72 | 73 | args.mel_frames = get_n_frames(args.n_samples, args) 74 | 75 | # initialize wandb 76 | use_wandb = False 77 | if rank == 0 and args.wandb: 78 | assert args.run_name is not None, 'wandb run needs to have a name.' 79 | assert args.wandb_project is not None, 'wandb project needs to be defined.' 80 | wandb.init( 81 | project=args.wandb_project, 82 | config=vars(args), 83 | name=args.run_name, 84 | reinit=False 85 | ) 86 | use_wandb = True 87 | 88 | # load dataset 89 | train_dataset, train_loader = get_dataset( 90 | dataroot=os.path.join(args.dataroot, 'train'), 91 | args=args, 92 | distributed=world_size > 1, 93 | rank=rank, 94 | world_size=world_size 95 | ) 96 | test_dataset, _ = get_dataset( 97 | dataroot=os.path.join(args.dataroot, 'test'), 98 | args=args, 99 | drop_last=False 100 | ) 101 | 102 | if rank == 0: 103 | print(f'==> Training dataset contains {len(train_dataset):,} songs.') 104 | print(f'==> Testing dataset contains {len(test_dataset):,} songs.') 105 | print(f'==> Using {args.device}') 106 | 107 | model = SimpleViT( 108 | image_size=(args.n_mels, args.mel_frames), 109 | channels=1, 110 | patch_size=args.patch_size, 111 | num_classes=args.num_outputs, 112 | dim=args.dim, 113 | depth=args.depth, 114 | heads=args.heads, 115 | mlp_dim=args.mlp_dim, 116 | dim_head=args.dim_head, 117 | additional_patch_size=args.additional_patch_size 118 | ) 119 | 120 | if args.proj_head: 121 | add_proj_head(model, args.proj_head) 122 | 123 | if args.task_type.lower() == 'mae': 124 | setup_for_mae(model, args) 125 | 126 | if args.resume: 127 | load_model(model, args.resume, args.device, args.ignore_layers, verbose=(rank == 0)) 128 | 129 | if args.use_other_patch_size: 130 | model.toggle_embeddings() 131 | 132 | if args.eval_hybrid: 133 | model.hybrid_mode = True 134 | args.dim *= 2 135 | 136 | # EXPERIMENTAL 137 | if args.add_layers: 138 | from vit_pytorch.simple_vit import Attention, FeedForward 139 | model.transformer.layers.extend([nn.ModuleList([ 140 | Attention(dim=args.dim, heads=args.heads, dim_head=64), 141 | FeedForward(dim=args.dim, hidden_dim=args.mlp_dim) 142 | ]) for _ in range(args.add_layers)]) 143 | 144 | if args.train_only_head_epochs: 145 | freeze_unfreeze_backbone(model, freeze=True) 146 | 147 | # EXPERIMENTAL 148 | if args.unfreeze_last_n_layers: 149 | for layer in model.transformer.layers[-args.unfreeze_last_n_layers:]: 150 | layer.train() 151 | for param in layer.parameters(): 152 | param.requires_grad = True 153 | 154 | model = model.to(args.device) 155 | 156 | criterion = get_criterion(args, N=len(train_dataset)) 157 | optimizer = get_optimizer(model, args) 158 | 159 | if args.resume_optimizer: 160 | load_optimizer(optimizer, criterion, args.resume_optimizer) 161 | 162 | n_params = sum(p.numel() for p in model.parameters()) 163 | n_active = sum(p.numel() for p in model.parameters() if p.requires_grad) 164 | active = f'({n_active:,} active)' if n_params != n_active else '' 165 | print(f'==> Model contains {n_params:,} parameters {active}') 166 | print(f'==> Using {args.optimizer.upper()} optimizer') 167 | 168 | return model, criterion, optimizer, train_loader, test_dataset, use_wandb, wandb 169 | 170 | 171 | def log_metrics(train_loss: float, test_loss: float, train_metrics: dict, test_metrics: dict, best_test_metrics: dict, use_wandb: bool): 172 | # log to Weights and Biases 173 | if use_wandb: 174 | log_dict = { 175 | 'train/loss': train_loss, 176 | 'test/loss': test_loss 177 | } 178 | 179 | for metric_name, value in train_metrics.items(): 180 | if metric_name == 'text': continue 181 | log_dict[f'train/{metric_name}'] = value 182 | for metric_name, value in test_metrics.items(): 183 | if metric_name == 'text': continue 184 | log_dict[f'test/{metric_name}'] = value 185 | 186 | wandb.log(log_dict) 187 | 188 | for metric_name, value in test_metrics.items(): 189 | if metric_name == 'text': continue 190 | if metric_name not in best_test_metrics or value > best_test_metrics[metric_name]: 191 | best_test_metrics[metric_name] = value 192 | 193 | 194 | if __name__ == '__main__': 195 | args = parse_args() 196 | 197 | try: 198 | rank = int(os.environ['LOCAL_RANK']) 199 | world_size = int(os.environ['WORLD_SIZE']) 200 | except KeyError: 201 | rank = 0 202 | world_size = 1 203 | 204 | main_worker(rank=rank, world_size=world_size, args=args) 205 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # suppress warnings 2 | import warnings 3 | warnings.filterwarnings('ignore') 4 | 5 | import argparse 6 | from collections import defaultdict 7 | import json 8 | from itertools import product 9 | from sklearn.preprocessing import StandardScaler 10 | import torch 11 | from torch import nn 12 | from tqdm.auto import trange, tqdm 13 | 14 | from train import setup_for_training 15 | from utils import * 16 | 17 | 18 | # evaluations that consistently give the best results. used to narrow 19 | # the search space for faster evaluations during model development 20 | FAST_EVAL_INDICES = [ 21 | 2, 3, 4, 5, 6, 9, 10, 11, 12, 22 | 13, 14, 16, 19, 37, 43, 52, 55, 23 | 56, 58, 59, 70, 94, 106, 213 24 | ] 25 | 26 | 27 | def main(args: argparse.Namespace): 28 | assert args.resume or args.eval_load_embeds, 'Model checkpoint or precomputed embeddings must be passed in.' 29 | assert args.device.lower() != 'cpu', 'CUDA not found. Terminating early to avoid wasting HPRC resources.' 30 | 31 | args.ignore_layers = ['linear_head'] 32 | args.proj_head = None 33 | args.proj_head_dropout = None 34 | args.hop_samples = args.hop_samples or args.n_samples 35 | args.hop_frames = get_n_frames(args.hop_samples, args) 36 | 37 | model, _, _, train_loader, test_dataset, use_wandb, wandb = setup_for_training(rank=0, world_size=1, args=args) 38 | 39 | valid_dataset, _ = get_dataset( 40 | dataroot=os.path.join(args.dataroot, 'valid'), 41 | args=args 42 | ) 43 | 44 | train_dataset = train_loader.dataset 45 | train_dataset.frame_size = None 46 | valid_dataset.frame_size = None 47 | test_dataset.frame_size = None 48 | model.linear_head = nn.Identity() 49 | 50 | if args.eval_load_embeds: 51 | train_dataset, valid_dataset, test_dataset = load_embeds(args.eval_load_embeds) 52 | else: 53 | train_dataset = compute_embeddings(model, train_dataset, args) 54 | valid_dataset = compute_embeddings(model, valid_dataset, args, testmode=True) 55 | test_dataset = compute_embeddings(model, test_dataset, args, testmode=True) 56 | 57 | if args.eval_save_embeds: 58 | save_embeds(train_dataset, valid_dataset, test_dataset, args.eval_save_embeds) 59 | 60 | hyperparam_grid = get_hyperparameter_grid() 61 | 62 | print(f'==> Starting grid search over {len(hyperparam_grid)} hyperparameter combinations.') 63 | 64 | best_overall_metric = -np.inf 65 | best_overall_metrics = {} 66 | best_hyperparams = None 67 | for idx, hyperparams in enumerate(hyperparam_grid): 68 | if args.eval_start_idx > idx: continue 69 | if args.fast_eval and (idx+1) not in FAST_EVAL_INDICES: continue 70 | print(f'==> Starting run {idx+1}/{len(hyperparam_grid)}') 71 | 72 | best_primary_metric, best_metrics = evaluate(train_dataset, valid_dataset, test_dataset, hyperparams, args) 73 | 74 | if best_primary_metric > best_overall_metric: 75 | print(f'Found new best run with best primary metric = {best_primary_metric:.4f}') 76 | print(f'All metrics: {best_metrics["text"]}') 77 | best_overall_metric = best_primary_metric 78 | best_overall_metrics = best_metrics.copy() 79 | best_hyperparams = hyperparams 80 | 81 | if use_wandb: 82 | wandb.log(best_overall_metrics) 83 | 84 | print('\n==> Grid Search Complete.') 85 | print('Best Overall Results:') 86 | for metric_name, value in best_overall_metrics.items(): 87 | if metric_name != 'text': 88 | print(f'{metric_name}: {value}') 89 | print('\nBest Hyperparameters:') 90 | print(best_hyperparams) 91 | 92 | if args.eval_save_results: 93 | with open(args.eval_save_results, 'w') as f: 94 | json.dump(best_overall_metrics, f, indent=4) 95 | 96 | 97 | class EmbeddingDataset(Dataset): 98 | ''' Dataset for precomputed embeddings. ''' 99 | def __init__(self, indices: torch.LongTensor, embeddings: torch.Tensor, labels: list, standardize: bool = False): 100 | self.indices = indices 101 | self.embeddings = embeddings 102 | self.labels = labels 103 | self.standardize = standardize 104 | 105 | self.standardized_embeddings = torch.tensor( 106 | StandardScaler().fit_transform(embeddings.cpu().numpy()), dtype=torch.float32 107 | ).to(embeddings.device) 108 | 109 | def __len__(self): 110 | return len(self.embeddings) 111 | 112 | def __getitem__(self, idx): 113 | actual_idx = self.indices[idx] 114 | embedding = self.embeddings[idx] if not self.standardize else self.standardized_embeddings[idx] 115 | label = self.labels[actual_idx] 116 | return actual_idx, embedding, label 117 | 118 | 119 | @torch.no_grad() 120 | def compute_embeddings(model: nn.Module, dataset: Dataset, args: argparse.Namespace, testmode: bool = False): 121 | ''' Computes model embeddings for every spectrogram in the dataset with the given hop size. ''' 122 | model.eval() 123 | indices, embeddings, labels = [], [], [] 124 | for i, spec, label in tqdm(dataset, desc='Computing Embeddings'): 125 | spec_hops = extract_hops(spec, args, testmode).to(args.device) 126 | embeds = model(spec_hops) 127 | 128 | n_embeds = embeds.shape[0] 129 | indices.extend([i] * n_embeds) 130 | embeddings.append(embeds) 131 | labels.append(label) 132 | 133 | indices = torch.LongTensor(indices) 134 | embeddings = torch.cat(embeddings, dim=0) 135 | return EmbeddingDataset(indices, embeddings, labels) 136 | 137 | 138 | def extract_hops(spec: torch.Tensor, args: argparse.Namespace, testmode: bool): 139 | ''' Extract overlapping hops from a spectrogram (1, n_mels, total_frames). ''' 140 | total_frames = spec.shape[-1] 141 | hop_size = args.hop_frames if not testmode else args.mel_frames 142 | frame_size = args.mel_frames 143 | 144 | specs = [] 145 | for start in range(0, total_frames - frame_size + 1, hop_size): 146 | end = start + frame_size 147 | specs.append(spec[..., start:end].unsqueeze(0)) 148 | 149 | if not specs: 150 | # if no hops can be extracted, pad the spectrogram 151 | pad_amount = frame_size - total_frames 152 | padded_spec = torch.nn.functional.pad(spec, (0, pad_amount), mode='constant', value=0) 153 | specs.append(padded_spec.unsqueeze(0)) 154 | 155 | return torch.cat(specs, dim=0) 156 | 157 | 158 | def get_hyperparameter_grid(): 159 | # hyperparameter grid, from JukeMIR paper 160 | feature_standardization_options = ['off', 'on'] 161 | model_types = ['linear', 'mlp'] 162 | batch_sizes = [64, 256] 163 | learning_rates = [1e-5, 1e-4, 1e-3] 164 | dropouts = [0.25, 0.5, 0.75] 165 | weight_decays = [0, 1e-4, 1e-3] 166 | 167 | hyperparameter_grid = list(product( 168 | feature_standardization_options, 169 | model_types, 170 | batch_sizes, 171 | learning_rates, 172 | dropouts, 173 | weight_decays 174 | )) 175 | 176 | return hyperparameter_grid 177 | 178 | 179 | def evaluate(train_dataset: Dataset, valid_dataset: Dataset, test_dataset: Dataset, hyperparams: list, args: argparse.Namespace): 180 | standardize, model_type, batch_size, learning_rate, dropout, weight_decay = hyperparams 181 | args.learning_rate = learning_rate 182 | args.weight_decay = weight_decay 183 | 184 | train_dataset.standardize = standardize 185 | test_dataset.standardize = standardize 186 | 187 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 188 | valid_loader = DataLoader(valid_dataset, batch_size=batch_size,shuffle=False) 189 | test_loader = DataLoader(test_dataset, batch_size=batch_size,shuffle=False) 190 | 191 | model = make_model(model_type, dropout, args) 192 | criterion = get_criterion(args, N=-1) 193 | optimizer = get_optimizer(model, args) 194 | 195 | best_valid_metric = -np.inf 196 | epochs_without_improvement = 0 197 | primary_metric = 0 198 | metrics = {} 199 | for epoch in (pbar := trange(args.epochs)): 200 | train_metrics, train_primary = train_epoch_embeddings(model, train_loader, criterion, optimizer, args) 201 | valid_metrics, valid_primary = test_embeddings(model, valid_loader, criterion, args) 202 | test_metrics, test_primary = test_embeddings(model, test_loader, criterion, args) 203 | 204 | if valid_primary > best_valid_metric: 205 | best_valid_metric = valid_primary 206 | primary_metric = test_primary 207 | metrics = test_metrics.copy() 208 | epochs_without_improvement = 0 209 | 210 | pbar.set_description(f'Best: {test_primary:.4f} (epoch {epoch})') 211 | else: 212 | epochs_without_improvement += 1 213 | 214 | if args.eval_patience is not None and epochs_without_improvement >= args.eval_patience: 215 | print(f'Early stopping after {epoch + 1} epochs.') 216 | break 217 | 218 | return primary_metric, metrics 219 | 220 | 221 | def make_model(model_type: str, dropout: float, args: argparse.Namespace): 222 | if model_type == 'linear': 223 | model = nn.Linear(args.dim, args.num_outputs) 224 | elif model_type == 'mlp': 225 | model = nn.Sequential( 226 | nn.Linear(args.dim, 512), 227 | nn.ReLU(), 228 | nn.Dropout(p=dropout), 229 | nn.Linear(512, args.num_outputs) 230 | ) 231 | else: 232 | raise ValueError(f'Unsupported model type: {args.model_type}') 233 | 234 | model = model.to(args.device) 235 | return model 236 | 237 | 238 | def update_predictions_targets(all_targets: defaultdict, all_predictions: defaultdict, indices: torch.LongTensor, labels: torch.Tensor, outputs: torch.Tensor): 239 | indices = indices.cpu().numpy() 240 | labels = labels.cpu().numpy() 241 | outputs = outputs.detach().cpu().numpy() 242 | 243 | # Accumulate predictions and store targets 244 | for idx, label, output in zip(indices, labels, outputs): 245 | all_predictions[idx].append(output) 246 | all_targets[idx] = label 247 | 248 | 249 | def cat_predictions_targets(all_targets: defaultdict, all_predictions: defaultdict): 250 | all_targets = np.array([ 251 | all_targets[i] for i in range(len(all_targets)) 252 | ]) 253 | all_predictions = np.array([ 254 | np.mean(all_predictions[i], axis=0) for i in range(len(all_predictions)) 255 | ]) 256 | 257 | return all_targets, all_predictions 258 | 259 | 260 | def train_epoch_embeddings(model: nn.Module, train_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, args: argparse.Namespace): 261 | ''' Train the classifier/MLP on embeddings for one epoch. ''' 262 | model.train() 263 | running_loss = 0.0 264 | all_targets = defaultdict(list) 265 | all_predictions = defaultdict(list) 266 | 267 | for indices, inputs, labels in train_loader: 268 | inputs = inputs.to(args.device, dtype=torch.float) 269 | labels = labels.to(args.device) 270 | 271 | optimizer.zero_grad() 272 | outputs = model(inputs) 273 | loss = criterion(outputs, labels) 274 | 275 | loss.backward() 276 | optimizer.step() 277 | 278 | running_loss += loss.item() 279 | 280 | update_predictions_targets(all_targets, all_predictions, indices, labels, outputs) 281 | 282 | all_targets, all_predictions = cat_predictions_targets(all_targets, all_predictions) 283 | 284 | # Compute metrics 285 | train_metrics = compute_metrics(all_targets, all_predictions, args) 286 | primary_metric = select_primary_metric(train_metrics, args) 287 | 288 | return train_metrics, primary_metric 289 | 290 | 291 | def test_embeddings(model: nn.Module, test_loader: DataLoader, criterion: nn.Module, args: argparse.Namespace): 292 | model.eval() 293 | running_loss = 0.0 294 | all_targets = defaultdict(list) 295 | all_predictions = defaultdict(list) 296 | 297 | with torch.no_grad(): 298 | for indices, inputs, labels in test_loader: 299 | inputs = inputs.to(args.device, dtype=torch.float) 300 | labels = labels.to(args.device) 301 | 302 | outputs = model(inputs) 303 | loss = criterion(outputs, labels) 304 | running_loss += loss.item() 305 | 306 | update_predictions_targets(all_targets, all_predictions, indices, labels, outputs) 307 | 308 | all_targets, all_predictions = cat_predictions_targets(all_targets, all_predictions) 309 | 310 | # Compute metrics 311 | test_metrics = compute_metrics(all_targets, all_predictions, args) 312 | primary_metric = select_primary_metric(test_metrics, args) 313 | 314 | return test_metrics, primary_metric 315 | 316 | 317 | def select_primary_metric(metrics: dict, args: argparse.Namespace): 318 | ''' Select the primary metric based on task type. ''' 319 | if args.task_type.lower() == 'binary': 320 | return metrics.get('auprc', -np.inf) 321 | 322 | elif args.task_type.lower() == 'multiclass': 323 | if args.key_detection: 324 | return metrics.get('weighted_accuracy', -np.inf) 325 | else: 326 | return metrics.get('top_1_accuracy', -np.inf) 327 | 328 | elif args.task_type.lower() == 'regression': 329 | dims = len([r for r in metrics if r.startswith('r2_dim')]) 330 | if dims > 1: 331 | return np.mean([v for k, v in metrics.items() if k.startswith('r2_dim')]) 332 | else: 333 | return metrics.get('r2', -np.inf) 334 | 335 | else: 336 | # no primary metric 337 | return -np.inf 338 | 339 | 340 | def save_embeds(train_dataset: EmbeddingDataset, valid_dataset: EmbeddingDataset, test_dataset: EmbeddingDataset, filepath: str): 341 | data = { 342 | 'train': { 343 | 'indices': train_dataset.indices.cpu(), 344 | 'embeddings': train_dataset.embeddings.cpu(), 345 | 'labels': train_dataset.labels 346 | }, 347 | 'valid': { 348 | 'indices': valid_dataset.indices.cpu(), 349 | 'embeddings': valid_dataset.embeddings.cpu(), 350 | 'labels': valid_dataset.labels 351 | }, 352 | 'test': { 353 | 'indices': test_dataset.indices.cpu(), 354 | 'embeddings': test_dataset.embeddings.cpu(), 355 | 'labels': test_dataset.labels 356 | } 357 | } 358 | 359 | with open(filepath, 'wb') as f: 360 | pickle.dump(data, f) 361 | 362 | 363 | def load_embeds(filepath: str): 364 | with open(filepath, 'rb') as f: 365 | data = pickle.load(f) 366 | 367 | train_dataset = EmbeddingDataset( 368 | indices=data['train']['indices'], 369 | embeddings=data['train']['embeddings'], 370 | labels=data['train']['labels'] 371 | ) 372 | 373 | valid_dataset = EmbeddingDataset( 374 | indices=data['valid']['indices'], 375 | embeddings=data['valid']['embeddings'], 376 | labels=data['valid']['labels'] 377 | ) 378 | 379 | test_dataset = EmbeddingDataset( 380 | indices=data['test']['indices'], 381 | embeddings=data['test']['embeddings'], 382 | labels=data['test']['labels'] 383 | ) 384 | 385 | return train_dataset, valid_dataset, test_dataset 386 | 387 | 388 | if __name__ == '__main__': 389 | args = parse_args() 390 | main(args) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from einops.layers.torch import Rearrange 3 | from glob import glob 4 | import json 5 | from libauc.losses.contrastive import GCLoss_v1 6 | import math 7 | import matplotlib.pyplot as plt 8 | from mir_eval.key import weighted_score 9 | from nnAudio.features.mel import MelSpectrogram 10 | import numpy as np 11 | import os 12 | import pickle 13 | import random 14 | from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, top_k_accuracy_score, r2_score 15 | import torch 16 | import torch.distributed.nn.functional as dist_fn 17 | from torch import nn, optim 18 | from torch.nn.parallel import DistributedDataParallel as DDP 19 | from torch.utils.data import DataLoader, Dataset, DistributedSampler 20 | from tqdm.auto import tqdm 21 | from vit_pytorch.simple_vit import Transformer 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='Hyperparameter configuration') 26 | 27 | # Run parameters 28 | parser.add_argument('--run_name', type=str, help='Name of the run') 29 | parser.add_argument('--wandb_project', type=str, help='Weights and Biases project name') 30 | parser.add_argument('--wandb', action='store_true', help='Use Weights and Biases') 31 | 32 | # Dataset parameters 33 | parser.add_argument('--dataroot', type=str, required=True, help='Root directory of the dataset') 34 | parser.add_argument('--n_samples', type=int, default=100000, help='Number of samples to use') 35 | parser.add_argument('--hop_samples', type=int, help='Number of samples to hop (overlapping window) for evaluation.') 36 | parser.add_argument('--sr', type=int, default=16000, help='Sampling rate of audio data') 37 | parser.add_argument('--filenames', type=str, help='JSON file containing a list of filenames (useful for very large datasets)') 38 | parser.add_argument('--unlabeled', action='store_true', help='Is the dataset unlabeled? (We assume a labeled dataset otherwise).') 39 | 40 | # Preprocessing parameters 41 | parser.add_argument('--n_mels', type=int, default=128, help='Number of mel bands') 42 | parser.add_argument('--log_mel', action='store_true', help='Take the log of the input spectrograms before passing them to the model') 43 | 44 | # DataLoader parameters 45 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU for training') 46 | parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading') 47 | 48 | # Optimizer parameters 49 | parser.add_argument('--optimizer', type=str, choices=['adam', 'sgd'], default='adam', help='Optimizer to use (adam or sgd)') 50 | parser.add_argument('--learning_rate', type=float, default=3e-4, help='Learning rate for the optimizer') 51 | parser.add_argument('--lr_schedule', type=str, default='constant', choices=['constant', 'cosine'], help='Learning rate schedule') 52 | parser.add_argument('--warmup_epochs', type=int, default=10, help='Number of warmup epochs if non-constant learning rate schedule is set') 53 | parser.add_argument('--weight_decay', type=float, default=1e-5, help='Weight decay for the optimizer') 54 | parser.add_argument('--sgd_momentum', type=float, default=0.9, help='Momentum for SGD optimizer') 55 | parser.add_argument('--sogclr_tau', type=float, default=0.1, help='Temperature for SogCLR (contrastive)') 56 | parser.add_argument('--sogclr_gamma', type=float, default=0.9, help='Gamma for SogCLR (contrastive)') 57 | parser.add_argument('--sogclr_eps', type=float, default=1e-8, help='Epsilon value for SogCLR (contrastive)') 58 | parser.add_argument('--isogclr', action='store_true', help='Use iSogCLR for individualized temperatures') 59 | parser.add_argument('--gamma_schedule', type=str, default='constant', choices=['constant', 'cosine'], help='Gamma schedule for SogCLR. If cosine, decays from 1.0 to --sogclr_gamma.') 60 | parser.add_argument('--grad_clip', type=float, help='Gradient clipping (by norm)') 61 | 62 | # Training parameters 63 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use for training') 64 | parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs') 65 | parser.add_argument('--checkpoint_dir', type=str, help='Directory to save checkpoints') 66 | parser.add_argument('--checkpoint_epochs', type=int, default=1, help='Checkpoint every x epochs') 67 | parser.add_argument('--resume', type=str, help='Resume training from the last checkpoint') 68 | parser.add_argument('--resume_optimizer', type=str, help='Load optimizer state from the last checkpoint') 69 | parser.add_argument('--resume_epochs', type=int, default=0, help='Epoch to resume training from (for checkpointing and learning rate scheduling). If provided, the script will run args.epochs - args.resume_epochs epochs of training.') 70 | parser.add_argument('--ignore_layers', type=str, nargs='*', default=[], help='List of layer names to ignore during loading from checkpoint (default: none)') 71 | parser.add_argument('--seed', type=int, default=42, help='Seed for deterministic output') 72 | parser.add_argument('--task_type', type=str, choices=['binary', 'multiclass', 'regression', 'contrastive', 'mae'], required=True, help='Task type for training (binary, multiclass, regression, contrastive, or mae [masked autoencoder])') 73 | parser.add_argument('--mask_ratio', type=float, default=None, help='Mask ratio for masked autoencoder task') 74 | parser.add_argument('--train_only_head_epochs', type=int, help='Number of epochs to train only the model head (freeze backbone for this many epochs)') 75 | parser.add_argument('--local-rank', type=int, help='Local rank (passed in by torchrun)') 76 | parser.add_argument('--dist_backend', type=str, default='nccl', help='Backend for distributed training (default: NCCL)') 77 | 78 | # Evaluation parameters 79 | parser.add_argument('--key_detection', action='store_true', help='Additionally evaluate on key detection (must have num_outputs=24)') 80 | parser.add_argument('--fast_eval', action='store_true', help='Skip some parts of hyperparameter search for quicker evaluation (usually accurate, especially for MTAT, but 10x faster). Actual result is at least as good as the result of fast evaluation.') 81 | parser.add_argument('--eval_start_idx', type=int, default=0, help='Start evaluation script (grid search) from permutation index i (useful if a run fails)') 82 | parser.add_argument('--eval_save_embeds', type=str, help='Location to save model embeddings to (saved as a pickle file)') 83 | parser.add_argument('--eval_load_embeds', type=str, help='Location to load model embeddings to (expects a pickle file)') 84 | parser.add_argument('--eval_save_results', type=str, help='Location to save evaluation results to (JSON file)') 85 | parser.add_argument('--eval_patience', type=int, help='Patience for evaluation (default: none)') 86 | parser.add_argument('--use_other_patch_size', action='store_true', help='Use the additional patch size for evaluation') 87 | parser.add_argument('--eval_hybrid', action='store_true', help='If there is an additional patch size provided (see --additional_patch_size), this sets the model to compute both forward passes and concatenating them.') 88 | 89 | # Model parameters 90 | parser.add_argument('--patch_size', type=int, nargs='+', default=[16], help='Patch size for model input (can be a single integer or a list of two integers)') 91 | parser.add_argument('--additional_patch_size', type=int, nargs='*', help='Additioanl patch size for model input (can be omitted, a single integer, or a list of two integers)') 92 | parser.add_argument('--num_outputs', type=int, default=50, help='Number of model outputs') 93 | parser.add_argument('--dim', type=int, default=256, help='ViT model dimension') 94 | parser.add_argument('--depth', type=int, default=6, help='ViT number of layers') 95 | parser.add_argument('--decoder_depth', type=int, default=4, help='Decoder depth for masked autoencoder task (default: 4)') 96 | parser.add_argument('--heads', type=int, default=16, help='ViT number of attention heads') 97 | parser.add_argument('--mlp_dim', type=int, default=1024, help='ViT MLP dimension') 98 | parser.add_argument('--dim_head', type=int, default=64, help='ViT attention head dimension') 99 | parser.add_argument('--proj_head', type=str, nargs='+', help='Projection head dimensions to add on to ViT model') 100 | parser.add_argument('--arch', type=str, help='ViT architecture', choices=['vit-s-32', 'vit-b-16', 'vit-l-16']) 101 | 102 | # Experimental parameters 103 | parser.add_argument('--mixup_alpha', type=float, help='Alpha parameter for MixUp augmentation') 104 | parser.add_argument('--mixup_beta', type=float, help='Beta parameter for MixUp augmentation') 105 | parser.add_argument('--max_frame_distance', type=int, help='Maximum frame distance from first view for frame selection of positives in contrastive learning (default: None (inf))') 106 | parser.add_argument('--add_layers', type=int, help='Number of transformer layers to add to the model (default: 0)') 107 | parser.add_argument('--unfreeze_last_n_layers', type=int, help='Unfreeze the last n transformer layers') 108 | 109 | 110 | args = parser.parse_args() 111 | 112 | # architecture 113 | if args.arch and args.arch.lower() == 'vit-s-32': 114 | # dim 384, depth 12, MLP 1536, 6 heads, 22M parameters 115 | args.dim = 384 116 | args.depth = 12 117 | args.mlp_dim = 1536 118 | args.heads = 6 119 | if args.arch and args.arch.lower() == 'vit-b-16': 120 | # dim 768, depth 12, MLP 3072, 12 heads, 87M parameters 121 | args.dim = 768 122 | args.depth = 12 123 | args.mlp_dim = 3072 124 | args.heads = 12 125 | if args.arch and args.arch.lower() == 'vit-l-16': 126 | # dim 1024, depth 24, MLP 4096, 16 heads, 303M parameters 127 | args.dim = 1024 128 | args.depth = 24 129 | args.mlp_dim = 4096 130 | args.heads = 16 131 | 132 | # if args.patch_size is a list of [single element], convert it to int 133 | # vit-pytorch only accepts tuples 134 | args.patch_size = tuple(args.patch_size) 135 | if len(args.patch_size) == 1: 136 | args.patch_size = args.patch_size[0] 137 | 138 | if args.additional_patch_size: 139 | args.additional_patch_size = tuple(args.additional_patch_size) 140 | if len(args.additional_patch_size) == 1: 141 | args.additional_patch_size = args.additional_patch_size[0] 142 | 143 | return args 144 | 145 | 146 | class MelSpectrogramDataset(Dataset): 147 | ''' Dataset for pre-computed Mel Spectrograms ''' 148 | def __init__( 149 | self, 150 | dataroot: str, # path to dataset 151 | frame_size: int = None, # how many frames to return? None for entire spectrogram 152 | labeled: bool = True, # do pickle files contain (spec, label) pairs or just spec? 153 | pickle_extensions: list = ['.pkl'], # extensions to match when searching for files 154 | n_views: int = 1, # number of 'crops' of the spectrogram to return 155 | filenames = None, # path to JSON file containing filenames 156 | max_frame_distance: int = None # maximum number of frames far from the first starting point, sampled uniformly 157 | ): 158 | super().__init__() 159 | self.dataroot = dataroot 160 | self.frame_size = frame_size 161 | self.labeled = labeled 162 | self.n_views = n_views 163 | self.max_frame_distance = max_frame_distance 164 | 165 | # find files 166 | if filenames is None: 167 | self.filenames = [] 168 | for ext in pickle_extensions: 169 | self.filenames.extend(glob(os.path.join(dataroot, '**/*' + ext), recursive=True)) 170 | else: 171 | with open(filenames, 'r') as f: 172 | filenames = json.load(f) 173 | 174 | # filter filenames to this subset 175 | subset = os.path.basename(os.path.normpath(dataroot)) 176 | self.filenames = [ 177 | os.path.join(dataroot, *os.path.normpath(f).split(os.sep)[1:]) 178 | for f in filenames 179 | if os.path.normpath(f).startswith(subset) 180 | ] 181 | 182 | def __len__(self): 183 | return len(self.filenames) 184 | 185 | def __getitem__(self, i: int): 186 | with open(self.filenames[i], 'rb') as f: 187 | data = pickle.load(f) 188 | 189 | if self.labeled: 190 | spec, label = data 191 | else: 192 | spec, label = data, 0 # no label 193 | 194 | if self.frame_size is not None: 195 | total_frames = spec.shape[-1] 196 | if total_frames < self.frame_size: 197 | raise ValueError(f'Spectrogram has fewer frames ({total_frames}) than the requested frame size ({self.frame_size}).') 198 | 199 | # Select the first frame uniformly 200 | start = random.randint(0, total_frames - self.frame_size) 201 | end = start + self.frame_size 202 | specs = [spec[..., start:end]] 203 | 204 | # Select subsequent frames based on the first one if n_views > 1 205 | for _ in range(1, self.n_views): 206 | if self.max_frame_distance is not None: 207 | # Generate new start position uniformly within the range of max_frame_distance 208 | new_start = start + random.randint(-self.max_frame_distance, self.max_frame_distance) 209 | # Make sure the new start is within valid bounds 210 | new_start = max(0, min(new_start, total_frames - self.frame_size)) 211 | else: 212 | # Uniform selection as fallback 213 | new_start = random.randint(0, total_frames - self.frame_size) 214 | 215 | new_end = new_start + self.frame_size 216 | specs.append(spec[..., new_start:new_end]) 217 | 218 | spec = specs if self.n_views > 1 else specs[0] 219 | 220 | return i, spec, label 221 | 222 | 223 | @torch.no_grad() 224 | def predict(model: nn.Module, spec: torch.Tensor, chunk_size: int): 225 | ''' Averages predictions over all chunks (of size chunk_size) ''' 226 | 227 | # No batched inputs 228 | spec = spec.squeeze() 229 | assert spec.dim() == 2 230 | 231 | model.eval() 232 | n_mels, total_frames = spec.shape 233 | 234 | # Calculate the number of chunks we can extract 235 | n_chunks = total_frames // chunk_size 236 | 237 | # If there are no full chunks 238 | if n_chunks == 0: 239 | print("The input spectrogram is too small for the specified chunk size.") 240 | return None 241 | 242 | chunks = spec[:, :n_chunks * chunk_size] 243 | chunks = chunks.view(n_mels, n_chunks, chunk_size) 244 | chunks = chunks.permute(1, 0, 2).unsqueeze(1) 245 | 246 | outputs = model(chunks) 247 | return torch.mean(outputs, dim=0) 248 | 249 | 250 | def compute_metrics(targets: np.ndarray, predictions: np.ndarray, args: argparse.Namespace): 251 | ''' Compute metrics for the given task type. Input shape: (batch_size, n_outputs) ''' 252 | if args.task_type.lower() == 'binary': 253 | auroc = roc_auc_score(targets, predictions, average='macro') 254 | auprc = average_precision_score(targets, predictions, average='macro') 255 | 256 | return { 257 | 'auroc': auroc, 258 | 'auprc': auprc, 259 | 'text': f'AUROC: {auroc:.4f}, AUPRC: {auprc:.4f}' 260 | } 261 | 262 | elif args.task_type.lower() == 'multiclass': 263 | top_1_accuracy = accuracy_score(targets, np.argmax(predictions, axis=1)) 264 | top_5_accuracy = top_k_accuracy_score(targets, predictions, k=5) 265 | 266 | metrics = { 267 | 'top_1_accuracy': top_1_accuracy, 268 | 'top_5_accuracy': top_5_accuracy, 269 | 'text': f'Top-1 Accuracy: {top_1_accuracy:.4f}, Top-5 Accuracy: {top_5_accuracy:.4f}' 270 | } 271 | 272 | if args.key_detection: 273 | weighted_accuracy = key_detection_accuracy(targets, np.argmax(predictions, axis=1)) 274 | metrics['weighted_accuracy'] = weighted_accuracy 275 | metrics['text'] += f', Weighted: {weighted_accuracy:.4f}' 276 | 277 | return metrics 278 | 279 | elif args.task_type.lower() == 'regression': 280 | r2 = r2_score(targets, predictions) 281 | 282 | metrics = { 283 | 'r2': r2, 284 | 'text': f'R^2: {r2:.4f}' 285 | } 286 | 287 | dims = targets.shape[1] 288 | if dims > 1: 289 | for i in range(targets.shape[1]): 290 | r2_dim = r2_score(targets[:, i], predictions[:, i]) 291 | metrics[f'r2_dim_{i}'] = r2_dim 292 | metrics['text'] += f', R^2_dim_{i}: {r2_dim:.4f}' 293 | 294 | return metrics 295 | 296 | elif args.task_type.lower() in ['contrastive', 'mae']: 297 | return {'text': ''} # no metrics 298 | 299 | else: 300 | raise Exception(f'Task type {args.task_type} not supported') 301 | 302 | 303 | def compute_loss(model: nn.Module, inputs: torch.Tensor, labels: torch.Tensor, indices: torch.Tensor, criterion: nn.Module, args: argparse.Namespace, return_mae_outputs: bool = False, recurse: bool = True): 304 | # evaluate for both patch sizes if there are two patch sizes 305 | if args.additional_patch_size is not None and recurse: 306 | outputs, loss_a = compute_loss(model, inputs, labels, indices, criterion, args, recurse = False) 307 | model.toggle_embeddings() 308 | _, loss_b = compute_loss(model, inputs, labels, indices, criterion, args, recurse = False) 309 | model.toggle_embeddings() 310 | loss = (loss_a + loss_b) / 2 311 | 312 | elif args.task_type.lower() in ['binary', 'multiclass', 'regression']: 313 | if args.mixup_alpha is not None: 314 | inputs, y_a, y_b, lam = mixup_data(inputs, labels, args.mixup_alpha, args.mixup_beta) 315 | 316 | if args.mask_ratio: 317 | outputs = masked_model_output(inputs, model, args) 318 | else: 319 | outputs = model(inputs) 320 | 321 | if args.mixup_alpha is not None: 322 | loss = lam * criterion(outputs, y_a) + (1 - lam) * criterion(outputs, y_b) 323 | loss = criterion(outputs, labels) 324 | 325 | elif args.task_type.lower() in ['contrastive']: 326 | assert args.mask_ratio is not None, 'You must provide --mask_ratio for contrastive learning.' 327 | a, b = inputs 328 | 329 | if args.mixup_alpha is not None: 330 | a = mixup_data(a, None, args.mixup_alpha, args.mixup_beta)[0] 331 | b = mixup_data(b, None, args.mixup_alpha, args.mixup_beta)[0] 332 | 333 | za = masked_model_output(a, model, args) 334 | zb = masked_model_output(b, model, args) 335 | 336 | # distributed training 337 | if args.world_size > 1: 338 | # all gather using dist.nn.functional which allows autograd tracking 339 | gathered_za = dist_fn.all_gather(za) 340 | gathered_zb = dist_fn.all_gather(zb) 341 | gathered_indices = dist_fn.all_gather(indices) 342 | 343 | # concatenate the gathered outputs 344 | za = torch.cat(gathered_za, dim=0) 345 | zb = torch.cat(gathered_zb, dim=0) 346 | indices = torch.cat(gathered_indices, dim=0) 347 | 348 | outputs = torch.zeros(1) # no output 349 | loss = criterion(za, zb, indices.cpu()) 350 | 351 | elif args.task_type.lower() in ['mae']: 352 | assert args.mask_ratio is not None, 'You must provide --mask_ratio for masked autoencoding.' 353 | 354 | x = model.to_patch_embedding(inputs) 355 | x += model.pos_embedding.to(inputs.device, dtype=inputs.dtype) 356 | 357 | B, N, _ = x.shape 358 | n_masked = int(args.mask_ratio * N) 359 | indices = torch.stack([torch.randperm(N) for _ in range(B)]) 360 | mask_indices = indices[:, :n_masked].to(args.device) 361 | unmask_indices = indices[:, n_masked:].to(args.device) 362 | 363 | unmasked = x.gather(1, unmask_indices.unsqueeze(-1).expand(-1, -1, x.size(-1))) 364 | encoded = model.transformer(unmasked) 365 | 366 | mask_tokens = model.mask_token.repeat(B, n_masked, 1) 367 | mask_tokens += model.pos_embedding[mask_indices.cpu()].to(args.device) 368 | decoder_input = torch.cat([encoded, mask_tokens], dim=1) 369 | 370 | combined_indices = torch.cat([unmask_indices, mask_indices], dim=1) 371 | sorted_indices = torch.argsort(combined_indices, dim=1).to(args.device) 372 | decoder_input = decoder_input.gather(1, sorted_indices.unsqueeze(-1).expand(-1, -1, decoder_input.size(-1))) 373 | 374 | decoded = model.decoder(decoder_input) 375 | decoded = model.decoder_norm(decoded) 376 | outputs = model.decoder_head(decoded) 377 | 378 | outputs_masked = outputs.gather(1, mask_indices.unsqueeze(-1).expand(-1, -1, outputs.size(-1))) 379 | 380 | actual_patched = model.patchify(inputs) 381 | actual = actual_patched.gather(1, mask_indices.unsqueeze(-1).expand(-1, -1, actual_patched.size(-1))) 382 | 383 | outputs = outputs.scatter(1, unmask_indices.unsqueeze(-1).expand(-1, -1, outputs.size(-1)), torch.log1p(actual_patched).gather(1, unmask_indices.unsqueeze(-1).expand(-1, -1, actual_patched.size(-1)))) 384 | outputs = model.unpatchify(outputs) if return_mae_outputs else torch.zeros(1) 385 | 386 | loss = criterion(outputs_masked, torch.log1p(actual)) 387 | 388 | else: 389 | raise Exception(f'Task type {args.task_type} not supported.') 390 | 391 | return outputs, loss 392 | 393 | 394 | def train_epoch(model: nn.Module, train_loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer, epoch: int, args: argparse.Namespace): 395 | model.train() 396 | running_loss = 0.0 397 | all_targets = [] 398 | all_predictions = [] 399 | 400 | if args.world_size > 1: 401 | train_loader.sampler.set_epoch(epoch) 402 | 403 | set_schedulers(epoch, criterion, optimizer, args) 404 | 405 | progress_bar = tqdm(train_loader) if args.rank == 0 else train_loader 406 | for indices, inputs, labels in progress_bar: 407 | if isinstance(inputs, list): 408 | inputs = torch.stack(inputs, dim=0) 409 | if args.log_mel: 410 | inputs = torch.log1p(inputs) 411 | indices = indices.to(args.device) 412 | inputs = inputs.to(args.device) 413 | labels = labels.to(args.device) 414 | 415 | optimizer.zero_grad() 416 | outputs, loss = compute_loss(model, inputs, labels, indices, criterion, args) 417 | running_loss += loss.item() 418 | 419 | loss.backward() 420 | if args.grad_clip: 421 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.grad_clip) 422 | optimizer.step() 423 | 424 | all_targets.append(labels.cpu().numpy()) 425 | all_predictions.append(outputs.detach().cpu().numpy()) 426 | 427 | if args.rank == 0: 428 | progress_bar.set_description(f'Loss: {loss.item():.4f}') 429 | 430 | # Concatenate all stored targets and predictions 431 | all_targets = np.concatenate(all_targets, axis=0) 432 | all_predictions = np.concatenate(all_predictions, axis=0) 433 | 434 | # Compute metrics 435 | train_metrics = compute_metrics(all_targets, all_predictions, args) 436 | 437 | avg_loss = running_loss / len(train_loader) 438 | 439 | if args.rank == 0: 440 | print(f'Epoch {epoch}: Train Loss: {avg_loss:.4f}, Train {train_metrics["text"]}') 441 | 442 | return avg_loss, train_metrics 443 | 444 | 445 | def test(model: nn.Module, test_dataset: Dataset, criterion: nn.Module, epoch: int, args: argparse.Namespace): 446 | # don't test on contrastive pretraining 447 | if args.task_type == 'contrastive': 448 | return 0, {'text': ''} 449 | 450 | # just output examples for MAE 451 | if args.task_type == 'mae': 452 | output_mae_examples(model, test_dataset, epoch, args) 453 | return 0, {'text': ''} 454 | 455 | model.eval() 456 | all_targets = [] 457 | all_predictions = [] 458 | running_loss = 0.0 459 | 460 | progress_bar = tqdm(test_dataset) 461 | with torch.no_grad(): 462 | for _, inputs, labels in progress_bar: 463 | if isinstance(inputs, list): 464 | inputs = torch.stack(inputs, dim=0) 465 | if args.log_mel: 466 | inputs = torch.log1p(inputs) 467 | if not isinstance(labels, torch.Tensor): 468 | labels = torch.tensor(labels) 469 | inputs = inputs.to(args.device) 470 | labels = labels.to(args.device) 471 | 472 | outputs = predict(model, inputs, chunk_size=args.mel_frames) 473 | loss = criterion(outputs, labels) 474 | running_loss += loss.item() 475 | 476 | all_targets.append(labels.unsqueeze(0).cpu().numpy()) 477 | all_predictions.append(outputs.unsqueeze(0).cpu().numpy()) 478 | 479 | progress_bar.set_description(f'Loss: {loss.item():.4f}') 480 | 481 | # Concatenate all stored targets and predictions 482 | if labels.numel() == 1: 483 | all_targets = np.array(all_targets).squeeze(1) 484 | all_predictions = np.array(all_predictions).squeeze(1) 485 | else: 486 | all_targets = np.concatenate(all_targets, axis=0) 487 | all_predictions = np.concatenate(all_predictions, axis=0) 488 | 489 | # Compute metrics 490 | test_metrics = compute_metrics(all_targets, all_predictions, args) 491 | avg_loss = running_loss / len(test_dataset) 492 | test_metrics_str = f', Test {test_metrics["text"]}' if len(test_metrics["text"]) > 0 else '' 493 | 494 | if args.rank == 0: 495 | print(f'Test Loss: {avg_loss:.4f}{test_metrics_str}') 496 | 497 | return avg_loss, test_metrics 498 | 499 | 500 | def get_dataset(dataroot: str, args: argparse.Namespace, distributed=False, rank=0, world_size=1, drop_last=True): 501 | dataset = MelSpectrogramDataset( 502 | dataroot=dataroot, 503 | frame_size=args.mel_frames, 504 | labeled=not args.unlabeled, 505 | n_views=1 if args.task_type != 'contrastive' else 2, 506 | filenames=args.filenames, 507 | max_frame_distance=args.max_frame_distance 508 | ) 509 | 510 | # Use DistributedSampler if training in distributed mode 511 | sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) if distributed else None 512 | 513 | dataloader = DataLoader( 514 | dataset=dataset, 515 | batch_size=args.batch_size, 516 | shuffle=(sampler is None), 517 | sampler=sampler, 518 | num_workers=args.num_workers, 519 | pin_memory=True, 520 | drop_last=drop_last 521 | ) 522 | 523 | return dataset, dataloader 524 | 525 | 526 | def get_n_frames(n_samples: int, args: argparse.Namespace): 527 | ''' How many frames is n_samples samples? ''' 528 | mel_spectrogram = MelSpectrogram(sr=args.sr, n_mels=128, verbose=False) 529 | 530 | # patch size along the time dimension 531 | patch_size_time = args.patch_size if isinstance(args.patch_size, int) else args.patch_size[1] 532 | 533 | mel_frames = mel_spectrogram(torch.randn(1, 1, n_samples)).shape[-1] 534 | mel_frames = math.floor(mel_frames / patch_size_time) * patch_size_time 535 | return mel_frames 536 | 537 | 538 | def get_optimizer(model: nn.Module, args: argparse.Namespace): 539 | if args.optimizer.lower() == 'sgd': 540 | optimizer = optim.SGD( 541 | model.parameters(), 542 | lr=args.learning_rate, 543 | momentum=args.sgd_momentum, 544 | weight_decay=args.weight_decay 545 | ) 546 | elif args.optimizer.lower() == 'adam': 547 | optimizer = optim.Adam( 548 | model.parameters(), 549 | lr=args.learning_rate, 550 | weight_decay=args.weight_decay 551 | ) 552 | else: 553 | raise ValueError("Unsupported optimizer. Please use 'sgd' or 'adam'.") 554 | 555 | return optimizer 556 | 557 | 558 | def get_criterion(args: argparse.Namespace, N: int): 559 | if args.task_type.lower() == 'binary': 560 | return nn.BCEWithLogitsLoss() 561 | 562 | elif args.task_type.lower() == 'multiclass': 563 | return nn.CrossEntropyLoss() 564 | 565 | elif args.task_type.lower() in ['regression', 'mae']: 566 | return nn.MSELoss() 567 | 568 | elif args.task_type.lower() == 'contrastive': 569 | return GCLoss_v1( 570 | N=N, 571 | tau=args.sogclr_tau, 572 | gamma=args.sogclr_gamma, 573 | gamma_schedule=args.gamma_schedule, 574 | device=args.device, 575 | distributed=False, # args.world_size > 1, 576 | gamma_decay_epochs=args.epochs, 577 | eps=args.sogclr_eps, 578 | enable_isogclr=args.isogclr 579 | ) 580 | 581 | else: 582 | raise Exception(f'Task type {args.task_type} not supported') 583 | 584 | 585 | def seed_everything(seed: int): 586 | # Seed the built-in random module 587 | random.seed(seed) 588 | 589 | # Seed numpy 590 | np.random.seed(seed) 591 | 592 | # Seed torch 593 | torch.manual_seed(seed) 594 | torch.cuda.manual_seed(seed) 595 | torch.cuda.manual_seed_all(seed) # for multi-GPU. 596 | 597 | # Ensure deterministic behavior when using torch.backends.cudnn 598 | torch.backends.cudnn.deterministic = True 599 | torch.backends.cudnn.benchmark = False 600 | 601 | 602 | def save_model(model: nn.Module, checkpoint_dir: str, filename: str): 603 | os.makedirs(checkpoint_dir, exist_ok=True) 604 | checkpoint_path = os.path.join(checkpoint_dir, filename) 605 | model = model.module if isinstance(model, DDP) else model 606 | torch.save(model.state_dict(), checkpoint_path) 607 | 608 | 609 | def load_model(model: nn.Module, checkpoint_path: str, device: str, ignore_layers: list, verbose: bool): 610 | ''' 611 | Load model from checkpoint. Ignores (does not load) weights for 612 | layers whose names start with any string in ignore_layers. 613 | ''' 614 | checkpoint = torch.load(checkpoint_path, map_location=device) 615 | 616 | filtered_state_dict = { 617 | k: v for k, v in checkpoint.items() 618 | if not any(k.startswith(layer) for layer in ignore_layers) 619 | } 620 | 621 | model.load_state_dict(filtered_state_dict, strict=False) 622 | 623 | if ignore_layers and verbose: 624 | print(f'==> Loaded model from {checkpoint_path}, ignoring layers: {", ".join(ignore_layers)}') 625 | 626 | 627 | def save_optimizer(optimizer: optim.Optimizer, criterion: nn.Module, checkpoint_dir: str, filename: str): 628 | os.makedirs(checkpoint_dir, exist_ok=True) 629 | checkpoint_path = os.path.join(checkpoint_dir, filename) 630 | data = { 631 | 'optimizer_state_dict': optimizer.state_dict() 632 | } 633 | 634 | if isinstance(criterion, GCLoss_v1): 635 | data['sogclr_u'] = criterion.u 636 | 637 | torch.save(data, checkpoint_path) 638 | 639 | 640 | def load_optimizer(optimizer: optim.Optimizer, criterion: nn.Module, checkpoint_path: str): 641 | checkpoint = torch.load(checkpoint_path) 642 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 643 | print(f'==> Optimizer state loaded from {checkpoint_path}') 644 | 645 | if isinstance(criterion, GCLoss_v1) and 'sogclr_u' in checkpoint: 646 | criterion.u = checkpoint['sogclr_u'] 647 | print(f'==> SogCLR parameter loaded from {checkpoint_path}') 648 | 649 | 650 | def add_proj_head(model: nn.Module, proj_head: list): 651 | ''' 652 | Adds MLP projection head on top of the model, where proj_head are the layer output dimensions and dropout layers are specified in the format "d" (e.g., "d0.5"). 653 | ''' 654 | 655 | mlp_dimensions = [model.linear_head.in_features] + proj_head 656 | layers = [] 657 | current_dim = model.linear_head.in_features 658 | 659 | for output_dim in mlp_dimensions[1:]: 660 | # handle dropout layers 661 | if output_dim.lower().startswith('d'): 662 | try: 663 | dropout_prob = float(output_dim[1:]) 664 | layers.append(nn.Dropout(dropout_prob)) 665 | except ValueError: 666 | raise ValueError(f'Invalid dropout probability: {output_dim}') 667 | 668 | # handle linear layers 669 | else: 670 | input_dim = current_dim 671 | output_dim = int(output_dim) 672 | layers.append(nn.Linear(input_dim, output_dim)) 673 | layers.append(nn.ReLU()) 674 | current_dim = output_dim 675 | 676 | # remove the last relu, if added 677 | if layers and isinstance(layers[-1], nn.ReLU): 678 | layers.pop() 679 | 680 | model.linear_head = nn.Sequential(*layers) 681 | 682 | 683 | def setup_for_mae(model: nn.Module, args: argparse.Namespace): 684 | patch_height, patch_width = args.patch_size if isinstance(args.patch_size, tuple) else (args.patch_size, args.patch_size) 685 | h = args.n_mels // patch_height 686 | w = args.mel_frames // patch_width 687 | 688 | model.mask_token = nn.Parameter(torch.randn(1, 1, args.dim)) 689 | model.decoder = Transformer(args.dim, args.decoder_depth, args.heads, 64, args.mlp_dim) 690 | model.decoder_norm = nn.LayerNorm(args.dim) 691 | 692 | model.patchify = Rearrange( 693 | 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', 694 | p1 = patch_height, p2 = patch_width 695 | ) 696 | model.unpatchify = Rearrange( 697 | 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', 698 | h = h, w = w, p1 = patch_height, p2 = patch_width 699 | ) 700 | model.decoder_head = nn.Linear(args.dim, patch_height * patch_width) 701 | 702 | 703 | def freeze_unfreeze_backbone(model: nn.Module, freeze: bool): 704 | ''' Freeze or unfreeze all model parameters except mlp head ''' 705 | for name, param in model.named_parameters(): 706 | if name.startswith('linear_head'): 707 | param.requires_grad = True 708 | else: 709 | param.requires_grad = not freeze 710 | 711 | 712 | def key_detection_accuracy(targets: list, predictions: list): 713 | keys = [ 714 | 'c major', 'c minor', 'db major', 'db minor', 'd major', 715 | 'd minor', 'eb major', 'eb minor', 'e major', 'e minor', 716 | 'f major', 'f minor', 'gb major', 'gb minor', 'g major', 717 | 'g minor', 'ab major', 'ab minor', 'a major', 'a minor', 718 | 'bb major', 'bb minor', 'b major', 'b minor' 719 | ] 720 | 721 | targets = [keys[i] for i in targets] 722 | predictions = [keys[i] for i in predictions] 723 | 724 | score = np.mean([weighted_score(target, pred) for target, pred in zip(targets, predictions)]) 725 | return score 726 | 727 | 728 | def output_mae_examples(model: nn.Module, test_dataset: Dataset, epoch: int, args: argparse.Namespace): 729 | if not args.checkpoint_dir or ((epoch+1) % args.checkpoint_epochs != 0): 730 | return 731 | 732 | n_imgs = 10 733 | 734 | indices = torch.randint(0, len(test_dataset), (n_imgs,)) 735 | inputs = torch.stack([test_dataset[i][1] for i in indices]).to(args.device) 736 | outputs, _ = compute_loss(model, inputs, None, None, nn.MSELoss(), args, return_mae_outputs=True) 737 | 738 | img_dir = os.path.join(args.checkpoint_dir, 'outputs', f'epoch_{epoch:03}') 739 | os.makedirs(img_dir, exist_ok=True) 740 | 741 | for i in range(n_imgs): 742 | input_spectrogram_np = torch.log1p(inputs[i]).squeeze().cpu().detach().numpy() 743 | output_spectrogram_np = outputs[i].squeeze().cpu().detach().numpy() 744 | 745 | _, axes = plt.subplots(1, 2, figsize=(15, 5)) 746 | 747 | im1 = axes[0].imshow(input_spectrogram_np, aspect='auto', origin='lower') 748 | axes[0].set_title(f'Input Spectrogram {i}') 749 | axes[0].set_ylabel('Frequency bins') 750 | axes[0].set_xlabel('Time frames') 751 | plt.colorbar(im1, ax=axes[0], format='%+2.0f dB') 752 | 753 | im2 = axes[1].imshow(output_spectrogram_np, aspect='auto', origin='lower') 754 | axes[1].set_title(f'Output Spectrogram {i}') 755 | axes[1].set_ylabel('Frequency bins') 756 | axes[1].set_xlabel('Time frames') 757 | plt.colorbar(im2, ax=axes[1], format='%+2.0f dB') 758 | 759 | file_path = os.path.join(img_dir, f'input_output_spectrogram_{i}.png') 760 | plt.savefig(file_path, bbox_inches='tight') 761 | plt.close() 762 | 763 | if args.rank == 0: 764 | print(f'==> Saved {n_imgs} input-output spectrogram images in {img_dir}') 765 | 766 | 767 | def mask_inputs(x: torch.Tensor, mask_ratio: float, device: str): 768 | ''' Input masking for contrastive learning ''' 769 | # input B, N, D --> output B, N * (1 - mask_ratio), D 770 | B, N, _ = x.shape 771 | n_masked = int(mask_ratio * N) 772 | indices = torch.stack([torch.randperm(N) for _ in range(B)]) 773 | unmask_indices = indices[:, n_masked:].to(device) 774 | unmasked = x.gather(1, unmask_indices.unsqueeze(-1).expand(-1, -1, x.size(-1))) 775 | return unmasked 776 | 777 | 778 | def masked_model_output(x: torch.Tensor, model: nn.Module, args: argparse.Namespace): 779 | ''' ViT forward pass with masking. ''' 780 | model = model.module if isinstance(model, DDP) else model 781 | 782 | x = model.to_patch_embedding(x) 783 | x += model.pos_embedding.to(x.device, dtype=x.dtype) 784 | 785 | x = mask_inputs(x, args.mask_ratio, args.device) 786 | 787 | z = model.transformer(x) 788 | z = z.mean(dim=1) 789 | z = model.to_latent(z) 790 | z = model.linear_head(z) 791 | 792 | return z 793 | 794 | 795 | def mixup_data(x: torch.Tensor, y: torch.Tensor, alpha=5.0, beta=2.0): 796 | ''' Compute the mixup data. Return mixed inputs, pairs of targets, and lambda ''' 797 | lam = np.random.beta(alpha, beta) if alpha > 0 else 1 798 | 799 | batch_size = x.size()[0] 800 | index = torch.randperm(batch_size) 801 | 802 | mixed_x = lam * x + (1 - lam) * x[index, :] 803 | y_a, y_b = (y, y[index]) if y is not None else (None, None) 804 | return mixed_x, y_a, y_b, lam 805 | 806 | 807 | def set_schedulers(epoch: int, criterion: nn.Module, optimizer: optim.Optimizer, args: argparse.Namespace): 808 | # gamma schedule 809 | if isinstance(criterion, GCLoss_v1): 810 | criterion.adjust_gamma(epoch) 811 | 812 | if args.rank == 0: 813 | print(f'Adjusted gamma according to schedule: {criterion.gamma:.5f}') 814 | 815 | # learning rate schedule 816 | if args.lr_schedule.lower() == 'cosine': 817 | # warmup 818 | if epoch < args.warmup_epochs: 819 | lr = args.learning_rate * float(epoch + 1) / args.warmup_epochs 820 | 821 | # cosine decay 822 | else: 823 | progress = (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs) 824 | lr = args.learning_rate * 0.5 * (1.0 + math.cos(math.pi * progress)) 825 | 826 | for param_group in optimizer.param_groups: 827 | param_group['lr'] = lr --------------------------------------------------------------------------------