├── configs
├── optimizer.yaml
├── data
│ ├── jamendo.yaml
│ ├── medley+cambridge-16.yaml
│ ├── medley+cambridge-8.yaml
│ ├── medley+cambridge+jamendo-16.yaml
│ ├── medley+cambridge+jamendo-8.yaml
│ └── medley+cambridge-4.yaml
├── models
│ ├── unpaired+feat.yaml
│ ├── naive.yaml
│ └── naive+feat.yaml
└── config.yaml
├── Assets
└── diffmst-main_modified.jpg
├── stability.sh
├── .gitignore
├── scripts
├── online.sh
├── info.py
├── compare.py
├── run.py
├── datasets.py
├── gain_testing.py
├── eval_listen.py
├── eval_ablation.py
├── eval_all_combo.py
└── online.py
├── tests
├── test_panner.py
├── test_encoder.py
├── test_bus.py
├── test_remix.py
├── test_mix.py
├── test_loss.py
├── test_mst.py
├── test_ke.py
├── test_comp.py
├── test_profile.py
├── test_reverb.py
├── test_dataset.py
├── test_sepremix.py
└── test_peq.py
├── main.py
├── mst
├── callbacks
│ ├── acc.py
│ ├── metrics.py
│ ├── plotting.py
│ ├── audio.py
│ └── mix.py
├── filter.py
├── param_system.py
├── panns.py
├── loss.py
├── fx_encoder.py
├── utils.py
└── system.py
├── data
├── console_ranges.yaml
└── instrument_name2id.json
├── setup.py
├── requirements.txt
└── README.md
/configs/optimizer.yaml:
--------------------------------------------------------------------------------
1 | optimizer:
2 | class_path: torch.optim.Adam
3 | init_args:
4 | lr: 0.00001
--------------------------------------------------------------------------------
/Assets/diffmst-main_modified.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sai-soum/Diff-MST/HEAD/Assets/diffmst-main_modified.jpg
--------------------------------------------------------------------------------
/stability.sh:
--------------------------------------------------------------------------------
1 | #
2 |
3 | cd /scratch
4 | mkdir medleydb
5 | cd medleydb
6 | aws s3 sync s3://stability-aws/MedleyDB ./
7 | tar -xvf MedleyDB_v1.tar
8 | tar -xvf MedleyDB_v2.tar
--------------------------------------------------------------------------------
/configs/data/jamendo.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: mst.dataloader.MixDataModule
3 | init_args:
4 | root_dir: /import/c4dm-datasets-ext/mtg-jamendo
5 | length: 262144
6 | batch_size: 4
7 | num_workers: 4
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | venv/**
2 | env/**
3 | mst/callbacks/trail.ipynb
4 | model_saved/**
5 | data/.ipynb_checkpoints/**
6 | data/*.ipynb
7 |
8 | __pycache__
9 | *.egg-info
10 |
11 | mix_KE_adv/**
12 | .vscode/
13 | logs/**
14 | checkpoints/
15 | debug
16 | *.wav
17 | *.png
18 | data/FXencoder_ps.pt
19 | outputs/**
--------------------------------------------------------------------------------
/scripts/online.sh:
--------------------------------------------------------------------------------
1 |
2 | CUDA_VISIBLE_DEVICES=4 python scripts/online.py \
3 | --track_dir "/import/c4dm-datasets-ext/test-multitracks/Kat Wright_By My Side" \
4 | --ref_mix "/import/c4dm-datasets-ext/diffmst_validation/listening/diffmst-examples_wavref/The Dip - Paddle To The Stars (Lyric Video).wav" \
5 | --use_gpu \
6 | --n_iters 1000 \
7 | --loss "feat" \
8 | #--stem_separation \
--------------------------------------------------------------------------------
/tests/test_panner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dasp_pytorch.functional import stereo_panner
3 |
4 | tracks = torch.ones(1, 4, 1)
5 | print(tracks)
6 | print(tracks.shape)
7 |
8 | param_dict = {"stereo_panner": {"pan": torch.tensor([0.0, 0.5, 1.0, 0.0])}}
9 |
10 | tracks = stereo_panner(tracks, **param_dict["stereo_panner"])
11 |
12 | print(tracks)
13 | print(tracks.shape)
14 | tracks.sum(dim=2)
15 |
--------------------------------------------------------------------------------
/tests/test_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from mst.panns import TCN
3 | from mst.modules import WaveformTransformerEncoder
4 |
5 |
6 | def count_parameters(model):
7 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
8 |
9 |
10 | # encoder = TCN(n_inputs=2)
11 | encoder = WaveformTransformerEncoder(n_inputs=2)
12 |
13 | print(count_parameters(encoder) / 1e6)
14 |
15 | x = torch.randn(4, 2, 262144)
16 |
17 |
18 | y = encoder(x)
19 | print(y.shape)
20 |
--------------------------------------------------------------------------------
/configs/data/medley+cambridge-16.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: mst.dataloader.MultitrackDataModule
3 | init_args:
4 | track_root_dirs:
5 | - /import/c4dm-datasets-ext/mixing-secrets/
6 | - /import/c4dm-datasets/
7 | metadata_files:
8 | - ./data/cambridge.yaml
9 | - ./data/medley.yaml
10 | length: 262144
11 | min_tracks: 2
12 | max_tracks: 16
13 | batch_size: 1
14 | num_workers: 4
15 | num_train_passes: 20
16 | num_val_passes: 1
17 | train_buffer_size_gb: 4.0
18 | val_buffer_size_gb: 0.2
19 | target_track_lufs_db: -48.0
--------------------------------------------------------------------------------
/configs/data/medley+cambridge-8.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: mst.dataloader.MultitrackDataModule
3 | init_args:
4 | track_root_dirs:
5 | - /import/c4dm-datasets-ext/mixing-secrets/
6 | - /import/c4dm-datasets/
7 | metadata_files:
8 | - ./data/cambridge.yaml
9 | - ./data/medley.yaml
10 | length: 262144
11 | min_tracks: 8
12 | max_tracks: 8
13 | batch_size: 4
14 | num_workers: 4
15 | num_train_passes: 20
16 | num_val_passes: 1
17 | train_buffer_size_gb: 4.0
18 | val_buffer_size_gb: 1.0
19 | target_track_lufs_db: -48.0
--------------------------------------------------------------------------------
/configs/data/medley+cambridge+jamendo-16.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: mst.dataloader.MultitrackDataModule
3 | init_args:
4 | track_root_dirs:
5 | - /import/c4dm-datasets-ext/mixing-secrets/
6 | - /import/c4dm-datasets/
7 | mix_root_dirs:
8 | - /import/c4dm-datasets-ext/mtg-jamendo
9 | metadata_files:
10 | - ./data/cambridge.yaml
11 | - ./data/medley.yaml
12 | length: 262144
13 | min_tracks: 2
14 | max_tracks: 16
15 | batch_size: 1
16 | num_workers: 4
17 | num_train_passes: 20
18 | num_val_passes: 1
19 | train_buffer_size_gb: 4.0
20 | val_buffer_size_gb: 0.2
21 | target_track_lufs_db: -48.0
22 | randomize_ref_mix_gain: true
--------------------------------------------------------------------------------
/configs/data/medley+cambridge+jamendo-8.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: mst.dataloader.MultitrackDataModule
3 | init_args:
4 | track_root_dirs:
5 | - /import/c4dm-datasets-ext/mixing-secrets/
6 | - /import/c4dm-datasets/
7 | mix_root_dirs:
8 | - /import/c4dm-datasets-ext/mtg-jamendo
9 | metadata_files:
10 | - ./data/cambridge.yaml
11 | - ./data/medley.yaml
12 | length: 262144
13 | min_tracks: 2
14 | max_tracks: 8
15 | batch_size: 4
16 | num_workers: 8
17 | num_train_passes: 20
18 | num_val_passes: 1
19 | train_buffer_size_gb: 2.0
20 | val_buffer_size_gb: 0.2
21 | target_mix_lufs_db: -16.0
22 | target_track_lufs_db: -48.0
23 | randomize_ref_mix_gain: false
24 |
--------------------------------------------------------------------------------
/scripts/info.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torchaudio
4 | from tqdm import tqdm
5 |
6 | root_dir = "/import/c4dm-datasets-ext/mixing-secrets/"
7 |
8 | # find all directories containing tracks
9 | song_dirs = glob.glob(os.path.join(root_dir, "*"))
10 |
11 | files = {"stereo": 0, "mono": 0}
12 |
13 | for song_dir in tqdm(song_dirs):
14 | # get all tracks in song dir
15 | track_filepaths = glob.glob(os.path.join(song_dir, "tracks", "**", "*.wav"))
16 |
17 | for track_filepath in track_filepaths:
18 | # get into
19 | md = torchaudio.info(track_filepath)
20 |
21 | if md.num_channels == 2:
22 | files["stereo"] += 1
23 | elif md.num_channels == 1:
24 | files["mono"] += 1
25 |
26 | print(files)
27 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # class MyLightningCLI(pl.LightningCLI):
2 | # def add_arguments_to_parser(self, parser):
3 | # parser.link_arguments(
4 | # "model.model.mix_console.num_control_params",
5 | # "model.model.controller.num_control_params",
6 | # apply_on="instantiate",
7 | # )
8 |
9 | import torch
10 | from pytorch_lightning.cli import LightningCLI
11 | from pytorch_lightning.strategies import DDPStrategy
12 |
13 |
14 | def cli_main():
15 | cli = LightningCLI(
16 | save_config_callback=None,
17 | trainer_defaults={
18 | "accelerator": "gpu",
19 | # "strategy": DDPStrategy(find_unused_parameters=True),
20 | "log_every_n_steps": 50,
21 | },
22 | )
23 |
24 |
25 | if __name__ == "__main__":
26 |
27 | cli_main()
28 |
--------------------------------------------------------------------------------
/tests/test_bus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | import matplotlib.pyplot as plt
4 |
5 | from dasp_pytorch.functional import stereo_bus
6 |
7 |
8 | bs = 2
9 | chs = 2
10 | seq_len = 262144
11 | sample_rate = 44100
12 |
13 | # x = torch.randn(bs, chs, seq_len)
14 | # x = x / x.abs().max().clamp(min=1e-8)
15 | # x *= 10 ** (-24 / 20)
16 |
17 | x, sr = torchaudio.load("tests/target-gtr.wav")
18 | x = x.unsqueeze(0)
19 |
20 | x = torch.randn(bs, 8, chs, seq_len)
21 | sends_db = torch.tensor([0.0, -3.0, -6.0, -9.0, -12.0, -15.0, -18.0, -21.0]).view(1, 8, 1)
22 | sends_db = sends_db.repeat(bs, 1, 1)
23 |
24 | print(x.shape)
25 |
26 | y = stereo_bus(
27 | x,
28 | sends_db,
29 | )
30 |
31 | print(y.shape)
32 | y /= y.abs().max().clamp(min=1e-8)
33 | torchaudio.save("tests/reverb.wav", y.view(2, -1), sample_rate)
34 |
--------------------------------------------------------------------------------
/tests/test_remix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from mst.modules import Remixer, AdvancedMixConsole
3 | from mst.dataloader import MixDataset
4 |
5 | if __name__ == "__main__":
6 | root_dir = "/import/c4dm-datasets-ext/mtg-jamendo"
7 | mix_dataset = MixDataset(root_dir, length=262144)
8 | mix_dataloader = torch.utils.data.DataLoader(mix_dataset, batch_size=4)
9 |
10 | mix_console = AdvancedMixConsole(44100)
11 | remixer = Remixer(44100)
12 |
13 | remixer.cuda()
14 |
15 | for batch_idx, batch in enumerate(mix_dataloader):
16 | mix = batch
17 |
18 | mix = mix.cuda()
19 |
20 | # create remix
21 | remix, track_params, fx_bus_params, master_bus_params = remixer(
22 | mix, mix_console
23 | )
24 |
25 | print(batch_idx, mix.abs().max(), remix.abs().max())
26 |
--------------------------------------------------------------------------------
/tests/test_mix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.models.resnet import resnet18
3 | from mst.modules import (
4 | MixStyleTransferModel,
5 | SpectrogramResNetEncoder,
6 | TransformerController,
7 | BasicMixConsole,
8 | AdvancedMixConsole,
9 | )
10 | from tqdm import tqdm
11 | from mst.mixing import naive_random_mix, knowledge_engineering_mix
12 |
13 | sample_rate = 44100
14 | embed_dim = 128
15 | num_control_params = 26
16 |
17 |
18 | mix_console = AdvancedMixConsole(sample_rate)
19 |
20 | for n in tqdm(range(100)):
21 | bs = 8
22 | num_tracks = 4
23 | seq_len = 262144
24 |
25 | tracks = torch.randn(bs, num_tracks, seq_len) * 0.1
26 |
27 | mix, params = naive_random_mix(tracks, mix_console)
28 |
29 | if torch.isnan(mix).any():
30 | print("NAN")
31 | print(mix.shape)
32 | print(torch.isnan(mix).any())
33 |
34 | break
35 |
--------------------------------------------------------------------------------
/configs/data/medley+cambridge-4.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: mst.dataloader.MultitrackDataModule
3 | init_args:
4 | #update the root dirs to the location of the dataset for Cambridge.mt and MEdleyDB. You can use just one dataset if you want. Multiple datasets should be provided as list of diectories.
5 | #corresponding metadata files should be provided as list of yaml files which contain train, val and test splits. We have default splits for Cambridge and MedleyDB in the data folder.
6 | track_root_dirs:
7 | - /import/c4dm-datasets-ext/mixing-secrets/
8 | - /import/c4dm-datasets/
9 | metadata_files:
10 | - ./data/cambridge.yaml
11 | - ./data/medley.yaml
12 | length: 262144
13 | #supports different values for min and max tracks
14 | min_tracks: 4
15 | max_tracks: 4
16 | batch_size: 2
17 | num_workers: 4
18 | num_train_passes: 20
19 | num_val_passes: 1
20 | train_buffer_size_gb: 4.0
21 | val_buffer_size_gb: 1.0
22 | target_track_lufs_db: -48.0
23 | randomize_ref_mix_gain: false
24 |
25 |
--------------------------------------------------------------------------------
/mst/callbacks/acc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pytorch_lightning as pl
4 |
5 | from mst.callbacks.plotting import plot_confusion_matrix
6 |
7 |
8 | class ConfusionMatrixCallback(pl.callbacks.Callback):
9 | def __init__(self):
10 | super().__init__()
11 | self.targets = []
12 | self.estimates = []
13 |
14 | def on_validation_batch_end(
15 | self,
16 | trainer,
17 | pl_module,
18 | outputs,
19 | batch,
20 | batch_idx,
21 | dataloader_idx,
22 | ):
23 | """Called when the validation batch ends."""
24 | if outputs is not None:
25 | self.targets.append(outputs["e"])
26 | self.estimates.append(outputs["e_hat"].max(1).indices)
27 |
28 | def on_validation_end(self, trainer, pl_module):
29 |
30 | e = torch.cat(self.targets, dim=0)
31 | e_hat = torch.cat(self.estimates, dim=0)
32 |
33 | trainer.logger.experiment.add_image(
34 | f"confusion_matrix",
35 | plot_confusion_matrix(
36 | e_hat,
37 | e,
38 | labels=pl_module.hparams.effects,
39 | ),
40 | trainer.global_step,
41 | )
42 |
43 | # clear outputs
44 | self.targets = []
45 | self.estimates = []
46 |
--------------------------------------------------------------------------------
/tests/test_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 |
4 | from mst.loss import AudioFeatureLoss, ParameterEstimatorLoss
5 | from mst.loss import (
6 | compute_crest_factor,
7 | compute_melspectrum,
8 | compute_barkspectrum,
9 | compute_rms,
10 | compute_stereo_imbalance,
11 | compute_stereo_width,
12 | )
13 |
14 | transforms = [
15 | compute_rms,
16 | compute_crest_factor,
17 | compute_stereo_width,
18 | compute_stereo_imbalance,
19 | compute_barkspectrum,
20 | ]
21 | weights = [10.0, 0.1, 10.0, 100.0, 0.1]
22 |
23 | sample_rate = 44100
24 |
25 | # loss = AudioFeatureLoss(weights, sample_rate, stem_separation=False)
26 |
27 | ckpt_path = "/import/c4dm-datasets-ext/Diff-MST/DiffMST-Param/0ymfi1pp/checkpoints/epoch=5-step=10842.ckpt"
28 | loss = ParameterEstimatorLoss(ckpt_path)
29 |
30 | # test with audio examples
31 | input, _ = torchaudio.load("outputs/output/pred_mix.wav")
32 | target, _ = torchaudio.load("outputs/output/ref_mix.wav")
33 |
34 |
35 | input = input.unsqueeze(0)
36 | target = target.unsqueeze(0)
37 |
38 | input = input.repeat(4, 1, 1)
39 | target = target.repeat(4, 1, 1)
40 |
41 | # input[0, ...] = 0.0001 * torch.randn_like(input[0, ...])
42 | # target[0, ...] = 0.0001 * torch.randn_like(input[0, ...])
43 |
44 |
45 | loss_val = loss(input, target)
46 |
47 | print(loss_val)
48 |
--------------------------------------------------------------------------------
/tests/test_mst.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.models.resnet import resnet18
3 | from mst.modules import (
4 | MixStyleTransferModel,
5 | SpectrogramResNetEncoder,
6 | TransformerController,
7 | BasicMixConsole,
8 | AdvancedMixConsole,
9 | )
10 |
11 | sample_rate = 44100
12 | embed_dim = 128
13 | num_track_control_params = 27
14 | num_fx_bus_control_params = 25
15 | num_master_bus_control_params = 24
16 | use_fx_bus = True
17 | use_master_bus = True
18 |
19 | track_encoder = SpectrogramResNetEncoder()
20 | mix_encoder = SpectrogramResNetEncoder()
21 | controller = TransformerController(
22 | embed_dim=embed_dim,
23 | num_track_control_params=num_track_control_params,
24 | num_fx_bus_control_params=num_fx_bus_control_params,
25 | num_master_bus_control_params=num_master_bus_control_params,
26 | )
27 |
28 |
29 | mix_console = AdvancedMixConsole(sample_rate)
30 |
31 | model = MixStyleTransferModel(track_encoder, mix_encoder, controller, mix_console)
32 |
33 | bs = 8
34 | num_tracks = 4
35 | seq_len = 262144
36 |
37 | tracks = torch.randn(bs, num_tracks, seq_len)
38 | ref_mix = torch.randn(bs, 2, seq_len)
39 |
40 | (
41 | mixed_tracks,
42 | mix,
43 | track_param_dict,
44 | fx_bus_param_dict,
45 | master_bus_param_dict,
46 | ) = model(tracks, ref_mix)
47 | print(mix.shape)
48 | print(mix)
49 |
--------------------------------------------------------------------------------
/tests/test_ke.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchaudio
4 |
5 | from mst.mixing import knowledge_engineering_mix
6 | from mst.modules import BasicMixConsole, AdvancedMixConsole
7 | from mst.dataloaders.medley import MedleyDBDataset
8 |
9 | dataset = MedleyDBDataset(root_dirs=["/scratch/medleydb/V1"], subset="train")
10 | print(len(dataset))
11 |
12 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
13 | # mix_console = mst.modules.BasicMixConsole(sample_rate=44100.0)
14 | mix_console = AdvancedMixConsole(sample_rate=44100.0)
15 | for i, (track, instrument_id, stereo) in enumerate(dataloader):
16 | print("\n\n mixing")
17 | print("track", track.size())
18 | batch_size, num_tracks, seq_len = track.size()
19 |
20 | print(instrument_id)
21 | print(stereo)
22 |
23 | mix, param_dict = knowledge_engineering_mix(
24 | track, mix_console, instrument_id, stereo
25 | )
26 | print(param_dict)
27 | sum_mix = torch.sum(track, dim=1)
28 | print("mix", mix.size())
29 |
30 | save_dir = "debug"
31 | os.makedirs(save_dir, exist_ok=True)
32 |
33 | # export audio
34 | for j in range(batch_size):
35 | torchaudio.save(os.path.join(save_dir, "mix_" + str(j) + ".wav"), mix[j], 44100)
36 | torchaudio.save(
37 | os.path.join(save_dir, "sum" + str(j) + ".wav"), sum_mix[j], 44100
38 | )
39 |
40 | print("mix", mix.size())
41 | if i == 0:
42 | break
43 |
--------------------------------------------------------------------------------
/data/console_ranges.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | #silence
3 | [instrument_category]:
4 | instruments:
5 | - instrument
6 | gain :
7 | - -24.0
8 | - -24.0
9 | pan :
10 | - 0.0
11 | - 1.0
12 | eq:
13 | eq_lowshelf_gain :
14 | - -3.0
15 | - -1.0
16 | eq_lowshelf_freq :
17 | - 20
18 | - 2000
19 | eq_lowshelf_q :
20 | - 0.1
21 | - 5.0
22 | eq_band0_gain :
23 | - -24
24 | - 24
25 | eq_band0_freq :
26 | - 80
27 | - 2000
28 | eq_band0_q :
29 | - 0.1
30 | - 5.0
31 | eq_band1_gain :
32 | - -24
33 | - 24
34 | eq_band1_freq :
35 | - 2000
36 | - 8000
37 | eq_band1_q :
38 | - 0.1
39 | - 5.0
40 | eq_band2_gain :
41 | - -24
42 | - 24
43 | eq_band2_freq :
44 | - 8000
45 | - 12000
46 | eq_band2_q :
47 | - 0.1
48 | - 5.0
49 | eq_band3_gain :
50 | - 0.0
51 | - 0.0
52 | eq_band3_freq :
53 | - 12000
54 | - 20000
55 | eq_band3_q :
56 | - 0.1
57 | - 5.0
58 | eq_highshelf_gain :
59 | - -24
60 | - 24
61 | eq_highshelf_freq :
62 | - 6000
63 | - 20000
64 | eq_highshelf_q :
65 | - 0.1
66 | - 5.0
67 | compressor:
68 | threshold_db :
69 | - -60.0
70 | - 0.0
71 | ratio:
72 | - 1.0
73 | - 10.0
74 | attack_ms:
75 | - 1.0
76 | - 1000.0
77 | release_ms:
78 | - 1.0
79 | - 1000.0
80 | knee_db:
81 | - 3.0
82 | - 24.0
83 | makeup_gain_db:
84 | - 0.0
85 | - 24.0
86 |
87 |
--------------------------------------------------------------------------------
/tests/test_comp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | import matplotlib.pyplot as plt
4 |
5 | from dasp_pytorch.functional import compressor
6 |
7 |
8 | bs = 2
9 | chs = 2
10 | seq_len = 262144
11 | sample_rate = 44100
12 |
13 | # x = torch.randn(bs, chs, seq_len)
14 | # x = x / x.abs().max().clamp(min=1e-8)
15 | # x *= 10 ** (-24 / 20)
16 |
17 |
18 | x = torch.zeros(bs, chs, seq_len)
19 |
20 | x[..., 0, 4096:131072] = 1.0
21 | x[..., 1, 16384:65536] = 1.0
22 |
23 |
24 | threshold_db = torch.tensor([-12.0, -6.0])
25 | ratio = torch.tensor([4.0, 4.0])
26 | attack_ms = torch.tensor([100.0, 1000.0])
27 | release_ms = torch.tensor([0.0, 0.0]) # dummy parameter
28 | knee_db = torch.tensor([6.0, 6.0])
29 | makeup_gain_db = torch.tensor([0.0, 0.0])
30 |
31 | threshold_db = threshold_db.view(1, chs).repeat(bs, 1)
32 | ratio = ratio.view(1, chs).repeat(bs, 1)
33 | attack_ms = attack_ms.view(1, chs).repeat(bs, 1)
34 | release_ms = release_ms.view(1, chs).repeat(bs, 1)
35 | knee_db = knee_db.view(1, chs).repeat(bs, 1)
36 | makeup_gain_db = makeup_gain_db.view(1, chs).repeat(bs, 1)
37 |
38 | print(threshold_db.shape)
39 | y = compressor(
40 | x,
41 | sample_rate,
42 | threshold_db,
43 | ratio,
44 | attack_ms,
45 | release_ms,
46 | knee_db,
47 | makeup_gain_db,
48 | )
49 |
50 | print(y.shape)
51 |
52 | fig, axs = plt.subplots(2, 1, figsize=(10, 6))
53 | axs[0].plot(x[0, 0, :].numpy(), label="input")
54 | axs[0].plot(y[0, 0, :].numpy(), label="output")
55 | axs[0].legend()
56 |
57 | axs[1].plot(x[0, 1, :].numpy(), label="input")
58 | axs[1].plot(y[0, 1, :].numpy(), label="output")
59 | axs[1].legend()
60 | plt.grid(c="lightgray")
61 | plt.savefig("compressor.png", dpi=300)
62 |
--------------------------------------------------------------------------------
/configs/models/unpaired+feat.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | class_path: mst.system.System
3 | init_args:
4 | generate_mix: false
5 | active_eq_epoch: 0
6 | active_compressor_epoch: 0
7 | active_fx_bus_epoch: 1000
8 | active_master_bus_epoch: 0
9 | mix_fn: mst.mixing.naive_random_mix
10 | mix_console:
11 | class_path: mst.modules.AdvancedMixConsole
12 | init_args:
13 | sample_rate: 44100
14 | input_min_gain_db: -48.0
15 | input_max_gain_db: 48.0
16 | output_min_gain_db: -48.0
17 | output_max_gain_db: 48.0
18 | eq_min_gain_db: -12.0
19 | eq_max_gain_db: 12.0
20 | min_pan: 0.0
21 | max_pan: 1.0
22 | model:
23 | class_path: mst.modules.MixStyleTransferModel
24 | init_args:
25 | track_encoder:
26 | class_path: mst.modules.SpectrogramEncoder
27 | init_args:
28 | embed_dim: 512
29 | n_fft: 2048
30 | hop_length: 512
31 | input_batchnorm: false
32 | mix_encoder:
33 | class_path: mst.modules.SpectrogramEncoder
34 | init_args:
35 | embed_dim: 512
36 | n_fft: 2048
37 | hop_length: 512
38 | input_batchnorm: false
39 | controller:
40 | class_path: mst.modules.TransformerController
41 | init_args:
42 | embed_dim: 512
43 | num_track_control_params: 27
44 | num_fx_bus_control_params: 25
45 | num_master_bus_control_params: 26
46 | num_layers: 12
47 | nhead: 8
48 |
49 | loss:
50 | class_path: mst.loss.AudioFeatureLoss
51 | init_args:
52 | sample_rate: 44100
53 | stem_separation: false
54 | use_clap: false
55 | weights:
56 | - 0.1 # rms
57 | - 0.001 # crest factor
58 | - 1.0 # stereo width
59 | - 1.0 # stereo imbalance
60 | - 0.1 # bark spectrum
61 |
62 |
--------------------------------------------------------------------------------
/mst/callbacks/metrics.py:
--------------------------------------------------------------------------------
1 | from re import X
2 | import numpy as np
3 | import pytorch_lightning as pl
4 |
5 |
6 |
7 |
8 | class LogAudioMetricsCallback(pl.callbacks.Callback):
9 | def __init__(
10 | self,
11 | sample_rate: int = 44100,
12 | ):
13 | super().__init__()
14 | self.sample_rate = sample_rate
15 |
16 | self.metrics = {
17 | "PESQi": PESQi(sample_rate),
18 | "MRSTFTi": MRSTFTi(),
19 | "SISDRi": SISDRi(),
20 | }
21 |
22 | self.outputs = []
23 |
24 | def on_validation_batch_end(
25 | self,
26 | trainer,
27 | pl_module,
28 | outputs,
29 | batch,
30 | batch_idx,
31 | dataloader_idx,
32 | ):
33 | """Called when the validation batch ends."""
34 |
35 | if outputs is not None:
36 | self.outputs.append(outputs)
37 |
38 | def on_validation_end(self, trainer, pl_module):
39 | y_hat_metrics = {
40 | "PESQi": [],
41 | "MRSTFTi": [],
42 | "SISDRi": [],
43 | }
44 | for output in self.outputs:
45 | for metric_name, metric in self.metrics.items():
46 | for batch_idx in range(len(output["y_hat"].shape)):
47 | y_hat = output["y_hat"][batch_idx, ...]
48 | x = output["x"][batch_idx, ...]
49 | y = output["y"][batch_idx, ...]
50 |
51 | try:
52 | val = metric(y_hat, x, y)
53 | y_hat_metrics[metric_name].append(val)
54 | except Exception as e:
55 | print(e)
56 |
57 | # log final mean metrics
58 | for metric_name, metric in y_hat_metrics.items():
59 | val = np.mean(metric)
60 | trainer.logger.experiment.add_scalar(
61 | f"metrics/estimated-{metric_name}", val, trainer.global_step
62 | )
63 |
64 | # clear outputs
65 | self.outputs = []
66 |
--------------------------------------------------------------------------------
/configs/models/naive.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | class_path: mst.system.System
3 | init_args:
4 | generate_mix: true
5 | active_eq_epoch: 0
6 | active_compressor_epoch: 0
7 | active_fx_bus_epoch: 1000
8 | active_master_bus_epoch: 0
9 |
10 | use_track_loss : false
11 | use_mix_loss : true
12 | use_param_loss : false
13 |
14 | mix_fn: mst.mixing.naive_random_mix
15 | mix_console:
16 | class_path: mst.modules.AdvancedMixConsole
17 | init_args:
18 | sample_rate: 44100
19 | input_min_gain_db: -48.0
20 | input_max_gain_db: 48.0
21 | output_min_gain_db: -48.0
22 | output_max_gain_db: 48.0
23 | eq_min_gain_db: -12.0
24 | eq_max_gain_db: 12.0
25 | min_pan: 0.0
26 | max_pan: 1.0
27 | model:
28 | class_path: mst.modules.MixStyleTransferModel
29 | init_args:
30 | track_encoder:
31 | class_path: mst.modules.SpectrogramEncoder
32 | init_args:
33 | embed_dim: 512
34 | n_fft: 2048
35 | hop_length: 512
36 | input_batchnorm: false
37 | mix_encoder:
38 | class_path: mst.modules.SpectrogramEncoder
39 | init_args:
40 | embed_dim: 512
41 | n_fft: 2048
42 | hop_length: 512
43 | input_batchnorm: false
44 | controller:
45 | class_path: mst.modules.TransformerController
46 | init_args:
47 | embed_dim: 512
48 | num_track_control_params: 27
49 | num_fx_bus_control_params: 25
50 | num_master_bus_control_params: 26
51 | num_layers: 12
52 | nhead: 8
53 |
54 | loss:
55 | class_path: auraloss.freq.MultiResolutionSTFTLoss
56 | init_args:
57 | fft_sizes:
58 | - 512
59 | - 2048
60 | - 8192
61 | hop_sizes:
62 | - 256
63 | - 1024
64 | - 4096
65 | win_lengths:
66 | - 512
67 | - 2048
68 | - 8192
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/tests/test_profile.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models as models
3 | from torch.profiler import profile, record_function, ProfilerActivity
4 |
5 | from mst.modules import (
6 | MixStyleTransferModel,
7 | SpectrogramResNetEncoder,
8 | TransformerController,
9 | BasicMixConsole,
10 | AdvancedMixConsole,
11 | )
12 |
13 | if False:
14 | sample_rate = 44100
15 | embed_dim = 128
16 | num_track_control_params = 27
17 | num_fx_bus_control_params = 25
18 | num_master_bus_control_params = 24
19 | use_fx_bus = True
20 | use_master_bus = True
21 |
22 | track_encoder = SpectrogramResNetEncoder()
23 | mix_encoder = SpectrogramResNetEncoder()
24 | controller = TransformerController(
25 | embed_dim=embed_dim,
26 | num_track_control_params=num_track_control_params,
27 | num_fx_bus_control_params=num_fx_bus_control_params,
28 | num_master_bus_control_params=num_master_bus_control_params,
29 | )
30 |
31 | mix_console = AdvancedMixConsole(sample_rate)
32 |
33 | model = MixStyleTransferModel(track_encoder, mix_encoder, controller, mix_console)
34 | model.cuda()
35 |
36 | bs = 1
37 | num_tracks = 8
38 | seq_len = 262144
39 |
40 | tracks = torch.randn(bs, num_tracks, seq_len)
41 | ref_mix = torch.randn(bs, 2, seq_len)
42 |
43 | tracks = tracks.cuda()
44 | ref_mix = ref_mix.cuda()
45 |
46 | with profile(
47 | activities=[
48 | ProfilerActivity.CUDA,
49 | ProfilerActivity.CPU,
50 | ],
51 | profile_memory=True,
52 | record_shapes=True,
53 | with_modules=False,
54 | with_stack=True,
55 | ) as prof:
56 | (
57 | mixed_tracks,
58 | mix,
59 | track_param_dict,
60 | fx_bus_param_dict,
61 | master_bus_param_dict,
62 | ) = model(tracks, ref_mix)
63 |
64 | print(
65 | prof.key_averages(group_by_stack_n=5).table(
66 | sort_by="self_cuda_memory_usage", row_limit=25
67 | )
68 | )
69 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from setuptools import setup, find_packages
3 |
4 | NAME = "DiffMST"
5 | DESCRIPTION = "Mix style transfer with differentiable signal processing"
6 | URL = "https://github.com/sai-soum/DiffMST.git"
7 | EMAIL = "s.s.vanka@qmul.ac.uk"
8 | AUTHOR = "Soumya Sai Vanka"
9 | REQUIRES_PYTHON = ">=3.7.11"
10 | VERSION = "0.0.1"
11 |
12 | HERE = Path(__file__).parent
13 |
14 | try:
15 | with open(HERE / "README.md", encoding="utf-8") as f:
16 | long_description = "\n" + f.read()
17 | except FileNotFoundError:
18 | long_description = DESCRIPTION
19 |
20 | setup(
21 | name=NAME,
22 | version=VERSION,
23 | description=DESCRIPTION,
24 | long_description=long_description,
25 | long_description_content_type="text/markdown",
26 | author=AUTHOR,
27 | author_email=EMAIL,
28 | python_requires=REQUIRES_PYTHON,
29 | url=URL,
30 | packages=[
31 | "mst",
32 | ],
33 | install_requires=[
34 | "auraloss==0.4.0",
35 | "dasp-pytorch==0.0.1",
36 | "librosa",
37 | "matplotlib",
38 | "numpy",
39 | "pedalboard==0.8.7",
40 | "pyloudnorm",
41 | "pytorch_lightning[extra]==2.1.4",
42 | "scipy==1.12.0",
43 | "tensorboard",
44 | "torch==2.2.0",
45 | "torchaudio==2.2.0",
46 | "torchvision==0.17.0",
47 | "tqdm",
48 | "wandb",
49 | ],
50 | extras_require={
51 | "asteroid": ["asteroid-filterbanks>=0.3.2"],
52 | "tests": [
53 | "pytest",
54 | "musdb>=0.4.0",
55 | "museval>=0.4.0",
56 | "asteroid-filterbanks>=0.3.2",
57 | "onnx",
58 | "tqdm",
59 | ],
60 | "stempeg": ["stempeg"],
61 | "evaluation": ["musdb>=0.4.0", "museval>=0.4.0"],
62 | },
63 | # entry_points={"console_scripts": ["umx=openunmix.cli:separate"]},
64 | # packages=find_packages(),
65 | include_package_data=True,
66 | classifiers=[
67 | "Topic :: Multimedia :: Sound/Audio",
68 | "Topic :: Scientific/Engineering",
69 | ],
70 | )
71 |
--------------------------------------------------------------------------------
/tests/test_reverb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | import matplotlib.pyplot as plt
4 |
5 | from dasp_pytorch.functional import noise_shaped_reverberation
6 |
7 |
8 | bs = 2
9 | chs = 2
10 | seq_len = 262144
11 | sample_rate = 44100
12 |
13 | # x = torch.randn(bs, chs, seq_len)
14 | # x = x / x.abs().max().clamp(min=1e-8)
15 | # x *= 10 ** (-24 / 20)
16 |
17 | x, sr = torchaudio.load("tests/target-gtr.wav")
18 | x = x.repeat(2, 1)
19 | x = x.unsqueeze(0)
20 | print(x.shape)
21 |
22 | band_gains = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
23 | band_decays = torch.tensor(
24 | [0.8, 0.7, 0.8, 0.6, 0.6, 0.7, 0.8, 0.8, 0.99, 0.8, 0.9, 1.0]
25 | )
26 |
27 | band0_gain = torch.rand(1)
28 | band1_gain = torch.rand(1)
29 | band2_gain = torch.rand(1)
30 | band3_gain = torch.rand(1)
31 | band4_gain = torch.rand(1)
32 | band5_gain = torch.rand(1)
33 | band6_gain = torch.rand(1)
34 | band7_gain = torch.rand(1)
35 | band8_gain = torch.rand(1)
36 | band9_gain = torch.rand(1)
37 | band10_gain = torch.rand(1)
38 | band11_gain = torch.rand(1)
39 | band0_decay = torch.rand(1)
40 | band1_decay = torch.rand(1)
41 | band2_decay = torch.rand(1)
42 | band3_decay = torch.rand(1)
43 | band4_decay = torch.rand(1)
44 | band5_decay = torch.rand(1)
45 | band6_decay = torch.rand(1)
46 | band7_decay = torch.rand(1)
47 | band8_decay = torch.rand(1)
48 | band9_decay = torch.rand(1)
49 | band10_decay = torch.rand(1)
50 | band11_decay = torch.rand(1)
51 |
52 | mix = torch.tensor([0.05])
53 |
54 | y = noise_shaped_reverberation(
55 | x,
56 | sample_rate,
57 | band0_gain,
58 | band1_gain,
59 | band2_gain,
60 | band3_gain,
61 | band4_gain,
62 | band5_gain,
63 | band6_gain,
64 | band7_gain,
65 | band8_gain,
66 | band9_gain,
67 | band10_gain,
68 | band11_gain,
69 | band0_decay,
70 | band1_decay,
71 | band2_decay,
72 | band3_decay,
73 | band4_decay,
74 | band5_decay,
75 | band6_decay,
76 | band7_decay,
77 | band8_decay,
78 | band9_decay,
79 | band10_decay,
80 | band11_decay,
81 | mix,
82 | num_samples=88200,
83 | num_bandpass_taps=1023,
84 | )
85 |
86 | print(y.shape)
87 | y /= y.abs().max().clamp(min=1e-8)
88 | torchaudio.save("tests/reverb.wav", y.view(2, -1), sample_rate)
89 |
--------------------------------------------------------------------------------
/tests/test_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchaudio
4 |
5 | from mst.mixing import naive_random_mix
6 | from mst.modules import AdvancedMixConsole
7 | from mst.dataloader import MultitrackDataModule
8 |
9 | datamodule = MultitrackDataModule(
10 | root_dirs=["/import/c4dm-datasets-ext/mixing-secrets/", "/import/c4dm-datasets/"],
11 | metadata_files=["data/cambridge.yaml", "data/medley.yaml"],
12 | length=262144,
13 | min_tracks=8,
14 | max_tracks=8,
15 | batch_size=2,
16 | num_workers=4,
17 | num_train_passes=20,
18 | num_val_passes=1,
19 | train_buffer_size_gb=0.1,
20 | val_buffer_size_gb=0.1,
21 | target_track_lufs_db=-48.0,
22 | )
23 | datamodule.setup("fit")
24 |
25 | root = "./"
26 | os.makedirs(f"{root}/debug", exist_ok=True)
27 |
28 | train_loader = datamodule.train_dataloader()
29 | mix_fn = naive_random_mix
30 | use_gpu = True
31 |
32 | generate_mix_console = AdvancedMixConsole(44100)
33 |
34 | for idx, batch in enumerate(train_loader):
35 | tracks, instrument_id, stereo_info = batch
36 |
37 | if use_gpu:
38 | tracks = tracks.cuda()
39 |
40 | sum_mix = tracks.sum(dim=1)
41 | # create a random mix (on GPU, if applicable)
42 | (
43 | mixed_tracks,
44 | mix,
45 | track_param_dict,
46 | fx_bus_param_dict,
47 | master_bus_param_dict,
48 | ) = mix_fn(
49 | tracks,
50 | generate_mix_console,
51 | use_track_input_fader=False,
52 | use_output_fader=False,
53 | use_fx_bus=True,
54 | use_master_bus=True,
55 | )
56 |
57 | mix = mix[0, ...]
58 | sum_mix = sum_mix[0, ...]
59 | print(mix.shape, sum_mix.shape)
60 | print(mix.abs().max(), sum_mix.abs().max())
61 | mix = mix.view(2, -1)
62 | sum_mix = sum_mix.repeat(2, 1)
63 |
64 | mix /= mix.abs().max()
65 | sum_mix /= sum_mix.abs().max()
66 |
67 | # split mix into a and b sections
68 | mix_a = mix[:, 0 : mix.shape[1] // 2]
69 | mix_b = mix[:, mix.shape[1] // 2 :]
70 |
71 | torchaudio.save(f"{root}/debug/{idx}_ref_mix_a.wav", mix_a.cpu(), 44100)
72 | torchaudio.save(f"{root}/debug/{idx}_ref_mix_b.wav", mix_b.cpu(), 44100)
73 | torchaudio.save(f"{root}/debug/{idx}_sum_mix.wav", sum_mix.cpu(), 44100)
74 | if idx > 25:
75 | break
76 |
--------------------------------------------------------------------------------
/scripts/compare.py:
--------------------------------------------------------------------------------
1 | # script to compare two mixes based on their features
2 | import os
3 | import torch
4 | import argparse
5 | import torchaudio
6 | import matplotlib.pyplot as plt
7 |
8 | from mst.loss import (
9 | compute_barkspectrum,
10 | compute_crest_factor,
11 | compute_rms,
12 | compute_stereo_imbalance,
13 | compute_stereo_width,
14 | )
15 |
16 | if __name__ == "__main__":
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("input_a", type=str)
19 | parser.add_argument("input_b", type=str)
20 | parser.add_argument("--output_dir", type=str, default="outputs/compare")
21 | args = parser.parse_args()
22 |
23 | input_a_filename = os.path.basename(args.input_a).split(".")[0]
24 | input_b_filename = os.path.basename(args.input_b).split(".")[0]
25 | run_name = f"{input_a_filename}-{input_b_filename}"
26 | output_dir = os.path.join(args.output_dir, run_name)
27 | os.makedirs(output_dir, exist_ok=True)
28 |
29 | # load audio files
30 | input_a, input_a_sample_rate = torchaudio.load(args.input_a)
31 | input_b, input_b_sample_rate = torchaudio.load(args.input_b)
32 |
33 | # -------------- compute features ----------------
34 | a_barkspectrum = compute_barkspectrum(input_a, sample_rate=44100)
35 | b_barkspectrum = compute_barkspectrum(input_b, sample_rate=44100)
36 |
37 | a_crest_factor = compute_crest_factor(input_a)
38 | b_crest_factor = compute_crest_factor(input_b)
39 |
40 | a_rms = compute_rms(input_a)
41 | b_rms = compute_rms(input_b)
42 |
43 | a_stereo_imbalance = compute_stereo_imbalance(input_a)
44 | b_stereo_imbalance = compute_stereo_imbalance(input_b)
45 |
46 | a_stereo_width = compute_stereo_width(input_a)
47 | b_stereo_width = compute_stereo_width(input_b)
48 |
49 | # -------------- plot features ----------------
50 |
51 | fig, axs = plt.subplots(2, 1, sharex=True, sharey=True)
52 | axs[0].plot(a_barkspectrum[0, :, 0], label="A-mid", color="tab:orange")
53 | axs[0].plot(b_barkspectrum[0, :, 0], label="B-mid", color="tab:blue")
54 | axs[1].plot(a_barkspectrum[0, :, 1], label="A-side", color="tab:orange")
55 | axs[1].plot(b_barkspectrum[0, :, 1], label="B-side", color="tab:blue")
56 | axs[0].legend()
57 | axs[1].legend()
58 | plt.savefig(os.path.join(output_dir, f"bark_spectrum.png"))
59 | plt.close("all")
60 |
--------------------------------------------------------------------------------
/configs/models/naive+feat.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | class_path: mst.system.System
3 | init_args:
4 | #geerate random mixes as specified in Method 1 of the paper ; True for Method 1, False for Method 2
5 | generate_mix: True
6 | #set the epoch values very high to disable any of the effects during training; fx_bus corresponds to reverb module.
7 | active_eq_epoch: 0
8 | active_compressor_epoch: 0
9 | active_fx_bus_epoch: 1000
10 | active_master_bus_epoch: 0
11 | use_track_loss : false
12 | use_mix_loss : true
13 | use_param_loss : false
14 |
15 | #We generate random mixes using the naive random mix function
16 |
17 | mix_fn: mst.mixing.naive_random_mix
18 | mix_console:
19 | class_path: mst.modules.AdvancedMixConsole
20 | init_args:
21 | sample_rate: 44100
22 | input_min_gain_db: -48.0
23 | input_max_gain_db: 48.0
24 | output_min_gain_db: -48.0
25 | output_max_gain_db: 48.0
26 | eq_min_gain_db: -12.0
27 | eq_max_gain_db: 12.0
28 | min_pan: 0.0
29 | max_pan: 1.0
30 | model:
31 | class_path: mst.modules.MixStyleTransferModel
32 | init_args:
33 | track_encoder:
34 | class_path: mst.modules.SpectrogramEncoder
35 | init_args:
36 | embed_dim: 512
37 | n_fft: 2048
38 | hop_length: 512
39 | input_batchnorm: false
40 | mix_encoder:
41 | class_path: mst.modules.SpectrogramEncoder
42 | init_args:
43 | embed_dim: 512
44 | n_fft: 2048
45 | hop_length: 512
46 | input_batchnorm: false
47 | controller:
48 | class_path: mst.modules.TransformerController
49 | init_args:
50 | embed_dim: 512
51 | num_track_control_params: 27
52 | num_fx_bus_control_params: 25
53 | num_master_bus_control_params: 26
54 | num_layers: 12
55 | nhead: 8
56 |
57 | loss:
58 | class_path: mst.loss.AudioFeatureLoss
59 | init_args:
60 | sample_rate: 44100
61 | stem_separation: false
62 | use_clap: False
63 | weights:
64 | - 0.1 # rms
65 | - 0.001 # crest factor
66 | - 1.0 # stereo width
67 | - 1.0 # stereo imbalance
68 | - 0.1 # bark spectrum
69 | - 1.0 # CLAP
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 42
2 | # use ckpt_path to load a model from a checkpoint
3 | # ckpt_path: /import/c4dm-datasets-ext/diffmst_logs_soum/2021-10-06/14-00-00/checkpoints/epoch=0-step=0.ckpt
4 | trainer:
5 | logger:
6 | class_path: pytorch_lightning.loggers.WandbLogger
7 | init_args:
8 | project: DiffMST
9 | #change to the directory where you want to save wandb logs
10 | save_dir: /import/c4dm-datasets-ext/diffmst_logs_soum
11 | enable_checkpointing: true
12 | callbacks:
13 | - class_path: mst.callbacks.audio.LogAudioCallback
14 | - class_path: pytorch_lightning.callbacks.ModelSummary
15 | init_args:
16 | max_depth: 2
17 | #uncomment if you want to run validation on custom examples during training
18 | # - class_path: mst.callbacks.mix.LogReferenceMix
19 | # init_args:
20 | # root_dirs:
21 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/Soren_ALittleLate_Full
22 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/Soren_ALittleLate_Full
23 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/MR0903_Moosmusic_Full
24 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/MR0903_Moosmusic_Full
25 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song3/SaturnSyndicate_CatchTheWave_Full
26 | # ref_mixes:
27 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/ref/Harry Styles - Late Night Talking (Official Video).wav
28 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song1/ref/Poom - Les Voiles (Official Audio).wav
29 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/ref/Justin Timberlake - Can't Stop The Feeling! [Lyrics].wav
30 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song2/ref/Taylor Swift - Shake It Off.wav
31 | # - /import/c4dm-datasets-ext/diffmst_validation/validation_set/song3/ref/Miley Cyrus - Wrecking Ball (Lyrics).wav
32 | default_root_dir: null
33 | gradient_clip_val: 10.0
34 | devices: 1
35 | check_val_every_n_epoch: 1
36 | max_epochs: 800
37 | #change to log less or more often to wandb
38 | log_every_n_steps: 500
39 | accelerator: gpu
40 | strategy: ddp_find_unused_parameters_true
41 | sync_batchnorm: true
42 | precision: 32
43 | enable_model_summary: true
44 | num_sanity_val_steps: 2
45 | benchmark: true
46 | accumulate_grad_batches: 1
47 | #reload_dataloaders_every_n_epochs: 1
48 |
49 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | aiohttp==3.9.5
3 | aiosignal==1.3.1
4 | antlr4-python3-runtime==4.9.3
5 | attrs==23.2.0
6 | audioread==3.0.1
7 | auraloss==0.4.0
8 | bitsandbytes==0.41.1
9 | certifi==2024.7.4
10 | cffi==1.16.0
11 | charset-normalizer==3.3.2
12 | click==8.1.7
13 | contourpy==1.2.1
14 | cycler==0.12.1
15 | dasp-pytorch==0.0.1
16 | decorator==5.1.1
17 | -e git+ssh://git@github.com/sai-soum/Diff-MST.git@f65df9e178b7e4af33c6fb3e728452024fd8169b#egg=DiffMST
18 | docker-pycreds==0.4.0
19 | docstring_parser==0.16
20 | filelock==3.15.4
21 | fonttools==4.53.1
22 | frozenlist==1.4.1
23 | fsspec==2024.6.1
24 | future==1.0.0
25 | gitdb==4.0.11
26 | GitPython==3.1.43
27 | grpcio==1.65.0
28 | hydra-core==1.3.2
29 | idna==3.7
30 | importlib_resources==6.4.0
31 | Jinja2==3.1.4
32 | joblib==1.4.2
33 | jsonargparse==4.31.0
34 | kiwisolver==1.4.5
35 | lazy_loader==0.4
36 | librosa==0.10.2.post1
37 | lightning-utilities==0.11.3.post0
38 | llvmlite==0.43.0
39 | Markdown==3.6
40 | markdown-it-py==3.0.0
41 | MarkupSafe==2.1.5
42 | matplotlib==3.9.1
43 | mdurl==0.1.2
44 | mpmath==1.3.0
45 | msgpack==1.0.8
46 | multidict==6.0.5
47 | networkx==3.3
48 | numba==0.60.0
49 | numpy==1.26.4
50 | nvidia-cublas-cu12==12.1.3.1
51 | nvidia-cuda-cupti-cu12==12.1.105
52 | nvidia-cuda-nvrtc-cu12==12.1.105
53 | nvidia-cuda-runtime-cu12==12.1.105
54 | nvidia-cudnn-cu12==8.9.2.26
55 | nvidia-cufft-cu12==11.0.2.54
56 | nvidia-curand-cu12==10.3.2.106
57 | nvidia-cusolver-cu12==11.4.5.107
58 | nvidia-cusparse-cu12==12.1.0.106
59 | nvidia-nccl-cu12==2.19.3
60 | nvidia-nvjitlink-cu12==12.5.82
61 | nvidia-nvtx-cu12==12.1.105
62 | omegaconf==2.3.0
63 | packaging==24.1
64 | pedalboard==0.8.7
65 | pillow==10.4.0
66 | platformdirs==4.2.2
67 | pooch==1.8.2
68 | protobuf==4.25.3
69 | psutil==6.0.0
70 | pycparser==2.22
71 | Pygments==2.18.0
72 | pyloudnorm==0.1.1
73 | pyparsing==3.1.2
74 | python-dateutil==2.9.0.post0
75 | pytorch-lightning==2.1.4
76 | PyYAML==6.0.1
77 | requests==2.32.3
78 | rich==13.7.1
79 | scikit-learn==1.5.1
80 | scipy==1.12.0
81 | sentry-sdk==2.9.0
82 | setproctitle==1.3.3
83 | six==1.16.0
84 | smmap==5.0.1
85 | soundfile==0.12.1
86 | soxr==0.3.7
87 | sympy==1.13.0
88 | tensorboard==2.17.0
89 | tensorboard-data-server==0.7.2
90 | tensorboardX==2.6.2.2
91 | threadpoolctl==3.5.0
92 | torch==2.2.0
93 | torchaudio==2.2.0
94 | torchmetrics==1.4.0.post0
95 | torchvision==0.17.0
96 | tqdm==4.66.4
97 | triton==2.2.0
98 | typeshed_client==2.5.1
99 | typing_extensions==4.12.2
100 | urllib3==2.2.2
101 | wandb==0.17.4
102 | Werkzeug==3.0.3
103 | yarl==1.9.4
104 |
--------------------------------------------------------------------------------
/tests/test_sepremix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 |
4 | from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
5 | from mst.modules import AdvancedMixConsole
6 |
7 |
8 | class Remixer(torch.nn.Module):
9 | def __init__(self, sample_rate: int) -> None:
10 | super().__init__()
11 | self.sample_rate = sample_rate
12 |
13 | # load source separation model
14 | bundle = HDEMUCS_HIGH_MUSDB_PLUS
15 | self.stem_separator = bundle.get_model()
16 | self.stem_separator.eval()
17 | # get sources list
18 | self.sources_list = list(self.stem_separator.sources)
19 |
20 | # load mix console
21 | self.mix_console = AdvancedMixConsole(sample_rate)
22 |
23 | def forward(self, x: torch.Tensor):
24 | """Take a tensor of mixes, separate, and then remix.
25 |
26 | Args:
27 | x (torch.Tensor): Tensor of mixes with shape (batch, 2, samples)
28 |
29 | Returns:
30 | remix (torch.Tensor): Tensor of remixes with shape (batch, 2, samples)
31 | sum_mix (torch.Tensor): Tensor of mixes from separeted outputs with shape (batch, 2, samples)
32 | track_params (torch.Tensor): Tensor of track params with shape (batch, 8, num_track_control_params)
33 | fx_bus_params (torch.Tensor): Tensor of fx bus params with shape (batch, num_fx_bus_control_params)
34 | master_bus_params (torch.Tensor): Tensor of master bus params with shape (batch, num_master_bus_control_params)
35 | """
36 | bs, chs, seq_len = x.size()
37 |
38 | # separate
39 | sources = self.stem_separator(x) # bs, 4, 2, seq_len
40 | sum_mix = sources.sum(dim=1) # bs, 2, seq_len
41 |
42 | # convert sources to mono tracks
43 | tracks = sources.view(bs, 8, -1)
44 |
45 | # provide some headroom before mixing
46 | tracks *= 10 ** (-32.0 / 20.0)
47 |
48 | # generate random mix parameters
49 | track_params = torch.rand(bs, 8, self.mix_console.num_track_control_params)
50 | fx_bus_params = torch.rand(bs, self.mix_console.num_fx_bus_control_params)
51 | master_bus_params = torch.rand(
52 | bs, self.mix_console.num_master_bus_control_params
53 | )
54 |
55 | # the forward expects params in range of (0,1)
56 | result = self.mix_console(
57 | tracks,
58 | track_params,
59 | fx_bus_params,
60 | master_bus_params,
61 | )
62 |
63 | # get the remix
64 | remix = result[1]
65 |
66 | return remix, sum_mix, track_params, fx_bus_params, master_bus_params
67 |
68 |
69 | if __name__ == "__main__":
70 | # create the remixer
71 | remixer = Remixer(44100)
72 |
73 | # get a mix
74 | mix, sample_rate = torchaudio.load("outputs/output/ref_mix.wav")
75 |
76 | mix = mix.unsqueeze(0)
77 |
78 | # remix
79 | remix, sum_mix = remixer(mix)
80 |
81 | # peak normalize
82 | remix = remix / remix.abs().max()
83 |
84 | torchaudio.save("outputs/output/remix.wav", remix.squeeze(0), sample_rate=44100)
85 | torchaudio.save(
86 | "outputs/output/separated_sum_mix.wav", sum_mix.squeeze(0), sample_rate=44100
87 | )
88 |
--------------------------------------------------------------------------------
/data/instrument_name2id.json:
--------------------------------------------------------------------------------
1 | {
2 | "silence": 0,
3 | "accordion": 1,
4 | "acoustic guitar": 2,
5 | "alto saxophone": 3,
6 | "auxiliary percussion": 4,
7 | "bamboo flute": 5,
8 | "banjo": 6,
9 | "baritone saxophone": 7,
10 | "bass clarinet": 8,
11 | "bass drum": 9,
12 | "bassoon": 10,
13 | "bongo": 11,
14 | "brass section": 12,
15 | "cabasa": 13,
16 | "castanet": 14,
17 | "cello": 15,
18 | "cello section": 16,
19 | "chimes": 17,
20 | "claps": 18,
21 | "clarinet": 19,
22 | "clarinet section": 20,
23 | "clean electric guitar": 21,
24 | "cornet": 22,
25 | "cowbell": 23,
26 | "cymbal": 24,
27 | "darbuka": 25,
28 | "distorted electric guitar": 26,
29 | "dizi": 27,
30 | "double bass": 28,
31 | "doumbek": 29,
32 | "drum machine": 30,
33 | "drum set": 31,
34 | "electric bass": 32,
35 | "electric piano": 33,
36 | "electronic organ": 34,
37 | "erhu": 35,
38 | "euphonium": 36,
39 | "female singer": 37,
40 | "flute": 38,
41 | "flute section": 39,
42 | "french horn": 40,
43 | "french horn section": 41,
44 | "fx/processed sound": 42,
45 | "glockenspiel": 43,
46 | "gong": 44,
47 | "gu": 45,
48 | "guiro": 46,
49 | "guzheng": 47,
50 | "harmonica": 48,
51 | "harp": 49,
52 | "high hat": 50,
53 | "horn section": 51,
54 | "kick drum": 52,
55 | "lap steel guitar": 53,
56 | "liuqin": 54,
57 | "male rapper": 55,
58 | "male singer": 56,
59 | "male speaker": 57,
60 | "mandolin": 58,
61 | "melodica": 59,
62 | "oboe": 60,
63 | "oud": 61,
64 | "piano": 62,
65 | "piccolo": 63,
66 | "sampler": 64,
67 | "scratches": 65,
68 | "shaker": 66,
69 | "sleigh bells": 68,
70 | "snare drum": 69,
71 | "soprano saxophone": 70,
72 | "string section": 71,
73 | "synthesizer": 72,
74 | "tabla": 73,
75 | "tack piano": 74,
76 | "tambourine": 75,
77 | "tenor saxophone": 76,
78 | "timpani": 77,
79 | "toms": 78,
80 | "trombone": 79,
81 | "trombone section": 80,
82 | "trumpet": 81,
83 | "trumpet section": 82,
84 | "tuba": 83,
85 | "vibraphone": 84,
86 | "viola": 85,
87 | "viola section": 86,
88 | "violin": 87,
89 | "violin section": 88,
90 | "vocalists": 89,
91 | "whistle": 90,
92 | "yangqin": 91,
93 | "zhongruan": 92,
94 | "Main system": 93,
95 | "backingvox": 94,
96 | "vox": 95,
97 | "synth": 96,
98 | "loop": 97,
99 | "percussion": 98,
100 | "kick": 99,
101 | "sfx": 100,
102 | "drum": 101,
103 | "overhead": 102,
104 | "snare": 103,
105 | "bass": 104,
106 | "conga": 105,
107 | "guitar": 106,
108 | "hi hat": 107,
109 | "clap": 108,
110 | "tom": 109,
111 | "misc": 110,
112 | "roommic": 111,
113 | "organ": 112,
114 | "keys": 113,
115 | "vocals": 114,
116 | "horn": 115,
117 | "bell": 116,
118 | "fiddle": 117,
119 | "string": 118,
120 | "stick": 119,
121 | "brass": 120,
122 | "choir": 121,
123 | "crash": 122,
124 | "rhodes": 123,
125 | "noise": 124,
126 | "hit": 125,
127 | "ukelele": 126,
128 | "bagpipe": 127,
129 | "saxophone": 128,
130 | "vocoder": 129,
131 | "marimba": 130,
132 | "cajon": 131,
133 | "soprano": 132,
134 | "alto": 133,
135 | "tenor": 134,
136 | "triangle": 135,
137 | "chorus": 136,
138 | " gong": 137,
139 | "ukulele": 138,
140 | "bongos": 139,
141 | "quartet": 140,
142 | "xylophone": 141,
143 | "woodwind": 142,
144 | "sitar": 143,
145 | "Main System": 144
146 | }
--------------------------------------------------------------------------------
/tests/test_peq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | import matplotlib.pyplot as plt
4 | from scipy.signal import savgol_filter
5 | from dasp_pytorch.functional import parametric_eq
6 |
7 |
8 | bs = 2
9 | chs = 2
10 | seq_len = 131072
11 | sample_rate = 44100
12 |
13 | x = torch.zeros(bs, chs, seq_len)
14 | x = x / x.abs().max().clamp(min=1e-8)
15 | x *= 10 ** (-24 / 20)
16 |
17 | low_shelf_gain_db = torch.tensor([0.0, 0.0])
18 | low_shelf_cutoff_freq = torch.tensor([20.0, 20.0])
19 | low_shelf_q_factor = torch.tensor([0.707, 0.707])
20 | band0_gain_db = torch.tensor([0.0, 12.0])
21 | band0_cutoff_freq = torch.tensor([100.0, 100.0])
22 | band0_q_factor = torch.tensor([0.707, 0.707])
23 | band1_gain_db = torch.tensor([12.0, 0.0])
24 | band1_cutoff_freq = torch.tensor([1000.0, 1000.0])
25 | band1_q_factor = torch.tensor([6.0, 0.707])
26 | band2_gain_db = torch.tensor([-6.0, 0.0])
27 | band2_cutoff_freq = torch.tensor([10000.0, 10000.0])
28 | band2_q_factor = torch.tensor([3.0, 0.707])
29 | band3_gain_db = torch.tensor([0.0, 0.0])
30 | band3_cutoff_freq = torch.tensor([12000.0, 12000.0])
31 | band3_q_factor = torch.tensor([0.707, 0.707])
32 | high_shelf_gain_db = torch.tensor([0.0, 0.0])
33 | high_shelf_cutoff_freq = torch.tensor([12000.0, 12000.0])
34 | high_shelf_q_factor = torch.tensor([0.707, 0.707])
35 |
36 | # reshape and repeat to match batch size
37 | low_shelf_gain_db = low_shelf_gain_db.view(1, chs).repeat(bs, 1)
38 | low_shelf_cutoff_freq = low_shelf_cutoff_freq.view(1, chs).repeat(bs, 1)
39 | low_shelf_q_factor = low_shelf_q_factor.view(1, chs).repeat(bs, 1)
40 | band0_gain_db = band0_gain_db.view(1, chs).repeat(bs, 1)
41 | band0_cutoff_freq = band0_cutoff_freq.view(1, chs).repeat(bs, 1)
42 | band0_q_factor = band0_q_factor.view(1, chs).repeat(bs, 1)
43 | band1_gain_db = band1_gain_db.view(1, chs).repeat(bs, 1)
44 | band1_cutoff_freq = band1_cutoff_freq.view(1, chs).repeat(bs, 1)
45 | band1_q_factor = band1_q_factor.view(1, chs).repeat(bs, 1)
46 | band2_gain_db = band2_gain_db.view(1, chs).repeat(bs, 1)
47 | band2_cutoff_freq = band2_cutoff_freq.view(1, chs).repeat(bs, 1)
48 | band2_q_factor = band2_q_factor.view(1, chs).repeat(bs, 1)
49 | band3_gain_db = band3_gain_db.view(1, chs).repeat(bs, 1)
50 | band3_cutoff_freq = band3_cutoff_freq.view(1, chs).repeat(bs, 1)
51 | band3_q_factor = band3_q_factor.view(1, chs).repeat(bs, 1)
52 | high_shelf_gain_db = high_shelf_gain_db.view(1, chs).repeat(bs, 1)
53 | high_shelf_cutoff_freq = high_shelf_cutoff_freq.view(1, chs).repeat(bs, 1)
54 | high_shelf_q_factor = high_shelf_q_factor.view(1, chs).repeat(bs, 1)
55 |
56 | y = parametric_eq(
57 | x,
58 | sample_rate,
59 | low_shelf_gain_db,
60 | low_shelf_cutoff_freq,
61 | low_shelf_q_factor,
62 | band0_gain_db,
63 | band0_cutoff_freq,
64 | band0_q_factor,
65 | band1_gain_db,
66 | band1_cutoff_freq,
67 | band1_q_factor,
68 | band2_gain_db,
69 | band2_cutoff_freq,
70 | band2_q_factor,
71 | band3_gain_db,
72 | band3_cutoff_freq,
73 | band3_q_factor,
74 | high_shelf_gain_db,
75 | high_shelf_cutoff_freq,
76 | high_shelf_q_factor,
77 | )
78 |
79 | print(y)
80 | print(y.shape)
81 |
82 | fig, axs = plt.subplots(chs, 1, figsize=(10, 6))
83 | for ch in range(chs):
84 | h_in = 20 * torch.log10(torch.fft.rfft(x[0, ch, :], dim=-1).abs() + 1e-8)
85 | h_out = 20 * torch.log10(torch.fft.rfft(y[0, ch, :], dim=-1).abs() + 1e-8)
86 |
87 | h_in_sm = savgol_filter(h_in.squeeze().numpy(), 255, 3)
88 | h_out_sm = savgol_filter(h_out.squeeze().numpy(), 255, 3)
89 |
90 | axs[ch].plot(h_out_sm - h_in_sm, label="input")
91 | axs[ch].set_xscale("log")
92 | axs[ch].legend()
93 |
94 | plt.savefig("test-peq.png", dpi=300)
95 |
96 | # torchaudio.save("test-peq-in.wav", x.view(1, -1), sample_rate)
97 | # torchaudio.save("test-peq-out.wav", y.view(1, -1), sample_rate)
98 |
--------------------------------------------------------------------------------
/mst/callbacks/plotting.py:
--------------------------------------------------------------------------------
1 | import io
2 | import torch
3 | import librosa
4 | import PIL.Image
5 | import numpy as np
6 | import librosa.display
7 | import matplotlib.pyplot as plt
8 |
9 | from typing import Any
10 | from torch.functional import Tensor
11 | from torchvision.transforms import ToTensor
12 | from sklearn.metrics import ConfusionMatrixDisplay
13 |
14 |
15 | def plot_spectrograms(
16 | input: torch.Tensor,
17 | target: torch.Tensor,
18 | estimate: torch.Tensor,
19 | n_fft: int = 4096,
20 | hop_length: int = 1024,
21 | sample_rate: float = 48000,
22 | filename: Any = None,
23 | ):
24 | """Create a side-by-side plot of the attention weights and the spectrogram.
25 | Args:
26 | input (torch.Tensor): Input audio tensor with shape [1 x samples].
27 | target (torch.Tensor): Target audio tensor with shape [1 x samples].
28 | estimate (torch.Tensor): Estimate of the target audio with shape [1 x samples].
29 | n_fft (int, optional): Analysis FFT size.
30 | hop_length (int, optional): Analysis hop length.
31 | sample_rate (float, optional): Audio sample rate.
32 | filename (str, optional): If a filename is supplied, the plot is saved to disk.
33 | """
34 | # use librosa to take stft
35 | x_stft = librosa.stft(
36 | input.view(-1).numpy(),
37 | n_fft=n_fft,
38 | hop_length=hop_length,
39 | )
40 | x_D = librosa.amplitude_to_db(
41 | np.abs(x_stft),
42 | ref=np.max,
43 | )
44 |
45 | y_stft = librosa.stft(
46 | target.view(-1).numpy(),
47 | n_fft=n_fft,
48 | hop_length=hop_length,
49 | )
50 | y_D = librosa.amplitude_to_db(
51 | np.abs(y_stft),
52 | ref=np.max,
53 | )
54 |
55 | y_hat_stft = librosa.stft(
56 | estimate.view(-1).numpy(),
57 | n_fft=n_fft,
58 | hop_length=hop_length,
59 | )
60 | y_hat_D = librosa.amplitude_to_db(
61 | np.abs(y_hat_stft),
62 | ref=np.max,
63 | )
64 |
65 | fig, axs = plt.subplots(
66 | nrows=3,
67 | sharex=True,
68 | figsize=(7, 6),
69 | )
70 |
71 | x_img = librosa.display.specshow(
72 | x_D,
73 | y_axis="log",
74 | x_axis="time",
75 | sr=sample_rate,
76 | hop_length=hop_length,
77 | ax=axs[0],
78 | )
79 |
80 | y_img = librosa.display.specshow(
81 | y_D,
82 | y_axis="log",
83 | x_axis="time",
84 | sr=sample_rate,
85 | hop_length=hop_length,
86 | ax=axs[1],
87 | )
88 |
89 | y_hat_img = librosa.display.specshow(
90 | y_hat_D,
91 | y_axis="log",
92 | x_axis="time",
93 | sr=sample_rate,
94 | hop_length=hop_length,
95 | ax=axs[2],
96 | )
97 |
98 | plt.tight_layout()
99 |
100 | if filename is not None:
101 | plt.savefig(filename, dpi=300)
102 |
103 | return fig2img(fig)
104 |
105 |
106 | def plot_confusion_matrix(e_hat, e, labels=None, filename=None):
107 | fig, ax = plt.subplots(figsize=(10, 10))
108 | cm = ConfusionMatrixDisplay.from_predictions(
109 | e,
110 | e_hat,
111 | labels=np.arange(len(labels)),
112 | display_labels=labels,
113 | )
114 | cm.plot(ax=ax, xticks_rotation="vertical")
115 |
116 | plt.tight_layout()
117 | if filename is not None:
118 | plt.savefig(filename, dpi=300)
119 |
120 | return fig2img(fig)
121 |
122 |
123 | def fig2img(fig, dpi=120):
124 | """Convert a matplotlib figure to JPEG to be show in Tensorboard."""
125 | buf = io.BytesIO()
126 | fig.savefig(buf, format="jpeg", dpi=dpi)
127 | buf.seek(0)
128 | image = PIL.Image.open(buf)
129 | image = ToTensor()(image)
130 | plt.close("all")
131 | return image
132 |
--------------------------------------------------------------------------------
/mst/callbacks/audio.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import wandb
3 | import numpy as np
4 | import pytorch_lightning as pl
5 |
6 |
7 | from mst.callbacks.plotting import plot_spectrograms
8 |
9 |
10 | class LogAudioCallback(pl.callbacks.Callback):
11 | def __init__(
12 | self,
13 | num_batches: int = 8,
14 | peak_normalize: bool = True,
15 | sample_rate: int = 44100,
16 | ):
17 | super().__init__()
18 | self.num_batches = num_batches
19 | self.peak_normalize = peak_normalize
20 | self.sample_rate = sample_rate
21 |
22 | def on_validation_batch_end(
23 | self,
24 | trainer,
25 | pl_module,
26 | outputs,
27 | batch,
28 | batch_idx,
29 | ):
30 | """Called when the validation batch ends."""
31 | if outputs is not None:
32 | num_examples = outputs["ref_mix_a"].shape[0]
33 | if batch_idx < self.num_batches:
34 | for sample_idx in range(num_examples):
35 | self.log_audio(
36 | outputs,
37 | batch_idx,
38 | sample_idx,
39 | pl_module.mix_console.sample_rate,
40 | trainer.global_step,
41 | trainer.logger,
42 | f"Epoch {trainer.current_epoch}",
43 | )
44 |
45 | def log_audio(
46 | self,
47 | outputs,
48 | batch_idx: int,
49 | sample_idx: int,
50 | sample_rate: int,
51 | global_step: int,
52 | logger,
53 | caption: str,
54 | n_fft: int = 4096,
55 | hop_length: int = 1024,
56 | ):
57 | audio_files = []
58 | audio_keys = []
59 | total_samples = 0
60 | # put all audio in file
61 | for key, audio in outputs.items():
62 | if "dict" in key: # skip parameters
63 | continue
64 |
65 | x = audio[sample_idx, ...].float()
66 | x = x.permute(1, 0)
67 | # x /= x.abs().max()
68 | audio_files.append(x)
69 | audio_keys.append(key)
70 | total_samples += x.shape[0]
71 |
72 | y = torch.zeros(total_samples + int(len(audio_keys) * sample_rate), 2)
73 | name = f"{batch_idx}_{sample_idx}"
74 | start = 0
75 | for x, key in zip(audio_files, audio_keys):
76 | end = start + x.shape[0]
77 | y[start:end, :] = x
78 | start = end + int(sample_rate)
79 | name += key + "-"
80 |
81 | logger.experiment.log(
82 | {
83 | f"{name}": wandb.Audio(
84 | y.numpy(),
85 | caption=caption,
86 | sample_rate=int(sample_rate),
87 | )
88 | }
89 | )
90 |
91 | # now try to log parameters
92 | pred_track_param_dict = outputs["pred_track_param_dict"]
93 | ref_track_param_dict = outputs["ref_track_param_dict"]
94 |
95 | pred_fx_bus_param_dict = outputs["pred_fx_bus_param_dict"]
96 | ref_fx_bus_param_dict = outputs["ref_fx_bus_param_dict"]
97 |
98 | pred_master_bus_param_dict = outputs["pred_master_bus_param_dict"]
99 | ref_master_bus_param_dict = outputs["ref_master_bus_param_dict"]
100 |
101 | effect_names = list(pred_track_param_dict.keys())
102 |
103 | column_names = None
104 | rows = []
105 | for effect_name in effect_names:
106 | param_names = list(pred_track_param_dict[effect_name].keys())
107 | for param_name in param_names:
108 | pred_param_val = pred_track_param_dict[effect_name][param_name]
109 | ref_param_val = ref_track_param_dict[effect_name][param_name]
110 |
111 | row = []
112 | row_name = f"{effect_name}.{param_name}"
113 | row.append(row_name)
114 |
115 | if column_names is None:
116 | column_names = ["parameter"]
117 | for i in range(pred_param_val.shape[1]):
118 | column_names.append(f"{i}_pred")
119 | column_names.append(f"{i}_ref")
120 | # column_names.append("master_bus_pred")
121 | # column_names.append("master_bus_ref")
122 |
123 | for i in range(pred_param_val.shape[1]):
124 | row.append(pred_param_val[sample_idx, i].item())
125 | row.append(ref_param_val[sample_idx, i].item())
126 |
127 | # row.append(pred_master_bus_param_dict[effect_name][batch_idx].item())
128 |
129 | rows.append(row)
130 |
131 | wandb_table = wandb.Table(data=rows, columns=column_names)
132 | logger.experiment.log(
133 | {f"batch={batch_idx}_sample={sample_idx}_parameters": wandb_table}
134 | )
135 |
--------------------------------------------------------------------------------
/scripts/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import torchaudio
5 | import pyloudnorm as pyln
6 | import pytorch_lightning as pl
7 |
8 | from mst.system import System
9 | from mst.utils import load_model
10 |
11 | # load a pretrained model and create a mix
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "config_path",
17 | type=str,
18 | help="Path to config.yaml for pretrained model checkpoint.",
19 | )
20 | parser.add_argument(
21 | "ckpt_path",
22 | type=str,
23 | help="Path to pretrained model checkpoint.",
24 | )
25 | parser.add_argument(
26 | "track_dir",
27 | type=str,
28 | help="Path to directory containing tracks.",
29 | )
30 | parser.add_argument(
31 | "ref_mix",
32 | type=str,
33 | help="Path to reference mix.",
34 | )
35 | parser.add_argument(
36 | "--output_dir",
37 | type=str,
38 | help="Path to output directory.",
39 | default="output",
40 | )
41 | parser.add_argument(
42 | "--use_gpu",
43 | action="store_true",
44 | help="Whether to use GPU.",
45 | )
46 | parser.add_argument(
47 | "--target_track_lufs_db",
48 | type=float,
49 | default=-32.0,
50 | )
51 | args = parser.parse_args()
52 |
53 | # load model
54 | model = load_model(
55 | args.config_path,
56 | args.ckpt_path,
57 | map_location="gpu" if args.use_gpu else "cpu",
58 | )
59 | sample_rate = 44100
60 | meter = pyln.Meter(sample_rate)
61 |
62 | print(f"Loaded model: {os.path.basename(args.ckpt_path)}\n")
63 |
64 | # load multitracks (wav files only)
65 | track_paths = [
66 | os.path.join(args.track_dir, f)
67 | for f in os.listdir(args.track_dir)
68 | if ".wav" in f
69 | ]
70 |
71 | tracks = []
72 | max_track_len = 0
73 | print("Loading tracks...")
74 | for idx, track_path in enumerate(track_paths):
75 | track, track_sr = torchaudio.load(
76 | track_path, frame_offset=262144, num_frames=262144 * 2
77 | )
78 | if track_sr != sample_rate:
79 | track = torchaudio.functional.resample(track, track_sr, sample_rate)
80 |
81 | track = track[:, :262144]
82 |
83 | # loudness normalization
84 | track_lufs_db = meter.integrated_loudness(track.permute(1, 0).numpy())
85 | delta_lufs_db = torch.tensor(
86 | [args.target_track_lufs_db - track_lufs_db]
87 | ).float()
88 | gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0)
89 | track = gain_lin * track
90 |
91 | tracks.append(track)
92 | print(
93 | f"({idx+1}/{len(track_paths)}): {os.path.basename(track_path)} {track.shape}"
94 | )
95 |
96 | # correct length of tracks to be the same
97 | max_track_len = max([t.shape[1] for t in tracks])
98 | for idx, track in enumerate(tracks):
99 | chs = track.shape[0]
100 | if track.shape[1] < max_track_len:
101 | pad = torch.zeros((chs, max_track_len - track.shape[1]))
102 | tracks[idx] = torch.cat([track, pad], dim=1)
103 |
104 | tracks = torch.cat(tracks, dim=0)
105 |
106 | # load reference track
107 | ref_mix, ref_sr = torchaudio.load(args.ref_mix)
108 | if ref_sr != sample_rate:
109 | ref_mix = torchaudio.functional.resample(ref_mix, ref_sr, sample_rate)
110 | ref_mix = ref_mix[:, :262144]
111 | print(f"\nLoaded reference mix: {os.path.basename(args.ref_mix)}.")
112 | print(f"tracks: {tracks.shape} ref_mix: {ref_mix.shape}\n")
113 |
114 | # create sum mix
115 | sum_mix = torch.sum(tracks, dim=0, keepdim=True)
116 |
117 | # create mix with model
118 | with torch.no_grad():
119 | result = model(
120 | tracks.unsqueeze(0),
121 | ref_mix.unsqueeze(0),
122 | use_track_gain=True,
123 | use_track_panner=True,
124 | use_track_eq=False,
125 | use_track_compressor=False,
126 | use_fx_bus=False,
127 | use_master_bus=False,
128 | )
129 | (
130 | mixed_tracks,
131 | mix,
132 | track_param_dict,
133 | fx_bus_param_dict,
134 | master_bus_param_dict,
135 | ) = result
136 | mix = mix.squeeze(0)
137 |
138 | mix /= torch.max(torch.abs(mix)) # peak normalize
139 | sum_mix /= torch.max(torch.abs(sum_mix)) # peak normalize
140 |
141 | # save mix
142 | os.makedirs(args.output_dir, exist_ok=True)
143 | torchaudio.save(os.path.join(args.output_dir, "pred_mix.wav"), mix, sample_rate)
144 | torchaudio.save(os.path.join(args.output_dir, "ref_mix.wav"), ref_mix, sample_rate)
145 | torchaudio.save(os.path.join(args.output_dir, "sum_mix.wav"), sum_mix, sample_rate)
146 | print(f"Saved mixes to {args.output_dir}.\n")
147 |
--------------------------------------------------------------------------------
/scripts/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import torchaudio
5 | from tqdm import tqdm
6 |
7 |
8 | def process_mixing_secrets():
9 | out_dir = "/import/c4dm-datasets-ext/mixing-secrets-mono/"
10 | # process LibriSpeech dataset to 48 kHz sample rate
11 | root_dir = "/import/c4dm-datasets-ext/mixing-secrets/"
12 |
13 | song_dirs = glob.glob(os.path.join(root_dir, "*"))
14 |
15 | for song_dir in song_dirs:
16 | song_out_dir = os.path.join(out_dir, os.path.basename(song_dir))
17 | if os.path.exists(song_out_dir):
18 | continue
19 | os.makedirs(song_out_dir, exist_ok=True)
20 | song_track_dir = os.path.join(song_dir, "tracks")
21 | song_out_track_dir = os.path.join(song_out_dir, "tracks")
22 | os.makedirs(song_out_track_dir, exist_ok=True)
23 | print(song_track_dir)
24 |
25 | song_track_sub_dirs = glob.glob(os.path.join(song_track_dir, "*"))
26 |
27 | track_load_dir = None
28 | for song_track_sub_dir in song_track_sub_dirs:
29 | if "_Full" in song_track_sub_dir:
30 | track_load_dir = song_track_sub_dir
31 |
32 | if track_load_dir is None:
33 | track_load_dir = song_track_sub_dirs[0]
34 |
35 | track_filepaths = glob.glob(os.path.join(track_load_dir, "*.wav"))
36 |
37 | for filepath in tqdm(track_filepaths):
38 | out_filepath = os.path.join(song_out_track_dir, os.path.basename(filepath))
39 | # convert the sample rate to 48 kHz using ffmpeg
40 | try:
41 | x, sr = torchaudio.load(filepath)
42 | except Exception as e:
43 | print(e)
44 | continue
45 |
46 | if sr != 48000:
47 | x = torchaudio.functional.resample(x, sr, 48000)
48 |
49 | if x.shape[0] == 2:
50 | torchaudio.save(
51 | out_filepath.replace(".wav", "_L.wav"),
52 | x[0:1, :],
53 | 48000,
54 | encoding="PCM_S",
55 | bits_per_sample=16,
56 | )
57 | torchaudio.save(
58 | out_filepath.replace(".wav", "_R.wav"),
59 | x[1:2, :],
60 | 48000,
61 | encoding="PCM_S",
62 | bits_per_sample=16,
63 | )
64 | else:
65 | torchaudio.save(
66 | out_filepath,
67 | x,
68 | 48000,
69 | encoding="PCM_S",
70 | bits_per_sample=16,
71 | )
72 |
73 |
74 | def process_medleydb():
75 | out_dir = "/import/c4dm-datasets-ext/medleydb-mono/"
76 | # process LibriSpeech dataset to 48 kHz sample rate
77 | root_dirs = [
78 | "/import/c4dm-datasets/MedleyDB_V1/V1",
79 | "/import/c4dm-datasets/MedleyDB_V2/V2",
80 | ]
81 |
82 | for root_dir in root_dirs:
83 | song_dirs = glob.glob(os.path.join(root_dir, "*"))
84 |
85 | for song_dir in song_dirs:
86 | print(song_dir)
87 | song_out_dir = os.path.join(out_dir, os.path.basename(song_dir))
88 | # if os.path.exists(song_out_dir):
89 | # continue
90 | os.makedirs(song_out_dir, exist_ok=True)
91 | song_track_dir = os.path.join(song_dir, f"{os.path.basename(song_dir)}_RAW")
92 | song_out_track_dir = os.path.join(song_out_dir, f"tracks")
93 | os.makedirs(song_out_track_dir, exist_ok=True)
94 | print(song_track_dir)
95 |
96 | track_filepaths = glob.glob(os.path.join(song_track_dir, "*.wav"))
97 |
98 | for filepath in tqdm(track_filepaths):
99 | out_filepath = os.path.join(
100 | song_out_track_dir, os.path.basename(filepath)
101 | )
102 | # convert the sample rate to 48 kHz using ffmpeg
103 | try:
104 | x, sr = torchaudio.load(filepath)
105 | except Exception as e:
106 | print(e)
107 | continue
108 |
109 | if sr != 48000:
110 | x = torchaudio.functional.resample(x, sr, 48000)
111 |
112 | if x.shape[0] == 2:
113 | torchaudio.save(
114 | out_filepath.replace(".wav", "_L.wav"),
115 | x[0:1, :],
116 | 48000,
117 | encoding="PCM_S",
118 | bits_per_sample=16,
119 | )
120 | torchaudio.save(
121 | out_filepath.replace(".wav", "_R.wav"),
122 | x[1:2, :],
123 | 48000,
124 | encoding="PCM_S",
125 | bits_per_sample=16,
126 | )
127 | else:
128 | torchaudio.save(
129 | out_filepath,
130 | x,
131 | 48000,
132 | encoding="PCM_S",
133 | bits_per_sample=16,
134 | )
135 |
136 |
137 | if __name__ == "__main__":
138 | dataset = "medleydb"
139 |
140 | if dataset == "mixing-secrets":
141 | process_mixing_secrets()
142 | elif dataset == "medleydb":
143 | process_medleydb()
144 | else:
145 | raise NotImplementedError
146 |
--------------------------------------------------------------------------------
/mst/filter.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import warnings
4 |
5 | # https://github.com/pytorch/audio/blob/d9942bae249329bd8c8bf5c92f0f108595fcb84f/torchaudio/functional/functional.py#L495
6 |
7 |
8 | def _create_triangular_filterbank(
9 | all_freqs: torch.Tensor,
10 | f_pts: torch.Tensor,
11 | ) -> torch.Tensor:
12 | """Create a triangular filter bank.
13 |
14 | Args:
15 | all_freqs (Tensor): STFT freq points of size (`n_freqs`).
16 | f_pts (Tensor): Filter mid points of size (`n_filter`).
17 |
18 | Returns:
19 | fb (Tensor): The filter bank of size (`n_freqs`, `n_filter`).
20 | """
21 | # Adopted from Librosa
22 | # calculate the difference between each filter mid point and each stft freq point in hertz
23 | f_diff = f_pts[1:] - f_pts[:-1] # (n_filter + 1)
24 | slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_filter + 2)
25 | # create overlapping triangles
26 | zero = torch.zeros(1)
27 | down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_filter)
28 | up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_filter)
29 | fb = torch.max(zero, torch.min(down_slopes, up_slopes))
30 |
31 | return fb
32 |
33 |
34 | # https://github.com/pytorch/audio/blob/d9942bae249329bd8c8bf5c92f0f108595fcb84f/torchaudio/prototype/functional/functional.py#L6
35 |
36 |
37 | def _hz_to_bark(freqs: float, bark_scale: str = "traunmuller") -> float:
38 | r"""Convert Hz to Barks.
39 |
40 | Args:
41 | freqs (float): Frequencies in Hz
42 | bark_scale (str, optional): Scale to use: ``traunmuller``, ``schroeder`` or ``wang``. (Default: ``traunmuller``)
43 |
44 | Returns:
45 | barks (float): Frequency in Barks
46 | """
47 |
48 | if bark_scale not in ["schroeder", "traunmuller", "wang"]:
49 | raise ValueError(
50 | 'bark_scale should be one of "schroeder", "traunmuller" or "wang".'
51 | )
52 |
53 | if bark_scale == "wang":
54 | return 6.0 * math.asinh(freqs / 600.0)
55 | elif bark_scale == "schroeder":
56 | return 7.0 * math.asinh(freqs / 650.0)
57 | # Traunmuller Bark scale
58 | barks = ((26.81 * freqs) / (1960.0 + freqs)) - 0.53
59 | # Bark value correction
60 | if barks < 2:
61 | barks += 0.15 * (2 - barks)
62 | elif barks > 20.1:
63 | barks += 0.22 * (barks - 20.1)
64 |
65 | return barks
66 |
67 |
68 | def _bark_to_hz(barks: torch.Tensor, bark_scale: str = "traunmuller") -> torch.Tensor:
69 | """Convert bark bin numbers to frequencies.
70 |
71 | Args:
72 | barks (torch.Tensor): Bark frequencies
73 | bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
74 |
75 | Returns:
76 | freqs (torch.Tensor): Barks converted in Hz
77 | """
78 |
79 | if bark_scale not in ["schroeder", "traunmuller", "wang"]:
80 | raise ValueError(
81 | 'bark_scale should be one of "traunmuller", "schroeder" or "wang".'
82 | )
83 |
84 | if bark_scale == "wang":
85 | return 600.0 * torch.sinh(barks / 6.0)
86 | elif bark_scale == "schroeder":
87 | return 650.0 * torch.sinh(barks / 7.0)
88 | # Bark value correction
89 | if any(barks < 2):
90 | idx = barks < 2
91 | barks[idx] = (barks[idx] - 0.3) / 0.85
92 | elif any(barks > 20.1):
93 | idx = barks > 20.1
94 | barks[idx] = (barks[idx] + 4.422) / 1.22
95 |
96 | # Traunmuller Bark scale
97 | freqs = 1960 * ((barks + 0.53) / (26.28 - barks))
98 |
99 | return freqs
100 |
101 |
102 | def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12):
103 | a440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
104 | return torch.log2(freqs / (a440 / 16))
105 |
106 |
107 | def barkscale_fbanks(
108 | n_freqs: int,
109 | f_min: float,
110 | f_max: float,
111 | n_barks: int,
112 | sample_rate: int,
113 | bark_scale: str = "traunmuller",
114 | ) -> torch.Tensor:
115 | r"""Create a frequency bin conversion matrix.
116 |
117 | .. devices:: CPU
118 |
119 | .. properties:: TorchScript
120 |
121 | .. image:: https://download.pytorch.org/torchaudio/doc-assets/bark_fbanks.png
122 | :alt: Visualization of generated filter bank
123 |
124 | Args:
125 | n_freqs (int): Number of frequencies to highlight/apply
126 | f_min (float): Minimum frequency (Hz)
127 | f_max (float): Maximum frequency (Hz)
128 | n_barks (int): Number of mel filterbanks
129 | sample_rate (int): Sample rate of the audio waveform
130 | bark_scale (str, optional): Scale to use: ``traunmuller``,``schroeder`` or ``wang``. (Default: ``traunmuller``)
131 |
132 | Returns:
133 | torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_barks``)
134 | meaning number of frequencies to highlight/apply to x the number of filterbanks.
135 | Each column is a filterbank so that assuming there is a matrix A of
136 | size (..., ``n_freqs``), the applied result would be
137 | ``A * barkscale_fbanks(A.size(-1), ...)``.
138 |
139 | """
140 |
141 | # freq bins
142 | all_freqs = torch.linspace(0, sample_rate // 2, n_freqs)
143 |
144 | # calculate bark freq bins
145 | m_min = _hz_to_bark(f_min, bark_scale=bark_scale)
146 | m_max = _hz_to_bark(f_max, bark_scale=bark_scale)
147 |
148 | m_pts = torch.linspace(m_min, m_max, n_barks + 2)
149 | f_pts = _bark_to_hz(m_pts, bark_scale=bark_scale)
150 |
151 | # create filterbank
152 | fb = _create_triangular_filterbank(all_freqs, f_pts)
153 |
154 | if (fb.max(dim=0).values == 0.0).any():
155 | warnings.warn(
156 | "At least one bark filterbank has all zero values. "
157 | f"The value for `n_barks` ({n_barks}) may be set too high. "
158 | f"Or, the value for `n_freqs` ({n_freqs}) may be set too low."
159 | )
160 |
161 | return fb
162 |
--------------------------------------------------------------------------------
/mst/param_system.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import itertools
4 | import pytorch_lightning as pl
5 |
6 | from typing import Callable
7 | from mst.utils import batch_stereo_peak_normalize
8 |
9 | import warnings
10 |
11 | warnings.filterwarnings(
12 | "ignore"
13 | ) # fix this later to catch warnings about reading mp3 files
14 |
15 |
16 | class ParameterEstimationSystem(pl.LightningModule):
17 | def __init__(
18 | self,
19 | encoder: torch.nn.Module,
20 | projector: torch.nn.Module,
21 | mix_console: torch.nn.Module,
22 | remixer: torch.nn.Module,
23 | schedule: str = "step",
24 | lr: float = 3e-4,
25 | max_epochs: int = 500,
26 | **kwargs,
27 | ) -> None:
28 | super().__init__()
29 | self.encoder = encoder
30 | self.mix_console = mix_console
31 | self.remixer = remixer
32 | self.projector = projector
33 | self.save_hyperparameters(
34 | ignore=["encoder", "mix_console", "remixer", "projector"]
35 | )
36 |
37 | def forward(
38 | self,
39 | input_mix: torch.Tensor,
40 | output_mix: torch.Tensor,
41 | ) -> torch.Tensor:
42 | # could consider masking different parts of input and output
43 | # so the model cannot rely on perfectly aligned inputs
44 |
45 | z_in_left = self.encoder(input_mix[:, 0:1, :])
46 | z_in_right = self.encoder(input_mix[:, 1:2, :])
47 |
48 | z_out_left = self.encoder(output_mix[:, 0:1, :])
49 | z_out_right = self.encoder(output_mix[:, 1:2, :])
50 |
51 | # take difference between embeddings
52 | z_diff_left = z_out_left - z_in_left
53 | z_diff_right = z_out_right - z_in_right
54 |
55 | z_diff = torch.cat([z_diff_left, z_diff_right], dim=-1)
56 |
57 | # project to parameter space
58 | track_params, fx_bus_params, master_bus_params = self.projector(z_diff)
59 |
60 | return track_params, fx_bus_params, master_bus_params
61 |
62 | def common_step(
63 | self,
64 | batch: tuple,
65 | batch_idx: int,
66 | optimizer_idx: int = 0,
67 | train: bool = False,
68 | ):
69 | """Model step used for validation and training.
70 | Args:
71 | batch (Tuple[Tensor, Tensor]): Batch items containing rmix, stems and orig mix
72 | batch_idx (int): Index of the batch within the current epoch.
73 | optimizer_idx (int): Index of the optimizer, this step is called once for each optimizer.
74 | train (bool): Wether step is called during training (True) or validation (False).
75 | """
76 | input_mix = batch
77 |
78 | # create a remix
79 | output_mix, track_params, fx_bus_params, master_bus_params = self.remixer(
80 | input_mix, self.mix_console
81 | )
82 |
83 | # estimate parameters
84 | track_params_hat, fx_bus_params_hat, master_bus_params_hat = self(
85 | input_mix, output_mix
86 | )
87 |
88 | # calculate loss
89 | track_params_loss = torch.nn.functional.mse_loss(
90 | track_params_hat,
91 | track_params,
92 | )
93 | fx_bus_params_loss = torch.nn.functional.mse_loss(
94 | fx_bus_params_hat,
95 | fx_bus_params,
96 | )
97 | master_bus_params_loss = torch.nn.functional.mse_loss(
98 | master_bus_params_hat,
99 | master_bus_params,
100 | )
101 |
102 | # scale by number of parameters
103 | track_params_loss *= track_params.shape[-1] + track_params.shape[-2]
104 | fx_bus_params_loss *= fx_bus_params.shape[-1]
105 | master_bus_params_loss *= master_bus_params.shape[-1]
106 |
107 | loss = track_params_loss + fx_bus_params_loss + master_bus_params_loss
108 |
109 | # log the losses
110 | self.log(
111 | ("train" if train else "val") + "/track_param_loss",
112 | track_params_loss,
113 | on_step=True,
114 | on_epoch=True,
115 | prog_bar=True,
116 | logger=True,
117 | sync_dist=True,
118 | )
119 |
120 | self.log(
121 | ("train" if train else "val") + "/fx_bus_param_loss",
122 | fx_bus_params_loss,
123 | on_step=True,
124 | on_epoch=True,
125 | prog_bar=True,
126 | logger=True,
127 | sync_dist=True,
128 | )
129 |
130 | self.log(
131 | ("train" if train else "val") + "/master_bus_param_loss",
132 | master_bus_params_loss,
133 | on_step=True,
134 | on_epoch=True,
135 | prog_bar=True,
136 | logger=True,
137 | sync_dist=True,
138 | )
139 |
140 | # log the losses
141 | self.log(
142 | ("train" if train else "val") + "/loss",
143 | loss,
144 | on_step=True,
145 | on_epoch=True,
146 | prog_bar=True,
147 | logger=True,
148 | sync_dist=True,
149 | )
150 |
151 | return loss
152 |
153 | def training_step(self, batch, batch_idx, optimizer_idx=0):
154 | loss = self.common_step(batch, batch_idx, train=True)
155 | return loss
156 |
157 | def validation_step(self, batch, batch_idx):
158 | loss = self.common_step(batch, batch_idx, train=False)
159 | return loss
160 |
161 | def configure_optimizers(self):
162 | optimizer = torch.optim.Adam(
163 | itertools.chain(self.encoder.parameters(), self.projector.parameters()),
164 | lr=self.hparams.lr,
165 | betas=(0.9, 0.999),
166 | )
167 |
168 | if self.hparams.schedule == "cosine":
169 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
170 | optimizer, T_max=self.hparams.max_epochs
171 | )
172 | elif self.hparams.schedule == "step":
173 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
174 | optimizer,
175 | [
176 | int(self.hparams.max_epochs * 0.85),
177 | int(self.hparams.max_epochs * 0.95),
178 | ],
179 | )
180 | else:
181 | return optimizer
182 | lr_schedulers = {"scheduler": scheduler, "interval": "epoch", "frequency": 1}
183 |
184 | return [optimizer], lr_schedulers
185 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # Diff-MST: Differentiable Mixing Style Transfer
5 | [Paper](https://sai-soum.github.io/assets/pdf/Differentiable_Mixing_Style_Transfer.pdf) | [Website](https://sai-soum.github.io/projects/diffmst/) | [Video](https://youtu.be/w90RGZ3IqQw)
6 |
7 |
8 |

9 |
10 |
11 |
12 |
20 | # Repository Structure
21 | 1. `configs` - Contains configuration files for training and inference.
22 | 2. `mst` - Contains the main codebase for the project.
23 | - `dataloaders` - Contains dataloaders for the project.
24 | - `modules` - Contains the modules for different components of the system.
25 | - `mixing` - Contains the mixing modules for creating mixes.
26 | - `loss` - Contains the loss functions for the project.
27 | - `panns` - contains the most basic components like cnn14, resnet, etc.
28 | - `utils` - Contains utility functions for the project.
29 | 3. `scripts` - Contains scripts for running inference.
30 |
31 | # Setup
32 | - Clone the repository
33 | ```
34 | git clone https://github.com/sai-soum/Diff-MST.git
35 | cd Diff-MST
36 | ```
37 |
38 | - Create new Python environment
39 | ```
40 | # for Linux/macOS
41 | python3 -m venv env
42 | source env/bin/activate
43 | ```
44 |
45 | - Install the `mst` package from source
46 | ```
47 | # Install as editable (for development)
48 | pip install -e .
49 |
50 | # Alternatively, do a regular install (read-only)
51 | pip install .
52 | ```
53 |
54 | # Usage
55 | ## Train
56 | We use [LightningCLI](https://lightning.ai/docs/pytorch/stable/) for training and [Wandb](https://wandb.ai/site) for logging.
57 |
58 | ### Setup
59 | In the `configs` directory, you will find the configuration files for the project.
60 | - `config.yaml` - Contains the general configuration for the project.
61 | - `optimizer.yaml` - Contains the optimizer configuration for the project.
62 | - `data/` - Contains the data configuration for the project.
63 | - `models/` - Contains the model configuration for the project.
64 | We have provided instructions within the configuration files for setting up the project.
65 |
66 | Few important configuration parameters:
67 | - In `configs/data/` change the following
68 | - `track_root_dirs` - The root directory for the dataset needs to be setup. You can pass multiple dataset directories as a list. However, you will also need to provide corresponding metadata YAML files containing train, test, and val split. Check `data/` directory for examples.
69 | - For method 1: set `generate_mix` to `True` in the model configuration file. Use `medley+cambridge-8.yaml` for training with random mixes of the same song as reference.
70 | - For method 2: set `generate_mix` to `False` in the model configuration file. Use `medley+cambridge+jamendo-8.yaml` for training with real unpaired songs as reference.
71 | - update `mix_root_dirs` - The root directory for the mix dataset. This is used for training with real unpaired songs as reference.
72 | - You may benefit from setting a smaller value for `train_buffer_size_gb` and `val_buffer_size_gb` in the data configuration file for initial testing of the code.
73 | - In `configs/models/`
74 | - you can change the audio effects you want to disable by setting a very large value for the corresponding parameter. For example, to disable the compressor, set `active_compressor_epoch` to `1000`.
75 | - You can change the loss function used for training by setting the `loss` parameter.
76 | - In `optimizer.yaml` you can change the learning rate parameters.
77 | - In `config.yaml`
78 | - Update the directory for logging using `save_dir` under `trainer`.
79 | - You can use `ckpt_path` to load a pre-trained model for fine-tuning, resuming training, or testing.
80 |
81 |
82 | ### Method 1: Training with random mixes of the same song as reference using MRSTFT loss.
83 | ```
84 | CUDA_VISIBLE_DEVICES=0 python main.py fit \
85 | -c configs/config.yaml \
86 | -c configs/optimizer.yaml \
87 | -c configs/data/medley+cambridge-8.yaml \
88 | -c configs/models/naive.yaml
89 | ```
90 |
91 | To run the fine-tuning using AFLoss
92 | ```
93 | CUDA_VISIBLE_DEVICES=0 python main.py fit \
94 | -c configs/config.yaml \
95 | -c configs/optimizer.yaml \
96 | -c configs/data/medley+cambridge-8.yaml \
97 | -c configs/models/naive+feat.yaml
98 | ```
99 |
100 | You can change the number of tracks, the size of training data for an epoch, and the batch size in the data configuration file located at `configs/data/`
101 |
102 | ### Method 2: Training with real unpaired songs as reference using AFloss.
103 |
104 | ```
105 | CUDA_VISIBLE_DEVICES=0 python main.py fit \
106 | -c configs/config.yaml \
107 | -c configs/optimizer.yaml \
108 | -c configs/data/medley+cambridge+jamendo-8.yaml \
109 | -c configs/models/unpaired+feat.yaml
110 | ```
111 |
112 | ## Inference
113 | To evaluate the model on real world data, run the ` scripts/eval_all_combo.py` script.
114 |
115 | Update the model checkpoints and the inference examples directory in the script.
116 |
117 | `Python 3.10` was used for training.
118 |
119 |
120 | ## Acknowledgements
121 | This work is funded and supported by UK Research and Innovation [grant number EP/S022694/1] and Steinberg Media Technologies GmbH under the AI and Music Centre for Doctoral Training (AIM-CDT) at the Centre for Digital Music, Queen Mary University of London, London, UK.
122 |
123 | ## Citation
124 | If you find this work useful, please consider citing our paper:
125 | ```
126 | @inproceedings{vanka2024diffmst,
127 | title={Diff-MST: Differentiable Mixing Style Transfer},
128 | author={Vanka, Soumya and Steinmetz, Christian and Rolland, Jean-Baptiste and Reiss, Joshua and Fazekas, Gy{\"o}rgy},
129 | booktitle={Proc. of the 25th Int. Society for Music Information Retrieval Conf. (ISMIR)},
130 | year={2024},
131 | organization={Int. Society for Music Information Retrieval (ISMIR)},
132 | abbr = {ISMIR},
133 | address = {San Francisco, USA},
134 | }
135 | ```
136 |
137 | ## License
138 | The code is licensed under the terms of the CC-BY-NC-SA 4.0 license. For a human-readable summary of the license, see https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en .
139 |
--------------------------------------------------------------------------------
/mst/panns.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/models.py
2 | # Under MIT License
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from typing import List
7 |
8 | # from torchlibrosa.stft import Spectrogram, LogmelFilterBank
9 | # from torchlibrosa.augmentation import SpecAugmentation
10 |
11 |
12 | def init_layer(layer):
13 | """Initialize a Linear or Convolutional layer."""
14 | nn.init.xavier_uniform_(layer.weight)
15 |
16 | if hasattr(layer, "bias"):
17 | if layer.bias is not None:
18 | layer.bias.data.fill_(0.0)
19 |
20 |
21 | def init_bn(bn):
22 | """Initialize a Batchnorm layer."""
23 | bn.bias.data.fill_(0.0)
24 | bn.weight.data.fill_(1.0)
25 |
26 |
27 | class ConvBlock(nn.Module):
28 | def __init__(self, in_channels, out_channels, use_batchnorm: bool = True, pool_type: str = 'avg'):
29 | super(ConvBlock, self).__init__()
30 | self.use_batchnorm = use_batchnorm
31 |
32 | self.conv1 = nn.Conv2d(
33 | in_channels=in_channels,
34 | out_channels=out_channels,
35 | kernel_size=(3, 3),
36 | stride=(1, 1),
37 | padding=(1, 1),
38 | bias=False,
39 | )
40 |
41 | self.conv2 = nn.Conv2d(
42 | in_channels=out_channels,
43 | out_channels=out_channels,
44 | kernel_size=(3, 3),
45 | stride=(1, 1),
46 | padding=(1, 1),
47 | bias=False,
48 | )
49 |
50 | if use_batchnorm:
51 | self.bn1 = nn.BatchNorm2d(out_channels)
52 | self.bn2 = nn.BatchNorm2d(out_channels)
53 | else:
54 | self.bn1 = nn.Identity()
55 | self.bn2 = nn.Identity()
56 |
57 | if pool_type == "max":
58 | self.pool_fn = F.max_pool2d
59 | elif pool_type == "avg":
60 | self.pool_fn = F.avg_pool2d
61 | elif pool_type == "avg+max":
62 | def pool_avg_max(x: torch.Tensor, kernel_size: List[int]):
63 | return F.avg_pool2d(x, kernel_size) + F.max_pool2d(x, kernel_size)
64 | self.pool_fn = pool_avg_max
65 | else:
66 | raise Exception("Incorrect argument for `pool_type`!")
67 |
68 |
69 | self.init_weight()
70 |
71 | def init_weight(self):
72 | init_layer(self.conv1)
73 | init_layer(self.conv2)
74 |
75 | if self.use_batchnorm:
76 | init_bn(self.bn1)
77 | init_bn(self.bn2)
78 |
79 | def forward(self, input: torch.Tensor, pool_size: List[int]):
80 | x = input
81 | x = F.relu_(self.bn1(self.conv1(x)))
82 | x = F.relu_(self.bn2(self.conv2(x)))
83 | x = self.pool_fn(x, pool_size)
84 |
85 | return x
86 |
87 |
88 | class ConvBlock5x5(nn.Module):
89 | def __init__(self, in_channels, out_channels):
90 | super(ConvBlock5x5, self).__init__()
91 |
92 | self.conv1 = nn.Conv2d(
93 | in_channels=in_channels,
94 | out_channels=out_channels,
95 | kernel_size=(5, 5),
96 | stride=(1, 1),
97 | padding=(2, 2),
98 | bias=False,
99 | )
100 |
101 | self.bn1 = nn.BatchNorm2d(out_channels)
102 |
103 | self.init_weight()
104 |
105 | def init_weight(self):
106 | init_layer(self.conv1)
107 | init_bn(self.bn1)
108 |
109 | def forward(self, input, pool_size=(2, 2), pool_type="avg"):
110 | x = input
111 | x = F.relu_(self.bn1(self.conv1(x)))
112 | if pool_type == "max":
113 | x = F.max_pool2d(x, kernel_size=pool_size)
114 | elif pool_type == "avg":
115 | x = F.avg_pool2d(x, kernel_size=pool_size)
116 | elif pool_type == "avg+max":
117 | x1 = F.avg_pool2d(x, kernel_size=pool_size)
118 | x2 = F.max_pool2d(x, kernel_size=pool_size)
119 | x = x1 + x2
120 | else:
121 | raise Exception("Incorrect argument!")
122 |
123 | return x
124 |
125 |
126 | class Cnn14(nn.Module):
127 | def __init__(
128 | self,
129 | num_classes: int,
130 | n_inputs: int = 1,
131 | use_batchnorm: bool = True,
132 | ):
133 | super(Cnn14, self).__init__()
134 |
135 | self.conv_block1 = ConvBlock(
136 | in_channels=n_inputs,
137 | out_channels=64,
138 | use_batchnorm=use_batchnorm,
139 | )
140 | self.conv_block2 = ConvBlock(
141 | in_channels=64,
142 | out_channels=128,
143 | use_batchnorm=use_batchnorm,
144 | )
145 | self.conv_block3 = ConvBlock(
146 | in_channels=128,
147 | out_channels=256,
148 | use_batchnorm=use_batchnorm,
149 | )
150 | self.conv_block4 = ConvBlock(
151 | in_channels=256,
152 | out_channels=512,
153 | use_batchnorm=use_batchnorm,
154 | )
155 | self.conv_block5 = ConvBlock(
156 | in_channels=512,
157 | out_channels=1024,
158 | use_batchnorm=use_batchnorm,
159 | )
160 | self.conv_block6 = ConvBlock(
161 | in_channels=1024,
162 | out_channels=2048,
163 | use_batchnorm=use_batchnorm,
164 | )
165 |
166 | self.fc = nn.Linear(2048, num_classes, bias=True)
167 | self.init_weight()
168 |
169 | def init_weight(self):
170 | # init_bn(self.bn0)
171 | init_layer(self.fc)
172 |
173 | def forward(self, x: torch.Tensor):
174 | """
175 | input (torch.Tensor): Spectrogram tensor with shape (bs, chs, bins, frames)
176 | """
177 | batch_size, chs, bins, frames = x.size()
178 |
179 | # x = x.view(batch_size, -1)
180 | # x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
181 | # x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
182 | # x = x.transpose(1, 3)
183 | # x = self.bn0(x)
184 | # x = x.transpose(1, 3)
185 | # if self.training:
186 | # x = self.spec_augmenter(x)
187 |
188 | x = self.conv_block1(x, pool_size=(2, 2))
189 | # x = F.dropout(x, p=0.2, training=self.training)
190 | x = self.conv_block2(x, pool_size=(4, 4))
191 | # x = F.dropout(x, p=0.2, training=self.training)
192 | x = self.conv_block3(x, pool_size=(4, 2))
193 | # x = F.dropout(x, p=0.2, training=self.training)
194 | x = self.conv_block4(x, pool_size=(4, 2))
195 | # x = F.dropout(x, p=0.2, training=self.training)
196 | x = self.conv_block5(x, pool_size=(4, 2))
197 | # x = F.dropout(x, p=0.2, training=self.training)
198 | x = self.conv_block6(x, pool_size=(2, 2))
199 | # x = F.dropout(x, p=0.2, training=self.training)
200 | x = torch.mean(x, dim=2) # mean across stft bins
201 |
202 | (x1, _) = torch.max(x, dim=2)
203 | x2 = torch.mean(x, dim=2)
204 | x = x1 + x2
205 | # x = F.dropout(x, p=0.5, training=self.training)
206 | x_out = self.fc(x)
207 | clipwise_output = x_out
208 |
209 | return clipwise_output
210 |
--------------------------------------------------------------------------------
/mst/loss.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import torch
3 | import librosa
4 |
5 | from typing import List
6 | from mst.filter import barkscale_fbanks
7 |
8 |
9 | def compute_mid_side(x: torch.Tensor):
10 | x_mid = x[:, 0, :] + x[:, 1, :]
11 | x_side = x[:, 0, :] - x[:, 1, :]
12 | return x_mid, x_side
13 |
14 |
15 | from mst.filter import barkscale_fbanks
16 |
17 | import yaml
18 | from mst.fx_encoder import FXencoder
19 |
20 | from mst.modules import SpectrogramEncoder
21 |
22 |
23 | def compute_mid_side(x: torch.Tensor):
24 | x_mid = x[:, 0, :] + x[:, 1, :]
25 | x_side = x[:, 0, :] - x[:, 1, :]
26 | return x_mid, x_side
27 |
28 |
29 | def compute_melspectrum(
30 | x: torch.Tensor,
31 | sample_rate: int = 44100,
32 | fft_size: int = 32768,
33 | n_bins: int = 128,
34 | **kwargs,
35 | ):
36 | """Compute mel-spectrogram.
37 |
38 | Args:
39 | x: (bs, 2, seq_len)
40 | sample_rate: sample rate of audio
41 | fft_size: size of fft
42 | n_bins: number of mel bins
43 |
44 | Returns:
45 | X: (bs, n_bins)
46 |
47 | """
48 | fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
49 | fb = torch.tensor(fb).unsqueeze(0).type_as(x)
50 |
51 | x = x.mean(dim=1, keepdim=True)
52 | X = torch.fft.rfft(x, n=fft_size, dim=-1)
53 | X = torch.abs(X)
54 | X = torch.mean(X, dim=1, keepdim=True) # take mean over time
55 | X = X.permute(0, 2, 1) # swap time and freq dims
56 | X = torch.matmul(fb, X)
57 | X = torch.log(X + 1e-8)
58 |
59 | return X
60 |
61 |
62 | def compute_barkspectrum(
63 | x: torch.Tensor,
64 | fft_size: int = 32768,
65 | n_bands: int = 24,
66 | sample_rate: int = 44100,
67 | f_min: float = 20.0,
68 | f_max: float = 20000.0,
69 | mode: str = "mid-side",
70 | **kwargs,
71 | ):
72 | """Compute bark-spectrogram.
73 |
74 | Args:
75 | x: (bs, 2, seq_len)
76 | fft_size: size of fft
77 | n_bands: number of bark bins
78 | sample_rate: sample rate of audio
79 | f_min: minimum frequency
80 | f_max: maximum frequency
81 | mode: "mono", "stereo", or "mid-side"
82 |
83 | Returns:
84 | X: (bs, 24)
85 |
86 | """
87 | # compute filterbank
88 | fb = barkscale_fbanks((fft_size // 2) + 1, f_min, f_max, n_bands, sample_rate)
89 | fb = fb.unsqueeze(0).type_as(x)
90 | fb = fb.permute(0, 2, 1)
91 |
92 | if mode == "mono":
93 | x = x.mean(dim=1) # average over channels
94 | signals = [x]
95 | elif mode == "stereo":
96 | signals = [x[:, 0, :], x[:, 1, :]]
97 | elif mode == "mid-side":
98 | x_mid = x[:, 0, :] + x[:, 1, :]
99 | x_side = x[:, 0, :] - x[:, 1, :]
100 | signals = [x_mid, x_side]
101 | else:
102 | raise ValueError(f"Invalid mode {mode}")
103 |
104 | outputs = []
105 | for signal in signals:
106 | X = torch.stft(
107 | signal,
108 | n_fft=fft_size,
109 | hop_length=fft_size // 4,
110 | return_complex=True,
111 | window=torch.hann_window(fft_size).to(x.device),
112 | ) # compute stft
113 | X = torch.abs(X) # take magnitude
114 | X = torch.mean(X, dim=-1, keepdim=True) # take mean over time
115 | # X = X.permute(0, 2, 1) # swap time and freq dims
116 | X = torch.matmul(fb, X) # apply filterbank
117 | X = torch.log(X + 1e-8)
118 | # X = torch.cat([X, X_log], dim=-1)
119 | outputs.append(X)
120 |
121 | # stack into tensor
122 | X = torch.cat(outputs, dim=-1)
123 |
124 | return X
125 |
126 |
127 | def compute_rms(x: torch.Tensor, **kwargs):
128 | """Compute root mean square energy.
129 |
130 | Args:
131 | x: (bs, 1, seq_len)
132 |
133 | Returns:
134 | rms: (bs, )
135 | """
136 | rms = torch.sqrt(torch.mean(x**2, dim=-1).clamp(min=1e-8))
137 | return rms
138 |
139 |
140 | def compute_crest_factor(x: torch.Tensor, **kwargs):
141 | """Compute crest factor as ratio of peak to rms energy in dB.
142 |
143 | Args:
144 | x: (bs, 2, seq_len)
145 |
146 | """
147 | num = torch.max(torch.abs(x), dim=-1)[0]
148 | den = compute_rms(x).clamp(min=1e-8)
149 | cf = 20 * torch.log10((num / den).clamp(min=1e-8))
150 | return cf
151 |
152 |
153 | def compute_stereo_width(x: torch.Tensor, **kwargs):
154 | """Compute stereo width as ratio of energy in sum and difference signals.
155 |
156 | Args:
157 | x: (bs, 2, seq_len)
158 |
159 | """
160 | bs, chs, seq_len = x.size()
161 |
162 | assert chs == 2, "Input must be stereo"
163 |
164 | # compute sum and diff of stereo channels
165 | x_sum = x[:, 0, :] + x[:, 1, :]
166 | x_diff = x[:, 0, :] - x[:, 1, :]
167 |
168 | # compute power of sum and diff
169 | sum_energy = torch.mean(x_sum**2, dim=-1)
170 | diff_energy = torch.mean(x_diff**2, dim=-1)
171 |
172 | # compute stereo width as ratio
173 | stereo_width = diff_energy / sum_energy.clamp(min=1e-8)
174 |
175 | return stereo_width
176 |
177 |
178 | def compute_stereo_imbalance(x: torch.Tensor, **kwargs):
179 | """Compute stereo imbalance as ratio of energy in left and right channels.
180 |
181 | Args:
182 | x: (bs, 2, seq_len)
183 |
184 | Returns:
185 | stereo_imbalance: (bs, )
186 |
187 | """
188 | left_energy = torch.mean(x[:, 0, :] ** 2, dim=-1)
189 | right_energy = torch.mean(x[:, 1, :] ** 2, dim=-1)
190 |
191 | stereo_imbalance = (right_energy - left_energy) / (
192 | right_energy + left_energy
193 | ).clamp(min=1e-8)
194 |
195 | return stereo_imbalance
196 |
197 |
198 | class AudioFeatureLoss(torch.nn.Module):
199 | def __init__(
200 | self,
201 | weights: List[float],
202 | sample_rate: int,
203 | stem_separation: bool = False,
204 | use_clap: bool = False,
205 | ) -> None:
206 | """Compute loss using a set of differentiable audio features.
207 |
208 | Args:
209 | weights: weights for each feature
210 | sample_rate: sample rate of audio
211 | stem_separation: whether to compute loss on stems or mix
212 |
213 | Based on features proposed in:
214 |
215 | Man, B. D., et al.
216 | "An analysis and evaluation of audio features for multitrack music mixtures."
217 | (2014).
218 |
219 | """
220 | super().__init__()
221 | self.weights = weights
222 | self.sample_rate = sample_rate
223 | self.stem_separation = stem_separation
224 | self.sources_list = ["mix"]
225 | self.source_weights = [1.0]
226 | self.use_clap = use_clap
227 |
228 | self.transforms = [
229 | compute_rms,
230 | compute_crest_factor,
231 | compute_stereo_width,
232 | compute_stereo_imbalance,
233 | compute_barkspectrum,
234 | ]
235 |
236 | assert len(self.transforms) == len(weights)
237 |
238 | def forward(self, input: torch.Tensor, target: torch.Tensor):
239 | losses = {}
240 |
241 | # reshape for example stem dim
242 | input_stems = input.unsqueeze(1)
243 | target_stems = target.unsqueeze(1)
244 |
245 | n_stems = input_stems.shape[1]
246 |
247 | # iterate over each stem compute loss for each transform
248 | for stem_idx in range(n_stems):
249 | input_stem = input_stems[:, stem_idx, ...]
250 | target_stem = target_stems[:, stem_idx, ...]
251 |
252 | for transform, weight in zip(self.transforms, self.weights):
253 | transform_name = "_".join(transform.__name__.split("_")[1:])
254 | key = f"{self.sources_list[stem_idx]}-{transform_name}"
255 | input_transform = transform(input_stem, sample_rate=self.sample_rate)
256 | target_transform = transform(target_stem, sample_rate=self.sample_rate)
257 | val = torch.nn.functional.mse_loss(input_transform, target_transform)
258 | losses[key] = weight * val * self.source_weights[stem_idx]
259 |
260 | return losses
261 |
--------------------------------------------------------------------------------
/scripts/gain_testing.py:
--------------------------------------------------------------------------------
1 |
2 | # run pretrained models over evaluation set to generate audio examples for the listening test
3 | import os
4 | import torch
5 | import torchaudio
6 | import pyloudnorm as pyln
7 | from mst.utils import load_diffmst, run_diffmst
8 | from mst.loss import compute_barkspectrum, compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, AudioFeatureLoss
9 | import json
10 | import numpy as np
11 | import csv
12 | import glob
13 | import yaml
14 |
15 |
16 | def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
17 |
18 | meter = pyln.Meter(44100)
19 | target_lufs_db = -48.0
20 |
21 | norm_tracks = []
22 | for track_idx in range(tracks.shape[1]):
23 | track = tracks[:, track_idx : track_idx + 1, :]
24 | lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy())
25 |
26 | if lufs_db < -80.0:
27 | print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.")
28 | continue
29 |
30 | lufs_delta_db = target_lufs_db - lufs_db
31 | track *= 10 ** (lufs_delta_db / 20)
32 | norm_tracks.append(track)
33 |
34 | norm_tracks = torch.cat(norm_tracks, dim=1)
35 | # create a sum mix with equal loudness
36 | sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1)
37 | sum_mix /= sum_mix.abs().max()
38 |
39 | return sum_mix, None, None, None
40 |
41 |
42 |
43 | if __name__ == "__main__":
44 | meter = pyln.Meter(44100)
45 | target_mix_lufs_db = -16.0
46 | target_track_lufs_db = -48.0
47 | output_dir = "outputs/gain_testing_diff_song_individual_tracks"
48 | os.makedirs(output_dir, exist_ok=True)
49 |
50 | methods = {
51 | "diffmst-16": {
52 | "model": load_diffmst(
53 | "/Users/svanka/Downloads/b4naquji/config.yaml",
54 | "/Users/svanka/Downloads/b4naquji/checkpoints/epoch=191-step=626608.ckpt",
55 | ),
56 | "func": run_diffmst,
57 | },
58 | # "sum": {
59 | # "model": (None, None),
60 | # "func": equal_loudness_mix,
61 | # },
62 | }
63 |
64 | ref_dir = "/Users/svanka/Downloads/DSD100subset/sources/Dev/055 - Angels In Amplifiers - I'm Alright"
65 | #mix_dir = "/Users/svanka/Downloads/DSD100subset/sources/Dev/055 - Angels In Amplifiers - I'm Alright"
66 | mix_dir = "/Users/svanka/Downloads/DSD100subset/Sources/Test/049 - Young Griffo - Facade"
67 |
68 | ref_tracks = glob.glob(os.path.join(ref_dir, "*.wav"))
69 | mix_tracks = glob.glob(os.path.join(mix_dir, "*.wav"))
70 |
71 | print(len(ref_tracks), len(mix_tracks))
72 | #order the tracks in ref_tracks to vocals, bass, other, drums
73 | ref_tracks_ordered = [""] * 4
74 | for track in ref_tracks:
75 | if "vocals" in track:
76 | ref_tracks_ordered[0] = track
77 | elif "bass" in track:
78 | ref_tracks_ordered[1] = track
79 | elif "other" in track:
80 | ref_tracks_ordered[2] = track
81 | elif "drums" in track:
82 | ref_tracks_ordered[3] = track
83 | ref_tracks = ref_tracks_ordered
84 |
85 | print(ref_tracks)
86 | # print(mix_tracks)
87 |
88 |
89 | #we will predict a mix for one track from reference, sum of two, sum of three, sum of four tracks from reference as the reference for model
90 | # and the mix as the input
91 |
92 | tracks = []
93 | #info = torchaudio.info(mix_tracks[0])
94 |
95 |
96 | track_instrument = []
97 | for track in mix_tracks:
98 | #audio, sr = torchaudio.load(track, frame_offset = int((info.num_frames)/2 - 220500), num_frames = 441000, backend="soundfile")
99 | audio, sr = torchaudio.load(track,num_frames = 441000, backend="soundfile")
100 | if sr != 44100:
101 | audio = torchaudio.functional.resample(audio, sr, 44100)
102 |
103 | if audio.shape[0] == 2:
104 | audio = audio.mean(dim=0, keepdim=True)
105 |
106 | tracks.append(audio)
107 | track_instrument.append(os.path.basename(track).replace(".wav", ""))
108 |
109 |
110 | tracks = torch.cat(tracks, dim=0)
111 | print("tracks shape", tracks.shape)
112 | tracks = tracks.unsqueeze(0)
113 | print("tracks shape", tracks.shape)
114 |
115 | #create a sum mix
116 | sum_mix, _, _, _ = equal_loudness_mix(tracks)
117 | print("sum_mix shape", sum_mix.shape)
118 | save_path = os.path.join(output_dir, f"{os.path.basename(mix_dir)}-sum_mix.wav")
119 | torchaudio.save(save_path, sum_mix.view(2, -1), 44100)
120 |
121 | ref_mix_tracks = []
122 | info = torchaudio.info(ref_tracks[0])
123 | name = "ref_mix-16="
124 | data = {}
125 | data["track_instrument"] = track_instrument
126 | for i , ref_track in enumerate(ref_tracks):
127 | instrument = name + "-" + os.path.basename(ref_track).replace(".wav", "")
128 | print(instrument)
129 | #name = instrument
130 | ref_audio, sr = torchaudio.load(ref_track, frame_offset = int((info.num_frames)/2 - 220500), num_frames = 441000, backend="soundfile")
131 | if sr != 44100:
132 | ref_audio = torchaudio.functional.resample(ref_audio, sr, 44100)
133 |
134 | #loudness normalize the reference mix to -48 LUFS
135 | ref_lufs_db = meter.integrated_loudness(ref_audio.squeeze().permute(1, 0).numpy())
136 | lufs_delta_db = target_track_lufs_db - ref_lufs_db
137 | ref_audio = ref_audio * 10 ** (lufs_delta_db / 20)
138 |
139 | #ref_mix_tracks.append(ref_audio)
140 | ref_mix_tracks = [ref_audio]
141 | ref_mix = torch.cat(ref_mix_tracks, dim=0)
142 | #create a stereo sum mix
143 | ref_mix = ref_mix.sum(dim=0, keepdim=True).repeat(1, 2, 1)
144 | #normalise to -16 LUFS
145 | ref_mix_lufs_db = meter.integrated_loudness(ref_mix.squeeze().permute(1, 0).numpy())
146 | lufs_delta_db = target_mix_lufs_db - ref_mix_lufs_db
147 | ref_mix = ref_mix * 10 ** (lufs_delta_db / 20)
148 | ref_save_path = os.path.join(output_dir, f"{os.path.basename(ref_dir)}-{instrument}.wav")
149 | torchaudio.save(ref_save_path, ref_mix.view(2, -1), 44100)
150 |
151 | yaml_path = os.path.join(output_dir, f"{os.path.basename(ref_dir)}-{instrument}.yaml")
152 | data["ref_mix"] = ref_save_path
153 | data["ref_instruments"] = instrument
154 | data["sum_mix"] = save_path
155 | #check if the json file exists
156 | print("tracks shape", tracks.shape)
157 | print("ref_mix shape", ref_mix.shape)
158 |
159 |
160 |
161 |
162 | for method_name, method in methods.items():
163 | model, mix_console = method["model"]
164 | func = method["func"]
165 | with torch.no_grad():
166 | result = func(
167 | tracks.clone(),
168 | ref_mix.clone(),
169 | model,
170 | mix_console,
171 | track_start_idx=0,
172 | ref_start_idx=0,
173 | )
174 |
175 | (
176 | pred_mix,
177 | pred_track_param_dict,
178 | pred_fx_bus_param_dict,
179 | pred_master_bus_param_dict,
180 | ) = result
181 |
182 |
183 | bs, chs, seq_len = pred_mix.shape
184 | print("pred_mix shape", pred_mix.shape)
185 | # loudness normalize the output mix
186 | mix_lufs_db = meter.integrated_loudness(
187 | pred_mix.squeeze(0).permute(1, 0).numpy()
188 | )
189 | print("pred_mix_lufs_db", mix_lufs_db)
190 | #print(mix_lufs_db)
191 | lufs_delta_db = target_mix_lufs_db - mix_lufs_db
192 | pred_mix = pred_mix * 10 ** (lufs_delta_db / 20)
193 | pred_mix_name = os.path.basename(mix_dir) + f"-pred_mix-ref_mix-16={instrument}.wav"
194 | mix_filepath = os.path.join(output_dir, pred_mix_name)
195 | torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100)
196 | # append to the json file param_dicts
197 |
198 | #print(pred_track_param_dict["input_gain"])
199 |
200 | data["pred_mix"] = pred_mix_name
201 | data["gain_values"] = pred_track_param_dict['input_fader']['gain_db'].detach().cpu().numpy().tolist()[0]
202 | #print(type(pred_track_param_dict['input_fader']['gain_db']))
203 |
204 |
205 | with open(yaml_path, "w") as f:
206 | yaml.dump(data, f)
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
--------------------------------------------------------------------------------
/mst/callbacks/mix.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import wandb
5 | import torchaudio
6 | import numpy as np
7 | import pyloudnorm as pyln
8 | import pytorch_lightning as pl
9 | from tqdm import tqdm
10 | from typing import List
11 |
12 | from mst.utils import batch_stereo_peak_normalize
13 | from mst.mixing import naive_random_mix
14 |
15 |
16 | class LogReferenceMix(pl.callbacks.Callback):
17 | def __init__(
18 | self,
19 | root_dirs: List[str],
20 | ref_mixes: List[str],
21 | peak_normalize: bool = True,
22 | sample_rate: int = 44100,
23 | length: int = 524288,
24 | target_track_lufs_db: float = -48.0,
25 | target_mix_lufs_db: float = -16.0,
26 | ):
27 | super().__init__()
28 | self.peak_normalize = peak_normalize
29 | self.sample_rate = sample_rate
30 | self.length = length
31 | self.target_track_lufs_db = target_track_lufs_db
32 | self.target_mix_lufs_db = target_mix_lufs_db
33 | self.meter = pyln.Meter(self.sample_rate)
34 |
35 | print(f"Initalizing reference mix logger with {len(root_dirs)} mixes.")
36 |
37 | self.songs = []
38 | for root_dir, ref_mix in zip(root_dirs, ref_mixes):
39 |
40 | print(f"Loading {root_dir}...")
41 | song = {}
42 | song["name"] = os.path.basename(root_dir)
43 |
44 |
45 | # load reference mix
46 | x, sr = torchaudio.load(ref_mix)
47 | print(f"Reference mix sample rate: {sr}")
48 |
49 |
50 | # convert sample rate if needed
51 | if sr != sample_rate:
52 | x = torchaudio.functional.resample(x, sr, sample_rate)
53 |
54 | print(f"Reference mix sample rate after resampling: {sample_rate}")
55 |
56 |
57 | song["ref_mix"] = x
58 |
59 | # load tracks
60 | track_filepaths = glob.glob(os.path.join(root_dir, "*.wav"))
61 | tracks = []
62 | print("Loading tracks...")
63 | for track_idx, track_filepath in enumerate(tqdm(track_filepaths)):
64 | x, sr = torchaudio.load(track_filepath)
65 |
66 | # convert sample rate if needed
67 | if sr != sample_rate:
68 | x = torchaudio.functional.resample(x, sr, sample_rate)
69 |
70 | # separate channels
71 | for ch_idx in range(x.shape[0]):
72 | x_ch = x[ch_idx : ch_idx + 1, :]
73 |
74 | # save
75 | tracks.append(x_ch)
76 |
77 | song["tracks"] = tracks
78 | self.songs.append(song)
79 |
80 | def on_validation_epoch_end(
81 | self,
82 | trainer,
83 | pl_module,
84 | ):
85 | """Called when the validation batch ends."""
86 | for idx, song in enumerate(self.songs):
87 | ref_mix = song["ref_mix"]
88 | tracks = song["tracks"]
89 | name = song["name"]
90 |
91 | # take a chunk from the middle of the mix
92 | start_idx = (ref_mix.shape[-1] // 2) - (131072 // 2)
93 | stop_idx = start_idx + 131072
94 | ref_mix_chunk = ref_mix[..., start_idx:stop_idx]
95 |
96 | # loudness normalize the mix
97 | mix_lufs_db = self.meter.integrated_loudness(
98 | ref_mix_chunk.permute(1, 0).numpy()
99 | )
100 | delta_lufs_db = torch.tensor(
101 | [self.target_mix_lufs_db - mix_lufs_db]
102 | ).float()
103 | gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0)
104 | ref_mix_chunk = gain_lin * ref_mix_chunk
105 |
106 | # move to gpu
107 | ref_mix_chunk = ref_mix_chunk.cuda()
108 |
109 | # make a mix of multiple sections of the tracks
110 | for n, start_idx in enumerate([0, 524288, 2 * 524288, 3 * 524288]):
111 | stop_idx = start_idx + 131072
112 |
113 | # loudness normalize tracks
114 | normalized_tracks = []
115 | for track in tracks:
116 | track = track[..., start_idx:stop_idx]
117 |
118 | if len(normalized_tracks) > 16:
119 | break
120 |
121 | if track.shape[-1] < 131072:
122 | continue
123 |
124 | track_lufs_db = self.meter.integrated_loudness(
125 | track.permute(1, 0).numpy()
126 | )
127 |
128 | if track_lufs_db < -48.0 or track_lufs_db == float("-inf"):
129 | continue
130 |
131 | delta_lufs_db = torch.tensor(
132 | [self.target_track_lufs_db - track_lufs_db]
133 | ).float()
134 |
135 | gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0)
136 | track = gain_lin * track
137 | normalized_tracks.append(track)
138 |
139 | if len(normalized_tracks) == 0:
140 | continue
141 |
142 | # cat tracks
143 | tracks_chunk = torch.cat(normalized_tracks, dim=0)
144 | tracks_chunk = tracks_chunk.cuda()
145 |
146 | with torch.no_grad():
147 | # predict parameters using the chunks
148 | (
149 | pred_track_params,
150 | pred_fx_bus_params,
151 | pred_master_bus_params,
152 | ) = pl_module.model(
153 | tracks_chunk.unsqueeze(0), ref_mix_chunk.unsqueeze(0)
154 | )
155 |
156 | # generate a mix with full tracks using the predicted mix console parameters
157 | (
158 | pred_mixed_tracks,
159 | pred_mix_chunk,
160 | pred_track_param_dict,
161 | pred_fx_bus_param_dict,
162 | pred_master_bus_param_dict,
163 | ) = pl_module.mix_console(
164 | tracks_chunk.unsqueeze(0),
165 | pred_track_params,
166 | pred_fx_bus_params,
167 | pred_master_bus_params,
168 | use_track_input_fader=pl_module.use_track_input_fader,
169 | use_track_panner=pl_module.use_track_panner,
170 | use_track_eq=pl_module.use_track_eq,
171 | use_track_compressor=pl_module.use_track_compressor,
172 | use_fx_bus=pl_module.use_fx_bus,
173 | use_master_bus=pl_module.use_master_bus,
174 | use_output_fader=pl_module.use_output_fader,
175 | )
176 |
177 | # normalize predicted mix
178 | pred_mix_chunk = batch_stereo_peak_normalize(pred_mix_chunk)
179 |
180 | # move back to cpu
181 | pred_mix_chunk = pred_mix_chunk.squeeze(0).cpu()
182 | ref_mix_chunk_out = ref_mix_chunk.squeeze(0).cpu()
183 |
184 | # generate sum mix
185 | sum_mix = tracks_chunk.unsqueeze(0).sum(dim=1, keepdim=True).cpu()
186 | sum_mix = batch_stereo_peak_normalize(sum_mix)
187 | sum_mix = sum_mix.squeeze(0)
188 |
189 | # generate random mix
190 | results = naive_random_mix(
191 | tracks_chunk.unsqueeze(0),
192 | pl_module.mix_console,
193 | use_track_input_fader=pl_module.use_track_input_fader,
194 | use_track_panner=pl_module.use_track_panner,
195 | use_track_eq=pl_module.use_track_eq,
196 | use_track_compressor=pl_module.use_track_compressor,
197 | use_fx_bus=pl_module.use_fx_bus,
198 | use_master_bus=pl_module.use_master_bus,
199 | use_output_fader=pl_module.use_output_fader,
200 | )
201 | rand_mix = results[1]
202 | rand_mix = batch_stereo_peak_normalize(rand_mix).cpu()
203 | rand_mix = rand_mix.squeeze(0)
204 |
205 | audios = {
206 | "ref_mix": ref_mix_chunk_out,
207 | "pred_mix": pred_mix_chunk,
208 | "sum_mix": sum_mix,
209 | "rand_mix": rand_mix,
210 | }
211 |
212 | total_samples = 0
213 | for x in audios.values():
214 | total_samples += x.shape[-1] + int(
215 | pl_module.mix_console.sample_rate
216 | )
217 |
218 | y = torch.zeros(total_samples, 2)
219 | log_name = f"{idx}_{n}{name}"
220 | start = 0
221 | for key, x in audios.items():
222 | end = start + x.shape[-1]
223 | y[start:end, :] = x.T
224 | start = end + int(pl_module.mix_console.sample_rate)
225 | log_name += key + "-"
226 |
227 | trainer.logger.experiment.log(
228 | {
229 | f"{log_name}": wandb.Audio(
230 | y.numpy(),
231 | sample_rate=int(pl_module.mix_console.sample_rate),
232 | )
233 | }
234 | )
235 |
--------------------------------------------------------------------------------
/mst/fx_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.nn.init as init
7 | import torchaudio
8 | import numpy as np
9 |
10 | import argparse
11 | import yaml
12 |
13 |
14 | currentdir = os.path.dirname(os.path.realpath(__file__))
15 | sys.path.append(os.path.dirname(currentdir))
16 |
17 | # 1-dimensional convolutional layer
18 | # in the order of conv -> norm -> activation
19 | class Conv1d_layer(nn.Module):
20 | def __init__(self, in_channels, out_channels, kernel_size, \
21 | stride=1, \
22 | padding="SAME", dilation=1, bias=True, \
23 | norm="batch", activation="relu", \
24 | mode="conv"):
25 | super(Conv1d_layer, self).__init__()
26 |
27 | self.conv1d = nn.Sequential()
28 |
29 | ''' padding '''
30 | if mode=="deconv":
31 | padding = int(dilation * (kernel_size-1) / 2)
32 | out_padding = 0 if stride==1 else 1
33 | elif mode=="conv" or "alias_free" in mode:
34 | if padding == "SAME":
35 | pad = int((kernel_size-1) * dilation)
36 | l_pad = int(pad//2)
37 | r_pad = pad - l_pad
38 | padding_area = (l_pad, r_pad)
39 | elif padding == "VALID":
40 | padding_area = (0, 0)
41 | else:
42 | pass
43 |
44 | ''' convolutional layer '''
45 | if mode=="deconv":
46 | self.conv1d.add_module("deconv1d", nn.ConvTranspose1d(in_channels, out_channels, kernel_size, \
47 | stride=stride, padding=padding, output_padding=out_padding, \
48 | dilation=dilation, \
49 | bias=bias))
50 | elif mode=="conv":
51 | self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
52 | self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
53 | stride=stride, padding=0, \
54 | dilation=dilation, \
55 | bias=bias))
56 | elif "alias_free" in mode:
57 | if "up" in mode:
58 | up_factor = stride * 2
59 | down_factor = 2
60 | elif "down" in mode:
61 | up_factor = 2
62 | down_factor = stride * 2
63 | else:
64 | raise ValueError("choose alias-free method : 'up' or 'down'")
65 | # procedure : conv -> upsample -> lrelu -> low-pass filter -> downsample
66 | # the torchaudio.transforms.Resample's default resampling_method is 'sinc_interpolation' which performs low-pass filter during the process
67 | # details at https://pytorch.org/audio/stable/transforms.html
68 | self.conv1d.add_module(f"{mode}1d_pad", nn.ReflectionPad1d(padding_area))
69 | self.conv1d.add_module(f"{mode}1d", nn.Conv1d(in_channels, out_channels, kernel_size, \
70 | stride=1, padding=0, \
71 | dilation=dilation, \
72 | bias=bias))
73 | self.conv1d.add_module(f"{mode}upsample", torchaudio.transforms.Resample(orig_freq=1, new_freq=up_factor))
74 | self.conv1d.add_module(f"{mode}lrelu", nn.LeakyReLU())
75 | self.conv1d.add_module(f"{mode}downsample", torchaudio.transforms.Resample(orig_freq=down_factor, new_freq=1))
76 |
77 | ''' normalization '''
78 | if norm=="batch":
79 | self.conv1d.add_module("batch_norm", nn.BatchNorm1d(out_channels))
80 | # self.conv1d.add_module("batch_norm", nn.SyncBatchNorm(out_channels))
81 |
82 | ''' activation '''
83 | if 'alias_free' not in mode:
84 | if activation=="relu":
85 | self.conv1d.add_module("relu", nn.ReLU())
86 | elif activation=="lrelu":
87 | self.conv1d.add_module("lrelu", nn.LeakyReLU())
88 |
89 |
90 | def forward(self, input):
91 | # input shape should be : batch x channel x height x width
92 | #print(input.shape)
93 | output = self.conv1d(input)
94 | return output
95 |
96 |
97 |
98 |
99 | # Residual Block
100 | # the input is added after the first convolutional layer, retaining its original channel size
101 | # therefore, the second convolutional layer's output channel may differ
102 | class Res_ConvBlock(nn.Module):
103 | def __init__(self, dimension, \
104 | in_channels, out_channels, \
105 | kernel_size, \
106 | stride=1, padding="SAME", \
107 | dilation=1, \
108 | bias=True, \
109 | norm="batch", \
110 | activation="relu", last_activation="relu", \
111 | mode="conv"):
112 | super(Res_ConvBlock, self).__init__()
113 |
114 | if dimension==1:
115 | self.conv1 = Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
116 | self.conv2 = Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
117 | elif dimension==2:
118 | self.conv1 = Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation)
119 | self.conv2 = Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode)
120 |
121 |
122 | def forward(self, input):
123 | #print("before c1_out in conv1d",input.shape)
124 | c1_out = self.conv1(input) + input
125 | c2_out = self.conv2(c1_out)
126 | return c2_out
127 |
128 |
129 |
130 | # Convoluaionl Block
131 | # consists of multiple (number of layer_num) convolutional layers
132 | # only the final convoluational layer outputs the desired 'out_channels'
133 | class ConvBlock(nn.Module):
134 | def __init__(self, dimension, layer_num, \
135 | in_channels, out_channels, \
136 | kernel_size, \
137 | stride=1, padding="SAME", \
138 | dilation=1, \
139 | bias=True, \
140 | norm="batch", \
141 | activation="relu", last_activation="relu", \
142 | mode="conv"):
143 | super(ConvBlock, self).__init__()
144 |
145 | conv_block = []
146 | if dimension==1:
147 | for i in range(layer_num-1):
148 | conv_block.append(Conv1d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
149 | conv_block.append(Conv1d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
150 | elif dimension==2:
151 | for i in range(layer_num-1):
152 | conv_block.append(Conv2d_layer(in_channels, in_channels, kernel_size, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=activation))
153 | conv_block.append(Conv2d_layer(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, norm=norm, activation=last_activation, mode=mode))
154 | self.conv_block = nn.Sequential(*conv_block)
155 |
156 |
157 | def forward(self, input):
158 | return self.conv_block(input)
159 |
160 |
161 | # FXencoder that extracts audio effects from music recordings trained with a contrastive objective
162 | class FXencoder(nn.Module):
163 | def __init__(self, config):
164 | super(FXencoder, self).__init__()
165 | # input is stereo channeled audio
166 | config["channels"].insert(0, 2)
167 |
168 | # encoder layers
169 | encoder = []
170 | for i in range(len(config["kernels"])):
171 | if config["conv_block"]=='res':
172 | encoder.append(Res_ConvBlock(dimension=1, \
173 | in_channels=config["channels"][i], \
174 | out_channels=config["channels"][i+1], \
175 | kernel_size=config["kernels"][i], \
176 | stride=config["strides"][i], \
177 | padding="SAME", \
178 | dilation=config["dilation"][i], \
179 | norm=config["norm"], \
180 | activation=config["activation"], \
181 | last_activation=config["activation"]))
182 | elif config["conv_block"]=='conv':
183 | encoder.append(ConvBlock(dimension=1, \
184 | layer_num=1, \
185 | in_channels=config["channels"][i], \
186 | out_channels=config["channels"][i+1], \
187 | kernel_size=config["kernels"][i], \
188 | stride=config["strides"][i], \
189 | padding="VALID", \
190 | dilation=config["dilation"][i], \
191 | norm=config["norm"], \
192 | activation=config["activation"], \
193 | last_activation=config["activation"], \
194 | mode='conv'))
195 | self.encoder = nn.Sequential(*encoder)
196 |
197 | # pooling method
198 | self.glob_pool = nn.AdaptiveAvgPool1d(1)
199 |
200 | # network forward operation
201 | def forward(self, input):
202 | #print("in resnet",input.shape)
203 | enc_output = self.encoder(input)
204 | glob_pooled = self.glob_pool(enc_output).squeeze(-1)
205 |
206 | # outputs c feature
207 | return glob_pooled
208 |
209 |
210 |
211 |
--------------------------------------------------------------------------------
/scripts/eval_listen.py:
--------------------------------------------------------------------------------
1 | # run pretrained models over evaluation set to generate audio examples for the listening test
2 | import os
3 | import torch
4 | import torchaudio
5 | import pyloudnorm as pyln
6 | from mst.utils import load_diffmst, run_diffmst
7 |
8 |
9 | def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
10 |
11 | meter = pyln.Meter(44100)
12 | target_lufs_db = -48.0
13 |
14 | norm_tracks = []
15 | for track_idx in range(tracks.shape[1]):
16 | track = tracks[:, track_idx : track_idx + 1, :]
17 | lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy())
18 |
19 | if lufs_db < -80.0:
20 | print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.")
21 | continue
22 |
23 | lufs_delta_db = target_lufs_db - lufs_db
24 | track *= 10 ** (lufs_delta_db / 20)
25 | norm_tracks.append(track)
26 |
27 | norm_tracks = torch.cat(norm_tracks, dim=1)
28 | # create a sum mix with equal loudness
29 | sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1)
30 | sum_mix /= sum_mix.abs().max()
31 |
32 | return sum_mix, None, None, None
33 |
34 |
35 | if __name__ == "__main__":
36 | meter = pyln.Meter(44100)
37 | target_lufs_db = -22.0
38 | output_dir = "outputs/listen_1"
39 | os.makedirs(output_dir, exist_ok=True)
40 |
41 | methods = {
42 | "diffmst-16": {
43 | "model": load_diffmst(
44 | "/import/c4dm-datasets-ext/Diff-MST/DiffMST/b4naquji/config.yaml",
45 | "/import/c4dm-datasets-ext/Diff-MST/DiffMST/b4naquji/checkpoints/epoch=191-step=626608.ckpt",
46 | ),
47 | "func": run_diffmst,
48 | },
49 | "sum": {
50 | "model": (None, None),
51 | "func": equal_loudness_mix,
52 | },
53 | }
54 |
55 | # get the validation examples
56 | examples = {
57 | "ecstasy": {
58 | "tracks": "/import/c4dm-datasets-ext/diffmst-examples/song1/BenFlowers_Ecstasy_Full/",
59 | "track_verse_start_idx": 1190700,
60 | "track_chorus_start_idx": 2381400,
61 | "ref": "/import/c4dm-datasets-ext/diffmst-examples/song1/ref/_Feel it all Around_ by Washed Out (Portlandia Theme)_01.wav",
62 | "ref_verse_start_idx": 970200,
63 | "ref_chorus_start_idx": 198450,
64 | },
65 | "by-my-side": {
66 | "tracks": "/import/c4dm-datasets-ext/diffmst-examples/song2/Kat Wright_By My Side/",
67 | "track_verse_start_idx": 1146600,
68 | "track_chorus_start_idx": 7144200,
69 | "ref": "/import/c4dm-datasets-ext/diffmst-examples/song2/ref/The Dip - Paddle To The Stars (Lyric Video)_01.wav",
70 | "ref_verse_start_idx": 661500,
71 | "ref_chorus_start_idx": 2028600,
72 | },
73 | "haunted-aged": {
74 | "tracks": "/import/c4dm-datasets-ext/diffmst-examples/song3/Titanium_HauntedAge_Full/",
75 | "track_verse_start_idx": 1675800,
76 | "track_chorus_start_idx": 3439800,
77 | "ref": "/import/c4dm-datasets-ext/diffmst-examples/song3/ref/Architects - _Doomsday__01.wav",
78 | "ref_verse_start_idx": 4630500,
79 | "ref_chorus_start_idx": 6570900,
80 | },
81 | }
82 |
83 | for example_name, example in examples.items():
84 | print(example_name)
85 | example_dir = os.path.join(output_dir, example_name)
86 | os.makedirs(example_dir, exist_ok=True)
87 | # load reference mix
88 | ref_audio, ref_sr = torchaudio.load(example["ref"], backend="soundfile")
89 | if ref_sr != 44100:
90 | ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 44100)
91 | print(ref_audio.shape, ref_sr)
92 |
93 | # first find all the tracks
94 | track_filepaths = []
95 | for root, dirs, files in os.walk(example["tracks"]):
96 | for filepath in files:
97 | if filepath.endswith(".wav"):
98 | track_filepaths.append(os.path.join(root, filepath))
99 |
100 | print(f"Found {len(track_filepaths)} tracks.")
101 |
102 | # load the tracks
103 | tracks = []
104 | lengths = []
105 | for track_idx, track_filepath in enumerate(track_filepaths):
106 | audio, sr = torchaudio.load(track_filepath, backend="soundfile")
107 |
108 | if sr != 44100:
109 | audio = torchaudio.functional.resample(audio, sr, 44100)
110 |
111 | # loudness normalize the tracks to -48 LUFS
112 | lufs_db = meter.integrated_loudness(audio.permute(1, 0).numpy())
113 | # lufs_delta_db = -48 - lufs_db
114 | # audio = audio * 10 ** (lufs_delta_db / 20)
115 |
116 | #print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db)
117 |
118 | if audio.shape[0] == 2:
119 | audio = audio.mean(dim=0, keepdim=True)
120 |
121 | chs, seq_len = audio.shape
122 |
123 | for ch_idx in range(chs):
124 | tracks.append(audio[ch_idx : ch_idx + 1, :])
125 | lengths.append(audio.shape[-1])
126 | print("Loaded tracks.")
127 | # find max length and pad if shorter
128 | max_length = max(lengths)
129 | for track_idx in range(len(tracks)):
130 | tracks[track_idx] = torch.nn.functional.pad(
131 | tracks[track_idx], (0, max_length - lengths[track_idx])
132 | )
133 | print("Padded tracks.")
134 | # stack into a tensor
135 | tracks = torch.cat(tracks, dim=0)
136 | tracks = tracks.view(1, -1, max_length)
137 | ref_audio = ref_audio.view(1, 2, -1)
138 |
139 | # crop tracks to max of 60 seconds or so
140 | # tracks = tracks[..., :4194304]
141 |
142 | print(tracks.shape)
143 |
144 | # create a sum mix with equal loudness
145 | sum_mix = torch.sum(tracks, dim=1, keepdim=True).squeeze(0)
146 | sum_filepath = os.path.join(example_dir, f"{example_name}-sum.wav")
147 | os.makepath(sum_filepath)
148 | print("sum_mix path created")
149 |
150 | # loudness normalize the sum mix
151 | sum_lufs_db = meter.integrated_loudness(sum_mix.permute(1, 0).numpy())
152 | lufs_delta_db = target_lufs_db - sum_lufs_db
153 | sum_mix = sum_mix * 10 ** (lufs_delta_db / 20)
154 |
155 | torchaudio.save(sum_filepath, sum_mix.view(1, -1), 44100)
156 | print("Sum mix saved.")
157 |
158 | # save the reference mix
159 | ref_filepath = os.path.join(example_dir, "ref-full.wav")
160 | torchaudio.save(ref_filepath, ref_audio.squeeze(), 44100)
161 | print("Reference mix saved.")
162 |
163 | for song_section in ["verse", "chorus"]:
164 | print("Mixing", song_section)
165 | if song_section == "verse":
166 | track_start_idx = example["track_verse_start_idx"]
167 | ref_start_idx = example["ref_verse_start_idx"]
168 | else:
169 | track_start_idx = example["track_chorus_start_idx"]
170 | ref_start_idx = example["ref_chorus_start_idx"]
171 |
172 | if track_start_idx + 262144 > tracks.shape[-1]:
173 | print("Tracks too short for this section.")
174 | if ref_start_idx + 262144 > ref_audio.shape[-1]:
175 | print("Reference too short for this section.")
176 |
177 | # crop the tracks to create a mix twice the size of the reference section
178 | mix_tracks = tracks
179 | # [..., track_start_idx : track_start_idx + (262144 * 2)]
180 | mix_tracks = tracks[..., track_start_idx : track_start_idx + (262144 * 2)]
181 | track_start_idx = 0
182 | print("mix_tracks", mix_tracks.shape)
183 |
184 | # save the reference mix section for analysis
185 | ref_analysis = ref_audio[..., ref_start_idx : ref_start_idx + 262144]
186 |
187 | # create mixes varying the loudness of the reference
188 | for ref_loudness_target in [-24, -16, -14.0, -12, -6]:
189 | print("Ref loudness", ref_loudness_target)
190 | ref_filepath = os.path.join(
191 | example_dir,
192 | f"ref-analysis-{song_section}-lufs-{ref_loudness_target:0.0f}.wav",
193 | )
194 |
195 | # loudness normalize the reference mix section to -14 LUFS
196 | ref_lufs_db = meter.integrated_loudness(
197 | ref_analysis.squeeze().permute(1, 0).numpy()
198 | )
199 | lufs_delta_db = ref_loudness_target - ref_lufs_db
200 | ref_analysis = ref_analysis * 10 ** (lufs_delta_db / 20)
201 |
202 | torchaudio.save(ref_filepath, ref_analysis.squeeze(), 44100)
203 |
204 | for method_name, method in methods.items():
205 | print(method_name)
206 | # tracks (torch.Tensor): Set of input tracks with shape (bs, num_tracks, seq_len)
207 | # ref_audio (torch.Tensor): Reference mix with shape (bs, 2, seq_len)
208 |
209 | if method_name == "sum":
210 | if ref_loudness_target != -16:
211 | continue
212 |
213 | if method_name == "sum" and song_section == "chorus":
214 | continue
215 |
216 | model, mix_console = method["model"]
217 | func = method["func"]
218 |
219 | print(tracks.shape, ref_audio.shape)
220 |
221 | with torch.no_grad():
222 | result = func(
223 | mix_tracks.clone(),
224 | ref_analysis.clone(),
225 | model,
226 | mix_console,
227 | track_start_idx=track_start_idx,
228 | ref_start_idx=ref_start_idx,
229 | )
230 |
231 | (
232 | pred_mix,
233 | pred_track_param_dict,
234 | pred_fx_bus_param_dict,
235 | pred_master_bus_param_dict,
236 | ) = result
237 |
238 | bs, chs, seq_len = pred_mix.shape
239 |
240 | # loudness normalize the output mix
241 | mix_lufs_db = meter.integrated_loudness(
242 | pred_mix.squeeze(0).permute(1, 0).numpy()
243 | )
244 | print(mix_lufs_db)
245 | lufs_delta_db = target_lufs_db - mix_lufs_db
246 | pred_mix = pred_mix * 10 ** (lufs_delta_db / 20)
247 |
248 | # save resulting audio and parameters
249 | mix_filepath = os.path.join(
250 | example_dir,
251 | f"{example_name}-{method_name}-ref={song_section}-lufs-{ref_loudness_target:0.0f}.wav",
252 | )
253 | torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100)
254 |
255 | # also save only the analysis section
256 | mix_analysis = pred_mix[
257 | ..., track_start_idx : track_start_idx + (2 * 262144)
258 | ]
259 |
260 | # loudness normalize the output mix
261 | mix_lufs_db = meter.integrated_loudness(
262 | mix_analysis.squeeze(0).permute(1, 0).numpy()
263 | )
264 | print(mix_lufs_db)
265 | mix_analysis = mix_analysis * 10 ** (lufs_delta_db / 20)
266 |
267 | mix_filepath = os.path.join(
268 | example_dir,
269 | f"{example_name}-{method_name}-analysis-{song_section}-lufs-{ref_loudness_target:0.0f}.wav",
270 | )
271 | torchaudio.save(mix_filepath, mix_analysis.view(chs, -1), 44100)
272 |
273 | print()
274 |
--------------------------------------------------------------------------------
/mst/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import torch
4 | import random
5 | import numpy as np
6 | import pyloudnorm as pyln
7 |
8 | from tqdm import tqdm
9 | from typing import Optional
10 | from importlib import import_module
11 | from mst.modules import MixStyleTransferModel
12 |
13 |
14 | def batch_stereo_peak_normalize(x: torch.Tensor):
15 | """Normalize a batch of stereo mixes by their peak value.
16 |
17 | Args:
18 | x (Tensor): 1-d tensor with shape (bs, 2, seq_len).
19 |
20 | Returns:
21 | x (Tensor): Normalized signal withs shape (vs, 2, seq_len).
22 | """
23 | # first find the peaks in each channel
24 | gain_lin = x.abs().max(dim=-1, keepdim=True)[0]
25 | # then find the maximum peak across left and right per batch item
26 | gain_lin = gain_lin.max(dim=-2, keepdim=True)[0]
27 | # normalize by the maximum peak
28 | x_norm = x / gain_lin.clamp(1e-8) # avoid division by zero
29 | return x_norm
30 |
31 |
32 | def run_diffmst(
33 | tracks: torch.Tensor,
34 | ref: torch.Tensor,
35 | model: torch.nn.Module,
36 | mix_console: torch.nn.Module,
37 | track_start_idx: int = 0,
38 | ref_start_idx: int = 0,
39 | ):
40 | """Run the differentiable mix style transfer model.
41 |
42 | Args:
43 | tracks (Tensor): Set of input tracks with shape (bs, num_tracks, 1, seq_len).
44 | ref (Tensor): Reference mix with shape (bs, 2, seq_len).
45 | model (torch.nn.Module): MixStyleTransferModel instance.
46 | mix_console (torch.nn.Module): MixConsole instance.
47 | track_start_idx (int, optional): Start index of the track to use. Default: 0.
48 | ref_start_idx (int, optional): Start index of the reference mix to use. Default: 0.
49 |
50 | Returns:
51 | pred_mix (Tensor): Predicted mix with shape (bs, 2, seq_len).
52 | pred_track_param_dict (dict): Dictionary with predicted track parameters.
53 | pred_fx_bus_param_dict (dict): Dictionary with predicted fx bus parameters.
54 | pred_master_bus_param_dict (dict): Dictionary with predicted master bus parameters.
55 | """
56 | # ------ defaults ------
57 | use_track_input_fader = True
58 | use_track_panner = True
59 | use_track_eq = True
60 | use_track_compressor = True
61 | use_fx_bus = False
62 | use_master_bus = True
63 | use_output_fader = True
64 |
65 | analysis_len = 262144
66 | meter = pyln.Meter(44100)
67 |
68 | # crop the input tracks and reference mix to the analysis length
69 | if tracks.shape[-1] >= analysis_len:
70 | analysis_tracks = tracks[
71 | ..., track_start_idx : track_start_idx + analysis_len
72 | ].clone()
73 | else:
74 | analysis_tracks = tracks.clone()
75 |
76 | if ref.shape[-1] >= analysis_len:
77 | analysis_ref = ref[..., ref_start_idx : ref_start_idx + analysis_len]
78 | else:
79 | analysis_ref = ref.clone()
80 |
81 | # loudness normalize the tracks to -48 LUFS
82 | norm_tracks = []
83 | norm_analysis_tracks = []
84 | track_padding = []
85 | for track_idx in range(analysis_tracks.shape[1]):
86 | analysis_track = analysis_tracks[:, track_idx : track_idx + 1, :]
87 | track = tracks[:, track_idx : track_idx + 1, :]
88 | lufs_db = meter.integrated_loudness(
89 | analysis_track.squeeze(0).permute(1, 0).numpy()
90 | )
91 | if lufs_db < -80.0:
92 | print(f"Skipping track {track_idx} due to low loudness {lufs_db}.")
93 | continue
94 |
95 | lufs_delta_db = -48 - lufs_db
96 | analysis_track *= 10 ** (lufs_delta_db / 20)
97 | track *= 10 ** (lufs_delta_db / 20)
98 |
99 | norm_analysis_tracks.append(analysis_track)
100 | norm_tracks.append(track)
101 | track_padding.append(False)
102 |
103 | norm_analysis_tracks = torch.cat(norm_analysis_tracks, dim=1)
104 | norm_tracks = torch.cat(norm_tracks, dim=1)
105 | print(norm_analysis_tracks.shape, norm_tracks.shape)
106 |
107 | # take only first 16 tracks
108 | # norm_tracks = norm_tracks[:, :16, :]
109 | # norm_analysis_tracks = norm_analysis_tracks[:, :16, :]
110 | print(norm_analysis_tracks.shape, norm_tracks.shape)
111 |
112 | # make tensor contiguous
113 | norm_analysis_tracks = norm_analysis_tracks.contiguous()
114 | norm_tracks = norm_tracks.contiguous()
115 |
116 | # ---- run model to estimate mix parmaeters using analysis audio ----
117 | pred_track_params, pred_fx_bus_params, pred_master_bus_params = model(
118 | norm_analysis_tracks, analysis_ref
119 | )
120 |
121 | # ------- generate a mix using the predicted mix console parameters -------
122 | # apply with sliding window of 262144 samples with overlap
123 | pred_mix = torch.zeros(1, 2, norm_tracks.shape[-1])
124 |
125 | for i in tqdm(range(0, norm_tracks.shape[-1], analysis_len // 2)):
126 | norm_tracks_window = norm_tracks[..., i : i + analysis_len]
127 | (
128 | pred_mixed_tracks,
129 | pred_mix_window,
130 | pred_track_param_dict,
131 | pred_fx_bus_param_dict,
132 | pred_master_bus_param_dict,
133 | ) = mix_console(
134 | norm_tracks_window,
135 | pred_track_params,
136 | pred_fx_bus_params,
137 | pred_master_bus_params,
138 | use_track_input_fader=use_track_input_fader,
139 | use_track_panner=use_track_panner,
140 | use_track_eq=use_track_eq,
141 | use_track_compressor=use_track_compressor,
142 | use_fx_bus=use_fx_bus,
143 | use_master_bus=use_master_bus,
144 | use_output_fader=use_output_fader,
145 | )
146 | if pred_mix_window.shape[-1] < analysis_len:
147 | pred_mix_window = torch.nn.functional.pad(
148 | pred_mix_window, (0, analysis_len - pred_mix_window.shape[-1])
149 | )
150 |
151 | window = torch.hann_window(pred_mix_window.shape[-1])
152 | # apply hann window
153 | if i == 0:
154 | # set the first half of the window to 1
155 | window[: window.shape[-1] // 2] = 1.0
156 |
157 | pred_mix_window *= window
158 |
159 | # check length of the mix window
160 | output_len = pred_mix[..., i : i + analysis_len].shape[-1]
161 |
162 | # overlap add
163 | pred_mix[..., i : i + analysis_len] += pred_mix_window[..., :output_len]
164 |
165 | # crop the mix to the original length
166 | pred_mix = pred_mix[..., : norm_tracks.shape[-1]]
167 |
168 | return (
169 | pred_mix,
170 | pred_track_param_dict,
171 | pred_fx_bus_param_dict,
172 | pred_master_bus_param_dict,
173 | )
174 |
175 |
176 | def load_diffmst(config_path: str, ckpt_path: str, map_location: str = "cpu"):
177 | with open(config_path) as f:
178 | config = yaml.safe_load(f)
179 |
180 | core_model_configs = config["model"]["init_args"]["model"]
181 |
182 | module_path, class_name = core_model_configs["class_path"].rsplit(".", 1)
183 | module = import_module(module_path)
184 | model = getattr(module, class_name)(**core_model_configs["init_args"])
185 |
186 | submodule_configs = core_model_configs["init_args"]
187 |
188 | # create track encoder module
189 | module_path, class_name = submodule_configs["track_encoder"]["class_path"].rsplit(
190 | ".", 1
191 | )
192 | module = import_module(module_path)
193 | track_encoder = getattr(module, class_name)(
194 | **submodule_configs["track_encoder"]["init_args"]
195 | )
196 |
197 | # create mix encoder module
198 | module_path, class_name = submodule_configs["mix_encoder"]["class_path"].rsplit(
199 | ".", 1
200 | )
201 | module = import_module(module_path)
202 | mix_encoder = getattr(module, class_name)(
203 | **submodule_configs["mix_encoder"]["init_args"]
204 | )
205 |
206 | # create controller module
207 | module_path, class_name = submodule_configs["controller"]["class_path"].rsplit(
208 | ".", 1
209 | )
210 | module = import_module(module_path)
211 | controller = getattr(module, class_name)(
212 | **submodule_configs["controller"]["init_args"]
213 | )
214 |
215 | # create mix console module
216 | module_path, class_name = config["model"]["init_args"]["mix_console"][
217 | "class_path"
218 | ].rsplit(".", 1)
219 | module = import_module(module_path)
220 | mix_console = getattr(module, class_name)(
221 | **config["model"]["init_args"]["mix_console"]["init_args"]
222 | )
223 |
224 | checkpoint = torch.load(ckpt_path, map_location=map_location)
225 |
226 | # load state dicts
227 | state_dict = {}
228 | for k, v in checkpoint["state_dict"].items():
229 | if k.startswith("model.track_encoder"):
230 | state_dict[k.replace("model.track_encoder.", "", 1)] = v
231 | track_encoder.load_state_dict(state_dict)
232 |
233 | state_dict = {}
234 | for k, v in checkpoint["state_dict"].items():
235 | if k.startswith("model.mix_encoder"):
236 | state_dict[k.replace("model.mix_encoder.", "", 1)] = v
237 | mix_encoder.load_state_dict(state_dict)
238 |
239 | state_dict = {}
240 | for k, v in checkpoint["state_dict"].items():
241 | if k.startswith("model.controller"):
242 | state_dict[k.replace("model.controller.", "", 1)] = v
243 | controller.load_state_dict(state_dict)
244 |
245 | state_dict = {}
246 | for k, v in checkpoint["state_dict"].items():
247 | if k.startswith("model.mix_console"):
248 | state_dict[k.replace("model.mix_console.", "", 1)] = v
249 | mix_console.load_state_dict(state_dict)
250 |
251 | model = MixStyleTransferModel(
252 | track_encoder,
253 | mix_encoder,
254 | controller,
255 | )
256 | model.eval()
257 |
258 | return model, mix_console
259 |
260 |
261 | def denorm(p, p_min=0.0, p_max=1.0):
262 | return (p * (p_max - p_min)) + p_min
263 |
264 |
265 | def norm(p, p_min=0.0, p_max=1.0):
266 | return (p - p_min) / (p_max - p_min)
267 |
268 |
269 | def seed_worker(worker_id):
270 | worker_seed = torch.initial_seed() % 2**32
271 | np.random.seed(worker_seed)
272 | random.seed(worker_seed)
273 |
274 |
275 | def center_crop(x, length: int):
276 | start = (x.shape[-1] - length) // 2
277 | stop = start + length
278 | return x[..., start:stop]
279 |
280 |
281 | def causal_crop(x, length: int):
282 | stop = x.shape[-1] - 1
283 | start = stop - length
284 | return x[..., start:stop]
285 |
286 |
287 | def count_parameters(model):
288 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
289 |
290 |
291 | def rand(low=0, high=1):
292 | return (torch.rand(1).numpy()[0] * (high - low)) + low
293 |
294 |
295 | def randint(low=0, high=1):
296 | return torch.randint(low, high + 1, (1,)).numpy()[0]
297 |
298 |
299 | def center_crop(x, length: int):
300 | if x.shape[-1] != length:
301 | start = (x.shape[-1] - length) // 2
302 | stop = start + length
303 | x = x[..., start:stop]
304 | return x
305 |
306 |
307 | def causal_crop(x, length: int):
308 | if x.shape[-1] != length:
309 | stop = x.shape[-1] - 1
310 | start = stop - length
311 | x = x[..., start:stop]
312 | return x
313 |
314 |
315 | def find_first_peak(x, threshold_dB=-36, sample_rate=44100):
316 | """Find the first peak of the input signal.
317 |
318 | Args:
319 | x (Tensor): 1-d tensor with signal.
320 | threshold_dB (float, optional): Minimum peak treshold in dB. Default: -36.0
321 | sample_rate (float, optional): Sample rate of the input signal. Default: 44100
322 |
323 | Returns:
324 | first_peak_sample (int): Sample index of the first peak.
325 | first_peak_sec (float): Location of the first peak in seconds.
326 | """
327 | signal = 20 * torch.log10(x.view(-1).abs() + 1e-8)
328 | peaks = torch.where(signal > threshold_dB)[0]
329 | first_peak_sample = peaks[0]
330 | first_peak_sec = first_peak_sample / sample_rate
331 |
332 | return first_peak_sample, first_peak_sec
333 |
334 |
335 | def fade_in_and_fade_out(x, fade_ms=10.0, sample_rate=44100):
336 | """Apply a linear fade in and fade out to the last dim of a signal.
337 |
338 | Args:
339 | x (Tensor): Tensor with signal(s).
340 | fade_ms (float, optional): Length of the fade in milliseconds. Default: 10.0
341 | sample_rate (int, optional): Sample rate. Default: 44100
342 |
343 | Returns:
344 | x (Tensor): Faded signal(s).
345 | """
346 | fade_samples = int(fade_ms * 1e-3 * sample_rate)
347 | fade_in = torch.linspace(0, 1.0, fade_samples)
348 | fade_out = torch.linspace(1.0, 0, fade_samples)
349 | x[..., :fade_samples] *= fade_in
350 | x[..., x.shape[-1] - fade_samples :] *= fade_out
351 |
352 | return x
353 |
354 |
355 | def common_member(a, b):
356 | a_set = set(a)
357 | b_set = set(b)
358 | if a_set & b_set:
359 | return True
360 | else:
361 | return False
362 |
--------------------------------------------------------------------------------
/scripts/eval_ablation.py:
--------------------------------------------------------------------------------
1 | # run pretrained models over evaluation set to generate audio examples for the listening test
2 | import os
3 | import torch
4 | import torchaudio
5 | import pyloudnorm as pyln
6 | from mst.utils import load_diffmst, run_diffmst
7 | from mst.loss import compute_barkspectrum, compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, AudioFeatureLoss
8 | import json
9 | import numpy as np
10 | import csv
11 | import glob
12 |
13 |
14 | def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
15 |
16 | meter = pyln.Meter(44100)
17 | target_lufs_db = -48.0
18 |
19 | norm_tracks = []
20 | for track_idx in range(tracks.shape[1]):
21 | track = tracks[:, track_idx : track_idx + 1, :]
22 | lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy())
23 |
24 | if lufs_db < -80.0:
25 | print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.")
26 | continue
27 |
28 | lufs_delta_db = target_lufs_db - lufs_db
29 | track *= 10 ** (lufs_delta_db / 20)
30 | norm_tracks.append(track)
31 |
32 | norm_tracks = torch.cat(norm_tracks, dim=1)
33 | # create a sum mix with equal loudness
34 | sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1)
35 | sum_mix /= sum_mix.abs().max()
36 |
37 | return sum_mix, None, None, None
38 |
39 | class NumpyEncoder(json.JSONEncoder):
40 | """ Special json encoder for numpy types """
41 | def default(self, obj):
42 | if isinstance(obj, np.integer):
43 | return int(obj)
44 | elif isinstance(obj, np.floating):
45 | return float(obj)
46 | elif isinstance(obj, np.ndarray):
47 | return obj.tolist()
48 | return json.JSONEncoder.default(self, obj)
49 |
50 | if __name__ == "__main__":
51 | meter = pyln.Meter(44100)
52 | target_lufs_db = -22.0
53 | output_dir = "outputs/ablation"
54 | os.makedirs(output_dir, exist_ok=True)
55 |
56 | methods = {
57 | "diffmst-16": {
58 | "model": load_diffmst(
59 | "/Users/svanka/Downloads/b4naquji/config.yaml",
60 | "/Users/svanka/Downloads/b4naquji/checkpoints/epoch=191-step=626608.ckpt",
61 | ),
62 | "func": run_diffmst,
63 | },
64 | "sum": {
65 | "model": (None, None),
66 | "func": equal_loudness_mix,
67 | },
68 | }
69 |
70 | # get the validation examples
71 | examples = {
72 | "ecstasy": {
73 | "tracks": "/Users/svanka/Downloads//diffmst-examples/song1/BenFlowers_Ecstasy_Full/",
74 | "ref": "/Users/svanka/Codes/Diff-MST/outputs/ablation_ref_examples/_Feel it all Around_ by Washed Out (Portlandia Theme)_01/",
75 | },
76 | "by-my-side": {
77 | "tracks": "/Users/svanka/Downloads//diffmst-examples/song2/Kat Wright_By My Side/",
78 | "ref": "/Users/svanka/Codes/Diff-MST/outputs/ablation_ref_examples/The Dip - Paddle To The Stars (Lyric Video)_01/",
79 | },
80 | "haunted-aged": {
81 | "tracks": "/Users/svanka/Downloads//diffmst-examples/song3/Titanium_HauntedAge_Full/",
82 | "ref": "/Users/svanka/Codes/Diff-MST/outputs/ablation_ref_examples/Architects - _Doomsday__01/",
83 | },
84 | }
85 |
86 |
87 | loss = AudioFeatureLoss([0.1,0.001,1.0,1.0,0.1], 44100)
88 | AF = {}
89 | #initialise to negative infinity
90 |
91 | for example_name, example in examples.items():
92 |
93 |
94 | AF[example_name] = {}
95 | print(example_name)
96 | example_dir = os.path.join(output_dir, example_name)
97 | os.makedirs(example_dir, exist_ok=True)
98 | json_dir = os.path.join(output_dir, "AF")
99 | if not os.path.exists(json_dir):
100 | os.makedirs(json_dir, exist_ok=True)
101 | csv_path = os.path.join(json_dir,f"{example_name}.csv")
102 | # if not os.path.exists(csv_path):
103 | # os.makedirs(csv_path)
104 | with open(csv_path, 'w') as f:
105 | writer = csv.writer(f)
106 | writer.writerow(["method", "audio_type","ablation","start_idx", "stop_idx", "rms", "crest_factor", "stereo_width", "stereo_imbalance", "barkspectrum", "net_AF_loss"])
107 | f.close()
108 | ref_loudness_target = -16.0
109 |
110 | # --------------first find all the tracks----------------
111 | track_filepaths = []
112 | for root, dirs, files in os.walk(example["tracks"]):
113 | for filepath in files:
114 | if filepath.endswith(".wav"):
115 | track_filepaths.append(os.path.join(root, filepath))
116 |
117 | print(f"Found {len(track_filepaths)} tracks.")
118 |
119 | # ----------------load the tracks----------------------------
120 | tracks = []
121 | lengths = []
122 | for track_idx, track_filepath in enumerate(track_filepaths):
123 | audio, sr = torchaudio.load(track_filepath, backend="soundfile")
124 |
125 | if sr != 44100:
126 | audio = torchaudio.functional.resample(audio, sr, 44100)
127 |
128 | # loudness normalize the tracks to -48 LUFS
129 | lufs_db = meter.integrated_loudness(audio.permute(1, 0).numpy())
130 | # lufs_delta_db = -48 - lufs_db
131 | # audio = audio * 10 ** (lufs_delta_db / 20)
132 |
133 | print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db)
134 |
135 | if audio.shape[0] == 2:
136 | audio = audio.mean(dim=0, keepdim=True)
137 |
138 | chs, seq_len = audio.shape
139 |
140 | for ch_idx in range(chs):
141 | tracks.append(audio[ch_idx : ch_idx + 1, :])
142 | lengths.append(audio.shape[-1])
143 |
144 | # find max length and pad if shorter
145 | max_length = max(lengths)
146 | min_length = min(lengths)
147 | for track_idx in range(len(tracks)):
148 | tracks[track_idx] = torch.nn.functional.pad(
149 | tracks[track_idx], (0, max_length - lengths[track_idx])
150 | )
151 |
152 | # stack into a tensor
153 | tracks = torch.cat(tracks, dim=0)
154 | tracks = tracks.view(1, -1, max_length)
155 | tracks_length = max_length
156 | refs = glob.glob(os.path.join(example["ref"],"*.wav"))
157 | print("found refs", len(refs))
158 | for ref in refs:
159 | ref_name = os.path.basename(ref).replace(".wav", "")
160 | test_type = ref_name.split("_")[-2] + "_" + ref_name.split("_")[-1]
161 | print(test_type)
162 |
163 | print(ref_name)
164 | AF[example_name]["ref"] = {}
165 | AF[example_name]["pred_mix"] = {}
166 | ref_audio, ref_sr = torchaudio.load(ref, backend="soundfile")
167 | if ref_sr != 44100:
168 | ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 44100)
169 | print(ref_audio.shape, ref_sr)
170 | ref_length = ref_audio.shape[-1]
171 | ref_audio = ref_audio.view(1, 2, -1)
172 |
173 | #loudness normalize the reference mix to -16 LUFS
174 | ref_lufs_db = meter.integrated_loudness(ref_audio.squeeze().permute(1, 0).numpy())
175 | lufs_delta_db = ref_loudness_target - ref_lufs_db
176 | ref_audio = ref_audio * 10 ** (lufs_delta_db / 20)
177 |
178 |
179 | # --------------run inference----------------
180 | #print(tracks.shape)
181 | track_idx = int(tracks_length / 2)
182 | ref_idx = int(ref_length / 2)
183 | mix_tracks = tracks[..., track_idx - 220500 : track_idx + 220500]
184 | ref_analysis = ref_audio[..., ref_idx - 220500 : ref_idx + 220500]
185 |
186 | ref_path = os.path.join(example_dir, os.path.basename(ref).replace(".wav", "-ref-16.wav"))
187 | torchaudio.save(ref_path, ref_analysis.squeeze(), 44100)
188 |
189 | for method_name, method in methods.items():
190 | AF[example_name]["ref"] [method_name] = {}
191 | AF[example_name]["pred_mix"] [method_name] = {}
192 |
193 | print(method_name)
194 | model, mix_console = method["model"]
195 | func = method["func"]
196 |
197 | with torch.no_grad():
198 | result = func(
199 | mix_tracks.clone(),
200 | ref_analysis.clone(),
201 | model,
202 | mix_console,
203 | track_start_idx=0,
204 | ref_start_idx=0,
205 | )
206 |
207 | (
208 | pred_mix,
209 | pred_track_param_dict,
210 | pred_fx_bus_param_dict,
211 | pred_master_bus_param_dict,
212 | ) = result
213 |
214 | bs, chs, seq_len = pred_mix.shape
215 | print("pred_mix shape", pred_mix.shape)
216 | # loudness normalize the output mix
217 | mix_lufs_db = meter.integrated_loudness(
218 | pred_mix.squeeze(0).permute(1, 0).numpy()
219 | )
220 | print("pred_mix_lufs_db", mix_lufs_db)
221 | #print(mix_lufs_db)
222 | lufs_delta_db = target_lufs_db - mix_lufs_db
223 | pred_mix = pred_mix * 10 ** (lufs_delta_db / 20)
224 | name = os.path.basename(ref).replace(".wav", "-pred_mix-16.wav")
225 | mix_filepath = os.path.join(example_dir, f"{method_name}_{name}")
226 | torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100)
227 |
228 | # compute audio features
229 |
230 | AF[example_name]["pred_mix"][method_name]["mix-rms"] = 0.1*compute_rms(pred_mix, sample_rate = sr).mean().detach().cpu().numpy()
231 | AF[example_name]["pred_mix"][method_name]["mix-crest_factor"] = 0.001*compute_crest_factor(pred_mix, sample_rate = sr).mean().detach().cpu().numpy()
232 | AF[example_name]["pred_mix"][method_name]["mix-stereo_width"] = 1.0*compute_stereo_width(pred_mix, sample_rate = sr).detach().cpu().numpy()
233 | AF[example_name]["pred_mix"][method_name]["mix-stereo_imbalance"] = 1.0*compute_stereo_imbalance(pred_mix, sample_rate = sr).detach().cpu().numpy()
234 | AF[example_name]["pred_mix"][method_name]["mix-barkspectrum"] = 0.1*compute_barkspectrum(pred_mix, sample_rate = sr).mean().detach().cpu().numpy()
235 |
236 | AF[example_name]["ref"][method_name]["mix-rms"] = 0.1*compute_rms(ref_analysis, sample_rate = sr).mean().detach().cpu().numpy()
237 | AF[example_name]["ref"][method_name]["mix-crest_factor"] = 0.001*compute_crest_factor(ref_analysis, sample_rate = sr).mean().detach().cpu().numpy()
238 | AF[example_name]["ref"][method_name]["mix-stereo_width"] = 1.0*compute_stereo_width(ref_analysis, sample_rate = sr).detach().cpu().numpy()
239 | AF[example_name]["ref"][method_name]["mix-stereo_imbalance"] = 1.0*compute_stereo_imbalance(ref_analysis, sample_rate = sr).detach().cpu().numpy()
240 | AF[example_name]["ref"][method_name]["mix-barkspectrum"] = 0.1*compute_barkspectrum(ref_analysis, sample_rate = sr).mean().detach().cpu().numpy()
241 |
242 | AF_loss = loss(pred_mix, ref_analysis)
243 | AF[example_name]["pred_mix"][method_name]["net_AF_loss"] = sum(AF_loss.values()).detach().cpu().numpy()
244 | AF[example_name]["ref"][method_name]["net_AF_loss"] = AF[example_name]["pred_mix"][method_name]["net_AF_loss"]
245 |
246 |
247 | # save resulting audio and parameters
248 | #append to csv the method name, audio section, audio features values and net loss on different columns
249 | with open(csv_path, 'a') as f:
250 | writer = csv.writer(f)
251 | writer.writerow([method_name, "pred_mix", test_type, track_idx - 220500, track_idx + 220500,AF[example_name]["pred_mix"][method_name]["mix-rms"], AF[example_name]["pred_mix"][method_name]["mix-crest_factor"], AF[example_name]["pred_mix"][method_name]["mix-stereo_width"], AF[example_name]["pred_mix"][method_name]["mix-stereo_imbalance"], AF[example_name]["pred_mix"][method_name]["mix-barkspectrum"], AF[example_name]["pred_mix"][method_name]["net_AF_loss"]])
252 | writer.writerow([method_name, "ref", test_type, ref_idx - 220500, ref_idx + 220500,AF[example_name]["ref"][method_name]["mix-rms"], AF[example_name]["ref"][method_name]["mix-crest_factor"], AF[example_name]["ref"][method_name]["mix-stereo_width"], AF[example_name]["ref"][method_name]["mix-stereo_imbalance"], AF[example_name]["ref"][method_name]["mix-barkspectrum"], AF[example_name]["ref"][method_name]["net_AF_loss"]])
253 | f.close()
254 |
255 |
256 |
257 | #write disctionary to json
258 |
259 |
260 |
--------------------------------------------------------------------------------
/scripts/eval_all_combo.py:
--------------------------------------------------------------------------------
1 | # run pretrained models over evaluation set to generate audio examples for the listening test
2 | import os
3 | import torch
4 | import torchaudio
5 | import pyloudnorm as pyln
6 | from mst.utils import load_diffmst, run_diffmst
7 | from mst.loss import compute_barkspectrum, compute_rms, compute_crest_factor, compute_stereo_width, compute_stereo_imbalance, AudioFeatureLoss
8 | import json
9 | import numpy as np
10 | import csv
11 |
12 |
13 | def equal_loudness_mix(tracks: torch.Tensor, *args, **kwargs):
14 |
15 | meter = pyln.Meter(44100)
16 | target_lufs_db = -48.0
17 |
18 | norm_tracks = []
19 | for track_idx in range(tracks.shape[1]):
20 | track = tracks[:, track_idx : track_idx + 1, :]
21 | lufs_db = meter.integrated_loudness(track.squeeze(0).permute(1, 0).numpy())
22 |
23 | if lufs_db < -80.0:
24 | print(f"Skipping track {track_idx} with {lufs_db:.2f} LUFS.")
25 | continue
26 |
27 | lufs_delta_db = target_lufs_db - lufs_db
28 | track *= 10 ** (lufs_delta_db / 20)
29 | norm_tracks.append(track)
30 |
31 | norm_tracks = torch.cat(norm_tracks, dim=1)
32 | # create a sum mix with equal loudness
33 | sum_mix = torch.sum(norm_tracks, dim=1, keepdim=True).repeat(1, 2, 1)
34 | sum_mix /= sum_mix.abs().max()
35 |
36 | return sum_mix, None, None, None
37 |
38 | class NumpyEncoder(json.JSONEncoder):
39 | """ Special json encoder for numpy types """
40 | def default(self, obj):
41 | if isinstance(obj, np.integer):
42 | return int(obj)
43 | elif isinstance(obj, np.floating):
44 | return float(obj)
45 | elif isinstance(obj, np.ndarray):
46 | return obj.tolist()
47 | return json.JSONEncoder.default(self, obj)
48 |
49 | if __name__ == "__main__":
50 | meter = pyln.Meter(44100)
51 | target_lufs_db = -22.0
52 | output_dir = "outputs/listen"
53 | os.makedirs(output_dir, exist_ok=True)
54 |
55 | methods = {
56 | "diffmst-16": {
57 | "model": load_diffmst(
58 | "/Users/svanka/Downloads/b4naquji/config.yaml",
59 | "/Users/svanka/Downloads/b4naquji/checkpoints/epoch=191-step=626608.ckpt",
60 | ),
61 | "func": run_diffmst,
62 | },
63 | "sum": {
64 | "model": (None, None),
65 | "func": equal_loudness_mix,
66 | },
67 | }
68 |
69 | # get the validation examples
70 | examples = {
71 | # "ecstasy": {
72 | # "tracks": "/Users/svanka/Downloads//diffmst-examples/song1/BenFlowers_Ecstasy_Full/",
73 | # "ref": "/Users/svanka/Downloads//diffmst-examples/song1/ref/_Feel it all Around_ by Washed Out (Portlandia Theme)_01.wav",
74 | # },
75 | # "by-my-side": {
76 | # "tracks": "/Users/svanka/Downloads//diffmst-examples/song2/Kat Wright_By My Side/",
77 | # "ref": "/Users/svanka/Downloads//diffmst-examples/song2/ref/The Dip - Paddle To The Stars (Lyric Video)_01.wav",
78 | # },
79 | "haunted-aged": {
80 | "tracks": "/Users/svanka/Downloads//diffmst-examples/song3/Titanium_HauntedAge_Full/",
81 | "ref": "/Users/svanka/Downloads//diffmst-examples/song3/ref/Architects - _Doomsday__01.wav",
82 | },
83 | }
84 | loss = AudioFeatureLoss([0.1,0.001,1.0,1.0,0.1], 44100)
85 | AF = {}
86 | #initialise to negative infinity
87 |
88 | for example_name, example in examples.items():
89 |
90 | AF[example_name] = {}
91 | print(example_name)
92 | example_dir = os.path.join(output_dir, example_name)
93 | os.makedirs(example_dir, exist_ok=True)
94 | json_dir = os.path.join(output_dir, "AF")
95 | if not os.path.exists(json_dir):
96 | os.makedirs(json_dir, exist_ok=True)
97 | csv_path = os.path.join(json_dir,f"{example_name}.csv")
98 | # if not os.path.exists(csv_path):
99 | # os.makedirs(csv_path)
100 | with open(csv_path, 'w') as f:
101 | writer = csv.writer(f)
102 | writer.writerow(["method", "audio_section","track_start_idx", "track_stop_idx", "ref_start_idx", "ref_stop_idx", "rms", "crest_factor", "stereo_width", "stereo_imbalance", "barkspectrum", "net_AF_loss"])
103 | f.close()
104 |
105 | # ----------load reference mix---------------
106 | ref_audio, ref_sr = torchaudio.load(example["ref"], backend="soundfile")
107 | if ref_sr != 44100:
108 | ref_audio = torchaudio.functional.resample(ref_audio, ref_sr, 44100)
109 | print(ref_audio.shape, ref_sr)
110 | ref_length = ref_audio.shape[-1]
111 | # --------------first find all the tracks----------------
112 | track_filepaths = []
113 | for root, dirs, files in os.walk(example["tracks"]):
114 | for filepath in files:
115 | if filepath.endswith(".wav"):
116 | track_filepaths.append(os.path.join(root, filepath))
117 |
118 | print(f"Found {len(track_filepaths)} tracks.")
119 |
120 | # ----------------load the tracks----------------------------
121 | tracks = []
122 | lengths = []
123 | for track_idx, track_filepath in enumerate(track_filepaths):
124 | audio, sr = torchaudio.load(track_filepath, backend="soundfile")
125 |
126 | if sr != 44100:
127 | audio = torchaudio.functional.resample(audio, sr, 44100)
128 |
129 | # loudness normalize the tracks to -48 LUFS
130 | lufs_db = meter.integrated_loudness(audio.permute(1, 0).numpy())
131 | # lufs_delta_db = -48 - lufs_db
132 | # audio = audio * 10 ** (lufs_delta_db / 20)
133 |
134 | print(track_idx, os.path.basename(track_filepath), audio.shape, sr, lufs_db)
135 |
136 | if audio.shape[0] == 2:
137 | audio = audio.mean(dim=0, keepdim=True)
138 |
139 | chs, seq_len = audio.shape
140 |
141 | for ch_idx in range(chs):
142 | tracks.append(audio[ch_idx : ch_idx + 1, :])
143 | lengths.append(audio.shape[-1])
144 |
145 | # find max length and pad if shorter
146 | max_length = max(lengths)
147 | min_length = min(lengths)
148 | for track_idx in range(len(tracks)):
149 | tracks[track_idx] = torch.nn.functional.pad(
150 | tracks[track_idx], (0, max_length - lengths[track_idx])
151 | )
152 |
153 | # stack into a tensor
154 | tracks = torch.cat(tracks, dim=0)
155 | tracks = tracks.view(1, -1, max_length)
156 | ref_audio = ref_audio.view(1, 2, -1)
157 |
158 | # crop tracks to max of 60 seconds or so
159 | # tracks = tracks[..., :4194304]
160 | tracks_length = max_length
161 |
162 | #print(tracks.shape)
163 | track_start_idx = int(tracks_length / 4)
164 | ref_start_idx = int(ref_length / 4)
165 | track_stop_idx = int(3*tracks_length / 4)
166 | ref_stop_idx = int(3*ref_length / 4)
167 | #find the number of sets of track samples of 10 sec duration each
168 | track_num_sets = int((track_stop_idx - track_start_idx) / 441000)
169 | ref_num_sets = int((ref_stop_idx - ref_start_idx) / 441000)
170 | print("track_num_sets", track_num_sets)
171 | print("ref_num_sets", ref_num_sets)
172 | min_AF_loss = float('inf')
173 | min_AF_loss_example = None
174 | for i in range(track_num_sets):
175 | for j in range(ref_num_sets):
176 | print(f"track-{i}-ref-{j}")
177 | #run inference for every combination of track and ref samples and calculate audio features.
178 | # We will save the audio features to a csv and audio files in the output directory
179 | mix_tracks = tracks[..., track_start_idx + i*441000 : track_start_idx + (i+1)*441000]
180 | ref_analysis = ref_audio[..., ref_start_idx + j*441000 : ref_start_idx + (j+1)*441000]
181 |
182 | # create mixes varying the loudness of the reference
183 | for ref_loudness_target in [-16.0]:
184 | print("Ref loudness", ref_loudness_target)
185 | ref_filepath = os.path.join(
186 | example_dir,
187 | f"ref-analysis-track-{i}-ref-{j}-lufs-{ref_loudness_target:0.0f}.wav",
188 | )
189 |
190 | # loudness normalize the reference mix section to -14 LUFS
191 | ref_lufs_db = meter.integrated_loudness(
192 | ref_analysis.squeeze().permute(1, 0).numpy()
193 | )
194 | print("ref_lufs_db", ref_lufs_db)
195 | lufs_delta_db = ref_loudness_target - ref_lufs_db
196 | ref_analysis = ref_analysis * 10 ** (lufs_delta_db / 20)
197 |
198 | torchaudio.save(ref_filepath, ref_analysis.squeeze(), 44100)
199 |
200 | AF_loss = 0
201 | for method_name, method in methods.items():
202 | AF[example_name][method_name] = {}
203 | print(method_name)
204 | # tracks (torch.Tensor): Set of input tracks with shape (bs, num_tracks, seq_len)
205 | # ref_audio (torch.Tensor): Reference mix with shape (bs, 2, seq_len)
206 |
207 | if method_name == "sum":
208 | if ref_loudness_target != -16:
209 | continue
210 |
211 |
212 | model, mix_console = method["model"]
213 | func = method["func"]
214 |
215 | #print(tracks.shape, ref_audio.shape)
216 | audio_section = f"track-{i}-ref-{j}-lufs-{ref_loudness_target:0.0f}"
217 | AF[example_name][method_name][audio_section] = {}
218 | AF[example_name][method_name][audio_section]["track_start_idx"] = track_start_idx + i*441000
219 | AF[example_name][method_name][audio_section]["track_stop_idx"] = track_start_idx + (i+1)*441000
220 | AF[example_name][method_name][audio_section]["ref_start_idx"] = ref_start_idx + j*441000
221 | AF[example_name][method_name][audio_section]["ref_stop_idx"] = ref_start_idx + (j+1)*441000
222 | with torch.no_grad():
223 | result = func(
224 | mix_tracks.clone(),
225 | ref_analysis.clone(),
226 | model,
227 | mix_console,
228 | track_start_idx=0,
229 | ref_start_idx=0,
230 | )
231 |
232 | (
233 | pred_mix,
234 | pred_track_param_dict,
235 | pred_fx_bus_param_dict,
236 | pred_master_bus_param_dict,
237 | ) = result
238 |
239 | bs, chs, seq_len = pred_mix.shape
240 | print("pred_mix shape", pred_mix.shape)
241 | # loudness normalize the output mix
242 | mix_lufs_db = meter.integrated_loudness(
243 | pred_mix.squeeze(0).permute(1, 0).numpy()
244 | )
245 | print("pred_mix_lufs_db", mix_lufs_db)
246 | #print(mix_lufs_db)
247 | lufs_delta_db = target_lufs_db - mix_lufs_db
248 | pred_mix = pred_mix * 10 ** (lufs_delta_db / 20)
249 | mix_filepath = os.path.join(
250 | example_dir,
251 | f"{example_name}-{method_name}-tracks-{i}-ref={j}-lufs-{ref_loudness_target:0.0f}.wav",
252 | )
253 | torchaudio.save(mix_filepath, pred_mix.view(chs, -1), 44100)
254 |
255 | # compute audio features
256 | AF_loss = loss(pred_mix, ref_analysis)
257 |
258 | for key, value in AF_loss.items():
259 | AF[example_name][method_name][audio_section][key] = value.detach().cpu().numpy()
260 | AF[example_name][method_name][audio_section]["net_AF_loss"] = sum(AF_loss.values()).detach().cpu().numpy()
261 | print(AF[example_name][method_name][audio_section])
262 |
263 | if AF[example_name][method_name][audio_section]["net_AF_loss"] < min_AF_loss:
264 | min_AF_loss = AF[example_name][method_name][audio_section]["net_AF_loss"]
265 | min_AF_loss_example = f"{example_name}-{method_name}-{audio_section}"
266 | print("min_AF_loss", min_AF_loss)
267 | print("min_AF_loss_example", min_AF_loss_example)
268 | # save resulting audio and parameters
269 | #append to csv the method name, audio section, audio features values and net loss on different columns
270 |
271 | with open(csv_path, 'a') as f:
272 | writer = csv.writer(f)
273 | writer.writerow([method_name, audio_section, AF[example_name][method_name][audio_section]["track_start_idx"], AF[example_name][method_name][audio_section]["track_stop_idx"], AF[example_name][method_name][audio_section]["ref_start_idx"], AF[example_name][method_name][audio_section]["ref_stop_idx"], AF[example_name][method_name][audio_section]["mix-rms"], AF[example_name][method_name][audio_section]["mix-crest_factor"], AF[example_name][method_name][audio_section]["mix-stereo_width"], AF[example_name][method_name][audio_section]["mix-stereo_imbalance"], AF[example_name][method_name][audio_section]["mix-barkspectrum"], AF[example_name][method_name][audio_section]["net_AF_loss"]])
274 | f.close()
275 |
276 |
277 | print(f"for {example_name} min loss is {min_AF_loss} corresponding to {min_AF_loss_example}")
278 | print()
279 |
280 | #write disctionary to json
281 |
282 |
283 |
--------------------------------------------------------------------------------
/scripts/online.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import torch
4 | import argparse
5 | import torchaudio
6 | import numpy as np
7 | import pyloudnorm as pyln
8 | import matplotlib.pyplot as plt
9 | from tqdm import tqdm
10 |
11 | from mst.loss import AudioFeatureLoss, StereoCLAPLoss, compute_barkspectrum
12 | from mst.modules import AdvancedMixConsole
13 |
14 |
15 | def optimize(
16 | tracks: torch.Tensor,
17 | ref_mix: torch.Tensor,
18 | mix_console: torch.nn.Module,
19 | loss_function: torch.nn.Module,
20 | init_scale: float = 0.001,
21 | lr: float = 1e-3,
22 | n_iters: int = 100,
23 | ):
24 | """Create a mix from the tracks that is as close as possible to the reference mixture.
25 |
26 | Args:
27 | tracks (torch.Tensor): Tensor of shape (n_tracks, n_samples).
28 | ref_mix (torch.Tensor): Tensor of shape (2, n_samples).
29 | mix_console (torch.nn.Module): Mix console instance. (e.g. AdvancedMixConsole)
30 | loss_function (torch.nn.Module): Loss function instance. (e.g. AudioFeatureLoss)
31 | n_iters (int): Number of iterations for the optimization.
32 |
33 | Returns:
34 | torch.Tensor: Tensor of shape (2, n_samples) that is as close as possible to the reference mixture.
35 | """
36 | loss_history = {"loss": []} # lists to store loss values
37 |
38 | # initialize the mix console parameters to optimize
39 | track_params = init_scale * torch.randn(
40 | tracks.shape[0], mix_console.num_track_control_params
41 | )
42 | fx_bus_params = init_scale * torch.randn(1, mix_console.num_fx_bus_control_params)
43 | master_bus_params = init_scale * torch.randn(
44 | 1, mix_console.num_master_bus_control_params
45 | )
46 |
47 | # move parameters to same device as tracks
48 | track_params = track_params.type_as(tracks)
49 | fx_bus_params = fx_bus_params.type_as(tracks)
50 | master_bus_params = master_bus_params.type_as(tracks)
51 |
52 | # require gradients for the parameters
53 | track_params.requires_grad = True
54 | fx_bus_params.requires_grad = True
55 | master_bus_params.requires_grad = True
56 |
57 | # create optimizer and link to console parameters
58 | optimizer = torch.optim.Adam(
59 | [track_params, fx_bus_params, master_bus_params], lr=lr
60 | )
61 |
62 | pbar = tqdm(range(n_iters))
63 |
64 | # reshape
65 | tracks = tracks.unsqueeze(0)
66 | ref_mix = ref_mix.unsqueeze(0)
67 | track_params = track_params.unsqueeze(0)
68 | fx_bus_params = fx_bus_params.unsqueeze(0)
69 | master_bus_params = master_bus_params.unsqueeze(0)
70 |
71 | for n in pbar:
72 | optimizer.zero_grad()
73 |
74 | # mix the tracks using the mix console
75 | # the mix console parameters are sigmoided to ensure they are in the range [0, 1]
76 | result = mix_console(
77 | tracks,
78 | torch.sigmoid(track_params),
79 | torch.sigmoid(fx_bus_params),
80 | torch.sigmoid(master_bus_params),
81 | use_fx_bus=False,
82 | )
83 | mix = result[1]
84 | track_param_dict = result[2]
85 | fx_bus_param_dict = result[3]
86 | master_bus_param_dict = result[4]
87 |
88 | # compute loss
89 | loss = 0
90 | losses = loss_function(mix, ref_mix)
91 | for loss_name, loss_value in losses.items():
92 | loss += loss_value
93 |
94 | # compute gradients and update parameters
95 | loss.backward()
96 | optimizer.step()
97 |
98 | # update progress bar
99 | pbar.set_description(f"Loss: {loss.item():.4f}")
100 |
101 | # store loss values
102 | loss_history["loss"].append(loss.item())
103 | for loss_name, loss_value in losses.items():
104 | if loss_name not in loss_history:
105 | loss_history[loss_name] = []
106 | loss_history[loss_name].append(loss_value.item())
107 |
108 | # reshape
109 | mix = mix.squeeze(0)
110 | track_params = track_params # .squeeze(0)
111 | fx_bus_params = fx_bus_params # .squeeze(0)
112 | master_bus_params = master_bus_params # .squeeze(0)
113 |
114 | return (
115 | mix,
116 | track_params,
117 | track_param_dict,
118 | fx_bus_params,
119 | fx_bus_param_dict,
120 | master_bus_params,
121 | master_bus_param_dict,
122 | loss_history,
123 | )
124 |
125 |
126 | if __name__ == "__main__":
127 | parser = argparse.ArgumentParser()
128 | parser.add_argument(
129 | "--track_dir",
130 | type=str,
131 | help="Path to directory with tracks.",
132 | )
133 | parser.add_argument(
134 | "--ref_mix",
135 | type=str,
136 | help="Path to reference mixture.",
137 | )
138 | parser.add_argument(
139 | "--n_iters",
140 | type=int,
141 | help="Number of iterations.",
142 | default=250,
143 | )
144 | parser.add_argument(
145 | "--lr",
146 | type=float,
147 | help="Learning rate.",
148 | default=1e-3,
149 | )
150 | parser.add_argument(
151 | "--loss",
152 | type=str,
153 | help="Loss function.",
154 | choices=["feat", "clap"],
155 | default="feat",
156 | )
157 | parser.add_argument(
158 | "--output_dir",
159 | type=str,
160 | help="Path to output directory.",
161 | default="outputs",
162 | )
163 | parser.add_argument(
164 | "--use_gpu",
165 | action="store_true",
166 | help="Use GPU for optimization.",
167 | )
168 | parser.add_argument(
169 | "--sample_rate",
170 | type=int,
171 | help="Sample rate of tracks.",
172 | default=44100,
173 | )
174 | parser.add_argument(
175 | "--block_size",
176 | type=int,
177 | default=524288,
178 | help="Analysis block size.",
179 | )
180 | parser.add_argument(
181 | "--start_time_s",
182 | type=float,
183 | default=32.0,
184 | help="Analysis block start time.",
185 | )
186 | parser.add_argument(
187 | "--target_track_lufs_db",
188 | type=float,
189 | default=-48.0,
190 | )
191 | parser.add_argument(
192 | "--target_mix_lufs_db",
193 | type=float,
194 | default=-14.0,
195 | )
196 | parser.add_argument(
197 | "--stem_separation",
198 | action="store_true",
199 | )
200 | parser
201 | args = parser.parse_args()
202 |
203 | meter = pyln.Meter(args.sample_rate)
204 | run_name = f"{os.path.basename(args.track_dir)}-->{os.path.basename(args.ref_mix).split('.')[0]}"
205 | output_dir = os.path.join(args.output_dir, run_name)
206 | os.makedirs(output_dir, exist_ok=True)
207 | os.makedirs(os.path.join(output_dir, "plots"), exist_ok=True)
208 |
209 | # -------------------------- data loading -------------------------- #
210 | # load tracks
211 | tracks = []
212 | print(f"Loading tracks for current run: {run_name}...")
213 | track_filepaths = sorted(glob.glob(os.path.join(args.track_dir, "*.wav")))
214 | num_tracks = len(track_filepaths)
215 | for track_idx, track_filepath in enumerate(track_filepaths):
216 | track, track_sample_rate = torchaudio.load(os.path.join(track_filepath))
217 | print(
218 | f"{track_idx+1}/{num_tracks}: {track.shape} {os.path.basename(track_filepath)}"
219 | )
220 |
221 | # check if track has same sample rate as reference mixture
222 | # if not, resample
223 | if track_sample_rate != args.sample_rate:
224 | track = torchaudio.transforms.Resample(track_sample_rate, args.sample_rate)(
225 | track
226 | )
227 |
228 | # check if the track is silent
229 | for ch_idx in range(track.shape[0]):
230 | # measure loudness
231 | track_lufs_db = meter.integrated_loudness(track[ch_idx, :].numpy())
232 |
233 | if track_lufs_db < -60.0:
234 | print(f"Track is inactive at {track_lufs_db:0.2f} dB. Skipping...")
235 | continue
236 | else:
237 | # loudness normalize
238 | delta_lufs_db = args.target_track_lufs_db - track_lufs_db
239 | delta_lufs_lin = 10 ** (delta_lufs_db / 20)
240 | tracks.append(delta_lufs_lin * track[ch_idx, :])
241 |
242 | tracks = torch.stack(tracks) # shape: (n_tracks, n_samples)
243 |
244 | # load reference mixture
245 | ref_mix, ref_sample_rate = torchaudio.load(args.ref_mix)
246 |
247 | if ref_sample_rate != args.sample_rate:
248 | ref_mix = torchaudio.transforms.Resample(ref_sample_rate, args.sample_rate)(
249 | ref_mix
250 | )
251 | mix_lufs_db = meter.integrated_loudness(ref_mix.permute(1, 0).numpy())
252 | delta_lufs_db = args.target_mix_lufs_db - mix_lufs_db
253 | delta_lufs_lin = 10 ** (delta_lufs_db / 20)
254 | ref_mix = delta_lufs_lin * ref_mix
255 |
256 | # use only a subsection of the reference mixture and tracks
257 | # this is to speed up the optimization
258 | start_time_s = args.start_time_s
259 | start_sample = int(start_time_s * args.sample_rate)
260 |
261 | ref_mix_section = ref_mix[:, start_sample : start_sample + args.block_size]
262 | tracks_section = tracks[:, start_sample : start_sample + args.block_size]
263 |
264 | print(ref_mix.shape, tracks.shape)
265 | print(ref_mix_section.shape, tracks_section.shape)
266 |
267 | if args.use_gpu:
268 | ref_mix_section = ref_mix_section.cuda()
269 | tracks_section = tracks_section.cuda()
270 |
271 | # -------------------------- setup -------------------------- #
272 | # mix console will use the same sample rate as the tracks
273 | mix_console = AdvancedMixConsole(args.sample_rate)
274 |
275 | weights = [
276 | 0.1, # rms
277 | 0.001, # crest factor
278 | 1.0, # stereo width
279 | 1.0, # stereo imbalance
280 | 1.00, # bark spectrum
281 | 100.0, # clap
282 | ]
283 |
284 | if args.loss == "feat":
285 | loss_function = AudioFeatureLoss(
286 | weights,
287 | args.sample_rate,
288 | stem_separation=args.stem_separation,
289 | )
290 | elif args.loss == "clap":
291 | loss_function = StereoCLAPLoss()
292 | else:
293 | raise ValueError(f"Unknown loss: {args.loss}")
294 |
295 | if args.use_gpu:
296 | loss_function.cuda()
297 |
298 | # -------------------------- optimization -------------------------- #
299 | result = optimize(
300 | tracks_section,
301 | ref_mix_section,
302 | mix_console,
303 | loss_function,
304 | n_iters=args.n_iters,
305 | )
306 |
307 | mix = result[0]
308 | track_params = result[1]
309 | track_param_dict = result[2]
310 | fx_bus_parms = result[3]
311 | fx_bus_param_dict = result[4]
312 | master_bus_params = result[5]
313 | master_bus_param_dict = result[6]
314 | loss_history = result[7]
315 |
316 | ref_mix = ref_mix.squeeze(0).cpu()
317 | mono_mix = tracks.sum(dim=0).repeat(2, 1).cpu()
318 |
319 | # print(track_param_dict)
320 | # print(fx_bus_param_dict)
321 | # print(master_bus_param_dict)
322 | print(mix.abs().max())
323 | print(mono_mix.abs().max())
324 |
325 | # ----------------------- full mix generation ---------------------- #
326 | # iterate over the tracks in blocks to mix the entire song
327 | # this is to avoid memory issues
328 | block_size = args.block_size
329 | n_blocks = tracks.shape[-1] // block_size
330 | full_mix = torch.zeros(2, tracks.shape[-1])
331 | for block_idx in tqdm(range(n_blocks)):
332 | tracks_block = tracks[:, block_idx * block_size : (block_idx + 1) * block_size]
333 | tracks_block = tracks_block.type_as(tracks_section)
334 |
335 | with torch.no_grad():
336 | result = mix_console(
337 | tracks_block.unsqueeze(0),
338 | torch.sigmoid(track_params),
339 | torch.sigmoid(fx_bus_parms),
340 | torch.sigmoid(master_bus_params),
341 | use_fx_bus=False,
342 | )
343 | mix_block = result[1].squeeze(0).cpu()
344 | full_mix[
345 | :, block_idx * block_size : (block_idx + 1) * block_size
346 | ] = mix_block
347 |
348 | # loudness normalize
349 | full_mix /= full_mix.abs().max()
350 | mono_mix /= mono_mix.abs().max()
351 |
352 | mono_mix_section = mono_mix[:, start_sample : start_sample + args.block_size]
353 |
354 | # -------------------------- analyze mixes -------------------------- #
355 |
356 | ref_spec = compute_barkspectrum(ref_mix.unsqueeze(0), sample_rate=args.sample_rate)
357 | pred_spec = compute_barkspectrum(
358 | full_mix.unsqueeze(0), sample_rate=args.sample_rate
359 | )
360 |
361 | fig, axs = plt.subplots(2, 1, sharex=True, sharey=True)
362 | axs[0].plot(ref_spec[0, :, 0], label="ref-mid", color="tab:orange")
363 | axs[0].plot(pred_spec[0, :, 0], label="pred-mid", color="tab:blue")
364 | axs[1].plot(ref_spec[0, :, 1], label="ref-side", color="tab:orange")
365 | axs[1].plot(pred_spec[0, :, 1], label="pred-side", color="tab:blue")
366 | axs[0].legend()
367 | axs[1].legend()
368 | plt.savefig(os.path.join(output_dir, "plots", "bark_specta.png"))
369 | plt.close("all")
370 |
371 | for idx, (loss_name, loss_vals) in enumerate(loss_history.items()):
372 | fig, ax = plt.subplots(1, 1)
373 | ax.plot(loss_vals, label=loss_name)
374 | ax.set_xlabel("Iteration")
375 | ax.set_ylabel(f"{loss_name}")
376 | plt.savefig(os.path.join(output_dir, "plots", f"{loss_name}.png"))
377 | plt.close("all")
378 |
379 | # -------------------------- save results -------------------------- #
380 | # save mix
381 | torchaudio.save(
382 | os.path.join(output_dir, "pred_mix_section.wav"),
383 | mix.squeeze(0).cpu(),
384 | args.sample_rate,
385 | )
386 | torchaudio.save(
387 | os.path.join(output_dir, "ref_mix_section.wav"),
388 | ref_mix_section.cpu(),
389 | args.sample_rate,
390 | )
391 | torchaudio.save(
392 | os.path.join(output_dir, "mono_mix_section.wav"),
393 | mono_mix_section,
394 | args.sample_rate,
395 | )
396 | torchaudio.save(
397 | os.path.join(output_dir, "pred_mix.wav"),
398 | full_mix,
399 | args.sample_rate,
400 | )
401 | torchaudio.save(
402 | os.path.join(output_dir, "mono_mix.wav"),
403 | mono_mix,
404 | args.sample_rate,
405 | )
406 | torchaudio.save(
407 | os.path.join(output_dir, "ref_mix.wav"),
408 | ref_mix,
409 | args.sample_rate,
410 | )
411 |
--------------------------------------------------------------------------------
/mst/system.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import yaml
4 | import torch
5 | import auraloss
6 | import pytorch_lightning as pl
7 |
8 | import time
9 | from typing import Callable
10 | from mst.mixing import knowledge_engineering_mix
11 | from mst.utils import batch_stereo_peak_normalize
12 | from mst.fx_encoder import FXencoder
13 | import pyloudnorm as pyln
14 |
15 |
16 | class System(pl.LightningModule):
17 | def __init__(
18 | self,
19 | model: torch.nn.Module,
20 | mix_console: torch.nn.Module,
21 | mix_fn: Callable,
22 | loss: torch.nn.Module,
23 | generate_mix: bool = True,
24 | use_track_loss: bool = False,
25 | use_mix_loss: bool = True,
26 | use_param_loss: bool = False,
27 | instrument_id_json: str = "data/instrument_name2id.json",
28 | knowledge_engineering_yaml: str = "data/knowledge_engineering.yaml",
29 | active_eq_epoch: int = 0,
30 | active_compressor_epoch: int = 0,
31 | active_fx_bus_epoch: int = 0,
32 | active_master_bus_epoch: int = 0,
33 | lr: float = 1e-4,
34 | max_epochs: int = 500,
35 | schedule: str = "step",
36 | **kwargs,
37 | ) -> None:
38 | super().__init__()
39 | self.model = model
40 | self.mix_console = mix_console
41 | self.mix_fn = mix_fn
42 | self.loss = loss
43 | self.generate_mix = generate_mix
44 | self.use_track_loss = use_track_loss
45 | self.use_mix_loss = use_mix_loss
46 | self.use_param_loss = use_param_loss
47 | self.active_eq_epoch = active_eq_epoch
48 | self.active_compressor_epoch = active_compressor_epoch
49 | self.active_fx_bus_epoch = active_fx_bus_epoch
50 | self.active_master_bus_epoch = active_master_bus_epoch
51 |
52 | self.meter = pyln.Meter(44100)
53 | #self.warmup = warmup
54 |
55 |
56 | self.save_hyperparameters(ignore=["model", "mix_console", "mix_fn", "loss"])
57 |
58 |
59 | # losses for evaluation
60 | self.sisdr = auraloss.time.SISDRLoss()
61 | self.mrstft = auraloss.freq.MultiResolutionSTFTLoss(
62 | fft_sizes=[512, 2048, 8192],
63 | hop_sizes=[256, 1024, 4096],
64 | win_lengths=[512, 2048, 8192],
65 | w_sc=0.0,
66 | w_phs=0.0,
67 | w_lin_mag=1.0,
68 | w_log_mag=1.0,
69 | )
70 |
71 | # load configuration files
72 | if mix_fn is knowledge_engineering_mix:
73 | with open(instrument_id_json, "r") as f:
74 | self.instrument_number_lookup = json.load(f)
75 |
76 | with open(knowledge_engineering_yaml, "r") as f:
77 | self.knowledge_engineering_dict = yaml.safe_load(f)
78 | else:
79 | self.instrument_number_lookup = None
80 | self.knowledge_engineering_dict = None
81 |
82 | # default
83 | self.use_track_input_fader = True
84 | self.use_track_panner = True
85 | self.use_track_eq = False
86 | self.use_track_compressor = False
87 | self.use_fx_bus = False
88 | self.use_master_bus = False
89 | self.use_output_fader = True
90 |
91 | def forward(self, tracks: torch.Tensor, ref_mix: torch.Tensor) -> torch.Tensor:
92 | """Apply model to audio waveform tracks.
93 | Args:
94 | tracks (torch.Tensor): Set of input tracks with shape (bs, num_tracks, 1, seq_len)
95 | ref_mix (torch.Tensor): Reference mix with shape (bs, 2, seq_len)
96 |
97 | Returns:
98 | pred_mix (torch.Tensor): Predicted mix with shape (bs, 2, seq_len)
99 | """
100 | return self.model(tracks, ref_mix)
101 |
102 | def common_step(
103 | self,
104 | batch: tuple,
105 | batch_idx: int,
106 | train: bool = False,
107 | ):
108 | """Model step used for validation and training.
109 | Args:
110 | batch (Tuple[Tensor, Tensor]): Batch items containing rmix, stems and orig mix
111 | batch_idx (int): Index of the batch within the current epoch.
112 | optimizer_idx (int): Index of the optimizer, this step is called once for each optimizer.
113 | train (bool): Wether step is called during training (True) or validation (False).
114 | """
115 |
116 | tracks, instrument_id, stereo_info, track_padding, ref_mix, song_name = batch
117 | #print("song_names from this batch: ", song_name)
118 |
119 | # split into A and B sections
120 | middle_idx = tracks.shape[-1] // 2
121 |
122 | # disable parts of the mix console based on global step
123 | if self.current_epoch >= self.active_eq_epoch:
124 | self.use_track_eq = True
125 |
126 | if self.current_epoch >= self.active_compressor_epoch:
127 | self.use_track_compressor = True
128 |
129 | if self.current_epoch >= self.active_fx_bus_epoch:
130 | self.use_fx_bus = True
131 |
132 | if self.current_epoch >= self.active_master_bus_epoch:
133 | self.use_master_bus = True
134 |
135 | bs, num_tracks, seq_len = tracks.shape
136 |
137 | # apply random gain to input tracks
138 | # tracks *= 10 ** ((torch.rand(bs, num_tracks, 1).type_as(tracks) * -12.0) / 20.0)
139 | ref_track_param_dict = None
140 | ref_fx_bus_param_dict = None
141 | ref_master_bus_param_dict = None
142 |
143 | # if tracks[...,middle_idx:].sum() == 0:
144 | # print("tracks are zero")
145 | # print(tracks[...,middle_idx:])
146 | # raise ValueError("input tracks are zero")
147 |
148 | # --------- create a random mix (on GPU, if applicable) ---------
149 | if self.generate_mix:
150 | (
151 | ref_mix_tracks,
152 | ref_mix,
153 | ref_track_param_dict,
154 | ref_fx_bus_param_dict,
155 | ref_master_bus_param_dict,
156 | ref_mix_params,
157 | ref_fx_bus_params,
158 | ref_master_bus_params
159 | ) = self.mix_fn(
160 | tracks,
161 | self.mix_console,
162 | use_track_input_fader=False, # do not use track input fader for training
163 | use_track_panner=self.use_track_panner,
164 | use_track_eq=self.use_track_eq,
165 | use_track_compressor=self.use_track_compressor,
166 | use_fx_bus=self.use_fx_bus,
167 | use_master_bus=self.use_master_bus,
168 | use_output_fader=False, # not used because we normalize output mixes
169 | instrument_id=instrument_id,
170 | stereo_id=stereo_info,
171 | instrument_number_file=self.instrument_number_lookup,
172 | ke_dict=self.knowledge_engineering_dict,
173 | )
174 |
175 | # normalize the reference mix
176 | ref_mix = batch_stereo_peak_normalize(ref_mix)
177 |
178 | if torch.isnan(ref_mix).any():
179 | #print(ref_track_param_dict)
180 | raise ValueError("Found nan in ref_mix")
181 |
182 |
183 | # if torch.count_nonzero(ref_mix[...,0:middle_idx])< 1:
184 | # print("ref_mix is zero")
185 | # raise ValueError("ref_mix is zero")
186 |
187 | ref_mix_a = ref_mix[..., :middle_idx] # this is passed to the model
188 | ref_mix_b = ref_mix[..., middle_idx:] # this is used for loss computation
189 |
190 | else:
191 | # when using a real mix, pass the same mix to model and loss
192 | ref_mix_a = ref_mix
193 | ref_mix_b = ref_mix
194 |
195 |
196 |
197 |
198 | # tracks_a = tracks[..., :input_middle_idx] # not used currently
199 |
200 | #print("input tracks: ", tracks[...,middle_idx:])
201 | #print("ref_mix: ", ref_mix_a)
202 |
203 |
204 | if self.current_epoch >= self.active_compressor_epoch:
205 | self.use_track_compressor = True
206 |
207 | if self.current_epoch >= self.active_fx_bus_epoch:
208 | self.use_fx_bus = True
209 |
210 | if self.current_epoch >= self.active_master_bus_epoch:
211 | self.use_master_bus = True
212 |
213 | bs, num_tracks, seq_len = tracks.shape
214 |
215 | # apply random gain to input tracks
216 | # tracks *= 10 ** ((torch.rand(bs, num_tracks, 1).type_as(tracks) * -12.0) / 20.0)
217 | ref_track_param_dict = None
218 | ref_fx_bus_param_dict = None
219 | ref_master_bus_param_dict = None
220 |
221 | # --------- create a random mix (on GPU, if applicable) ---------
222 | if self.generate_mix:
223 | (
224 | ref_mix_tracks,
225 | ref_mix,
226 | ref_track_param_dict,
227 | ref_fx_bus_param_dict,
228 | ref_master_bus_param_dict,
229 | ref_mix_params,
230 | ref_fx_bus_params,
231 | ref_master_bus_params,
232 | ) = self.mix_fn(
233 | tracks,
234 | self.mix_console,
235 | use_track_input_fader=False, # do not use track input fader for training
236 | use_track_panner=self.use_track_panner,
237 | use_track_eq=self.use_track_eq,
238 | use_track_compressor=self.use_track_compressor,
239 | use_fx_bus=self.use_fx_bus,
240 | use_master_bus=self.use_master_bus,
241 | use_output_fader=False, # not used because we normalize output mixes
242 | instrument_id=instrument_id,
243 | stereo_id=stereo_info,
244 | instrument_number_file=self.instrument_number_lookup,
245 | ke_dict=self.knowledge_engineering_dict,
246 | )
247 |
248 | # normalize the reference mix
249 | ref_mix = batch_stereo_peak_normalize(ref_mix)
250 |
251 | if torch.isnan(ref_mix).any():
252 | print(ref_track_param_dict)
253 | raise ValueError("Found nan in ref_mix")
254 |
255 | ref_mix_a = ref_mix[..., :middle_idx] # this is passed to the model
256 | ref_mix_b = ref_mix[..., middle_idx:] # this is used for loss computation
257 | # tracks_a = tracks[..., :input_middle_idx] # not used currently
258 | tracks_b = tracks[..., middle_idx:] # this is passed to the model
259 | else:
260 | # when using a real mix, pass the same mix to model and loss
261 | ref_mix_a = ref_mix
262 | ref_mix_b = ref_mix
263 | tracks_b = tracks
264 |
265 |
266 | # ---- run model with tracks from section A using reference mix from section B ----
267 | (
268 | pred_track_params,
269 | pred_fx_bus_params,
270 | pred_master_bus_params,
271 | ) = self.model(tracks_b, ref_mix_a, track_padding_mask=track_padding)
272 |
273 | # ------- generate a mix using the predicted mix console parameters -------
274 | (
275 | pred_mixed_tracks_b,
276 | pred_mix_b,
277 | pred_track_param_dict,
278 | pred_fx_bus_param_dict,
279 | pred_master_bus_param_dict,
280 | ) = self.mix_console(
281 | tracks_b,
282 | pred_track_params,
283 | pred_fx_bus_params,
284 | pred_master_bus_params,
285 | use_track_input_fader=self.use_track_input_fader,
286 | use_track_panner=self.use_track_panner,
287 | use_track_eq=self.use_track_eq,
288 | use_track_compressor=self.use_track_compressor,
289 | use_fx_bus=self.use_fx_bus,
290 | use_master_bus=self.use_master_bus,
291 | use_output_fader=self.use_output_fader,
292 | )
293 |
294 | # normalize the predicted mix before computing the loss
295 | # pred_mix_b = batch_stereo_peak_normalize(pred_mix_b)
296 |
297 |
298 | if ref_track_param_dict is None:
299 | ref_track_param_dict = pred_track_param_dict
300 | ref_fx_bus_param_dict = pred_fx_bus_param_dict
301 | ref_master_bus_param_dict = pred_master_bus_param_dict
302 |
303 | # ---------------------------- compute and log loss ------------------------------
304 |
305 |
306 | #print("pred_mix: ", pred_mix_b)
307 | # if pred_mix_b.sum() == 0:
308 |
309 | #print("pred_track_params: ", pred_track_params)
310 | #print("pred_fx_bus_params: ", pred_fx_bus_params)
311 | #print("pred_master_bus_params: ", pred_master_bus_params)
312 | #print("ref_mix: ",ref_mix_b)
313 |
314 | loss = 0
315 |
316 | #if parameter_loss is being used to train model, no need to generate mix
317 | if self.use_param_loss:
318 | track_param_loss = self.loss(pred_track_params, ref_mix_params)
319 | loss += track_param_loss
320 | if self.use_fx_bus:
321 | fx_bus_param_loss = self.loss(pred_fx_bus_params, ref_fx_bus_params)
322 | loss += fx_bus_param_loss
323 | if self.use_master_bus:
324 | master_bus_param_loss = self.loss(pred_master_bus_params, ref_master_bus_params)
325 | loss += master_bus_param_loss
326 |
327 |
328 | # ---------------------------- compute and log loss ------------------------------
329 |
330 | loss = 0
331 | if self.use_mix_loss:
332 | mix_loss = self.loss(pred_mix_b, ref_mix_b)
333 |
334 | if type(mix_loss) == dict:
335 | for key, val in mix_loss.items():
336 | loss += val.mean()
337 | else:
338 | loss += mix_loss
339 |
340 |
341 | if type(mix_loss) == dict:
342 | for key, value in mix_loss.items():
343 | self.log(
344 | ("train" if train else "val") + "/" + key,
345 | value,
346 | on_step=True,
347 | on_epoch=True,
348 | prog_bar=False,
349 | logger=True,
350 | sync_dist=True,
351 | )
352 | #print(loss)
353 |
354 |
355 | # log the losses
356 | self.log(
357 | ("train" if train else "val") + "/loss",
358 | loss,
359 | on_step=True,
360 | on_epoch=True,
361 | prog_bar=True,
362 | logger=True,
363 | sync_dist=True,
364 | )
365 |
366 |
367 | # sisdr_error = -self.sisdr(pred_mix_b, ref_mix_b)
368 | # log the SI-SDR error
369 | # self.log(
370 | # ("train" if train else "val") + "/si-sdr",
371 | # sisdr_error,
372 | # on_step=False,
373 | # on_epoch=True,
374 | # prog_bar=False,
375 | # logger=True,
376 | # sync_dist=True,
377 | # )
378 |
379 | # mrstft_error = self.mrstft(pred_mix_b, ref_mix_b)
380 | ## log the MR-STFT error
381 | # self.log(
382 | # ("train" if train else "val") + "/mrstft",
383 | # mrstft_error,
384 | # on_step=False,
385 | # on_epoch=True,
386 | # prog_bar=False,
387 | # logger=True,
388 | # sync_dist=True,
389 | # )
390 |
391 | # for plotting down the line
392 | sum_mix_b = tracks_b.sum(dim=1, keepdim=True).detach().float().cpu()
393 | sum_mix_b = batch_stereo_peak_normalize(sum_mix_b)
394 | data_dict = {
395 | "ref_mix_a": ref_mix_a.detach().float().cpu(),
396 | "ref_mix_b_norm": ref_mix_b.detach().float().cpu(),
397 | "pred_mix_b_norm": pred_mix_b.detach().float().cpu(),
398 | "sum_mix_b": sum_mix_b,
399 | "ref_track_param_dict": ref_track_param_dict,
400 | "pred_track_param_dict": pred_track_param_dict,
401 | "ref_fx_bus_param_dict": ref_fx_bus_param_dict,
402 | "pred_fx_bus_param_dict": pred_fx_bus_param_dict,
403 | "ref_master_bus_param_dict": ref_master_bus_param_dict,
404 | "pred_master_bus_param_dict": pred_master_bus_param_dict,
405 | }
406 |
407 | return loss, data_dict
408 |
409 | def training_step(self, batch, batch_idx):
410 | loss, data_dict = self.common_step(batch, batch_idx, train=True)
411 |
412 | #print(loss)
413 | return loss
414 |
415 | def validation_step(self, batch, batch_idx):
416 | loss, data_dict = self.common_step(batch, batch_idx, train=False)
417 | return data_dict
418 |
419 | def configure_optimizers(self):
420 | optimizer = torch.optim.Adam(
421 | self.model.parameters(),
422 | lr=self.hparams.lr,
423 | betas=(0.9, 0.999),
424 | )
425 |
426 | if self.hparams.schedule == "cosine":
427 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
428 | optimizer, T_max=self.hparams.max_epochs
429 | )
430 | elif self.hparams.schedule == "step":
431 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
432 | optimizer,
433 | [
434 | int(self.hparams.max_epochs * 0.85),
435 | int(self.hparams.max_epochs * 0.95),
436 | ],
437 | )
438 | else:
439 | #print(optimizer)
440 | return optimizer
441 | lr_schedulers = {"scheduler": scheduler, "interval": "epoch", "frequency": 1}
442 |
443 | return [optimizer], lr_schedulers
444 |
--------------------------------------------------------------------------------