├── configs ├── optimizer.yaml ├── data │ ├── jamendo.yaml │ ├── medley+cambridge-16.yaml │ ├── medley+cambridge-8.yaml │ ├── medley+cambridge+jamendo-16.yaml │ ├── medley+cambridge+jamendo-8.yaml │ └── medley+cambridge-4.yaml ├── models │ ├── unpaired+feat.yaml │ ├── naive.yaml │ └── naive+feat.yaml └── config.yaml ├── Assets └── diffmst-main_modified.jpg ├── stability.sh ├── .gitignore ├── scripts ├── online.sh ├── info.py ├── compare.py ├── run.py ├── datasets.py ├── gain_testing.py ├── eval_listen.py ├── eval_ablation.py ├── eval_all_combo.py └── online.py ├── tests ├── test_panner.py ├── test_encoder.py ├── test_bus.py ├── test_remix.py ├── test_mix.py ├── test_loss.py ├── test_mst.py ├── test_ke.py ├── test_comp.py ├── test_profile.py ├── test_reverb.py ├── test_dataset.py ├── test_sepremix.py └── test_peq.py ├── main.py ├── mst ├── callbacks │ ├── acc.py │ ├── metrics.py │ ├── plotting.py │ ├── audio.py │ └── mix.py ├── filter.py ├── param_system.py ├── panns.py ├── loss.py ├── fx_encoder.py ├── utils.py └── system.py ├── data ├── console_ranges.yaml └── instrument_name2id.json ├── setup.py ├── requirements.txt └── README.md /configs/optimizer.yaml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | class_path: torch.optim.Adam 3 | init_args: 4 | lr: 0.00001 -------------------------------------------------------------------------------- /Assets/diffmst-main_modified.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sai-soum/Diff-MST/HEAD/Assets/diffmst-main_modified.jpg -------------------------------------------------------------------------------- /stability.sh: -------------------------------------------------------------------------------- 1 | # 2 | 3 | cd /scratch 4 | mkdir medleydb 5 | cd medleydb 6 | aws s3 sync s3://stability-aws/MedleyDB ./ 7 | tar -xvf MedleyDB_v1.tar 8 | tar -xvf MedleyDB_v2.tar -------------------------------------------------------------------------------- /configs/data/jamendo.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: mst.dataloader.MixDataModule 3 | init_args: 4 | root_dir: /import/c4dm-datasets-ext/mtg-jamendo 5 | length: 262144 6 | batch_size: 4 7 | num_workers: 4 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv/** 2 | env/** 3 | mst/callbacks/trail.ipynb 4 | model_saved/** 5 | data/.ipynb_checkpoints/** 6 | data/*.ipynb 7 | 8 | __pycache__ 9 | *.egg-info 10 | 11 | mix_KE_adv/** 12 | .vscode/ 13 | logs/** 14 | checkpoints/ 15 | debug 16 | *.wav 17 | *.png 18 | data/FXencoder_ps.pt 19 | outputs/** -------------------------------------------------------------------------------- /scripts/online.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=4 python scripts/online.py \ 3 | --track_dir "/import/c4dm-datasets-ext/test-multitracks/Kat Wright_By My Side" \ 4 | --ref_mix "/import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/The Dip - Paddle To The Stars (Lyric Video).wav" \ 5 | --use_gpu \ 6 | --n_iters 1000 \ 7 | --loss "feat" \ 8 | #--stem_separation \ -------------------------------------------------------------------------------- /tests/test_panner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dasp_pytorch.functional import stereo_panner 3 | 4 | tracks = torch.ones(1, 4, 1) 5 | print(tracks) 6 | print(tracks.shape) 7 | 8 | param_dict = {"stereo_panner": {"pan": torch.tensor([0.0, 0.5, 1.0, 0.0])}} 9 | 10 | tracks = stereo_panner(tracks, **param_dict["stereo_panner"]) 11 | 12 | print(tracks) 13 | print(tracks.shape) 14 | tracks.sum(dim=2) 15 | -------------------------------------------------------------------------------- /tests/test_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mst.panns import TCN 3 | from mst.modules import WaveformTransformerEncoder 4 | 5 | 6 | def count_parameters(model): 7 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 8 | 9 | 10 | # encoder = TCN(n_inputs=2) 11 | encoder = WaveformTransformerEncoder(n_inputs=2) 12 | 13 | print(count_parameters(encoder) / 1e6) 14 | 15 | x = torch.randn(4, 2, 262144) 16 | 17 | 18 | y = encoder(x) 19 | print(y.shape) 20 | -------------------------------------------------------------------------------- /configs/data/medley+cambridge-16.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: mst.dataloader.MultitrackDataModule 3 | init_args: 4 | track_root_dirs: 5 | - /import/c4dm-datasets-ext/mixing-secrets/ 6 | - /import/c4dm-datasets/ 7 | metadata_files: 8 | - ./data/cambridge.yaml 9 | - ./data/medley.yaml 10 | length: 262144 11 | min_tracks: 2 12 | max_tracks: 16 13 | batch_size: 1 14 | num_workers: 4 15 | num_train_passes: 20 16 | num_val_passes: 1 17 | train_buffer_size_gb: 4.0 18 | val_buffer_size_gb: 0.2 19 | target_track_lufs_db: -48.0 -------------------------------------------------------------------------------- /configs/data/medley+cambridge-8.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: mst.dataloader.MultitrackDataModule 3 | init_args: 4 | track_root_dirs: 5 | - /import/c4dm-datasets-ext/mixing-secrets/ 6 | - /import/c4dm-datasets/ 7 | metadata_files: 8 | - ./data/cambridge.yaml 9 | - ./data/medley.yaml 10 | length: 262144 11 | min_tracks: 8 12 | max_tracks: 8 13 | batch_size: 4 14 | num_workers: 4 15 | num_train_passes: 20 16 | num_val_passes: 1 17 | train_buffer_size_gb: 4.0 18 | val_buffer_size_gb: 1.0 19 | target_track_lufs_db: -48.0 -------------------------------------------------------------------------------- /configs/data/medley+cambridge+jamendo-16.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: mst.dataloader.MultitrackDataModule 3 | init_args: 4 | track_root_dirs: 5 | - /import/c4dm-datasets-ext/mixing-secrets/ 6 | - /import/c4dm-datasets/ 7 | mix_root_dirs: 8 | - /import/c4dm-datasets-ext/mtg-jamendo 9 | metadata_files: 10 | - ./data/cambridge.yaml 11 | - ./data/medley.yaml 12 | length: 262144 13 | min_tracks: 2 14 | max_tracks: 16 15 | batch_size: 1 16 | num_workers: 4 17 | num_train_passes: 20 18 | num_val_passes: 1 19 | train_buffer_size_gb: 4.0 20 | val_buffer_size_gb: 0.2 21 | target_track_lufs_db: -48.0 22 | randomize_ref_mix_gain: true -------------------------------------------------------------------------------- /configs/data/medley+cambridge+jamendo-8.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: mst.dataloader.MultitrackDataModule 3 | init_args: 4 | track_root_dirs: 5 | - /import/c4dm-datasets-ext/mixing-secrets/ 6 | - /import/c4dm-datasets/ 7 | mix_root_dirs: 8 | - /import/c4dm-datasets-ext/mtg-jamendo 9 | metadata_files: 10 | - ./data/cambridge.yaml 11 | - ./data/medley.yaml 12 | length: 262144 13 | min_tracks: 2 14 | max_tracks: 8 15 | batch_size: 4 16 | num_workers: 8 17 | num_train_passes: 20 18 | num_val_passes: 1 19 | train_buffer_size_gb: 2.0 20 | val_buffer_size_gb: 0.2 21 | target_mix_lufs_db: -16.0 22 | target_track_lufs_db: -48.0 23 | randomize_ref_mix_gain: false 24 | -------------------------------------------------------------------------------- /scripts/info.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torchaudio 4 | from tqdm import tqdm 5 | 6 | root_dir = "/import/c4dm-datasets-ext/mixing-secrets/" 7 | 8 | # find all directories containing tracks 9 | song_dirs = glob.glob(os.path.join(root_dir, "*")) 10 | 11 | files = {"stereo": 0, "mono": 0} 12 | 13 | for song_dir in tqdm(song_dirs): 14 | # get all tracks in song dir 15 | track_filepaths = glob.glob(os.path.join(song_dir, "tracks", "**", "*.wav")) 16 | 17 | for track_filepath in track_filepaths: 18 | # get into 19 | md = torchaudio.info(track_filepath) 20 | 21 | if md.num_channels == 2: 22 | files["stereo"] += 1 23 | elif md.num_channels == 1: 24 | files["mono"] += 1 25 | 26 | print(files) 27 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # class MyLightningCLI(pl.LightningCLI): 2 | # def add_arguments_to_parser(self, parser): 3 | # parser.link_arguments( 4 | # "model.model.mix_console.num_control_params", 5 | # "model.model.controller.num_control_params", 6 | # apply_on="instantiate", 7 | # ) 8 | 9 | import torch 10 | from pytorch_lightning.cli import LightningCLI 11 | from pytorch_lightning.strategies import DDPStrategy 12 | 13 | 14 | def cli_main(): 15 | cli = LightningCLI( 16 | save_config_callback=None, 17 | trainer_defaults={ 18 | "accelerator": "gpu", 19 | # "strategy": DDPStrategy(find_unused_parameters=True), 20 | "log_every_n_steps": 50, 21 | }, 22 | ) 23 | 24 | 25 | if __name__ == "__main__": 26 | 27 | cli_main() 28 | -------------------------------------------------------------------------------- /tests/test_bus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import matplotlib.pyplot as plt 4 | 5 | from dasp_pytorch.functional import stereo_bus 6 | 7 | 8 | bs = 2 9 | chs = 2 10 | seq_len = 262144 11 | sample_rate = 44100 12 | 13 | # x = torch.randn(bs, chs, seq_len) 14 | # x = x / x.abs().max().clamp(min=1e-8) 15 | # x *= 10 ** (-24 / 20) 16 | 17 | x, sr = torchaudio.load("tests/target-gtr.wav") 18 | x = x.unsqueeze(0) 19 | 20 | x = torch.randn(bs, 8, chs, seq_len) 21 | sends_db = torch.tensor([0.0, -3.0, -6.0, -9.0, -12.0, -15.0, -18.0, -21.0]).view(1, 8, 1) 22 | sends_db = sends_db.repeat(bs, 1, 1) 23 | 24 | print(x.shape) 25 | 26 | y = stereo_bus( 27 | x, 28 | sends_db, 29 | ) 30 | 31 | print(y.shape) 32 | y /= y.abs().max().clamp(min=1e-8) 33 | torchaudio.save("tests/reverb.wav", y.view(2, -1), sample_rate) 34 | -------------------------------------------------------------------------------- /tests/test_remix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mst.modules import Remixer, AdvancedMixConsole 3 | from mst.dataloader import MixDataset 4 | 5 | if __name__ == "__main__": 6 | root_dir = "/import/c4dm-datasets-ext/mtg-jamendo" 7 | mix_dataset = MixDataset(root_dir, length=262144) 8 | mix_dataloader = torch.utils.data.DataLoader(mix_dataset, batch_size=4) 9 | 10 | mix_console = AdvancedMixConsole(44100) 11 | remixer = Remixer(44100) 12 | 13 | remixer.cuda() 14 | 15 | for batch_idx, batch in enumerate(mix_dataloader): 16 | mix = batch 17 | 18 | mix = mix.cuda() 19 | 20 | # create remix 21 | remix, track_params, fx_bus_params, master_bus_params = remixer( 22 | mix, mix_console 23 | ) 24 | 25 | print(batch_idx, mix.abs().max(), remix.abs().max()) 26 | -------------------------------------------------------------------------------- /tests/test_mix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.resnet import resnet18 3 | from mst.modules import ( 4 | MixStyleTransferModel, 5 | SpectrogramResNetEncoder, 6 | TransformerController, 7 | BasicMixConsole, 8 | AdvancedMixConsole, 9 | ) 10 | from tqdm import tqdm 11 | from mst.mixing import naive_random_mix, knowledge_engineering_mix 12 | 13 | sample_rate = 44100 14 | embed_dim = 128 15 | num_control_params = 26 16 | 17 | 18 | mix_console = AdvancedMixConsole(sample_rate) 19 | 20 | for n in tqdm(range(100)): 21 | bs = 8 22 | num_tracks = 4 23 | seq_len = 262144 24 | 25 | tracks = torch.randn(bs, num_tracks, seq_len) * 0.1 26 | 27 | mix, params = naive_random_mix(tracks, mix_console) 28 | 29 | if torch.isnan(mix).any(): 30 | print("NAN") 31 | print(mix.shape) 32 | print(torch.isnan(mix).any()) 33 | 34 | break 35 | -------------------------------------------------------------------------------- /configs/data/medley+cambridge-4.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: mst.dataloader.MultitrackDataModule 3 | init_args: 4 | #update the root dirs to the location of the dataset for Cambridge.mt and MEdleyDB. You can use just one dataset if you want. Multiple datasets should be provided as list of diectories. 5 | #corresponding metadata files should be provided as list of yaml files which contain train, val and test splits. We have default splits for Cambridge and MedleyDB in the data folder. 6 | track_root_dirs: 7 | - /import/c4dm-datasets-ext/mixing-secrets/ 8 | - /import/c4dm-datasets/ 9 | metadata_files: 10 | - ./data/cambridge.yaml 11 | - ./data/medley.yaml 12 | length: 262144 13 | #supports different values for min and max tracks 14 | min_tracks: 4 15 | max_tracks: 4 16 | batch_size: 2 17 | num_workers: 4 18 | num_train_passes: 20 19 | num_val_passes: 1 20 | train_buffer_size_gb: 4.0 21 | val_buffer_size_gb: 1.0 22 | target_track_lufs_db: -48.0 23 | randomize_ref_mix_gain: false 24 | 25 | -------------------------------------------------------------------------------- /mst/callbacks/acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pytorch_lightning as pl 4 | 5 | from mst.callbacks.plotting import plot_confusion_matrix 6 | 7 | 8 | class ConfusionMatrixCallback(pl.callbacks.Callback): 9 | def __init__(self): 10 | super().__init__() 11 | self.targets = [] 12 | self.estimates = [] 13 | 14 | def on_validation_batch_end( 15 | self, 16 | trainer, 17 | pl_module, 18 | outputs, 19 | batch, 20 | batch_idx, 21 | dataloader_idx, 22 | ): 23 | """Called when the validation batch ends.""" 24 | if outputs is not None: 25 | self.targets.append(outputs["e"]) 26 | self.estimates.append(outputs["e_hat"].max(1).indices) 27 | 28 | def on_validation_end(self, trainer, pl_module): 29 | 30 | e = torch.cat(self.targets, dim=0) 31 | e_hat = torch.cat(self.estimates, dim=0) 32 | 33 | trainer.logger.experiment.add_image( 34 | f"confusion_matrix", 35 | plot_confusion_matrix( 36 | e_hat, 37 | e, 38 | labels=pl_module.hparams.effects, 39 | ), 40 | trainer.global_step, 41 | ) 42 | 43 | # clear outputs 44 | self.targets = [] 45 | self.estimates = [] 46 | -------------------------------------------------------------------------------- /tests/test_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | 4 | from mst.loss import AudioFeatureLoss, ParameterEstimatorLoss 5 | from mst.loss import ( 6 | compute_crest_factor, 7 | compute_melspectrum, 8 | compute_barkspectrum, 9 | compute_rms, 10 | compute_stereo_imbalance, 11 | compute_stereo_width, 12 | ) 13 | 14 | transforms = [ 15 | compute_rms, 16 | compute_crest_factor, 17 | compute_stereo_width, 18 | compute_stereo_imbalance, 19 | compute_barkspectrum, 20 | ] 21 | weights = [10.0, 0.1, 10.0, 100.0, 0.1] 22 | 23 | sample_rate = 44100 24 | 25 | # loss = AudioFeatureLoss(weights, sample_rate, stem_separation=False) 26 | 27 | ckpt_path = "/import/c4dm-datasets-ext/Diff-MST/DiffMST-Param/0ymfi1pp/checkpoints/epoch=5-step=10842.ckpt" 28 | loss = ParameterEstimatorLoss(ckpt_path) 29 | 30 | # test with audio examples 31 | input, _ = torchaudio.load("outputs/output/pred_mix.wav") 32 | target, _ = torchaudio.load("outputs/output/ref_mix.wav") 33 | 34 | 35 | input = input.unsqueeze(0) 36 | target = target.unsqueeze(0) 37 | 38 | input = input.repeat(4, 1, 1) 39 | target = target.repeat(4, 1, 1) 40 | 41 | # input[0, ...] = 0.0001 * torch.randn_like(input[0, ...]) 42 | # target[0, ...] = 0.0001 * torch.randn_like(input[0, ...]) 43 | 44 | 45 | loss_val = loss(input, target) 46 | 47 | print(loss_val) 48 | -------------------------------------------------------------------------------- /tests/test_mst.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.resnet import resnet18 3 | from mst.modules import ( 4 | MixStyleTransferModel, 5 | SpectrogramResNetEncoder, 6 | TransformerController, 7 | BasicMixConsole, 8 | AdvancedMixConsole, 9 | ) 10 | 11 | sample_rate = 44100 12 | embed_dim = 128 13 | num_track_control_params = 27 14 | num_fx_bus_control_params = 25 15 | num_master_bus_control_params = 24 16 | use_fx_bus = True 17 | use_master_bus = True 18 | 19 | track_encoder = SpectrogramResNetEncoder() 20 | mix_encoder = SpectrogramResNetEncoder() 21 | controller = TransformerController( 22 | embed_dim=embed_dim, 23 | num_track_control_params=num_track_control_params, 24 | num_fx_bus_control_params=num_fx_bus_control_params, 25 | num_master_bus_control_params=num_master_bus_control_params, 26 | ) 27 | 28 | 29 | mix_console = AdvancedMixConsole(sample_rate) 30 | 31 | model = MixStyleTransferModel(track_encoder, mix_encoder, controller, mix_console) 32 | 33 | bs = 8 34 | num_tracks = 4 35 | seq_len = 262144 36 | 37 | tracks = torch.randn(bs, num_tracks, seq_len) 38 | ref_mix = torch.randn(bs, 2, seq_len) 39 | 40 | ( 41 | mixed_tracks, 42 | mix, 43 | track_param_dict, 44 | fx_bus_param_dict, 45 | master_bus_param_dict, 46 | ) = model(tracks, ref_mix) 47 | print(mix.shape) 48 | print(mix) 49 | -------------------------------------------------------------------------------- /tests/test_ke.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | 5 | from mst.mixing import knowledge_engineering_mix 6 | from mst.modules import BasicMixConsole, AdvancedMixConsole 7 | from mst.dataloaders.medley import MedleyDBDataset 8 | 9 | dataset = MedleyDBDataset(root_dirs=["/scratch/medleydb/V1"], subset="train") 10 | print(len(dataset)) 11 | 12 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) 13 | # mix_console = mst.modules.BasicMixConsole(sample_rate=44100.0) 14 | mix_console = AdvancedMixConsole(sample_rate=44100.0) 15 | for i, (track, instrument_id, stereo) in enumerate(dataloader): 16 | print("\n\n mixing") 17 | print("track", track.size()) 18 | batch_size, num_tracks, seq_len = track.size() 19 | 20 | print(instrument_id) 21 | print(stereo) 22 | 23 | mix, param_dict = knowledge_engineering_mix( 24 | track, mix_console, instrument_id, stereo 25 | ) 26 | print(param_dict) 27 | sum_mix = torch.sum(track, dim=1) 28 | print("mix", mix.size()) 29 | 30 | save_dir = "debug" 31 | os.makedirs(save_dir, exist_ok=True) 32 | 33 | # export audio 34 | for j in range(batch_size): 35 | torchaudio.save(os.path.join(save_dir, "mix_" + str(j) + ".wav"), mix[j], 44100) 36 | torchaudio.save( 37 | os.path.join(save_dir, "sum" + str(j) + ".wav"), sum_mix[j], 44100 38 | ) 39 | 40 | print("mix", mix.size()) 41 | if i == 0: 42 | break 43 | -------------------------------------------------------------------------------- /data/console_ranges.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | #silence 3 | [instrument_category]: 4 | instruments: 5 | - instrument 6 | gain : 7 | - -24.0 8 | - -24.0 9 | pan : 10 | - 0.0 11 | - 1.0 12 | eq: 13 | eq_lowshelf_gain : 14 | - -3.0 15 | - -1.0 16 | eq_lowshelf_freq : 17 | - 20 18 | - 2000 19 | eq_lowshelf_q : 20 | - 0.1 21 | - 5.0 22 | eq_band0_gain : 23 | - -24 24 | - 24 25 | eq_band0_freq : 26 | - 80 27 | - 2000 28 | eq_band0_q : 29 | - 0.1 30 | - 5.0 31 | eq_band1_gain : 32 | - -24 33 | - 24 34 | eq_band1_freq : 35 | - 2000 36 | - 8000 37 | eq_band1_q : 38 | - 0.1 39 | - 5.0 40 | eq_band2_gain : 41 | - -24 42 | - 24 43 | eq_band2_freq : 44 | - 8000 45 | - 12000 46 | eq_band2_q : 47 | - 0.1 48 | - 5.0 49 | eq_band3_gain : 50 | - 0.0 51 | - 0.0 52 | eq_band3_freq : 53 | - 12000 54 | - 20000 55 | eq_band3_q : 56 | - 0.1 57 | - 5.0 58 | eq_highshelf_gain : 59 | - -24 60 | - 24 61 | eq_highshelf_freq : 62 | - 6000 63 | - 20000 64 | eq_highshelf_q : 65 | - 0.1 66 | - 5.0 67 | compressor: 68 | threshold_db : 69 | - -60.0 70 | - 0.0 71 | ratio: 72 | - 1.0 73 | - 10.0 74 | attack_ms: 75 | - 1.0 76 | - 1000.0 77 | release_ms: 78 | - 1.0 79 | - 1000.0 80 | knee_db: 81 | - 3.0 82 | - 24.0 83 | makeup_gain_db: 84 | - 0.0 85 | - 24.0 86 | 87 | -------------------------------------------------------------------------------- /tests/test_comp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import matplotlib.pyplot as plt 4 | 5 | from dasp_pytorch.functional import compressor 6 | 7 | 8 | bs = 2 9 | chs = 2 10 | seq_len = 262144 11 | sample_rate = 44100 12 | 13 | # x = torch.randn(bs, chs, seq_len) 14 | # x = x / x.abs().max().clamp(min=1e-8) 15 | # x *= 10 ** (-24 / 20) 16 | 17 | 18 | x = torch.zeros(bs, chs, seq_len) 19 | 20 | x[..., 0, 4096:131072] = 1.0 21 | x[..., 1, 16384:65536] = 1.0 22 | 23 | 24 | threshold_db = torch.tensor([-12.0, -6.0]) 25 | ratio = torch.tensor([4.0, 4.0]) 26 | attack_ms = torch.tensor([100.0, 1000.0]) 27 | release_ms = torch.tensor([0.0, 0.0]) # dummy parameter 28 | knee_db = torch.tensor([6.0, 6.0]) 29 | makeup_gain_db = torch.tensor([0.0, 0.0]) 30 | 31 | threshold_db = threshold_db.view(1, chs).repeat(bs, 1) 32 | ratio = ratio.view(1, chs).repeat(bs, 1) 33 | attack_ms = attack_ms.view(1, chs).repeat(bs, 1) 34 | release_ms = release_ms.view(1, chs).repeat(bs, 1) 35 | knee_db = knee_db.view(1, chs).repeat(bs, 1) 36 | makeup_gain_db = makeup_gain_db.view(1, chs).repeat(bs, 1) 37 | 38 | print(threshold_db.shape) 39 | y = compressor( 40 | x, 41 | sample_rate, 42 | threshold_db, 43 | ratio, 44 | attack_ms, 45 | release_ms, 46 | knee_db, 47 | makeup_gain_db, 48 | ) 49 | 50 | print(y.shape) 51 | 52 | fig, axs = plt.subplots(2, 1, figsize=(10, 6)) 53 | axs[0].plot(x[0, 0, :].numpy(), label="input") 54 | axs[0].plot(y[0, 0, :].numpy(), label="output") 55 | axs[0].legend() 56 | 57 | axs[1].plot(x[0, 1, :].numpy(), label="input") 58 | axs[1].plot(y[0, 1, :].numpy(), label="output") 59 | axs[1].legend() 60 | plt.grid(c="lightgray") 61 | plt.savefig("compressor.png", dpi=300) 62 | -------------------------------------------------------------------------------- /configs/models/unpaired+feat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_path: mst.system.System 3 | init_args: 4 | generate_mix: false 5 | active_eq_epoch: 0 6 | active_compressor_epoch: 0 7 | active_fx_bus_epoch: 1000 8 | active_master_bus_epoch: 0 9 | mix_fn: mst.mixing.naive_random_mix 10 | mix_console: 11 | class_path: mst.modules.AdvancedMixConsole 12 | init_args: 13 | sample_rate: 44100 14 | input_min_gain_db: -48.0 15 | input_max_gain_db: 48.0 16 | output_min_gain_db: -48.0 17 | output_max_gain_db: 48.0 18 | eq_min_gain_db: -12.0 19 | eq_max_gain_db: 12.0 20 | min_pan: 0.0 21 | max_pan: 1.0 22 | model: 23 | class_path: mst.modules.MixStyleTransferModel 24 | init_args: 25 | track_encoder: 26 | class_path: mst.modules.SpectrogramEncoder 27 | init_args: 28 | embed_dim: 512 29 | n_fft: 2048 30 | hop_length: 512 31 | input_batchnorm: false 32 | mix_encoder: 33 | class_path: mst.modules.SpectrogramEncoder 34 | init_args: 35 | embed_dim: 512 36 | n_fft: 2048 37 | hop_length: 512 38 | input_batchnorm: false 39 | controller: 40 | class_path: mst.modules.TransformerController 41 | init_args: 42 | embed_dim: 512 43 | num_track_control_params: 27 44 | num_fx_bus_control_params: 25 45 | num_master_bus_control_params: 26 46 | num_layers: 12 47 | nhead: 8 48 | 49 | loss: 50 | class_path: mst.loss.AudioFeatureLoss 51 | init_args: 52 | sample_rate: 44100 53 | stem_separation: false 54 | use_clap: false 55 | weights: 56 | - 0.1 # rms 57 | - 0.001 # crest factor 58 | - 1.0 # stereo width 59 | - 1.0 # stereo imbalance 60 | - 0.1 # bark spectrum 61 | 62 | -------------------------------------------------------------------------------- /mst/callbacks/metrics.py: -------------------------------------------------------------------------------- 1 | from re import X 2 | import numpy as np 3 | import pytorch_lightning as pl 4 | 5 | 6 | 7 | 8 | class LogAudioMetricsCallback(pl.callbacks.Callback): 9 | def __init__( 10 | self, 11 | sample_rate: int = 44100, 12 | ): 13 | super().__init__() 14 | self.sample_rate = sample_rate 15 | 16 | self.metrics = { 17 | "PESQi": PESQi(sample_rate), 18 | "MRSTFTi": MRSTFTi(), 19 | "SISDRi": SISDRi(), 20 | } 21 | 22 | self.outputs = [] 23 | 24 | def on_validation_batch_end( 25 | self, 26 | trainer, 27 | pl_module, 28 | outputs, 29 | batch, 30 | batch_idx, 31 | dataloader_idx, 32 | ): 33 | """Called when the validation batch ends.""" 34 | 35 | if outputs is not None: 36 | self.outputs.append(outputs) 37 | 38 | def on_validation_end(self, trainer, pl_module): 39 | y_hat_metrics = { 40 | "PESQi": [], 41 | "MRSTFTi": [], 42 | "SISDRi": [], 43 | } 44 | for output in self.outputs: 45 | for metric_name, metric in self.metrics.items(): 46 | for batch_idx in range(len(output["y_hat"].shape)): 47 | y_hat = output["y_hat"][batch_idx, ...] 48 | x = output["x"][batch_idx, ...] 49 | y = output["y"][batch_idx, ...] 50 | 51 | try: 52 | val = metric(y_hat, x, y) 53 | y_hat_metrics[metric_name].append(val) 54 | except Exception as e: 55 | print(e) 56 | 57 | # log final mean metrics 58 | for metric_name, metric in y_hat_metrics.items(): 59 | val = np.mean(metric) 60 | trainer.logger.experiment.add_scalar( 61 | f"metrics/estimated-{metric_name}", val, trainer.global_step 62 | ) 63 | 64 | # clear outputs 65 | self.outputs = [] 66 | -------------------------------------------------------------------------------- /configs/models/naive.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_path: mst.system.System 3 | init_args: 4 | generate_mix: true 5 | active_eq_epoch: 0 6 | active_compressor_epoch: 0 7 | active_fx_bus_epoch: 1000 8 | active_master_bus_epoch: 0 9 | 10 | use_track_loss : false 11 | use_mix_loss : true 12 | use_param_loss : false 13 | 14 | mix_fn: mst.mixing.naive_random_mix 15 | mix_console: 16 | class_path: mst.modules.AdvancedMixConsole 17 | init_args: 18 | sample_rate: 44100 19 | input_min_gain_db: -48.0 20 | input_max_gain_db: 48.0 21 | output_min_gain_db: -48.0 22 | output_max_gain_db: 48.0 23 | eq_min_gain_db: -12.0 24 | eq_max_gain_db: 12.0 25 | min_pan: 0.0 26 | max_pan: 1.0 27 | model: 28 | class_path: mst.modules.MixStyleTransferModel 29 | init_args: 30 | track_encoder: 31 | class_path: mst.modules.SpectrogramEncoder 32 | init_args: 33 | embed_dim: 512 34 | n_fft: 2048 35 | hop_length: 512 36 | input_batchnorm: false 37 | mix_encoder: 38 | class_path: mst.modules.SpectrogramEncoder 39 | init_args: 40 | embed_dim: 512 41 | n_fft: 2048 42 | hop_length: 512 43 | input_batchnorm: false 44 | controller: 45 | class_path: mst.modules.TransformerController 46 | init_args: 47 | embed_dim: 512 48 | num_track_control_params: 27 49 | num_fx_bus_control_params: 25 50 | num_master_bus_control_params: 26 51 | num_layers: 12 52 | nhead: 8 53 | 54 | loss: 55 | class_path: auraloss.freq.MultiResolutionSTFTLoss 56 | init_args: 57 | fft_sizes: 58 | - 512 59 | - 2048 60 | - 8192 61 | hop_sizes: 62 | - 256 63 | - 1024 64 | - 4096 65 | win_lengths: 66 | - 512 67 | - 2048 68 | - 8192 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /tests/test_profile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | from torch.profiler import profile, record_function, ProfilerActivity 4 | 5 | from mst.modules import ( 6 | MixStyleTransferModel, 7 | SpectrogramResNetEncoder, 8 | TransformerController, 9 | BasicMixConsole, 10 | AdvancedMixConsole, 11 | ) 12 | 13 | if False: 14 | sample_rate = 44100 15 | embed_dim = 128 16 | num_track_control_params = 27 17 | num_fx_bus_control_params = 25 18 | num_master_bus_control_params = 24 19 | use_fx_bus = True 20 | use_master_bus = True 21 | 22 | track_encoder = SpectrogramResNetEncoder() 23 | mix_encoder = SpectrogramResNetEncoder() 24 | controller = TransformerController( 25 | embed_dim=embed_dim, 26 | num_track_control_params=num_track_control_params, 27 | num_fx_bus_control_params=num_fx_bus_control_params, 28 | num_master_bus_control_params=num_master_bus_control_params, 29 | ) 30 | 31 | mix_console = AdvancedMixConsole(sample_rate) 32 | 33 | model = MixStyleTransferModel(track_encoder, mix_encoder, controller, mix_console) 34 | model.cuda() 35 | 36 | bs = 1 37 | num_tracks = 8 38 | seq_len = 262144 39 | 40 | tracks = torch.randn(bs, num_tracks, seq_len) 41 | ref_mix = torch.randn(bs, 2, seq_len) 42 | 43 | tracks = tracks.cuda() 44 | ref_mix = ref_mix.cuda() 45 | 46 | with profile( 47 | activities=[ 48 | ProfilerActivity.CUDA, 49 | ProfilerActivity.CPU, 50 | ], 51 | profile_memory=True, 52 | record_shapes=True, 53 | with_modules=False, 54 | with_stack=True, 55 | ) as prof: 56 | ( 57 | mixed_tracks, 58 | mix, 59 | track_param_dict, 60 | fx_bus_param_dict, 61 | master_bus_param_dict, 62 | ) = model(tracks, ref_mix) 63 | 64 | print( 65 | prof.key_averages(group_by_stack_n=5).table( 66 | sort_by="self_cuda_memory_usage", row_limit=25 67 | ) 68 | ) 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup, find_packages 3 | 4 | NAME = "DiffMST" 5 | DESCRIPTION = "Mix style transfer with differentiable signal processing" 6 | URL = "https://github.com/sai-soum/DiffMST.git" 7 | EMAIL = "s.s.vanka@qmul.ac.uk" 8 | AUTHOR = "Soumya Sai Vanka" 9 | REQUIRES_PYTHON = ">=3.7.11" 10 | VERSION = "0.0.1" 11 | 12 | HERE = Path(__file__).parent 13 | 14 | try: 15 | with open(HERE / "README.md", encoding="utf-8") as f: 16 | long_description = "\n" + f.read() 17 | except FileNotFoundError: 18 | long_description = DESCRIPTION 19 | 20 | setup( 21 | name=NAME, 22 | version=VERSION, 23 | description=DESCRIPTION, 24 | long_description=long_description, 25 | long_description_content_type="text/markdown", 26 | author=AUTHOR, 27 | author_email=EMAIL, 28 | python_requires=REQUIRES_PYTHON, 29 | url=URL, 30 | packages=[ 31 | "mst", 32 | ], 33 | install_requires=[ 34 | "auraloss==0.4.0", 35 | "dasp-pytorch==0.0.1", 36 | "librosa", 37 | "matplotlib", 38 | "numpy", 39 | "pedalboard==0.8.7", 40 | "pyloudnorm", 41 | "pytorch_lightning[extra]==2.1.4", 42 | "scipy==1.12.0", 43 | "tensorboard", 44 | "torch==2.2.0", 45 | "torchaudio==2.2.0", 46 | "torchvision==0.17.0", 47 | "tqdm", 48 | "wandb", 49 | ], 50 | extras_require={ 51 | "asteroid": ["asteroid-filterbanks>=0.3.2"], 52 | "tests": [ 53 | "pytest", 54 | "musdb>=0.4.0", 55 | "museval>=0.4.0", 56 | "asteroid-filterbanks>=0.3.2", 57 | "onnx", 58 | "tqdm", 59 | ], 60 | "stempeg": ["stempeg"], 61 | "evaluation": ["musdb>=0.4.0", "museval>=0.4.0"], 62 | }, 63 | # entry_points={"console_scripts": ["umx=openunmix.cli:separate"]}, 64 | # packages=find_packages(), 65 | include_package_data=True, 66 | classifiers=[ 67 | "Topic :: Multimedia :: Sound/Audio", 68 | "Topic :: Scientific/Engineering", 69 | ], 70 | ) 71 | -------------------------------------------------------------------------------- /tests/test_reverb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import matplotlib.pyplot as plt 4 | 5 | from dasp_pytorch.functional import noise_shaped_reverberation 6 | 7 | 8 | bs = 2 9 | chs = 2 10 | seq_len = 262144 11 | sample_rate = 44100 12 | 13 | # x = torch.randn(bs, chs, seq_len) 14 | # x = x / x.abs().max().clamp(min=1e-8) 15 | # x *= 10 ** (-24 / 20) 16 | 17 | x, sr = torchaudio.load("tests/target-gtr.wav") 18 | x = x.repeat(2, 1) 19 | x = x.unsqueeze(0) 20 | print(x.shape) 21 | 22 | band_gains = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) 23 | band_decays = torch.tensor( 24 | [0.8, 0.7, 0.8, 0.6, 0.6, 0.7, 0.8, 0.8, 0.99, 0.8, 0.9, 1.0] 25 | ) 26 | 27 | band0_gain = torch.rand(1) 28 | band1_gain = torch.rand(1) 29 | band2_gain = torch.rand(1) 30 | band3_gain = torch.rand(1) 31 | band4_gain = torch.rand(1) 32 | band5_gain = torch.rand(1) 33 | band6_gain = torch.rand(1) 34 | band7_gain = torch.rand(1) 35 | band8_gain = torch.rand(1) 36 | band9_gain = torch.rand(1) 37 | band10_gain = torch.rand(1) 38 | band11_gain = torch.rand(1) 39 | band0_decay = torch.rand(1) 40 | band1_decay = torch.rand(1) 41 | band2_decay = torch.rand(1) 42 | band3_decay = torch.rand(1) 43 | band4_decay = torch.rand(1) 44 | band5_decay = torch.rand(1) 45 | band6_decay = torch.rand(1) 46 | band7_decay = torch.rand(1) 47 | band8_decay = torch.rand(1) 48 | band9_decay = torch.rand(1) 49 | band10_decay = torch.rand(1) 50 | band11_decay = torch.rand(1) 51 | 52 | mix = torch.tensor([0.05]) 53 | 54 | y = noise_shaped_reverberation( 55 | x, 56 | sample_rate, 57 | band0_gain, 58 | band1_gain, 59 | band2_gain, 60 | band3_gain, 61 | band4_gain, 62 | band5_gain, 63 | band6_gain, 64 | band7_gain, 65 | band8_gain, 66 | band9_gain, 67 | band10_gain, 68 | band11_gain, 69 | band0_decay, 70 | band1_decay, 71 | band2_decay, 72 | band3_decay, 73 | band4_decay, 74 | band5_decay, 75 | band6_decay, 76 | band7_decay, 77 | band8_decay, 78 | band9_decay, 79 | band10_decay, 80 | band11_decay, 81 | mix, 82 | num_samples=88200, 83 | num_bandpass_taps=1023, 84 | ) 85 | 86 | print(y.shape) 87 | y /= y.abs().max().clamp(min=1e-8) 88 | torchaudio.save("tests/reverb.wav", y.view(2, -1), sample_rate) 89 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchaudio 4 | 5 | from mst.mixing import naive_random_mix 6 | from mst.modules import AdvancedMixConsole 7 | from mst.dataloader import MultitrackDataModule 8 | 9 | datamodule = MultitrackDataModule( 10 | root_dirs=["/import/c4dm-datasets-ext/mixing-secrets/", "/import/c4dm-datasets/"], 11 | metadata_files=["data/cambridge.yaml", "data/medley.yaml"], 12 | length=262144, 13 | min_tracks=8, 14 | max_tracks=8, 15 | batch_size=2, 16 | num_workers=4, 17 | num_train_passes=20, 18 | num_val_passes=1, 19 | train_buffer_size_gb=0.1, 20 | val_buffer_size_gb=0.1, 21 | target_track_lufs_db=-48.0, 22 | ) 23 | datamodule.setup("fit") 24 | 25 | root = "./" 26 | os.makedirs(f"{root}/debug", exist_ok=True) 27 | 28 | train_loader = datamodule.train_dataloader() 29 | mix_fn = naive_random_mix 30 | use_gpu = True 31 | 32 | generate_mix_console = AdvancedMixConsole(44100) 33 | 34 | for idx, batch in enumerate(train_loader): 35 | tracks, instrument_id, stereo_info = batch 36 | 37 | if use_gpu: 38 | tracks = tracks.cuda() 39 | 40 | sum_mix = tracks.sum(dim=1) 41 | # create a random mix (on GPU, if applicable) 42 | ( 43 | mixed_tracks, 44 | mix, 45 | track_param_dict, 46 | fx_bus_param_dict, 47 | master_bus_param_dict, 48 | ) = mix_fn( 49 | tracks, 50 | generate_mix_console, 51 | use_track_input_fader=False, 52 | use_output_fader=False, 53 | use_fx_bus=True, 54 | use_master_bus=True, 55 | ) 56 | 57 | mix = mix[0, ...] 58 | sum_mix = sum_mix[0, ...] 59 | print(mix.shape, sum_mix.shape) 60 | print(mix.abs().max(), sum_mix.abs().max()) 61 | mix = mix.view(2, -1) 62 | sum_mix = sum_mix.repeat(2, 1) 63 | 64 | mix /= mix.abs().max() 65 | sum_mix /= sum_mix.abs().max() 66 | 67 | # split mix into a and b sections 68 | mix_a = mix[:, 0 : mix.shape[1] // 2] 69 | mix_b = mix[:, mix.shape[1] // 2 :] 70 | 71 | torchaudio.save(f"{root}/debug/{idx}_ref_mix_a.wav", mix_a.cpu(), 44100) 72 | torchaudio.save(f"{root}/debug/{idx}_ref_mix_b.wav", mix_b.cpu(), 44100) 73 | torchaudio.save(f"{root}/debug/{idx}_sum_mix.wav", sum_mix.cpu(), 44100) 74 | if idx > 25: 75 | break 76 | -------------------------------------------------------------------------------- /scripts/compare.py: -------------------------------------------------------------------------------- 1 | # script to compare two mixes based on their features 2 | import os 3 | import torch 4 | import argparse 5 | import torchaudio 6 | import matplotlib.pyplot as plt 7 | 8 | from mst.loss import ( 9 | compute_barkspectrum, 10 | compute_crest_factor, 11 | compute_rms, 12 | compute_stereo_imbalance, 13 | compute_stereo_width, 14 | ) 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("input_a", type=str) 19 | parser.add_argument("input_b", type=str) 20 | parser.add_argument("--output_dir", type=str, default="outputs/compare") 21 | args = parser.parse_args() 22 | 23 | input_a_filename = os.path.basename(args.input_a).split(".")[0] 24 | input_b_filename = os.path.basename(args.input_b).split(".")[0] 25 | run_name = f"{input_a_filename}-{input_b_filename}" 26 | output_dir = os.path.join(args.output_dir, run_name) 27 | os.makedirs(output_dir, exist_ok=True) 28 | 29 | # load audio files 30 | input_a, input_a_sample_rate = torchaudio.load(args.input_a) 31 | input_b, input_b_sample_rate = torchaudio.load(args.input_b) 32 | 33 | # -------------- compute features ---------------- 34 | a_barkspectrum = compute_barkspectrum(input_a, sample_rate=44100) 35 | b_barkspectrum = compute_barkspectrum(input_b, sample_rate=44100) 36 | 37 | a_crest_factor = compute_crest_factor(input_a) 38 | b_crest_factor = compute_crest_factor(input_b) 39 | 40 | a_rms = compute_rms(input_a) 41 | b_rms = compute_rms(input_b) 42 | 43 | a_stereo_imbalance = compute_stereo_imbalance(input_a) 44 | b_stereo_imbalance = compute_stereo_imbalance(input_b) 45 | 46 | a_stereo_width = compute_stereo_width(input_a) 47 | b_stereo_width = compute_stereo_width(input_b) 48 | 49 | # -------------- plot features ---------------- 50 | 51 | fig, axs = plt.subplots(2, 1, sharex=True, sharey=True) 52 | axs[0].plot(a_barkspectrum[0, :, 0], label="A-mid", color="tab:orange") 53 | axs[0].plot(b_barkspectrum[0, :, 0], label="B-mid", color="tab:blue") 54 | axs[1].plot(a_barkspectrum[0, :, 1], label="A-side", color="tab:orange") 55 | axs[1].plot(b_barkspectrum[0, :, 1], label="B-side", color="tab:blue") 56 | axs[0].legend() 57 | axs[1].legend() 58 | plt.savefig(os.path.join(output_dir, f"bark_spectrum.png")) 59 | plt.close("all") 60 | -------------------------------------------------------------------------------- /configs/models/naive+feat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | class_path: mst.system.System 3 | init_args: 4 | #geerate random mixes as specified in Method 1 of the paper ; True for Method 1, False for Method 2 5 | generate_mix: True 6 | #set the epoch values very high to disable any of the effects during training; fx_bus corresponds to reverb module. 7 | active_eq_epoch: 0 8 | active_compressor_epoch: 0 9 | active_fx_bus_epoch: 1000 10 | active_master_bus_epoch: 0 11 | use_track_loss : false 12 | use_mix_loss : true 13 | use_param_loss : false 14 | 15 | #We generate random mixes using the naive random mix function 16 | 17 | mix_fn: mst.mixing.naive_random_mix 18 | mix_console: 19 | class_path: mst.modules.AdvancedMixConsole 20 | init_args: 21 | sample_rate: 44100 22 | input_min_gain_db: -48.0 23 | input_max_gain_db: 48.0 24 | output_min_gain_db: -48.0 25 | output_max_gain_db: 48.0 26 | eq_min_gain_db: -12.0 27 | eq_max_gain_db: 12.0 28 | min_pan: 0.0 29 | max_pan: 1.0 30 | model: 31 | class_path: mst.modules.MixStyleTransferModel 32 | init_args: 33 | track_encoder: 34 | class_path: mst.modules.SpectrogramEncoder 35 | init_args: 36 | embed_dim: 512 37 | n_fft: 2048 38 | hop_length: 512 39 | input_batchnorm: false 40 | mix_encoder: 41 | class_path: mst.modules.SpectrogramEncoder 42 | init_args: 43 | embed_dim: 512 44 | n_fft: 2048 45 | hop_length: 512 46 | input_batchnorm: false 47 | controller: 48 | class_path: mst.modules.TransformerController 49 | init_args: 50 | embed_dim: 512 51 | num_track_control_params: 27 52 | num_fx_bus_control_params: 25 53 | num_master_bus_control_params: 26 54 | num_layers: 12 55 | nhead: 8 56 | 57 | loss: 58 | class_path: mst.loss.AudioFeatureLoss 59 | init_args: 60 | sample_rate: 44100 61 | stem_separation: false 62 | use_clap: False 63 | weights: 64 | - 0.1 # rms 65 | - 0.001 # crest factor 66 | - 1.0 # stereo width 67 | - 1.0 # stereo imbalance 68 | - 0.1 # bark spectrum 69 | - 1.0 # CLAP 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | # use ckpt_path to load a model from a checkpoint 3 | # ckpt_path: /import/c4dm-datasets-ext/diffmst_logs_soum/2021-10-06/14-00-00/checkpoints/epoch=0-step=0.ckpt 4 | trainer: 5 | logger: 6 | class_path: pytorch_lightning.loggers.WandbLogger 7 | init_args: 8 | project: DiffMST 9 | #change to the directory where you want to save wandb logs 10 | save_dir: /import/c4dm-datasets-ext/diffmst_logs_soum 11 | enable_checkpointing: true 12 | callbacks: 13 | - class_path: mst.callbacks.audio.LogAudioCallback 14 | - class_path: pytorch_lightning.callbacks.ModelSummary 15 | init_args: 16 | max_depth: 2 17 | #uncomment if you want to run validation on custom examples during training 18 | # - class_path: mst.callbacks.mix.LogReferenceMix 19 | # init_args: 20 | # root_dirs: 21 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/Soren_ALittleLate_Full 22 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/Soren_ALittleLate_Full 23 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/MR0903_Moosmusic_Full 24 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/MR0903_Moosmusic_Full 25 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song3/SaturnSyndicate_CatchTheWave_Full 26 | # ref_mixes: 27 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/ref/Harry Styles - Late Night Talking (Official Video).wav 28 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/ref/Poom - Les Voiles (Official Audio).wav 29 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/ref/Justin Timberlake - Can't Stop The Feeling! [Lyrics].wav 30 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/ref/Taylor Swift - Shake It Off.wav 31 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song3/ref/Miley Cyrus - Wrecking Ball (Lyrics).wav 32 | default_root_dir: null 33 | gradient_clip_val: 10.0 34 | devices: 1 35 | check_val_every_n_epoch: 1 36 | max_epochs: 800 37 | #change to log less or more often to wandb 38 | log_every_n_steps: 500 39 | accelerator: gpu 40 | strategy: ddp_find_unused_parameters_true 41 | sync_batchnorm: true 42 | precision: 32 43 | enable_model_summary: true 44 | num_sanity_val_steps: 2 45 | benchmark: true 46 | accumulate_grad_batches: 1 47 | #reload_dataloaders_every_n_epochs: 1 48 | 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | antlr4-python3-runtime==4.9.3 5 | attrs==23.2.0 6 | audioread==3.0.1 7 | auraloss==0.4.0 8 | bitsandbytes==0.41.1 9 | certifi==2024.7.4 10 | cffi==1.16.0 11 | charset-normalizer==3.3.2 12 | click==8.1.7 13 | contourpy==1.2.1 14 | cycler==0.12.1 15 | dasp-pytorch==0.0.1 16 | decorator==5.1.1 17 | -e git+ssh://git@github.com/sai-soum/Diff-MST.git@f65df9e178b7e4af33c6fb3e728452024fd8169b#egg=DiffMST 18 | docker-pycreds==0.4.0 19 | docstring_parser==0.16 20 | filelock==3.15.4 21 | fonttools==4.53.1 22 | frozenlist==1.4.1 23 | fsspec==2024.6.1 24 | future==1.0.0 25 | gitdb==4.0.11 26 | GitPython==3.1.43 27 | grpcio==1.65.0 28 | hydra-core==1.3.2 29 | idna==3.7 30 | importlib_resources==6.4.0 31 | Jinja2==3.1.4 32 | joblib==1.4.2 33 | jsonargparse==4.31.0 34 | kiwisolver==1.4.5 35 | lazy_loader==0.4 36 | librosa==0.10.2.post1 37 | lightning-utilities==0.11.3.post0 38 | llvmlite==0.43.0 39 | Markdown==3.6 40 | markdown-it-py==3.0.0 41 | MarkupSafe==2.1.5 42 | matplotlib==3.9.1 43 | mdurl==0.1.2 44 | mpmath==1.3.0 45 | msgpack==1.0.8 46 | multidict==6.0.5 47 | networkx==3.3 48 | numba==0.60.0 49 | numpy==1.26.4 50 | nvidia-cublas-cu12==12.1.3.1 51 | nvidia-cuda-cupti-cu12==12.1.105 52 | nvidia-cuda-nvrtc-cu12==12.1.105 53 | nvidia-cuda-runtime-cu12==12.1.105 54 | nvidia-cudnn-cu12==8.9.2.26 55 | nvidia-cufft-cu12==11.0.2.54 56 | nvidia-curand-cu12==10.3.2.106 57 | nvidia-cusolver-cu12==11.4.5.107 58 | nvidia-cusparse-cu12==12.1.0.106 59 | nvidia-nccl-cu12==2.19.3 60 | nvidia-nvjitlink-cu12==12.5.82 61 | nvidia-nvtx-cu12==12.1.105 62 | omegaconf==2.3.0 63 | packaging==24.1 64 | pedalboard==0.8.7 65 | pillow==10.4.0 66 | platformdirs==4.2.2 67 | pooch==1.8.2 68 | protobuf==4.25.3 69 | psutil==6.0.0 70 | pycparser==2.22 71 | Pygments==2.18.0 72 | pyloudnorm==0.1.1 73 | pyparsing==3.1.2 74 | python-dateutil==2.9.0.post0 75 | pytorch-lightning==2.1.4 76 | PyYAML==6.0.1 77 | requests==2.32.3 78 | rich==13.7.1 79 | scikit-learn==1.5.1 80 | scipy==1.12.0 81 | sentry-sdk==2.9.0 82 | setproctitle==1.3.3 83 | six==1.16.0 84 | smmap==5.0.1 85 | soundfile==0.12.1 86 | soxr==0.3.7 87 | sympy==1.13.0 88 | tensorboard==2.17.0 89 | tensorboard-data-server==0.7.2 90 | tensorboardX==2.6.2.2 91 | threadpoolctl==3.5.0 92 | torch==2.2.0 93 | torchaudio==2.2.0 94 | torchmetrics==1.4.0.post0 95 | torchvision==0.17.0 96 | tqdm==4.66.4 97 | triton==2.2.0 98 | typeshed_client==2.5.1 99 | typing_extensions==4.12.2 100 | urllib3==2.2.2 101 | wandb==0.17.4 102 | Werkzeug==3.0.3 103 | yarl==1.9.4 104 | -------------------------------------------------------------------------------- /tests/test_sepremix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | 4 | from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS 5 | from mst.modules import AdvancedMixConsole 6 | 7 | 8 | class Remixer(torch.nn.Module): 9 | def __init__(self, sample_rate: int) -> None: 10 | super().__init__() 11 | self.sample_rate = sample_rate 12 | 13 | # load source separation model 14 | bundle = HDEMUCS_HIGH_MUSDB_PLUS 15 | self.stem_separator = bundle.get_model() 16 | self.stem_separator.eval() 17 | # get sources list 18 | self.sources_list = list(self.stem_separator.sources) 19 | 20 | # load mix console 21 | self.mix_console = AdvancedMixConsole(sample_rate) 22 | 23 | def forward(self, x: torch.Tensor): 24 | """Take a tensor of mixes, separate, and then remix. 25 | 26 | Args: 27 | x (torch.Tensor): Tensor of mixes with shape (batch, 2, samples) 28 | 29 | Returns: 30 | remix (torch.Tensor): Tensor of remixes with shape (batch, 2, samples) 31 | sum_mix (torch.Tensor): Tensor of mixes from separeted outputs with shape (batch, 2, samples) 32 | track_params (torch.Tensor): Tensor of track params with shape (batch, 8, num_track_control_params) 33 | fx_bus_params (torch.Tensor): Tensor of fx bus params with shape (batch, num_fx_bus_control_params) 34 | master_bus_params (torch.Tensor): Tensor of master bus params with shape (batch, num_master_bus_control_params) 35 | """ 36 | bs, chs, seq_len = x.size() 37 | 38 | # separate 39 | sources = self.stem_separator(x) # bs, 4, 2, seq_len 40 | sum_mix = sources.sum(dim=1) # bs, 2, seq_len 41 | 42 | # convert sources to mono tracks 43 | tracks = sources.view(bs, 8, -1) 44 | 45 | # provide some headroom before mixing 46 | tracks *= 10 ** (-32.0 / 20.0) 47 | 48 | # generate random mix parameters 49 | track_params = torch.rand(bs, 8, self.mix_console.num_track_control_params) 50 | fx_bus_params = torch.rand(bs, self.mix_console.num_fx_bus_control_params) 51 | master_bus_params = torch.rand( 52 | bs, self.mix_console.num_master_bus_control_params 53 | ) 54 | 55 | # the forward expects params in range of (0,1) 56 | result = self.mix_console( 57 | tracks, 58 | track_params, 59 | fx_bus_params, 60 | master_bus_params, 61 | ) 62 | 63 | # get the remix 64 | remix = result[1] 65 | 66 | return remix, sum_mix, track_params, fx_bus_params, master_bus_params 67 | 68 | 69 | if __name__ == "__main__": 70 | # create the remixer 71 | remixer = Remixer(44100) 72 | 73 | # get a mix 74 | mix, sample_rate = torchaudio.load("outputs/output/ref_mix.wav") 75 | 76 | mix = mix.unsqueeze(0) 77 | 78 | # remix 79 | remix, sum_mix = remixer(mix) 80 | 81 | # peak normalize 82 | remix = remix / remix.abs().max() 83 | 84 | torchaudio.save("outputs/output/remix.wav", remix.squeeze(0), sample_rate=44100) 85 | torchaudio.save( 86 | "outputs/output/separated_sum_mix.wav", sum_mix.squeeze(0), sample_rate=44100 87 | ) 88 | -------------------------------------------------------------------------------- /data/instrument_name2id.json: -------------------------------------------------------------------------------- 1 | { 2 | "silence": 0, 3 | "accordion": 1, 4 | "acoustic guitar": 2, 5 | "alto saxophone": 3, 6 | "auxiliary percussion": 4, 7 | "bamboo flute": 5, 8 | "banjo": 6, 9 | "baritone saxophone": 7, 10 | "bass clarinet": 8, 11 | "bass drum": 9, 12 | "bassoon": 10, 13 | "bongo": 11, 14 | "brass section": 12, 15 | "cabasa": 13, 16 | "castanet": 14, 17 | "cello": 15, 18 | "cello section": 16, 19 | "chimes": 17, 20 | "claps": 18, 21 | "clarinet": 19, 22 | "clarinet section": 20, 23 | "clean electric guitar": 21, 24 | "cornet": 22, 25 | "cowbell": 23, 26 | "cymbal": 24, 27 | "darbuka": 25, 28 | "distorted electric guitar": 26, 29 | "dizi": 27, 30 | "double bass": 28, 31 | "doumbek": 29, 32 | "drum machine": 30, 33 | "drum set": 31, 34 | "electric bass": 32, 35 | "electric piano": 33, 36 | "electronic organ": 34, 37 | "erhu": 35, 38 | "euphonium": 36, 39 | "female singer": 37, 40 | "flute": 38, 41 | "flute section": 39, 42 | "french horn": 40, 43 | "french horn section": 41, 44 | "fx/processed sound": 42, 45 | "glockenspiel": 43, 46 | "gong": 44, 47 | "gu": 45, 48 | "guiro": 46, 49 | "guzheng": 47, 50 | "harmonica": 48, 51 | "harp": 49, 52 | "high hat": 50, 53 | "horn section": 51, 54 | "kick drum": 52, 55 | "lap steel guitar": 53, 56 | "liuqin": 54, 57 | "male rapper": 55, 58 | "male singer": 56, 59 | "male speaker": 57, 60 | "mandolin": 58, 61 | "melodica": 59, 62 | "oboe": 60, 63 | "oud": 61, 64 | "piano": 62, 65 | "piccolo": 63, 66 | "sampler": 64, 67 | "scratches": 65, 68 | "shaker": 66, 69 | "sleigh bells": 68, 70 | "snare drum": 69, 71 | "soprano saxophone": 70, 72 | "string section": 71, 73 | "synthesizer": 72, 74 | "tabla": 73, 75 | "tack piano": 74, 76 | "tambourine": 75, 77 | "tenor saxophone": 76, 78 | "timpani": 77, 79 | "toms": 78, 80 | "trombone": 79, 81 | "trombone section": 80, 82 | "trumpet": 81, 83 | "trumpet section": 82, 84 | "tuba": 83, 85 | "vibraphone": 84, 86 | "viola": 85, 87 | "viola section": 86, 88 | "violin": 87, 89 | "violin section": 88, 90 | "vocalists": 89, 91 | "whistle": 90, 92 | "yangqin": 91, 93 | "zhongruan": 92, 94 | "Main system": 93, 95 | "backingvox": 94, 96 | "vox": 95, 97 | "synth": 96, 98 | "loop": 97, 99 | "percussion": 98, 100 | "kick": 99, 101 | "sfx": 100, 102 | "drum": 101, 103 | "overhead": 102, 104 | "snare": 103, 105 | "bass": 104, 106 | "conga": 105, 107 | "guitar": 106, 108 | "hi hat": 107, 109 | "clap": 108, 110 | "tom": 109, 111 | "misc": 110, 112 | "roommic": 111, 113 | "organ": 112, 114 | "keys": 113, 115 | "vocals": 114, 116 | "horn": 115, 117 | "bell": 116, 118 | "fiddle": 117, 119 | "string": 118, 120 | "stick": 119, 121 | "brass": 120, 122 | "choir": 121, 123 | "crash": 122, 124 | "rhodes": 123, 125 | "noise": 124, 126 | "hit": 125, 127 | "ukelele": 126, 128 | "bagpipe": 127, 129 | "saxophone": 128, 130 | "vocoder": 129, 131 | "marimba": 130, 132 | "cajon": 131, 133 | "soprano": 132, 134 | "alto": 133, 135 | "tenor": 134, 136 | "triangle": 135, 137 | "chorus": 136, 138 | " gong": 137, 139 | "ukulele": 138, 140 | "bongos": 139, 141 | "quartet": 140, 142 | "xylophone": 141, 143 | "woodwind": 142, 144 | "sitar": 143, 145 | "Main System": 144 146 | } -------------------------------------------------------------------------------- /tests/test_peq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import matplotlib.pyplot as plt 4 | from scipy.signal import savgol_filter 5 | from dasp_pytorch.functional import parametric_eq 6 | 7 | 8 | bs = 2 9 | chs = 2 10 | seq_len = 131072 11 | sample_rate = 44100 12 | 13 | x = torch.zeros(bs, chs, seq_len) 14 | x = x / x.abs().max().clamp(min=1e-8) 15 | x *= 10 ** (-24 / 20) 16 | 17 | low_shelf_gain_db = torch.tensor([0.0, 0.0]) 18 | low_shelf_cutoff_freq = torch.tensor([20.0, 20.0]) 19 | low_shelf_q_factor = torch.tensor([0.707, 0.707]) 20 | band0_gain_db = torch.tensor([0.0, 12.0]) 21 | band0_cutoff_freq = torch.tensor([100.0, 100.0]) 22 | band0_q_factor = torch.tensor([0.707, 0.707]) 23 | band1_gain_db = torch.tensor([12.0, 0.0]) 24 | band1_cutoff_freq = torch.tensor([1000.0, 1000.0]) 25 | band1_q_factor = torch.tensor([6.0, 0.707]) 26 | band2_gain_db = torch.tensor([-6.0, 0.0]) 27 | band2_cutoff_freq = torch.tensor([10000.0, 10000.0]) 28 | band2_q_factor = torch.tensor([3.0, 0.707]) 29 | band3_gain_db = torch.tensor([0.0, 0.0]) 30 | band3_cutoff_freq = torch.tensor([12000.0, 12000.0]) 31 | band3_q_factor = torch.tensor([0.707, 0.707]) 32 | high_shelf_gain_db = torch.tensor([0.0, 0.0]) 33 | high_shelf_cutoff_freq = torch.tensor([12000.0, 12000.0]) 34 | high_shelf_q_factor = torch.tensor([0.707, 0.707]) 35 | 36 | # reshape and repeat to match batch size 37 | low_shelf_gain_db = low_shelf_gain_db.view(1, chs).repeat(bs, 1) 38 | low_shelf_cutoff_freq = low_shelf_cutoff_freq.view(1, chs).repeat(bs, 1) 39 | low_shelf_q_factor = low_shelf_q_factor.view(1, chs).repeat(bs, 1) 40 | band0_gain_db = band0_gain_db.view(1, chs).repeat(bs, 1) 41 | band0_cutoff_freq = band0_cutoff_freq.view(1, chs).repeat(bs, 1) 42 | band0_q_factor = band0_q_factor.view(1, chs).repeat(bs, 1) 43 | band1_gain_db = band1_gain_db.view(1, chs).repeat(bs, 1) 44 | band1_cutoff_freq = band1_cutoff_freq.view(1, chs).repeat(bs, 1) 45 | band1_q_factor = band1_q_factor.view(1, chs).repeat(bs, 1) 46 | band2_gain_db = band2_gain_db.view(1, chs).repeat(bs, 1) 47 | band2_cutoff_freq = band2_cutoff_freq.view(1, chs).repeat(bs, 1) 48 | band2_q_factor = band2_q_factor.view(1, chs).repeat(bs, 1) 49 | band3_gain_db = band3_gain_db.view(1, chs).repeat(bs, 1) 50 | band3_cutoff_freq = band3_cutoff_freq.view(1, chs).repeat(bs, 1) 51 | band3_q_factor = band3_q_factor.view(1, chs).repeat(bs, 1) 52 | high_shelf_gain_db = high_shelf_gain_db.view(1, chs).repeat(bs, 1) 53 | high_shelf_cutoff_freq = high_shelf_cutoff_freq.view(1, chs).repeat(bs, 1) 54 | high_shelf_q_factor = high_shelf_q_factor.view(1, chs).repeat(bs, 1) 55 | 56 | y = parametric_eq( 57 | x, 58 | sample_rate, 59 | low_shelf_gain_db, 60 | low_shelf_cutoff_freq, 61 | low_shelf_q_factor, 62 | band0_gain_db, 63 | band0_cutoff_freq, 64 | band0_q_factor, 65 | band1_gain_db, 66 | band1_cutoff_freq, 67 | band1_q_factor, 68 | band2_gain_db, 69 | band2_cutoff_freq, 70 | band2_q_factor, 71 | band3_gain_db, 72 | band3_cutoff_freq, 73 | band3_q_factor, 74 | high_shelf_gain_db, 75 | high_shelf_cutoff_freq, 76 | high_shelf_q_factor, 77 | ) 78 | 79 | print(y) 80 | print(y.shape) 81 | 82 | fig, axs = plt.subplots(chs, 1, figsize=(10, 6)) 83 | for ch in range(chs): 84 | h_in = 20 * torch.log10(torch.fft.rfft(x[0, ch, :], dim=-1).abs() + 1e-8) 85 | h_out = 20 * torch.log10(torch.fft.rfft(y[0, ch, :], dim=-1).abs() + 1e-8) 86 | 87 | h_in_sm = savgol_filter(h_in.squeeze().numpy(), 255, 3) 88 | h_out_sm = savgol_filter(h_out.squeeze().numpy(), 255, 3) 89 | 90 | axs[ch].plot(h_out_sm - h_in_sm, label="input") 91 | axs[ch].set_xscale("log") 92 | axs[ch].legend() 93 | 94 | plt.savefig("test-peq.png", dpi=300) 95 | 96 | # torchaudio.save("test-peq-in.wav", x.view(1, -1), sample_rate) 97 | # torchaudio.save("test-peq-out.wav", y.view(1, -1), sample_rate) 98 | -------------------------------------------------------------------------------- /mst/callbacks/plotting.py: -------------------------------------------------------------------------------- 1 | import io 2 | import torch 3 | import librosa 4 | import PIL.Image 5 | import numpy as np 6 | import librosa.display 7 | import matplotlib.pyplot as plt 8 | 9 | from typing import Any 10 | from torch.functional import Tensor 11 | from torchvision.transforms import ToTensor 12 | from sklearn.metrics import ConfusionMatrixDisplay 13 | 14 | 15 | def plot_spectrograms( 16 | input: torch.Tensor, 17 | target: torch.Tensor, 18 | estimate: torch.Tensor, 19 | n_fft: int = 4096, 20 | hop_length: int = 1024, 21 | sample_rate: float = 48000, 22 | filename: Any = None, 23 | ): 24 | """Create a side-by-side plot of the attention weights and the spectrogram. 25 | Args: 26 | input (torch.Tensor): Input audio tensor with shape [1 x samples]. 27 | target (torch.Tensor): Target audio tensor with shape [1 x samples]. 28 | estimate (torch.Tensor): Estimate of the target audio with shape [1 x samples]. 29 | n_fft (int, optional): Analysis FFT size. 30 | hop_length (int, optional): Analysis hop length. 31 | sample_rate (float, optional): Audio sample rate. 32 | filename (str, optional): If a filename is supplied, the plot is saved to disk. 33 | """ 34 | # use librosa to take stft 35 | x_stft = librosa.stft( 36 | input.view(-1).numpy(), 37 | n_fft=n_fft, 38 | hop_length=hop_length, 39 | ) 40 | x_D = librosa.amplitude_to_db( 41 | np.abs(x_stft), 42 | ref=np.max, 43 | ) 44 | 45 | y_stft = librosa.stft( 46 | target.view(-1).numpy(), 47 | n_fft=n_fft, 48 | hop_length=hop_length, 49 | ) 50 | y_D = librosa.amplitude_to_db( 51 | np.abs(y_stft), 52 | ref=np.max, 53 | ) 54 | 55 | y_hat_stft = librosa.stft( 56 | estimate.view(-1).numpy(), 57 | n_fft=n_fft, 58 | hop_length=hop_length, 59 | ) 60 | y_hat_D = librosa.amplitude_to_db( 61 | np.abs(y_hat_stft), 62 | ref=np.max, 63 | ) 64 | 65 | fig, axs = plt.subplots( 66 | nrows=3, 67 | sharex=True, 68 | figsize=(7, 6), 69 | ) 70 | 71 | x_img = librosa.display.specshow( 72 | x_D, 73 | y_axis="log", 74 | x_axis="time", 75 | sr=sample_rate, 76 | hop_length=hop_length, 77 | ax=axs[0], 78 | ) 79 | 80 | y_img = librosa.display.specshow( 81 | y_D, 82 | y_axis="log", 83 | x_axis="time", 84 | sr=sample_rate, 85 | hop_length=hop_length, 86 | ax=axs[1], 87 | ) 88 | 89 | y_hat_img = librosa.display.specshow( 90 | y_hat_D, 91 | y_axis="log", 92 | x_axis="time", 93 | sr=sample_rate, 94 | hop_length=hop_length, 95 | ax=axs[2], 96 | ) 97 | 98 | plt.tight_layout() 99 | 100 | if filename is not None: 101 | plt.savefig(filename, dpi=300) 102 | 103 | return fig2img(fig) 104 | 105 | 106 | def plot_confusion_matrix(e_hat, e, labels=None, filename=None): 107 | fig, ax = plt.subplots(figsize=(10, 10)) 108 | cm = ConfusionMatrixDisplay.from_predictions( 109 | e, 110 | e_hat, 111 | labels=np.arange(len(labels)), 112 | display_labels=labels, 113 | ) 114 | cm.plot(ax=ax, xticks_rotation="vertical") 115 | 116 | plt.tight_layout() 117 | if filename is not None: 118 | plt.savefig(filename, dpi=300) 119 | 120 | return fig2img(fig) 121 | 122 | 123 | def fig2img(fig, dpi=120): 124 | """Convert a matplotlib figure to JPEG to be show in Tensorboard.""" 125 | buf = io.BytesIO() 126 | fig.savefig(buf, format="jpeg", dpi=dpi) 127 | buf.seek(0) 128 | image = PIL.Image.open(buf) 129 | image = ToTensor()(image) 130 | plt.close("all") 131 | return image 132 | -------------------------------------------------------------------------------- /mst/callbacks/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import wandb 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | 6 | 7 | from mst.callbacks.plotting import plot_spectrograms 8 | 9 | 10 | class LogAudioCallback(pl.callbacks.Callback): 11 | def __init__( 12 | self, 13 | num_batches: int = 8, 14 | peak_normalize: bool = True, 15 | sample_rate: int = 44100, 16 | ): 17 | super().__init__() 18 | self.num_batches = num_batches 19 | self.peak_normalize = peak_normalize 20 | self.sample_rate = sample_rate 21 | 22 | def on_validation_batch_end( 23 | self, 24 | trainer, 25 | pl_module, 26 | outputs, 27 | batch, 28 | batch_idx, 29 | ): 30 | """Called when the validation batch ends.""" 31 | if outputs is not None: 32 | num_examples = outputs["ref_mix_a"].shape[0] 33 | if batch_idx < self.num_batches: 34 | for sample_idx in range(num_examples): 35 | self.log_audio( 36 | outputs, 37 | batch_idx, 38 | sample_idx, 39 | pl_module.mix_console.sample_rate, 40 | trainer.global_step, 41 | trainer.logger, 42 | f"Epoch {trainer.current_epoch}", 43 | ) 44 | 45 | def log_audio( 46 | self, 47 | outputs, 48 | batch_idx: int, 49 | sample_idx: int, 50 | sample_rate: int, 51 | global_step: int, 52 | logger, 53 | caption: str, 54 | n_fft: int = 4096, 55 | hop_length: int = 1024, 56 | ): 57 | audio_files = [] 58 | audio_keys = [] 59 | total_samples = 0 60 | # put all audio in file 61 | for key, audio in outputs.items(): 62 | if "dict" in key: # skip parameters 63 | continue 64 | 65 | x = audio[sample_idx, ...].float() 66 | x = x.permute(1, 0) 67 | # x /= x.abs().max() 68 | audio_files.append(x) 69 | audio_keys.append(key) 70 | total_samples += x.shape[0] 71 | 72 | y = torch.zeros(total_samples + int(len(audio_keys) * sample_rate), 2) 73 | name = f"{batch_idx}_{sample_idx}" 74 | start = 0 75 | for x, key in zip(audio_files, audio_keys): 76 | end = start + x.shape[0] 77 | y[start:end, :] = x 78 | start = end + int(sample_rate) 79 | name += key + "-" 80 | 81 | logger.experiment.log( 82 | { 83 | f"{name}": wandb.Audio( 84 | y.numpy(), 85 | caption=caption, 86 | sample_rate=int(sample_rate), 87 | ) 88 | } 89 | ) 90 | 91 | # now try to log parameters 92 | pred_track_param_dict = outputs["pred_track_param_dict"] 93 | ref_track_param_dict = outputs["ref_track_param_dict"] 94 | 95 | pred_fx_bus_param_dict = outputs["pred_fx_bus_param_dict"] 96 | ref_fx_bus_param_dict = outputs["ref_fx_bus_param_dict"] 97 | 98 | pred_master_bus_param_dict = outputs["pred_master_bus_param_dict"] 99 | ref_master_bus_param_dict = outputs["ref_master_bus_param_dict"] 100 | 101 | effect_names = list(pred_track_param_dict.keys()) 102 | 103 | column_names = None 104 | rows = [] 105 | for effect_name in effect_names: 106 | param_names = list(pred_track_param_dict[effect_name].keys()) 107 | for param_name in param_names: 108 | pred_param_val = pred_track_param_dict[effect_name][param_name] 109 | ref_param_val = ref_track_param_dict[effect_name][param_name] 110 | 111 | row = [] 112 | row_name = f"{effect_name}.{param_name}" 113 | row.append(row_name) 114 | 115 | if column_names is None: 116 | column_names = ["parameter"] 117 | for i in range(pred_param_val.shape[1]): 118 | column_names.append(f"{i}_pred") 119 | column_names.append(f"{i}_ref") 120 | # column_names.append("master_bus_pred") 121 | # column_names.append("master_bus_ref") 122 | 123 | for i in range(pred_param_val.shape[1]): 124 | row.append(pred_param_val[sample_idx, i].item()) 125 | row.append(ref_param_val[sample_idx, i].item()) 126 | 127 | # row.append(pred_master_bus_param_dict[effect_name][batch_idx].item()) 128 | 129 | rows.append(row) 130 | 131 | wandb_table = wandb.Table(data=rows, columns=column_names) 132 | logger.experiment.log( 133 | {f"batch={batch_idx}_sample={sample_idx}_parameters": wandb_table} 134 | ) 135 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import torchaudio 5 | import pyloudnorm as pyln 6 | import pytorch_lightning as pl 7 | 8 | from mst.system import System 9 | from mst.utils import load_model 10 | 11 | # load a pretrained model and create a mix 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "config_path", 17 | type=str, 18 | help="Path to config.yaml for pretrained model checkpoint.", 19 | ) 20 | parser.add_argument( 21 | "ckpt_path", 22 | type=str, 23 | help="Path to pretrained model checkpoint.", 24 | ) 25 | parser.add_argument( 26 | "track_dir", 27 | type=str, 28 | help="Path to directory containing tracks.", 29 | ) 30 | parser.add_argument( 31 | "ref_mix", 32 | type=str, 33 | help="Path to reference mix.", 34 | ) 35 | parser.add_argument( 36 | "--output_dir", 37 | type=str, 38 | help="Path to output directory.", 39 | default="output", 40 | ) 41 | parser.add_argument( 42 | "--use_gpu", 43 | action="store_true", 44 | help="Whether to use GPU.", 45 | ) 46 | parser.add_argument( 47 | "--target_track_lufs_db", 48 | type=float, 49 | default=-32.0, 50 | ) 51 | args = parser.parse_args() 52 | 53 | # load model 54 | model = load_model( 55 | args.config_path, 56 | args.ckpt_path, 57 | map_location="gpu" if args.use_gpu else "cpu", 58 | ) 59 | sample_rate = 44100 60 | meter = pyln.Meter(sample_rate) 61 | 62 | print(f"Loaded model: {os.path.basename(args.ckpt_path)}\n") 63 | 64 | # load multitracks (wav files only) 65 | track_paths = [ 66 | os.path.join(args.track_dir, f) 67 | for f in os.listdir(args.track_dir) 68 | if ".wav" in f 69 | ] 70 | 71 | tracks = [] 72 | max_track_len = 0 73 | print("Loading tracks...") 74 | for idx, track_path in enumerate(track_paths): 75 | track, track_sr = torchaudio.load( 76 | track_path, frame_offset=262144, num_frames=262144 * 2 77 | ) 78 | if track_sr != sample_rate: 79 | track = torchaudio.functional.resample(track, track_sr, sample_rate) 80 | 81 | track = track[:, :262144] 82 | 83 | # loudness normalization 84 | track_lufs_db = meter.integrated_loudness(track.permute(1, 0).numpy()) 85 | delta_lufs_db = torch.tensor( 86 | [args.target_track_lufs_db - track_lufs_db] 87 | ).float() 88 | gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0) 89 | track = gain_lin * track 90 | 91 | tracks.append(track) 92 | print( 93 | f"({idx+1}/{len(track_paths)}): {os.path.basename(track_path)} {track.shape}" 94 | ) 95 | 96 | # correct length of tracks to be the same 97 | max_track_len = max([t.shape[1] for t in tracks]) 98 | for idx, track in enumerate(tracks): 99 | chs = track.shape[0] 100 | if track.shape[1] < max_track_len: 101 | pad = torch.zeros((chs, max_track_len - track.shape[1])) 102 | tracks[idx] = torch.cat([track, pad], dim=1) 103 | 104 | tracks = torch.cat(tracks, dim=0) 105 | 106 | # load reference track 107 | ref_mix, ref_sr = torchaudio.load(args.ref_mix) 108 | if ref_sr != sample_rate: 109 | ref_mix = torchaudio.functional.resample(ref_mix, ref_sr, sample_rate) 110 | ref_mix = ref_mix[:, :262144] 111 | print(f"\nLoaded reference mix: {os.path.basename(args.ref_mix)}.") 112 | print(f"tracks: {tracks.shape} ref_mix: {ref_mix.shape}\n") 113 | 114 | # create sum mix 115 | sum_mix = torch.sum(tracks, dim=0, keepdim=True) 116 | 117 | # create mix with model 118 | with torch.no_grad(): 119 | result = model( 120 | tracks.unsqueeze(0), 121 | ref_mix.unsqueeze(0), 122 | use_track_gain=True, 123 | use_track_panner=True, 124 | use_track_eq=False, 125 | use_track_compressor=False, 126 | use_fx_bus=False, 127 | use_master_bus=False, 128 | ) 129 | ( 130 | mixed_tracks, 131 | mix, 132 | track_param_dict, 133 | fx_bus_param_dict, 134 | master_bus_param_dict, 135 | ) = result 136 | mix = mix.squeeze(0) 137 | 138 | mix /= torch.max(torch.abs(mix)) # peak normalize 139 | sum_mix /= torch.max(torch.abs(sum_mix)) # peak normalize 140 | 141 | # save mix 142 | os.makedirs(args.output_dir, exist_ok=True) 143 | torchaudio.save(os.path.join(args.output_dir, "pred_mix.wav"), mix, sample_rate) 144 | torchaudio.save(os.path.join(args.output_dir, "ref_mix.wav"), ref_mix, sample_rate) 145 | torchaudio.save(os.path.join(args.output_dir, "sum_mix.wav"), sum_mix, sample_rate) 146 | print(f"Saved mixes to {args.output_dir}.\n") 147 | -------------------------------------------------------------------------------- /scripts/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import torchaudio 5 | from tqdm import tqdm 6 | 7 | 8 | def process_mixing_secrets(): 9 | out_dir = "/import/c4dm-datasets-ext/mixing-secrets-mono/" 10 | # process LibriSpeech dataset to 48 kHz sample rate 11 | root_dir = "/import/c4dm-datasets-ext/mixing-secrets/" 12 | 13 | song_dirs = glob.glob(os.path.join(root_dir, "*")) 14 | 15 | for song_dir in song_dirs: 16 | song_out_dir = os.path.join(out_dir, os.path.basename(song_dir)) 17 | if os.path.exists(song_out_dir): 18 | continue 19 | os.makedirs(song_out_dir, exist_ok=True) 20 | song_track_dir = os.path.join(song_dir, "tracks") 21 | song_out_track_dir = os.path.join(song_out_dir, "tracks") 22 | os.makedirs(song_out_track_dir, exist_ok=True) 23 | print(song_track_dir) 24 | 25 | song_track_sub_dirs = glob.glob(os.path.join(song_track_dir, "*")) 26 | 27 | track_load_dir = None 28 | for song_track_sub_dir in song_track_sub_dirs: 29 | if "_Full" in song_track_sub_dir: 30 | track_load_dir = song_track_sub_dir 31 | 32 | if track_load_dir is None: 33 | track_load_dir = song_track_sub_dirs[0] 34 | 35 | track_filepaths = glob.glob(os.path.join(track_load_dir, "*.wav")) 36 | 37 | for filepath in tqdm(track_filepaths): 38 | out_filepath = os.path.join(song_out_track_dir, os.path.basename(filepath)) 39 | # convert the sample rate to 48 kHz using ffmpeg 40 | try: 41 | x, sr = torchaudio.load(filepath) 42 | except Exception as e: 43 | print(e) 44 | continue 45 | 46 | if sr != 48000: 47 | x = torchaudio.functional.resample(x, sr, 48000) 48 | 49 | if x.shape[0] == 2: 50 | torchaudio.save( 51 | out_filepath.replace(".wav", "_L.wav"), 52 | x[0:1, :], 53 | 48000, 54 | encoding="PCM_S", 55 | bits_per_sample=16, 56 | ) 57 | torchaudio.save( 58 | out_filepath.replace(".wav", "_R.wav"), 59 | x[1:2, :], 60 | 48000, 61 | encoding="PCM_S", 62 | bits_per_sample=16, 63 | ) 64 | else: 65 | torchaudio.save( 66 | out_filepath, 67 | x, 68 | 48000, 69 | encoding="PCM_S", 70 | bits_per_sample=16, 71 | ) 72 | 73 | 74 | def process_medleydb(): 75 | out_dir = "/import/c4dm-datasets-ext/medleydb-mono/" 76 | # process LibriSpeech dataset to 48 kHz sample rate 77 | root_dirs = [ 78 | "/import/c4dm-datasets/MedleyDB_V1/V1", 79 | "/import/c4dm-datasets/MedleyDB_V2/V2", 80 | ] 81 | 82 | for root_dir in root_dirs: 83 | song_dirs = glob.glob(os.path.join(root_dir, "*")) 84 | 85 | for song_dir in song_dirs: 86 | print(song_dir) 87 | song_out_dir = os.path.join(out_dir, os.path.basename(song_dir)) 88 | # if os.path.exists(song_out_dir): 89 | # continue 90 | os.makedirs(song_out_dir, exist_ok=True) 91 | song_track_dir = os.path.join(song_dir, f"{os.path.basename(song_dir)}_RAW") 92 | song_out_track_dir = os.path.join(song_out_dir, f"tracks") 93 | os.makedirs(song_out_track_dir, exist_ok=True) 94 | print(song_track_dir) 95 | 96 | track_filepaths = glob.glob(os.path.join(song_track_dir, "*.wav")) 97 | 98 | for filepath in tqdm(track_filepaths): 99 | out_filepath = os.path.join( 100 | song_out_track_dir, os.path.basename(filepath) 101 | ) 102 | # convert the sample rate to 48 kHz using ffmpeg 103 | try: 104 | x, sr = torchaudio.load(filepath) 105 | except Exception as e: 106 | print(e) 107 | continue 108 | 109 | if sr != 48000: 110 | x = torchaudio.functional.resample(x, sr, 48000) 111 | 112 | if x.shape[0] == 2: 113 | torchaudio.save( 114 | out_filepath.replace(".wav", "_L.wav"), 115 | x[0:1, :], 116 | 48000, 117 | encoding="PCM_S", 118 | bits_per_sample=16, 119 | ) 120 | torchaudio.save( 121 | out_filepath.replace(".wav", "_R.wav"), 122 | x[1:2, :], 123 | 48000, 124 | encoding="PCM_S", 125 | bits_per_sample=16, 126 | ) 127 | else: 128 | torchaudio.save( 129 | out_filepath, 130 | x, 131 | 48000, 132 | encoding="PCM_S", 133 | bits_per_sample=16, 134 | ) 135 | 136 | 137 | if __name__ == "__main__": 138 | dataset = "medleydb" 139 | 140 | if dataset == "mixing-secrets": 141 | process_mixing_secrets() 142 | elif dataset == "medleydb": 143 | process_medleydb() 144 | else: 145 | raise NotImplementedError 146 | -------------------------------------------------------------------------------- /mst/filter.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import warnings 4 | 5 | # https://github.com/pytorch/audio/blob/d9942bae249329bd8c8bf5c92f0f108595fcb84f/torchaudio/functional/functional.py#L495 6 | 7 | 8 | def _create_triangular_filterbank( 9 | all_freqs: torch.Tensor, 10 | f_pts: torch.Tensor, 11 | ) -> torch.Tensor: 12 | """Create a triangular filter bank. 13 | 14 | Args: 15 | all_freqs (Tensor): STFT freq points of size (`n_freqs`). 16 | f_pts (Tensor): Filter mid points of size (`n_filter`). 17 | 18 | Returns: 19 | fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`). 20 | """ 21 | # Adopted from Librosa 22 | # calculate the difference between each filter mid point and each stft freq point in hertz 23 | f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1) 24 | slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2) 25 | # create overlapping triangles 26 | zero = torch.zeros(1) 27 | down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter) 28 | up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter) 29 | fb = torch.max(zero, torch.min(down_slopes, up_slopes)) 30 | 31 | return fb 32 | 33 | 34 | # https://github.com/pytorch/audio/blob/d9942bae249329bd8c8bf5c92f0f108595fcb84f/torchaudio/prototype/functional/functional.py#L6 35 | 36 | 37 | def _hz_to_bark(freqs: float, bark_scale: str = "traunmuller") -> float: 38 | r"""Convert Hz to Barks. 39 | 40 | Args: 41 | freqs (float): Frequencies in Hz 42 | bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``) 43 | 44 | Returns: 45 | barks (float): Frequency in Barks 46 | """ 47 | 48 | if bark_scale not in ["schroeder", "traunmuller", "wang"]: 49 | raise ValueError( 50 | 'bark_scale should be one of "schroeder", "traunmuller" or "wang".' 51 | ) 52 | 53 | if bark_scale == "wang": 54 | return 6.0 * math.asinh(freqs / 600.0) 55 | elif bark_scale == "schroeder": 56 | return 7.0 * math.asinh(freqs / 650.0) 57 | # Traunmuller Bark scale 58 | barks = ((26.81 * freqs) / (1960.0 + freqs)) - 0.53 59 | # Bark value correction 60 | if barks < 2: 61 | barks += 0.15 * (2 - barks) 62 | elif barks > 20.1: 63 | barks += 0.22 * (barks - 20.1) 64 | 65 | return barks 66 | 67 | 68 | def _bark_to_hz(barks: torch.Tensor, bark_scale: str = "traunmuller") -> torch.Tensor: 69 | """Convert bark bin numbers to frequencies. 70 | 71 | Args: 72 | barks (torch.Tensor): Bark frequencies 73 | bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``) 74 | 75 | Returns: 76 | freqs (torch.Tensor): Barks converted in Hz 77 | """ 78 | 79 | if bark_scale not in ["schroeder", "traunmuller", "wang"]: 80 | raise ValueError( 81 | 'bark_scale should be one of "traunmuller", "schroeder" or "wang".' 82 | ) 83 | 84 | if bark_scale == "wang": 85 | return 600.0 * torch.sinh(barks / 6.0) 86 | elif bark_scale == "schroeder": 87 | return 650.0 * torch.sinh(barks / 7.0) 88 | # Bark value correction 89 | if any(barks < 2): 90 | idx = barks < 2 91 | barks[idx] = (barks[idx] - 0.3) / 0.85 92 | elif any(barks > 20.1): 93 | idx = barks > 20.1 94 | barks[idx] = (barks[idx] + 4.422) / 1.22 95 | 96 | # Traunmuller Bark scale 97 | freqs = 1960 * ((barks + 0.53) / (26.28 - barks)) 98 | 99 | return freqs 100 | 101 | 102 | def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12): 103 | a440 = 440.0 * 2.0 ** (tuning / bins_per_octave) 104 | return torch.log2(freqs / (a440 / 16)) 105 | 106 | 107 | def barkscale_fbanks( 108 | n_freqs: int, 109 | f_min: float, 110 | f_max: float, 111 | n_barks: int, 112 | sample_rate: int, 113 | bark_scale: str = "traunmuller", 114 | ) -> torch.Tensor: 115 | r"""Create a frequency bin conversion matrix. 116 | 117 | .. devices:: CPU 118 | 119 | .. properties:: TorchScript 120 | 121 | .. image:: https://download.pytorch.org/torchaudio/doc-assets/bark_fbanks.png 122 | :alt: Visualization of generated filter bank 123 | 124 | Args: 125 | n_freqs (int): Number of frequencies to highlight/apply 126 | f_min (float): Minimum frequency (Hz) 127 | f_max (float): Maximum frequency (Hz) 128 | n_barks (int): Number of mel filterbanks 129 | sample_rate (int): Sample rate of the audio waveform 130 | bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``) 131 | 132 | Returns: 133 | torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_barks``) 134 | meaning number of frequencies to highlight/apply to x the number of filterbanks. 135 | Each column is a filterbank so that assuming there is a matrix A of 136 | size (..., ``n_freqs``), the applied result would be 137 | ``A * barkscale_fbanks(A.size(-1), ...)``. 138 | 139 | """ 140 | 141 | # freq bins 142 | all_freqs = torch.linspace(0, sample_rate // 2, n_freqs) 143 | 144 | # calculate bark freq bins 145 | m_min = _hz_to_bark(f_min, bark_scale=bark_scale) 146 | m_max = _hz_to_bark(f_max, bark_scale=bark_scale) 147 | 148 | m_pts = torch.linspace(m_min, m_max, n_barks + 2) 149 | f_pts = _bark_to_hz(m_pts, bark_scale=bark_scale) 150 | 151 | # create filterbank 152 | fb = _create_triangular_filterbank(all_freqs, f_pts) 153 | 154 | if (fb.max(dim=0).values == 0.0).any(): 155 | warnings.warn( 156 | "At least one bark filterbank has all zero values. " 157 | f"The value for `n_barks` ({n_barks}) may be set too high. " 158 | f"Or, the value for `n_freqs` ({n_freqs}) may be set too low." 159 | ) 160 | 161 | return fb 162 | -------------------------------------------------------------------------------- /mst/param_system.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import itertools 4 | import pytorch_lightning as pl 5 | 6 | from typing import Callable 7 | from mst.utils import batch_stereo_peak_normalize 8 | 9 | import warnings 10 | 11 | warnings.filterwarnings( 12 | "ignore" 13 | ) # fix this later to catch warnings about reading mp3 files 14 | 15 | 16 | class ParameterEstimationSystem(pl.LightningModule): 17 | def __init__( 18 | self, 19 | encoder: torch.nn.Module, 20 | projector: torch.nn.Module, 21 | mix_console: torch.nn.Module, 22 | remixer: torch.nn.Module, 23 | schedule: str = "step", 24 | lr: float = 3e-4, 25 | max_epochs: int = 500, 26 | **kwargs, 27 | ) -> None: 28 | super().__init__() 29 | self.encoder = encoder 30 | self.mix_console = mix_console 31 | self.remixer = remixer 32 | self.projector = projector 33 | self.save_hyperparameters( 34 | ignore=["encoder", "mix_console", "remixer", "projector"] 35 | ) 36 | 37 | def forward( 38 | self, 39 | input_mix: torch.Tensor, 40 | output_mix: torch.Tensor, 41 | ) -> torch.Tensor: 42 | # could consider masking different parts of input and output 43 | # so the model cannot rely on perfectly aligned inputs 44 | 45 | z_in_left = self.encoder(input_mix[:, 0:1, :]) 46 | z_in_right = self.encoder(input_mix[:, 1:2, :]) 47 | 48 | z_out_left = self.encoder(output_mix[:, 0:1, :]) 49 | z_out_right = self.encoder(output_mix[:, 1:2, :]) 50 | 51 | # take difference between embeddings 52 | z_diff_left = z_out_left - z_in_left 53 | z_diff_right = z_out_right - z_in_right 54 | 55 | z_diff = torch.cat([z_diff_left, z_diff_right], dim=-1) 56 | 57 | # project to parameter space 58 | track_params, fx_bus_params, master_bus_params = self.projector(z_diff) 59 | 60 | return track_params, fx_bus_params, master_bus_params 61 | 62 | def common_step( 63 | self, 64 | batch: tuple, 65 | batch_idx: int, 66 | optimizer_idx: int = 0, 67 | train: bool = False, 68 | ): 69 | """Model step used for validation and training. 70 | Args: 71 | batch (Tuple[Tensor, Tensor]): Batch items containing rmix, stems and orig mix 72 | batch_idx (int): Index of the batch within the current epoch. 73 | optimizer_idx (int): Index of the optimizer, this step is called once for each optimizer. 74 | train (bool): Wether step is called during training (True) or validation (False). 75 | """ 76 | input_mix = batch 77 | 78 | # create a remix 79 | output_mix, track_params, fx_bus_params, master_bus_params = self.remixer( 80 | input_mix, self.mix_console 81 | ) 82 | 83 | # estimate parameters 84 | track_params_hat, fx_bus_params_hat, master_bus_params_hat = self( 85 | input_mix, output_mix 86 | ) 87 | 88 | # calculate loss 89 | track_params_loss = torch.nn.functional.mse_loss( 90 | track_params_hat, 91 | track_params, 92 | ) 93 | fx_bus_params_loss = torch.nn.functional.mse_loss( 94 | fx_bus_params_hat, 95 | fx_bus_params, 96 | ) 97 | master_bus_params_loss = torch.nn.functional.mse_loss( 98 | master_bus_params_hat, 99 | master_bus_params, 100 | ) 101 | 102 | # scale by number of parameters 103 | track_params_loss *= track_params.shape[-1] + track_params.shape[-2] 104 | fx_bus_params_loss *= fx_bus_params.shape[-1] 105 | master_bus_params_loss *= master_bus_params.shape[-1] 106 | 107 | loss = track_params_loss + fx_bus_params_loss + master_bus_params_loss 108 | 109 | # log the losses 110 | self.log( 111 | ("train" if train else "val") + "/track_param_loss", 112 | track_params_loss, 113 | on_step=True, 114 | on_epoch=True, 115 | prog_bar=True, 116 | logger=True, 117 | sync_dist=True, 118 | ) 119 | 120 | self.log( 121 | ("train" if train else "val") + "/fx_bus_param_loss", 122 | fx_bus_params_loss, 123 | on_step=True, 124 | on_epoch=True, 125 | prog_bar=True, 126 | logger=True, 127 | sync_dist=True, 128 | ) 129 | 130 | self.log( 131 | ("train" if train else "val") + "/master_bus_param_loss", 132 | master_bus_params_loss, 133 | on_step=True, 134 | on_epoch=True, 135 | prog_bar=True, 136 | logger=True, 137 | sync_dist=True, 138 | ) 139 | 140 | # log the losses 141 | self.log( 142 | ("train" if train else "val") + "/loss", 143 | loss, 144 | on_step=True, 145 | on_epoch=True, 146 | prog_bar=True, 147 | logger=True, 148 | sync_dist=True, 149 | ) 150 | 151 | return loss 152 | 153 | def training_step(self, batch, batch_idx, optimizer_idx=0): 154 | loss = self.common_step(batch, batch_idx, train=True) 155 | return loss 156 | 157 | def validation_step(self, batch, batch_idx): 158 | loss = self.common_step(batch, batch_idx, train=False) 159 | return loss 160 | 161 | def configure_optimizers(self): 162 | optimizer = torch.optim.Adam( 163 | itertools.chain(self.encoder.parameters(), self.projector.parameters()), 164 | lr=self.hparams.lr, 165 | betas=(0.9, 0.999), 166 | ) 167 | 168 | if self.hparams.schedule == "cosine": 169 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 170 | optimizer, T_max=self.hparams.max_epochs 171 | ) 172 | elif self.hparams.schedule == "step": 173 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 174 | optimizer, 175 | [ 176 | int(self.hparams.max_epochs * 0.85), 177 | int(self.hparams.max_epochs * 0.95), 178 | ], 179 | ) 180 | else: 181 | return optimizer 182 | lr_schedulers = {"scheduler": scheduler, "interval": "epoch", "frequency": 1} 183 | 184 | return [optimizer], lr_schedulers 185 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 | # Diff-MST: Differentiable Mixing Style Transfer 5 | [Paper](https://sai-soum.github.io/assets/pdf/Differentiable_Mixing_Style_Transfer.pdf) | [Website](https://sai-soum.github.io/projects/diffmst/) | [Video](https://youtu.be/w90RGZ3IqQw) 6 | 7 | 8 | 9 | 10 |
11 | 12 | 20 | # Repository Structure 21 | 1. `configs` - Contains configuration files for training and inference. 22 | 2. `mst` - Contains the main codebase for the project. 23 | - `dataloaders` - Contains dataloaders for the project. 24 | - `modules` - Contains the modules for different components of the system. 25 | - `mixing` - Contains the mixing modules for creating mixes. 26 | - `loss` - Contains the loss functions for the project. 27 | - `panns` - contains the most basic components like cnn14, resnet, etc. 28 | - `utils` - Contains utility functions for the project. 29 | 3. `scripts` - Contains scripts for running inference. 30 | 31 | # Setup 32 | - Clone the repository 33 | ``` 34 | git clone https://github.com/sai-soum/Diff-MST.git 35 | cd Diff-MST 36 | ``` 37 | 38 | - Create new Python environment 39 | ``` 40 | # for Linux/macOS 41 | python3 -m venv env 42 | source env/bin/activate 43 | ``` 44 | 45 | - Install the `mst` package from source 46 | ``` 47 | # Install as editable (for development) 48 | pip install -e . 49 | 50 | # Alternatively, do a regular install (read-only) 51 | pip install . 52 | ``` 53 | 54 | # Usage 55 | ## Train 56 | We use [LightningCLI](https://lightning.ai/docs/pytorch/stable/) for training and [Wandb](https://wandb.ai/site) for logging. 57 | 58 | ### Setup 59 | In the `configs` directory, you will find the configuration files for the project. 60 | - `config.yaml` - Contains the general configuration for the project. 61 | - `optimizer.yaml` - Contains the optimizer configuration for the project. 62 | - `data/` - Contains the data configuration for the project. 63 | - `models/` - Contains the model configuration for the project. 64 | We have provided instructions within the configuration files for setting up the project. 65 | 66 | Few important configuration parameters: 67 | - In `configs/data/` change the following 68 | - `track_root_dirs` - The root directory for the dataset needs to be setup. You can pass multiple dataset directories as a list. However, you will also need to provide corresponding metadata YAML files containing train, test, and val split. Check `data/` directory for examples. 69 | - For method 1: set `generate_mix` to `True` in the model configuration file. Use `medley+cambridge-8.yaml` for training with random mixes of the same song as reference. 70 | - For method 2: set `generate_mix` to `False` in the model configuration file. Use `medley+cambridge+jamendo-8.yaml` for training with real unpaired songs as reference. 71 | - update `mix_root_dirs` - The root directory for the mix dataset. This is used for training with real unpaired songs as reference. 72 | - You may benefit from setting a smaller value for `train_buffer_size_gb` and `val_buffer_size_gb` in the data configuration file for initial testing of the code. 73 | - In `configs/models/` 74 | - you can change the audio effects you want to disable by setting a very large value for the corresponding parameter. For example, to disable the compressor, set `active_compressor_epoch` to `1000`. 75 | - You can change the loss function used for training by setting the `loss` parameter. 76 | - In `optimizer.yaml` you can change the learning rate parameters. 77 | - In `config.yaml` 78 | - Update the directory for logging using `save_dir` under `trainer`. 79 | - You can use `ckpt_path` to load a pre-trained model for fine-tuning, resuming training, or testing. 80 | 81 | 82 | ### Method 1: Training with random mixes of the same song as reference using MRSTFT loss. 83 | ``` 84 | CUDA_VISIBLE_DEVICES=0 python main.py fit \ 85 | -c configs/config.yaml \ 86 | -c configs/optimizer.yaml \ 87 | -c configs/data/medley+cambridge-8.yaml \ 88 | -c configs/models/naive.yaml 89 | ``` 90 | 91 | To run the fine-tuning using AFLoss 92 | ``` 93 | CUDA_VISIBLE_DEVICES=0 python main.py fit \ 94 | -c configs/config.yaml \ 95 | -c configs/optimizer.yaml \ 96 | -c configs/data/medley+cambridge-8.yaml \ 97 | -c configs/models/naive+feat.yaml 98 | ``` 99 | 100 | You can change the number of tracks, the size of training data for an epoch, and the batch size in the data configuration file located at `configs/data/` 101 | 102 | ### Method 2: Training with real unpaired songs as reference using AFloss. 103 | 104 | ``` 105 | CUDA_VISIBLE_DEVICES=0 python main.py fit \ 106 | -c configs/config.yaml \ 107 | -c configs/optimizer.yaml \ 108 | -c configs/data/medley+cambridge+jamendo-8.yaml \ 109 | -c configs/models/unpaired+feat.yaml 110 | ``` 111 | 112 | ## Inference 113 | To evaluate the model on real world data, run the ` scripts/eval_all_combo.py` script. 114 | 115 | Update the model checkpoints and the inference examples directory in the script. 116 | 117 | `Python 3.10` was used for training. 118 | 119 | 120 | ## Acknowledgements 121 | This work is funded and supported by UK Research and Innovation [grant number EP/S022694/1] and Steinberg Media Technologies GmbH under the AI and Music Centre for Doctoral Training (AIM-CDT) at the Centre for Digital Music, Queen Mary University of London, London, UK. 122 | 123 | ## Citation 124 | If you find this work useful, please consider citing our paper: 125 | ``` 126 | @inproceedings{vanka2024diffmst, 127 | title={Diff-MST: Differentiable Mixing Style Transfer}, 128 | author={Vanka, Soumya and Steinmetz, Christian and Rolland, Jean-Baptiste and Reiss, Joshua and Fazekas, Gy{\"o}rgy}, 129 | booktitle={Proc. of the 25th Int. Society for Music Information Retrieval Conf. (ISMIR)}, 130 | year={2024}, 131 | organization={Int. Society for Music Information Retrieval (ISMIR)}, 132 | abbr = {ISMIR}, 133 | address = {San Francisco, USA}, 134 | } 135 | ``` 136 | 137 | ## License 138 | The code is licensed under the terms of the CC-BY-NC-SA 4.0 license. For a human-readable summary of the license, see https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en . 139 | -------------------------------------------------------------------------------- /mst/panns.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py 2 | # Under MIT License 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from typing import List 7 | 8 | # from torchlibrosa.stft import Spectrogram, LogmelFilterBank 9 | # from torchlibrosa.augmentation import SpecAugmentation 10 | 11 | 12 | def init_layer(layer): 13 | """Initialize a Linear or Convolutional layer.""" 14 | nn.init.xavier_uniform_(layer.weight) 15 | 16 | if hasattr(layer, "bias"): 17 | if layer.bias is not None: 18 | layer.bias.data.fill_(0.0) 19 | 20 | 21 | def init_bn(bn): 22 | """Initialize a Batchnorm layer.""" 23 | bn.bias.data.fill_(0.0) 24 | bn.weight.data.fill_(1.0) 25 | 26 | 27 | class ConvBlock(nn.Module): 28 | def __init__(self, in_channels, out_channels, use_batchnorm: bool = True, pool_type: str = 'avg'): 29 | super(ConvBlock, self).__init__() 30 | self.use_batchnorm = use_batchnorm 31 | 32 | self.conv1 = nn.Conv2d( 33 | in_channels=in_channels, 34 | out_channels=out_channels, 35 | kernel_size=(3, 3), 36 | stride=(1, 1), 37 | padding=(1, 1), 38 | bias=False, 39 | ) 40 | 41 | self.conv2 = nn.Conv2d( 42 | in_channels=out_channels, 43 | out_channels=out_channels, 44 | kernel_size=(3, 3), 45 | stride=(1, 1), 46 | padding=(1, 1), 47 | bias=False, 48 | ) 49 | 50 | if use_batchnorm: 51 | self.bn1 = nn.BatchNorm2d(out_channels) 52 | self.bn2 = nn.BatchNorm2d(out_channels) 53 | else: 54 | self.bn1 = nn.Identity() 55 | self.bn2 = nn.Identity() 56 | 57 | if pool_type == "max": 58 | self.pool_fn = F.max_pool2d 59 | elif pool_type == "avg": 60 | self.pool_fn = F.avg_pool2d 61 | elif pool_type == "avg+max": 62 | def pool_avg_max(x: torch.Tensor, kernel_size: List[int]): 63 | return F.avg_pool2d(x, kernel_size) + F.max_pool2d(x, kernel_size) 64 | self.pool_fn = pool_avg_max 65 | else: 66 | raise Exception("Incorrect argument for `pool_type`!") 67 | 68 | 69 | self.init_weight() 70 | 71 | def init_weight(self): 72 | init_layer(self.conv1) 73 | init_layer(self.conv2) 74 | 75 | if self.use_batchnorm: 76 | init_bn(self.bn1) 77 | init_bn(self.bn2) 78 | 79 | def forward(self, input: torch.Tensor, pool_size: List[int]): 80 | x = input 81 | x = F.relu_(self.bn1(self.conv1(x))) 82 | x = F.relu_(self.bn2(self.conv2(x))) 83 | x = self.pool_fn(x, pool_size) 84 | 85 | return x 86 | 87 | 88 | class ConvBlock5x5(nn.Module): 89 | def __init__(self, in_channels, out_channels): 90 | super(ConvBlock5x5, self).__init__() 91 | 92 | self.conv1 = nn.Conv2d( 93 | in_channels=in_channels, 94 | out_channels=out_channels, 95 | kernel_size=(5, 5), 96 | stride=(1, 1), 97 | padding=(2, 2), 98 | bias=False, 99 | ) 100 | 101 | self.bn1 = nn.BatchNorm2d(out_channels) 102 | 103 | self.init_weight() 104 | 105 | def init_weight(self): 106 | init_layer(self.conv1) 107 | init_bn(self.bn1) 108 | 109 | def forward(self, input, pool_size=(2, 2), pool_type="avg"): 110 | x = input 111 | x = F.relu_(self.bn1(self.conv1(x))) 112 | if pool_type == "max": 113 | x = F.max_pool2d(x, kernel_size=pool_size) 114 | elif pool_type == "avg": 115 | x = F.avg_pool2d(x, kernel_size=pool_size) 116 | elif pool_type == "avg+max": 117 | x1 = F.avg_pool2d(x, kernel_size=pool_size) 118 | x2 = F.max_pool2d(x, kernel_size=pool_size) 119 | x = x1 + x2 120 | else: 121 | raise Exception("Incorrect argument!") 122 | 123 | return x 124 | 125 | 126 | class Cnn14(nn.Module): 127 | def __init__( 128 | self, 129 | num_classes: int, 130 | n_inputs: int = 1, 131 | use_batchnorm: bool = True, 132 | ): 133 | super(Cnn14, self).__init__() 134 | 135 | self.conv_block1 = ConvBlock( 136 | in_channels=n_inputs, 137 | out_channels=64, 138 | use_batchnorm=use_batchnorm, 139 | ) 140 | self.conv_block2 = ConvBlock( 141 | in_channels=64, 142 | out_channels=128, 143 | use_batchnorm=use_batchnorm, 144 | ) 145 | self.conv_block3 = ConvBlock( 146 | in_channels=128, 147 | out_channels=256, 148 | use_batchnorm=use_batchnorm, 149 | ) 150 | self.conv_block4 = ConvBlock( 151 | in_channels=256, 152 | out_channels=512, 153 | use_batchnorm=use_batchnorm, 154 | ) 155 | self.conv_block5 = ConvBlock( 156 | in_channels=512, 157 | out_channels=1024, 158 | use_batchnorm=use_batchnorm, 159 | ) 160 | self.conv_block6 = ConvBlock( 161 | in_channels=1024, 162 | out_channels=2048, 163 | use_batchnorm=use_batchnorm, 164 | ) 165 | 166 | self.fc = nn.Linear(2048, num_classes, bias=True) 167 | self.init_weight() 168 | 169 | def init_weight(self): 170 | # init_bn(self.bn0) 171 | init_layer(self.fc) 172 | 173 | def forward(self, x: torch.Tensor): 174 | """ 175 | input (torch.Tensor): Spectrogram tensor with shape (bs, chs, bins, frames) 176 | """ 177 | batch_size, chs, bins, frames = x.size() 178 | 179 | # x = x.view(batch_size, -1) 180 | # x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) 181 | # x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) 182 | # x = x.transpose(1, 3) 183 | # x = self.bn0(x) 184 | # x = x.transpose(1, 3) 185 | # if self.training: 186 | # x = self.spec_augmenter(x) 187 | 188 | x = self.conv_block1(x, pool_size=(2, 2)) 189 | # x = F.dropout(x, p=0.2, training=self.training) 190 | x = self.conv_block2(x, pool_size=(4, 4)) 191 | # x = F.dropout(x, p=0.2, training=self.training) 192 | x = self.conv_block3(x, pool_size=(4, 2)) 193 | # x = F.dropout(x, p=0.2, training=self.training) 194 | x = self.conv_block4(x, pool_size=(4, 2)) 195 | # x = F.dropout(x, p=0.2, training=self.training) 196 | x = self.conv_block5(x, pool_size=(4, 2)) 197 | # x = F.dropout(x, p=0.2, training=self.training) 198 | x = self.conv_block6(x, pool_size=(2, 2)) 199 | # x = F.dropout(x, p=0.2, training=self.training) 200 | x = torch.mean(x, dim=2) # mean across stft bins 201 | 202 | (x1, _) = torch.max(x, dim=2) 203 | x2 = torch.mean(x, dim=2) 204 | x = x1 + x2 205 | # x = F.dropout(x, p=0.5, training=self.training) 206 | x_out = self.fc(x) 207 | clipwise_output = x_out 208 | 209 | return clipwise_output 210 | -------------------------------------------------------------------------------- /mst/loss.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import librosa 4 | 5 | from typing import List 6 | from mst.filter import barkscale_fbanks 7 | 8 | 9 | def compute_mid_side(x: torch.Tensor): 10 | x_mid = x[:, 0, :] + x[:, 1, :] 11 | x_side = x[:, 0, :] - x[:, 1, :] 12 | return x_mid, x_side 13 | 14 | 15 | from mst.filter import barkscale_fbanks 16 | 17 | import yaml 18 | from mst.fx_encoder import FXencoder 19 | 20 | from mst.modules import SpectrogramEncoder 21 | 22 | 23 | def compute_mid_side(x: torch.Tensor): 24 | x_mid = x[:, 0, :] + x[:, 1, :] 25 | x_side = x[:, 0, :] - x[:, 1, :] 26 | return x_mid, x_side 27 | 28 | 29 | def compute_melspectrum( 30 | x: torch.Tensor, 31 | sample_rate: int = 44100, 32 | fft_size: int = 32768, 33 | n_bins: int = 128, 34 | **kwargs, 35 | ): 36 | """Compute mel-spectrogram. 37 | 38 | Args: 39 | x: (bs, 2, seq_len) 40 | sample_rate: sample rate of audio 41 | fft_size: size of fft 42 | n_bins: number of mel bins 43 | 44 | Returns: 45 | X: (bs, n_bins) 46 | 47 | """ 48 | fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) 49 | fb = torch.tensor(fb).unsqueeze(0).type_as(x) 50 | 51 | x = x.mean(dim=1, keepdim=True) 52 | X = torch.fft.rfft(x, n=fft_size, dim=-1) 53 | X = torch.abs(X) 54 | X = torch.mean(X, dim=1, keepdim=True) # take mean over time 55 | X = X.permute(0, 2, 1) # swap time and freq dims 56 | X = torch.matmul(fb, X) 57 | X = torch.log(X + 1e-8) 58 | 59 | return X 60 | 61 | 62 | def compute_barkspectrum( 63 | x: torch.Tensor, 64 | fft_size: int = 32768, 65 | n_bands: int = 24, 66 | sample_rate: int = 44100, 67 | f_min: float = 20.0, 68 | f_max: float = 20000.0, 69 | mode: str = "mid-side", 70 | **kwargs, 71 | ): 72 | """Compute bark-spectrogram. 73 | 74 | Args: 75 | x: (bs, 2, seq_len) 76 | fft_size: size of fft 77 | n_bands: number of bark bins 78 | sample_rate: sample rate of audio 79 | f_min: minimum frequency 80 | f_max: maximum frequency 81 | mode: "mono", "stereo", or "mid-side" 82 | 83 | Returns: 84 | X: (bs, 24) 85 | 86 | """ 87 | # compute filterbank 88 | fb = barkscale_fbanks((fft_size // 2) + 1, f_min, f_max, n_bands, sample_rate) 89 | fb = fb.unsqueeze(0).type_as(x) 90 | fb = fb.permute(0, 2, 1) 91 | 92 | if mode == "mono": 93 | x = x.mean(dim=1) # average over channels 94 | signals = [x] 95 | elif mode == "stereo": 96 | signals = [x[:, 0, :], x[:, 1, :]] 97 | elif mode == "mid-side": 98 | x_mid = x[:, 0, :] + x[:, 1, :] 99 | x_side = x[:, 0, :] - x[:, 1, :] 100 | signals = [x_mid, x_side] 101 | else: 102 | raise ValueError(f"Invalid mode {mode}") 103 | 104 | outputs = [] 105 | for signal in signals: 106 | X = torch.stft( 107 | signal, 108 | n_fft=fft_size, 109 | hop_length=fft_size // 4, 110 | return_complex=True, 111 | window=torch.hann_window(fft_size).to(x.device), 112 | ) # compute stft 113 | X = torch.abs(X) # take magnitude 114 | X = torch.mean(X, dim=-1, keepdim=True) # take mean over time 115 | # X = X.permute(0, 2, 1) # swap time and freq dims 116 | X = torch.matmul(fb, X) # apply filterbank 117 | X = torch.log(X + 1e-8) 118 | # X = torch.cat([X, X_log], dim=-1) 119 | outputs.append(X) 120 | 121 | # stack into tensor 122 | X = torch.cat(outputs, dim=-1) 123 | 124 | return X 125 | 126 | 127 | def compute_rms(x: torch.Tensor, **kwargs): 128 | """Compute root mean square energy. 129 | 130 | Args: 131 | x: (bs, 1, seq_len) 132 | 133 | Returns: 134 | rms: (bs, ) 135 | """ 136 | rms = torch.sqrt(torch.mean(x**2, dim=-1).clamp(min=1e-8)) 137 | return rms 138 | 139 | 140 | def compute_crest_factor(x: torch.Tensor, **kwargs): 141 | """Compute crest factor as ratio of peak to rms energy in dB. 142 | 143 | Args: 144 | x: (bs, 2, seq_len) 145 | 146 | """ 147 | num = torch.max(torch.abs(x), dim=-1)[0] 148 | den = compute_rms(x).clamp(min=1e-8) 149 | cf = 20 * torch.log10((num / den).clamp(min=1e-8)) 150 | return cf 151 | 152 | 153 | def compute_stereo_width(x: torch.Tensor, **kwargs): 154 | """Compute stereo width as ratio of energy in sum and difference signals. 155 | 156 | Args: 157 | x: (bs, 2, seq_len) 158 | 159 | """ 160 | bs, chs, seq_len = x.size() 161 | 162 | assert chs == 2, "Input must be stereo" 163 | 164 | # compute sum and diff of stereo channels 165 | x_sum = x[:, 0, :] + x[:, 1, :] 166 | x_diff = x[:, 0, :] - x[:, 1, :] 167 | 168 | # compute power of sum and diff 169 | sum_energy = torch.mean(x_sum**2, dim=-1) 170 | diff_energy = torch.mean(x_diff**2, dim=-1) 171 | 172 | # compute stereo width as ratio 173 | stereo_width = diff_energy / sum_energy.clamp(min=1e-8) 174 | 175 | return stereo_width 176 | 177 | 178 | def compute_stereo_imbalance(x: torch.Tensor, **kwargs): 179 | """Compute stereo imbalance as ratio of energy in left and right channels. 180 | 181 | Args: 182 | x: (bs, 2, seq_len) 183 | 184 | Returns: 185 | stereo_imbalance: (bs, ) 186 | 187 | """ 188 | left_energy = torch.mean(x[:, 0, :] ** 2, dim=-1) 189 | right_energy = torch.mean(x[:, 1, :] ** 2, dim=-1) 190 | 191 | stereo_imbalance = (right_energy - left_energy) / ( 192 | right_energy + left_energy 193 | ).clamp(min=1e-8) 194 | 195 | return stereo_imbalance 196 | 197 | 198 | class AudioFeatureLoss(torch.nn.Module): 199 | def __init__( 200 | self, 201 | weights: List[float], 202 | sample_rate: int, 203 | stem_separation: bool = False, 204 | use_clap: bool = False, 205 | ) -> None: 206 | """Compute loss using a set of differentiable audio features. 207 | 208 | Args: 209 | weights: weights for each feature 210 | sample_rate: sample rate of audio 211 | stem_separation: whether to compute loss on stems or mix 212 | 213 | Based on features proposed in: 214 | 215 | Man, B. D., et al. 216 | "An analysis and evaluation of audio features for multitrack music mixtures." 217 | (2014). 218 | 219 | """ 220 | super().__init__() 221 | self.weights = weights 222 | self.sample_rate = sample_rate 223 | self.stem_separation = stem_separation 224 | self.sources_list = ["mix"] 225 | self.source_weights = [1.0] 226 | self.use_clap = use_clap 227 | 228 | self.transforms = [ 229 | compute_rms, 230 | compute_crest_factor, 231 | compute_stereo_width, 232 | compute_stereo_imbalance, 233 | compute_barkspectrum, 234 | ] 235 | 236 | assert len(self.transforms) == len(weights) 237 | 238 | def forward(self, input: torch.Tensor, target: torch.Tensor): 239 | losses = {} 240 | 241 | # reshape for example stem dim 242 | input_stems = input.unsqueeze(1) 243 | target_stems = target.unsqueeze(1) 244 | 245 | n_stems = input_stems.shape[1] 246 | 247 | # iterate over each stem compute loss for each transform 248 | for stem_idx in range(n_stems): 249 | input_stem = input_stems[:, stem_idx, ...] 250 | target_stem = target_stems[:, stem_idx, ...] 251 | 252 | for transform, weight in zip(self.transforms, self.weights): 253 | transform_name = "_".join(transform.__name__.split("_")[1:]) 254 | key = f"{self.sources_list[stem_idx]}-{transform_name}" 255 | input_transform = transform(input_stem, sample_rate=self.sample_rate) 256 | target_transform = transform(target_stem, sample_rate=self.sample_rate) 257 | val = torch.nn.functional.mse_loss(input_transform, target_transform) 258 | losses[key] = weight * val * self.source_weights[stem_idx] 259 | 260 | return losses 261 | -------------------------------------------------------------------------------- /scripts/gain_testing.py: -------------------------------------------------------------------------------- 1 | 2 | # run pretrained models over evaluation set to generate audio examples for the listening test 3 | import os 4 | import torch 5 | import torchaudio 6 | import pyloudnorm as pyln 7 | from mst.utils import load_diffmst, run_diffmst 8 | from mst.loss import compute_barkspectrum, compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, AudioFeatureLoss 9 | import json 10 | import numpy as np 11 | import csv 12 | import glob 13 | import yaml 14 | 15 | 16 | def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs): 17 | 18 | meter = pyln.Meter(44100) 19 | target_lufs_db = -48.0 20 | 21 | norm_tracks = [] 22 | for track_idx in range(tracks.shape[1]): 23 | track = tracks[:, track_idx : track_idx + 1, :] 24 | lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy()) 25 | 26 | if lufs_db < -80.0: 27 | print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.") 28 | continue 29 | 30 | lufs_delta_db = target_lufs_db - lufs_db 31 | track *= 10 ** (lufs_delta_db / 20) 32 | norm_tracks.append(track) 33 | 34 | norm_tracks = torch.cat(norm_tracks, dim=1) 35 | # create a sum mix with equal loudness 36 | sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1) 37 | sum_mix /= sum_mix.abs().max() 38 | 39 | return sum_mix, None, None, None 40 | 41 | 42 | 43 | if __name__ == "__main__": 44 | meter = pyln.Meter(44100) 45 | target_mix_lufs_db = -16.0 46 | target_track_lufs_db = -48.0 47 | output_dir = "outputs/gain_testing_diff_song_individual_tracks" 48 | os.makedirs(output_dir, exist_ok=True) 49 | 50 | methods = { 51 | "diffmst-16": { 52 | "model": load_diffmst( 53 | "/Users/svanka/Downloads/b4naquji/config.yaml", 54 | "/Users/svanka/Downloads/b4naquji/checkpoints/epoch=191-step=626608.ckpt", 55 | ), 56 | "func": run_diffmst, 57 | }, 58 | # "sum": { 59 | # "model": (None, None), 60 | # "func": equal_loudness_mix, 61 | # }, 62 | } 63 | 64 | ref_dir = "/Users/svanka/Downloads/DSD100subset/sources/Dev/055 - Angels In Amplifiers - I'm Alright" 65 | #mix_dir = "/Users/svanka/Downloads/DSD100subset/sources/Dev/055 - Angels In Amplifiers - I'm Alright" 66 | mix_dir = "/Users/svanka/Downloads/DSD100subset/Sources/Test/049 - Young Griffo - Facade" 67 | 68 | ref_tracks = glob.glob(os.path.join(ref_dir, "*.wav")) 69 | mix_tracks = glob.glob(os.path.join(mix_dir, "*.wav")) 70 | 71 | print(len(ref_tracks), len(mix_tracks)) 72 | #order the tracks in ref_tracks to vocals, bass, other, drums 73 | ref_tracks_ordered = [""] * 4 74 | for track in ref_tracks: 75 | if "vocals" in track: 76 | ref_tracks_ordered[0] = track 77 | elif "bass" in track: 78 | ref_tracks_ordered[1] = track 79 | elif "other" in track: 80 | ref_tracks_ordered[2] = track 81 | elif "drums" in track: 82 | ref_tracks_ordered[3] = track 83 | ref_tracks = ref_tracks_ordered 84 | 85 | print(ref_tracks) 86 | # print(mix_tracks) 87 | 88 | 89 | #we will predict a mix for one track from reference, sum of two, sum of three, sum of four tracks from reference as the reference for model 90 | # and the mix as the input 91 | 92 | tracks = [] 93 | #info = torchaudio.info(mix_tracks[0]) 94 | 95 | 96 | track_instrument = [] 97 | for track in mix_tracks: 98 | #audio, sr = torchaudio.load(track, frame_offset = int((info.num_frames)/2 - 220500), num_frames = 441000, backend="soundfile") 99 | audio, sr = torchaudio.load(track,num_frames = 441000, backend="soundfile") 100 | if sr != 44100: 101 | audio = torchaudio.functional.resample(audio, sr, 44100) 102 | 103 | if audio.shape[0] == 2: 104 | audio = audio.mean(dim=0, keepdim=True) 105 | 106 | tracks.append(audio) 107 | track_instrument.append(os.path.basename(track).replace(".wav", "")) 108 | 109 | 110 | tracks = torch.cat(tracks, dim=0) 111 | print("tracks shape", tracks.shape) 112 | tracks = tracks.unsqueeze(0) 113 | print("tracks shape", tracks.shape) 114 | 115 | #create a sum mix 116 | sum_mix, _, _, _ = equal_loudness_mix(tracks) 117 | print("sum_mix shape", sum_mix.shape) 118 | save_path = os.path.join(output_dir, f"{os.path.basename(mix_dir)}-sum_mix.wav") 119 | torchaudio.save(save_path, sum_mix.view(2, -1), 44100) 120 | 121 | ref_mix_tracks = [] 122 | info = torchaudio.info(ref_tracks[0]) 123 | name = "ref_mix-16=" 124 | data = {} 125 | data["track_instrument"] = track_instrument 126 | for i , ref_track in enumerate(ref_tracks): 127 | instrument = name + "-" + os.path.basename(ref_track).replace(".wav", "") 128 | print(instrument) 129 | #name = instrument 130 | ref_audio, sr = torchaudio.load(ref_track, frame_offset = int((info.num_frames)/2 - 220500), num_frames = 441000, backend="soundfile") 131 | if sr != 44100: 132 | ref_audio = torchaudio.functional.resample(ref_audio, sr, 44100) 133 | 134 | #loudness normalize the reference mix to -48 LUFS 135 | ref_lufs_db = meter.integrated_loudness(ref_audio.squeeze().permute(1, 0).numpy()) 136 | lufs_delta_db = target_track_lufs_db - ref_lufs_db 137 | ref_audio = ref_audio * 10 ** (lufs_delta_db / 20) 138 | 139 | #ref_mix_tracks.append(ref_audio) 140 | ref_mix_tracks = [ref_audio] 141 | ref_mix = torch.cat(ref_mix_tracks, dim=0) 142 | #create a stereo sum mix 143 | ref_mix = ref_mix.sum(dim=0, keepdim=True).repeat(1, 2, 1) 144 | #normalise to -16 LUFS 145 | ref_mix_lufs_db = meter.integrated_loudness(ref_mix.squeeze().permute(1, 0).numpy()) 146 | lufs_delta_db = target_mix_lufs_db - ref_mix_lufs_db 147 | ref_mix = ref_mix * 10 ** (lufs_delta_db / 20) 148 | ref_save_path = os.path.join(output_dir, f"{os.path.basename(ref_dir)}-{instrument}.wav") 149 | torchaudio.save(ref_save_path, ref_mix.view(2, -1), 44100) 150 | 151 | yaml_path = os.path.join(output_dir, f"{os.path.basename(ref_dir)}-{instrument}.yaml") 152 | data["ref_mix"] = ref_save_path 153 | data["ref_instruments"] = instrument 154 | data["sum_mix"] = save_path 155 | #check if the json file exists 156 | print("tracks shape", tracks.shape) 157 | print("ref_mix shape", ref_mix.shape) 158 | 159 | 160 | 161 | 162 | for method_name, method in methods.items(): 163 | model, mix_console = method["model"] 164 | func = method["func"] 165 | with torch.no_grad(): 166 | result = func( 167 | tracks.clone(), 168 | ref_mix.clone(), 169 | model, 170 | mix_console, 171 | track_start_idx=0, 172 | ref_start_idx=0, 173 | ) 174 | 175 | ( 176 | pred_mix, 177 | pred_track_param_dict, 178 | pred_fx_bus_param_dict, 179 | pred_master_bus_param_dict, 180 | ) = result 181 | 182 | 183 | bs, chs, seq_len = pred_mix.shape 184 | print("pred_mix shape", pred_mix.shape) 185 | # loudness normalize the output mix 186 | mix_lufs_db = meter.integrated_loudness( 187 | pred_mix.squeeze(0).permute(1, 0).numpy() 188 | ) 189 | print("pred_mix_lufs_db", mix_lufs_db) 190 | #print(mix_lufs_db) 191 | lufs_delta_db = target_mix_lufs_db - mix_lufs_db 192 | pred_mix = pred_mix * 10 ** (lufs_delta_db / 20) 193 | pred_mix_name = os.path.basename(mix_dir) + f"-pred_mix-ref_mix-16={instrument}.wav" 194 | mix_filepath = os.path.join(output_dir, pred_mix_name) 195 | torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100) 196 | # append to the json file param_dicts 197 | 198 | #print(pred_track_param_dict["input_gain"]) 199 | 200 | data["pred_mix"] = pred_mix_name 201 | data["gain_values"] = pred_track_param_dict['input_fader']['gain_db'].detach().cpu().numpy().tolist()[0] 202 | #print(type(pred_track_param_dict['input_fader']['gain_db'])) 203 | 204 | 205 | with open(yaml_path, "w") as f: 206 | yaml.dump(data, f) 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /mst/callbacks/mix.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import wandb 5 | import torchaudio 6 | import numpy as np 7 | import pyloudnorm as pyln 8 | import pytorch_lightning as pl 9 | from tqdm import tqdm 10 | from typing import List 11 | 12 | from mst.utils import batch_stereo_peak_normalize 13 | from mst.mixing import naive_random_mix 14 | 15 | 16 | class LogReferenceMix(pl.callbacks.Callback): 17 | def __init__( 18 | self, 19 | root_dirs: List[str], 20 | ref_mixes: List[str], 21 | peak_normalize: bool = True, 22 | sample_rate: int = 44100, 23 | length: int = 524288, 24 | target_track_lufs_db: float = -48.0, 25 | target_mix_lufs_db: float = -16.0, 26 | ): 27 | super().__init__() 28 | self.peak_normalize = peak_normalize 29 | self.sample_rate = sample_rate 30 | self.length = length 31 | self.target_track_lufs_db = target_track_lufs_db 32 | self.target_mix_lufs_db = target_mix_lufs_db 33 | self.meter = pyln.Meter(self.sample_rate) 34 | 35 | print(f"Initalizing reference mix logger with {len(root_dirs)} mixes.") 36 | 37 | self.songs = [] 38 | for root_dir, ref_mix in zip(root_dirs, ref_mixes): 39 | 40 | print(f"Loading {root_dir}...") 41 | song = {} 42 | song["name"] = os.path.basename(root_dir) 43 | 44 | 45 | # load reference mix 46 | x, sr = torchaudio.load(ref_mix) 47 | print(f"Reference mix sample rate: {sr}") 48 | 49 | 50 | # convert sample rate if needed 51 | if sr != sample_rate: 52 | x = torchaudio.functional.resample(x, sr, sample_rate) 53 | 54 | print(f"Reference mix sample rate after resampling: {sample_rate}") 55 | 56 | 57 | song["ref_mix"] = x 58 | 59 | # load tracks 60 | track_filepaths = glob.glob(os.path.join(root_dir, "*.wav")) 61 | tracks = [] 62 | print("Loading tracks...") 63 | for track_idx, track_filepath in enumerate(tqdm(track_filepaths)): 64 | x, sr = torchaudio.load(track_filepath) 65 | 66 | # convert sample rate if needed 67 | if sr != sample_rate: 68 | x = torchaudio.functional.resample(x, sr, sample_rate) 69 | 70 | # separate channels 71 | for ch_idx in range(x.shape[0]): 72 | x_ch = x[ch_idx : ch_idx + 1, :] 73 | 74 | # save 75 | tracks.append(x_ch) 76 | 77 | song["tracks"] = tracks 78 | self.songs.append(song) 79 | 80 | def on_validation_epoch_end( 81 | self, 82 | trainer, 83 | pl_module, 84 | ): 85 | """Called when the validation batch ends.""" 86 | for idx, song in enumerate(self.songs): 87 | ref_mix = song["ref_mix"] 88 | tracks = song["tracks"] 89 | name = song["name"] 90 | 91 | # take a chunk from the middle of the mix 92 | start_idx = (ref_mix.shape[-1] // 2) - (131072 // 2) 93 | stop_idx = start_idx + 131072 94 | ref_mix_chunk = ref_mix[..., start_idx:stop_idx] 95 | 96 | # loudness normalize the mix 97 | mix_lufs_db = self.meter.integrated_loudness( 98 | ref_mix_chunk.permute(1, 0).numpy() 99 | ) 100 | delta_lufs_db = torch.tensor( 101 | [self.target_mix_lufs_db - mix_lufs_db] 102 | ).float() 103 | gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0) 104 | ref_mix_chunk = gain_lin * ref_mix_chunk 105 | 106 | # move to gpu 107 | ref_mix_chunk = ref_mix_chunk.cuda() 108 | 109 | # make a mix of multiple sections of the tracks 110 | for n, start_idx in enumerate([0, 524288, 2 * 524288, 3 * 524288]): 111 | stop_idx = start_idx + 131072 112 | 113 | # loudness normalize tracks 114 | normalized_tracks = [] 115 | for track in tracks: 116 | track = track[..., start_idx:stop_idx] 117 | 118 | if len(normalized_tracks) > 16: 119 | break 120 | 121 | if track.shape[-1] < 131072: 122 | continue 123 | 124 | track_lufs_db = self.meter.integrated_loudness( 125 | track.permute(1, 0).numpy() 126 | ) 127 | 128 | if track_lufs_db < -48.0 or track_lufs_db == float("-inf"): 129 | continue 130 | 131 | delta_lufs_db = torch.tensor( 132 | [self.target_track_lufs_db - track_lufs_db] 133 | ).float() 134 | 135 | gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0) 136 | track = gain_lin * track 137 | normalized_tracks.append(track) 138 | 139 | if len(normalized_tracks) == 0: 140 | continue 141 | 142 | # cat tracks 143 | tracks_chunk = torch.cat(normalized_tracks, dim=0) 144 | tracks_chunk = tracks_chunk.cuda() 145 | 146 | with torch.no_grad(): 147 | # predict parameters using the chunks 148 | ( 149 | pred_track_params, 150 | pred_fx_bus_params, 151 | pred_master_bus_params, 152 | ) = pl_module.model( 153 | tracks_chunk.unsqueeze(0), ref_mix_chunk.unsqueeze(0) 154 | ) 155 | 156 | # generate a mix with full tracks using the predicted mix console parameters 157 | ( 158 | pred_mixed_tracks, 159 | pred_mix_chunk, 160 | pred_track_param_dict, 161 | pred_fx_bus_param_dict, 162 | pred_master_bus_param_dict, 163 | ) = pl_module.mix_console( 164 | tracks_chunk.unsqueeze(0), 165 | pred_track_params, 166 | pred_fx_bus_params, 167 | pred_master_bus_params, 168 | use_track_input_fader=pl_module.use_track_input_fader, 169 | use_track_panner=pl_module.use_track_panner, 170 | use_track_eq=pl_module.use_track_eq, 171 | use_track_compressor=pl_module.use_track_compressor, 172 | use_fx_bus=pl_module.use_fx_bus, 173 | use_master_bus=pl_module.use_master_bus, 174 | use_output_fader=pl_module.use_output_fader, 175 | ) 176 | 177 | # normalize predicted mix 178 | pred_mix_chunk = batch_stereo_peak_normalize(pred_mix_chunk) 179 | 180 | # move back to cpu 181 | pred_mix_chunk = pred_mix_chunk.squeeze(0).cpu() 182 | ref_mix_chunk_out = ref_mix_chunk.squeeze(0).cpu() 183 | 184 | # generate sum mix 185 | sum_mix = tracks_chunk.unsqueeze(0).sum(dim=1, keepdim=True).cpu() 186 | sum_mix = batch_stereo_peak_normalize(sum_mix) 187 | sum_mix = sum_mix.squeeze(0) 188 | 189 | # generate random mix 190 | results = naive_random_mix( 191 | tracks_chunk.unsqueeze(0), 192 | pl_module.mix_console, 193 | use_track_input_fader=pl_module.use_track_input_fader, 194 | use_track_panner=pl_module.use_track_panner, 195 | use_track_eq=pl_module.use_track_eq, 196 | use_track_compressor=pl_module.use_track_compressor, 197 | use_fx_bus=pl_module.use_fx_bus, 198 | use_master_bus=pl_module.use_master_bus, 199 | use_output_fader=pl_module.use_output_fader, 200 | ) 201 | rand_mix = results[1] 202 | rand_mix = batch_stereo_peak_normalize(rand_mix).cpu() 203 | rand_mix = rand_mix.squeeze(0) 204 | 205 | audios = { 206 | "ref_mix": ref_mix_chunk_out, 207 | "pred_mix": pred_mix_chunk, 208 | "sum_mix": sum_mix, 209 | "rand_mix": rand_mix, 210 | } 211 | 212 | total_samples = 0 213 | for x in audios.values(): 214 | total_samples += x.shape[-1] + int( 215 | pl_module.mix_console.sample_rate 216 | ) 217 | 218 | y = torch.zeros(total_samples, 2) 219 | log_name = f"{idx}_{n}{name}" 220 | start = 0 221 | for key, x in audios.items(): 222 | end = start + x.shape[-1] 223 | y[start:end, :] = x.T 224 | start = end + int(pl_module.mix_console.sample_rate) 225 | log_name += key + "-" 226 | 227 | trainer.logger.experiment.log( 228 | { 229 | f"{log_name}": wandb.Audio( 230 | y.numpy(), 231 | sample_rate=int(pl_module.mix_console.sample_rate), 232 | ) 233 | } 234 | ) 235 | -------------------------------------------------------------------------------- /mst/fx_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | import torchaudio 8 | import numpy as np 9 | 10 | import argparse 11 | import yaml 12 | 13 | 14 | currentdir = os.path.dirname(os.path.realpath(__file__)) 15 | sys.path.append(os.path.dirname(currentdir)) 16 | 17 | # 1-dimensional convolutional layer 18 | # in the order of conv -> norm -> activation 19 | class Conv1d_layer(nn.Module): 20 | def __init__(self, in_channels, out_channels, kernel_size, \ 21 | stride=1, \ 22 | padding="SAME", dilation=1, bias=True, \ 23 | norm="batch", activation="relu", \ 24 | mode="conv"): 25 | super(Conv1d_layer, self).__init__() 26 | 27 | self.conv1d = nn.Sequential() 28 | 29 | ''' padding ''' 30 | if mode=="deconv": 31 | padding = int(dilation * (kernel_size-1) / 2) 32 | out_padding = 0 if stride==1 else 1 33 | elif mode=="conv" or "alias_free" in mode: 34 | if padding == "SAME": 35 | pad = int((kernel_size-1) * dilation) 36 | l_pad = int(pad//2) 37 | r_pad = pad - l_pad 38 | padding_area = (l_pad, r_pad) 39 | elif padding == "VALID": 40 | padding_area = (0, 0) 41 | else: 42 | pass 43 | 44 | ''' convolutional layer ''' 45 | if mode=="deconv": 46 | self.conv1d.add_module("deconv1d", nn.ConvTranspose1d(in_channels, out_channels, kernel_size, \ 47 | stride=stride, padding=padding, output_padding=out_padding, \ 48 | dilation=dilation, \ 49 | bias=bias)) 50 | elif mode=="conv": 51 | self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area)) 52 | self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \ 53 | stride=stride, padding=0, \ 54 | dilation=dilation, \ 55 | bias=bias)) 56 | elif "alias_free" in mode: 57 | if "up" in mode: 58 | up_factor = stride * 2 59 | down_factor = 2 60 | elif "down" in mode: 61 | up_factor = 2 62 | down_factor = stride * 2 63 | else: 64 | raise ValueError("choose alias-free method : 'up' or 'down'") 65 | # procedure : conv -> upsample -> lrelu -> low-pass filter -> downsample 66 | # the torchaudio.transforms.Resample's default resampling_method is 'sinc_interpolation' which performs low-pass filter during the process 67 | # details at https://pytorch.org/audio/stable/transforms.html 68 | self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area)) 69 | self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \ 70 | stride=1, padding=0, \ 71 | dilation=dilation, \ 72 | bias=bias)) 73 | self.conv1d.add_module(f"{mode}upsample", torchaudio.transforms.Resample(orig_freq=1, new_freq=up_factor)) 74 | self.conv1d.add_module(f"{mode}lrelu", nn.LeakyReLU()) 75 | self.conv1d.add_module(f"{mode}downsample", torchaudio.transforms.Resample(orig_freq=down_factor, new_freq=1)) 76 | 77 | ''' normalization ''' 78 | if norm=="batch": 79 | self.conv1d.add_module("batch_norm", nn.BatchNorm1d(out_channels)) 80 | # self.conv1d.add_module("batch_norm", nn.SyncBatchNorm(out_channels)) 81 | 82 | ''' activation ''' 83 | if 'alias_free' not in mode: 84 | if activation=="relu": 85 | self.conv1d.add_module("relu", nn.ReLU()) 86 | elif activation=="lrelu": 87 | self.conv1d.add_module("lrelu", nn.LeakyReLU()) 88 | 89 | 90 | def forward(self, input): 91 | # input shape should be : batch x channel x height x width 92 | #print(input.shape) 93 | output = self.conv1d(input) 94 | return output 95 | 96 | 97 | 98 | 99 | # Residual Block 100 | # the input is added after the first convolutional layer, retaining its original channel size 101 | # therefore, the second convolutional layer's output channel may differ 102 | class Res_ConvBlock(nn.Module): 103 | def __init__(self, dimension, \ 104 | in_channels, out_channels, \ 105 | kernel_size, \ 106 | stride=1, padding="SAME", \ 107 | dilation=1, \ 108 | bias=True, \ 109 | norm="batch", \ 110 | activation="relu", last_activation="relu", \ 111 | mode="conv"): 112 | super(Res_ConvBlock, self).__init__() 113 | 114 | if dimension==1: 115 | self.conv1 = Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation) 116 | self.conv2 = Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode) 117 | elif dimension==2: 118 | self.conv1 = Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation) 119 | self.conv2 = Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode) 120 | 121 | 122 | def forward(self, input): 123 | #print("before c1_out in conv1d",input.shape) 124 | c1_out = self.conv1(input) + input 125 | c2_out = self.conv2(c1_out) 126 | return c2_out 127 | 128 | 129 | 130 | # Convoluaionl Block 131 | # consists of multiple (number of layer_num) convolutional layers 132 | # only the final convoluational layer outputs the desired 'out_channels' 133 | class ConvBlock(nn.Module): 134 | def __init__(self, dimension, layer_num, \ 135 | in_channels, out_channels, \ 136 | kernel_size, \ 137 | stride=1, padding="SAME", \ 138 | dilation=1, \ 139 | bias=True, \ 140 | norm="batch", \ 141 | activation="relu", last_activation="relu", \ 142 | mode="conv"): 143 | super(ConvBlock, self).__init__() 144 | 145 | conv_block = [] 146 | if dimension==1: 147 | for i in range(layer_num-1): 148 | conv_block.append(Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)) 149 | conv_block.append(Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)) 150 | elif dimension==2: 151 | for i in range(layer_num-1): 152 | conv_block.append(Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)) 153 | conv_block.append(Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)) 154 | self.conv_block = nn.Sequential(*conv_block) 155 | 156 | 157 | def forward(self, input): 158 | return self.conv_block(input) 159 | 160 | 161 | # FXencoder that extracts audio effects from music recordings trained with a contrastive objective 162 | class FXencoder(nn.Module): 163 | def __init__(self, config): 164 | super(FXencoder, self).__init__() 165 | # input is stereo channeled audio 166 | config["channels"].insert(0, 2) 167 | 168 | # encoder layers 169 | encoder = [] 170 | for i in range(len(config["kernels"])): 171 | if config["conv_block"]=='res': 172 | encoder.append(Res_ConvBlock(dimension=1, \ 173 | in_channels=config["channels"][i], \ 174 | out_channels=config["channels"][i+1], \ 175 | kernel_size=config["kernels"][i], \ 176 | stride=config["strides"][i], \ 177 | padding="SAME", \ 178 | dilation=config["dilation"][i], \ 179 | norm=config["norm"], \ 180 | activation=config["activation"], \ 181 | last_activation=config["activation"])) 182 | elif config["conv_block"]=='conv': 183 | encoder.append(ConvBlock(dimension=1, \ 184 | layer_num=1, \ 185 | in_channels=config["channels"][i], \ 186 | out_channels=config["channels"][i+1], \ 187 | kernel_size=config["kernels"][i], \ 188 | stride=config["strides"][i], \ 189 | padding="VALID", \ 190 | dilation=config["dilation"][i], \ 191 | norm=config["norm"], \ 192 | activation=config["activation"], \ 193 | last_activation=config["activation"], \ 194 | mode='conv')) 195 | self.encoder = nn.Sequential(*encoder) 196 | 197 | # pooling method 198 | self.glob_pool = nn.AdaptiveAvgPool1d(1) 199 | 200 | # network forward operation 201 | def forward(self, input): 202 | #print("in resnet",input.shape) 203 | enc_output = self.encoder(input) 204 | glob_pooled = self.glob_pool(enc_output).squeeze(-1) 205 | 206 | # outputs c feature 207 | return glob_pooled 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /scripts/eval_listen.py: -------------------------------------------------------------------------------- 1 | # run pretrained models over evaluation set to generate audio examples for the listening test 2 | import os 3 | import torch 4 | import torchaudio 5 | import pyloudnorm as pyln 6 | from mst.utils import load_diffmst, run_diffmst 7 | 8 | 9 | def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs): 10 | 11 | meter = pyln.Meter(44100) 12 | target_lufs_db = -48.0 13 | 14 | norm_tracks = [] 15 | for track_idx in range(tracks.shape[1]): 16 | track = tracks[:, track_idx : track_idx + 1, :] 17 | lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy()) 18 | 19 | if lufs_db < -80.0: 20 | print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.") 21 | continue 22 | 23 | lufs_delta_db = target_lufs_db - lufs_db 24 | track *= 10 ** (lufs_delta_db / 20) 25 | norm_tracks.append(track) 26 | 27 | norm_tracks = torch.cat(norm_tracks, dim=1) 28 | # create a sum mix with equal loudness 29 | sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1) 30 | sum_mix /= sum_mix.abs().max() 31 | 32 | return sum_mix, None, None, None 33 | 34 | 35 | if __name__ == "__main__": 36 | meter = pyln.Meter(44100) 37 | target_lufs_db = -22.0 38 | output_dir = "outputs/listen_1" 39 | os.makedirs(output_dir, exist_ok=True) 40 | 41 | methods = { 42 | "diffmst-16": { 43 | "model": load_diffmst( 44 | "/import/c4dm-datasets-ext/Diff-MST/DiffMST/b4naquji/config.yaml", 45 | "/import/c4dm-datasets-ext/Diff-MST/DiffMST/b4naquji/checkpoints/epoch=191-step=626608.ckpt", 46 | ), 47 | "func": run_diffmst, 48 | }, 49 | "sum": { 50 | "model": (None, None), 51 | "func": equal_loudness_mix, 52 | }, 53 | } 54 | 55 | # get the validation examples 56 | examples = { 57 | "ecstasy": { 58 | "tracks": "/import/c4dm-datasets-ext/diffmst-examples/song1/BenFlowers_Ecstasy_Full/", 59 | "track_verse_start_idx": 1190700, 60 | "track_chorus_start_idx": 2381400, 61 | "ref": "/import/c4dm-datasets-ext/diffmst-examples/song1/ref/_Feel it all Around_ by Washed Out (Portlandia Theme)_01.wav", 62 | "ref_verse_start_idx": 970200, 63 | "ref_chorus_start_idx": 198450, 64 | }, 65 | "by-my-side": { 66 | "tracks": "/import/c4dm-datasets-ext/diffmst-examples/song2/Kat Wright_By My Side/", 67 | "track_verse_start_idx": 1146600, 68 | "track_chorus_start_idx": 7144200, 69 | "ref": "/import/c4dm-datasets-ext/diffmst-examples/song2/ref/The Dip - Paddle To The Stars (Lyric Video)_01.wav", 70 | "ref_verse_start_idx": 661500, 71 | "ref_chorus_start_idx": 2028600, 72 | }, 73 | "haunted-aged": { 74 | "tracks": "/import/c4dm-datasets-ext/diffmst-examples/song3/Titanium_HauntedAge_Full/", 75 | "track_verse_start_idx": 1675800, 76 | "track_chorus_start_idx": 3439800, 77 | "ref": "/import/c4dm-datasets-ext/diffmst-examples/song3/ref/Architects - _Doomsday__01.wav", 78 | "ref_verse_start_idx": 4630500, 79 | "ref_chorus_start_idx": 6570900, 80 | }, 81 | } 82 | 83 | for example_name, example in examples.items(): 84 | print(example_name) 85 | example_dir = os.path.join(output_dir, example_name) 86 | os.makedirs(example_dir, exist_ok=True) 87 | # load reference mix 88 | ref_audio, ref_sr = torchaudio.load(example["ref"], backend="soundfile") 89 | if ref_sr != 44100: 90 | ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 44100) 91 | print(ref_audio.shape, ref_sr) 92 | 93 | # first find all the tracks 94 | track_filepaths = [] 95 | for root, dirs, files in os.walk(example["tracks"]): 96 | for filepath in files: 97 | if filepath.endswith(".wav"): 98 | track_filepaths.append(os.path.join(root, filepath)) 99 | 100 | print(f"Found {len(track_filepaths)} tracks.") 101 | 102 | # load the tracks 103 | tracks = [] 104 | lengths = [] 105 | for track_idx, track_filepath in enumerate(track_filepaths): 106 | audio, sr = torchaudio.load(track_filepath, backend="soundfile") 107 | 108 | if sr != 44100: 109 | audio = torchaudio.functional.resample(audio, sr, 44100) 110 | 111 | # loudness normalize the tracks to -48 LUFS 112 | lufs_db = meter.integrated_loudness(audio.permute(1, 0).numpy()) 113 | # lufs_delta_db = -48 - lufs_db 114 | # audio = audio * 10 ** (lufs_delta_db / 20) 115 | 116 | #print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db) 117 | 118 | if audio.shape[0] == 2: 119 | audio = audio.mean(dim=0, keepdim=True) 120 | 121 | chs, seq_len = audio.shape 122 | 123 | for ch_idx in range(chs): 124 | tracks.append(audio[ch_idx : ch_idx + 1, :]) 125 | lengths.append(audio.shape[-1]) 126 | print("Loaded tracks.") 127 | # find max length and pad if shorter 128 | max_length = max(lengths) 129 | for track_idx in range(len(tracks)): 130 | tracks[track_idx] = torch.nn.functional.pad( 131 | tracks[track_idx], (0, max_length - lengths[track_idx]) 132 | ) 133 | print("Padded tracks.") 134 | # stack into a tensor 135 | tracks = torch.cat(tracks, dim=0) 136 | tracks = tracks.view(1, -1, max_length) 137 | ref_audio = ref_audio.view(1, 2, -1) 138 | 139 | # crop tracks to max of 60 seconds or so 140 | # tracks = tracks[..., :4194304] 141 | 142 | print(tracks.shape) 143 | 144 | # create a sum mix with equal loudness 145 | sum_mix = torch.sum(tracks, dim=1, keepdim=True).squeeze(0) 146 | sum_filepath = os.path.join(example_dir, f"{example_name}-sum.wav") 147 | os.makepath(sum_filepath) 148 | print("sum_mix path created") 149 | 150 | # loudness normalize the sum mix 151 | sum_lufs_db = meter.integrated_loudness(sum_mix.permute(1, 0).numpy()) 152 | lufs_delta_db = target_lufs_db - sum_lufs_db 153 | sum_mix = sum_mix * 10 ** (lufs_delta_db / 20) 154 | 155 | torchaudio.save(sum_filepath, sum_mix.view(1, -1), 44100) 156 | print("Sum mix saved.") 157 | 158 | # save the reference mix 159 | ref_filepath = os.path.join(example_dir, "ref-full.wav") 160 | torchaudio.save(ref_filepath, ref_audio.squeeze(), 44100) 161 | print("Reference mix saved.") 162 | 163 | for song_section in ["verse", "chorus"]: 164 | print("Mixing", song_section) 165 | if song_section == "verse": 166 | track_start_idx = example["track_verse_start_idx"] 167 | ref_start_idx = example["ref_verse_start_idx"] 168 | else: 169 | track_start_idx = example["track_chorus_start_idx"] 170 | ref_start_idx = example["ref_chorus_start_idx"] 171 | 172 | if track_start_idx + 262144 > tracks.shape[-1]: 173 | print("Tracks too short for this section.") 174 | if ref_start_idx + 262144 > ref_audio.shape[-1]: 175 | print("Reference too short for this section.") 176 | 177 | # crop the tracks to create a mix twice the size of the reference section 178 | mix_tracks = tracks 179 | # [..., track_start_idx : track_start_idx + (262144 * 2)] 180 | mix_tracks = tracks[..., track_start_idx : track_start_idx + (262144 * 2)] 181 | track_start_idx = 0 182 | print("mix_tracks", mix_tracks.shape) 183 | 184 | # save the reference mix section for analysis 185 | ref_analysis = ref_audio[..., ref_start_idx : ref_start_idx + 262144] 186 | 187 | # create mixes varying the loudness of the reference 188 | for ref_loudness_target in [-24, -16, -14.0, -12, -6]: 189 | print("Ref loudness", ref_loudness_target) 190 | ref_filepath = os.path.join( 191 | example_dir, 192 | f"ref-analysis-{song_section}-lufs-{ref_loudness_target:0.0f}.wav", 193 | ) 194 | 195 | # loudness normalize the reference mix section to -14 LUFS 196 | ref_lufs_db = meter.integrated_loudness( 197 | ref_analysis.squeeze().permute(1, 0).numpy() 198 | ) 199 | lufs_delta_db = ref_loudness_target - ref_lufs_db 200 | ref_analysis = ref_analysis * 10 ** (lufs_delta_db / 20) 201 | 202 | torchaudio.save(ref_filepath, ref_analysis.squeeze(), 44100) 203 | 204 | for method_name, method in methods.items(): 205 | print(method_name) 206 | # tracks (torch.Tensor): Set of input tracks with shape (bs, num_tracks, seq_len) 207 | # ref_audio (torch.Tensor): Reference mix with shape (bs, 2, seq_len) 208 | 209 | if method_name == "sum": 210 | if ref_loudness_target != -16: 211 | continue 212 | 213 | if method_name == "sum" and song_section == "chorus": 214 | continue 215 | 216 | model, mix_console = method["model"] 217 | func = method["func"] 218 | 219 | print(tracks.shape, ref_audio.shape) 220 | 221 | with torch.no_grad(): 222 | result = func( 223 | mix_tracks.clone(), 224 | ref_analysis.clone(), 225 | model, 226 | mix_console, 227 | track_start_idx=track_start_idx, 228 | ref_start_idx=ref_start_idx, 229 | ) 230 | 231 | ( 232 | pred_mix, 233 | pred_track_param_dict, 234 | pred_fx_bus_param_dict, 235 | pred_master_bus_param_dict, 236 | ) = result 237 | 238 | bs, chs, seq_len = pred_mix.shape 239 | 240 | # loudness normalize the output mix 241 | mix_lufs_db = meter.integrated_loudness( 242 | pred_mix.squeeze(0).permute(1, 0).numpy() 243 | ) 244 | print(mix_lufs_db) 245 | lufs_delta_db = target_lufs_db - mix_lufs_db 246 | pred_mix = pred_mix * 10 ** (lufs_delta_db / 20) 247 | 248 | # save resulting audio and parameters 249 | mix_filepath = os.path.join( 250 | example_dir, 251 | f"{example_name}-{method_name}-ref={song_section}-lufs-{ref_loudness_target:0.0f}.wav", 252 | ) 253 | torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100) 254 | 255 | # also save only the analysis section 256 | mix_analysis = pred_mix[ 257 | ..., track_start_idx : track_start_idx + (2 * 262144) 258 | ] 259 | 260 | # loudness normalize the output mix 261 | mix_lufs_db = meter.integrated_loudness( 262 | mix_analysis.squeeze(0).permute(1, 0).numpy() 263 | ) 264 | print(mix_lufs_db) 265 | mix_analysis = mix_analysis * 10 ** (lufs_delta_db / 20) 266 | 267 | mix_filepath = os.path.join( 268 | example_dir, 269 | f"{example_name}-{method_name}-analysis-{song_section}-lufs-{ref_loudness_target:0.0f}.wav", 270 | ) 271 | torchaudio.save(mix_filepath, mix_analysis.view(chs, -1), 44100) 272 | 273 | print() 274 | -------------------------------------------------------------------------------- /mst/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import random 5 | import numpy as np 6 | import pyloudnorm as pyln 7 | 8 | from tqdm import tqdm 9 | from typing import Optional 10 | from importlib import import_module 11 | from mst.modules import MixStyleTransferModel 12 | 13 | 14 | def batch_stereo_peak_normalize(x: torch.Tensor): 15 | """Normalize a batch of stereo mixes by their peak value. 16 | 17 | Args: 18 | x (Tensor): 1-d tensor with shape (bs, 2, seq_len). 19 | 20 | Returns: 21 | x (Tensor): Normalized signal withs shape (vs, 2, seq_len). 22 | """ 23 | # first find the peaks in each channel 24 | gain_lin = x.abs().max(dim=-1, keepdim=True)[0] 25 | # then find the maximum peak across left and right per batch item 26 | gain_lin = gain_lin.max(dim=-2, keepdim=True)[0] 27 | # normalize by the maximum peak 28 | x_norm = x / gain_lin.clamp(1e-8) # avoid division by zero 29 | return x_norm 30 | 31 | 32 | def run_diffmst( 33 | tracks: torch.Tensor, 34 | ref: torch.Tensor, 35 | model: torch.nn.Module, 36 | mix_console: torch.nn.Module, 37 | track_start_idx: int = 0, 38 | ref_start_idx: int = 0, 39 | ): 40 | """Run the differentiable mix style transfer model. 41 | 42 | Args: 43 | tracks (Tensor): Set of input tracks with shape (bs, num_tracks, 1, seq_len). 44 | ref (Tensor): Reference mix with shape (bs, 2, seq_len). 45 | model (torch.nn.Module): MixStyleTransferModel instance. 46 | mix_console (torch.nn.Module): MixConsole instance. 47 | track_start_idx (int, optional): Start index of the track to use. Default: 0. 48 | ref_start_idx (int, optional): Start index of the reference mix to use. Default: 0. 49 | 50 | Returns: 51 | pred_mix (Tensor): Predicted mix with shape (bs, 2, seq_len). 52 | pred_track_param_dict (dict): Dictionary with predicted track parameters. 53 | pred_fx_bus_param_dict (dict): Dictionary with predicted fx bus parameters. 54 | pred_master_bus_param_dict (dict): Dictionary with predicted master bus parameters. 55 | """ 56 | # ------ defaults ------ 57 | use_track_input_fader = True 58 | use_track_panner = True 59 | use_track_eq = True 60 | use_track_compressor = True 61 | use_fx_bus = False 62 | use_master_bus = True 63 | use_output_fader = True 64 | 65 | analysis_len = 262144 66 | meter = pyln.Meter(44100) 67 | 68 | # crop the input tracks and reference mix to the analysis length 69 | if tracks.shape[-1] >= analysis_len: 70 | analysis_tracks = tracks[ 71 | ..., track_start_idx : track_start_idx + analysis_len 72 | ].clone() 73 | else: 74 | analysis_tracks = tracks.clone() 75 | 76 | if ref.shape[-1] >= analysis_len: 77 | analysis_ref = ref[..., ref_start_idx : ref_start_idx + analysis_len] 78 | else: 79 | analysis_ref = ref.clone() 80 | 81 | # loudness normalize the tracks to -48 LUFS 82 | norm_tracks = [] 83 | norm_analysis_tracks = [] 84 | track_padding = [] 85 | for track_idx in range(analysis_tracks.shape[1]): 86 | analysis_track = analysis_tracks[:, track_idx : track_idx + 1, :] 87 | track = tracks[:, track_idx : track_idx + 1, :] 88 | lufs_db = meter.integrated_loudness( 89 | analysis_track.squeeze(0).permute(1, 0).numpy() 90 | ) 91 | if lufs_db < -80.0: 92 | print(f"Skipping track {track_idx} due to low loudness {lufs_db}.") 93 | continue 94 | 95 | lufs_delta_db = -48 - lufs_db 96 | analysis_track *= 10 ** (lufs_delta_db / 20) 97 | track *= 10 ** (lufs_delta_db / 20) 98 | 99 | norm_analysis_tracks.append(analysis_track) 100 | norm_tracks.append(track) 101 | track_padding.append(False) 102 | 103 | norm_analysis_tracks = torch.cat(norm_analysis_tracks, dim=1) 104 | norm_tracks = torch.cat(norm_tracks, dim=1) 105 | print(norm_analysis_tracks.shape, norm_tracks.shape) 106 | 107 | # take only first 16 tracks 108 | # norm_tracks = norm_tracks[:, :16, :] 109 | # norm_analysis_tracks = norm_analysis_tracks[:, :16, :] 110 | print(norm_analysis_tracks.shape, norm_tracks.shape) 111 | 112 | # make tensor contiguous 113 | norm_analysis_tracks = norm_analysis_tracks.contiguous() 114 | norm_tracks = norm_tracks.contiguous() 115 | 116 | # ---- run model to estimate mix parmaeters using analysis audio ---- 117 | pred_track_params, pred_fx_bus_params, pred_master_bus_params = model( 118 | norm_analysis_tracks, analysis_ref 119 | ) 120 | 121 | # ------- generate a mix using the predicted mix console parameters ------- 122 | # apply with sliding window of 262144 samples with overlap 123 | pred_mix = torch.zeros(1, 2, norm_tracks.shape[-1]) 124 | 125 | for i in tqdm(range(0, norm_tracks.shape[-1], analysis_len // 2)): 126 | norm_tracks_window = norm_tracks[..., i : i + analysis_len] 127 | ( 128 | pred_mixed_tracks, 129 | pred_mix_window, 130 | pred_track_param_dict, 131 | pred_fx_bus_param_dict, 132 | pred_master_bus_param_dict, 133 | ) = mix_console( 134 | norm_tracks_window, 135 | pred_track_params, 136 | pred_fx_bus_params, 137 | pred_master_bus_params, 138 | use_track_input_fader=use_track_input_fader, 139 | use_track_panner=use_track_panner, 140 | use_track_eq=use_track_eq, 141 | use_track_compressor=use_track_compressor, 142 | use_fx_bus=use_fx_bus, 143 | use_master_bus=use_master_bus, 144 | use_output_fader=use_output_fader, 145 | ) 146 | if pred_mix_window.shape[-1] < analysis_len: 147 | pred_mix_window = torch.nn.functional.pad( 148 | pred_mix_window, (0, analysis_len - pred_mix_window.shape[-1]) 149 | ) 150 | 151 | window = torch.hann_window(pred_mix_window.shape[-1]) 152 | # apply hann window 153 | if i == 0: 154 | # set the first half of the window to 1 155 | window[: window.shape[-1] // 2] = 1.0 156 | 157 | pred_mix_window *= window 158 | 159 | # check length of the mix window 160 | output_len = pred_mix[..., i : i + analysis_len].shape[-1] 161 | 162 | # overlap add 163 | pred_mix[..., i : i + analysis_len] += pred_mix_window[..., :output_len] 164 | 165 | # crop the mix to the original length 166 | pred_mix = pred_mix[..., : norm_tracks.shape[-1]] 167 | 168 | return ( 169 | pred_mix, 170 | pred_track_param_dict, 171 | pred_fx_bus_param_dict, 172 | pred_master_bus_param_dict, 173 | ) 174 | 175 | 176 | def load_diffmst(config_path: str, ckpt_path: str, map_location: str = "cpu"): 177 | with open(config_path) as f: 178 | config = yaml.safe_load(f) 179 | 180 | core_model_configs = config["model"]["init_args"]["model"] 181 | 182 | module_path, class_name = core_model_configs["class_path"].rsplit(".", 1) 183 | module = import_module(module_path) 184 | model = getattr(module, class_name)(**core_model_configs["init_args"]) 185 | 186 | submodule_configs = core_model_configs["init_args"] 187 | 188 | # create track encoder module 189 | module_path, class_name = submodule_configs["track_encoder"]["class_path"].rsplit( 190 | ".", 1 191 | ) 192 | module = import_module(module_path) 193 | track_encoder = getattr(module, class_name)( 194 | **submodule_configs["track_encoder"]["init_args"] 195 | ) 196 | 197 | # create mix encoder module 198 | module_path, class_name = submodule_configs["mix_encoder"]["class_path"].rsplit( 199 | ".", 1 200 | ) 201 | module = import_module(module_path) 202 | mix_encoder = getattr(module, class_name)( 203 | **submodule_configs["mix_encoder"]["init_args"] 204 | ) 205 | 206 | # create controller module 207 | module_path, class_name = submodule_configs["controller"]["class_path"].rsplit( 208 | ".", 1 209 | ) 210 | module = import_module(module_path) 211 | controller = getattr(module, class_name)( 212 | **submodule_configs["controller"]["init_args"] 213 | ) 214 | 215 | # create mix console module 216 | module_path, class_name = config["model"]["init_args"]["mix_console"][ 217 | "class_path" 218 | ].rsplit(".", 1) 219 | module = import_module(module_path) 220 | mix_console = getattr(module, class_name)( 221 | **config["model"]["init_args"]["mix_console"]["init_args"] 222 | ) 223 | 224 | checkpoint = torch.load(ckpt_path, map_location=map_location) 225 | 226 | # load state dicts 227 | state_dict = {} 228 | for k, v in checkpoint["state_dict"].items(): 229 | if k.startswith("model.track_encoder"): 230 | state_dict[k.replace("model.track_encoder.", "", 1)] = v 231 | track_encoder.load_state_dict(state_dict) 232 | 233 | state_dict = {} 234 | for k, v in checkpoint["state_dict"].items(): 235 | if k.startswith("model.mix_encoder"): 236 | state_dict[k.replace("model.mix_encoder.", "", 1)] = v 237 | mix_encoder.load_state_dict(state_dict) 238 | 239 | state_dict = {} 240 | for k, v in checkpoint["state_dict"].items(): 241 | if k.startswith("model.controller"): 242 | state_dict[k.replace("model.controller.", "", 1)] = v 243 | controller.load_state_dict(state_dict) 244 | 245 | state_dict = {} 246 | for k, v in checkpoint["state_dict"].items(): 247 | if k.startswith("model.mix_console"): 248 | state_dict[k.replace("model.mix_console.", "", 1)] = v 249 | mix_console.load_state_dict(state_dict) 250 | 251 | model = MixStyleTransferModel( 252 | track_encoder, 253 | mix_encoder, 254 | controller, 255 | ) 256 | model.eval() 257 | 258 | return model, mix_console 259 | 260 | 261 | def denorm(p, p_min=0.0, p_max=1.0): 262 | return (p * (p_max - p_min)) + p_min 263 | 264 | 265 | def norm(p, p_min=0.0, p_max=1.0): 266 | return (p - p_min) / (p_max - p_min) 267 | 268 | 269 | def seed_worker(worker_id): 270 | worker_seed = torch.initial_seed() % 2**32 271 | np.random.seed(worker_seed) 272 | random.seed(worker_seed) 273 | 274 | 275 | def center_crop(x, length: int): 276 | start = (x.shape[-1] - length) // 2 277 | stop = start + length 278 | return x[..., start:stop] 279 | 280 | 281 | def causal_crop(x, length: int): 282 | stop = x.shape[-1] - 1 283 | start = stop - length 284 | return x[..., start:stop] 285 | 286 | 287 | def count_parameters(model): 288 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 289 | 290 | 291 | def rand(low=0, high=1): 292 | return (torch.rand(1).numpy()[0] * (high - low)) + low 293 | 294 | 295 | def randint(low=0, high=1): 296 | return torch.randint(low, high + 1, (1,)).numpy()[0] 297 | 298 | 299 | def center_crop(x, length: int): 300 | if x.shape[-1] != length: 301 | start = (x.shape[-1] - length) // 2 302 | stop = start + length 303 | x = x[..., start:stop] 304 | return x 305 | 306 | 307 | def causal_crop(x, length: int): 308 | if x.shape[-1] != length: 309 | stop = x.shape[-1] - 1 310 | start = stop - length 311 | x = x[..., start:stop] 312 | return x 313 | 314 | 315 | def find_first_peak(x, threshold_dB=-36, sample_rate=44100): 316 | """Find the first peak of the input signal. 317 | 318 | Args: 319 | x (Tensor): 1-d tensor with signal. 320 | threshold_dB (float, optional): Minimum peak treshold in dB. Default: -36.0 321 | sample_rate (float, optional): Sample rate of the input signal. Default: 44100 322 | 323 | Returns: 324 | first_peak_sample (int): Sample index of the first peak. 325 | first_peak_sec (float): Location of the first peak in seconds. 326 | """ 327 | signal = 20 * torch.log10(x.view(-1).abs() + 1e-8) 328 | peaks = torch.where(signal > threshold_dB)[0] 329 | first_peak_sample = peaks[0] 330 | first_peak_sec = first_peak_sample / sample_rate 331 | 332 | return first_peak_sample, first_peak_sec 333 | 334 | 335 | def fade_in_and_fade_out(x, fade_ms=10.0, sample_rate=44100): 336 | """Apply a linear fade in and fade out to the last dim of a signal. 337 | 338 | Args: 339 | x (Tensor): Tensor with signal(s). 340 | fade_ms (float, optional): Length of the fade in milliseconds. Default: 10.0 341 | sample_rate (int, optional): Sample rate. Default: 44100 342 | 343 | Returns: 344 | x (Tensor): Faded signal(s). 345 | """ 346 | fade_samples = int(fade_ms * 1e-3 * sample_rate) 347 | fade_in = torch.linspace(0, 1.0, fade_samples) 348 | fade_out = torch.linspace(1.0, 0, fade_samples) 349 | x[..., :fade_samples] *= fade_in 350 | x[..., x.shape[-1] - fade_samples :] *= fade_out 351 | 352 | return x 353 | 354 | 355 | def common_member(a, b): 356 | a_set = set(a) 357 | b_set = set(b) 358 | if a_set & b_set: 359 | return True 360 | else: 361 | return False 362 | -------------------------------------------------------------------------------- /scripts/eval_ablation.py: -------------------------------------------------------------------------------- 1 | # run pretrained models over evaluation set to generate audio examples for the listening test 2 | import os 3 | import torch 4 | import torchaudio 5 | import pyloudnorm as pyln 6 | from mst.utils import load_diffmst, run_diffmst 7 | from mst.loss import compute_barkspectrum, compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, AudioFeatureLoss 8 | import json 9 | import numpy as np 10 | import csv 11 | import glob 12 | 13 | 14 | def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs): 15 | 16 | meter = pyln.Meter(44100) 17 | target_lufs_db = -48.0 18 | 19 | norm_tracks = [] 20 | for track_idx in range(tracks.shape[1]): 21 | track = tracks[:, track_idx : track_idx + 1, :] 22 | lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy()) 23 | 24 | if lufs_db < -80.0: 25 | print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.") 26 | continue 27 | 28 | lufs_delta_db = target_lufs_db - lufs_db 29 | track *= 10 ** (lufs_delta_db / 20) 30 | norm_tracks.append(track) 31 | 32 | norm_tracks = torch.cat(norm_tracks, dim=1) 33 | # create a sum mix with equal loudness 34 | sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1) 35 | sum_mix /= sum_mix.abs().max() 36 | 37 | return sum_mix, None, None, None 38 | 39 | class NumpyEncoder(json.JSONEncoder): 40 | """ Special json encoder for numpy types """ 41 | def default(self, obj): 42 | if isinstance(obj, np.integer): 43 | return int(obj) 44 | elif isinstance(obj, np.floating): 45 | return float(obj) 46 | elif isinstance(obj, np.ndarray): 47 | return obj.tolist() 48 | return json.JSONEncoder.default(self, obj) 49 | 50 | if __name__ == "__main__": 51 | meter = pyln.Meter(44100) 52 | target_lufs_db = -22.0 53 | output_dir = "outputs/ablation" 54 | os.makedirs(output_dir, exist_ok=True) 55 | 56 | methods = { 57 | "diffmst-16": { 58 | "model": load_diffmst( 59 | "/Users/svanka/Downloads/b4naquji/config.yaml", 60 | "/Users/svanka/Downloads/b4naquji/checkpoints/epoch=191-step=626608.ckpt", 61 | ), 62 | "func": run_diffmst, 63 | }, 64 | "sum": { 65 | "model": (None, None), 66 | "func": equal_loudness_mix, 67 | }, 68 | } 69 | 70 | # get the validation examples 71 | examples = { 72 | "ecstasy": { 73 | "tracks": "/Users/svanka/Downloads//diffmst-examples/song1/BenFlowers_Ecstasy_Full/", 74 | "ref": "/Users/svanka/Codes/Diff-MST/outputs/ablation_ref_examples/_Feel it all Around_ by Washed Out (Portlandia Theme)_01/", 75 | }, 76 | "by-my-side": { 77 | "tracks": "/Users/svanka/Downloads//diffmst-examples/song2/Kat Wright_By My Side/", 78 | "ref": "/Users/svanka/Codes/Diff-MST/outputs/ablation_ref_examples/The Dip - Paddle To The Stars (Lyric Video)_01/", 79 | }, 80 | "haunted-aged": { 81 | "tracks": "/Users/svanka/Downloads//diffmst-examples/song3/Titanium_HauntedAge_Full/", 82 | "ref": "/Users/svanka/Codes/Diff-MST/outputs/ablation_ref_examples/Architects - _Doomsday__01/", 83 | }, 84 | } 85 | 86 | 87 | loss = AudioFeatureLoss([0.1,0.001,1.0,1.0,0.1], 44100) 88 | AF = {} 89 | #initialise to negative infinity 90 | 91 | for example_name, example in examples.items(): 92 | 93 | 94 | AF[example_name] = {} 95 | print(example_name) 96 | example_dir = os.path.join(output_dir, example_name) 97 | os.makedirs(example_dir, exist_ok=True) 98 | json_dir = os.path.join(output_dir, "AF") 99 | if not os.path.exists(json_dir): 100 | os.makedirs(json_dir, exist_ok=True) 101 | csv_path = os.path.join(json_dir,f"{example_name}.csv") 102 | # if not os.path.exists(csv_path): 103 | # os.makedirs(csv_path) 104 | with open(csv_path, 'w') as f: 105 | writer = csv.writer(f) 106 | writer.writerow(["method", "audio_type","ablation","start_idx", "stop_idx", "rms", "crest_factor", "stereo_width", "stereo_imbalance", "barkspectrum", "net_AF_loss"]) 107 | f.close() 108 | ref_loudness_target = -16.0 109 | 110 | # --------------first find all the tracks---------------- 111 | track_filepaths = [] 112 | for root, dirs, files in os.walk(example["tracks"]): 113 | for filepath in files: 114 | if filepath.endswith(".wav"): 115 | track_filepaths.append(os.path.join(root, filepath)) 116 | 117 | print(f"Found {len(track_filepaths)} tracks.") 118 | 119 | # ----------------load the tracks---------------------------- 120 | tracks = [] 121 | lengths = [] 122 | for track_idx, track_filepath in enumerate(track_filepaths): 123 | audio, sr = torchaudio.load(track_filepath, backend="soundfile") 124 | 125 | if sr != 44100: 126 | audio = torchaudio.functional.resample(audio, sr, 44100) 127 | 128 | # loudness normalize the tracks to -48 LUFS 129 | lufs_db = meter.integrated_loudness(audio.permute(1, 0).numpy()) 130 | # lufs_delta_db = -48 - lufs_db 131 | # audio = audio * 10 ** (lufs_delta_db / 20) 132 | 133 | print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db) 134 | 135 | if audio.shape[0] == 2: 136 | audio = audio.mean(dim=0, keepdim=True) 137 | 138 | chs, seq_len = audio.shape 139 | 140 | for ch_idx in range(chs): 141 | tracks.append(audio[ch_idx : ch_idx + 1, :]) 142 | lengths.append(audio.shape[-1]) 143 | 144 | # find max length and pad if shorter 145 | max_length = max(lengths) 146 | min_length = min(lengths) 147 | for track_idx in range(len(tracks)): 148 | tracks[track_idx] = torch.nn.functional.pad( 149 | tracks[track_idx], (0, max_length - lengths[track_idx]) 150 | ) 151 | 152 | # stack into a tensor 153 | tracks = torch.cat(tracks, dim=0) 154 | tracks = tracks.view(1, -1, max_length) 155 | tracks_length = max_length 156 | refs = glob.glob(os.path.join(example["ref"],"*.wav")) 157 | print("found refs", len(refs)) 158 | for ref in refs: 159 | ref_name = os.path.basename(ref).replace(".wav", "") 160 | test_type = ref_name.split("_")[-2] + "_" + ref_name.split("_")[-1] 161 | print(test_type) 162 | 163 | print(ref_name) 164 | AF[example_name]["ref"] = {} 165 | AF[example_name]["pred_mix"] = {} 166 | ref_audio, ref_sr = torchaudio.load(ref, backend="soundfile") 167 | if ref_sr != 44100: 168 | ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 44100) 169 | print(ref_audio.shape, ref_sr) 170 | ref_length = ref_audio.shape[-1] 171 | ref_audio = ref_audio.view(1, 2, -1) 172 | 173 | #loudness normalize the reference mix to -16 LUFS 174 | ref_lufs_db = meter.integrated_loudness(ref_audio.squeeze().permute(1, 0).numpy()) 175 | lufs_delta_db = ref_loudness_target - ref_lufs_db 176 | ref_audio = ref_audio * 10 ** (lufs_delta_db / 20) 177 | 178 | 179 | # --------------run inference---------------- 180 | #print(tracks.shape) 181 | track_idx = int(tracks_length / 2) 182 | ref_idx = int(ref_length / 2) 183 | mix_tracks = tracks[..., track_idx - 220500 : track_idx + 220500] 184 | ref_analysis = ref_audio[..., ref_idx - 220500 : ref_idx + 220500] 185 | 186 | ref_path = os.path.join(example_dir, os.path.basename(ref).replace(".wav", "-ref-16.wav")) 187 | torchaudio.save(ref_path, ref_analysis.squeeze(), 44100) 188 | 189 | for method_name, method in methods.items(): 190 | AF[example_name]["ref"] [method_name] = {} 191 | AF[example_name]["pred_mix"] [method_name] = {} 192 | 193 | print(method_name) 194 | model, mix_console = method["model"] 195 | func = method["func"] 196 | 197 | with torch.no_grad(): 198 | result = func( 199 | mix_tracks.clone(), 200 | ref_analysis.clone(), 201 | model, 202 | mix_console, 203 | track_start_idx=0, 204 | ref_start_idx=0, 205 | ) 206 | 207 | ( 208 | pred_mix, 209 | pred_track_param_dict, 210 | pred_fx_bus_param_dict, 211 | pred_master_bus_param_dict, 212 | ) = result 213 | 214 | bs, chs, seq_len = pred_mix.shape 215 | print("pred_mix shape", pred_mix.shape) 216 | # loudness normalize the output mix 217 | mix_lufs_db = meter.integrated_loudness( 218 | pred_mix.squeeze(0).permute(1, 0).numpy() 219 | ) 220 | print("pred_mix_lufs_db", mix_lufs_db) 221 | #print(mix_lufs_db) 222 | lufs_delta_db = target_lufs_db - mix_lufs_db 223 | pred_mix = pred_mix * 10 ** (lufs_delta_db / 20) 224 | name = os.path.basename(ref).replace(".wav", "-pred_mix-16.wav") 225 | mix_filepath = os.path.join(example_dir, f"{method_name}_{name}") 226 | torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100) 227 | 228 | # compute audio features 229 | 230 | AF[example_name]["pred_mix"][method_name]["mix-rms"] = 0.1*compute_rms(pred_mix, sample_rate = sr).mean().detach().cpu().numpy() 231 | AF[example_name]["pred_mix"][method_name]["mix-crest_factor"] = 0.001*compute_crest_factor(pred_mix, sample_rate = sr).mean().detach().cpu().numpy() 232 | AF[example_name]["pred_mix"][method_name]["mix-stereo_width"] = 1.0*compute_stereo_width(pred_mix, sample_rate = sr).detach().cpu().numpy() 233 | AF[example_name]["pred_mix"][method_name]["mix-stereo_imbalance"] = 1.0*compute_stereo_imbalance(pred_mix, sample_rate = sr).detach().cpu().numpy() 234 | AF[example_name]["pred_mix"][method_name]["mix-barkspectrum"] = 0.1*compute_barkspectrum(pred_mix, sample_rate = sr).mean().detach().cpu().numpy() 235 | 236 | AF[example_name]["ref"][method_name]["mix-rms"] = 0.1*compute_rms(ref_analysis, sample_rate = sr).mean().detach().cpu().numpy() 237 | AF[example_name]["ref"][method_name]["mix-crest_factor"] = 0.001*compute_crest_factor(ref_analysis, sample_rate = sr).mean().detach().cpu().numpy() 238 | AF[example_name]["ref"][method_name]["mix-stereo_width"] = 1.0*compute_stereo_width(ref_analysis, sample_rate = sr).detach().cpu().numpy() 239 | AF[example_name]["ref"][method_name]["mix-stereo_imbalance"] = 1.0*compute_stereo_imbalance(ref_analysis, sample_rate = sr).detach().cpu().numpy() 240 | AF[example_name]["ref"][method_name]["mix-barkspectrum"] = 0.1*compute_barkspectrum(ref_analysis, sample_rate = sr).mean().detach().cpu().numpy() 241 | 242 | AF_loss = loss(pred_mix, ref_analysis) 243 | AF[example_name]["pred_mix"][method_name]["net_AF_loss"] = sum(AF_loss.values()).detach().cpu().numpy() 244 | AF[example_name]["ref"][method_name]["net_AF_loss"] = AF[example_name]["pred_mix"][method_name]["net_AF_loss"] 245 | 246 | 247 | # save resulting audio and parameters 248 | #append to csv the method name, audio section, audio features values and net loss on different columns 249 | with open(csv_path, 'a') as f: 250 | writer = csv.writer(f) 251 | writer.writerow([method_name, "pred_mix", test_type, track_idx - 220500, track_idx + 220500,AF[example_name]["pred_mix"][method_name]["mix-rms"], AF[example_name]["pred_mix"][method_name]["mix-crest_factor"], AF[example_name]["pred_mix"][method_name]["mix-stereo_width"], AF[example_name]["pred_mix"][method_name]["mix-stereo_imbalance"], AF[example_name]["pred_mix"][method_name]["mix-barkspectrum"], AF[example_name]["pred_mix"][method_name]["net_AF_loss"]]) 252 | writer.writerow([method_name, "ref", test_type, ref_idx - 220500, ref_idx + 220500,AF[example_name]["ref"][method_name]["mix-rms"], AF[example_name]["ref"][method_name]["mix-crest_factor"], AF[example_name]["ref"][method_name]["mix-stereo_width"], AF[example_name]["ref"][method_name]["mix-stereo_imbalance"], AF[example_name]["ref"][method_name]["mix-barkspectrum"], AF[example_name]["ref"][method_name]["net_AF_loss"]]) 253 | f.close() 254 | 255 | 256 | 257 | #write disctionary to json 258 | 259 | 260 | -------------------------------------------------------------------------------- /scripts/eval_all_combo.py: -------------------------------------------------------------------------------- 1 | # run pretrained models over evaluation set to generate audio examples for the listening test 2 | import os 3 | import torch 4 | import torchaudio 5 | import pyloudnorm as pyln 6 | from mst.utils import load_diffmst, run_diffmst 7 | from mst.loss import compute_barkspectrum, compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, AudioFeatureLoss 8 | import json 9 | import numpy as np 10 | import csv 11 | 12 | 13 | def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs): 14 | 15 | meter = pyln.Meter(44100) 16 | target_lufs_db = -48.0 17 | 18 | norm_tracks = [] 19 | for track_idx in range(tracks.shape[1]): 20 | track = tracks[:, track_idx : track_idx + 1, :] 21 | lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy()) 22 | 23 | if lufs_db < -80.0: 24 | print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.") 25 | continue 26 | 27 | lufs_delta_db = target_lufs_db - lufs_db 28 | track *= 10 ** (lufs_delta_db / 20) 29 | norm_tracks.append(track) 30 | 31 | norm_tracks = torch.cat(norm_tracks, dim=1) 32 | # create a sum mix with equal loudness 33 | sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1) 34 | sum_mix /= sum_mix.abs().max() 35 | 36 | return sum_mix, None, None, None 37 | 38 | class NumpyEncoder(json.JSONEncoder): 39 | """ Special json encoder for numpy types """ 40 | def default(self, obj): 41 | if isinstance(obj, np.integer): 42 | return int(obj) 43 | elif isinstance(obj, np.floating): 44 | return float(obj) 45 | elif isinstance(obj, np.ndarray): 46 | return obj.tolist() 47 | return json.JSONEncoder.default(self, obj) 48 | 49 | if __name__ == "__main__": 50 | meter = pyln.Meter(44100) 51 | target_lufs_db = -22.0 52 | output_dir = "outputs/listen" 53 | os.makedirs(output_dir, exist_ok=True) 54 | 55 | methods = { 56 | "diffmst-16": { 57 | "model": load_diffmst( 58 | "/Users/svanka/Downloads/b4naquji/config.yaml", 59 | "/Users/svanka/Downloads/b4naquji/checkpoints/epoch=191-step=626608.ckpt", 60 | ), 61 | "func": run_diffmst, 62 | }, 63 | "sum": { 64 | "model": (None, None), 65 | "func": equal_loudness_mix, 66 | }, 67 | } 68 | 69 | # get the validation examples 70 | examples = { 71 | # "ecstasy": { 72 | # "tracks": "/Users/svanka/Downloads//diffmst-examples/song1/BenFlowers_Ecstasy_Full/", 73 | # "ref": "/Users/svanka/Downloads//diffmst-examples/song1/ref/_Feel it all Around_ by Washed Out (Portlandia Theme)_01.wav", 74 | # }, 75 | # "by-my-side": { 76 | # "tracks": "/Users/svanka/Downloads//diffmst-examples/song2/Kat Wright_By My Side/", 77 | # "ref": "/Users/svanka/Downloads//diffmst-examples/song2/ref/The Dip - Paddle To The Stars (Lyric Video)_01.wav", 78 | # }, 79 | "haunted-aged": { 80 | "tracks": "/Users/svanka/Downloads//diffmst-examples/song3/Titanium_HauntedAge_Full/", 81 | "ref": "/Users/svanka/Downloads//diffmst-examples/song3/ref/Architects - _Doomsday__01.wav", 82 | }, 83 | } 84 | loss = AudioFeatureLoss([0.1,0.001,1.0,1.0,0.1], 44100) 85 | AF = {} 86 | #initialise to negative infinity 87 | 88 | for example_name, example in examples.items(): 89 | 90 | AF[example_name] = {} 91 | print(example_name) 92 | example_dir = os.path.join(output_dir, example_name) 93 | os.makedirs(example_dir, exist_ok=True) 94 | json_dir = os.path.join(output_dir, "AF") 95 | if not os.path.exists(json_dir): 96 | os.makedirs(json_dir, exist_ok=True) 97 | csv_path = os.path.join(json_dir,f"{example_name}.csv") 98 | # if not os.path.exists(csv_path): 99 | # os.makedirs(csv_path) 100 | with open(csv_path, 'w') as f: 101 | writer = csv.writer(f) 102 | writer.writerow(["method", "audio_section","track_start_idx", "track_stop_idx", "ref_start_idx", "ref_stop_idx", "rms", "crest_factor", "stereo_width", "stereo_imbalance", "barkspectrum", "net_AF_loss"]) 103 | f.close() 104 | 105 | # ----------load reference mix--------------- 106 | ref_audio, ref_sr = torchaudio.load(example["ref"], backend="soundfile") 107 | if ref_sr != 44100: 108 | ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 44100) 109 | print(ref_audio.shape, ref_sr) 110 | ref_length = ref_audio.shape[-1] 111 | # --------------first find all the tracks---------------- 112 | track_filepaths = [] 113 | for root, dirs, files in os.walk(example["tracks"]): 114 | for filepath in files: 115 | if filepath.endswith(".wav"): 116 | track_filepaths.append(os.path.join(root, filepath)) 117 | 118 | print(f"Found {len(track_filepaths)} tracks.") 119 | 120 | # ----------------load the tracks---------------------------- 121 | tracks = [] 122 | lengths = [] 123 | for track_idx, track_filepath in enumerate(track_filepaths): 124 | audio, sr = torchaudio.load(track_filepath, backend="soundfile") 125 | 126 | if sr != 44100: 127 | audio = torchaudio.functional.resample(audio, sr, 44100) 128 | 129 | # loudness normalize the tracks to -48 LUFS 130 | lufs_db = meter.integrated_loudness(audio.permute(1, 0).numpy()) 131 | # lufs_delta_db = -48 - lufs_db 132 | # audio = audio * 10 ** (lufs_delta_db / 20) 133 | 134 | print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db) 135 | 136 | if audio.shape[0] == 2: 137 | audio = audio.mean(dim=0, keepdim=True) 138 | 139 | chs, seq_len = audio.shape 140 | 141 | for ch_idx in range(chs): 142 | tracks.append(audio[ch_idx : ch_idx + 1, :]) 143 | lengths.append(audio.shape[-1]) 144 | 145 | # find max length and pad if shorter 146 | max_length = max(lengths) 147 | min_length = min(lengths) 148 | for track_idx in range(len(tracks)): 149 | tracks[track_idx] = torch.nn.functional.pad( 150 | tracks[track_idx], (0, max_length - lengths[track_idx]) 151 | ) 152 | 153 | # stack into a tensor 154 | tracks = torch.cat(tracks, dim=0) 155 | tracks = tracks.view(1, -1, max_length) 156 | ref_audio = ref_audio.view(1, 2, -1) 157 | 158 | # crop tracks to max of 60 seconds or so 159 | # tracks = tracks[..., :4194304] 160 | tracks_length = max_length 161 | 162 | #print(tracks.shape) 163 | track_start_idx = int(tracks_length / 4) 164 | ref_start_idx = int(ref_length / 4) 165 | track_stop_idx = int(3*tracks_length / 4) 166 | ref_stop_idx = int(3*ref_length / 4) 167 | #find the number of sets of track samples of 10 sec duration each 168 | track_num_sets = int((track_stop_idx - track_start_idx) / 441000) 169 | ref_num_sets = int((ref_stop_idx - ref_start_idx) / 441000) 170 | print("track_num_sets", track_num_sets) 171 | print("ref_num_sets", ref_num_sets) 172 | min_AF_loss = float('inf') 173 | min_AF_loss_example = None 174 | for i in range(track_num_sets): 175 | for j in range(ref_num_sets): 176 | print(f"track-{i}-ref-{j}") 177 | #run inference for every combination of track and ref samples and calculate audio features. 178 | # We will save the audio features to a csv and audio files in the output directory 179 | mix_tracks = tracks[..., track_start_idx + i*441000 : track_start_idx + (i+1)*441000] 180 | ref_analysis = ref_audio[..., ref_start_idx + j*441000 : ref_start_idx + (j+1)*441000] 181 | 182 | # create mixes varying the loudness of the reference 183 | for ref_loudness_target in [-16.0]: 184 | print("Ref loudness", ref_loudness_target) 185 | ref_filepath = os.path.join( 186 | example_dir, 187 | f"ref-analysis-track-{i}-ref-{j}-lufs-{ref_loudness_target:0.0f}.wav", 188 | ) 189 | 190 | # loudness normalize the reference mix section to -14 LUFS 191 | ref_lufs_db = meter.integrated_loudness( 192 | ref_analysis.squeeze().permute(1, 0).numpy() 193 | ) 194 | print("ref_lufs_db", ref_lufs_db) 195 | lufs_delta_db = ref_loudness_target - ref_lufs_db 196 | ref_analysis = ref_analysis * 10 ** (lufs_delta_db / 20) 197 | 198 | torchaudio.save(ref_filepath, ref_analysis.squeeze(), 44100) 199 | 200 | AF_loss = 0 201 | for method_name, method in methods.items(): 202 | AF[example_name][method_name] = {} 203 | print(method_name) 204 | # tracks (torch.Tensor): Set of input tracks with shape (bs, num_tracks, seq_len) 205 | # ref_audio (torch.Tensor): Reference mix with shape (bs, 2, seq_len) 206 | 207 | if method_name == "sum": 208 | if ref_loudness_target != -16: 209 | continue 210 | 211 | 212 | model, mix_console = method["model"] 213 | func = method["func"] 214 | 215 | #print(tracks.shape, ref_audio.shape) 216 | audio_section = f"track-{i}-ref-{j}-lufs-{ref_loudness_target:0.0f}" 217 | AF[example_name][method_name][audio_section] = {} 218 | AF[example_name][method_name][audio_section]["track_start_idx"] = track_start_idx + i*441000 219 | AF[example_name][method_name][audio_section]["track_stop_idx"] = track_start_idx + (i+1)*441000 220 | AF[example_name][method_name][audio_section]["ref_start_idx"] = ref_start_idx + j*441000 221 | AF[example_name][method_name][audio_section]["ref_stop_idx"] = ref_start_idx + (j+1)*441000 222 | with torch.no_grad(): 223 | result = func( 224 | mix_tracks.clone(), 225 | ref_analysis.clone(), 226 | model, 227 | mix_console, 228 | track_start_idx=0, 229 | ref_start_idx=0, 230 | ) 231 | 232 | ( 233 | pred_mix, 234 | pred_track_param_dict, 235 | pred_fx_bus_param_dict, 236 | pred_master_bus_param_dict, 237 | ) = result 238 | 239 | bs, chs, seq_len = pred_mix.shape 240 | print("pred_mix shape", pred_mix.shape) 241 | # loudness normalize the output mix 242 | mix_lufs_db = meter.integrated_loudness( 243 | pred_mix.squeeze(0).permute(1, 0).numpy() 244 | ) 245 | print("pred_mix_lufs_db", mix_lufs_db) 246 | #print(mix_lufs_db) 247 | lufs_delta_db = target_lufs_db - mix_lufs_db 248 | pred_mix = pred_mix * 10 ** (lufs_delta_db / 20) 249 | mix_filepath = os.path.join( 250 | example_dir, 251 | f"{example_name}-{method_name}-tracks-{i}-ref={j}-lufs-{ref_loudness_target:0.0f}.wav", 252 | ) 253 | torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100) 254 | 255 | # compute audio features 256 | AF_loss = loss(pred_mix, ref_analysis) 257 | 258 | for key, value in AF_loss.items(): 259 | AF[example_name][method_name][audio_section][key] = value.detach().cpu().numpy() 260 | AF[example_name][method_name][audio_section]["net_AF_loss"] = sum(AF_loss.values()).detach().cpu().numpy() 261 | print(AF[example_name][method_name][audio_section]) 262 | 263 | if AF[example_name][method_name][audio_section]["net_AF_loss"] < min_AF_loss: 264 | min_AF_loss = AF[example_name][method_name][audio_section]["net_AF_loss"] 265 | min_AF_loss_example = f"{example_name}-{method_name}-{audio_section}" 266 | print("min_AF_loss", min_AF_loss) 267 | print("min_AF_loss_example", min_AF_loss_example) 268 | # save resulting audio and parameters 269 | #append to csv the method name, audio section, audio features values and net loss on different columns 270 | 271 | with open(csv_path, 'a') as f: 272 | writer = csv.writer(f) 273 | writer.writerow([method_name, audio_section, AF[example_name][method_name][audio_section]["track_start_idx"], AF[example_name][method_name][audio_section]["track_stop_idx"], AF[example_name][method_name][audio_section]["ref_start_idx"], AF[example_name][method_name][audio_section]["ref_stop_idx"], AF[example_name][method_name][audio_section]["mix-rms"], AF[example_name][method_name][audio_section]["mix-crest_factor"], AF[example_name][method_name][audio_section]["mix-stereo_width"], AF[example_name][method_name][audio_section]["mix-stereo_imbalance"], AF[example_name][method_name][audio_section]["mix-barkspectrum"], AF[example_name][method_name][audio_section]["net_AF_loss"]]) 274 | f.close() 275 | 276 | 277 | print(f"for {example_name} min loss is {min_AF_loss} corresponding to {min_AF_loss_example}") 278 | print() 279 | 280 | #write disctionary to json 281 | 282 | 283 | -------------------------------------------------------------------------------- /scripts/online.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import argparse 5 | import torchaudio 6 | import numpy as np 7 | import pyloudnorm as pyln 8 | import matplotlib.pyplot as plt 9 | from tqdm import tqdm 10 | 11 | from mst.loss import AudioFeatureLoss, StereoCLAPLoss, compute_barkspectrum 12 | from mst.modules import AdvancedMixConsole 13 | 14 | 15 | def optimize( 16 | tracks: torch.Tensor, 17 | ref_mix: torch.Tensor, 18 | mix_console: torch.nn.Module, 19 | loss_function: torch.nn.Module, 20 | init_scale: float = 0.001, 21 | lr: float = 1e-3, 22 | n_iters: int = 100, 23 | ): 24 | """Create a mix from the tracks that is as close as possible to the reference mixture. 25 | 26 | Args: 27 | tracks (torch.Tensor): Tensor of shape (n_tracks, n_samples). 28 | ref_mix (torch.Tensor): Tensor of shape (2, n_samples). 29 | mix_console (torch.nn.Module): Mix console instance. (e.g. AdvancedMixConsole) 30 | loss_function (torch.nn.Module): Loss function instance. (e.g. AudioFeatureLoss) 31 | n_iters (int): Number of iterations for the optimization. 32 | 33 | Returns: 34 | torch.Tensor: Tensor of shape (2, n_samples) that is as close as possible to the reference mixture. 35 | """ 36 | loss_history = {"loss": []} # lists to store loss values 37 | 38 | # initialize the mix console parameters to optimize 39 | track_params = init_scale * torch.randn( 40 | tracks.shape[0], mix_console.num_track_control_params 41 | ) 42 | fx_bus_params = init_scale * torch.randn(1, mix_console.num_fx_bus_control_params) 43 | master_bus_params = init_scale * torch.randn( 44 | 1, mix_console.num_master_bus_control_params 45 | ) 46 | 47 | # move parameters to same device as tracks 48 | track_params = track_params.type_as(tracks) 49 | fx_bus_params = fx_bus_params.type_as(tracks) 50 | master_bus_params = master_bus_params.type_as(tracks) 51 | 52 | # require gradients for the parameters 53 | track_params.requires_grad = True 54 | fx_bus_params.requires_grad = True 55 | master_bus_params.requires_grad = True 56 | 57 | # create optimizer and link to console parameters 58 | optimizer = torch.optim.Adam( 59 | [track_params, fx_bus_params, master_bus_params], lr=lr 60 | ) 61 | 62 | pbar = tqdm(range(n_iters)) 63 | 64 | # reshape 65 | tracks = tracks.unsqueeze(0) 66 | ref_mix = ref_mix.unsqueeze(0) 67 | track_params = track_params.unsqueeze(0) 68 | fx_bus_params = fx_bus_params.unsqueeze(0) 69 | master_bus_params = master_bus_params.unsqueeze(0) 70 | 71 | for n in pbar: 72 | optimizer.zero_grad() 73 | 74 | # mix the tracks using the mix console 75 | # the mix console parameters are sigmoided to ensure they are in the range [0, 1] 76 | result = mix_console( 77 | tracks, 78 | torch.sigmoid(track_params), 79 | torch.sigmoid(fx_bus_params), 80 | torch.sigmoid(master_bus_params), 81 | use_fx_bus=False, 82 | ) 83 | mix = result[1] 84 | track_param_dict = result[2] 85 | fx_bus_param_dict = result[3] 86 | master_bus_param_dict = result[4] 87 | 88 | # compute loss 89 | loss = 0 90 | losses = loss_function(mix, ref_mix) 91 | for loss_name, loss_value in losses.items(): 92 | loss += loss_value 93 | 94 | # compute gradients and update parameters 95 | loss.backward() 96 | optimizer.step() 97 | 98 | # update progress bar 99 | pbar.set_description(f"Loss: {loss.item():.4f}") 100 | 101 | # store loss values 102 | loss_history["loss"].append(loss.item()) 103 | for loss_name, loss_value in losses.items(): 104 | if loss_name not in loss_history: 105 | loss_history[loss_name] = [] 106 | loss_history[loss_name].append(loss_value.item()) 107 | 108 | # reshape 109 | mix = mix.squeeze(0) 110 | track_params = track_params # .squeeze(0) 111 | fx_bus_params = fx_bus_params # .squeeze(0) 112 | master_bus_params = master_bus_params # .squeeze(0) 113 | 114 | return ( 115 | mix, 116 | track_params, 117 | track_param_dict, 118 | fx_bus_params, 119 | fx_bus_param_dict, 120 | master_bus_params, 121 | master_bus_param_dict, 122 | loss_history, 123 | ) 124 | 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument( 129 | "--track_dir", 130 | type=str, 131 | help="Path to directory with tracks.", 132 | ) 133 | parser.add_argument( 134 | "--ref_mix", 135 | type=str, 136 | help="Path to reference mixture.", 137 | ) 138 | parser.add_argument( 139 | "--n_iters", 140 | type=int, 141 | help="Number of iterations.", 142 | default=250, 143 | ) 144 | parser.add_argument( 145 | "--lr", 146 | type=float, 147 | help="Learning rate.", 148 | default=1e-3, 149 | ) 150 | parser.add_argument( 151 | "--loss", 152 | type=str, 153 | help="Loss function.", 154 | choices=["feat", "clap"], 155 | default="feat", 156 | ) 157 | parser.add_argument( 158 | "--output_dir", 159 | type=str, 160 | help="Path to output directory.", 161 | default="outputs", 162 | ) 163 | parser.add_argument( 164 | "--use_gpu", 165 | action="store_true", 166 | help="Use GPU for optimization.", 167 | ) 168 | parser.add_argument( 169 | "--sample_rate", 170 | type=int, 171 | help="Sample rate of tracks.", 172 | default=44100, 173 | ) 174 | parser.add_argument( 175 | "--block_size", 176 | type=int, 177 | default=524288, 178 | help="Analysis block size.", 179 | ) 180 | parser.add_argument( 181 | "--start_time_s", 182 | type=float, 183 | default=32.0, 184 | help="Analysis block start time.", 185 | ) 186 | parser.add_argument( 187 | "--target_track_lufs_db", 188 | type=float, 189 | default=-48.0, 190 | ) 191 | parser.add_argument( 192 | "--target_mix_lufs_db", 193 | type=float, 194 | default=-14.0, 195 | ) 196 | parser.add_argument( 197 | "--stem_separation", 198 | action="store_true", 199 | ) 200 | parser 201 | args = parser.parse_args() 202 | 203 | meter = pyln.Meter(args.sample_rate) 204 | run_name = f"{os.path.basename(args.track_dir)}-->{os.path.basename(args.ref_mix).split('.')[0]}" 205 | output_dir = os.path.join(args.output_dir, run_name) 206 | os.makedirs(output_dir, exist_ok=True) 207 | os.makedirs(os.path.join(output_dir, "plots"), exist_ok=True) 208 | 209 | # -------------------------- data loading -------------------------- # 210 | # load tracks 211 | tracks = [] 212 | print(f"Loading tracks for current run: {run_name}...") 213 | track_filepaths = sorted(glob.glob(os.path.join(args.track_dir, "*.wav"))) 214 | num_tracks = len(track_filepaths) 215 | for track_idx, track_filepath in enumerate(track_filepaths): 216 | track, track_sample_rate = torchaudio.load(os.path.join(track_filepath)) 217 | print( 218 | f"{track_idx+1}/{num_tracks}: {track.shape} {os.path.basename(track_filepath)}" 219 | ) 220 | 221 | # check if track has same sample rate as reference mixture 222 | # if not, resample 223 | if track_sample_rate != args.sample_rate: 224 | track = torchaudio.transforms.Resample(track_sample_rate, args.sample_rate)( 225 | track 226 | ) 227 | 228 | # check if the track is silent 229 | for ch_idx in range(track.shape[0]): 230 | # measure loudness 231 | track_lufs_db = meter.integrated_loudness(track[ch_idx, :].numpy()) 232 | 233 | if track_lufs_db < -60.0: 234 | print(f"Track is inactive at {track_lufs_db:0.2f} dB. Skipping...") 235 | continue 236 | else: 237 | # loudness normalize 238 | delta_lufs_db = args.target_track_lufs_db - track_lufs_db 239 | delta_lufs_lin = 10 ** (delta_lufs_db / 20) 240 | tracks.append(delta_lufs_lin * track[ch_idx, :]) 241 | 242 | tracks = torch.stack(tracks) # shape: (n_tracks, n_samples) 243 | 244 | # load reference mixture 245 | ref_mix, ref_sample_rate = torchaudio.load(args.ref_mix) 246 | 247 | if ref_sample_rate != args.sample_rate: 248 | ref_mix = torchaudio.transforms.Resample(ref_sample_rate, args.sample_rate)( 249 | ref_mix 250 | ) 251 | mix_lufs_db = meter.integrated_loudness(ref_mix.permute(1, 0).numpy()) 252 | delta_lufs_db = args.target_mix_lufs_db - mix_lufs_db 253 | delta_lufs_lin = 10 ** (delta_lufs_db / 20) 254 | ref_mix = delta_lufs_lin * ref_mix 255 | 256 | # use only a subsection of the reference mixture and tracks 257 | # this is to speed up the optimization 258 | start_time_s = args.start_time_s 259 | start_sample = int(start_time_s * args.sample_rate) 260 | 261 | ref_mix_section = ref_mix[:, start_sample : start_sample + args.block_size] 262 | tracks_section = tracks[:, start_sample : start_sample + args.block_size] 263 | 264 | print(ref_mix.shape, tracks.shape) 265 | print(ref_mix_section.shape, tracks_section.shape) 266 | 267 | if args.use_gpu: 268 | ref_mix_section = ref_mix_section.cuda() 269 | tracks_section = tracks_section.cuda() 270 | 271 | # -------------------------- setup -------------------------- # 272 | # mix console will use the same sample rate as the tracks 273 | mix_console = AdvancedMixConsole(args.sample_rate) 274 | 275 | weights = [ 276 | 0.1, # rms 277 | 0.001, # crest factor 278 | 1.0, # stereo width 279 | 1.0, # stereo imbalance 280 | 1.00, # bark spectrum 281 | 100.0, # clap 282 | ] 283 | 284 | if args.loss == "feat": 285 | loss_function = AudioFeatureLoss( 286 | weights, 287 | args.sample_rate, 288 | stem_separation=args.stem_separation, 289 | ) 290 | elif args.loss == "clap": 291 | loss_function = StereoCLAPLoss() 292 | else: 293 | raise ValueError(f"Unknown loss: {args.loss}") 294 | 295 | if args.use_gpu: 296 | loss_function.cuda() 297 | 298 | # -------------------------- optimization -------------------------- # 299 | result = optimize( 300 | tracks_section, 301 | ref_mix_section, 302 | mix_console, 303 | loss_function, 304 | n_iters=args.n_iters, 305 | ) 306 | 307 | mix = result[0] 308 | track_params = result[1] 309 | track_param_dict = result[2] 310 | fx_bus_parms = result[3] 311 | fx_bus_param_dict = result[4] 312 | master_bus_params = result[5] 313 | master_bus_param_dict = result[6] 314 | loss_history = result[7] 315 | 316 | ref_mix = ref_mix.squeeze(0).cpu() 317 | mono_mix = tracks.sum(dim=0).repeat(2, 1).cpu() 318 | 319 | # print(track_param_dict) 320 | # print(fx_bus_param_dict) 321 | # print(master_bus_param_dict) 322 | print(mix.abs().max()) 323 | print(mono_mix.abs().max()) 324 | 325 | # ----------------------- full mix generation ---------------------- # 326 | # iterate over the tracks in blocks to mix the entire song 327 | # this is to avoid memory issues 328 | block_size = args.block_size 329 | n_blocks = tracks.shape[-1] // block_size 330 | full_mix = torch.zeros(2, tracks.shape[-1]) 331 | for block_idx in tqdm(range(n_blocks)): 332 | tracks_block = tracks[:, block_idx * block_size : (block_idx + 1) * block_size] 333 | tracks_block = tracks_block.type_as(tracks_section) 334 | 335 | with torch.no_grad(): 336 | result = mix_console( 337 | tracks_block.unsqueeze(0), 338 | torch.sigmoid(track_params), 339 | torch.sigmoid(fx_bus_parms), 340 | torch.sigmoid(master_bus_params), 341 | use_fx_bus=False, 342 | ) 343 | mix_block = result[1].squeeze(0).cpu() 344 | full_mix[ 345 | :, block_idx * block_size : (block_idx + 1) * block_size 346 | ] = mix_block 347 | 348 | # loudness normalize 349 | full_mix /= full_mix.abs().max() 350 | mono_mix /= mono_mix.abs().max() 351 | 352 | mono_mix_section = mono_mix[:, start_sample : start_sample + args.block_size] 353 | 354 | # -------------------------- analyze mixes -------------------------- # 355 | 356 | ref_spec = compute_barkspectrum(ref_mix.unsqueeze(0), sample_rate=args.sample_rate) 357 | pred_spec = compute_barkspectrum( 358 | full_mix.unsqueeze(0), sample_rate=args.sample_rate 359 | ) 360 | 361 | fig, axs = plt.subplots(2, 1, sharex=True, sharey=True) 362 | axs[0].plot(ref_spec[0, :, 0], label="ref-mid", color="tab:orange") 363 | axs[0].plot(pred_spec[0, :, 0], label="pred-mid", color="tab:blue") 364 | axs[1].plot(ref_spec[0, :, 1], label="ref-side", color="tab:orange") 365 | axs[1].plot(pred_spec[0, :, 1], label="pred-side", color="tab:blue") 366 | axs[0].legend() 367 | axs[1].legend() 368 | plt.savefig(os.path.join(output_dir, "plots", "bark_specta.png")) 369 | plt.close("all") 370 | 371 | for idx, (loss_name, loss_vals) in enumerate(loss_history.items()): 372 | fig, ax = plt.subplots(1, 1) 373 | ax.plot(loss_vals, label=loss_name) 374 | ax.set_xlabel("Iteration") 375 | ax.set_ylabel(f"{loss_name}") 376 | plt.savefig(os.path.join(output_dir, "plots", f"{loss_name}.png")) 377 | plt.close("all") 378 | 379 | # -------------------------- save results -------------------------- # 380 | # save mix 381 | torchaudio.save( 382 | os.path.join(output_dir, "pred_mix_section.wav"), 383 | mix.squeeze(0).cpu(), 384 | args.sample_rate, 385 | ) 386 | torchaudio.save( 387 | os.path.join(output_dir, "ref_mix_section.wav"), 388 | ref_mix_section.cpu(), 389 | args.sample_rate, 390 | ) 391 | torchaudio.save( 392 | os.path.join(output_dir, "mono_mix_section.wav"), 393 | mono_mix_section, 394 | args.sample_rate, 395 | ) 396 | torchaudio.save( 397 | os.path.join(output_dir, "pred_mix.wav"), 398 | full_mix, 399 | args.sample_rate, 400 | ) 401 | torchaudio.save( 402 | os.path.join(output_dir, "mono_mix.wav"), 403 | mono_mix, 404 | args.sample_rate, 405 | ) 406 | torchaudio.save( 407 | os.path.join(output_dir, "ref_mix.wav"), 408 | ref_mix, 409 | args.sample_rate, 410 | ) 411 | -------------------------------------------------------------------------------- /mst/system.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import yaml 4 | import torch 5 | import auraloss 6 | import pytorch_lightning as pl 7 | 8 | import time 9 | from typing import Callable 10 | from mst.mixing import knowledge_engineering_mix 11 | from mst.utils import batch_stereo_peak_normalize 12 | from mst.fx_encoder import FXencoder 13 | import pyloudnorm as pyln 14 | 15 | 16 | class System(pl.LightningModule): 17 | def __init__( 18 | self, 19 | model: torch.nn.Module, 20 | mix_console: torch.nn.Module, 21 | mix_fn: Callable, 22 | loss: torch.nn.Module, 23 | generate_mix: bool = True, 24 | use_track_loss: bool = False, 25 | use_mix_loss: bool = True, 26 | use_param_loss: bool = False, 27 | instrument_id_json: str = "data/instrument_name2id.json", 28 | knowledge_engineering_yaml: str = "data/knowledge_engineering.yaml", 29 | active_eq_epoch: int = 0, 30 | active_compressor_epoch: int = 0, 31 | active_fx_bus_epoch: int = 0, 32 | active_master_bus_epoch: int = 0, 33 | lr: float = 1e-4, 34 | max_epochs: int = 500, 35 | schedule: str = "step", 36 | **kwargs, 37 | ) -> None: 38 | super().__init__() 39 | self.model = model 40 | self.mix_console = mix_console 41 | self.mix_fn = mix_fn 42 | self.loss = loss 43 | self.generate_mix = generate_mix 44 | self.use_track_loss = use_track_loss 45 | self.use_mix_loss = use_mix_loss 46 | self.use_param_loss = use_param_loss 47 | self.active_eq_epoch = active_eq_epoch 48 | self.active_compressor_epoch = active_compressor_epoch 49 | self.active_fx_bus_epoch = active_fx_bus_epoch 50 | self.active_master_bus_epoch = active_master_bus_epoch 51 | 52 | self.meter = pyln.Meter(44100) 53 | #self.warmup = warmup 54 | 55 | 56 | self.save_hyperparameters(ignore=["model", "mix_console", "mix_fn", "loss"]) 57 | 58 | 59 | # losses for evaluation 60 | self.sisdr = auraloss.time.SISDRLoss() 61 | self.mrstft = auraloss.freq.MultiResolutionSTFTLoss( 62 | fft_sizes=[512, 2048, 8192], 63 | hop_sizes=[256, 1024, 4096], 64 | win_lengths=[512, 2048, 8192], 65 | w_sc=0.0, 66 | w_phs=0.0, 67 | w_lin_mag=1.0, 68 | w_log_mag=1.0, 69 | ) 70 | 71 | # load configuration files 72 | if mix_fn is knowledge_engineering_mix: 73 | with open(instrument_id_json, "r") as f: 74 | self.instrument_number_lookup = json.load(f) 75 | 76 | with open(knowledge_engineering_yaml, "r") as f: 77 | self.knowledge_engineering_dict = yaml.safe_load(f) 78 | else: 79 | self.instrument_number_lookup = None 80 | self.knowledge_engineering_dict = None 81 | 82 | # default 83 | self.use_track_input_fader = True 84 | self.use_track_panner = True 85 | self.use_track_eq = False 86 | self.use_track_compressor = False 87 | self.use_fx_bus = False 88 | self.use_master_bus = False 89 | self.use_output_fader = True 90 | 91 | def forward(self, tracks: torch.Tensor, ref_mix: torch.Tensor) -> torch.Tensor: 92 | """Apply model to audio waveform tracks. 93 | Args: 94 | tracks (torch.Tensor): Set of input tracks with shape (bs, num_tracks, 1, seq_len) 95 | ref_mix (torch.Tensor): Reference mix with shape (bs, 2, seq_len) 96 | 97 | Returns: 98 | pred_mix (torch.Tensor): Predicted mix with shape (bs, 2, seq_len) 99 | """ 100 | return self.model(tracks, ref_mix) 101 | 102 | def common_step( 103 | self, 104 | batch: tuple, 105 | batch_idx: int, 106 | train: bool = False, 107 | ): 108 | """Model step used for validation and training. 109 | Args: 110 | batch (Tuple[Tensor, Tensor]): Batch items containing rmix, stems and orig mix 111 | batch_idx (int): Index of the batch within the current epoch. 112 | optimizer_idx (int): Index of the optimizer, this step is called once for each optimizer. 113 | train (bool): Wether step is called during training (True) or validation (False). 114 | """ 115 | 116 | tracks, instrument_id, stereo_info, track_padding, ref_mix, song_name = batch 117 | #print("song_names from this batch: ", song_name) 118 | 119 | # split into A and B sections 120 | middle_idx = tracks.shape[-1] // 2 121 | 122 | # disable parts of the mix console based on global step 123 | if self.current_epoch >= self.active_eq_epoch: 124 | self.use_track_eq = True 125 | 126 | if self.current_epoch >= self.active_compressor_epoch: 127 | self.use_track_compressor = True 128 | 129 | if self.current_epoch >= self.active_fx_bus_epoch: 130 | self.use_fx_bus = True 131 | 132 | if self.current_epoch >= self.active_master_bus_epoch: 133 | self.use_master_bus = True 134 | 135 | bs, num_tracks, seq_len = tracks.shape 136 | 137 | # apply random gain to input tracks 138 | # tracks *= 10 ** ((torch.rand(bs, num_tracks, 1).type_as(tracks) * -12.0) / 20.0) 139 | ref_track_param_dict = None 140 | ref_fx_bus_param_dict = None 141 | ref_master_bus_param_dict = None 142 | 143 | # if tracks[...,middle_idx:].sum() == 0: 144 | # print("tracks are zero") 145 | # print(tracks[...,middle_idx:]) 146 | # raise ValueError("input tracks are zero") 147 | 148 | # --------- create a random mix (on GPU, if applicable) --------- 149 | if self.generate_mix: 150 | ( 151 | ref_mix_tracks, 152 | ref_mix, 153 | ref_track_param_dict, 154 | ref_fx_bus_param_dict, 155 | ref_master_bus_param_dict, 156 | ref_mix_params, 157 | ref_fx_bus_params, 158 | ref_master_bus_params 159 | ) = self.mix_fn( 160 | tracks, 161 | self.mix_console, 162 | use_track_input_fader=False, # do not use track input fader for training 163 | use_track_panner=self.use_track_panner, 164 | use_track_eq=self.use_track_eq, 165 | use_track_compressor=self.use_track_compressor, 166 | use_fx_bus=self.use_fx_bus, 167 | use_master_bus=self.use_master_bus, 168 | use_output_fader=False, # not used because we normalize output mixes 169 | instrument_id=instrument_id, 170 | stereo_id=stereo_info, 171 | instrument_number_file=self.instrument_number_lookup, 172 | ke_dict=self.knowledge_engineering_dict, 173 | ) 174 | 175 | # normalize the reference mix 176 | ref_mix = batch_stereo_peak_normalize(ref_mix) 177 | 178 | if torch.isnan(ref_mix).any(): 179 | #print(ref_track_param_dict) 180 | raise ValueError("Found nan in ref_mix") 181 | 182 | 183 | # if torch.count_nonzero(ref_mix[...,0:middle_idx])< 1: 184 | # print("ref_mix is zero") 185 | # raise ValueError("ref_mix is zero") 186 | 187 | ref_mix_a = ref_mix[..., :middle_idx] # this is passed to the model 188 | ref_mix_b = ref_mix[..., middle_idx:] # this is used for loss computation 189 | 190 | else: 191 | # when using a real mix, pass the same mix to model and loss 192 | ref_mix_a = ref_mix 193 | ref_mix_b = ref_mix 194 | 195 | 196 | 197 | 198 | # tracks_a = tracks[..., :input_middle_idx] # not used currently 199 | 200 | #print("input tracks: ", tracks[...,middle_idx:]) 201 | #print("ref_mix: ", ref_mix_a) 202 | 203 | 204 | if self.current_epoch >= self.active_compressor_epoch: 205 | self.use_track_compressor = True 206 | 207 | if self.current_epoch >= self.active_fx_bus_epoch: 208 | self.use_fx_bus = True 209 | 210 | if self.current_epoch >= self.active_master_bus_epoch: 211 | self.use_master_bus = True 212 | 213 | bs, num_tracks, seq_len = tracks.shape 214 | 215 | # apply random gain to input tracks 216 | # tracks *= 10 ** ((torch.rand(bs, num_tracks, 1).type_as(tracks) * -12.0) / 20.0) 217 | ref_track_param_dict = None 218 | ref_fx_bus_param_dict = None 219 | ref_master_bus_param_dict = None 220 | 221 | # --------- create a random mix (on GPU, if applicable) --------- 222 | if self.generate_mix: 223 | ( 224 | ref_mix_tracks, 225 | ref_mix, 226 | ref_track_param_dict, 227 | ref_fx_bus_param_dict, 228 | ref_master_bus_param_dict, 229 | ref_mix_params, 230 | ref_fx_bus_params, 231 | ref_master_bus_params, 232 | ) = self.mix_fn( 233 | tracks, 234 | self.mix_console, 235 | use_track_input_fader=False, # do not use track input fader for training 236 | use_track_panner=self.use_track_panner, 237 | use_track_eq=self.use_track_eq, 238 | use_track_compressor=self.use_track_compressor, 239 | use_fx_bus=self.use_fx_bus, 240 | use_master_bus=self.use_master_bus, 241 | use_output_fader=False, # not used because we normalize output mixes 242 | instrument_id=instrument_id, 243 | stereo_id=stereo_info, 244 | instrument_number_file=self.instrument_number_lookup, 245 | ke_dict=self.knowledge_engineering_dict, 246 | ) 247 | 248 | # normalize the reference mix 249 | ref_mix = batch_stereo_peak_normalize(ref_mix) 250 | 251 | if torch.isnan(ref_mix).any(): 252 | print(ref_track_param_dict) 253 | raise ValueError("Found nan in ref_mix") 254 | 255 | ref_mix_a = ref_mix[..., :middle_idx] # this is passed to the model 256 | ref_mix_b = ref_mix[..., middle_idx:] # this is used for loss computation 257 | # tracks_a = tracks[..., :input_middle_idx] # not used currently 258 | tracks_b = tracks[..., middle_idx:] # this is passed to the model 259 | else: 260 | # when using a real mix, pass the same mix to model and loss 261 | ref_mix_a = ref_mix 262 | ref_mix_b = ref_mix 263 | tracks_b = tracks 264 | 265 | 266 | # ---- run model with tracks from section A using reference mix from section B ---- 267 | ( 268 | pred_track_params, 269 | pred_fx_bus_params, 270 | pred_master_bus_params, 271 | ) = self.model(tracks_b, ref_mix_a, track_padding_mask=track_padding) 272 | 273 | # ------- generate a mix using the predicted mix console parameters ------- 274 | ( 275 | pred_mixed_tracks_b, 276 | pred_mix_b, 277 | pred_track_param_dict, 278 | pred_fx_bus_param_dict, 279 | pred_master_bus_param_dict, 280 | ) = self.mix_console( 281 | tracks_b, 282 | pred_track_params, 283 | pred_fx_bus_params, 284 | pred_master_bus_params, 285 | use_track_input_fader=self.use_track_input_fader, 286 | use_track_panner=self.use_track_panner, 287 | use_track_eq=self.use_track_eq, 288 | use_track_compressor=self.use_track_compressor, 289 | use_fx_bus=self.use_fx_bus, 290 | use_master_bus=self.use_master_bus, 291 | use_output_fader=self.use_output_fader, 292 | ) 293 | 294 | # normalize the predicted mix before computing the loss 295 | # pred_mix_b = batch_stereo_peak_normalize(pred_mix_b) 296 | 297 | 298 | if ref_track_param_dict is None: 299 | ref_track_param_dict = pred_track_param_dict 300 | ref_fx_bus_param_dict = pred_fx_bus_param_dict 301 | ref_master_bus_param_dict = pred_master_bus_param_dict 302 | 303 | # ---------------------------- compute and log loss ------------------------------ 304 | 305 | 306 | #print("pred_mix: ", pred_mix_b) 307 | # if pred_mix_b.sum() == 0: 308 | 309 | #print("pred_track_params: ", pred_track_params) 310 | #print("pred_fx_bus_params: ", pred_fx_bus_params) 311 | #print("pred_master_bus_params: ", pred_master_bus_params) 312 | #print("ref_mix: ",ref_mix_b) 313 | 314 | loss = 0 315 | 316 | #if parameter_loss is being used to train model, no need to generate mix 317 | if self.use_param_loss: 318 | track_param_loss = self.loss(pred_track_params, ref_mix_params) 319 | loss += track_param_loss 320 | if self.use_fx_bus: 321 | fx_bus_param_loss = self.loss(pred_fx_bus_params, ref_fx_bus_params) 322 | loss += fx_bus_param_loss 323 | if self.use_master_bus: 324 | master_bus_param_loss = self.loss(pred_master_bus_params, ref_master_bus_params) 325 | loss += master_bus_param_loss 326 | 327 | 328 | # ---------------------------- compute and log loss ------------------------------ 329 | 330 | loss = 0 331 | if self.use_mix_loss: 332 | mix_loss = self.loss(pred_mix_b, ref_mix_b) 333 | 334 | if type(mix_loss) == dict: 335 | for key, val in mix_loss.items(): 336 | loss += val.mean() 337 | else: 338 | loss += mix_loss 339 | 340 | 341 | if type(mix_loss) == dict: 342 | for key, value in mix_loss.items(): 343 | self.log( 344 | ("train" if train else "val") + "/" + key, 345 | value, 346 | on_step=True, 347 | on_epoch=True, 348 | prog_bar=False, 349 | logger=True, 350 | sync_dist=True, 351 | ) 352 | #print(loss) 353 | 354 | 355 | # log the losses 356 | self.log( 357 | ("train" if train else "val") + "/loss", 358 | loss, 359 | on_step=True, 360 | on_epoch=True, 361 | prog_bar=True, 362 | logger=True, 363 | sync_dist=True, 364 | ) 365 | 366 | 367 | # sisdr_error = -self.sisdr(pred_mix_b, ref_mix_b) 368 | # log the SI-SDR error 369 | # self.log( 370 | # ("train" if train else "val") + "/si-sdr", 371 | # sisdr_error, 372 | # on_step=False, 373 | # on_epoch=True, 374 | # prog_bar=False, 375 | # logger=True, 376 | # sync_dist=True, 377 | # ) 378 | 379 | # mrstft_error = self.mrstft(pred_mix_b, ref_mix_b) 380 | ## log the MR-STFT error 381 | # self.log( 382 | # ("train" if train else "val") + "/mrstft", 383 | # mrstft_error, 384 | # on_step=False, 385 | # on_epoch=True, 386 | # prog_bar=False, 387 | # logger=True, 388 | # sync_dist=True, 389 | # ) 390 | 391 | # for plotting down the line 392 | sum_mix_b = tracks_b.sum(dim=1, keepdim=True).detach().float().cpu() 393 | sum_mix_b = batch_stereo_peak_normalize(sum_mix_b) 394 | data_dict = { 395 | "ref_mix_a": ref_mix_a.detach().float().cpu(), 396 | "ref_mix_b_norm": ref_mix_b.detach().float().cpu(), 397 | "pred_mix_b_norm": pred_mix_b.detach().float().cpu(), 398 | "sum_mix_b": sum_mix_b, 399 | "ref_track_param_dict": ref_track_param_dict, 400 | "pred_track_param_dict": pred_track_param_dict, 401 | "ref_fx_bus_param_dict": ref_fx_bus_param_dict, 402 | "pred_fx_bus_param_dict": pred_fx_bus_param_dict, 403 | "ref_master_bus_param_dict": ref_master_bus_param_dict, 404 | "pred_master_bus_param_dict": pred_master_bus_param_dict, 405 | } 406 | 407 | return loss, data_dict 408 | 409 | def training_step(self, batch, batch_idx): 410 | loss, data_dict = self.common_step(batch, batch_idx, train=True) 411 | 412 | #print(loss) 413 | return loss 414 | 415 | def validation_step(self, batch, batch_idx): 416 | loss, data_dict = self.common_step(batch, batch_idx, train=False) 417 | return data_dict 418 | 419 | def configure_optimizers(self): 420 | optimizer = torch.optim.Adam( 421 | self.model.parameters(), 422 | lr=self.hparams.lr, 423 | betas=(0.9, 0.999), 424 | ) 425 | 426 | if self.hparams.schedule == "cosine": 427 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 428 | optimizer, T_max=self.hparams.max_epochs 429 | ) 430 | elif self.hparams.schedule == "step": 431 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 432 | optimizer, 433 | [ 434 | int(self.hparams.max_epochs * 0.85), 435 | int(self.hparams.max_epochs * 0.95), 436 | ], 437 | ) 438 | else: 439 | #print(optimizer) 440 | return optimizer 441 | lr_schedulers = {"scheduler": scheduler, "interval": "epoch", "frequency": 1} 442 | 443 | return [optimizer], lr_schedulers 444 | --------------------------------------------------------------------------------