├── mae_ast ├── data │ ├── __init__.py │ └── mae_ast_dataset.py ├── __init__.py ├── criterions │ ├── __init__.py │ └── mae_ast_criterion.py ├── tasks │ ├── __init__.py │ └── mae_ast_pretraining.py └── models │ ├── __init__.py │ └── mae_ast.py ├── s3prl ├── mae_ast │ ├── hubconf.py │ └── expert.py ├── config │ ├── speech_commands │ │ └── config.yaml │ ├── voxceleb1 │ │ └── config.yaml │ └── emotion │ │ └── config.yaml └── README.md ├── config └── pretrain │ └── mae_ast.yaml └── README.md /mae_ast/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . mae_ast_dataset import * -------------------------------------------------------------------------------- /mae_ast/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data 2 | from . import tasks 3 | from . import models 4 | from . import criterions 5 | -------------------------------------------------------------------------------- /s3prl/mae_ast/hubconf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from s3prl.utility.download import _urls_to_filepaths 4 | from .expert import UpstreamExpert as _UpstreamExpert 5 | 6 | 7 | def mae_ast(ckpt, *args, **kwargs): 8 | """ 9 | The model from local ckpt 10 | ckpt (str): PATH 11 | """ 12 | assert os.path.isfile(ckpt) 13 | return _UpstreamExpert(ckpt, *args, **kwargs) -------------------------------------------------------------------------------- /s3prl/config/speech_commands/config.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | total_steps: 10000 3 | gradient_clipping: 1 4 | gradient_accumulate_steps: 1 5 | 6 | log_step: 10 7 | eval_step: 100 8 | save_step: 100 9 | max_keep: 1 10 | eval_dataloaders: 11 | - dev 12 | - test 13 | 14 | optimizer: 15 | name: TorchOptim 16 | torch_optim_name: Adam 17 | lr: 1.0e-5 18 | 19 | downstream_expert: 20 | datarc: 21 | speech_commands_root: /groups/public/benchmark/speech_commands/train 22 | speech_commands_test_root: /groups/public/benchmark/speech_commands/test 23 | num_workers: 8 24 | batch_size: 512 25 | 26 | modelrc: 27 | projector_dim: 256 28 | select: UtteranceLevel 29 | UtteranceLevel: 30 | pooling: MeanPooling 31 | 32 | -------------------------------------------------------------------------------- /mae_ast/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import importlib 8 | import os 9 | 10 | from fairseq import registry 11 | from fairseq.criterions.fairseq_criterion import ( # noqa 12 | FairseqCriterion, 13 | LegacyFairseqCriterion, 14 | ) 15 | from omegaconf import DictConfig 16 | 17 | 18 | # automatically import any Python files in the criterions/ directory 19 | for file in sorted(os.listdir(os.path.dirname(__file__))): 20 | if file.endswith(".py") and not file.startswith("_"): 21 | file_name = file[: file.find(".py")] 22 | importlib.import_module("mae_ast.criterions." + file_name) 23 | -------------------------------------------------------------------------------- /s3prl/config/voxceleb1/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | # this voxceleb1 is doing speaker classification task! 3 | runner: 4 | total_steps: 60000 5 | gradient_clipping: 1 6 | gradient_accumulate_steps: 4 7 | 8 | log_step: 50 9 | eval_step: 500 10 | save_step: 1000 11 | max_keep: 1 12 | eval_dataloaders: 13 | - dev 14 | - test 15 | 16 | optimizer: 17 | name: TorchOptim 18 | torch_optim_name: Adam 19 | lr: 1.0e-4 20 | 21 | # # comment the whole scheduler config block to disable learning rate scheduling 22 | # scheduler: 23 | # name: linear_schedule_with_warmup 24 | # num_warmup_steps: 5000 25 | 26 | downstream_expert: 27 | datarc: 28 | file_path: /path/to/VoxCeleb1 29 | meta_data: ./downstream/voxceleb1/veri_test_class.txt 30 | num_workers: 12 31 | train_batch_size: 4 32 | eval_batch_size: 16 33 | max_timestep: 128000 34 | 35 | modelrc: 36 | projector_dim: 256 37 | select: UtteranceLevel 38 | UtteranceLevel: 39 | pooling: MeanPooling 40 | -------------------------------------------------------------------------------- /s3prl/config/emotion/config.yaml: -------------------------------------------------------------------------------- 1 | runner: 2 | total_steps: 3000 3 | gradient_clipping: 1 4 | gradient_accumulate_steps: 5 5 | 6 | log_step: 30 7 | eval_step: 150 8 | save_step: 150 9 | max_keep: 1 10 | eval_dataloaders: 11 | - dev 12 | - test 13 | 14 | optimizer: 15 | name: TorchOptim 16 | torch_optim_name: Adam 17 | lr: 1.0e-4 18 | 19 | # comment the whole scheduler config block 20 | # to disable learning rate scheduling 21 | # scheduler: 22 | # name: linear_schedule_with_warmup 23 | # num_warmup_steps: 1400 24 | 25 | downstream_expert: 26 | datarc: 27 | root: /groups/public/benchmark/IEMOCAP 28 | test_fold: fold1 29 | pre_load: True 30 | train_batch_size: 6 31 | eval_batch_size: 16 32 | num_workers: 12 33 | valid_ratio: 0.2 34 | 35 | modelrc: 36 | projector_dim: 256 37 | select: UtteranceLevel 38 | 39 | UtteranceLevel: 40 | pooling: MeanPooling 41 | 42 | DeepModel: 43 | model_type: CNNSelfAttention 44 | hidden_dim: 80 45 | kernel_size: 5 46 | padding: 2 47 | pooling: 5 48 | dropout: 0.4 49 | -------------------------------------------------------------------------------- /config/pretrain/mae_ast.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | common: 4 | fp16: true 5 | log_format: json 6 | log_interval: 200 7 | seed: 1337 8 | tensorboard_logdir: tblog 9 | 10 | checkpoint: 11 | save_interval_updates: 25000 12 | keep_interval_updates: 1 13 | no_epoch_checkpoints: true 14 | 15 | 16 | distributed_training: 17 | ddp_backend: no_c10d 18 | distributed_backend: 'nccl' 19 | distributed_world_size: 32 20 | nprocs_per_node: 8 21 | find_unused_parameters: true 22 | 23 | task: 24 | _name: mae_ast_pretraining 25 | data: ??? 26 | feature_type: "fbank" 27 | mask_type: "random_mask" 28 | sample_rate: 16000 29 | max_sample_size: 250000 30 | min_sample_size: 32000 31 | feature_rate: 100 32 | feature_dim: 128 33 | pad_audio: false 34 | random_crop: true 35 | normalize: false # must be consistent with extractor 36 | deltas: false 37 | 38 | dataset: 39 | num_workers: 6 40 | max_tokens: 1400000 41 | skip_invalid_size_inputs_valid_test: true 42 | validate_interval: 5 43 | validate_interval_updates: 10000 44 | 45 | criterion: 46 | _name: mae_ast 47 | classification_weight: 1 48 | reconstruction_weight: 10 49 | 50 | optimization: 51 | max_update: 400000 52 | lr: [0.0001] 53 | clip_norm: 10.0 54 | 55 | optimizer: 56 | _name: adam 57 | adam_betas: (0.9,0.98) 58 | adam_eps: 1e-06 59 | weight_decay: 0.01 60 | 61 | lr_scheduler: 62 | _name: polynomial_decay 63 | warmup_updates: 32000 64 | 65 | model: 66 | _name: mae_ast 67 | dropout_input: 0.1 68 | dropout: 0.1 69 | attention_dropout: 0.1 70 | activation_dropout: 0.0 71 | feature_grad_mult: 0.1 72 | encoder_layerdrop: 0.05 73 | decoder_layerdrop: 0.0 74 | encoder_layers: 12 75 | decoder_layers: 2 76 | random_mask_prob: 0.75 77 | enc_conv_pos: false 78 | enc_sine_pos: true 79 | dec_conv_pos: false 80 | dec_sine_pos: true 81 | ast_kernel_size_chan: 16 82 | ast_kernel_size_time: 16 83 | ast_kernel_stride_chan: 16 84 | ast_kernel_stride_time: 16 85 | 86 | hydra: 87 | job: 88 | config: 89 | override_dirname: 90 | kv_sep: '-' 91 | item_sep: '__' 92 | exclude_keys: 93 | - run 94 | - task.data 95 | run: 96 | dir: ??? 97 | sweep: 98 | dir: ??? 99 | subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} 100 | -------------------------------------------------------------------------------- /s3prl/README.md: -------------------------------------------------------------------------------- 1 | # Fine Tuining MAE-AST on Superb with S3prl 2 | 3 | This folder (mae_ast in s3prl) is the upstream module for fine-tuning MAE-AST on Superb with S3prl, and acts as a demo for how to use the MAE-AST in downstream applications. 4 | 5 | ## Setting up a S3prl environment 6 | Here's an example of the environment setup used, given that fairseq and s3prl are cloned as separate folders in the home directory. 7 | 8 | ``` 9 | conda create -n s3prl_mae_ast python=3.7 10 | conda activate s3prl_mae_ast 11 | conda install pathlib 12 | cd ~/s3prl 13 | pip install -e ./ 14 | cd ~/fairseq 15 | pip install -e ./ 16 | pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 17 | pip install --upgrade protobuf==3.20.0 18 | ``` 19 | 20 | ## Fine-tuning 21 | 22 | Fine-tuning commands should be performed in ``~/s3prl/s3prl`` with ``s3prl_mae_ast`` activated 23 | All downstream configuration files used in the paper can be accessed in s3prl/config. These contain default settings or slight modifications to gradient accumulate steps and batch size to mimic default settings on our hardware. Datasets will need to be downloaded as specified in s3prl. 24 | 25 | What follows are the exact run commands used alongside these config files for fine-tuning (disregarding template paths): 26 | #### Command to sanity check imports 27 | ``` 28 | python run_downstream.py -m train -n S3prl_mae_ast_sanity -u npc -d example -f 29 | ``` 30 | 31 | #### Keyword Spotting 1 (KS1) (Speech Commands) 32 | ``` 33 | python run_downstream.py -m train -n MAE_AST_Template_Name -u mae_ast -d speech_commands -f \ 34 | -k /path/to/checkpoint.pt \ 35 | -s last_hidden_state \ 36 | -o "config.downstream_expert.datarc.speech_commands_root='/path/to/speech_commands_v0.01/',,\ 37 | config.downstream_expert.datarc.speech_commands_test_root='/path/to/speech_commands_test_set_v0.01/',,\ 38 | config.optimizer.lr=1.0e-5" 39 | ``` 40 | 41 | #### Speaker Identification (SID) (VoxCeleb1) 42 | ``` 43 | python run_downstream.py -m train -n MAE_AST_Template_Name -u mae_ast -d voxceleb1 -f \ 44 | -k /path/to/checkpoint.pt \ 45 | -s hidden_states \ 46 | -o "config.downstream_expert.datarc.file_path='/path/to/VoxCeleb1/',,\ 47 | config.optimizer.lr=1.0e-4" 48 | ``` 49 | 50 | #### Emotion Recognition IEMOCAP (ER) 51 | Recall ER takes place over five folds, with the resulting test score being the average of the tests from each fold. 52 | ``` 53 | for test_fold in fold1 fold2 fold3 fold4 fold5; 54 | do 55 | python run_downstream.py -m train -n MAE_AST_Template_Name$test_fold -u mae_ast -d emotion -f \ 56 | -k /path/to/checkpoint.pt \ 57 | -s last_hidden_state \ 58 | -o "/path/to/IEMOCAP_full_release',,\ 59 | config.downstream_expert.datarc.test_fold=$test_fold,,\ 60 | config.optimizer.lr=1.0e-4" 61 | done 62 | ``` 63 | -------------------------------------------------------------------------------- /s3prl/mae_ast/expert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- # 2 | """*********************************************************************************************""" 3 | # Adopted from Fast VGS Wrapper by Puyuan Peng 4 | """*********************************************************************************************""" 5 | # FileName [ upstream/mae_ast/expert.py ] 6 | # Synopsis [ Upstream MAE-AST Wrapper ] 7 | # Author [ Alan Baade ] 8 | # Copyright [ Copyleft(c), Alan Baade ] 9 | """*********************************************************************************************""" 10 | 11 | import argparse 12 | from typing import List 13 | from packaging import version 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from torch.nn.utils.rnn import pad_sequence 18 | 19 | import torchaudio 20 | 21 | from ..interfaces import UpstreamBase 22 | 23 | import sys 24 | sys.path.insert(1, '/home/abaade/MAE-AST-Public') 25 | from mae_ast.models.mae_ast import MAE_AST, MAE_AST_Config 26 | from mae_ast.tasks.mae_ast_pretraining import MAE_AST_Pretraining_Config, MAE_AST_Pretraining_Task 27 | 28 | from types import SimpleNamespace 29 | 30 | class UpstreamExpert(UpstreamBase): 31 | def __init__(self, ckpt, **kwargs): 32 | super().__init__(**kwargs) 33 | 34 | checkpoint = torch.load(ckpt) 35 | 36 | self.cfg = checkpoint["cfg"]["model"] 37 | self.task_cfg = checkpoint["cfg"]["task"] 38 | 39 | self.model = MAE_AST(SimpleNamespace(**checkpoint["cfg"]["model"]), SimpleNamespace(**checkpoint["cfg"]["task"])) 40 | 41 | self.model.load_state_dict(checkpoint["model"], strict=True) 42 | 43 | # Required for hidden states to have defined indices. 44 | self.model.encoder.layerdrop = 0 45 | 46 | self.sample_rate = self.task_cfg['sample_rate'] 47 | self.feature_dim = self.task_cfg['feature_dim'] 48 | self.feature_rate = self.task_cfg['feature_rate'] 49 | 50 | self.is_decoder_finetune = False # TODO bad way of passing in info 51 | 52 | def get_downsample_rates(self, key: str) -> int: 53 | return 320 54 | # self.downsample_rate = round(self.sample_rate / self.feature_rate * self.feature_dim/(16*16)) 55 | 56 | def wav_to_spectrogram(self, wav): 57 | return torchaudio.compliance.kaldi.fbank( # Frame shift and length are standard at 10, 25 58 | waveform=wav, 59 | sample_frequency=self.sample_rate, 60 | use_energy=False, 61 | num_mel_bins=self.feature_dim 62 | ) 63 | 64 | def forward(self, wavs): 65 | device = wavs[0].device 66 | 67 | features = [self.wav_to_spectrogram(wav.unsqueeze(0)) for wav in wavs] 68 | feature_lengths = torch.LongTensor([len(feature) for feature in features]).to(device) 69 | feature_padding_mask = ~torch.lt( 70 | torch.arange(max(feature_lengths)).unsqueeze(0).to(device), 71 | feature_lengths.unsqueeze(1), 72 | ) 73 | padded_features = pad_sequence(features, batch_first=True) 74 | 75 | results = self.model(padded_features, padding_mask=feature_padding_mask, mask=False, features_only=True, is_decoder_finetune=self.is_decoder_finetune) 76 | 77 | return {"last_hidden_state": results["x"], "hidden_states": results["hidden_states"]} 78 | -------------------------------------------------------------------------------- /mae_ast/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import argparse 8 | import importlib 9 | import os 10 | 11 | from fairseq.dataclass import FairseqDataclass 12 | from fairseq.dataclass.utils import merge_with_parent 13 | from hydra.core.config_store import ConfigStore 14 | 15 | from fairseq.tasks.fairseq_task import FairseqTask, LegacyFairseqTask # noqa 16 | 17 | 18 | # register dataclass 19 | TASK_DATACLASS_REGISTRY = {} 20 | TASK_REGISTRY = {} 21 | TASK_CLASS_NAMES = set() 22 | 23 | 24 | def setup_task(cfg: FairseqDataclass, **kwargs): 25 | task = None 26 | task_name = getattr(cfg, "task", None) 27 | 28 | if isinstance(task_name, str): 29 | # legacy tasks 30 | task = TASK_REGISTRY[task_name] 31 | if task_name in TASK_DATACLASS_REGISTRY: 32 | dc = TASK_DATACLASS_REGISTRY[task_name] 33 | cfg = dc.from_namespace(cfg) 34 | else: 35 | task_name = getattr(cfg, "_name", None) 36 | 37 | if task_name and task_name in TASK_DATACLASS_REGISTRY: 38 | dc = TASK_DATACLASS_REGISTRY[task_name] 39 | cfg = merge_with_parent(dc(), cfg) 40 | task = TASK_REGISTRY[task_name] 41 | 42 | assert ( 43 | task is not None 44 | ), f"Could not infer task type from {cfg}. Available argparse tasks: {TASK_REGISTRY.keys()}. Available hydra tasks: {TASK_DATACLASS_REGISTRY.keys()}" 45 | 46 | return task.setup_task(cfg, **kwargs) 47 | 48 | 49 | def register_task(name, dataclass=None): 50 | """ 51 | New tasks can be added to fairseq with the 52 | :func:`~fairseq.tasks.register_task` function decorator. 53 | 54 | For example:: 55 | 56 | @register_task('classification') 57 | class ClassificationTask(FairseqTask): 58 | (...) 59 | 60 | .. note:: 61 | 62 | All Tasks must implement the :class:`~fairseq.tasks.FairseqTask` 63 | interface. 64 | 65 | Args: 66 | name (str): the name of the task 67 | """ 68 | 69 | def register_task_cls(cls): 70 | if name in TASK_REGISTRY: 71 | raise ValueError("Cannot register duplicate task ({})".format(name)) 72 | if not issubclass(cls, FairseqTask): 73 | raise ValueError( 74 | "Task ({}: {}) must extend FairseqTask".format(name, cls.__name__) 75 | ) 76 | if cls.__name__ in TASK_CLASS_NAMES: 77 | raise ValueError( 78 | "Cannot register task with duplicate class name ({})".format( 79 | cls.__name__ 80 | ) 81 | ) 82 | TASK_REGISTRY[name] = cls 83 | TASK_CLASS_NAMES.add(cls.__name__) 84 | 85 | if dataclass is not None and not issubclass(dataclass, FairseqDataclass): 86 | raise ValueError( 87 | "Dataclass {} must extend FairseqDataclass".format(dataclass) 88 | ) 89 | 90 | cls.__dataclass = dataclass 91 | if dataclass is not None: 92 | TASK_DATACLASS_REGISTRY[name] = dataclass 93 | 94 | cs = ConfigStore.instance() 95 | node = dataclass() 96 | node._name = name 97 | cs.store(name=name, group="task", node=node, provider="fairseq") 98 | 99 | return cls 100 | 101 | return register_task_cls 102 | 103 | 104 | def get_task(name): 105 | return TASK_REGISTRY[name] 106 | 107 | 108 | def import_tasks(tasks_dir, namespace): 109 | for file in os.listdir(tasks_dir): 110 | path = os.path.join(tasks_dir, file) 111 | if ( 112 | not file.startswith("_") 113 | and not file.startswith(".") 114 | and (file.endswith(".py") or os.path.isdir(path)) 115 | ): 116 | task_name = file[: file.find(".py")] if file.endswith(".py") else file 117 | importlib.import_module(namespace + "." + task_name) 118 | 119 | # expose `task_parser` for sphinx 120 | if task_name in TASK_REGISTRY: 121 | parser = argparse.ArgumentParser(add_help=False) 122 | group_task = parser.add_argument_group("Task name") 123 | # fmt: off 124 | group_task.add_argument('--task', metavar=task_name, 125 | help='Enable this task with: ``--task=' + task_name + '``') 126 | # fmt: on 127 | group_args = parser.add_argument_group( 128 | "Additional command-line arguments" 129 | ) 130 | TASK_REGISTRY[task_name].add_args(group_args) 131 | globals()[task_name + "_parser"] = parser 132 | 133 | 134 | # automatically import any Python files in the tasks/ directory 135 | tasks_dir = os.path.dirname(__file__) 136 | import_tasks(tasks_dir, "mae_ast.tasks") 137 | -------------------------------------------------------------------------------- /mae_ast/criterions/mae_ast_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import re 8 | from dataclasses import dataclass, field 9 | from typing import List, Optional 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from fairseq import metrics, utils 14 | from fairseq.criterions import FairseqCriterion, register_criterion 15 | from fairseq.dataclass import FairseqDataclass 16 | 17 | 18 | @dataclass 19 | class MAE_AST_Criterion_Config(FairseqDataclass): 20 | reconstruction_weight: float = field( 21 | default=10.0, 22 | metadata={"help": "weight for reconstruction in SSAST Type Model. Equals Lambda. Default 10. Set to 0 to not use"}, 23 | ) 24 | classification_weight: float = field( 25 | default=1.0, 26 | metadata={"help": "weight for classification in SSAST Type Model. Default 1. Set to 0 to not use"}, 27 | ) 28 | 29 | 30 | @register_criterion("mae_ast", dataclass=MAE_AST_Criterion_Config) 31 | class MAE_AST_Criterion(FairseqCriterion): 32 | def __init__( 33 | self, 34 | task, 35 | reconstruction_weight, 36 | classification_weight, 37 | # log_keys=None, 38 | ): 39 | super().__init__(task) 40 | self.reconstruction_weight = reconstruction_weight 41 | self.classification_weight = classification_weight 42 | 43 | def forward(self, model, sample, reduce=True, log_pred=False): 44 | """Compute the loss for the given sample. 45 | Returns a tuple with three elements: 46 | 1) the loss 47 | 2) the sample size, which is used as the denominator for the gradient 48 | 3) logging outputs to display while training 49 | """ 50 | 51 | net_output = model(**sample["net_input"]) 52 | 53 | loss = 0.0 54 | logging_output = {} 55 | 56 | logp_m_list_recon, logp_m_list_class = model.get_logits(net_output) 57 | targ_m_list = model.get_targets(net_output, True) 58 | assert (self.reconstruction_weight > 0 or self.classification_weight > 0) and len(logp_m_list_recon) > 0 59 | 60 | if self.reconstruction_weight > 0: 61 | loss_recon = F.mse_loss(logp_m_list_recon, targ_m_list) 62 | logging_output["loss_recon"] = loss_recon.detach().item() 63 | loss += self.reconstruction_weight * loss_recon 64 | 65 | if self.classification_weight > 0: 66 | all_dots = torch.matmul(logp_m_list_class, targ_m_list.transpose(-1, -2)) 67 | log_softmax = torch.log_softmax(all_dots, dim=-1) 68 | loss_info_nce = -torch.mean(torch.diagonal(log_softmax, dim1=-2, dim2=-1)) 69 | 70 | logging_output["loss_info_nce"] = loss_info_nce.detach().item() 71 | 72 | loss += self.classification_weight * loss_info_nce 73 | 74 | sample_size = 1 75 | 76 | logging_output = { 77 | "loss": loss.item() if reduce else loss, 78 | "ntokens": targ_m_list.size(1), 79 | "nsentences": sample["id"].numel(), 80 | "sample_size": sample_size, 81 | **logging_output, 82 | } 83 | 84 | return loss, sample_size, logging_output 85 | 86 | @staticmethod 87 | def reduce_metrics(logging_outputs) -> None: 88 | """Aggregate logging outputs from data parallel training (copied from normal cross entropy).""" 89 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 90 | ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) 91 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 92 | 93 | metrics.log_scalar( 94 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 95 | ) 96 | if sample_size != ntokens: 97 | metrics.log_scalar( 98 | "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 99 | ) 100 | metrics.log_derived( 101 | "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) 102 | ) 103 | else: 104 | metrics.log_derived( 105 | "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) 106 | ) 107 | 108 | counts = {} 109 | for lk in logging_outputs[0].keys(): 110 | if lk.startswith("count_"): 111 | val = sum(log[lk] for log in logging_outputs) 112 | metrics.log_scalar(lk, val) 113 | counts[lk] = val 114 | 115 | for lk in logging_outputs[0].keys(): 116 | if lk.startswith("loss_"): 117 | val = sum(log[lk] for log in logging_outputs) 118 | metrics.log_scalar(lk, val / sample_size / math.log(2), round=3) 119 | elif lk.startswith("correct_"): 120 | val = sum(log[lk] for log in logging_outputs) 121 | metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)]) 122 | 123 | @staticmethod 124 | def aggregate_logging_outputs(logging_outputs): 125 | """Aggregate logging outputs from data parallel training.""" 126 | raise NotImplementedError() 127 | 128 | @staticmethod 129 | def logging_outputs_can_be_summed() -> bool: 130 | """ 131 | Whether the logging outputs returned by `forward` can be summed 132 | across workers prior to calling `reduce_metrics`. Setting this 133 | to True will improves distributed training speed. 134 | """ 135 | return False 136 | -------------------------------------------------------------------------------- /mae_ast/tasks/mae_ast_pretraining.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the LICENSE file in 5 | # the root directory of this source tree. An additional grant of patent rights 6 | # can be found in the PATENTS file in the same directory. 7 | 8 | import logging 9 | import os 10 | import sys 11 | from typing import Dict, List, Optional, Tuple 12 | 13 | import numpy as np 14 | 15 | from dataclasses import dataclass, field 16 | from fairseq.data import Dictionary 17 | from mae_ast.data import MAE_AST_Dataset 18 | from fairseq.dataclass import ChoiceEnum 19 | from fairseq.dataclass.configs import FairseqDataclass 20 | from fairseq.tasks import register_task 21 | from fairseq.tasks.fairseq_task import FairseqTask 22 | from omegaconf import MISSING 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | MASK_TYPE_CHOICES = ChoiceEnum(["retain_spans", "random_mask", "random_mask_batched", "chunk_mask"]) 27 | 28 | 29 | @dataclass 30 | class MAE_AST_Pretraining_Config(FairseqDataclass): 31 | data: str = field(default=MISSING, metadata={"help": "path to data directory"}) 32 | 33 | sample_rate: int = field( 34 | default=16_000, 35 | metadata={ 36 | "help": "target sample rate. audio files will be up/down " 37 | "sampled to this rate" 38 | }, 39 | ) 40 | normalize: bool = field( 41 | default=False, 42 | metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, 43 | ) 44 | enable_padding: bool = field( 45 | default=False, 46 | metadata={"help": "pad shorter samples instead of cropping"}, 47 | ) 48 | max_keep_size: Optional[int] = field( 49 | default=None, 50 | metadata={"help": "exclude sample longer than this"}, 51 | ) 52 | max_sample_size: Optional[int] = field( 53 | default=None, 54 | metadata={"help": "max sample size to crop to for batching"}, 55 | ) 56 | min_sample_size: Optional[int] = field( 57 | default=None, 58 | metadata={"help": "min sample size to crop to for batching"}, 59 | ) 60 | random_crop: Optional[bool] = field( 61 | default=True, 62 | metadata={"help": "always crop from the beginning if false"}, 63 | ) 64 | pad_audio: Optional[bool] = field( 65 | default=False, 66 | metadata={"help": "pad audio to the longest one in the batch if true"}, 67 | ) 68 | 69 | feature_type: Optional[str] = field( 70 | default='wav', 71 | metadata={"help": "choose from ['wav', 'spectrogram', 'fbank', 'mfcc']"} 72 | ) 73 | 74 | feature_rate: Optional[int] = field( 75 | default=100, 76 | metadata={ 77 | "help": "rate of feature input to the transformer, if use wav, this arg is omited, else if use spectrogram/fbank/mfcc, the default is 100, i.e. 1s audio gives 100 frames. the label rate of using MFCC is also 100"} 78 | ) 79 | 80 | feature_dim: Optional[int] = field( 81 | default=100, 82 | metadata={ 83 | "help": "dim feature input to the transformer, if use wav, this arg is omited, else if use spectrogram/fbank/mfcc, the default is 80"} 84 | ) 85 | 86 | deltas: Optional[bool] = field( 87 | default=True, 88 | metadata={ 89 | "help": "whether or not add delta and delta-delta to the feature, only effective for spectrogram/fbank/mfcc"} 90 | ) 91 | 92 | mask_spans: Optional[bool] = field( 93 | default=False, 94 | metadata={"help": "mask random spans, same as that is used in HuBERT and w2v2"} 95 | ) 96 | 97 | mask_type: MASK_TYPE_CHOICES = field( 98 | default='random_mask', 99 | metadata={"help": 100 | """Determine type of mask for MAE pretraining. 101 | -retain_spans: Only for frame data. Wav2Vec2 like masking. 102 | -random_mask: Perform masking on completely random tokens. No chunking. Used in MAE 103 | -random_mask_batched: random_mask with the same mask across the batch. 104 | -chunk_mask: Perform masking on chunks until mask_spans hit. From SSAST. Same across batch for speed. 105 | """} 106 | ) 107 | 108 | 109 | @register_task("mae_ast_pretraining", dataclass=MAE_AST_Pretraining_Config) 110 | class MAE_AST_Pretraining_Task(FairseqTask): 111 | cfg: MAE_AST_Pretraining_Config 112 | 113 | def __init__( 114 | self, 115 | cfg: MAE_AST_Pretraining_Config, 116 | ) -> None: 117 | super().__init__(cfg) 118 | 119 | logger.info(f"current directory is {os.getcwd()}") 120 | logger.info(f"MAEPretrainingTask Config {cfg}") 121 | 122 | self.cfg = cfg 123 | 124 | @property 125 | def source_dictionary(self) -> Optional[Dictionary]: 126 | return None 127 | 128 | @property 129 | def target_dictionary(self) -> Optional[Dictionary]: 130 | return None 131 | 132 | @property 133 | def dictionaries(self) -> List[Dictionary]: 134 | return None 135 | 136 | @classmethod 137 | def setup_task( 138 | cls, cfg: MAE_AST_Pretraining_Config, **kwargs 139 | ) -> "MAE_AST_Pretraining_Task": 140 | return cls(cfg) 141 | 142 | def load_dataset(self, split: str, **kwargs) -> None: 143 | manifest = f"{self.cfg.data}/{split}.tsv" 144 | 145 | self.datasets[split] = MAE_AST_Dataset( 146 | manifest, 147 | sample_rate=self.cfg.sample_rate, 148 | max_keep_sample_size=self.cfg.max_keep_size, 149 | min_keep_sample_size=self.cfg.min_sample_size, 150 | max_sample_size=self.cfg.max_sample_size, 151 | pad_audio=self.cfg.pad_audio, 152 | normalize=self.cfg.normalize, 153 | random_crop=self.cfg.random_crop, 154 | feature_type=self.cfg.feature_type, 155 | feature_dim=self.cfg.feature_dim, 156 | deltas=self.cfg.deltas, 157 | feature_rate=self.cfg.feature_rate 158 | ) 159 | 160 | def max_positions(self) -> Tuple[int, int]: 161 | return (sys.maxsize, sys.maxsize) 162 | 163 | def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: 164 | return indices 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MAE-AST 2 | This repository contains the code for the paper [MAE-AST: Masked Autoencoding Audio Spectrogram Transformer](https://arxiv.org/abs/2203.16691). Pretrained checkpoints to be hosted in the coming few days. 3 | 4 | This repository contains three folders: config, mae_ast, and s3prl. Config contains a default pre-training config for the mae-ast. The mae_ast folder contains the main code for the model, and runs under [fairseq](https://github.com/facebookresearch/fairseq). This includes a criterion, task, data loading, and models. The s3prl folder provides the upstream model and configuration for fine-tuning the MAE-AST on Superb tasks under the [S3prl repository](https://github.com/s3prl/s3prl). This repository does not include fine-tuning code for AudioSet, Librispeech, and KS2, which are instead evaluated under the [SSAST library](https://github.com/YuanGongND/ssast) with no settings changed. 5 | 6 | Please email abaade@utexas.edu for questions. 7 | 8 | ## Pretrained Model Download 9 | Below are the two 12-layer models used in the overall results section of the paper, with a masking ratio of 75%. Clicking the link attempts to display the model checkpoints as a text file. Use wget or open the link in a new tab and save. 10 | | Download | Model | Layers | Masking | AS | ESC-50 | KS2 | KS1 | SID | ER | 11 | |-----------|---------------|--------|---------|-------|--------|-------|-------|-------|-------| 12 | | [Checkpoint](https://www.cs.utexas.edu/~harwath/model_checkpoints/mae_ast/chunk_patch_75_12LayerEncoder.pt) | MAE-AST Patch | 12 | Chunked | 0.306 | 0.900 | 0.979 | 0.958 | - | 0.598 | 13 | | [Checkpoint](https://www.cs.utexas.edu/~harwath/model_checkpoints/mae_ast/random_frame_75_12LayerEncoder.pt) | MAE-AST Frame | 12 | Random | 0.230 | 0.889 | 0.980 | 0.973 | 0.633 | 0.621 | 14 | 15 | ## Pre-Training 16 | Pretraining on fairseq is done as follows 17 | 18 | ### Environment Setup 19 | Run the following commands with conda to set up an environment for pretraining. This assumes that fairseq is downloaded to the home directory 20 | ``` 21 | conda create -n fairseq_mae_ast python=3.9 22 | conda activate fairseq_mae_ast 23 | pip install soundfile 24 | cd ~/fairseq 25 | pip install -e ./ 26 | conda install tensorboardX 27 | conda install av -c conda-forge 28 | pip install sortedcontainers 29 | pip install tensorboard 30 | ``` 31 | 32 | ### Input files 33 | The dataset code takes in a directory which contains the files train.tsv, valid.tsv, and test.tsv, containing paths to the train, valid, and test data respectively. Each of train.tsv, valid.tsv, and test.tsv are tab separated value files with a ``/`` on the first line, followed by lines with (audio file paths, tab, length in frames of that audio file). For example, train.tsv starts with: 34 | ``` 35 | / 36 | /path/to/AudioSet/unbalanced/6XUF56FlKvg.mkv 479232 37 | /path/to/data/AudioSet/unbalanced/eJS_911G6ps.mkv 477696 38 | ``` 39 | and test.tsv starts with: 40 | ``` 41 | / 42 | /path/to/LibriSpeech/data/test-other/3331/159609/3331-159609-0002.flac 225600 43 | /path/to/LibriSpeech/data/test-other/3331/159609/3331-159609-0021.flac 165920 44 | ``` 45 | The dataset expects either mkv or flac files as input. 46 | 47 | ### Environment Variables 48 | Let MAE-AST-Public be the base directory of this repository 49 | 50 | Run the following to set up enviroment variables 51 | ``` 52 | conda activate fairseq_mae_ast 53 | cd ~/MAE-AST-Public 54 | export HYDRA_FULL_ERROR=1 55 | data_dir=/path/to/directory_with_train_valid_test_tsv_input_files 56 | config_dir=/path/to/MAE-AST-Public/config/pretrain 57 | user_dir=/path/to/MAE-AST-Public/mae_ast 58 | ``` 59 | 60 | ### Pretraining commands 61 | The following run commands overwrite the default pretrain configuration, and contain the most important settings to change. 62 | 63 | The code for configuration settings is at the top of ``mae_ast/models/mae_ast.py`` and ``mae_ast/tasks/mae_ast_pretraining.py``. The main model logic (model forward pass) is in the middle of ``mae_ast/models/mae_ast.py`` 64 | 65 | #### Patched, Chunked Masking (SSAST), 12 Layer Encoder, 75% masking ratio 66 | Default Model Patch (12 Layer). 67 | ``` 68 | fairseq-hydra-train \ 69 | --config-dir ${config_dir} --config-name mae_ast common.user_dir=${user_dir} task.data=${data_dir} model._name=mae_ast criterion._name=mae_ast \ 70 | model.encoder_layers=12 model.decoder_layers=2 \ 71 | model.random_mask_prob=0.75 task.mask_type="chunk_mask" \ 72 | model.ast_kernel_size_chan=16 model.ast_kernel_size_time=16 model.ast_kernel_stride_chan=16 model.ast_kernel_stride_time=16 \ 73 | criterion.classification_weight=1 criterion.reconstruction_weight=10 \ 74 | distributed_training.distributed_world_size=1 distributed_training.nprocs_per_node=1 \ 75 | common.log_interval=200 checkpoint.save_interval_updates=25000 \ 76 | optimization.max_update=550000 dataset.max_tokens=8388608 optimization.lr=[0.0001]\ 77 | hydra.run.dir=/path/to/output_model_directory 78 | ``` 79 | 80 | #### Frame, Random Masking, 12 Layer Encoder, 75% masking ratio 81 | Default Model Frame (12 Layer). 82 | Changing the kernel sizes and strides determines frame vs patch models. 83 | ``` 84 | fairseq-hydra-train \ 85 | --config-dir ${config_dir} --config-name mae_ast common.user_dir=${user_dir} task.data=${data_dir} model._name=mae_ast criterion._name=mae_ast \ 86 | model.encoder_layers=12 model.decoder_layers=2 \ 87 | model.random_mask_prob=0.75 task.mask_type="random_mask" \ 88 | model.ast_kernel_size_chan=128 model.ast_kernel_size_time=2 model.ast_kernel_stride_chan=128 model.ast_kernel_stride_time=2 \ 89 | criterion.classification_weight=1 criterion.reconstruction_weight=10 \ 90 | distributed_training.distributed_world_size=1 distributed_training.nprocs_per_node=1 \ 91 | common.log_interval=200 checkpoint.save_interval_updates=25000 \ 92 | optimization.max_update=550000 dataset.max_tokens=8388608 optimization.lr=[0.0001]\ 93 | hydra.run.dir=/path/to/output_model_directory 94 | ``` 95 | 96 | #### Frame, Chunked Masking (Wav2Vec2), 12 Layer Encoder, 75% masking ratio 97 | The random mask probability is 1.45 due to overlap in Wav2Vec2-style masking (specified by task.mask_type="retain_spans"), which creates an average 75% masking ratio. 98 | Set the random mask probability to 0.74 for an average of 50% masking. For all other mask types, the random mask probability directly corresponds to the amount of tokens masked. 99 | ``` 100 | fairseq-hydra-train \ 101 | --config-dir ${config_dir} --config-name mae_ast common.user_dir=${user_dir} task.data=${data_dir} model._name=mae_ast criterion._name=mae_ast \ 102 | model.encoder_layers=12 model.decoder_layers=2 \ 103 | model.random_mask_prob=1.45 task.mask_type="retain_spans" \ 104 | model.ast_kernel_size_chan=128 model.ast_kernel_size_time=2 model.ast_kernel_stride_chan=128 model.ast_kernel_stride_time=2 \ 105 | criterion.classification_weight=1 criterion.reconstruction_weight=10 \ 106 | distributed_training.distributed_world_size=1 distributed_training.nprocs_per_node=1 \ 107 | common.log_interval=200 checkpoint.save_interval_updates=25000 \ 108 | optimization.max_update=550000 dataset.max_tokens=8388608 optimization.lr=[0.0001]\ 109 | hydra.run.dir=/path/to/output_model_directory 110 | ``` 111 | 112 | ## Fine-Tuning 113 | The s3prl directory contains an example for fine-tuning the MAE-AST on superb, plus a readme with specific fine-tuning settings. s3prl/mae_ast/hubconf.py takes in a checkpoint generated during pretraining and uses it on downstream tasks. 114 | -------------------------------------------------------------------------------- /mae_ast/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """isort:skip_file""" 6 | 7 | import argparse 8 | import importlib 9 | import os 10 | from contextlib import ExitStack 11 | 12 | from fairseq.dataclass import FairseqDataclass 13 | from fairseq.dataclass.utils import merge_with_parent 14 | from hydra.core.config_store import ConfigStore 15 | from omegaconf import open_dict, OmegaConf 16 | 17 | from fairseq.models.composite_encoder import CompositeEncoder 18 | from fairseq.models.distributed_fairseq_model import DistributedFairseqModel 19 | from fairseq.models.fairseq_decoder import FairseqDecoder 20 | from fairseq.models.fairseq_encoder import FairseqEncoder 21 | from fairseq.models.fairseq_incremental_decoder import FairseqIncrementalDecoder 22 | from fairseq.models.fairseq_model import ( 23 | BaseFairseqModel, 24 | FairseqEncoderDecoderModel, 25 | FairseqEncoderModel, 26 | FairseqLanguageModel, 27 | FairseqModel, 28 | FairseqMultiModel, 29 | ) 30 | 31 | 32 | MODEL_REGISTRY = {} 33 | MODEL_DATACLASS_REGISTRY = {} 34 | ARCH_MODEL_REGISTRY = {} 35 | ARCH_MODEL_NAME_REGISTRY = {} 36 | ARCH_MODEL_INV_REGISTRY = {} 37 | ARCH_CONFIG_REGISTRY = {} 38 | 39 | 40 | __all__ = [ 41 | "BaseFairseqModel", 42 | "CompositeEncoder", 43 | "DistributedFairseqModel", 44 | "FairseqDecoder", 45 | "FairseqEncoder", 46 | "FairseqEncoderDecoderModel", 47 | "FairseqEncoderModel", 48 | "FairseqIncrementalDecoder", 49 | "FairseqLanguageModel", 50 | "FairseqModel", 51 | "FairseqMultiModel", 52 | ] 53 | 54 | 55 | def build_model(cfg: FairseqDataclass, task): 56 | 57 | model = None 58 | model_type = getattr(cfg, "_name", None) or getattr(cfg, "arch", None) 59 | 60 | if not model_type and len(cfg) == 1: 61 | # this is hit if config object is nested in directory that is named after model type 62 | 63 | model_type = next(iter(cfg)) 64 | if model_type in MODEL_DATACLASS_REGISTRY: 65 | cfg = cfg[model_type] 66 | else: 67 | raise Exception( 68 | "Could not infer model type from directory. Please add _name field to indicate model type. " 69 | "Available models: " 70 | + str(MODEL_DATACLASS_REGISTRY.keys()) 71 | + " Requested model type: " 72 | + model_type 73 | ) 74 | 75 | if model_type in ARCH_MODEL_REGISTRY: 76 | # case 1: legacy models 77 | model = ARCH_MODEL_REGISTRY[model_type] 78 | elif model_type in MODEL_DATACLASS_REGISTRY: 79 | # case 2: config-driven models 80 | model = MODEL_REGISTRY[model_type] 81 | 82 | if model_type in MODEL_DATACLASS_REGISTRY: 83 | # set defaults from dataclass. note that arch name and model name can be the same 84 | dc = MODEL_DATACLASS_REGISTRY[model_type] 85 | 86 | if isinstance(cfg, argparse.Namespace): 87 | cfg = dc.from_namespace(cfg) 88 | else: 89 | cfg = merge_with_parent(dc(), cfg) 90 | else: 91 | if model_type in ARCH_CONFIG_REGISTRY: 92 | with open_dict(cfg) if OmegaConf.is_config(cfg) else ExitStack(): 93 | # this calls the different "arch" functions (like base_architecture()) that you indicate 94 | # if you specify --arch on the command line. this is only applicable to the old argparse based models 95 | # hydra models should expose different architectures via different config files 96 | # it will modify the cfg object and default parameters according to the arch 97 | ARCH_CONFIG_REGISTRY[model_type](cfg) 98 | 99 | assert model is not None, ( 100 | f"Could not infer model type from {cfg}. " 101 | "Available models: {}".format(MODEL_DATACLASS_REGISTRY.keys()) 102 | + f" Requested model type: {model_type}" 103 | ) 104 | 105 | return model.build_model(cfg, task) 106 | 107 | 108 | def register_model(name, dataclass=None): 109 | """ 110 | New model types can be added to fairseq with the :func:`register_model` 111 | function decorator. 112 | 113 | For example:: 114 | 115 | @register_model('lstm') 116 | class LSTM(FairseqEncoderDecoderModel): 117 | (...) 118 | 119 | .. note:: All models must implement the :class:`BaseFairseqModel` interface. 120 | Typically you will extend :class:`FairseqEncoderDecoderModel` for 121 | sequence-to-sequence tasks or :class:`FairseqLanguageModel` for 122 | language modeling tasks. 123 | 124 | Args: 125 | name (str): the name of the model 126 | """ 127 | 128 | def register_model_cls(cls): 129 | if name in MODEL_REGISTRY: 130 | raise ValueError("Cannot register duplicate model ({})".format(name)) 131 | if not issubclass(cls, BaseFairseqModel): 132 | raise ValueError( 133 | "Model ({}: {}) must extend BaseFairseqModel".format(name, cls.__name__) 134 | ) 135 | MODEL_REGISTRY[name] = cls 136 | if dataclass is not None and not issubclass(dataclass, FairseqDataclass): 137 | raise ValueError( 138 | "Dataclass {} must extend FairseqDataclass".format(dataclass) 139 | ) 140 | 141 | cls.__dataclass = dataclass 142 | if dataclass is not None: 143 | MODEL_DATACLASS_REGISTRY[name] = dataclass 144 | 145 | cs = ConfigStore.instance() 146 | node = dataclass() 147 | node._name = name 148 | cs.store(name=name, group="model", node=node, provider="fairseq") 149 | 150 | @register_model_architecture(name, name) 151 | def noop(_): 152 | pass 153 | 154 | return cls 155 | 156 | return register_model_cls 157 | 158 | 159 | def register_model_architecture(model_name, arch_name): 160 | """ 161 | New model architectures can be added to fairseq with the 162 | :func:`register_model_architecture` function decorator. After registration, 163 | model architectures can be selected with the ``--arch`` command-line 164 | argument. 165 | 166 | For example:: 167 | 168 | @register_model_architecture('lstm', 'lstm_luong_wmt_en_de') 169 | def lstm_luong_wmt_en_de(cfg): 170 | args.encoder_embed_dim = getattr(cfg.model, 'encoder_embed_dim', 1000) 171 | (...) 172 | 173 | The decorated function should take a single argument *cfg*, which is a 174 | :class:`omegaconf.DictConfig`. The decorated function should modify these 175 | arguments in-place to match the desired architecture. 176 | 177 | Args: 178 | model_name (str): the name of the Model (Model must already be 179 | registered) 180 | arch_name (str): the name of the model architecture (``--arch``) 181 | """ 182 | 183 | def register_model_arch_fn(fn): 184 | if model_name not in MODEL_REGISTRY: 185 | raise ValueError( 186 | "Cannot register model architecture for unknown model type ({})".format( 187 | model_name 188 | ) 189 | ) 190 | if arch_name in ARCH_MODEL_REGISTRY: 191 | raise ValueError( 192 | "Cannot register duplicate model architecture ({})".format(arch_name) 193 | ) 194 | if not callable(fn): 195 | raise ValueError( 196 | "Model architecture must be callable ({})".format(arch_name) 197 | ) 198 | ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] 199 | ARCH_MODEL_NAME_REGISTRY[arch_name] = model_name 200 | ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) 201 | ARCH_CONFIG_REGISTRY[arch_name] = fn 202 | return fn 203 | 204 | return register_model_arch_fn 205 | 206 | 207 | def import_models(models_dir, namespace): 208 | for file in os.listdir(models_dir): 209 | path = os.path.join(models_dir, file) 210 | if ( 211 | not file.startswith("_") 212 | and not file.startswith(".") 213 | and (file.endswith(".py") or os.path.isdir(path)) 214 | ): 215 | model_name = file[: file.find(".py")] if file.endswith(".py") else file 216 | importlib.import_module(namespace + "." + model_name) 217 | 218 | # extra `model_parser` for sphinx 219 | if model_name in MODEL_REGISTRY: 220 | parser = argparse.ArgumentParser(add_help=False) 221 | group_archs = parser.add_argument_group("Named architectures") 222 | group_archs.add_argument( 223 | "--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name] 224 | ) 225 | group_args = parser.add_argument_group( 226 | "Additional command-line arguments" 227 | ) 228 | MODEL_REGISTRY[model_name].add_args(group_args) 229 | globals()[model_name + "_parser"] = parser 230 | 231 | 232 | # automatically import any Python files in the models/ directory 233 | models_dir = os.path.dirname(__file__) 234 | import_models(models_dir, "mae_ast.models") 235 | -------------------------------------------------------------------------------- /mae_ast/data/mae_ast_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import itertools 7 | import logging 8 | import os 9 | import sys 10 | from typing import Any, List, Optional, Union 11 | 12 | import numpy as np 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from fairseq.data import data_utils 17 | from fairseq.data.fairseq_dataset import FairseqDataset 18 | 19 | import torchaudio 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def load_audio(manifest_path, max_keep, min_keep): 25 | n_long, n_short = 0, 0 26 | names, inds, sizes = [], [], [] 27 | with open(manifest_path) as f: 28 | root = f.readline().strip() 29 | for ind, line in enumerate(f): 30 | items = line.strip().split("\t") 31 | assert len(items) == 2, line 32 | sz = int(items[1]) 33 | if min_keep is not None and sz < min_keep: 34 | n_short += 1 35 | elif max_keep is not None and sz > max_keep: 36 | n_long += 1 37 | else: 38 | names.append(items[0]) 39 | inds.append(ind) 40 | sizes.append(sz) 41 | tot = ind + 1 42 | logger.info( 43 | ( 44 | f"max_keep={max_keep}, min_keep={min_keep}, " 45 | f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " 46 | f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" 47 | ) 48 | ) 49 | return root, names, inds, tot, sizes 50 | 51 | 52 | class MAE_AST_Dataset(FairseqDataset): 53 | def __init__( 54 | self, 55 | manifest_path: str, 56 | sample_rate: float, 57 | max_keep_sample_size: Optional[int] = None, 58 | min_keep_sample_size: Optional[int] = None, 59 | max_sample_size: Optional[int] = None, 60 | shuffle: bool = True, 61 | pad_audio: bool = False, 62 | normalize: bool = False, 63 | random_crop: bool = False, 64 | feature_type: str = "wav", 65 | feature_dim: int = 36, 66 | deltas: bool = True, 67 | feature_rate: int = 100, 68 | ): 69 | self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio( 70 | manifest_path, max_keep_sample_size, min_keep_sample_size 71 | ) 72 | assert feature_type in ['wav', 'spectrogram', 'fbank', 'mfcc'], feature_type 73 | self.feature_rate = feature_rate 74 | self.feature_type = feature_type 75 | self.feature_dim = feature_dim 76 | self.deltas = deltas 77 | self.sample_rate = sample_rate 78 | self.shuffle = shuffle 79 | self.random_crop = random_crop 80 | 81 | self.max_sample_size = ( 82 | max_sample_size if max_sample_size is not None else sys.maxsize 83 | ) 84 | self.pad_audio = pad_audio 85 | self.normalize = normalize 86 | logger.info( 87 | f"pad_audio={pad_audio}, random_crop={random_crop}, " 88 | f"normalize={normalize}, max_sample_size={self.max_sample_size}" 89 | ) 90 | 91 | def get_audio(self, index): 92 | import soundfile as sf 93 | import av 94 | 95 | wav_path = os.path.join(self.audio_root, self.audio_names[index]) 96 | if (wav_path.endswith(".mkv")): 97 | with av.open(wav_path, metadata_errors="ignore") as container: 98 | decode = container.decode(audio=0) 99 | first_frame = next(decode) 100 | cur_sample_rate = first_frame.sample_rate 101 | aframes_list = [first_frame.to_ndarray()] 102 | for frame in decode: 103 | aframes_list.append(frame.to_ndarray()) 104 | aframes = np.concatenate(aframes_list, 1) 105 | wav = torch.as_tensor(aframes).mean(dim=0) 106 | else: 107 | wav, cur_sample_rate = sf.read(wav_path) 108 | wav = torch.from_numpy(wav).float() 109 | if self.feature_type == "wav": 110 | feat = self.postprocess_wav(wav, cur_sample_rate) 111 | else: 112 | feat = self.postprocess_spec(wav, cur_sample_rate) 113 | return feat 114 | 115 | def __getitem__(self, index): 116 | wav = self.get_audio(index) 117 | return {"id": index, "source": wav} # , "label_list": labels} 118 | 119 | def __len__(self): 120 | return len(self.sizes) 121 | 122 | def crop_to_max_size(self, wav, target_size): 123 | size = len(wav) 124 | diff = size - target_size 125 | if diff <= 0: 126 | return wav, 0 127 | 128 | start, end = 0, target_size 129 | if self.random_crop: 130 | start = np.random.randint(0, diff + 1) 131 | end = size - diff + start 132 | return wav[start:end], start 133 | 134 | def collater(self, samples): 135 | samples = [s for s in samples if s["source"] is not None] 136 | if len(samples) == 0: 137 | return {} 138 | 139 | audios = [s["source"] for s in samples] 140 | audio_sizes = [len(s) for s in audios] 141 | if self.pad_audio: 142 | audio_size = min(max(audio_sizes), self.max_sample_size) 143 | else: 144 | audio_size = min(min(audio_sizes), self.max_sample_size) 145 | collated_audios, padding_mask, audio_starts = self.collater_audio( 146 | audios, audio_size 147 | ) 148 | 149 | net_input = {"source": collated_audios, "padding_mask": padding_mask} 150 | batch = { 151 | "id": torch.LongTensor([s["id"] for s in samples]), 152 | "net_input": net_input, 153 | } 154 | 155 | return batch 156 | 157 | def collater_audio(self, audios, audio_size): 158 | if self.feature_type == "wav": 159 | collated_audios = audios[0].new_zeros(len(audios), audio_size) 160 | else: 161 | feat_dim = self.feature_dim * 3 if self.deltas else self.feature_dim 162 | collated_audios = audios[0].new_zeros(len(audios), audio_size, feat_dim) 163 | 164 | padding_mask = ( 165 | torch.BoolTensor(collated_audios.shape[:2]).fill_(False) 166 | ) 167 | audio_starts = [0 for _ in audios] 168 | for i, audio in enumerate(audios): 169 | diff = len(audio) - audio_size 170 | if diff == 0: 171 | collated_audios[i] = audio 172 | elif diff < 0: 173 | assert self.pad_audio 174 | collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) 175 | padding_mask[i, diff:] = True 176 | else: 177 | collated_audios[i], audio_starts[i] = self.crop_to_max_size( 178 | audio, audio_size 179 | ) 180 | return collated_audios, padding_mask, audio_starts 181 | 182 | def num_tokens(self, index): 183 | return self.size(index) 184 | 185 | def size(self, index): 186 | if self.pad_audio: 187 | return self.sizes[index] 188 | return min(self.sizes[index], self.max_sample_size) 189 | 190 | def ordered_indices(self): 191 | if self.shuffle: 192 | order = [np.random.permutation(len(self))] 193 | else: 194 | order = [np.arange(len(self))] 195 | 196 | order.append(self.sizes) 197 | return np.lexsort(order)[::-1] 198 | 199 | def postprocess_wav(self, wav, cur_sample_rate): 200 | if wav.dim() == 2: 201 | wav = wav.mean(-1) 202 | assert wav.dim() == 1, wav.dim() 203 | 204 | if cur_sample_rate != self.sample_rate: 205 | raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") 206 | 207 | if self.normalize: 208 | with torch.no_grad(): 209 | wav = F.layer_norm(wav, wav.shape) 210 | return wav 211 | 212 | def postprocess_spec(self, wav, cur_sample_rate): 213 | if wav.dim() == 2: 214 | wav = wav.mean(-1) 215 | assert wav.dim() == 1, wav.dim() 216 | 217 | if cur_sample_rate != self.sample_rate: 218 | wav = torchaudio.functional.resample(wav, cur_sample_rate, self.sample_rate) 219 | 220 | wav = wav.view(1, -1) 221 | if self.feature_type == "spectrogram": 222 | feat = torchaudio.compliance.kaldi.spectrogram( 223 | waveform=wav, 224 | sample_frequency=self.sample_rate 225 | ) # (time, freq) 226 | elif self.feature_type == "fbank": 227 | feat = torchaudio.compliance.kaldi.fbank( 228 | waveform=wav, 229 | sample_frequency=self.sample_rate, 230 | use_energy=False, 231 | num_mel_bins=self.feature_dim 232 | ) # (time, freq) 233 | else: 234 | feat = torchaudio.compliance.kaldi.mfcc( 235 | waveform=wav, 236 | sample_frequency=self.sample_rate, 237 | use_energy=False, 238 | ) # (time, freq) 239 | feat = feat[:, :self.feature_dim] 240 | if self.deltas: 241 | feat = feat.transpose(0, 1) # (freq, time) 242 | deltas = torchaudio.functional.compute_deltas(feat) 243 | ddeltas = torchaudio.functional.compute_deltas(deltas) 244 | concat = torch.cat([feat, deltas, ddeltas], dim=0) 245 | concat = concat.transpose(0, 1).contiguous() 246 | return concat 247 | else: 248 | return feat 249 | -------------------------------------------------------------------------------- /mae_ast/models/mae_ast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from typing import Dict, List, Optional, Tuple 8 | import random 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | from dataclasses import dataclass, field 14 | from fairseq import utils 15 | from fairseq.data.data_utils import compute_mask_indices 16 | from fairseq.data.dictionary import Dictionary 17 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass 18 | from fairseq.models import BaseFairseqModel, register_model 19 | from fairseq.models.wav2vec.wav2vec2 import ( 20 | ConvFeatureExtractionModel, 21 | # TransformerEncoder, 22 | ) 23 | # from fairseq.modules import ( 24 | # SinusoidalPositionalEmbedding 25 | # ) 26 | from fairseq.modules import GradMultiply, LayerNorm 27 | from mae_ast.tasks.mae_ast_pretraining import ( 28 | MAE_AST_Pretraining_Config, 29 | MAE_AST_Pretraining_Task, 30 | ) 31 | from omegaconf import II 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) 36 | 37 | 38 | @dataclass 39 | class MAE_AST_Config(FairseqDataclass): 40 | # Patching settings (frame vs patch based model) 41 | ast_kernel_size_chan: int = field( 42 | default=16, 43 | metadata={ 44 | "help": "When using reconstruction on image data, this sets the kernel size for channels (size and stride must be identical for reconstruction) used for masked features. Default 16, see ast_models.py"} 45 | ) 46 | ast_kernel_size_time: int = field( 47 | default=16, 48 | metadata={ 49 | "help": "When using reconstruction on image data, this sets the kernel size for time (size and stride must be identical for reconstruction) used for masked features. Default 16, see ast_models.py."} 50 | ) 51 | ast_kernel_stride_chan: int = field( 52 | default=16, 53 | metadata={ 54 | "help": "When using reconstruction on image data, this sets the kernel stride for channels (size and stride must be identical for reconstruction) used for masked features. Default 16, see ast_models.py"} 55 | ) 56 | ast_kernel_stride_time: int = field( 57 | default=16, 58 | metadata={ 59 | "help": "When using reconstruction on image data, this sets the kernel stride for time (size and stride must be identical for reconstruction) used for masked features. Default 16, see ast_models.py."} 60 | ) 61 | 62 | # Encoder and general transformer settings 63 | encoder_layers: int = field( 64 | default=12, metadata={"help": "num encoder layers in the encoder transformer"} 65 | ) 66 | encoder_embed_dim: int = field( 67 | default=768, metadata={"help": "encoder embedding dimension"} 68 | ) 69 | encoder_ffn_embed_dim: int = field( 70 | default=3072, metadata={"help": "encoder embedding dimension for FFN"} 71 | ) 72 | encoder_attention_heads: int = field( 73 | default=12, metadata={"help": "num encoder attention heads"} 74 | ) 75 | activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( 76 | default="gelu", metadata={"help": "activation function to use"} 77 | ) 78 | 79 | layer_norm_first: bool = field( 80 | default=False, 81 | metadata={"help": "apply layernorm first in the transformer"}, 82 | ) 83 | feature_grad_mult: float = field( 84 | default=1.0, 85 | metadata={"help": "multiply feature extractor var grads by this"}, 86 | ) 87 | 88 | # Decoder settings 89 | use_post_enc_proj: bool = field( 90 | default=False, 91 | metadata={ 92 | "help": "Linear projection on the encoder output. Required if decoder embed dim != encoder embed dim"} 93 | ) 94 | decoder_embed_dim: int = field( 95 | default=768, metadata={"help": "decoder embedding dimension"} 96 | ) 97 | decoder_layers: int = field( 98 | default=2, metadata={"help": "num encoder layers in the decoder transformer"} 99 | ) 100 | decoder_layerdrop: float = field( 101 | default=0.0, 102 | metadata={"help": "probability of dropping a decoder transformer layer. The decoder is shallow, so this should typically be 0"}, 103 | ) 104 | 105 | # Dropouts 106 | dropout: float = field( 107 | default=0.1, 108 | metadata={"help": "dropout probability for the transformer"}, 109 | ) 110 | attention_dropout: float = field( 111 | default=0.1, 112 | metadata={"help": "dropout probability for attention weights"}, 113 | ) 114 | activation_dropout: float = field( 115 | default=0.0, 116 | metadata={"help": "dropout probability after activation in FFN"}, 117 | ) 118 | encoder_layerdrop: float = field( 119 | default=0.0, 120 | metadata={"help": "probability of dropping a transformer layer"}, 121 | ) 122 | dropout_input: float = field( 123 | default=0.0, 124 | metadata={"help": "dropout to apply to the input (after feat extr)"}, 125 | ) 126 | 127 | # Overall Masking Settings 128 | random_mask_prob: float = field( 129 | default=0.75, 130 | metadata={"help": "Probability of a given token being masked. Exact use depends on mask type"} 131 | ) 132 | 133 | # Wav2Vec2-like Masking settings 134 | mask_length: int = field(default=10, metadata={"help": "mask length"}) 135 | 136 | mask_selection: MASKING_DISTRIBUTION_CHOICES = field( 137 | default="static", metadata={"help": "how to choose mask length"} 138 | ) 139 | mask_other: float = field( 140 | default=0, 141 | metadata={ 142 | "help": "secondary mask argument " 143 | "(used for more complex distributions), " 144 | "see help in compute_mask_indicesh" 145 | }, 146 | ) 147 | no_mask_overlap: bool = field( 148 | default=False, metadata={"help": "whether to allow masks to overlap"} 149 | ) 150 | mask_min_space: int = field( 151 | default=0, 152 | metadata={"help": "min space between spans (if no overlap is enabled)"}, 153 | ) 154 | 155 | # Convolutional positional embeddings (not used for MAE-AST) 156 | conv_pos: int = field( 157 | default=128, 158 | metadata={"help": "number of filters for convolutional positional embeddings"}, 159 | ) 160 | conv_pos_groups: int = field( 161 | default=16, 162 | metadata={"help": "number of groups for convolutional positional embedding"}, 163 | ) 164 | 165 | # loss computation 166 | checkpoint_activations: bool = field( 167 | default=False, 168 | metadata={"help": "recompute activations and save memory for extra compute"}, 169 | ) 170 | 171 | # positional embeddings 172 | max_token_length: int = field( 173 | default=48000, 174 | metadata={"help": "the longest input sequence length, used for sinusoidal positional embedding"} 175 | ) 176 | enc_sine_pos: bool = field( 177 | default=False, 178 | metadata={"help": "sinusoidal positional embeddings for encoder input"} 179 | ) 180 | enc_conv_pos: bool = field( 181 | default=False, 182 | metadata={"help": "convnet positional embeddings for encoder input"} 183 | ) 184 | dec_sine_pos: bool = field( 185 | default=False, 186 | metadata={"help": "sinusoidal positional embeddings for decoder input"} 187 | ) 188 | dec_conv_pos: bool = field( 189 | default=False, 190 | metadata={"help": "convnet positional embeddings for decoder input"} 191 | ) 192 | 193 | 194 | @register_model("mae_ast", dataclass=MAE_AST_Config) 195 | class MAE_AST(BaseFairseqModel): 196 | def __init__( 197 | self, 198 | cfg: MAE_AST_Config, 199 | task_cfg: MAE_AST_Pretraining_Task, 200 | ) -> None: 201 | super().__init__() 202 | logger.info(f"MAEModel Config: {cfg}") 203 | self.cfg = cfg 204 | self.task_cfg = task_cfg 205 | 206 | self.feature_extractor = nn.Identity() 207 | self.post_extract_proj = nn.Linear(cfg.ast_kernel_size_time * cfg.ast_kernel_size_chan, cfg.encoder_embed_dim) 208 | self.layer_norm = LayerNorm(task_cfg.feature_dim) 209 | self.batch_norm = nn.BatchNorm2d(num_features=1, affine=False) 210 | self.unfold = nn.Unfold(kernel_size=(cfg.ast_kernel_size_time, cfg.ast_kernel_size_chan), 211 | stride=(cfg.ast_kernel_stride_time, cfg.ast_kernel_stride_chan)) 212 | 213 | self.is_batched_mask = self.task_cfg.mask_type == 'random_mask_batched' or self.task_cfg.mask_type == 'chunk_patch' 214 | 215 | self.mask_selection = cfg.mask_selection 216 | self.mask_other = cfg.mask_other 217 | self.mask_length = cfg.mask_length 218 | self.no_mask_overlap = cfg.no_mask_overlap 219 | self.mask_min_space = cfg.mask_min_space 220 | 221 | self.dropout_input = nn.Dropout(cfg.dropout_input) 222 | 223 | self.feature_grad_mult = cfg.feature_grad_mult 224 | 225 | self.encoder_mask_emb = nn.Parameter(torch.FloatTensor(cfg.encoder_embed_dim).uniform_()) 226 | 227 | if self.cfg.enc_conv_pos: 228 | self.enc_conv_pos_embed = ConvPosEmbed(cfg) 229 | if self.cfg.enc_sine_pos: 230 | self.enc_sine_pos_embed = SinusoidalPositionalEncoding(d_model=self.cfg.encoder_embed_dim, max_len=self.cfg.max_token_length) 231 | 232 | self.encoder = TransformerEncoder(cfg) 233 | 234 | if self.cfg.use_post_enc_proj: 235 | self.post_enc_proj = nn.Linear(self.cfg.encoder_embed_dim, self.cfg.decoder_embed_dim) 236 | else: 237 | assert self.cfg.decoder_embed_dim == self.cfg.encoder_embed_dim, "Need post_enc_projection if encoder and decoder embed dims differ" 238 | 239 | self.decoder_mask_emb = nn.Parameter( 240 | torch.FloatTensor(cfg.decoder_embed_dim).uniform_()) 241 | if self.cfg.dec_conv_pos: 242 | self.dec_conv_pos_embed = ConvPosEmbed(cfg) 243 | if self.cfg.dec_sine_pos: 244 | self.dec_sine_pos_embed = SinusoidalPositionalEncoding(d_model=self.cfg.decoder_embed_dim, max_len=self.cfg.max_token_length) 245 | # decoder uses the same set of params as the encoders, except for the number of layers and width 246 | self.decoder = TransformerEncoder(cfg, decoder=True) 247 | 248 | self.final_proj_reconstruction = nn.Linear(cfg.decoder_embed_dim, 249 | cfg.ast_kernel_size_time * cfg.ast_kernel_size_chan) 250 | self.final_proj_classification = nn.Linear(cfg.decoder_embed_dim, 251 | cfg.ast_kernel_size_time * cfg.ast_kernel_size_chan) 252 | 253 | assert self.task_cfg.feature_type != "wav" 254 | 255 | def upgrade_state_dict_named(self, state_dict, name): 256 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 257 | 258 | super().upgrade_state_dict_named(state_dict, name) 259 | return state_dict 260 | 261 | @classmethod 262 | def build_model(cls, cfg: MAE_AST_Config, task: MAE_AST_Pretraining_Task): 263 | """Build a new model instance.""" 264 | 265 | model = MAE_AST(cfg, task.cfg) 266 | return model 267 | 268 | def forward_features(self, source: torch.Tensor) -> torch.Tensor: 269 | if self.feature_grad_mult > 0: 270 | features = self.feature_extractor(source) 271 | if self.feature_grad_mult != 1.0: 272 | features = GradMultiply.apply(features, self.feature_grad_mult) 273 | else: 274 | with torch.no_grad(): 275 | features = self.feature_extractor(source) 276 | return features 277 | 278 | def forward_padding_mask( 279 | self, 280 | features: torch.Tensor, 281 | padding_mask: torch.Tensor, 282 | feature_dim=128, 283 | ) -> torch.Tensor: 284 | if (padding_mask[:, -1].sum() == 0): # Fast exit during training if not necessary 285 | return padding_mask.new_zeros(features.shape[:2]) 286 | 287 | non_zero_count = padding_mask.size(-1) - padding_mask.sum(dim=-1) 288 | num_patches_over_channel = feature_dim // self.cfg.ast_kernel_size_chan 289 | padding_mask_indices = (((non_zero_count - 1) // self.cfg.ast_kernel_size_time) + 1) * num_patches_over_channel 290 | new_padding_mask = padding_mask.new_zeros(features.shape[:2]) 291 | 292 | for i in range(new_padding_mask.size(0)): 293 | if padding_mask_indices[i] < new_padding_mask.size(-1): 294 | new_padding_mask[i, padding_mask_indices[i]:] = True 295 | 296 | return new_padding_mask 297 | 298 | def forward_mask(self, features, padding_mask): 299 | B, T, C = features.shape 300 | num_retained_tokens = int((1 - self.cfg.random_mask_prob) * T) 301 | num_retained_tokens = max(1, num_retained_tokens) 302 | retained_idx = [] 303 | masked_idx = [] 304 | 305 | if self.task_cfg.mask_type == 'retain_spans': 306 | num_retained_tokens = 0 307 | while num_retained_tokens == 0: # This loop will almost never run more than once for any reasonable mask ratio. 308 | retained_indices = compute_mask_indices( 309 | (B, T), 310 | padding_mask, 311 | self.cfg.random_mask_prob, 312 | self.mask_length, 313 | self.mask_selection, 314 | self.mask_other, 315 | min_masks=2, 316 | no_overlap=self.no_mask_overlap, 317 | min_space=self.mask_min_space, 318 | ) 319 | num_retained_tokens = retained_indices[0].sum() 320 | for i in range(B): 321 | cur_retained_idx = np.where(retained_indices[i])[0] 322 | cur_masked_idx = np.where(~retained_indices[i])[0] 323 | retained_idx.append(cur_retained_idx) 324 | features[i, cur_masked_idx] = self.encoder_mask_emb 325 | masked_idx.append(cur_masked_idx) 326 | elif self.task_cfg.mask_type == 'random_mask': 327 | for i in range(B): 328 | idx = list(range(T)) 329 | random.shuffle(idx) 330 | cur_retained_idx = idx[:num_retained_tokens] 331 | retained_idx.append(cur_retained_idx) 332 | cur_masked_idx = idx[num_retained_tokens:] 333 | masked_idx.append(cur_masked_idx) 334 | features[i, cur_masked_idx] = self.encoder_mask_emb 335 | elif self.task_cfg.mask_type == 'random_mask_batched': 336 | idx = list(range(T)) 337 | random.shuffle(idx) 338 | cur_retained_idx = idx[:num_retained_tokens] 339 | retained_idx = [cur_retained_idx] 340 | cur_masked_idx = idx[num_retained_tokens:] 341 | masked_idx = [cur_masked_idx] 342 | features[:, cur_masked_idx] = self.encoder_mask_emb 343 | elif self.task_cfg.mask_type == 'chunk_mask': # Copies SSAST code and goes to bottom right from uniform starting index 344 | cur_masked_idx = set() 345 | chunk_size = random.randrange(3, 5 + 1) 346 | chan_adjust = self.task_cfg.feature_dim // self.cfg.ast_kernel_stride_chan 347 | num_masked_tokens = T - num_retained_tokens 348 | while len(cur_masked_idx) <= num_masked_tokens: 349 | t_topleft = random.randrange(T) 350 | for t_offset in range(0, chunk_size): 351 | for c_offset in range(0, chunk_size): 352 | mask_cand = t_topleft + t_offset + chan_adjust * c_offset 353 | if (mask_cand < T): 354 | cur_masked_idx.add(mask_cand) 355 | cur_masked_idx = list(cur_masked_idx) 356 | cur_masked_idx = cur_masked_idx[:num_masked_tokens] 357 | cur_retained_idx = list(set(range(T)).difference(cur_masked_idx)) 358 | for i in range(B): # Using same mask for whole batch because SSAST code is very slow 359 | retained_idx.append(cur_retained_idx) 360 | masked_idx.append(cur_masked_idx) 361 | features[i, cur_masked_idx] = self.encoder_mask_emb 362 | 363 | return retained_idx, masked_idx, T - num_retained_tokens 364 | 365 | def forward( 366 | self, 367 | source: torch.Tensor, 368 | padding_mask: Optional[torch.Tensor] = None, 369 | mask: bool = True, 370 | features_only: bool = False, 371 | output_layer: Optional[int] = None, 372 | is_decoder_finetune: bool = False, 373 | is_input_prepatched: bool = False, 374 | ) -> Dict[str, torch.Tensor]: 375 | 376 | # Checks whether the dataset was patched and normalized before-hand. is_input_prepatched == True for speed profiling during training. 377 | if is_input_prepatched: 378 | source_patch = source 379 | else: 380 | # Batch normalization ('Mimics' AST dataset normalization) 381 | source = source.unsqueeze(1) 382 | source = self.batch_norm(source) * 0.5 # Mean 0, St dev 0.5 383 | # Create image patches for masking via Unfold. BTC input shape BTC output shape 384 | # Output continues to be unsqueezed from batch norm 385 | source_patch = self.unfold(source).transpose(-1, -2) 386 | 387 | features = self.forward_features(source_patch) 388 | if padding_mask is not None: # Reshape padding mask 389 | padding_mask = self.forward_padding_mask(features, padding_mask) 390 | if self.post_extract_proj is not None: # Project patches to vectors of size encoder_dim 391 | features = self.post_extract_proj(features) 392 | if mask: # additional regularization adopted from hubert 393 | features = self.dropout_input(features) 394 | 395 | B, T, C = features.shape 396 | 397 | # Calculate retained_idx and masked_idx. Uses safe assumption that nothing is padded during pretraining 398 | if mask: 399 | retained_idx, masked_idx, num_masked_tokens = self.forward_mask(features, padding_mask) 400 | else: 401 | retained_idx = [] 402 | masked_idx = [] 403 | num_masked_tokens = 0 404 | 405 | 406 | # Pre-Encoder Positional Embeddings 407 | if self.cfg.enc_conv_pos: 408 | conv_pos = self.enc_conv_pos_embed(features, padding_mask) 409 | features = conv_pos + features 410 | if self.cfg.enc_sine_pos: 411 | sine_pos = self.enc_sine_pos_embed(features, padding_mask) 412 | features = sine_pos + features 413 | 414 | 415 | # Remove masked tokens from features 416 | if mask: 417 | if self.is_batched_mask: 418 | x = features[:, retained_idx[0]] 419 | retained_padding_mask = padding_mask[:, retained_idx[0]] 420 | else: 421 | x = [] 422 | retained_padding_mask = [] 423 | for i in range(B): 424 | x.append(features[i, retained_idx[i]]) 425 | retained_padding_mask.append(padding_mask[i, retained_idx[i]]) 426 | x = torch.stack(x, dim=0) 427 | retained_padding_mask = torch.stack(retained_padding_mask, dim=0) 428 | else: 429 | x = features 430 | retained_padding_mask = padding_mask 431 | 432 | 433 | # Encoder forward pass + Early return for features 434 | x, encoder_hidden_states = self.encoder( 435 | x, 436 | padding_mask=retained_padding_mask, 437 | layer=None if output_layer is None else output_layer - 1, 438 | ) 439 | 440 | 441 | if not is_decoder_finetune and (features_only or not mask): 442 | return {"x": x, "padding_mask": retained_padding_mask, "features": features, "hidden_states": encoder_hidden_states} 443 | 444 | if self.cfg.use_post_enc_proj: 445 | x = self.post_enc_proj(x) 446 | 447 | 448 | # Add masked tokens back 449 | if mask: 450 | full_x = torch.empty((B, T, C), device=x.device, dtype=x.dtype) 451 | mask_indices = torch.zeros(torch.Size([B, T]), device=padding_mask.device, dtype=torch.bool) 452 | if self.is_batched_mask: 453 | full_x[:, retained_idx[0]] = x 454 | full_x[:, masked_idx[0]] = self.decoder_mask_emb 455 | mask_indices[:, masked_idx[0]] = True 456 | else: 457 | for i, (cur_feat, ridx, midx) in enumerate(zip(x, retained_idx, masked_idx)): 458 | full_x[i, ridx] = cur_feat 459 | full_x[i, midx] = self.decoder_mask_emb 460 | mask_indices[i, midx] = True 461 | else: 462 | full_x = x 463 | 464 | 465 | # Pre decoder positional embeddings 466 | if self.cfg.dec_conv_pos: 467 | conv_pos = self.dec_conv_pos_embed(full_x, padding_mask) 468 | full_x = conv_pos + full_x 469 | if self.cfg.dec_sine_pos: 470 | # Concerning that magnitudes of layer-normed encoder outputs are similar to decoder positional embeddings. 471 | sine_pos = self.dec_sine_pos_embed(full_x, padding_mask) 472 | full_x = sine_pos + full_x 473 | 474 | 475 | # Decoder forward pass 476 | x, decoder_hidden_states = self.decoder(full_x, padding_mask=padding_mask, layer=None) 477 | 478 | if is_decoder_finetune: 479 | return {"x": x, "padding_mask": padding_mask, "features": features, "hidden_states": encoder_hidden_states + decoder_hidden_states} 480 | 481 | 482 | # Construct linear projection logits and masked reconstruction targets 483 | x_masked_indices = x[mask_indices].view(B, num_masked_tokens, -1) 484 | 485 | logit_m_list_recon = self.final_proj_reconstruction(x_masked_indices) 486 | logit_m_list_class = self.final_proj_classification(x_masked_indices) 487 | 488 | target_m_list = source_patch[mask_indices].view(B, num_masked_tokens, -1) 489 | 490 | 491 | result = { 492 | "logit_m_list_recon": logit_m_list_recon, 493 | "logit_m_list_class": logit_m_list_class, 494 | "target_m_list": target_m_list, 495 | "padding_mask": padding_mask, 496 | } 497 | return result 498 | 499 | 500 | def extract_features( 501 | self, 502 | source: torch.Tensor, 503 | padding_mask: Optional[torch.Tensor] = None, 504 | mask: bool = False, 505 | ret_conv: bool = False, 506 | output_layer: Optional[int] = None, 507 | ) -> Tuple[torch.Tensor, torch.Tensor]: 508 | res = self.forward( 509 | source, 510 | padding_mask=padding_mask, 511 | mask=mask, 512 | features_only=True, 513 | output_layer=output_layer, 514 | ) 515 | feature = res["features"] if ret_conv else res["x"] 516 | return feature, res["padding_mask"] 517 | 518 | def get_logits(self, net_output): 519 | logits_list = net_output["logit_m_list_recon"], net_output["logit_m_list_class"] 520 | logits_list = [x.float() for x in logits_list if x is not None] 521 | return logits_list 522 | 523 | def get_targets(self, net_output, is_masked=True): 524 | return net_output["target_m_list"].float() 525 | 526 | def get_extra_losses(self, net_output): 527 | extra_losses = [] 528 | names = [] 529 | 530 | if "features_pen" in net_output: 531 | extra_losses.append(net_output["features_pen"]) 532 | names.append("features_pen") 533 | 534 | return extra_losses, names 535 | 536 | def remove_pretraining_modules(self): 537 | self.final_proj_reconstruction = None 538 | self.final_proj_classification = None 539 | 540 | 541 | import math 542 | from dataclasses import dataclass, field 543 | from typing import List, Tuple 544 | 545 | import numpy as np 546 | import torch 547 | import torch.nn as nn 548 | import torch.nn.functional as F 549 | from fairseq import utils 550 | from fairseq.data.data_utils import compute_mask_indices 551 | from fairseq.dataclass import ChoiceEnum, FairseqDataclass 552 | from fairseq.models import BaseFairseqModel, register_model 553 | from fairseq.modules import ( 554 | Fp32GroupNorm, 555 | Fp32LayerNorm, 556 | GradMultiply, 557 | GumbelVectorQuantizer, 558 | LayerNorm, 559 | MultiheadAttention, 560 | SamePad, 561 | TransposeLast, 562 | ) 563 | from fairseq.modules.checkpoint_activations import checkpoint_wrapper 564 | from fairseq.modules.transformer_sentence_encoder import init_bert_params 565 | from fairseq.utils import buffered_arange, index_put, is_xla_tensor 566 | from fairseq.distributed import fsdp_wrap 567 | 568 | 569 | class ConvPosEmbed(nn.Module): 570 | def __init__(self, args, decoder=False): 571 | super().__init__() 572 | self.embedding_dim = args.encoder_embed_dim 573 | self.pos_conv = nn.Conv1d( 574 | self.embedding_dim, 575 | self.embedding_dim, 576 | kernel_size=args.conv_pos, 577 | padding=args.conv_pos // 2, 578 | groups=args.conv_pos_groups, 579 | ) 580 | dropout = 0 581 | std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) 582 | nn.init.normal_(self.pos_conv.weight, mean=0, std=std) 583 | nn.init.constant_(self.pos_conv.bias, 0) 584 | 585 | self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) 586 | self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) 587 | 588 | def forward(self, x, padding_mask=None): 589 | if padding_mask is not None: 590 | x = index_put(x, padding_mask, 0) 591 | 592 | x_conv = self.pos_conv(x.transpose(1, 2)) # [b,t,c] -> [b,c,t] 593 | x_conv = x_conv.transpose(1, 2) # [b,c,t] -> [b,t,c] 594 | return x_conv 595 | 596 | 597 | class TransformerEncoder(nn.Module): 598 | def __init__(self, args, decoder=False): 599 | super().__init__() 600 | 601 | self.dropout = args.dropout 602 | 603 | layers = [] 604 | if decoder: 605 | # the onnly difference between encoder and decoder is the number of layers 606 | num_layers = args.decoder_layers 607 | self.embedding_dim = args.decoder_embed_dim 608 | self.layerdrop = args.decoder_layerdrop 609 | else: 610 | num_layers = args.encoder_layers 611 | self.embedding_dim = args.encoder_embed_dim 612 | self.layerdrop = args.encoder_layerdrop 613 | for _ in range(num_layers): 614 | layer = TransformerSentenceEncoderLayer( 615 | embedding_dim=self.embedding_dim, 616 | ffn_embedding_dim=args.encoder_ffn_embed_dim, 617 | num_attention_heads=args.encoder_attention_heads, 618 | dropout=self.dropout, 619 | attention_dropout=args.attention_dropout, 620 | activation_dropout=args.activation_dropout, 621 | activation_fn=args.activation_fn, 622 | layer_norm_first=args.layer_norm_first, 623 | ) 624 | if args.checkpoint_activations: 625 | layer = fsdp_wrap(layer) 626 | layer = checkpoint_wrapper(layer) 627 | layers.append(layer) 628 | self.layers = nn.ModuleList(layers) 629 | 630 | self.layer_norm_first = args.layer_norm_first 631 | self.layer_norm = LayerNorm(self.embedding_dim) 632 | 633 | self.apply(init_bert_params) 634 | 635 | def forward(self, x, padding_mask=None, layer=None): 636 | # print(f"shape of input the transformer: {x.shape}") 637 | # print(f"shape of padding mask: {padding_mask.shape}") 638 | x, layer_results = self.extract_features(x, padding_mask, layer) 639 | 640 | if self.layer_norm_first and layer is None: 641 | x = self.layer_norm(x) 642 | 643 | return x, layer_results 644 | 645 | def extract_features(self, x, padding_mask=None, tgt_layer=None): 646 | 647 | if not self.layer_norm_first: 648 | x = self.layer_norm(x) 649 | 650 | x = F.dropout(x, p=self.dropout, training=self.training) 651 | 652 | # B x T x C -> T x B x C 653 | x = x.transpose(0, 1) 654 | 655 | hidden_states = [] 656 | r = None 657 | for i, layer in enumerate(self.layers): 658 | dropout_probability = np.random.random() 659 | if not self.training or (dropout_probability > self.layerdrop): 660 | x, z = layer(x, self_attn_padding_mask=padding_mask, need_weights=False) 661 | hidden_states.append(x.transpose(0, 1)) 662 | if i == tgt_layer: 663 | r = x 664 | break 665 | 666 | if r is not None: 667 | x = r 668 | 669 | # T x B x C -> B x T x C 670 | x = x.transpose(0, 1) 671 | 672 | return x, hidden_states 673 | 674 | def max_positions(self): 675 | """Maximum output length supported by the encoder.""" 676 | return self.cfg.max_positions 677 | 678 | def upgrade_state_dict_named(self, state_dict, name): 679 | """Upgrade a (possibly old) state dict for new versions of fairseq.""" 680 | return state_dict 681 | 682 | 683 | class TransformerSentenceEncoderLayer(nn.Module): 684 | """ 685 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained 686 | models. 687 | """ 688 | 689 | def __init__( 690 | self, 691 | embedding_dim: float = 768, 692 | ffn_embedding_dim: float = 3072, 693 | num_attention_heads: float = 8, 694 | dropout: float = 0.1, 695 | attention_dropout: float = 0.1, 696 | activation_dropout: float = 0.1, 697 | activation_fn: str = "relu", 698 | layer_norm_first: bool = False, 699 | ) -> None: 700 | 701 | super().__init__() 702 | # Initialize parameters 703 | self.embedding_dim = embedding_dim 704 | self.dropout = dropout 705 | self.activation_dropout = activation_dropout 706 | 707 | # Initialize blocks 708 | self.activation_fn = utils.get_activation_fn(activation_fn) 709 | self.self_attn = MultiheadAttention( 710 | self.embedding_dim, 711 | num_attention_heads, 712 | dropout=attention_dropout, 713 | self_attention=True, 714 | ) 715 | 716 | self.dropout1 = nn.Dropout(dropout) 717 | self.dropout2 = nn.Dropout(self.activation_dropout) 718 | self.dropout3 = nn.Dropout(dropout) 719 | 720 | self.layer_norm_first = layer_norm_first 721 | 722 | # layer norm associated with the self attention layer 723 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim) 724 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) 725 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) 726 | 727 | # layer norm associated with the position wise feed-forward NN 728 | self.final_layer_norm = LayerNorm(self.embedding_dim) 729 | 730 | def forward( 731 | self, 732 | x: torch.Tensor, 733 | self_attn_mask: torch.Tensor = None, 734 | self_attn_padding_mask: torch.Tensor = None, 735 | need_weights: bool = False, 736 | att_args=None, 737 | ): 738 | """ 739 | LayerNorm is applied either before or after the self-attention/ffn 740 | modules similar to the original Transformer imlementation. 741 | """ 742 | residual = x 743 | 744 | if self.layer_norm_first: 745 | x = self.self_attn_layer_norm(x) 746 | x, attn = self.self_attn( 747 | query=x, 748 | key=x, 749 | value=x, 750 | key_padding_mask=self_attn_padding_mask, 751 | attn_mask=self_attn_mask, 752 | ) 753 | x = self.dropout1(x) 754 | x = residual + x 755 | 756 | residual = x 757 | x = self.final_layer_norm(x) 758 | x = self.activation_fn(self.fc1(x)) 759 | x = self.dropout2(x) 760 | x = self.fc2(x) 761 | x = self.dropout3(x) 762 | x = residual + x 763 | else: 764 | x, attn = self.self_attn( 765 | query=x, 766 | key=x, 767 | value=x, 768 | key_padding_mask=self_attn_padding_mask, 769 | ) 770 | 771 | x = self.dropout1(x) 772 | x = residual + x 773 | 774 | x = self.self_attn_layer_norm(x) 775 | 776 | residual = x 777 | x = self.activation_fn(self.fc1(x)) 778 | x = self.dropout2(x) 779 | x = self.fc2(x) 780 | x = self.dropout3(x) 781 | x = residual + x 782 | x = self.final_layer_norm(x) 783 | 784 | return x, attn 785 | 786 | 787 | class SinusoidalPositionalEncoding(nn.Module): 788 | 789 | def __init__(self, d_model: int, max_len: int = 480000): 790 | super().__init__() 791 | position = torch.arange(max_len).unsqueeze(1) 792 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 793 | pe = torch.zeros(1, max_len, d_model) 794 | pe[0, :, 0::2] = torch.sin(position * div_term) 795 | pe[0, :, 1::2] = torch.cos(position * div_term) 796 | self.register_buffer('pe', pe) 797 | 798 | def forward(self, x, padding_mask): 799 | """ 800 | Args: 801 | x: Tensor, shape [bsz, seq_len, embedding_dim] 802 | """ 803 | pe = self.pe[:, :x.shape[1]].repeat((padding_mask.shape[0], 1, 1)) 804 | pe[padding_mask] = 0. 805 | return pe 806 | --------------------------------------------------------------------------------