├── .gitignore ├── LICENSE ├── README.md ├── configs ├── beats.yaml └── tagging.yaml ├── datasets ├── beats.py └── tagging.py ├── harmonicstft.py ├── inference.py ├── models ├── base.py ├── beats.py └── tagging.py ├── networks.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | runs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MWM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpecTNT 2 | 3 | Repository exploring the SpecTNT architecture applied to: 4 | - music tagging as described in the paper ["SpecTNT: a time-frequency transformer for music audio"](https://arxiv.org/abs/2110.09127) 5 | - beats and downbeats estimation as described in the paper ["Modeling beats and downbeats with a time-frequency transformer"](https://arxiv.org/abs/2205.14701) 6 | 7 | ### Disclaimer 8 | 9 | This is by no mean an official implementation. Any comments and suggestions regarding fixes and enhancements are highly encouraged. 10 | 11 | ## Datasets 12 | 13 | Datasets constitution and their associated class implementation are excluded from this repo. Dummy classes are presented instead. 14 | 15 | :warning: as is, the implentation of validation steps in [models](/models/) entails that validation datamodules should stack entire tracks by batch of 1. 16 | 17 | List of datasets used in the paper associated to each task for training, validation and testing: 18 | - music tagging: a Million Song Dataset subset split [this way](https://github.com/minzwon/semi-supervised-music-tagging-transformer/blob/master/data/splits/msd_splits.tsv) and intersected with the [LastFM dataset](http://millionsongdataset.com/lastfm/) for tags 19 | - beats and downbeats estimation: Beatles, Ballroom, SMC, Hainsworth, Simac, HJDB, RWC-Popular, Harmonix Set 20 | 21 | ## Configuration 22 | 23 | ### Datamodule 24 | - `input_length`: models are trained on audio chunks which length is defined here in seconds. 25 | - `hop_length`: number of samples between successive frames used to constructs features (melspectrograms or harmonic filters). 26 | - `time_shrinking` *(specific to beat estimation)*: time pooling entails dimension shrinking along the time axis after inference through the front-end model (see below for more details). Target labels tensors should be constructed taking this shrinking into account. Has a crutial impact on inference durations but is not clearly specified in the paper. 27 | 28 | ### Front-end model 29 | - `freq_pooling`, `time_pooling` *(specific to beat estimation)*: pooling along the frequency and the time axis that occur at the end of the front-end module. Have a crutial impact on inference durations but are not clearly specified in the paper. 30 | 31 | ### Network 32 | - `n_channels`, `n_frequencies`, `n_times`: shape of tensors that input the SpecTNT. Should be consistent with audio chunk shape and feature extractor parameters. 33 | - `use_tct`: whether to use Temporal Class Tokens, which act as aggregators along the temporal axis, or not. Set to `true` in case of a track-wise prediction task (ex: tagging) and to `false` in case of a frame-wise prediction task (ex: beat estimation). 34 | - `n_classes`: number of output classes. 35 | 36 | ## Usage 37 | 38 | ### Inference 39 | Inference functions are presented in `inference.py` and should guide users unfamiliar with `pytorch-lightning` and `hydra` libraries. 40 | 41 | ### Training 42 | :warning: Real training should be preceded with a dataset constitution following guidelines presented in each paper. 43 | 44 | Users can though test training pipelines with the dummy dataset classes presented in this repo using the following commands in terminal: 45 | - beats and downbeats estimation: 46 | ```bash 47 | python train.py --config-name beats 48 | ``` 49 | - music tagging: 50 | ```bash 51 | python train.py --config-name tagging 52 | ``` 53 | -------------------------------------------------------------------------------- /configs/beats.yaml: -------------------------------------------------------------------------------- 1 | experiment: baseline 2 | 3 | datamodule: 4 | _target_: datasets.beats.DummyBeatDataModule 5 | batch_size: 2 6 | n_workers: 4 7 | pin_memory: False 8 | sample_rate: 16000 9 | input_length: 5 10 | hop_length: 256 11 | time_shrinking: 4 12 | 13 | features: 14 | _target_: harmonicstft.HarmonicSTFT 15 | sample_rate: 16000 16 | n_fft: 512 17 | n_harmonic: 6 18 | semitone_scale: 2 19 | learn_bw: "only_Q" 20 | # checkpoint: "" 21 | 22 | fe_model: 23 | _target_: networks.ResFrontEnd 24 | in_channels: 6 25 | out_channels: 256 26 | freq_pooling: [2, 2, 2] 27 | time_pooling: [2, 2, 1] 28 | 29 | net: 30 | _target_: networks.SpecTNT 31 | n_channels: 256 32 | n_frequencies: 16 33 | n_times: 78 34 | embed_dim: 128 35 | spectral_dmodel: 64 36 | spectral_nheads: 4 37 | spectral_dimff: 64 38 | temporal_dmodel: 256 39 | temporal_nheads: 8 40 | temporal_dimff: 256 41 | n_blocks: 5 42 | dropout: 0.15 43 | use_tct: false 44 | n_classes: 3 45 | 46 | model: 47 | _target_: models.beats.BeatEstimator 48 | activation_fn: softmax 49 | 50 | trainer: 51 | _target_: pytorch_lightning.Trainer 52 | fast_dev_run: false # for debugging 53 | # gpus: 54 | # - 0 55 | precision: 32 56 | accumulate_grad_batches: 16 57 | check_val_every_n_epoch: 5 58 | max_steps: 1000000 59 | 60 | criterion: 61 | _target_: torch.nn.CrossEntropyLoss 62 | weight: null 63 | 64 | callbacks: 65 | lr_monitor: 66 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 67 | logging_interval: epoch 68 | 69 | optim: 70 | _target_: torch.optim.AdamW 71 | betas: 72 | - 0.9 73 | - 0.999 74 | lr: 0.0005 75 | weight_decay: 0.005 76 | 77 | model_checkpoint: 78 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 79 | dirpath: checkpoints/ 80 | filename: "{best_epoch:02d}" 81 | monitor: val_loss 82 | mode: min 83 | save_last: true 84 | save_top_k: 1 85 | verbose: true 86 | 87 | logger: 88 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 89 | name: "" 90 | save_dir: ${now:%Y-%m-%d}_${now:%H-%M-%S}/tensorboard/ 91 | default_hp_metric: false 92 | 93 | hydra: 94 | run: 95 | dir: runs/beats/${experiment} 96 | 97 | ignore_warning: True 98 | seed: 42 99 | -------------------------------------------------------------------------------- /configs/tagging.yaml: -------------------------------------------------------------------------------- 1 | experiment: baseline 2 | 3 | datamodule: 4 | _target_: datasets.tagging.DummyTaggingDataModule 5 | sample_rate: 22050 6 | input_length: 4.54 7 | batch_size: 2 8 | n_workers: 4 9 | pin_memory: false 10 | 11 | features: 12 | _target_: torchaudio.transforms.MelSpectrogram 13 | sample_rate: 22050 14 | n_fft: 1024 15 | hop_length: 512 16 | n_mels: 128 17 | 18 | fe_model: 19 | _target_: networks.Res2DMaxPoolModule 20 | in_channels: 1 21 | out_channels: 128 22 | pooling: [1, 4] 23 | 24 | net: 25 | _target_: networks.SpecTNT 26 | n_channels: 128 27 | n_frequencies: 128 28 | n_times: 49 29 | embed_dim: 128 30 | spectral_dmodel: 96 31 | spectral_nheads: 4 32 | spectral_dimff: 96 33 | temporal_dmodel: 96 34 | temporal_nheads: 8 35 | temporal_dimff: 96 36 | dropout: 0.15 37 | n_blocks: 3 38 | use_tct: true 39 | n_classes: 50 40 | 41 | model: 42 | _target_: models.tagging.MusicTagger 43 | activation_fn: sigmoid 44 | 45 | trainer: 46 | _target_: pytorch_lightning.Trainer 47 | fast_dev_run: false # for debugging 48 | # gpus: 49 | # - 0 50 | precision: 32 51 | accumulate_grad_batches: 1 52 | check_val_every_n_epoch: 10 53 | max_steps: 1000000 54 | 55 | criterion: 56 | _target_: torch.nn.CrossEntropyLoss 57 | weight: null 58 | 59 | callbacks: 60 | lr_monitor: 61 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 62 | logging_interval: epoch 63 | 64 | optim: 65 | _target_: torch.optim.AdamW 66 | betas: 67 | - 0.9 68 | - 0.999 69 | lr: 0.0005 70 | weight_decay: 0.005 71 | 72 | model_checkpoint: 73 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 74 | dirpath: checkpoints/ 75 | filename: "{best_epoch:02d}" 76 | monitor: val_loss 77 | mode: min 78 | save_last: true 79 | save_top_k: 1 80 | verbose: true 81 | 82 | logger: 83 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 84 | name: "" 85 | save_dir: ${now:%Y-%m-%d}_${now:%H-%M-%S}/tensorboard/ 86 | default_hp_metric: false 87 | 88 | hydra: 89 | run: 90 | dir: runs/tagging/${experiment} 91 | 92 | ignore_warning: True 93 | seed: 42 94 | -------------------------------------------------------------------------------- /datasets/beats.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.utils.data as tud 3 | import pytorch_lightning as pl 4 | 5 | 6 | class DummyBeatDataset(tud.Dataset): 7 | 8 | def __init__(self, sample_rate, input_length, hop_length, time_shrinking, mode): 9 | self.sample_rate = sample_rate 10 | self.input_length = input_length 11 | 12 | self.target_fps = sample_rate / (hop_length * time_shrinking) 13 | self.target_nframes = int(input_length * self.target_fps) 14 | 15 | assert mode in ["train", "validation", "test"] 16 | self.mode = mode 17 | 18 | def __len__(self): 19 | if self.mode == "train": 20 | return 80 21 | elif self.mode == "validation": 22 | return 10 23 | elif self.mode == "test": 24 | return 10 25 | 26 | def __getitem__(self, i): 27 | if self.mode == "train": 28 | return { 29 | 'audio': th.zeros(self.input_length * self.sample_rate), 30 | 'targets': th.zeros(self.target_nframes, 3) 31 | } 32 | elif self.mode in ["validation", "test"]: 33 | return { 34 | 'audio': th.zeros(10 * self.input_length * self.sample_rate), 35 | 'targets': th.zeros(10 * self.target_nframes, 3), 36 | 'beats': th.arange(0, 50, 0.5), 37 | 'downbeats': th.arange(0, 50, 2.) 38 | } 39 | 40 | 41 | class DummyBeatDataModule(pl.LightningDataModule): 42 | def __init__(self, batch_size, n_workers, pin_memory, sample_rate, input_length, hop_length, time_shrinking): 43 | self.batch_size = batch_size 44 | self.n_workers = n_workers 45 | self.pin_memory = pin_memory 46 | self.sample_rate = sample_rate 47 | self.input_length = input_length 48 | self.hop_length = hop_length 49 | self.time_shrinking = time_shrinking 50 | 51 | def setup(self, stage): 52 | self.train_set = DummyBeatDataset( 53 | self.sample_rate, 54 | self.input_length, 55 | self.hop_length, 56 | self.time_shrinking, 57 | "train" 58 | ) 59 | self.val_set = DummyBeatDataset( 60 | self.sample_rate, 61 | self.input_length, 62 | self.hop_length, 63 | self.time_shrinking, 64 | "validation" 65 | ) 66 | self.test_set = DummyBeatDataset( 67 | self.sample_rate, 68 | self.input_length, 69 | self.hop_length, 70 | self.time_shrinking, 71 | "test" 72 | ) 73 | 74 | def train_dataloader(self): 75 | return tud.DataLoader(self.train_set, 76 | batch_size=self.batch_size, 77 | pin_memory=self.pin_memory, 78 | shuffle=True, 79 | num_workers=self.n_workers) 80 | 81 | def val_dataloader(self): 82 | return tud.DataLoader(self.val_set, 83 | batch_size=1, 84 | pin_memory=self.pin_memory, 85 | shuffle=False, 86 | num_workers=self.n_workers) 87 | 88 | def test_dataloader(self): 89 | return tud.DataLoader(self.test_set, 90 | batch_size=1, 91 | pin_memory=self.pin_memory, 92 | shuffle=False, 93 | num_workers=self.n_workers) 94 | -------------------------------------------------------------------------------- /datasets/tagging.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import pytorch_lightning as pl 3 | import torch.utils.data as tud 4 | 5 | 6 | class DummyTaggingDataset(tud.Dataset): 7 | def __init__( 8 | self, 9 | sample_rate, 10 | input_length, 11 | mode 12 | ): 13 | self.num_samples = int(input_length * sample_rate) 14 | self.dummy_tags = th.zeros(50) 15 | self.dummy_tags[0] = 1 16 | 17 | assert mode in ["train", "validation", "test"] 18 | self.mode = mode 19 | 20 | def __len__(self): 21 | if self.mode == "train": 22 | return 80 23 | elif self.mode == "validation": 24 | return 10 25 | elif self.mode == "test": 26 | return 10 27 | 28 | def __getitem__(self, index): 29 | if self.mode == "train": 30 | return { 31 | "audio": th.zeros(self.num_samples), 32 | "targets": self.dummy_tags 33 | } 34 | elif self.mode in ["validation", "test"]: 35 | return { 36 | "audio": th.zeros(10 * self.num_samples), 37 | "targets": self.dummy_tags 38 | } 39 | 40 | 41 | class DummyTaggingDataModule(pl.LightningDataModule): 42 | def __init__( 43 | self, 44 | sample_rate, 45 | input_length, 46 | batch_size, 47 | n_workers, 48 | pin_memory 49 | ): 50 | self.sample_rate = sample_rate 51 | self.input_length = input_length 52 | self.batch_size = batch_size 53 | self.n_workers = n_workers 54 | self.pin_memory = pin_memory 55 | 56 | def setup(self, stage): 57 | self.train_set = DummyTaggingDataset( 58 | sample_rate=self.sample_rate, 59 | input_length=self.input_length, 60 | mode="train" 61 | ) 62 | self.val_set = DummyTaggingDataset( 63 | sample_rate=self.sample_rate, 64 | input_length=self.input_length, 65 | mode="validation" 66 | ) 67 | self.test_set = DummyTaggingDataset( 68 | sample_rate=self.sample_rate, 69 | input_length=self.input_length, 70 | mode="test" 71 | ) 72 | 73 | def train_dataloader(self): 74 | return tud.DataLoader(self.train_set, 75 | batch_size=self.batch_size, 76 | pin_memory=self.pin_memory, 77 | shuffle=True, 78 | num_workers=self.n_workers) 79 | 80 | def val_dataloader(self): 81 | return tud.DataLoader(self.val_set, 82 | batch_size=1, 83 | pin_memory=self.pin_memory, 84 | shuffle=False, 85 | num_workers=self.n_workers) 86 | 87 | def test_dataloader(self): 88 | return tud.DataLoader(self.test_set, 89 | batch_size=1, 90 | pin_memory=self.pin_memory, 91 | shuffle=False, 92 | num_workers=self.n_workers) 93 | -------------------------------------------------------------------------------- /harmonicstft.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torchaudio 3 | import numpy as np 4 | import torch as th 5 | import torch.nn as nn 6 | 7 | 8 | def hz_to_midi(hz): 9 | return 12 * (th.log2(hz) - np.log2(440.0)) + 69 10 | 11 | 12 | def midi_to_hz(midi): 13 | return 440.0 * (2.0 ** ((midi - 69.0)/12.0)) 14 | 15 | 16 | def note_to_midi(note): 17 | return librosa.core.note_to_midi(note) 18 | 19 | 20 | def hz_to_note(hz): 21 | return librosa.core.hz_to_note(hz) 22 | 23 | 24 | def initialize_filterbank(sample_rate, n_harmonic, semitone_scale): 25 | # MIDI 26 | # lowest note 27 | low_midi = note_to_midi('C1') 28 | # highest note 29 | high_note = hz_to_note(sample_rate / (2 * n_harmonic)) 30 | high_midi = note_to_midi(high_note) 31 | # number of scales 32 | level = (high_midi - low_midi) * semitone_scale 33 | midi = np.linspace(low_midi, high_midi, level + 1) 34 | hz = midi_to_hz(midi[:-1]) 35 | # stack harmonics 36 | harmonic_hz = [] 37 | for i in range(n_harmonic): 38 | harmonic_hz = np.concatenate((harmonic_hz, hz * (i+1))) 39 | return harmonic_hz, level 40 | 41 | 42 | class HarmonicSTFT(nn.Module): 43 | """ 44 | Trainable harmonic filters as implemented by Minz Won. 45 | 46 | Paper: https://ccrma.stanford.edu/~urinieto/MARL/publications/ICASSP2020_Won.pdf 47 | Code: https://github.com/minzwon/data-driven-harmonic-filters 48 | Pretrained: https://github.com/minzwon/sota-music-tagging-models/tree/master/training 49 | """ 50 | 51 | def __init__(self, 52 | sample_rate=16000, 53 | n_fft=513, 54 | win_length=None, 55 | hop_length=None, 56 | pad=0, 57 | power=2, 58 | normalized=False, 59 | n_harmonic=6, 60 | semitone_scale=2, 61 | bw_Q=1.0, 62 | learn_bw=None, 63 | checkpoint=None): 64 | super(HarmonicSTFT, self).__init__() 65 | 66 | # Parameters 67 | self.sample_rate = sample_rate 68 | self.n_harmonic = n_harmonic 69 | self.bw_alpha = 0.1079 70 | self.bw_beta = 24.7 71 | 72 | # Spectrogram 73 | self.spec = torchaudio.transforms.Spectrogram(n_fft=n_fft, win_length=win_length, 74 | hop_length=hop_length, pad=pad, 75 | window_fn=th.hann_window, 76 | power=power, normalized=normalized, wkwargs=None) 77 | self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() 78 | 79 | # Initialize the filterbank. Equally spaced in MIDI scale. 80 | harmonic_hz, self.level = initialize_filterbank( 81 | sample_rate, n_harmonic, semitone_scale) 82 | 83 | # Center frequncies to tensor 84 | self.f0 = th.tensor(harmonic_hz.astype('float32')) 85 | 86 | # Bandwidth parameters 87 | if learn_bw == 'only_Q': 88 | self.bw_Q = nn.Parameter(th.tensor( 89 | np.array([bw_Q]).astype('float32'))) 90 | elif learn_bw == 'fix': 91 | self.bw_Q = th.tensor(np.array([bw_Q]).astype('float32')) 92 | 93 | if checkpoint is not None: 94 | state_dict = th.load(checkpoint) 95 | hstft_state_dict = {k.replace('hstft.', ''): v for k, 96 | v in state_dict.items() if 'hstft.' in k} 97 | self.load_state_dict(hstft_state_dict) 98 | 99 | def get_harmonic_fb(self): 100 | # bandwidth 101 | bw = (self.bw_alpha * self.f0 + self.bw_beta) / self.bw_Q 102 | bw = bw.unsqueeze(0) # (1, n_band) 103 | f0 = self.f0.unsqueeze(0) # (1, n_band) 104 | fft_bins = self.fft_bins.unsqueeze(1) # (n_bins, 1) 105 | 106 | up_slope = th.matmul(fft_bins, (2/bw)) + 1 - (2 * f0 / bw) 107 | down_slope = th.matmul(fft_bins, (-2/bw)) + 1 + (2 * f0 / bw) 108 | fb = th.max(self.zero, th.min(down_slope, up_slope)) 109 | return fb 110 | 111 | def to_device(self, device, n_bins): 112 | self.f0 = self.f0.to(device) 113 | self.bw_Q = self.bw_Q.to(device) 114 | # fft bins 115 | self.fft_bins = th.linspace(0, self.sample_rate//2, n_bins) 116 | self.fft_bins = self.fft_bins.to(device) 117 | self.zero = th.zeros(1) 118 | self.zero = self.zero.to(device) 119 | 120 | def forward(self, waveform): 121 | # stft 122 | spectrogram = self.spec(waveform) 123 | # to device 124 | self.to_device(waveform.device, spectrogram.size(1)) 125 | # triangle filter 126 | harmonic_fb = self.get_harmonic_fb() 127 | harmonic_spec = th.matmul( 128 | spectrogram.transpose(1, 2), harmonic_fb).transpose(1, 2) 129 | # (batch, channel, length) -> (batch, harmonic, f0, length) 130 | b, c, l = harmonic_spec.size() 131 | harmonic_spec = harmonic_spec.view(b, self.n_harmonic, self.level, l) 132 | # amplitude to db 133 | harmonic_spec = self.amplitude_to_db(harmonic_spec) 134 | return harmonic_spec 135 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import torch as th 3 | import torch.nn as nn 4 | import hydra.utils as hu 5 | from omegaconf import OmegaConf 6 | 7 | 8 | def predict(audio_path: str, cfg_path: str, ckpt_path: str) -> th.Tensor: 9 | """ 10 | Args: 11 | audio_path: string path to audio file to be analyzed 12 | cfg_path: string path to config 13 | ckpt_path: string path to checkpoint 14 | 15 | Return: 16 | probs_list: torch.Tensor of estimated probability distribution over output classes for each output frame 17 | """ 18 | # Load config and params 19 | cfg = OmegaConf.load(cfg_path) 20 | input_length, sample_rate, batch_size = ( 21 | cfg.datamodule.input_length, 22 | cfg.datamodule.sample_rate, 23 | cfg.datamodule.batch_size 24 | ) 25 | # Load audio 26 | audio, _ = librosa.load(audio_path, sr=sample_rate, mono=True) 27 | audio = th.from_numpy(audio) 28 | # Load modules 29 | feature_extractor = hu.instantiate(cfg.features) 30 | fe_model = hu.instantiate(cfg.fe_model) 31 | net = hu.instantiate(cfg.net, fe_model=fe_model) 32 | # Load weights 33 | if ckpt_path is not None: 34 | ckpt = th.load(ckpt_path, map_location="cpu") 35 | net_state_dict = {k.replace("net.", ""): v for k, 36 | v in ckpt["state_dict"].items() if "feature_extractor" not in k} 37 | net.load_state_dict(net_state_dict) 38 | features_state_dict = {k.replace("feature_extractor.", ""): v for k, 39 | v in ckpt["state_dict"].items() if "feature_extractor" in k} 40 | feature_extractor.load_state_dict(features_state_dict) 41 | _ = net.eval() 42 | _ = feature_extractor.eval() 43 | # Inference loop 44 | audio_chunks = th.cat([el.unsqueeze(0) for el in audio.split( 45 | split_size=int(input_length*sample_rate))[:-1]], dim=0) 46 | probs_list = th.tensor([]) 47 | for batch_audio in audio_chunks.split(batch_size): 48 | with th.no_grad(): 49 | features = feature_extractor(batch_audio) 50 | logits = net(features) 51 | if cfg.model.activation_fn == "softmax": 52 | probs = th.softmax(logits, dim=2) 53 | elif cfg.model.activation_fn == "sigmoid": 54 | probs = th.sigmoid(logits) 55 | probs_list = th.cat( 56 | [probs_list, probs.flatten(end_dim=1).cpu()], dim=0) 57 | return probs_list 58 | 59 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import pytorch_lightning as pl 3 | 4 | 5 | class BaseModel(pl.LightningModule): 6 | def __init__(self, feature_extractor, net, optimizer, lr_scheduler, criterion, datamodule, activation_fn): 7 | super().__init__() 8 | 9 | self.feature_extractor = feature_extractor 10 | self.net = net 11 | self.optimizer = optimizer 12 | self.lr_scheduler = lr_scheduler 13 | self.criterion = criterion 14 | self.datamodule = datamodule 15 | 16 | if activation_fn == "softmax": 17 | self.activation = nn.Softmax(dim=2) 18 | elif activation_fn == "sigmoid": 19 | self.activation = nn.Sigmoid() 20 | 21 | def configure_optimizers(self): 22 | if self.lr_scheduler is None: 23 | return {"optimizer": self.optimizer} 24 | else: 25 | return {"optimizer": self.optimizer, "lr_scheduler": self.lr_scheduler, "monitor": "val_loss"} 26 | 27 | @staticmethod 28 | def _classname(obj, lower=True): 29 | if hasattr(obj, '__name__'): 30 | name = obj.__name__ 31 | else: 32 | name = obj.__class__.__name__ 33 | return name.lower() if lower else name 34 | -------------------------------------------------------------------------------- /models/beats.py: -------------------------------------------------------------------------------- 1 | import mir_eval 2 | import torch as th 3 | from .base import BaseModel 4 | 5 | class BeatEstimator(BaseModel): 6 | def __init__(self, feature_extractor, net, optimizer, lr_scheduler, criterion, datamodule, activation_fn): 7 | super().__init__( 8 | feature_extractor, 9 | net, 10 | optimizer, 11 | lr_scheduler, 12 | criterion, 13 | datamodule, 14 | activation_fn 15 | ) 16 | 17 | self.target_fps = datamodule.sample_rate / \ 18 | (datamodule.hop_length * datamodule.time_shrinking) 19 | 20 | def training_step(self, batch, batch_idx): 21 | losses = {} 22 | x, y = batch['audio'], batch['targets'] 23 | features = self.feature_extractor(x) 24 | logits = self.net(features) 25 | losses['train_loss'] = self.criterion( 26 | logits.flatten(end_dim=1), y.flatten(end_dim=1)) 27 | self.log_dict(losses, on_step=False, on_epoch=True) 28 | return losses['train_loss'] 29 | 30 | def validation_step(self, batch, batch_idx): 31 | losses = {} 32 | audio, targets, ref_beats, ref_downbeats = ( 33 | batch['audio'][0], 34 | batch['targets'][0].cpu(), 35 | batch['beats'][0].cpu(), 36 | batch['downbeats'][0].cpu() 37 | ) 38 | input_length, sample_rate, batch_size = ( 39 | self.datamodule.input_length, 40 | self.datamodule.sample_rate, 41 | self.datamodule.batch_size 42 | ) 43 | audio_chunks = th.cat([el.unsqueeze(0) for el in audio.split( 44 | split_size=int(input_length*sample_rate))[:-1]], dim=0) 45 | # Inference loop 46 | logits_list, probs_list = th.tensor([]), th.tensor([]) 47 | for batch_audio in audio_chunks.split(batch_size): 48 | with th.no_grad(): 49 | features = self.feature_extractor(batch_audio) 50 | logits = self.net(features) 51 | probs = self.activation(logits) 52 | logits_list = th.cat( 53 | [logits_list, logits.flatten(end_dim=1).cpu()], dim=0) 54 | probs_list = th.cat( 55 | [probs_list, probs.flatten(end_dim=1).cpu()], dim=0) 56 | # Postprocessing 57 | beats_data = probs_list.argmax(dim=1) 58 | est_beats = th.where(beats_data == 0)[0] / self.target_fps 59 | est_downbeats = th.where(beats_data == 1)[0] / self.target_fps 60 | # Eval 61 | losses['val_loss'] = self.criterion( 62 | logits_list, targets[:len(logits_list)]) 63 | losses['beats_f_measure'] = mir_eval.beat.f_measure( 64 | ref_beats, est_beats) 65 | losses['downbeats_f_measure'] = mir_eval.beat.f_measure( 66 | ref_downbeats, est_downbeats) 67 | self.log_dict(losses, on_step=False, on_epoch=True) 68 | return losses['val_loss'] 69 | -------------------------------------------------------------------------------- /models/tagging.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from sklearn import metrics 3 | from .base import BaseModel 4 | 5 | 6 | class MusicTagger(BaseModel): 7 | def __init__(self, feature_extractor, net, optimizer, lr_scheduler, criterion, datamodule, activation_fn): 8 | super().__init__( 9 | feature_extractor, 10 | net, 11 | optimizer, 12 | lr_scheduler, 13 | criterion, 14 | datamodule, 15 | activation_fn 16 | ) 17 | 18 | def training_step(self, batch, batch_idx): 19 | loss_dict = {} 20 | x, y = batch['audio'], batch['targets'] 21 | features = self.feature_extractor(x) 22 | logits = self.net(features) 23 | loss_dict['train_loss'] = self.criterion(logits, y) 24 | self.log_dict(loss_dict, on_step=False, on_epoch=True) 25 | return loss_dict['train_loss'] 26 | 27 | def validation_step(self, batch, batch_idx): 28 | loss_dict = {} 29 | x, y = batch['audio'][0], batch['targets'].cpu() 30 | sample_rate, input_length, batch_size = ( 31 | self.datamodule.sample_rate, 32 | self.datamodule.input_length, 33 | self.datamodule.batch_size 34 | ) 35 | # Process whole track as batches of chunks 36 | audio_chunks = th.cat([el.unsqueeze(0) for el in x.split( 37 | split_size=int(input_length*sample_rate))[:-1]], dim=0) 38 | logits_list, probs_list = th.tensor([]), th.tensor([]) 39 | for audio_batch in audio_chunks.split(batch_size): 40 | with th.no_grad(): 41 | features = self.feature_extractor(audio_batch) 42 | logits = self.net(features) 43 | probs = self.activation(logits) 44 | logits_list = th.cat([logits_list, logits.cpu()], dim=0) 45 | probs_list = th.cat([probs_list, probs.cpu()], dim=0) 46 | # Aggregate along track and then compute metrics 47 | logits_agg, probs_agg = logits_list.mean(dim=0).unsqueeze( 48 | 0), probs_list.mean(dim=0).unsqueeze(0) 49 | loss_dict['val_loss'] = self.criterion(logits_agg, y).item() 50 | loss_dict['val_roc_auc'] = metrics.roc_auc_score( 51 | y.T, probs_agg.T, average="macro") 52 | loss_dict['val_pr_auc'] = metrics.average_precision_score( 53 | y.T, probs_agg.T, average="macro") 54 | self.log_dict(loss_dict, on_step=False, on_epoch=True) 55 | return loss_dict['val_loss'] 56 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | 4 | 5 | class Res2DMaxPoolModule(nn.Module): 6 | def __init__(self, in_channels, out_channels, pooling=2): 7 | super(Res2DMaxPoolModule, self).__init__() 8 | self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1) 9 | self.bn_1 = nn.BatchNorm2d(out_channels) 10 | self.conv_2 = nn.Conv2d(out_channels, out_channels, 3, padding=1) 11 | self.bn_2 = nn.BatchNorm2d(out_channels) 12 | self.relu = nn.ReLU() 13 | self.mp = nn.MaxPool2d(tuple(pooling)) 14 | 15 | # residual 16 | self.diff = False 17 | if in_channels != out_channels: 18 | self.conv_3 = nn.Conv2d( 19 | in_channels, out_channels, 3, padding=1) 20 | self.bn_3 = nn.BatchNorm2d(out_channels) 21 | self.diff = True 22 | 23 | def forward(self, x): 24 | out = self.bn_2(self.conv_2(self.relu(self.bn_1(self.conv_1(x))))) 25 | if self.diff: 26 | x = self.bn_3(self.conv_3(x)) 27 | out = x + out 28 | out = self.mp(self.relu(out)) 29 | return out 30 | 31 | 32 | class ResFrontEnd(nn.Module): 33 | """ 34 | Adapted from Minz Won ResNet implementation. 35 | 36 | Original code: https://github.com/minzwon/semi-supervised-music-tagging-transformer/blob/master/src/modules.py 37 | """ 38 | def __init__(self, in_channels, out_channels, freq_pooling, time_pooling): 39 | super(ResFrontEnd, self).__init__() 40 | self.input_bn = nn.BatchNorm2d(in_channels) 41 | self.layer1 = Res2DMaxPoolModule( 42 | in_channels, out_channels, pooling=(freq_pooling[0], time_pooling[0])) 43 | self.layer2 = Res2DMaxPoolModule( 44 | out_channels, out_channels, pooling=(freq_pooling[1], time_pooling[1])) 45 | self.layer3 = Res2DMaxPoolModule( 46 | out_channels, out_channels, pooling=(freq_pooling[2], time_pooling[2])) 47 | 48 | def forward(self, hcqt): 49 | """ 50 | Inputs: 51 | hcqt: [B, F, K, T] 52 | 53 | Outputs: 54 | out: [B, ^F, ^K, ^T] 55 | """ 56 | # batch normalization 57 | out = self.input_bn(hcqt) 58 | 59 | # CNN 60 | out = self.layer1(out) 61 | out = self.layer2(out) 62 | out = self.layer3(out) 63 | 64 | return out 65 | 66 | 67 | class SpecTNTBlock(nn.Module): 68 | def __init__( 69 | self, n_channels, n_frequencies, n_times, 70 | spectral_dmodel, spectral_nheads, spectral_dimff, 71 | temporal_dmodel, temporal_nheads, temporal_dimff, 72 | embed_dim, dropout, use_tct 73 | ): 74 | super().__init__() 75 | 76 | self.D = embed_dim 77 | self.F = n_frequencies 78 | self.K = n_channels 79 | self.T = n_times 80 | 81 | # TCT: Temporal Class Token 82 | if use_tct: 83 | self.T += 1 84 | 85 | # Shared frequency-time linear layers 86 | self.D_to_K = nn.Linear(self.D, self.K) 87 | self.K_to_D = nn.Linear(self.K, self.D) 88 | 89 | # Spectral Transformer Encoder 90 | self.spectral_linear_in = nn.Linear(self.F+1, spectral_dmodel) 91 | self.spectral_encoder_layer = nn.TransformerEncoderLayer( 92 | d_model=spectral_dmodel, nhead=spectral_nheads, dim_feedforward=spectral_dimff, dropout=dropout, batch_first=True, activation="gelu", norm_first=True) 93 | self.spectral_linear_out = nn.Linear(spectral_dmodel, self.F+1) 94 | 95 | # Temporal Transformer Encoder 96 | self.temporal_linear_in = nn.Linear(self.T, temporal_dmodel) 97 | self.temporal_encoder_layer = nn.TransformerEncoderLayer( 98 | d_model=temporal_dmodel, nhead=temporal_nheads, dim_feedforward=temporal_dimff, dropout=dropout, batch_first=True, activation="gelu", norm_first=True) 99 | self.temporal_linear_out = nn.Linear(temporal_dmodel, self.T) 100 | 101 | def forward(self, spec_in, temp_in): 102 | """ 103 | Inputs: 104 | spec_in: spectral embedding input [B, T, F+1, K] 105 | temp_in: temporal embedding input [B, T, 1, D] 106 | 107 | Outputs: 108 | spec_out: spectral embedding output [B, T, F+1, K] 109 | temp_out: temporal embedding output [B, T, 1, D] 110 | """ 111 | # Element-wise addition between TE and FCT 112 | spec_in = spec_in + \ 113 | nn.functional.pad(self.D_to_K(temp_in), (0, 0, 0, self.F)) 114 | 115 | # Spectral Transformer 116 | spec_in = spec_in.flatten(0, 1).transpose(1, 2) # [B*T, K, F+1] 117 | emb = self.spectral_linear_in(spec_in) # [B*T, K, spectral_dmodel] 118 | spec_enc_out = self.spectral_encoder_layer( 119 | emb) # [B*T, K, spectral_dmodel] 120 | spec_out = self.spectral_linear_out(spec_enc_out) # [B*T, K, F+1] 121 | spec_out = spec_out.view(-1, self.T, self.K, 122 | self.F+1).transpose(2, 3) # [B, T, F+1, K] 123 | 124 | # FCT slicing (first raw) + back to D 125 | temp_in = temp_in + self.K_to_D(spec_out[:, :, :1, :]) # [B, T, 1, D] 126 | 127 | # Temporal Transformer 128 | temp_in = temp_in.permute(0, 2, 3, 1).flatten(0, 1) # [B, D, T] 129 | emb = self.temporal_linear_in(temp_in) # [B, D, temporal_dmodel] 130 | temp_enc_out = self.temporal_encoder_layer( 131 | emb) # [B, D, temporal_dmodel] 132 | temp_out = self.temporal_linear_out(temp_enc_out) # [B, D, T] 133 | temp_out = temp_out.unsqueeze(1).permute(0, 3, 1, 2) # [B, T, 1, D] 134 | 135 | return spec_out, temp_out 136 | 137 | 138 | class SpecTNTModule(nn.Module): 139 | def __init__( 140 | self, n_channels, n_frequencies, n_times, 141 | spectral_dmodel, spectral_nheads, spectral_dimff, 142 | temporal_dmodel, temporal_nheads, temporal_dimff, 143 | embed_dim, n_blocks, dropout, use_tct 144 | ): 145 | super().__init__() 146 | 147 | D = embed_dim 148 | F = n_frequencies 149 | K = n_channels 150 | T = n_times 151 | 152 | # Frequency Class Token 153 | self.fct = nn.Parameter(th.zeros(1, T, 1, K)) 154 | 155 | # Frequency Positional Encoding 156 | self.fpe = nn.Parameter(th.zeros(1, 1, F+1, K)) 157 | 158 | # TCT: Temporal Class Token 159 | if use_tct: 160 | self.tct = nn.Parameter(th.zeros(1, 1, 1, D)) 161 | else: 162 | self.tct = None 163 | 164 | # Temporal Embedding 165 | self.te = nn.Parameter(th.rand(1, T, 1, D)) 166 | 167 | # SpecTNT blocks 168 | self.spectnt_blocks = nn.ModuleList([ 169 | SpecTNTBlock( 170 | n_channels, 171 | n_frequencies, 172 | n_times, 173 | spectral_dmodel, 174 | spectral_nheads, 175 | spectral_dimff, 176 | temporal_dmodel, 177 | temporal_nheads, 178 | temporal_dimff, 179 | embed_dim, 180 | dropout, 181 | use_tct 182 | ) 183 | for _ in range(n_blocks) 184 | ]) 185 | 186 | def forward(self, x): 187 | """ 188 | Input: 189 | x: [B, T, F, K] 190 | 191 | Output: 192 | spec_emb: [B, T, F+1, K] 193 | temp_emb: [B, T, 1, D] 194 | """ 195 | batch_size = len(x) 196 | 197 | # Initialize spectral embedding - concat FCT (first raw) + add FPE 198 | fct = th.repeat_interleave(self.fct, batch_size, 0) # [B, T, 1, K] 199 | spec_emb = th.cat([fct, x], dim=2) # [B, T, F+1, K] 200 | spec_emb = spec_emb + self.fpe 201 | if self.tct is not None: 202 | spec_emb = nn.functional.pad( 203 | spec_emb, (0, 0, 0, 0, 1, 0)) # [B, T+1, F+1, K] 204 | 205 | # Initialize temporal embedding 206 | temp_emb = th.repeat_interleave(self.te, batch_size, 0) # [B, T, 1, D] 207 | if self.tct is not None: 208 | tct = th.repeat_interleave(self.tct, batch_size, 0) # [B, 1, 1, D] 209 | temp_emb = th.cat([tct, temp_emb], dim=1) # [B, T+1, 1, D] 210 | 211 | # SpecTNT blocks inference 212 | for block in self.spectnt_blocks: 213 | spec_emb, temp_emb = block(spec_emb, temp_emb) 214 | 215 | return spec_emb, temp_emb 216 | 217 | 218 | class SpecTNT(nn.Module): 219 | def __init__( 220 | self, fe_model, 221 | n_channels, n_frequencies, n_times, 222 | spectral_dmodel, spectral_nheads, spectral_dimff, 223 | temporal_dmodel, temporal_nheads, temporal_dimff, 224 | embed_dim, n_blocks, dropout, use_tct, n_classes 225 | ): 226 | super().__init__() 227 | 228 | # TCT: Temporal Class Token 229 | self.use_tct = use_tct 230 | 231 | # Front-end model 232 | self.fe_model = fe_model 233 | 234 | # Main model 235 | self.main_model = SpecTNTModule( 236 | n_channels, 237 | n_frequencies, 238 | n_times, 239 | spectral_dmodel, 240 | spectral_nheads, 241 | spectral_dimff, 242 | temporal_dmodel, 243 | temporal_nheads, 244 | temporal_dimff, 245 | embed_dim, 246 | n_blocks, 247 | dropout, 248 | use_tct 249 | ) 250 | 251 | # Linear layer 252 | self.linear_out = nn.Linear(embed_dim, n_classes) 253 | 254 | def forward(self, features): 255 | """ 256 | Input: 257 | features: [B, K, F, T] 258 | 259 | Output: 260 | logits: 261 | - [B, n_classes] if use_tct 262 | - [B, T, n_classes] otherwise 263 | """ 264 | # Add channel dimension if None 265 | if len(features.size()) == 3: 266 | features = features.unsqueeze(1) 267 | # Front-end model 268 | fe_out = self.fe_model(features) # [B, ^K, ^F, ^T] 269 | fe_out = fe_out.permute(0, 3, 2, 1) # [B, T, F, K] 270 | # Main model 271 | _, temp_emb = self.main_model(fe_out) # [B, T, 1, D] 272 | # Linear layer 273 | if self.use_tct: 274 | return self.linear_out(temp_emb[:, 0, 0, :]) # [B, n_classes] 275 | else: 276 | return self.linear_out(temp_emb[:, :, 0, :]) # [B, T, n_classes] 277 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch_lightning==1.5.10 2 | omegaconf==2.1.1 3 | torch==1.11.0 4 | numpy>=1.22 5 | librosa==0.8.1 6 | hydra_core==1.1.2 7 | mir_eval==0.6 8 | torchaudio==0.11.0 9 | hydra==2.5 10 | scikit_learn==1.1.2 11 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import hydra.utils as hu 3 | import pytorch_lightning as pl 4 | 5 | 6 | @hydra.main(config_path="configs/", config_name="beats") 7 | def main(cfg): 8 | if "seed" in cfg: 9 | pl.seed_everything(cfg.seed) 10 | 11 | feature_extractor = hu.instantiate(cfg.features) 12 | fe_model = hu.instantiate(cfg.fe_model) 13 | net = hu.instantiate( 14 | cfg.net, 15 | fe_model=fe_model 16 | ) 17 | optimizer = hu.instantiate(cfg.optim, params=net.parameters()) 18 | lr_scheduler = hu.instantiate( 19 | cfg.lr_scheduler, optimizer) if "lr_scheduler" in cfg else None 20 | criterion = hu.instantiate(cfg.criterion) 21 | 22 | datamodule = hu.instantiate(cfg.datamodule) 23 | model = hu.instantiate( 24 | cfg.model, 25 | net=net, 26 | feature_extractor=feature_extractor, 27 | optimizer=optimizer, 28 | lr_scheduler=lr_scheduler, 29 | criterion=criterion, 30 | datamodule=datamodule 31 | ) 32 | 33 | model_ckpt = hu.instantiate(cfg.model_checkpoint) 34 | logger, callbacks = [], [] 35 | profiler = None 36 | if "profiler" in cfg.trainer and cfg.trainer.profiler: 37 | profiler = pl.profiler.AdvancedProfiler(dirpath=cfg.logger.save_dir, 38 | filename=cfg.experiment) 39 | if "logger" in cfg: 40 | logger = hu.instantiate(cfg.logger) 41 | if "callbacks" in cfg: 42 | for _, cb_cfg in cfg.callbacks.items(): 43 | callbacks.append(hu.instantiate(cb_cfg)) 44 | 45 | if "resume" in cfg: 46 | trainer = hu.instantiate(cfg.trainer, 47 | checkpoint_callback=model_ckpt, 48 | callbacks=callbacks, 49 | logger=logger, 50 | resume_from_checkpoint=cfg.resume.ckpt_path, 51 | profiler=profiler) 52 | print("Resuming model checkpoint..") 53 | else: 54 | trainer = hu.instantiate(cfg.trainer, 55 | checkpoint_callback=model_ckpt, 56 | callbacks=callbacks, 57 | logger=logger, 58 | profiler=profiler) 59 | 60 | trainer.fit(model=model, datamodule=datamodule) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | --------------------------------------------------------------------------------